From cfa642cf0af34c1e62ca6fb5106041a538d322c1 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 30 Oct 2025 11:16:30 +0800 Subject: [PATCH] =?UTF-8?q?feat(expression):=20=E5=A2=9E=E5=BC=BA=E8=A1=A8?= =?UTF-8?q?=E8=BE=BE=E5=AD=A6=E4=B9=A0=E4=B8=8E=E9=80=89=E6=8B=A9=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E7=9A=84=E5=81=A5=E5=A3=AE=E6=80=A7=E5=92=8C=E6=99=BA?= =?UTF-8?q?=E8=83=BD=E5=8C=B9=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 改进表达学习器的提示词格式规范,增强LLM输出解析的容错性 - 优化表达选择器的模型预测模式,添加情境提取和模糊匹配机制 - 增强StyleLearner的错误处理和日志记录,提高训练和预测的稳定性 - 改进流循环管理器的日志输出,避免重复信息刷屏 - 扩展SendAPI的消息查找功能,支持DatabaseMessages对象兼容 - 添加智能回退机制,当模型预测失败时自动切换到经典模式 - 优化数据库查询逻辑,支持跨聊天流的表达方式共享 BREAKING CHANGE: 表达选择器的模型预测模式现在需要情境提取器配合使用,旧版本配置可能需要更新依赖关系 --- scripts/check_expression_database.py | 116 ++++++++++++ scripts/check_style_field.py | 65 +++++++ scripts/debug_style_learner.py | 88 +++++++++ src/chat/express/expression_learner.py | 178 +++++++++++++++--- src/chat/express/expression_selector.py | 172 +++++++++++++---- src/chat/express/situation_extractor.py | 162 ++++++++++++++++ src/chat/express/style_learner.py | 22 ++- .../message_manager/distribution_manager.py | 12 +- src/plugin_system/apis/send_api.py | 63 +++++-- 9 files changed, 795 insertions(+), 83 deletions(-) create mode 100644 scripts/check_expression_database.py create mode 100644 scripts/check_style_field.py create mode 100644 scripts/debug_style_learner.py create mode 100644 src/chat/express/situation_extractor.py diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py new file mode 100644 index 000000000..f600cc434 --- /dev/null +++ b/scripts/check_expression_database.py @@ -0,0 +1,116 @@ +""" +检查表达方式数据库状态的诊断脚本 +""" +import asyncio +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from sqlalchemy import select, func +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression + + +async def check_database(): + """检查表达方式数据库状态""" + + print("=" * 60) + print("表达方式数据库诊断报告") + print("=" * 60) + + async with get_db_session() as session: + # 1. 统计总数 + total_count = await session.execute(select(func.count()).select_from(Expression)) + total = total_count.scalar() + print(f"\n📊 总表达方式数量: {total}") + + if total == 0: + print("\n⚠️ 数据库为空!") + print("\n可能的原因:") + print("1. 还没有进行过表达学习") + print("2. 配置中禁用了表达学习") + print("3. 学习过程中发生了错误") + print("\n建议:") + print("- 检查 bot_config.toml 中的 [expression] 配置") + print("- 查看日志中是否有表达学习相关的错误") + print("- 确认聊天流的 learn_expression 配置为 true") + return + + # 2. 按 chat_id 统计 + print("\n📝 按聊天流统计:") + chat_counts = await session.execute( + select(Expression.chat_id, func.count()) + .group_by(Expression.chat_id) + ) + for chat_id, count in chat_counts: + print(f" - {chat_id}: {count} 个表达方式") + + # 3. 按 type 统计 + print("\n📝 按类型统计:") + type_counts = await session.execute( + select(Expression.type, func.count()) + .group_by(Expression.type) + ) + for expr_type, count in type_counts: + print(f" - {expr_type}: {count} 个") + + # 4. 检查 situation 和 style 字段是否有空值 + print("\n🔍 字段完整性检查:") + null_situation = await session.execute( + select(func.count()) + .select_from(Expression) + .where(Expression.situation == None) + ) + null_style = await session.execute( + select(func.count()) + .select_from(Expression) + .where(Expression.style == None) + ) + + null_sit_count = null_situation.scalar() + null_sty_count = null_style.scalar() + + print(f" - situation 为空: {null_sit_count} 个") + print(f" - style 为空: {null_sty_count} 个") + + if null_sit_count > 0 or null_sty_count > 0: + print(" ⚠️ 发现空值!这会导致匹配失败") + + # 5. 显示一些样例数据 + print("\n📋 样例数据 (前10条):") + samples = await session.execute( + select(Expression) + .limit(10) + ) + + for i, expr in enumerate(samples.scalars(), 1): + print(f"\n [{i}] Chat: {expr.chat_id}") + print(f" Type: {expr.type}") + print(f" Situation: {expr.situation}") + print(f" Style: {expr.style}") + print(f" Count: {expr.count}") + + # 6. 检查 style 字段的唯一值 + print("\n📋 Style 字段样例 (前20个):") + unique_styles = await session.execute( + select(Expression.style) + .distinct() + .limit(20) + ) + + styles = [s for s in unique_styles.scalars()] + for style in styles: + print(f" - {style}") + + print(f"\n (共 {len(styles)} 个不同的 style)") + + print("\n" + "=" * 60) + print("诊断完成") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(check_database()) diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py new file mode 100644 index 000000000..c8f5ef1fb --- /dev/null +++ b/scripts/check_style_field.py @@ -0,0 +1,65 @@ +""" +检查数据库中 style 字段的内容特征 +""" +import asyncio +import sys +from pathlib import Path + +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from sqlalchemy import select +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression + + +async def analyze_style_fields(): + """分析 style 字段的内容""" + + print("=" * 60) + print("Style 字段内容分析") + print("=" * 60) + + async with get_db_session() as session: + # 获取所有表达方式 + result = await session.execute(select(Expression).limit(30)) + expressions = result.scalars().all() + + print(f"\n总共检查 {len(expressions)} 条记录\n") + + # 按类型分类 + style_examples = [] + + for expr in expressions: + if expr.type == "style": + style_examples.append({ + "situation": expr.situation, + "style": expr.style, + "length": len(expr.style) if expr.style else 0 + }) + + print("📋 Style 类型样例 (前15条):") + print("="*60) + for i, ex in enumerate(style_examples[:15], 1): + print(f"\n[{i}]") + print(f" Situation: {ex['situation']}") + print(f" Style: {ex['style']}") + print(f" 长度: {ex['length']} 字符") + + # 判断是具体表达还是风格描述 + if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']): + style_type = "✓ 风格描述" + elif ex['length'] <= 10: + style_type = "? 可能是具体表达(较短)" + else: + style_type = "✗ 具体表达内容" + + print(f" 类型判断: {style_type}") + + print("\n" + "="*60) + print("分析完成") + print("="*60) + + +if __name__ == "__main__": + asyncio.run(analyze_style_fields()) diff --git a/scripts/debug_style_learner.py b/scripts/debug_style_learner.py new file mode 100644 index 000000000..970ba2532 --- /dev/null +++ b/scripts/debug_style_learner.py @@ -0,0 +1,88 @@ +""" +检查 StyleLearner 模型状态的诊断脚本 +""" +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.chat.express.style_learner import style_learner_manager +from src.common.logger import get_logger + +logger = get_logger("debug_style_learner") + + +def check_style_learner_status(chat_id: str): + """检查指定 chat_id 的 StyleLearner 状态""" + + print("=" * 60) + print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}") + print("=" * 60) + + # 获取 learner + learner = style_learner_manager.get_learner(chat_id) + + # 1. 基本信息 + print(f"\n📊 基本信息:") + print(f" Chat ID: {learner.chat_id}") + print(f" 风格数量: {len(learner.style_to_id)}") + print(f" 下一个ID: {learner.next_style_id}") + print(f" 最大风格数: {learner.max_styles}") + + # 2. 学习统计 + print(f"\n📈 学习统计:") + print(f" 总样本数: {learner.learning_stats['total_samples']}") + print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}") + + # 3. 风格列表(前20个) + print(f"\n📋 已学习的风格 (前20个):") + all_styles = learner.get_all_styles() + if not all_styles: + print(" ⚠️ 没有任何风格!模型尚未训练") + else: + for i, style in enumerate(all_styles[:20], 1): + style_id = learner.style_to_id.get(style) + situation = learner.id_to_situation.get(style_id, "N/A") + print(f" [{i}] {style}") + print(f" (ID: {style_id}, Situation: {situation})") + + # 4. 测试预测 + print(f"\n🔮 测试预测功能:") + if not all_styles: + print(" ⚠️ 无法测试,模型没有训练数据") + else: + test_situations = [ + "表示惊讶", + "讨论游戏", + "表达赞同" + ] + + for test_sit in test_situations: + print(f"\n 测试输入: '{test_sit}'") + best_style, scores = learner.predict_style(test_sit, top_k=3) + + if best_style: + print(f" ✓ 最佳匹配: {best_style}") + print(f" Top 3:") + for style, score in list(scores.items())[:3]: + print(f" - {style}: {score:.4f}") + else: + print(f" ✗ 预测失败") + + print("\n" + "=" * 60) + print("诊断完成") + print("=" * 60) + + +if __name__ == "__main__": + # 从诊断报告中看到的 chat_id + test_chat_ids = [ + "52fb94af9f500a01e023ea780e43606e", # 有78个表达方式 + "46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式 + ] + + for chat_id in test_chat_ids: + check_style_learner_status(chat_id) + print("\n") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index a3299bfa1..75864be40 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -46,17 +46,29 @@ def init_prompt() -> None: 3. 语言风格包含特殊内容和情感 4. 思考有没有特殊的梗,一并总结成语言风格 5. 例子仅供参考,请严格根据群聊内容总结!!! -注意:总结成如下格式的规律,总结的内容要详细,但具有概括性: -例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。 + +**重要:必须严格按照以下格式输出,每行一条规律:** +当"xxx"时,使用"xxx" + +格式说明: +- 必须以"当"开头 +- 场景描述用双引号包裹,不超过20个字 +- 必须包含"使用"或"可以" +- 表达风格用双引号包裹,不超过20个字 +- 每条规律独占一行 例如: 当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx" 当"表示讽刺的赞同,不想讲道理"时,使用"对对对" -当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂" -当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" +当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契"时,使用"懂的都懂" +当"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" -请注意:不要总结你自己(SELF)的发言 -现在请你概括 +注意: +1. 不要总结你自己(SELF)的发言 +2. 如果聊天内容中没有明显的特殊风格,请只输出1-2条最明显的特点 +3. 不要输出其他解释性文字,只输出符合格式的规律 + +现在请你概括: """ Prompt(learn_style_prompt, "learn_style_prompt") @@ -68,16 +80,28 @@ def init_prompt() -> None: 2.不要涉及具体的人名,只考虑语法和句法特点, 3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。 4. 例子仅供参考,请严格根据群聊内容总结!!! -总结成如下格式的规律,总结的内容要简洁,不浮夸: -当"xxx"时,可以"xxx" + +**重要:必须严格按照以下格式输出,每行一条规律:** +当"xxx"时,使用"xxx" + +格式说明: +- 必须以"当"开头 +- 场景描述用双引号包裹 +- 必须包含"使用"或"可以" +- 句法特点用双引号包裹 +- 每条规律独占一行 例如: 当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法 当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法 当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法 -注意不要总结你自己(SELF)的发言 -现在请你概括 +注意: +1. 不要总结你自己(SELF)的发言 +2. 如果聊天内容中没有明显的句法特点,请只输出1-2条最明显的特点 +3. 不要输出其他解释性文字,只输出符合格式的规律 + +现在请你概括: """ Prompt(learn_grammar_prompt, "learn_grammar_prompt") @@ -408,28 +432,43 @@ class ExpressionLearner: for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) - # 🔥 新增:训练 StyleLearner + # 🔥 训练 StyleLearner # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) if type == "style": try: # 获取 StyleLearner 实例 learner = style_learner_manager.get_learner(chat_id) + logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") + # 为每个学习到的表达方式训练模型 - # 这里使用 situation 作为前置内容(context),style 作为目标风格 + # 使用 situation 作为输入,style 作为目标 + # 这是最符合语义的方式:场景 -> 表达方式 + success_count = 0 for expr in expr_list: situation = expr["situation"] style = expr["style"] # 训练映射关系: situation -> style - learner.learn_mapping(situation, style) + if learner.learn_mapping(situation, style): + success_count += 1 + else: + logger.warning(f"训练失败: {situation} -> {style}") - logger.debug(f"已将 {len(expr_list)} 个表达方式训练到 StyleLearner") + logger.info( + f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " + f"当前风格总数={len(learner.get_all_styles())}, " + f"总样本数={learner.learning_stats['total_samples']}" + ) # 保存模型 - learner.save(style_learner_manager.model_save_path) + if learner.save(style_learner_manager.model_save_path): + logger.info(f"StyleLearner 模型保存成功: {chat_id}") + else: + logger.error(f"StyleLearner 模型保存失败: {chat_id}") + except Exception as e: - logger.error(f"训练 StyleLearner 失败: {e}") + logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True) return learnt_expressions return None @@ -481,9 +520,17 @@ class ExpressionLearner: logger.error(f"学习{type_str}失败: {e}") return None + if not response or not response.strip(): + logger.warning(f"LLM返回空响应,无法学习{type_str}") + return None + logger.debug(f"学习{type_str}的response: {response}") expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + + if not expressions: + logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。") + logger.info(f"LLM完整响应:\n{response}") return expressions, chat_id @@ -491,31 +538,100 @@ class ExpressionLearner: def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 + 支持多种引号格式:"" 和 "" """ expressions: list[tuple[str, str, str]] = [] - for line in response.splitlines(): + failed_lines = [] + + for line_num, line in enumerate(response.splitlines(), 1): line = line.strip() if not line: continue + + # 替换中文引号为英文引号,便于统一处理 + line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"') + # 查找"当"和下一个引号 - idx_when = line.find('当"') + idx_when = line_normalized.find('当"') if idx_when == -1: - continue - idx_quote1 = idx_when + 1 - idx_quote2 = line.find('"', idx_quote1 + 1) - if idx_quote2 == -1: - continue - situation = line[idx_quote1 + 1 : idx_quote2] - # 查找"使用" - idx_use = line.find('使用"', idx_quote2) + # 尝试不带引号的格式: 当xxx时 + idx_when = line_normalized.find('当') + if idx_when == -1: + failed_lines.append((line_num, line, "找不到'当'关键字")) + continue + + # 提取"当"和"时"之间的内容 + idx_shi = line_normalized.find('时', idx_when) + if idx_shi == -1: + failed_lines.append((line_num, line, "找不到'时'关键字")) + continue + situation = line_normalized[idx_when + 1:idx_shi].strip('"\'""') + search_start = idx_shi + else: + idx_quote1 = idx_when + 1 + idx_quote2 = line_normalized.find('"', idx_quote1 + 1) + if idx_quote2 == -1: + failed_lines.append((line_num, line, "situation部分引号不匹配")) + continue + situation = line_normalized[idx_quote1 + 1 : idx_quote2] + search_start = idx_quote2 + + # 查找"使用"或"可以" + idx_use = line_normalized.find('使用"', search_start) if idx_use == -1: + idx_use = line_normalized.find('可以"', search_start) + if idx_use == -1: + # 尝试不带引号的格式 + idx_use = line_normalized.find('使用', search_start) + if idx_use == -1: + idx_use = line_normalized.find('可以', search_start) + if idx_use == -1: + failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字")) + continue + + # 提取剩余部分作为style + style = line_normalized[idx_use + 2:].strip('"\'"",。') + if not style: + failed_lines.append((line_num, line, "style部分为空")) + continue + else: + idx_quote3 = idx_use + 2 + idx_quote4 = line_normalized.find('"', idx_quote3 + 1) + if idx_quote4 == -1: + # 如果没有结束引号,取到行尾 + style = line_normalized[idx_quote3 + 1:].strip('"\'""') + else: + style = line_normalized[idx_quote3 + 1 : idx_quote4] + else: + idx_quote3 = idx_use + 2 + idx_quote4 = line_normalized.find('"', idx_quote3 + 1) + if idx_quote4 == -1: + # 如果没有结束引号,取到行尾 + style = line_normalized[idx_quote3 + 1:].strip('"\'""') + else: + style = line_normalized[idx_quote3 + 1 : idx_quote4] + + # 清理并验证 + situation = situation.strip() + style = style.strip() + + if not situation or not style: + failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'")) continue - idx_quote3 = idx_use + 2 - idx_quote4 = line.find('"', idx_quote3 + 1) - if idx_quote4 == -1: - continue - style = line[idx_quote3 + 1 : idx_quote4] + expressions.append((chat_id, situation, style)) + + # 记录解析失败的行 + if failed_lines: + logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:") + for line_num, line, reason in failed_lines[:5]: # 只显示前5个 + logger.warning(f" 行{line_num}: {reason}") + logger.debug(f" 原文: {line}") + + if not expressions: + logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}") + else: + logger.debug(f"成功解析 {len(expressions)} 个表达方式") return expressions diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index dc511055a..1dbf7e08e 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -15,7 +15,8 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -# 导入StyleLearner管理器 +# 导入StyleLearner管理器和情境提取器 +from .situation_extractor import situation_extractor from .style_learner import style_learner_manager logger = get_logger("expression_selector") @@ -130,17 +131,18 @@ class ExpressionSelector: current_group = rule.group break - if not current_group: - return [chat_id] + # 🔥 始终包含当前 chat_id(确保至少能查到自己的数据) + related_chat_ids = [chat_id] - # 找出同一组的所有chat_id - related_chat_ids = [] - for rule in rules: - if rule.group == current_group and rule.chat_stream_id: - if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): - related_chat_ids.append(chat_id_candidate) + if current_group: + # 找出同一组的所有chat_id + for rule in rules: + if rule.group == current_group and rule.chat_stream_id: + if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): + if chat_id_candidate not in related_chat_ids: + related_chat_ids.append(chat_id_candidate) - return related_chat_ids if related_chat_ids else [chat_id] + return related_chat_ids async def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float @@ -313,22 +315,52 @@ class ExpressionSelector: max_num: int = 10, min_num: int = 5, ) -> list[dict[str, Any]]: - """模型预测模式:使用StyleLearner预测最合适的表达风格""" - logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式") + """模型预测模式:先提取情境,再使用StyleLearner预测表达风格""" + logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") # 检查是否允许在此聊天流中使用表达 if not self.can_use_expression_for_chat(chat_id): logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") return [] - # 获取或创建StyleLearner实例 + # 步骤1: 提取聊天情境 + situations = await situation_extractor.extract_situations( + chat_history=chat_info, + target_message=target_message, + max_situations=3 + ) + + if not situations: + logger.warning(f"无法提取聊天情境,回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}") + + # 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式 learner = style_learner_manager.get_learner(chat_id) - # 使用StyleLearner预测最合适的风格 - best_style, all_scores = learner.predict_style(chat_info, top_k=max_num) + all_predicted_styles = {} + for i, situation in enumerate(situations, 1): + logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}") + best_style, scores = learner.predict_style(situation, top_k=max_num) + + if best_style and scores: + logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}") + # 合并分数(取最高分) + for style, score in scores.items(): + if style not in all_predicted_styles or score > all_predicted_styles[style]: + all_predicted_styles[style] = score + else: + logger.debug(f" 该情境未返回预测结果") - if not best_style or not all_scores: - logger.warning(f"StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") + if not all_predicted_styles: + logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -338,9 +370,12 @@ class ExpressionSelector: ) # 将分数字典转换为列表格式 [(style, score), ...] - predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True) + predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True) - # 根据预测的风格从数据库获取表达方式 + logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}") + + # 步骤3: 根据预测的风格从数据库获取表达方式 + logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式") expressions = await self.get_model_predicted_expressions( chat_id=chat_id, predicted_styles=predicted_styles, @@ -348,7 +383,7 @@ class ExpressionSelector: ) if not expressions: - logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式") + logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -357,7 +392,7 @@ class ExpressionSelector: min_num=min_num ) - logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式") + logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式") return expressions async def get_model_predicted_expressions( @@ -384,22 +419,95 @@ class ExpressionSelector: style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]] logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}") + # 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式) + related_chat_ids = self.get_related_chat_ids(chat_id) + logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}") + async with get_db_session() as session: - # 查询匹配这些风格的表达方式 - stmt = ( - select(Expression) - .where(Expression.chat_id == chat_id) - .where(Expression.style.in_(style_names)) - .order_by(Expression.count.desc()) - .limit(max_num) + # 🔍 先检查数据库中实际有哪些 chat_id 的数据 + db_chat_ids_result = await session.execute( + select(Expression.chat_id) + .where(Expression.type == "style") + .distinct() ) - result = await session.execute(stmt) - expressions_objs = result.scalars().all() + db_chat_ids = [cid for cid in db_chat_ids_result.scalars()] + logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}") - if not expressions_objs: - logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式") + # 获取所有相关 chat_id 的表达方式(用于模糊匹配) + all_expressions_result = await session.execute( + select(Expression) + .where(Expression.chat_id.in_(related_chat_ids)) + .where(Expression.type == "style") + ) + all_expressions = list(all_expressions_result.scalars()) + + logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}") + + # 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id + if not all_expressions: + logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询") + all_expressions_result = await session.execute( + select(Expression) + .where(Expression.type == "style") + ) + all_expressions = list(all_expressions_result.scalars()) + logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}") + + if not all_expressions: + logger.warning(f"数据库中完全没有任何表达方式,需要先学习") return [] + # 🔥 使用模糊匹配而不是精确匹配 + # 计算每个预测style与数据库style的相似度 + from difflib import SequenceMatcher + + matched_expressions = [] + for expr in all_expressions: + db_style = expr.style or "" + max_similarity = 0.0 + best_predicted = "" + + # 与每个预测的style计算相似度 + for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 + # 计算字符串相似度 + similarity = SequenceMatcher(None, predicted_style, db_style).ratio() + + # 也检查包含关系(如果一个是另一个的子串,给更高分) + if len(predicted_style) >= 2 and len(db_style) >= 2: + if predicted_style in db_style or db_style in predicted_style: + similarity = max(similarity, 0.7) + + if similarity > max_similarity: + max_similarity = similarity + best_predicted = predicted_style + + # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 + if max_similarity >= 0.3: # 30%相似度阈值 + matched_expressions.append((expr, max_similarity, expr.count, best_predicted)) + + if not matched_expressions: + # 收集数据库中的style样例用于调试 + all_styles = [e.style for e in all_expressions[:10]] + logger.warning( + f"数据库中没有找到匹配的表达方式(相似度阈值30%):\n" + f" 预测的style (前3个): {style_names}\n" + f" 数据库中存在的style样例: {all_styles}\n" + f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式" + ) + return [] + + # 按照相似度*count排序,选择最佳匹配 + matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True) + expressions_objs = [e[0] for e in matched_expressions[:max_num]] + + # 显示最佳匹配的详细信息 + top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]] + logger.info( + f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n" + f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n" + f" Top3匹配: {top_matches}" + ) + # 转换为字典格式 expressions = [] for expr in expressions_objs: diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py new file mode 100644 index 000000000..8ebe0a8bd --- /dev/null +++ b/src/chat/express/situation_extractor.py @@ -0,0 +1,162 @@ +""" +情境提取器 +从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测 +""" +from typing import Optional + +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +logger = get_logger("situation_extractor") + + +def init_prompt(): + situation_extraction_prompt = """ +以下是正在进行的聊天内容: +{chat_history} + +你的名字是{bot_name}{target_message_info} + +请分析当前聊天的情境特征,提取出最能描述当前情境的1-3个关键场景描述。 + +场景描述应该: +1. 简洁明了(每个不超过20个字) +2. 聚焦情绪、话题、氛围 +3. 不涉及具体人名 +4. 类似于"表示惊讶"、"讨论游戏"、"表达赞同"这样的格式 + +请以纯文本格式输出,每行一个场景描述,不要有序号、引号或其他格式: + +例如: +表示惊讶和意外 +讨论技术问题 +表达友好的赞同 + +现在请提取当前聊天的情境: +""" + Prompt(situation_extraction_prompt, "situation_extraction_prompt") + + +class SituationExtractor: + """情境提取器,从聊天历史中提取当前情境""" + + def __init__(self): + self.llm_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, + request_type="expression.situation_extractor" + ) + + async def extract_situations( + self, + chat_history: list | str, + target_message: Optional[str] = None, + max_situations: int = 3 + ) -> list[str]: + """ + 从聊天历史中提取情境 + + Args: + chat_history: 聊天历史(列表或字符串) + target_message: 目标消息(可选) + max_situations: 最多提取的情境数量 + + Returns: + 情境描述列表 + """ + # 转换chat_history为字符串 + if isinstance(chat_history, list): + chat_info = "\n".join([ + f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" + for msg in chat_history + ]) + else: + chat_info = chat_history + + # 构建目标消息信息 + if target_message: + target_message_info = f",现在你想要回复消息:{target_message}" + else: + target_message_info = "" + + # 构建 prompt + try: + prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format( + bot_name=global_config.bot.nickname, + chat_history=chat_info, + target_message_info=target_message_info + ) + + # 调用 LLM + response, _ = await self.llm_model.generate_response_async( + prompt=prompt, + temperature=0.3 + ) + + if not response or not response.strip(): + logger.warning("LLM返回空响应,无法提取情境") + return [] + + # 解析响应 + situations = self._parse_situations(response, max_situations) + + if situations: + logger.debug(f"提取到 {len(situations)} 个情境: {situations}") + else: + logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}") + + return situations + + except Exception as e: + logger.error(f"提取情境失败: {e}") + return [] + + @staticmethod + def _parse_situations(response: str, max_situations: int) -> list[str]: + """ + 解析 LLM 返回的情境描述 + + Args: + response: LLM 响应 + max_situations: 最多返回的情境数量 + + Returns: + 情境描述列表 + """ + situations = [] + + for line in response.splitlines(): + line = line.strip() + if not line: + continue + + # 移除可能的序号、引号等 + line = line.lstrip('0123456789.、-*>))】] \t"\'""''') + line = line.rstrip('"\'""''') + line = line.strip() + + if not line: + continue + + # 过滤掉明显不是情境描述的内容 + if len(line) > 30: # 太长 + continue + if len(line) < 2: # 太短 + continue + if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']): + continue + + situations.append(line) + + if len(situations) >= max_situations: + break + + return situations + + +# 初始化 prompt +init_prompt() + +# 全局单例 +situation_extractor = SituationExtractor() diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index fa0302fb1..c254ef98c 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -142,13 +142,26 @@ class StyleLearner: (最佳style文本, 所有候选的分数字典) """ try: + # 先检查是否有训练数据 + if not self.style_to_id: + logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}") + return None, {} + best_style_id, scores = self.expressor.predict(up_content, k=top_k) if best_style_id is None: + logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...") return None, {} # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) + + if best_style is None: + logger.warning( + f"style_id无法转换为style文本: style_id={best_style_id}, " + f"已知的id_to_style数量={len(self.id_to_style)}" + ) + return None, {} # 转换所有分数 style_scores = {} @@ -156,11 +169,18 @@ class StyleLearner: style_text = self.id_to_style.get(sid) if style_text: style_scores[style_text] = score + else: + logger.warning(f"跳过无法转换的style_id: {sid}") + + logger.debug( + f"预测成功: up_content={up_content[:30]}..., " + f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}" + ) return best_style, style_scores except Exception as e: - logger.error(f"预测style失败: {e}") + logger.error(f"预测style失败: {e}", exc_info=True) return None, {} def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 5bf13081f..c111bf8b4 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -46,6 +46,9 @@ class StreamLoopManager: # 状态控制 self.is_running = False + # 每个流的上一次间隔值(用于日志去重) + self._last_intervals: dict[str, float] = {} + logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})") async def start(self) -> None: @@ -285,7 +288,11 @@ class StreamLoopManager: interval = await self._calculate_interval(stream_id, has_messages) # 6. sleep等待下次检查 - logger.info(f"流 {stream_id} 等待 {interval:.2f}s") + # 只在间隔发生变化时输出日志,避免刷屏 + last_interval = self._last_intervals.get(stream_id) + if last_interval is None or abs(interval - last_interval) > 0.01: + logger.info(f"流 {stream_id} 等待周期变化: {interval:.2f}s") + self._last_intervals[stream_id] = interval await asyncio.sleep(interval) except asyncio.CancelledError: @@ -316,6 +323,9 @@ class StreamLoopManager: except Exception as e: logger.debug(f"释放自适应流处理槽位失败: {e}") + # 清理间隔记录 + self._last_intervals.pop(stream_id, None) + logger.info(f"流循环结束: {stream_id}") async def _get_stream_context(self, stream_id: str) -> Any | None: diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index feb18848b..191bbc16d 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -108,52 +108,79 @@ def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | """查找要回复的消息 Args: - message_dict: 消息字典 + message_dict: 消息字典或 DatabaseMessages 对象 Returns: Optional[MessageRecv]: 找到的消息,如果没找到则返回None """ + # 兼容 DatabaseMessages 对象和字典 + if isinstance(message_dict, dict): + user_platform = message_dict.get("user_platform", "") + user_id = message_dict.get("user_id", "") + user_nickname = message_dict.get("user_nickname", "") + user_cardname = message_dict.get("user_cardname", "") + chat_info_group_id = message_dict.get("chat_info_group_id") + chat_info_group_platform = message_dict.get("chat_info_group_platform", "") + chat_info_group_name = message_dict.get("chat_info_group_name", "") + chat_info_platform = message_dict.get("chat_info_platform", "") + message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") + time_val = message_dict.get("time") + additional_config = message_dict.get("additional_config") + processed_plain_text = message_dict.get("processed_plain_text") + else: + # DatabaseMessages 对象 + user_platform = getattr(message_dict, "user_platform", "") + user_id = getattr(message_dict, "user_id", "") + user_nickname = getattr(message_dict, "user_nickname", "") + user_cardname = getattr(message_dict, "user_cardname", "") + chat_info_group_id = getattr(message_dict, "chat_info_group_id", None) + chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "") + chat_info_group_name = getattr(message_dict, "chat_info_group_name", "") + chat_info_platform = getattr(message_dict, "chat_info_platform", "") + message_id = getattr(message_dict, "message_id", None) + time_val = getattr(message_dict, "time", None) + additional_config = getattr(message_dict, "additional_config", None) + processed_plain_text = getattr(message_dict, "processed_plain_text", "") + # 构建MessageRecv对象 user_info = { - "platform": message_dict.get("user_platform", ""), - "user_id": message_dict.get("user_id", ""), - "user_nickname": message_dict.get("user_nickname", ""), - "user_cardname": message_dict.get("user_cardname", ""), + "platform": user_platform, + "user_id": user_id, + "user_nickname": user_nickname, + "user_cardname": user_cardname, } group_info = {} - if message_dict.get("chat_info_group_id"): + if chat_info_group_id: group_info = { - "platform": message_dict.get("chat_info_group_platform", ""), - "group_id": message_dict.get("chat_info_group_id", ""), - "group_name": message_dict.get("chat_info_group_name", ""), + "platform": chat_info_group_platform, + "group_id": chat_info_group_id, + "group_name": chat_info_group_name, } format_info = {"content_format": "", "accept_format": ""} template_info = {"template_items": {}} message_info = { - "platform": message_dict.get("chat_info_platform", ""), - "message_id": message_dict.get("message_id") - or message_dict.get("chat_info_message_id") - or message_dict.get("id"), - "time": message_dict.get("time"), + "platform": chat_info_platform, + "message_id": message_id, + "time": time_val, "group_info": group_info, "user_info": user_info, - "additional_config": message_dict.get("additional_config"), + "additional_config": additional_config, "format_info": format_info, "template_info": template_info, } new_message_dict = { "message_info": message_info, - "raw_message": message_dict.get("processed_plain_text"), - "processed_plain_text": message_dict.get("processed_plain_text"), + "raw_message": processed_plain_text, + "processed_plain_text": processed_plain_text, } message_recv = MessageRecv(new_message_dict) - logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") + logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}") return message_recv