From f6349f278d830bf639a674892fa0a24653a6bf9e Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Wed, 29 Oct 2025 22:52:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(expression):=20=E6=B7=BB=E5=8A=A0=E8=A1=A8?= =?UTF-8?q?=E8=BE=BE=E6=96=B9=E5=BC=8F=E9=80=89=E6=8B=A9=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E4=B8=8EDatabaseMessages=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E6=80=A7=E6=94=B9=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增统一的表达方式选择入口,支持classic和exp_model两种模式 - 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格 - 改进多个模块对DatabaseMessages数据模型的兼容性处理 - 优化消息处理逻辑,统一处理字典和DatabaseMessages对象 - 在配置中添加expression.mode字段控制表达选择模式 --- src/chat/express/express_utils.py | 254 +++++++++++ src/chat/express/expression_learner.py | 36 +- src/chat/express/expression_selector.py | 178 ++++++++ src/chat/express/expressor_model/__init__.py | 9 + src/chat/express/expressor_model/model.py | 216 ++++++++++ src/chat/express/expressor_model/online_nb.py | 142 ++++++ src/chat/express/expressor_model/tokenizer.py | 62 +++ src/chat/express/style_learner.py | 405 ++++++++++++++++++ src/chat/planner_actions/action_manager.py | 44 +- src/chat/replyer/default_generator.py | 55 ++- src/chat/utils/prompt.py | 6 +- src/config/official_configs.py | 4 + src/mais4u/mais4u_chat/s4u_prompt.py | 11 +- src/plugin_system/base/base_action.py | 18 +- .../built_in/social_toolkit_plugin/plugin.py | 26 +- template/bot_config_template.toml | 7 +- 16 files changed, 1419 insertions(+), 54 deletions(-) create mode 100644 src/chat/express/express_utils.py create mode 100644 src/chat/express/expressor_model/__init__.py create mode 100644 src/chat/express/expressor_model/model.py create mode 100644 src/chat/express/expressor_model/online_nb.py create mode 100644 src/chat/express/expressor_model/tokenizer.py create mode 100644 src/chat/express/style_learner.py diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py new file mode 100644 index 000000000..bd7f41e2d --- /dev/null +++ b/src/chat/express/express_utils.py @@ -0,0 +1,254 @@ +""" +表达系统工具函数 +提供消息过滤、文本相似度计算、加权随机抽样等功能 +""" +import difflib +import random +import re +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger + +logger = get_logger("express_utils") + + +def filter_message_content(content: Optional[str]) -> str: + """ + 过滤消息内容,移除回复、@、图片等格式 + + Args: + content: 原始消息内容 + + Returns: + 过滤后的纯文本内容 + """ + if not content: + return "" + + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 + content = re.sub(r"\[回复.*?\],说:\s*", "", content) + # 移除@<...>格式的内容 + content = re.sub(r"@<[^>]*>", "", content) + # 移除[图片:...]格式的图片ID + content = re.sub(r"\[图片:[^\]]*\]", "", content) + # 移除[表情包:...]格式的内容 + content = re.sub(r"\[表情包:[^\]]*\]", "", content) + + return content.strip() + + +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + 相似度值 (0-1) + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + +def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]: + """ + 加权随机抽样函数 + + Args: + population: 待抽样的数据列表 + k: 抽样数量 + weight_key: 权重字段名,如果为None则等概率抽样 + + Returns: + 抽样结果列表 + """ + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + # 如果指定了权重字段 + if weight_key and all(weight_key in item for item in population): + try: + # 获取权重 + weights = [float(item.get(weight_key, 1.0)) for item in population] + # 使用random.choices进行加权抽样 + return random.choices(population, weights=weights, k=k) + except (ValueError, TypeError) as e: + logger.warning(f"加权抽样失败,使用等概率抽样: {e}") + + # 等概率抽样 + selected = [] + population_copy = population.copy() + + for _ in range(k): + if not population_copy: + break + # 随机选择一个元素 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + + return selected + + +def normalize_text(text: str) -> str: + """ + 标准化文本,移除多余空白字符 + + Args: + text: 输入文本 + + Returns: + 标准化后的文本 + """ + # 替换多个连续空白字符为单个空格 + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: + """ + 简单的关键词提取(基于词频) + + Args: + text: 输入文本 + max_keywords: 最大关键词数量 + + Returns: + 关键词列表 + """ + if not text: + return [] + + try: + import jieba.analyse + + # 使用TF-IDF提取关键词 + keywords = jieba.analyse.extract_tags(text, topK=max_keywords) + return keywords + except ImportError: + logger.warning("jieba未安装,无法提取关键词") + # 简单分词 + words = text.split() + return words[:max_keywords] + + +def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str: + """ + 格式化表达方式对 + + Args: + situation: 情境 + style: 风格 + index: 序号(可选) + + Returns: + 格式化后的字符串 + """ + if index is not None: + return f'{index}. 当"{situation}"时,使用"{style}"' + else: + return f'当"{situation}"时,使用"{style}"' + + +def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: + """ + 解析表达方式对文本 + + Args: + text: 格式化的表达方式对文本 + + Returns: + (situation, style) 或 None + """ + # 匹配格式:当"..."时,使用"..." + match = re.search(r'当"(.+?)"时,使用"(.+?)"', text) + if match: + return match.group(1), match.group(2) + return None + + +def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]: + """ + 批量去重表达方式 + + Args: + expressions: 表达方式列表 + key_fields: 用于去重的字段名列表 + + Returns: + 去重后的表达方式列表 + """ + seen = set() + unique_expressions = [] + + for expr in expressions: + # 构建去重key + key_values = tuple(expr.get(field, "") for field in key_fields) + + if key_values not in seen: + seen.add(key_values) + unique_expressions.append(expr) + + return unique_expressions + + +def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float: + """ + 根据时间计算权重(时间衰减) + + Args: + last_active_time: 最后活跃时间戳 + current_time: 当前时间戳 + half_life_days: 半衰期天数 + + Returns: + 权重值 (0-1) + """ + time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数 + if time_diff_days < 0: + return 1.0 + + # 使用指数衰减公式 + decay_rate = 0.693 / half_life_days # ln(2) / half_life + weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days))) + + return weight + + +def merge_expressions_from_multiple_chats( + expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100 +) -> List[Dict[str, Any]]: + """ + 合并多个聊天室的表达方式 + + Args: + expressions_dict: {chat_id: [expressions]} + max_total: 最大合并数量 + + Returns: + 合并后的表达方式列表 + """ + all_expressions = [] + + # 收集所有表达方式 + for chat_id, expressions in expressions_dict.items(): + for expr in expressions: + # 添加source_id标识 + expr_with_source = expr.copy() + expr_with_source["source_id"] = chat_id + all_expressions.append(expr_with_source) + + # 按count或last_active_time排序 + if all_expressions and "count" in all_expressions[0]: + all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True) + elif all_expressions and "last_active_time" in all_expressions[0]: + all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True) + + # 去重(基于situation和style) + all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"]) + + # 限制数量 + return all_expressions[:max_total] diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 0c25b9fc6..a3299bfa1 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -16,6 +16,9 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +# 导入 StyleLearner 管理器 +from .style_learner import style_learner_manager + MAX_EXPRESSION_COUNT = 300 DECAY_DAYS = 30 # 30天衰减到0.01 DECAY_MIN = 0.01 # 最小衰减值 @@ -405,6 +408,29 @@ class ExpressionLearner: for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) + # 🔥 新增:训练 StyleLearner + # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) + if type == "style": + try: + # 获取 StyleLearner 实例 + learner = style_learner_manager.get_learner(chat_id) + + # 为每个学习到的表达方式训练模型 + # 这里使用 situation 作为前置内容(context),style 作为目标风格 + for expr in expr_list: + situation = expr["situation"] + style = expr["style"] + + # 训练映射关系: situation -> style + learner.learn_mapping(situation, style) + + logger.debug(f"已将 {len(expr_list)} 个表达方式训练到 StyleLearner") + + # 保存模型 + learner.save(style_learner_manager.model_save_path) + except Exception as e: + logger.error(f"训练 StyleLearner 失败: {e}") + return learnt_expressions return None @@ -522,12 +548,12 @@ class ExpressionLearnerManager: os.path.join(base_dir, "learnt_grammar"), ] - try: - for directory in directories_to_create: + for directory in directories_to_create: + try: os.makedirs(directory, exist_ok=True) - logger.debug(f"确保目录存在: {directory}") - except Exception as e: - logger.error(f"创建目录失败 {directory}: {e}") + logger.debug(f"确保目录存在: {directory}") + except Exception as e: + logger.error(f"创建目录失败 {directory}: {e}") @staticmethod async def _auto_migrate_json_to_db(): diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index eee737f3e..dc511055a 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -15,6 +15,9 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +# 导入StyleLearner管理器 +from .style_learner import style_learner_manager + logger = get_logger("expression_selector") @@ -236,6 +239,181 @@ class ExpressionSelector: ) await session.commit() + async def select_suitable_expressions( + self, + chat_id: str, + chat_history: list | str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """ + 统一的表达方式选择入口,根据配置自动选择模式 + + Args: + chat_id: 聊天ID + chat_history: 聊天历史(列表或字符串) + target_message: 目标消息 + max_num: 最多返回数量 + min_num: 最少返回数量 + + Returns: + 选中的表达方式列表 + """ + # 转换chat_history为字符串 + if isinstance(chat_history, list): + chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history]) + else: + chat_info = chat_history + + # 根据配置选择模式 + mode = global_config.expression.mode + logger.debug(f"[ExpressionSelector] 使用模式: {mode}") + + if mode == "exp_model": + return await self._select_expressions_model_only( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + else: # classic mode + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + async def _select_expressions_classic( + self, + chat_id: str, + chat_info: str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """经典模式:随机抽样 + LLM评估""" + logger.debug(f"[Classic模式] 使用LLM评估表达方式") + return await self.select_suitable_expressions_llm( + chat_id=chat_id, + chat_info=chat_info, + max_num=max_num, + min_num=min_num, + target_message=target_message + ) + + async def _select_expressions_model_only( + self, + chat_id: str, + chat_info: str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """模型预测模式:使用StyleLearner预测最合适的表达风格""" + logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式") + + # 检查是否允许在此聊天流中使用表达 + if not self.can_use_expression_for_chat(chat_id): + logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") + return [] + + # 获取或创建StyleLearner实例 + learner = style_learner_manager.get_learner(chat_id) + + # 使用StyleLearner预测最合适的风格 + best_style, all_scores = learner.predict_style(chat_info, top_k=max_num) + + if not best_style or not all_scores: + logger.warning(f"StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + # 将分数字典转换为列表格式 [(style, score), ...] + predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True) + + # 根据预测的风格从数据库获取表达方式 + expressions = await self.get_model_predicted_expressions( + chat_id=chat_id, + predicted_styles=predicted_styles, + max_num=max_num + ) + + if not expressions: + logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式") + return expressions + + async def get_model_predicted_expressions( + self, + chat_id: str, + predicted_styles: list[tuple[str, float]], + max_num: int = 10 + ) -> list[dict[str, Any]]: + """ + 根据StyleLearner预测的风格获取表达方式 + + Args: + chat_id: 聊天ID + predicted_styles: 预测的风格列表,格式: [(style, score), ...] + max_num: 最多返回数量 + + Returns: + 表达方式列表 + """ + if not predicted_styles: + return [] + + # 提取风格名称(前3个最佳匹配) + style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]] + logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}") + + async with get_db_session() as session: + # 查询匹配这些风格的表达方式 + stmt = ( + select(Expression) + .where(Expression.chat_id == chat_id) + .where(Expression.style.in_(style_names)) + .order_by(Expression.count.desc()) + .limit(max_num) + ) + result = await session.execute(stmt) + expressions_objs = result.scalars().all() + + if not expressions_objs: + logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式") + return [] + + # 转换为字典格式 + expressions = [] + for expr in expressions_objs: + expressions.append({ + "situation": expr.situation or "", + "style": expr.style or "", + "type": expr.type or "style", + "count": float(expr.count) if expr.count else 0.0, + "last_active_time": expr.last_active_time or 0.0 + }) + + logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式") + return expressions + async def select_suitable_expressions_llm( self, chat_id: str, diff --git a/src/chat/express/expressor_model/__init__.py b/src/chat/express/expressor_model/__init__.py new file mode 100644 index 000000000..a13656a85 --- /dev/null +++ b/src/chat/express/expressor_model/__init__.py @@ -0,0 +1,9 @@ +""" +表达模型包 +包含基于Online Naive Bayes的机器学习模型 +""" +from .model import ExpressorModel +from .online_nb import OnlineNaiveBayes +from .tokenizer import Tokenizer + +__all__ = ["ExpressorModel", "OnlineNaiveBayes", "Tokenizer"] diff --git a/src/chat/express/expressor_model/model.py b/src/chat/express/expressor_model/model.py new file mode 100644 index 000000000..8c18240a8 --- /dev/null +++ b/src/chat/express/expressor_model/model.py @@ -0,0 +1,216 @@ +""" +基于Online Naive Bayes的表达模型 +支持候选表达的动态添加和在线学习 +""" +import os +import pickle +from collections import Counter, defaultdict +from typing import Dict, Optional, Tuple + +from src.common.logger import get_logger + +from .online_nb import OnlineNaiveBayes +from .tokenizer import Tokenizer + +logger = get_logger("expressor.model") + + +class ExpressorModel: + """直接使用朴素贝叶斯精排(可在线学习)""" + + def __init__( + self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000, use_jieba: bool = True + ): + """ + Args: + alpha: 词频平滑参数 + beta: 类别先验平滑参数 + gamma: 衰减因子 + vocab_size: 词汇表大小 + use_jieba: 是否使用jieba分词 + """ + # 初始化分词器 + self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) + + # 初始化在线朴素贝叶斯模型 + self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) + + # 候选表达管理 + self._candidates: Dict[str, str] = {} # cid -> text (style) + self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) + + logger.info( + f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})" + ) + + def add_candidate(self, cid: str, text: str, situation: Optional[str] = None): + """ + 添加候选文本和对应的situation + + Args: + cid: 候选ID + text: 表达文本 (style) + situation: 情境文本 + """ + self._candidates[cid] = text + if situation is not None: + self._situations[cid] = situation + + # 确保在nb模型中初始化该候选的计数 + if cid not in self.nb.cls_counts: + self.nb.cls_counts[cid] = 0.0 + if cid not in self.nb.token_counts: + self.nb.token_counts[cid] = defaultdict(float) + + def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: + """ + 直接对所有候选进行朴素贝叶斯评分 + + Args: + text: 查询文本 + k: 返回前k个候选,如果为None则返回所有 + + Returns: + (最佳候选ID, 所有候选的分数字典) + """ + # 1. 分词 + toks = self.tokenizer.tokenize(text) + if not toks or not self._candidates: + return None, {} + + # 2. 计算词频 + tf = Counter(toks) + all_cids = list(self._candidates.keys()) + + # 3. 批量评分 + scores = self.nb.score_batch(tf, all_cids) + + if not scores: + return None, {} + + # 4. 根据k参数限制返回的候选数量 + if k is not None and k > 0: + sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + limited_scores = dict(sorted_scores[:k]) + best = sorted_scores[0][0] if sorted_scores else None + return best, limited_scores + else: + best = max(scores.items(), key=lambda x: x[1])[0] + return best, scores + + def update_positive(self, text: str, cid: str): + """ + 更新正反馈学习 + + Args: + text: 输入文本 + cid: 目标类别ID + """ + toks = self.tokenizer.tokenize(text) + if not toks: + return + + tf = Counter(toks) + self.nb.update_positive(tf, cid) + + def decay(self, factor: Optional[float] = None): + """ + 应用知识衰减 + + Args: + factor: 衰减因子,如果为None则使用模型配置的gamma + """ + self.nb.decay(factor) + + def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: + """ + 获取候选信息 + + Args: + cid: 候选ID + + Returns: + (style文本, situation文本) + """ + style = self._candidates.get(cid) + situation = self._situations.get(cid) + return style, situation + + def get_all_candidates(self) -> Dict[str, Tuple[str, str]]: + """ + 获取所有候选 + + Returns: + {cid: (style, situation)} + """ + result = {} + for cid in self._candidates.keys(): + style, situation = self.get_candidate_info(cid) + result[cid] = (style, situation) + return result + + def save(self, path: str): + """ + 保存模型到文件 + + Args: + path: 保存路径 + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + + data = { + "candidates": self._candidates, + "situations": self._situations, + "nb_cls_counts": dict(self.nb.cls_counts), + "nb_token_counts": {k: dict(v) for k, v in self.nb.token_counts.items()}, + "nb_alpha": self.nb.alpha, + "nb_beta": self.nb.beta, + "nb_gamma": self.nb.gamma, + "nb_V": self.nb.V, + } + + with open(path, "wb") as f: + pickle.dump(data, f) + + logger.info(f"模型已保存到 {path}") + + def load(self, path: str): + """ + 从文件加载模型 + + Args: + path: 加载路径 + """ + if not os.path.exists(path): + logger.warning(f"模型文件不存在: {path}") + return + + with open(path, "rb") as f: + data = pickle.load(f) + + self._candidates = data["candidates"] + self._situations = data["situations"] + + # 恢复nb模型的参数 + self.nb.alpha = data["nb_alpha"] + self.nb.beta = data["nb_beta"] + self.nb.gamma = data["nb_gamma"] + self.nb.V = data["nb_V"] + + # 恢复统计数据 + self.nb.cls_counts = defaultdict(float, data["nb_cls_counts"]) + self.nb.token_counts = defaultdict(lambda: defaultdict(float)) + for cid, tc in data["nb_token_counts"].items(): + self.nb.token_counts[cid] = defaultdict(float, tc) + + logger.info(f"模型已从 {path} 加载") + + def get_stats(self) -> Dict: + """获取模型统计信息""" + nb_stats = self.nb.get_stats() + return { + "n_candidates": len(self._candidates), + "n_classes": nb_stats["n_classes"], + "n_tokens": nb_stats["n_tokens"], + "total_counts": nb_stats["total_counts"], + } diff --git a/src/chat/express/expressor_model/online_nb.py b/src/chat/express/expressor_model/online_nb.py new file mode 100644 index 000000000..39bd0d1cd --- /dev/null +++ b/src/chat/express/expressor_model/online_nb.py @@ -0,0 +1,142 @@ +""" +在线朴素贝叶斯分类器 +支持增量学习和知识衰减 +""" +import math +from collections import Counter, defaultdict +from typing import Dict, List, Optional + +from src.common.logger import get_logger + +logger = get_logger("expressor.online_nb") + + +class OnlineNaiveBayes: + """在线朴素贝叶斯分类器""" + + def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): + """ + Args: + alpha: 词频平滑参数 + beta: 类别先验平滑参数 + gamma: 衰减因子 (0-1之间,1表示不衰减) + vocab_size: 词汇表大小 + """ + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.V = vocab_size + + # 类别统计 + self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count + self.token_counts: Dict[str, Dict[str, float]] = defaultdict( + lambda: defaultdict(float) + ) # cid -> term -> count + + # 缓存 + self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + + def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: + """ + 批量计算候选的贝叶斯分数 + + Args: + tf: 查询文本的词频Counter + cids: 候选类别ID列表 + + Returns: + 每个候选的分数字典 + """ + total_cls = sum(self.cls_counts.values()) + n_cls = max(1, len(self.cls_counts)) + denom_prior = math.log(total_cls + self.beta * n_cls) + + out: Dict[str, float] = {} + for cid in cids: + # 计算先验概率 log P(c) + prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior + s = prior + + # 计算似然概率 log P(w|c) + logZ = self._logZ_c(cid) + tc = self.token_counts[cid] + + for term, qtf in tf.items(): + num = tc.get(term, 0.0) + self.alpha + s += qtf * (math.log(num) - logZ) + + out[cid] = s + return out + + def update_positive(self, tf: Counter, cid: str): + """ + 正反馈更新 + + Args: + tf: 词频Counter + cid: 类别ID + """ + inc = 0.0 + tc = self.token_counts[cid] + + # 更新词频统计 + for term, c in tf.items(): + tc[term] += float(c) + inc += float(c) + + # 更新类别统计 + self.cls_counts[cid] += inc + self._invalidate(cid) + + def decay(self, factor: Optional[float] = None): + """ + 知识衰减(遗忘机制) + + Args: + factor: 衰减因子,如果为None则使用self.gamma + """ + g = self.gamma if factor is None else factor + if g >= 1.0: + return + + # 对所有统计进行衰减 + for cid in list(self.cls_counts.keys()): + self.cls_counts[cid] *= g + for term in list(self.token_counts[cid].keys()): + self.token_counts[cid][term] *= g + self._invalidate(cid) + + logger.debug(f"应用知识衰减,衰减因子: {g}") + + def _logZ_c(self, cid: str) -> float: + """ + 计算归一化因子logZ + + Args: + cid: 类别ID + + Returns: + log(Z_c) + """ + if cid not in self._logZ: + Z = self.cls_counts[cid] + self.V * self.alpha + self._logZ[cid] = math.log(max(Z, 1e-12)) + return self._logZ[cid] + + def _invalidate(self, cid: str): + """ + 使缓存失效 + + Args: + cid: 类别ID + """ + if cid in self._logZ: + del self._logZ[cid] + + def get_stats(self) -> Dict: + """获取统计信息""" + return { + "n_classes": len(self.cls_counts), + "n_tokens": sum(len(tc) for tc in self.token_counts.values()), + "total_counts": sum(self.cls_counts.values()), + } diff --git a/src/chat/express/expressor_model/tokenizer.py b/src/chat/express/expressor_model/tokenizer.py new file mode 100644 index 000000000..e25f780d4 --- /dev/null +++ b/src/chat/express/expressor_model/tokenizer.py @@ -0,0 +1,62 @@ +""" +文本分词器,支持中文Jieba分词 +""" +from typing import List + +from src.common.logger import get_logger + +logger = get_logger("expressor.tokenizer") + + +class Tokenizer: + """文本分词器,支持中文Jieba分词""" + + def __init__(self, stopwords: set = None, use_jieba: bool = True): + """ + Args: + stopwords: 停用词集合 + use_jieba: 是否使用jieba分词 + """ + self.stopwords = stopwords or set() + self.use_jieba = use_jieba + + if use_jieba: + try: + import jieba + + jieba.initialize() + logger.info("Jieba分词器初始化成功") + except ImportError: + logger.warning("Jieba未安装,将使用字符级分词") + self.use_jieba = False + + def tokenize(self, text: str) -> List[str]: + """ + 分词并返回token列表 + + Args: + text: 输入文本 + + Returns: + token列表 + """ + if not text: + return [] + + # 使用jieba分词 + if self.use_jieba: + try: + import jieba + + tokens = list(jieba.cut(text)) + except Exception as e: + logger.warning(f"Jieba分词失败,使用字符级分词: {e}") + tokens = list(text) + else: + # 简单按字符分词 + tokens = list(text) + + # 过滤停用词和空字符串 + tokens = [token.strip() for token in tokens if token.strip() and token not in self.stopwords] + + return tokens diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py new file mode 100644 index 000000000..fa0302fb1 --- /dev/null +++ b/src/chat/express/style_learner.py @@ -0,0 +1,405 @@ +""" +风格学习引擎 +基于ExpressorModel实现的表达风格学习和预测系统 +支持多聊天室独立建模和在线学习 +""" +import os +import time +from typing import Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from .expressor_model import ExpressorModel + +logger = get_logger("expressor.style_learner") + + +class StyleLearner: + """单个聊天室的表达风格学习器""" + + def __init__(self, chat_id: str, model_config: Optional[Dict] = None): + """ + Args: + chat_id: 聊天室ID + model_config: 模型配置 + """ + self.chat_id = chat_id + self.model_config = model_config or { + "alpha": 0.5, + "beta": 0.5, + "gamma": 0.99, # 衰减因子,支持遗忘 + "vocab_size": 200000, + "use_jieba": True, + } + + # 初始化表达模型 + self.expressor = ExpressorModel(**self.model_config) + + # 动态风格管理 + self.max_styles = 2000 # 每个chat_id最多2000个风格 + self.style_to_id: Dict[str, str] = {} # style文本 -> style_id + self.id_to_style: Dict[str, str] = {} # style_id -> style文本 + self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 + self.next_style_id = 0 + + # 学习统计 + self.learning_stats = { + "total_samples": 0, + "style_counts": {}, + "last_update": time.time(), + } + + logger.info(f"StyleLearner初始化成功: chat_id={chat_id}") + + def add_style(self, style: str, situation: Optional[str] = None) -> bool: + """ + 动态添加一个新的风格 + + Args: + style: 风格文本 + situation: 情境文本 + + Returns: + 是否添加成功 + """ + try: + # 检查是否已存在 + if style in self.style_to_id: + return True + + # 检查是否超过最大限制 + if len(self.style_to_id) >= self.max_styles: + logger.warning(f"已达到最大风格数量限制 ({self.max_styles})") + return False + + # 生成新的style_id + style_id = f"style_{self.next_style_id}" + self.next_style_id += 1 + + # 添加到映射 + self.style_to_id[style] = style_id + self.id_to_style[style_id] = style + if situation: + self.id_to_situation[style_id] = situation + + # 添加到expressor模型 + self.expressor.add_candidate(style_id, style, situation) + + # 初始化统计 + self.learning_stats["style_counts"][style_id] = 0 + + logger.debug(f"添加风格成功: {style_id} -> {style}") + return True + + except Exception as e: + logger.error(f"添加风格失败: {e}") + return False + + def learn_mapping(self, up_content: str, style: str) -> bool: + """ + 学习一个up_content到style的映射 + + Args: + up_content: 前置内容 + style: 目标风格 + + Returns: + 是否学习成功 + """ + try: + # 如果style不存在,先添加它 + if style not in self.style_to_id: + if not self.add_style(style): + return False + + # 获取style_id + style_id = self.style_to_id[style] + + # 使用正反馈学习 + self.expressor.update_positive(up_content, style_id) + + # 更新统计 + self.learning_stats["total_samples"] += 1 + self.learning_stats["style_counts"][style_id] += 1 + self.learning_stats["last_update"] = time.time() + + logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}") + return True + + except Exception as e: + logger.error(f"学习映射失败: {e}") + return False + + def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 根据up_content预测最合适的style + + Args: + up_content: 前置内容 + top_k: 返回前k个候选 + + Returns: + (最佳style文本, 所有候选的分数字典) + """ + try: + best_style_id, scores = self.expressor.predict(up_content, k=top_k) + + if best_style_id is None: + return None, {} + + # 将style_id转换为style文本 + best_style = self.id_to_style.get(best_style_id) + + # 转换所有分数 + style_scores = {} + for sid, score in scores.items(): + style_text = self.id_to_style.get(sid) + if style_text: + style_scores[style_text] = score + + return best_style, style_scores + + except Exception as e: + logger.error(f"预测style失败: {e}") + return None, {} + + def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: + """ + 获取style的完整信息 + + Args: + style: 风格文本 + + Returns: + (style_id, situation) + """ + style_id = self.style_to_id.get(style) + if not style_id: + return None, None + + situation = self.id_to_situation.get(style_id) + return style_id, situation + + def get_all_styles(self) -> List[str]: + """ + 获取所有风格列表 + + Returns: + 风格文本列表 + """ + return list(self.style_to_id.keys()) + + def apply_decay(self, factor: Optional[float] = None): + """ + 应用知识衰减 + + Args: + factor: 衰减因子 + """ + self.expressor.decay(factor) + logger.debug(f"应用知识衰减: chat_id={self.chat_id}") + + def save(self, base_path: str) -> bool: + """ + 保存学习器到文件 + + Args: + base_path: 基础保存路径 + + Returns: + 是否保存成功 + """ + try: + # 创建保存目录 + save_dir = os.path.join(base_path, self.chat_id) + os.makedirs(save_dir, exist_ok=True) + + # 保存expressor模型 + model_path = os.path.join(save_dir, "expressor_model.pkl") + self.expressor.save(model_path) + + # 保存映射关系和统计信息 + import pickle + + meta_path = os.path.join(save_dir, "meta.pkl") + meta_data = { + "style_to_id": self.style_to_id, + "id_to_style": self.id_to_style, + "id_to_situation": self.id_to_situation, + "next_style_id": self.next_style_id, + "learning_stats": self.learning_stats, + } + + with open(meta_path, "wb") as f: + pickle.dump(meta_data, f) + + logger.info(f"StyleLearner保存成功: {save_dir}") + return True + + except Exception as e: + logger.error(f"保存StyleLearner失败: {e}") + return False + + def load(self, base_path: str) -> bool: + """ + 从文件加载学习器 + + Args: + base_path: 基础加载路径 + + Returns: + 是否加载成功 + """ + try: + save_dir = os.path.join(base_path, self.chat_id) + + # 检查目录是否存在 + if not os.path.exists(save_dir): + logger.debug(f"StyleLearner保存目录不存在: {save_dir}") + return False + + # 加载expressor模型 + model_path = os.path.join(save_dir, "expressor_model.pkl") + if os.path.exists(model_path): + self.expressor.load(model_path) + + # 加载映射关系和统计信息 + import pickle + + meta_path = os.path.join(save_dir, "meta.pkl") + if os.path.exists(meta_path): + with open(meta_path, "rb") as f: + meta_data = pickle.load(f) + + self.style_to_id = meta_data["style_to_id"] + self.id_to_style = meta_data["id_to_style"] + self.id_to_situation = meta_data["id_to_situation"] + self.next_style_id = meta_data["next_style_id"] + self.learning_stats = meta_data["learning_stats"] + + logger.info(f"StyleLearner加载成功: {save_dir}") + return True + + except Exception as e: + logger.error(f"加载StyleLearner失败: {e}") + return False + + def get_stats(self) -> Dict: + """获取统计信息""" + model_stats = self.expressor.get_stats() + return { + "chat_id": self.chat_id, + "n_styles": len(self.style_to_id), + "total_samples": self.learning_stats["total_samples"], + "last_update": self.learning_stats["last_update"], + "model_stats": model_stats, + } + + +class StyleLearnerManager: + """多聊天室表达风格学习管理器""" + + def __init__(self, model_save_path: str = "data/expression/style_models"): + """ + Args: + model_save_path: 模型保存路径 + """ + self.learners: Dict[str, StyleLearner] = {} + self.model_save_path = model_save_path + + # 确保保存目录存在 + os.makedirs(model_save_path, exist_ok=True) + + logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") + + def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: + """ + 获取或创建指定chat_id的学习器 + + Args: + chat_id: 聊天室ID + model_config: 模型配置 + + Returns: + StyleLearner实例 + """ + if chat_id not in self.learners: + # 创建新的学习器 + learner = StyleLearner(chat_id, model_config) + + # 尝试加载已保存的模型 + learner.load(self.model_save_path) + + self.learners[chat_id] = learner + + return self.learners[chat_id] + + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: + """ + 学习一个映射关系 + + Args: + chat_id: 聊天室ID + up_content: 前置内容 + style: 目标风格 + + Returns: + 是否学习成功 + """ + learner = self.get_learner(chat_id) + return learner.learn_mapping(up_content, style) + + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 预测最合适的风格 + + Args: + chat_id: 聊天室ID + up_content: 前置内容 + top_k: 返回前k个候选 + + Returns: + (最佳style, 分数字典) + """ + learner = self.get_learner(chat_id) + return learner.predict_style(up_content, top_k) + + def save_all(self) -> bool: + """ + 保存所有学习器 + + Returns: + 是否全部保存成功 + """ + success = True + for chat_id, learner in self.learners.items(): + if not learner.save(self.model_save_path): + success = False + + logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}") + return success + + def apply_decay_all(self, factor: Optional[float] = None): + """ + 对所有学习器应用知识衰减 + + Args: + factor: 衰减因子 + """ + for learner in self.learners.values(): + learner.apply_decay(factor) + + logger.info(f"对所有StyleLearner应用知识衰减") + + def get_all_stats(self) -> Dict[str, Dict]: + """ + 获取所有学习器的统计信息 + + Returns: + {chat_id: stats} + """ + return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} + + +# 全局单例 +style_learner_manager = StyleLearnerManager() diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 854ca615a..e15dab72a 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -5,6 +5,7 @@ from typing import Any from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.timer_calculator import Timer +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config from src.person_info.person_info import get_person_info_manager @@ -142,7 +143,7 @@ class ChatterActionManager: self, action_name: str, chat_id: str, - target_message: dict | None = None, + target_message: dict | DatabaseMessages | None = None, reasoning: str = "", action_data: dict | None = None, thinking_id: str | None = None, @@ -262,9 +263,15 @@ class ChatterActionManager: from_plugin=False, ) if not success or not response_set: - logger.info( - f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败" - ) + # 安全地获取 processed_plain_text + if isinstance(target_message, DatabaseMessages): + msg_text = target_message.processed_plain_text or "未知消息" + elif target_message: + msg_text = target_message.get("processed_plain_text", "未知消息") + else: + msg_text = "未知消息" + + logger.info(f"对 {msg_text} 的回复生成失败") return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} except asyncio.CancelledError: logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消") @@ -322,8 +329,11 @@ class ChatterActionManager: # 获取目标消息ID target_message_id = None - if target_message and isinstance(target_message, dict): - target_message_id = target_message.get("message_id") + if target_message: + if isinstance(target_message, DatabaseMessages): + target_message_id = target_message.message_id + elif isinstance(target_message, dict): + target_message_id = target_message.get("message_id") elif action_data and isinstance(action_data, dict): target_message_id = action_data.get("target_message_id") @@ -488,14 +498,19 @@ class ChatterActionManager: person_info_manager = get_person_info_manager() # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.get("chat_info_platform") - if platform is None: - platform = getattr(chat_stream, "platform", "unknown") + if isinstance(action_message, DatabaseMessages): + platform = action_message.chat_info.platform + user_id = action_message.user_info.user_id + else: + platform = action_message.get("chat_info_platform") + if platform is None: + platform = getattr(chat_stream, "platform", "unknown") + user_id = action_message.get("user_id", "") # 获取用户信息并生成回复提示 person_id = person_info_manager.get_person_id( platform, - action_message.get("user_id", ""), + user_id, ) person_name = await person_info_manager.get_value(person_id, "person_name") action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" @@ -565,7 +580,14 @@ class ChatterActionManager: # 根据新消息数量决定是否需要引用回复 reply_text = "" - is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True + # 检查是否为主动思考消息 + if isinstance(message_data, DatabaseMessages): + # DatabaseMessages 对象没有 message_type 字段,默认为 False + is_proactive_thinking = False + elif message_data: + is_proactive_thinking = message_data.get("message_type") == "proactive_thinking" + else: + is_proactive_thinking = True logger.debug(f"[send_response] message_data: {message_data}") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 6578fd215..3a90d36a0 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality @@ -474,10 +475,13 @@ class DefaultReplyer: style_habits = [] grammar_habits = [] - # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions = await expression_selector.select_suitable_expressions_llm( - self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target + # 使用统一的表达方式选择入口(支持classic和exp_model模式) + selected_expressions = await expression_selector.select_suitable_expressions( + chat_id=self.chat_stream.stream_id, + chat_history=chat_history, + target_message=target, + max_num=8, + min_num=2 ) if selected_expressions: @@ -1206,7 +1210,7 @@ class DefaultReplyer: extra_info: str = "", available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, - reply_message: dict[str, Any] | None = None, + reply_message: dict[str, Any] | DatabaseMessages | None = None, ) -> str: """ 构建回复器上下文 @@ -1248,10 +1252,24 @@ class DefaultReplyer: if reply_message is None: logger.warning("reply_message 为 None,无法构建prompt") return "" - platform = reply_message.get("chat_info_platform") + + # 统一处理 DatabaseMessages 对象和字典 + if isinstance(reply_message, DatabaseMessages): + platform = reply_message.chat_info.platform + user_id = reply_message.user_info.user_id + user_nickname = reply_message.user_info.user_nickname + user_cardname = reply_message.user_info.user_cardname + processed_plain_text = reply_message.processed_plain_text + else: + platform = reply_message.get("chat_info_platform") + user_id = reply_message.get("user_id") + user_nickname = reply_message.get("user_nickname") + user_cardname = reply_message.get("user_cardname") + processed_plain_text = reply_message.get("processed_plain_text") + person_id = person_info_manager.get_person_id( platform, # type: ignore - reply_message.get("user_id"), # type: ignore + user_id, # type: ignore ) person_name = await person_info_manager.get_value(person_id, "person_name") @@ -1260,22 +1278,22 @@ class DefaultReplyer: # 尝试从reply_message获取用户名 await person_info_manager.first_knowing_some_one( platform, # type: ignore - reply_message.get("user_id"), # type: ignore - reply_message.get("user_nickname") or "", - reply_message.get("user_cardname") or "", + user_id, # type: ignore + user_nickname or "", + user_cardname or "", ) # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) current_user_id = await person_info_manager.get_value(person_id, "user_id") - current_platform = reply_message.get("chat_info_platform") + current_platform = platform if current_user_id == bot_user_id and current_platform == global_config.bot.platform: sender = f"{person_name}(你)" else: # 如果不是bot自己,直接使用person_name sender = person_name - target = reply_message.get("processed_plain_text") + target = processed_plain_text # 最终的空值检查,确保sender和target不为None if sender is None: @@ -1609,15 +1627,22 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - reply_message: dict[str, Any] | None = None, + reply_message: dict[str, Any] | DatabaseMessages | None = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) if reply_message: - sender = reply_message.get("sender") - target = reply_message.get("target") + if isinstance(reply_message, DatabaseMessages): + # 从 DatabaseMessages 对象获取 sender 和 target + # 注意: DatabaseMessages 没有直接的 sender/target 字段 + # 需要根据实际情况构造 + sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id + target = reply_message.processed_plain_text or "" + else: + sender = reply_message.get("sender") + target = reply_message.get("target") else: sender, target = self._parse_reply_target(reply_to) diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 09c0dad95..0d83f7cfd 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -606,11 +606,11 @@ class Prompt: recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - # 使用LLM选择与当前情景匹配的表达习惯 + # 使用统一的表达方式选择入口(支持classic和exp_model模式) expression_selector = ExpressionSelector(self.parameters.chat_id) - selected_expressions = await expression_selector.select_suitable_expressions_llm( + selected_expressions = await expression_selector.select_suitable_expressions( chat_id=self.parameters.chat_id, - chat_info=chat_history, + chat_history=chat_history, target_message=self.parameters.target, ) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 0ef26f33c..3d1529420 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -183,6 +183,10 @@ class ExpressionRule(ValidatedConfigBase): class ExpressionConfig(ValidatedConfigBase): """表达配置类""" + mode: Literal["classic", "exp_model"] = Field( + default="classic", + description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测" + ) rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @staticmethod diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 7b0dca370..c612fab48 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -107,10 +107,13 @@ class PromptBuilder: style_habits = [] grammar_habits = [] - # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions = await expression_selector.select_suitable_expressions_llm( - chat_stream.stream_id, chat_history, max_num=12, min_num=5, target_message=target + # 使用统一的表达方式选择入口(支持classic和exp_model模式) + selected_expressions = await expression_selector.select_suitable_expressions( + chat_id=chat_stream.stream_id, + chat_history=chat_history, + target_message=target, + max_num=12, + min_num=5 ) if selected_expressions: diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 5790d2312..b5071e578 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING from src.chat.message_receive.chat_stream import ChatStream +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.plugin_system.apis import database_api, message_api, send_api from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType @@ -180,11 +181,18 @@ class BaseAction(ABC): if self.has_action_message: if self.action_name != "no_reply": - self.group_id = str(self.action_message.get("chat_info_group_id", None)) - self.group_name = self.action_message.get("chat_info_group_name", None) - - self.user_id = str(self.action_message.get("user_id", None)) - self.user_nickname = self.action_message.get("user_nickname", None) + # 统一处理 DatabaseMessages 对象和字典 + if isinstance(self.action_message, DatabaseMessages): + self.group_id = str(self.action_message.group_info.group_id if self.action_message.group_info else None) + self.group_name = self.action_message.group_info.group_name if self.action_message.group_info else None + self.user_id = str(self.action_message.user_info.user_id) + self.user_nickname = self.action_message.user_info.user_nickname + else: + self.group_id = str(self.action_message.get("chat_info_group_id", None)) + self.group_name = self.action_message.get("chat_info_group_name", None) + self.user_id = str(self.action_message.get("user_id", None)) + self.user_nickname = self.action_message.get("user_nickname", None) + if self.group_id: self.is_group = True self.target_id = self.group_id diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 05005c173..8d75ca2fd 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -6,6 +6,7 @@ from typing import ClassVar from dateutil.parser import parse as parse_datetime from src.chat.message_receive.chat_stream import ChatStream +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask, async_task_manager from src.person_info.person_info import get_person_info_manager @@ -253,19 +254,19 @@ class SetEmojiLikeAction(BaseAction): message_id = None set_like = self.action_data.get("set", True) - if self.has_action_message and isinstance(self.action_message, dict): - message_id = self.action_message.get("message_id") - logger.info(f"获取到的消息ID: {message_id}") - else: + if self.has_action_message: + if isinstance(self.action_message, DatabaseMessages): + message_id = self.action_message.message_id + logger.info(f"获取到的消息ID: {message_id}") + elif isinstance(self.action_message, dict): + message_id = self.action_message.get("message_id") + logger.info(f"获取到的消息ID: {message_id}") + + if not message_id: logger.error("未提供有效的消息或消息ID") await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False) return False, "未提供消息ID" - if not message_id: - logger.error("消息ID为空") - await self.store_action_info(action_prompt_display="贴表情失败: 消息ID为空", action_done=False) - return False, "消息ID为空" - available_models = llm_api.get_available_models() if "utils_small" not in available_models: logger.error("未找到 'utils_small' 模型配置,无法选择表情") @@ -273,7 +274,12 @@ class SetEmojiLikeAction(BaseAction): model_to_use = available_models["utils_small"] - context_text = self.action_message.get("processed_plain_text", "") + # 统一处理 DatabaseMessages 和字典 + if isinstance(self.action_message, DatabaseMessages): + context_text = self.action_message.processed_plain_text or "" + else: + context_text = self.action_message.get("processed_plain_text", "") + if not context_text: logger.error("无法找到动作选择的原始消息文本") return False, "无法找到动作选择的原始消息文本" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index cb3813945..ddfdada13 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.5.1" +version = "7.5.2" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -92,6 +92,11 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息 [expression] # 表达学习配置 +# mode: 表达方式模式,可选: +# - "classic": 经典模式,随机抽样 + LLM选择 +# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达 +mode = "classic" + # rules是一个列表,每个元素都是一个学习规则 # chat_stream_id: 聊天流ID,格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置 # use_expression: 是否使用学到的表达 (true/false)