feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进

- 新增统一的表达方式选择入口,支持classic和exp_model两种模式
- 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格
- 改进多个模块对DatabaseMessages数据模型的兼容性处理
- 优化消息处理逻辑,统一处理字典和DatabaseMessages对象
- 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
Windpicker-owo
2025-10-29 22:52:32 +08:00
parent f2d7af6d87
commit f6349f278d
16 changed files with 1419 additions and 54 deletions

View File

@@ -0,0 +1,254 @@
"""
表达系统工具函数
提供消息过滤、文本相似度计算、加权随机抽样等功能
"""
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
Args:
content: 原始消息内容
Returns:
过滤后的纯文本内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[图片:...]格式的图片ID
content = re.sub(r"\[图片:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
相似度值 (0-1)
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
"""
加权随机抽样函数
Args:
population: 待抽样的数据列表
k: 抽样数量
weight_key: 权重字段名如果为None则等概率抽样
Returns:
抽样结果列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 如果指定了权重字段
if weight_key and all(weight_key in item for item in population):
try:
# 获取权重
weights = [float(item.get(weight_key, 1.0)) for item in population]
# 使用random.choices进行加权抽样
return random.choices(population, weights=weights, k=k)
except (ValueError, TypeError) as e:
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
# 等概率抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
return selected
def normalize_text(text: str) -> str:
"""
标准化文本,移除多余空白字符
Args:
text: 输入文本
Returns:
标准化后的文本
"""
# 替换多个连续空白字符为单个空格
text = re.sub(r"\s+", " ", text)
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
"""
简单的关键词提取(基于词频)
Args:
text: 输入文本
max_keywords: 最大关键词数量
Returns:
关键词列表
"""
if not text:
return []
try:
import jieba.analyse
# 使用TF-IDF提取关键词
keywords = jieba.analyse.extract_tags(text, topK=max_keywords)
return keywords
except ImportError:
logger.warning("jieba未安装无法提取关键词")
# 简单分词
words = text.split()
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
"""
格式化表达方式对
Args:
situation: 情境
style: 风格
index: 序号(可选)
Returns:
格式化后的字符串
"""
if index is not None:
return f'{index}. 当"{situation}"时,使用"{style}"'
else:
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
"""
解析表达方式对文本
Args:
text: 格式化的表达方式对文本
Returns:
(situation, style) 或 None
"""
# 匹配格式:当"..."时,使用"..."
match = re.search(r'"(.+?)"时,使用"(.+?)"', text)
if match:
return match.group(1), match.group(2)
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
"""
批量去重表达方式
Args:
expressions: 表达方式列表
key_fields: 用于去重的字段名列表
Returns:
去重后的表达方式列表
"""
seen = set()
unique_expressions = []
for expr in expressions:
# 构建去重key
key_values = tuple(expr.get(field, "") for field in key_fields)
if key_values not in seen:
seen.add(key_values)
unique_expressions.append(expr)
return unique_expressions
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
"""
根据时间计算权重(时间衰减)
Args:
last_active_time: 最后活跃时间戳
current_time: 当前时间戳
half_life_days: 半衰期天数
Returns:
权重值 (0-1)
"""
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
if time_diff_days < 0:
return 1.0
# 使用指数衰减公式
decay_rate = 0.693 / half_life_days # ln(2) / half_life
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
return weight
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
"""
合并多个聊天室的表达方式
Args:
expressions_dict: {chat_id: [expressions]}
max_total: 最大合并数量
Returns:
合并后的表达方式列表
"""
all_expressions = []
# 收集所有表达方式
for chat_id, expressions in expressions_dict.items():
for expr in expressions:
# 添加source_id标识
expr_with_source = expr.copy()
expr_with_source["source_id"] = chat_id
all_expressions.append(expr_with_source)
# 按count或last_active_time排序
if all_expressions and "count" in all_expressions[0]:
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
elif all_expressions and "last_active_time" in all_expressions[0]:
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
# 去重基于situation和style
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
# 限制数量
return all_expressions[:max_total]

View File

@@ -16,6 +16,9 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
# 导入 StyleLearner 管理器
from .style_learner import style_learner_manager
MAX_EXPRESSION_COUNT = 300 MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01 DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值 DECAY_MIN = 0.01 # 最小衰减值
@@ -405,6 +408,29 @@ class ExpressionLearner:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr) await session.delete(expr)
# 🔥 新增:训练 StyleLearner
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)
if type == "style":
try:
# 获取 StyleLearner 实例
learner = style_learner_manager.get_learner(chat_id)
# 为每个学习到的表达方式训练模型
# 这里使用 situation 作为前置内容contextstyle 作为目标风格
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
learner.learn_mapping(situation, style)
logger.debug(f"已将 {len(expr_list)} 个表达方式训练到 StyleLearner")
# 保存模型
learner.save(style_learner_manager.model_save_path)
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}")
return learnt_expressions return learnt_expressions
return None return None
@@ -522,12 +548,12 @@ class ExpressionLearnerManager:
os.path.join(base_dir, "learnt_grammar"), os.path.join(base_dir, "learnt_grammar"),
] ]
try: for directory in directories_to_create:
for directory in directories_to_create: try:
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}") logger.debug(f"确保目录存在: {directory}")
except Exception as e: except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}") logger.error(f"创建目录失败 {directory}: {e}")
@staticmethod @staticmethod
async def _auto_migrate_json_to_db(): async def _auto_migrate_json_to_db():

View File

@@ -15,6 +15,9 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
# 导入StyleLearner管理器
from .style_learner import style_learner_manager
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@@ -236,6 +239,181 @@ class ExpressionSelector:
) )
await session.commit() await session.commit()
async def select_suitable_expressions(
self,
chat_id: str,
chat_history: list | str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""
统一的表达方式选择入口,根据配置自动选择模式
Args:
chat_id: 聊天ID
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息
max_num: 最多返回数量
min_num: 最少返回数量
Returns:
选中的表达方式列表
"""
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
else: # classic mode
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式使用StyleLearner预测最合适的表达风格"""
logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 获取或创建StyleLearner实例
learner = style_learner_manager.get_learner(chat_id)
# 使用StyleLearner预测最合适的风格
best_style, all_scores = learner.predict_style(chat_info, top_k=max_num)
if not best_style or not all_scores:
logger.warning(f"StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
# 根据预测的风格从数据库获取表达方式
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
predicted_styles: list[tuple[str, float]],
max_num: int = 10
) -> list[dict[str, Any]]:
"""
根据StyleLearner预测的风格获取表达方式
Args:
chat_id: 聊天ID
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
max_num: 最多返回数量
Returns:
表达方式列表
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
async with get_db_session() as session:
# 查询匹配这些风格的表达方式
stmt = (
select(Expression)
.where(Expression.chat_id == chat_id)
.where(Expression.style.in_(style_names))
.order_by(Expression.count.desc())
.limit(max_num)
)
result = await session.execute(stmt)
expressions_objs = result.scalars().all()
if not expressions_objs:
logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式")
return []
# 转换为字典格式
expressions = []
for expr in expressions_objs:
expressions.append({
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions
async def select_suitable_expressions_llm( async def select_suitable_expressions_llm(
self, self,
chat_id: str, chat_id: str,

View File

@@ -0,0 +1,9 @@
"""
表达模型包
包含基于Online Naive Bayes的机器学习模型
"""
from .model import ExpressorModel
from .online_nb import OnlineNaiveBayes
from .tokenizer import Tokenizer
__all__ = ["ExpressorModel", "OnlineNaiveBayes", "Tokenizer"]

View File

@@ -0,0 +1,216 @@
"""
基于Online Naive Bayes的表达模型
支持候选表达的动态添加和在线学习
"""
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
from .online_nb import OnlineNaiveBayes
from .tokenizer import Tokenizer
logger = get_logger("expressor.model")
class ExpressorModel:
"""直接使用朴素贝叶斯精排(可在线学习)"""
def __init__(
self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000, use_jieba: bool = True
):
"""
Args:
alpha: 词频平滑参数
beta: 类别先验平滑参数
gamma: 衰减因子
vocab_size: 词汇表大小
use_jieba: 是否使用jieba分词
"""
# 初始化分词器
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
# 初始化在线朴素贝叶斯模型
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
"""
添加候选文本和对应的situation
Args:
cid: 候选ID
text: 表达文本 (style)
situation: 情境文本
"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
Args:
text: 查询文本
k: 返回前k个候选如果为None则返回所有
Returns:
(最佳候选ID, 所有候选的分数字典)
"""
# 1. 分词
toks = self.tokenizer.tokenize(text)
if not toks or not self._candidates:
return None, {}
# 2. 计算词频
tf = Counter(toks)
all_cids = list(self._candidates.keys())
# 3. 批量评分
scores = self.nb.score_batch(tf, all_cids)
if not scores:
return None, {}
# 4. 根据k参数限制返回的候选数量
if k is not None and k > 0:
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""
更新正反馈学习
Args:
text: 输入文本
cid: 目标类别ID
"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
"""
应用知识衰减
Args:
factor: 衰减因子如果为None则使用模型配置的gamma
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取候选信息
Args:
cid: 候选ID
Returns:
(style文本, situation文本)
"""
style = self._candidates.get(cid)
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
"""
获取所有候选
Returns:
{cid: (style, situation)}
"""
result = {}
for cid in self._candidates.keys():
style, situation = self.get_candidate_info(cid)
result[cid] = (style, situation)
return result
def save(self, path: str):
"""
保存模型到文件
Args:
path: 保存路径
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
data = {
"candidates": self._candidates,
"situations": self._situations,
"nb_cls_counts": dict(self.nb.cls_counts),
"nb_token_counts": {k: dict(v) for k, v in self.nb.token_counts.items()},
"nb_alpha": self.nb.alpha,
"nb_beta": self.nb.beta,
"nb_gamma": self.nb.gamma,
"nb_V": self.nb.V,
}
with open(path, "wb") as f:
pickle.dump(data, f)
logger.info(f"模型已保存到 {path}")
def load(self, path: str):
"""
从文件加载模型
Args:
path: 加载路径
"""
if not os.path.exists(path):
logger.warning(f"模型文件不存在: {path}")
return
with open(path, "rb") as f:
data = pickle.load(f)
self._candidates = data["candidates"]
self._situations = data["situations"]
# 恢复nb模型的参数
self.nb.alpha = data["nb_alpha"]
self.nb.beta = data["nb_beta"]
self.nb.gamma = data["nb_gamma"]
self.nb.V = data["nb_V"]
# 恢复统计数据
self.nb.cls_counts = defaultdict(float, data["nb_cls_counts"])
self.nb.token_counts = defaultdict(lambda: defaultdict(float))
for cid, tc in data["nb_token_counts"].items():
self.nb.token_counts[cid] = defaultdict(float, tc)
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {
"n_candidates": len(self._candidates),
"n_classes": nb_stats["n_classes"],
"n_tokens": nb_stats["n_tokens"],
"total_counts": nb_stats["total_counts"],
}

View File

@@ -0,0 +1,142 @@
"""
在线朴素贝叶斯分类器
支持增量学习和知识衰减
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
logger = get_logger("expressor.online_nb")
class OnlineNaiveBayes:
"""在线朴素贝叶斯分类器"""
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
"""
Args:
alpha: 词频平滑参数
beta: 类别先验平滑参数
gamma: 衰减因子 (0-1之间1表示不衰减)
vocab_size: 词汇表大小
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
"""
批量计算候选的贝叶斯分数
Args:
tf: 查询文本的词频Counter
cids: 候选类别ID列表
Returns:
每个候选的分数字典
"""
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
# 计算似然概率 log P(w|c)
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
"""
正反馈更新
Args:
tf: 词频Counter
cid: 类别ID
"""
inc = 0.0
tc = self.token_counts[cid]
# 更新词频统计
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
# 更新类别统计
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
"""
知识衰减(遗忘机制)
Args:
factor: 衰减因子如果为None则使用self.gamma
"""
g = self.gamma if factor is None else factor
if g >= 1.0:
return
# 对所有统计进行衰减
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)
logger.debug(f"应用知识衰减,衰减因子: {g}")
def _logZ_c(self, cid: str) -> float:
"""
计算归一化因子logZ
Args:
cid: 类别ID
Returns:
log(Z_c)
"""
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def _invalidate(self, cid: str):
"""
使缓存失效
Args:
cid: 类别ID
"""
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),
"n_tokens": sum(len(tc) for tc in self.token_counts.values()),
"total_counts": sum(self.cls_counts.values()),
}

View File

@@ -0,0 +1,62 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
logger = get_logger("expressor.tokenizer")
class Tokenizer:
"""文本分词器支持中文Jieba分词"""
def __init__(self, stopwords: set = None, use_jieba: bool = True):
"""
Args:
stopwords: 停用词集合
use_jieba: 是否使用jieba分词
"""
self.stopwords = stopwords or set()
self.use_jieba = use_jieba
if use_jieba:
try:
import jieba
jieba.initialize()
logger.info("Jieba分词器初始化成功")
except ImportError:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
"""
分词并返回token列表
Args:
text: 输入文本
Returns:
token列表
"""
if not text:
return []
# 使用jieba分词
if self.use_jieba:
try:
import jieba
tokens = list(jieba.cut(text))
except Exception as e:
logger.warning(f"Jieba分词失败使用字符级分词: {e}")
tokens = list(text)
else:
# 简单按字符分词
tokens = list(text)
# 过滤停用词和空字符串
tokens = [token.strip() for token in tokens if token.strip() and token not in self.stopwords]
return tokens

View File

@@ -0,0 +1,405 @@
"""
风格学习引擎
基于ExpressorModel实现的表达风格学习和预测系统
支持多聊天室独立建模和在线学习
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
from .expressor_model import ExpressorModel
logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
"""
Args:
chat_id: 聊天室ID
model_config: 模型配置
"""
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"last_update": time.time(),
}
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 情境文本
Returns:
是否添加成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
# 初始化统计
self.learning_stats["style_counts"][style_id] = 0
logger.debug(f"添加风格成功: {style_id} -> {style}")
return True
except Exception as e:
logger.error(f"添加风格失败: {e}")
return False
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
Args:
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["last_update"] = time.time()
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
except Exception as e:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style文本, 所有候选的分数字典)
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"预测style失败: {e}")
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取style的完整信息
Args:
style: 风格文本
Returns:
(style_id, situation)
"""
style_id = self.style_to_id.get(style)
if not style_id:
return None, None
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
"""
获取所有风格列表
Returns:
风格文本列表
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
"""
应用知识衰减
Args:
factor: 衰减因子
"""
self.expressor.decay(factor)
logger.debug(f"应用知识衰减: chat_id={self.chat_id}")
def save(self, base_path: str) -> bool:
"""
保存学习器到文件
Args:
base_path: 基础保存路径
Returns:
是否保存成功
"""
try:
# 创建保存目录
save_dir = os.path.join(base_path, self.chat_id)
os.makedirs(save_dir, exist_ok=True)
# 保存expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
self.expressor.save(model_path)
# 保存映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats,
}
with open(meta_path, "wb") as f:
pickle.dump(meta_data, f)
logger.info(f"StyleLearner保存成功: {save_dir}")
return True
except Exception as e:
logger.error(f"保存StyleLearner失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载学习器
Args:
base_path: 基础加载路径
Returns:
是否加载成功
"""
try:
save_dir = os.path.join(base_path, self.chat_id)
# 检查目录是否存在
if not os.path.exists(save_dir):
logger.debug(f"StyleLearner保存目录不存在: {save_dir}")
return False
# 加载expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
if os.path.exists(model_path):
self.expressor.load(model_path)
# 加载映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
if os.path.exists(meta_path):
with open(meta_path, "rb") as f:
meta_data = pickle.load(f)
self.style_to_id = meta_data["style_to_id"]
self.id_to_style = meta_data["id_to_style"]
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
logger.info(f"StyleLearner加载成功: {save_dir}")
return True
except Exception as e:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
"chat_id": self.chat_id,
"n_styles": len(self.style_to_id),
"total_samples": self.learning_stats["total_samples"],
"last_update": self.learning_stats["last_update"],
"model_stats": model_stats,
}
class StyleLearnerManager:
"""多聊天室表达风格学习管理器"""
def __init__(self, model_save_path: str = "data/expression/style_models"):
"""
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
os.makedirs(model_save_path, exist_ok=True)
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
return self.learners[chat_id]
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的风格
Args:
chat_id: 聊天室ID
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style, 分数字典)
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def save_all(self) -> bool:
"""
保存所有学习器
Returns:
是否全部保存成功
"""
success = True
for chat_id, learner in self.learners.items():
if not learner.save(self.model_save_path):
success = False
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
"""
对所有学习器应用知识衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""
获取所有学习器的统计信息
Returns:
{chat_id: stats}
"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
# 全局单例
style_learner_manager = StyleLearnerManager()

View File

@@ -5,6 +5,7 @@ from typing import Any
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
@@ -142,7 +143,7 @@ class ChatterActionManager:
self, self,
action_name: str, action_name: str,
chat_id: str, chat_id: str,
target_message: dict | None = None, target_message: dict | DatabaseMessages | None = None,
reasoning: str = "", reasoning: str = "",
action_data: dict | None = None, action_data: dict | None = None,
thinking_id: str | None = None, thinking_id: str | None = None,
@@ -262,9 +263,15 @@ class ChatterActionManager:
from_plugin=False, from_plugin=False,
) )
if not success or not response_set: if not success or not response_set:
logger.info( # 安全地获取 processed_plain_text
f"{target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败" if isinstance(target_message, DatabaseMessages):
) msg_text = target_message.processed_plain_text or "未知消息"
elif target_message:
msg_text = target_message.get("processed_plain_text", "未知消息")
else:
msg_text = "未知消息"
logger.info(f"{msg_text} 的回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消") logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消")
@@ -322,8 +329,11 @@ class ChatterActionManager:
# 获取目标消息ID # 获取目标消息ID
target_message_id = None target_message_id = None
if target_message and isinstance(target_message, dict): if target_message:
target_message_id = target_message.get("message_id") if isinstance(target_message, DatabaseMessages):
target_message_id = target_message.message_id
elif isinstance(target_message, dict):
target_message_id = target_message.get("message_id")
elif action_data and isinstance(action_data, dict): elif action_data and isinstance(action_data, dict):
target_message_id = action_data.get("target_message_id") target_message_id = action_data.get("target_message_id")
@@ -488,14 +498,19 @@ class ChatterActionManager:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 # 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform") if isinstance(action_message, DatabaseMessages):
if platform is None: platform = action_message.chat_info.platform
platform = getattr(chat_stream, "platform", "unknown") user_id = action_message.user_info.user_id
else:
platform = action_message.get("chat_info_platform")
if platform is None:
platform = getattr(chat_stream, "platform", "unknown")
user_id = action_message.get("user_id", "")
# 获取用户信息并生成回复提示 # 获取用户信息并生成回复提示
person_id = person_info_manager.get_person_id( person_id = person_info_manager.get_person_id(
platform, platform,
action_message.get("user_id", ""), user_id,
) )
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@@ -565,7 +580,14 @@ class ChatterActionManager:
# 根据新消息数量决定是否需要引用回复 # 根据新消息数量决定是否需要引用回复
reply_text = "" reply_text = ""
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True # 检查是否为主动思考消息
if isinstance(message_data, DatabaseMessages):
# DatabaseMessages 对象没有 message_type 字段,默认为 False
is_proactive_thinking = False
elif message_data:
is_proactive_thinking = message_data.get("message_type") == "proactive_thinking"
else:
is_proactive_thinking = True
logger.debug(f"[send_response] message_data: {message_data}") logger.debug(f"[send_response] message_data: {message_data}")

View File

@@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.prompt_params import PromptParameters
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
@@ -474,10 +475,13 @@ class DefaultReplyer:
style_habits = [] style_habits = []
grammar_habits = [] grammar_habits = []
# 使用从处理器传来的选中表达方式 # 使用统一的表达方式选择入口支持classic和exp_model模式
# LLM模式调用LLM选择5-10个然后随机选5个 selected_expressions = await expression_selector.select_suitable_expressions(
selected_expressions = await expression_selector.select_suitable_expressions_llm( chat_id=self.chat_stream.stream_id,
self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target chat_history=chat_history,
target_message=target,
max_num=8,
min_num=2
) )
if selected_expressions: if selected_expressions:
@@ -1206,7 +1210,7 @@ class DefaultReplyer:
extra_info: str = "", extra_info: str = "",
available_actions: dict[str, ActionInfo] | None = None, available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = True, enable_tool: bool = True,
reply_message: dict[str, Any] | None = None, reply_message: dict[str, Any] | DatabaseMessages | None = None,
) -> str: ) -> str:
""" """
构建回复器上下文 构建回复器上下文
@@ -1248,10 +1252,24 @@ class DefaultReplyer:
if reply_message is None: if reply_message is None:
logger.warning("reply_message 为 None无法构建prompt") logger.warning("reply_message 为 None无法构建prompt")
return "" return ""
platform = reply_message.get("chat_info_platform")
# 统一处理 DatabaseMessages 对象和字典
if isinstance(reply_message, DatabaseMessages):
platform = reply_message.chat_info.platform
user_id = reply_message.user_info.user_id
user_nickname = reply_message.user_info.user_nickname
user_cardname = reply_message.user_info.user_cardname
processed_plain_text = reply_message.processed_plain_text
else:
platform = reply_message.get("chat_info_platform")
user_id = reply_message.get("user_id")
user_nickname = reply_message.get("user_nickname")
user_cardname = reply_message.get("user_cardname")
processed_plain_text = reply_message.get("processed_plain_text")
person_id = person_info_manager.get_person_id( person_id = person_info_manager.get_person_id(
platform, # type: ignore platform, # type: ignore
reply_message.get("user_id"), # type: ignore user_id, # type: ignore
) )
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -1260,22 +1278,22 @@ class DefaultReplyer:
# 尝试从reply_message获取用户名 # 尝试从reply_message获取用户名
await person_info_manager.first_knowing_some_one( await person_info_manager.first_knowing_some_one(
platform, # type: ignore platform, # type: ignore
reply_message.get("user_id"), # type: ignore user_id, # type: ignore
reply_message.get("user_nickname") or "", user_nickname or "",
reply_message.get("user_cardname") or "", user_cardname or "",
) )
# 检查是否是bot自己的名字如果是则替换为"(你)" # 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
current_user_id = await person_info_manager.get_value(person_id, "user_id") current_user_id = await person_info_manager.get_value(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform") current_platform = platform
if current_user_id == bot_user_id and current_platform == global_config.bot.platform: if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
sender = f"{person_name}(你)" sender = f"{person_name}(你)"
else: else:
# 如果不是bot自己直接使用person_name # 如果不是bot自己直接使用person_name
sender = person_name sender = person_name
target = reply_message.get("processed_plain_text") target = processed_plain_text
# 最终的空值检查确保sender和target不为None # 最终的空值检查确保sender和target不为None
if sender is None: if sender is None:
@@ -1609,15 +1627,22 @@ class DefaultReplyer:
raw_reply: str, raw_reply: str,
reason: str, reason: str,
reply_to: str, reply_to: str,
reply_message: dict[str, Any] | None = None, reply_message: dict[str, Any] | DatabaseMessages | None = None,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
if reply_message: if reply_message:
sender = reply_message.get("sender") if isinstance(reply_message, DatabaseMessages):
target = reply_message.get("target") # 从 DatabaseMessages 对象获取 sender 和 target
# 注意: DatabaseMessages 没有直接的 sender/target 字段
# 需要根据实际情况构造
sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id
target = reply_message.processed_plain_text or ""
else:
sender = reply_message.get("sender")
target = reply_message.get("target")
else: else:
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)

View File

@@ -606,11 +606,11 @@ class Prompt:
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
) )
# 使用LLM选择与当前情景匹配的表达习惯 # 使用统一的表达方式选择入口支持classic和exp_model模式
expression_selector = ExpressionSelector(self.parameters.chat_id) expression_selector = ExpressionSelector(self.parameters.chat_id)
selected_expressions = await expression_selector.select_suitable_expressions_llm( selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=self.parameters.chat_id, chat_id=self.parameters.chat_id,
chat_info=chat_history, chat_history=chat_history,
target_message=self.parameters.target, target_message=self.parameters.target,
) )

View File

@@ -183,6 +183,10 @@ class ExpressionRule(ValidatedConfigBase):
class ExpressionConfig(ValidatedConfigBase): class ExpressionConfig(ValidatedConfigBase):
"""表达配置类""" """表达配置类"""
mode: Literal["classic", "exp_model"] = Field(
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@staticmethod @staticmethod

View File

@@ -107,10 +107,13 @@ class PromptBuilder:
style_habits = [] style_habits = []
grammar_habits = [] grammar_habits = []
# 使用从处理器传来的选中表达方式 # 使用统一的表达方式选择入口支持classic和exp_model模式
# LLM模式调用LLM选择5-10个然后随机选5个 selected_expressions = await expression_selector.select_suitable_expressions(
selected_expressions = await expression_selector.select_suitable_expressions_llm( chat_id=chat_stream.stream_id,
chat_stream.stream_id, chat_history, max_num=12, min_num=5, target_message=target chat_history=chat_history,
target_message=target,
max_num=12,
min_num=5
) )
if selected_expressions: if selected_expressions:

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import database_api, message_api, send_api from src.plugin_system.apis import database_api, message_api, send_api
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType
@@ -180,11 +181,18 @@ class BaseAction(ABC):
if self.has_action_message: if self.has_action_message:
if self.action_name != "no_reply": if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None)) # 统一处理 DatabaseMessages 对象和字典
self.group_name = self.action_message.get("chat_info_group_name", None) if isinstance(self.action_message, DatabaseMessages):
self.group_id = str(self.action_message.group_info.group_id if self.action_message.group_info else None)
self.user_id = str(self.action_message.get("user_id", None)) self.group_name = self.action_message.group_info.group_name if self.action_message.group_info else None
self.user_nickname = self.action_message.get("user_nickname", None) self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
else:
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id: if self.group_id:
self.is_group = True self.is_group = True
self.target_id = self.group_id self.target_id = self.group_id

View File

@@ -6,6 +6,7 @@ from typing import ClassVar
from dateutil.parser import parse as parse_datetime from dateutil.parser import parse as parse_datetime
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
@@ -253,19 +254,19 @@ class SetEmojiLikeAction(BaseAction):
message_id = None message_id = None
set_like = self.action_data.get("set", True) set_like = self.action_data.get("set", True)
if self.has_action_message and isinstance(self.action_message, dict): if self.has_action_message:
message_id = self.action_message.get("message_id") if isinstance(self.action_message, DatabaseMessages):
logger.info(f"获取到的消息ID: {message_id}") message_id = self.action_message.message_id
else: logger.info(f"获取到的消息ID: {message_id}")
elif isinstance(self.action_message, dict):
message_id = self.action_message.get("message_id")
logger.info(f"获取到的消息ID: {message_id}")
if not message_id:
logger.error("未提供有效的消息或消息ID") logger.error("未提供有效的消息或消息ID")
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False) await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
return False, "未提供消息ID" return False, "未提供消息ID"
if not message_id:
logger.error("消息ID为空")
await self.store_action_info(action_prompt_display="贴表情失败: 消息ID为空", action_done=False)
return False, "消息ID为空"
available_models = llm_api.get_available_models() available_models = llm_api.get_available_models()
if "utils_small" not in available_models: if "utils_small" not in available_models:
logger.error("未找到 'utils_small' 模型配置,无法选择表情") logger.error("未找到 'utils_small' 模型配置,无法选择表情")
@@ -273,7 +274,12 @@ class SetEmojiLikeAction(BaseAction):
model_to_use = available_models["utils_small"] model_to_use = available_models["utils_small"]
context_text = self.action_message.get("processed_plain_text", "") # 统一处理 DatabaseMessages 和字典
if isinstance(self.action_message, DatabaseMessages):
context_text = self.action_message.processed_plain_text or ""
else:
context_text = self.action_message.get("processed_plain_text", "")
if not context_text: if not context_text:
logger.error("无法找到动作选择的原始消息文本") logger.error("无法找到动作选择的原始消息文本")
return False, "无法找到动作选择的原始消息文本" return False, "无法找到动作选择的原始消息文本"

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "7.5.1" version = "7.5.2"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读---- #----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值 #如果你想要修改配置文件请递增version的值
@@ -92,6 +92,11 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
[expression] [expression]
# 表达学习配置 # 表达学习配置
# mode: 表达方式模式,可选:
# - "classic": 经典模式,随机抽样 + LLM选择
# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达
mode = "classic"
# rules是一个列表每个元素都是一个学习规则 # rules是一个列表每个元素都是一个学习规则
# chat_stream_id: 聊天流ID格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置 # chat_stream_id: 聊天流ID格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置
# use_expression: 是否使用学到的表达 (true/false) # use_expression: 是否使用学到的表达 (true/false)