diff --git a/integration_test_relationship_tools.py b/integration_test_relationship_tools.py new file mode 100644 index 000000000..a2ac3a7fa --- /dev/null +++ b/integration_test_relationship_tools.py @@ -0,0 +1,303 @@ +""" +关系追踪工具集成测试脚本 + +注意:此脚本需要在完整的应用环境中运行 +建议通过 bot.py 启动后在交互式环境中测试 +""" + +import asyncio + + +async def test_user_profile_tool(): + """测试用户画像工具""" + print("\n" + "=" * 80) + print("测试 UserProfileTool") + print("=" * 80) + + from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool + from src.common.database.sqlalchemy_database_api import db_query + from src.common.database.sqlalchemy_models import UserRelationships + + tool = UserProfileTool() + print(f"✅ 工具名称: {tool.name}") + print(f" 工具描述: {tool.description}") + + # 执行工具 + test_user_id = "integration_test_user_001" + result = await tool.execute({ + "target_user_id": test_user_id, + "user_aliases": "测试小明,TestMing,小明君", + "impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。", + "preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说", + "affection_score": 0.85 + }) + + print(f"\n✅ 工具执行结果:") + print(f" 类型: {result.get('type')}") + print(f" 内容: {result.get('content')}") + + # 验证数据库 + db_data = await db_query( + UserRelationships, + filters={"user_id": test_user_id}, + limit=1 + ) + + if db_data: + data = db_data[0] + print(f"\n✅ 数据库验证:") + print(f" user_id: {data.get('user_id')}") + print(f" user_aliases: {data.get('user_aliases')}") + print(f" relationship_text: {data.get('relationship_text', '')[:80]}...") + print(f" preference_keywords: {data.get('preference_keywords')}") + print(f" relationship_score: {data.get('relationship_score')}") + return True + else: + print(f"\n❌ 数据库中未找到数据") + return False + + +async def test_chat_stream_impression_tool(): + """测试聊天流印象工具""" + print("\n" + "=" * 80) + print("测试 ChatStreamImpressionTool") + print("=" * 80) + + from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool + from src.common.database.sqlalchemy_database_api import db_query + from src.common.database.sqlalchemy_models import ChatStreams, get_db_session + + # 准备测试数据:先创建一条 ChatStreams 记录 + test_stream_id = "integration_test_stream_001" + print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}") + + import time + current_time = time.time() + + async with get_db_session() as session: + new_stream = ChatStreams( + stream_id=test_stream_id, + create_time=current_time, + last_active_time=current_time, + platform="QQ", + user_platform="QQ", + user_id="test_user_123", + user_nickname="测试用户", + group_name="测试技术交流群", + group_platform="QQ", + group_id="test_group_456", + stream_impression_text="", # 初始为空 + stream_chat_style="", + stream_topic_keywords="", + stream_interest_score=0.5 + ) + session.add(new_stream) + await session.commit() + print(f"✅ 测试聊天流记录已创建") + + tool = ChatStreamImpressionTool() + print(f"✅ 工具名称: {tool.name}") + print(f" 工具描述: {tool.description}") + + # 执行工具 + result = await tool.execute({ + "stream_id": test_stream_id, + "impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。", + "chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享", + "topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目", + "interest_score": 0.90 + }) + + print(f"\n✅ 工具执行结果:") + print(f" 类型: {result.get('type')}") + print(f" 内容: {result.get('content')}") + + # 验证数据库 + db_data = await db_query( + ChatStreams, + filters={"stream_id": test_stream_id}, + limit=1 + ) + + if db_data: + data = db_data[0] + print(f"\n✅ 数据库验证:") + print(f" stream_id: {data.get('stream_id')}") + print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...") + print(f" stream_chat_style: {data.get('stream_chat_style')}") + print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}") + print(f" stream_interest_score: {data.get('stream_interest_score')}") + return True + else: + print(f"\n❌ 数据库中未找到数据") + return False + + +async def test_relationship_info_build(): + """测试关系信息构建""" + print("\n" + "=" * 80) + print("测试关系信息构建(提示词集成)") + print("=" * 80) + + from src.person_info.relationship_fetcher import relationship_fetcher_manager + + test_stream_id = "integration_test_stream_001" + test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试 + + fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id) + print(f"✅ RelationshipFetcher 已创建") + + # 测试聊天流印象构建 + print(f"\n🔍 构建聊天流印象...") + stream_info = await fetcher.build_chat_stream_impression(test_stream_id) + + if stream_info: + print(f"✅ 聊天流印象构建成功") + print(f"\n{'=' * 80}") + print(stream_info) + print(f"{'=' * 80}") + else: + print(f"⚠️ 聊天流印象为空(可能测试数据不存在)") + + return True + + +async def cleanup_test_data(): + """清理测试数据""" + print("\n" + "=" * 80) + print("清理测试数据") + print("=" * 80) + + from src.common.database.sqlalchemy_database_api import db_query + from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams + + try: + # 清理用户数据 + await db_query( + UserRelationships, + query_type="delete", + filters={"user_id": "integration_test_user_001"} + ) + print("✅ 用户测试数据已清理") + + # 清理聊天流数据 + await db_query( + ChatStreams, + query_type="delete", + filters={"stream_id": "integration_test_stream_001"} + ) + print("✅ 聊天流测试数据已清理") + + return True + except Exception as e: + print(f"⚠️ 清理失败: {e}") + return False + + +async def run_all_tests(): + """运行所有测试""" + print("\n" + "=" * 80) + print("关系追踪工具集成测试") + print("=" * 80) + + results = {} + + # 测试1 + try: + results["UserProfileTool"] = await test_user_profile_tool() + except Exception as e: + print(f"\n❌ UserProfileTool 测试失败: {e}") + import traceback + traceback.print_exc() + results["UserProfileTool"] = False + + # 测试2 + try: + results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool() + except Exception as e: + print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}") + import traceback + traceback.print_exc() + results["ChatStreamImpressionTool"] = False + + # 测试3 + try: + results["RelationshipFetcher"] = await test_relationship_info_build() + except Exception as e: + print(f"\n❌ RelationshipFetcher 测试失败: {e}") + import traceback + traceback.print_exc() + results["RelationshipFetcher"] = False + + # 清理 + try: + await cleanup_test_data() + except Exception as e: + print(f"\n⚠️ 清理测试数据失败: {e}") + + # 总结 + print("\n" + "=" * 80) + print("测试总结") + print("=" * 80) + + passed = sum(1 for r in results.values() if r) + total = len(results) + + for test_name, result in results.items(): + status = "✅ 通过" if result else "❌ 失败" + print(f"{status} - {test_name}") + + print(f"\n总计: {passed}/{total} 测试通过") + + if passed == total: + print("\n🎉 所有测试通过!") + else: + print(f"\n⚠️ {total - passed} 个测试失败") + + return passed == total + + +# 使用说明 +print(""" +============================================================================ +关系追踪工具集成测试脚本 +============================================================================ + +此脚本需要在完整的应用环境中运行。 + +使用方法1: 在 bot.py 中添加测试调用 +----------------------------------- +在 bot.py 的 main() 函数中添加: + + # 测试关系追踪工具 + from tests.integration_test_relationship_tools import run_all_tests + await run_all_tests() + +使用方法2: 在 Python REPL 中运行 +----------------------------------- +启动 bot.py 后,在 Python 调试控制台中执行: + + import asyncio + from tests.integration_test_relationship_tools import run_all_tests + asyncio.create_task(run_all_tests()) + +使用方法3: 直接在此文件底部运行 +----------------------------------- +取消注释下面的代码,然后确保已启动应用环境 +============================================================================ +""") + + +# 如果需要直接运行(需要应用环境已启动) +if __name__ == "__main__": + print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境") + print("建议在 bot.py 启动后的环境中运行\n") + + try: + asyncio.run(run_all_tests()) + except Exception as e: + print(f"\n❌ 测试失败: {e}") + print("\n建议:") + print("1. 确保已启动 bot.py") + print("2. 在 Python 调试控制台中运行测试") + print("3. 或在 bot.py 中添加测试调用") diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py new file mode 100644 index 000000000..f600cc434 --- /dev/null +++ b/scripts/check_expression_database.py @@ -0,0 +1,116 @@ +""" +检查表达方式数据库状态的诊断脚本 +""" +import asyncio +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from sqlalchemy import select, func +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression + + +async def check_database(): + """检查表达方式数据库状态""" + + print("=" * 60) + print("表达方式数据库诊断报告") + print("=" * 60) + + async with get_db_session() as session: + # 1. 统计总数 + total_count = await session.execute(select(func.count()).select_from(Expression)) + total = total_count.scalar() + print(f"\n📊 总表达方式数量: {total}") + + if total == 0: + print("\n⚠️ 数据库为空!") + print("\n可能的原因:") + print("1. 还没有进行过表达学习") + print("2. 配置中禁用了表达学习") + print("3. 学习过程中发生了错误") + print("\n建议:") + print("- 检查 bot_config.toml 中的 [expression] 配置") + print("- 查看日志中是否有表达学习相关的错误") + print("- 确认聊天流的 learn_expression 配置为 true") + return + + # 2. 按 chat_id 统计 + print("\n📝 按聊天流统计:") + chat_counts = await session.execute( + select(Expression.chat_id, func.count()) + .group_by(Expression.chat_id) + ) + for chat_id, count in chat_counts: + print(f" - {chat_id}: {count} 个表达方式") + + # 3. 按 type 统计 + print("\n📝 按类型统计:") + type_counts = await session.execute( + select(Expression.type, func.count()) + .group_by(Expression.type) + ) + for expr_type, count in type_counts: + print(f" - {expr_type}: {count} 个") + + # 4. 检查 situation 和 style 字段是否有空值 + print("\n🔍 字段完整性检查:") + null_situation = await session.execute( + select(func.count()) + .select_from(Expression) + .where(Expression.situation == None) + ) + null_style = await session.execute( + select(func.count()) + .select_from(Expression) + .where(Expression.style == None) + ) + + null_sit_count = null_situation.scalar() + null_sty_count = null_style.scalar() + + print(f" - situation 为空: {null_sit_count} 个") + print(f" - style 为空: {null_sty_count} 个") + + if null_sit_count > 0 or null_sty_count > 0: + print(" ⚠️ 发现空值!这会导致匹配失败") + + # 5. 显示一些样例数据 + print("\n📋 样例数据 (前10条):") + samples = await session.execute( + select(Expression) + .limit(10) + ) + + for i, expr in enumerate(samples.scalars(), 1): + print(f"\n [{i}] Chat: {expr.chat_id}") + print(f" Type: {expr.type}") + print(f" Situation: {expr.situation}") + print(f" Style: {expr.style}") + print(f" Count: {expr.count}") + + # 6. 检查 style 字段的唯一值 + print("\n📋 Style 字段样例 (前20个):") + unique_styles = await session.execute( + select(Expression.style) + .distinct() + .limit(20) + ) + + styles = [s for s in unique_styles.scalars()] + for style in styles: + print(f" - {style}") + + print(f"\n (共 {len(styles)} 个不同的 style)") + + print("\n" + "=" * 60) + print("诊断完成") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(check_database()) diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py new file mode 100644 index 000000000..c8f5ef1fb --- /dev/null +++ b/scripts/check_style_field.py @@ -0,0 +1,65 @@ +""" +检查数据库中 style 字段的内容特征 +""" +import asyncio +import sys +from pathlib import Path + +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from sqlalchemy import select +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression + + +async def analyze_style_fields(): + """分析 style 字段的内容""" + + print("=" * 60) + print("Style 字段内容分析") + print("=" * 60) + + async with get_db_session() as session: + # 获取所有表达方式 + result = await session.execute(select(Expression).limit(30)) + expressions = result.scalars().all() + + print(f"\n总共检查 {len(expressions)} 条记录\n") + + # 按类型分类 + style_examples = [] + + for expr in expressions: + if expr.type == "style": + style_examples.append({ + "situation": expr.situation, + "style": expr.style, + "length": len(expr.style) if expr.style else 0 + }) + + print("📋 Style 类型样例 (前15条):") + print("="*60) + for i, ex in enumerate(style_examples[:15], 1): + print(f"\n[{i}]") + print(f" Situation: {ex['situation']}") + print(f" Style: {ex['style']}") + print(f" 长度: {ex['length']} 字符") + + # 判断是具体表达还是风格描述 + if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']): + style_type = "✓ 风格描述" + elif ex['length'] <= 10: + style_type = "? 可能是具体表达(较短)" + else: + style_type = "✗ 具体表达内容" + + print(f" 类型判断: {style_type}") + + print("\n" + "="*60) + print("分析完成") + print("="*60) + + +if __name__ == "__main__": + asyncio.run(analyze_style_fields()) diff --git a/scripts/debug_style_learner.py b/scripts/debug_style_learner.py new file mode 100644 index 000000000..970ba2532 --- /dev/null +++ b/scripts/debug_style_learner.py @@ -0,0 +1,88 @@ +""" +检查 StyleLearner 模型状态的诊断脚本 +""" +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.chat.express.style_learner import style_learner_manager +from src.common.logger import get_logger + +logger = get_logger("debug_style_learner") + + +def check_style_learner_status(chat_id: str): + """检查指定 chat_id 的 StyleLearner 状态""" + + print("=" * 60) + print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}") + print("=" * 60) + + # 获取 learner + learner = style_learner_manager.get_learner(chat_id) + + # 1. 基本信息 + print(f"\n📊 基本信息:") + print(f" Chat ID: {learner.chat_id}") + print(f" 风格数量: {len(learner.style_to_id)}") + print(f" 下一个ID: {learner.next_style_id}") + print(f" 最大风格数: {learner.max_styles}") + + # 2. 学习统计 + print(f"\n📈 学习统计:") + print(f" 总样本数: {learner.learning_stats['total_samples']}") + print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}") + + # 3. 风格列表(前20个) + print(f"\n📋 已学习的风格 (前20个):") + all_styles = learner.get_all_styles() + if not all_styles: + print(" ⚠️ 没有任何风格!模型尚未训练") + else: + for i, style in enumerate(all_styles[:20], 1): + style_id = learner.style_to_id.get(style) + situation = learner.id_to_situation.get(style_id, "N/A") + print(f" [{i}] {style}") + print(f" (ID: {style_id}, Situation: {situation})") + + # 4. 测试预测 + print(f"\n🔮 测试预测功能:") + if not all_styles: + print(" ⚠️ 无法测试,模型没有训练数据") + else: + test_situations = [ + "表示惊讶", + "讨论游戏", + "表达赞同" + ] + + for test_sit in test_situations: + print(f"\n 测试输入: '{test_sit}'") + best_style, scores = learner.predict_style(test_sit, top_k=3) + + if best_style: + print(f" ✓ 最佳匹配: {best_style}") + print(f" Top 3:") + for style, score in list(scores.items())[:3]: + print(f" - {style}: {score:.4f}") + else: + print(f" ✗ 预测失败") + + print("\n" + "=" * 60) + print("诊断完成") + print("=" * 60) + + +if __name__ == "__main__": + # 从诊断报告中看到的 chat_id + test_chat_ids = [ + "52fb94af9f500a01e023ea780e43606e", # 有78个表达方式 + "46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式 + ] + + for chat_id in test_chat_ids: + check_style_learner_status(chat_id) + print("\n") diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py new file mode 100644 index 000000000..bd7f41e2d --- /dev/null +++ b/src/chat/express/express_utils.py @@ -0,0 +1,254 @@ +""" +表达系统工具函数 +提供消息过滤、文本相似度计算、加权随机抽样等功能 +""" +import difflib +import random +import re +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger + +logger = get_logger("express_utils") + + +def filter_message_content(content: Optional[str]) -> str: + """ + 过滤消息内容,移除回复、@、图片等格式 + + Args: + content: 原始消息内容 + + Returns: + 过滤后的纯文本内容 + """ + if not content: + return "" + + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 + content = re.sub(r"\[回复.*?\],说:\s*", "", content) + # 移除@<...>格式的内容 + content = re.sub(r"@<[^>]*>", "", content) + # 移除[图片:...]格式的图片ID + content = re.sub(r"\[图片:[^\]]*\]", "", content) + # 移除[表情包:...]格式的内容 + content = re.sub(r"\[表情包:[^\]]*\]", "", content) + + return content.strip() + + +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + 相似度值 (0-1) + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + +def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]: + """ + 加权随机抽样函数 + + Args: + population: 待抽样的数据列表 + k: 抽样数量 + weight_key: 权重字段名,如果为None则等概率抽样 + + Returns: + 抽样结果列表 + """ + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + # 如果指定了权重字段 + if weight_key and all(weight_key in item for item in population): + try: + # 获取权重 + weights = [float(item.get(weight_key, 1.0)) for item in population] + # 使用random.choices进行加权抽样 + return random.choices(population, weights=weights, k=k) + except (ValueError, TypeError) as e: + logger.warning(f"加权抽样失败,使用等概率抽样: {e}") + + # 等概率抽样 + selected = [] + population_copy = population.copy() + + for _ in range(k): + if not population_copy: + break + # 随机选择一个元素 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + + return selected + + +def normalize_text(text: str) -> str: + """ + 标准化文本,移除多余空白字符 + + Args: + text: 输入文本 + + Returns: + 标准化后的文本 + """ + # 替换多个连续空白字符为单个空格 + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: + """ + 简单的关键词提取(基于词频) + + Args: + text: 输入文本 + max_keywords: 最大关键词数量 + + Returns: + 关键词列表 + """ + if not text: + return [] + + try: + import jieba.analyse + + # 使用TF-IDF提取关键词 + keywords = jieba.analyse.extract_tags(text, topK=max_keywords) + return keywords + except ImportError: + logger.warning("jieba未安装,无法提取关键词") + # 简单分词 + words = text.split() + return words[:max_keywords] + + +def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str: + """ + 格式化表达方式对 + + Args: + situation: 情境 + style: 风格 + index: 序号(可选) + + Returns: + 格式化后的字符串 + """ + if index is not None: + return f'{index}. 当"{situation}"时,使用"{style}"' + else: + return f'当"{situation}"时,使用"{style}"' + + +def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: + """ + 解析表达方式对文本 + + Args: + text: 格式化的表达方式对文本 + + Returns: + (situation, style) 或 None + """ + # 匹配格式:当"..."时,使用"..." + match = re.search(r'当"(.+?)"时,使用"(.+?)"', text) + if match: + return match.group(1), match.group(2) + return None + + +def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]: + """ + 批量去重表达方式 + + Args: + expressions: 表达方式列表 + key_fields: 用于去重的字段名列表 + + Returns: + 去重后的表达方式列表 + """ + seen = set() + unique_expressions = [] + + for expr in expressions: + # 构建去重key + key_values = tuple(expr.get(field, "") for field in key_fields) + + if key_values not in seen: + seen.add(key_values) + unique_expressions.append(expr) + + return unique_expressions + + +def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float: + """ + 根据时间计算权重(时间衰减) + + Args: + last_active_time: 最后活跃时间戳 + current_time: 当前时间戳 + half_life_days: 半衰期天数 + + Returns: + 权重值 (0-1) + """ + time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数 + if time_diff_days < 0: + return 1.0 + + # 使用指数衰减公式 + decay_rate = 0.693 / half_life_days # ln(2) / half_life + weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days))) + + return weight + + +def merge_expressions_from_multiple_chats( + expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100 +) -> List[Dict[str, Any]]: + """ + 合并多个聊天室的表达方式 + + Args: + expressions_dict: {chat_id: [expressions]} + max_total: 最大合并数量 + + Returns: + 合并后的表达方式列表 + """ + all_expressions = [] + + # 收集所有表达方式 + for chat_id, expressions in expressions_dict.items(): + for expr in expressions: + # 添加source_id标识 + expr_with_source = expr.copy() + expr_with_source["source_id"] = chat_id + all_expressions.append(expr_with_source) + + # 按count或last_active_time排序 + if all_expressions and "count" in all_expressions[0]: + all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True) + elif all_expressions and "last_active_time" in all_expressions[0]: + all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True) + + # 去重(基于situation和style) + all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"]) + + # 限制数量 + return all_expressions[:max_total] diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 0c25b9fc6..75864be40 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -16,6 +16,9 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +# 导入 StyleLearner 管理器 +from .style_learner import style_learner_manager + MAX_EXPRESSION_COUNT = 300 DECAY_DAYS = 30 # 30天衰减到0.01 DECAY_MIN = 0.01 # 最小衰减值 @@ -43,17 +46,29 @@ def init_prompt() -> None: 3. 语言风格包含特殊内容和情感 4. 思考有没有特殊的梗,一并总结成语言风格 5. 例子仅供参考,请严格根据群聊内容总结!!! -注意:总结成如下格式的规律,总结的内容要详细,但具有概括性: -例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。 + +**重要:必须严格按照以下格式输出,每行一条规律:** +当"xxx"时,使用"xxx" + +格式说明: +- 必须以"当"开头 +- 场景描述用双引号包裹,不超过20个字 +- 必须包含"使用"或"可以" +- 表达风格用双引号包裹,不超过20个字 +- 每条规律独占一行 例如: 当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx" 当"表示讽刺的赞同,不想讲道理"时,使用"对对对" -当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂" -当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" +当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契"时,使用"懂的都懂" +当"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" -请注意:不要总结你自己(SELF)的发言 -现在请你概括 +注意: +1. 不要总结你自己(SELF)的发言 +2. 如果聊天内容中没有明显的特殊风格,请只输出1-2条最明显的特点 +3. 不要输出其他解释性文字,只输出符合格式的规律 + +现在请你概括: """ Prompt(learn_style_prompt, "learn_style_prompt") @@ -65,16 +80,28 @@ def init_prompt() -> None: 2.不要涉及具体的人名,只考虑语法和句法特点, 3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。 4. 例子仅供参考,请严格根据群聊内容总结!!! -总结成如下格式的规律,总结的内容要简洁,不浮夸: -当"xxx"时,可以"xxx" + +**重要:必须严格按照以下格式输出,每行一条规律:** +当"xxx"时,使用"xxx" + +格式说明: +- 必须以"当"开头 +- 场景描述用双引号包裹 +- 必须包含"使用"或"可以" +- 句法特点用双引号包裹 +- 每条规律独占一行 例如: 当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法 当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法 当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法 -注意不要总结你自己(SELF)的发言 -现在请你概括 +注意: +1. 不要总结你自己(SELF)的发言 +2. 如果聊天内容中没有明显的句法特点,请只输出1-2条最明显的特点 +3. 不要输出其他解释性文字,只输出符合格式的规律 + +现在请你概括: """ Prompt(learn_grammar_prompt, "learn_grammar_prompt") @@ -405,6 +432,44 @@ class ExpressionLearner: for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) + # 🔥 训练 StyleLearner + # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) + if type == "style": + try: + # 获取 StyleLearner 实例 + learner = style_learner_manager.get_learner(chat_id) + + logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") + + # 为每个学习到的表达方式训练模型 + # 使用 situation 作为输入,style 作为目标 + # 这是最符合语义的方式:场景 -> 表达方式 + success_count = 0 + for expr in expr_list: + situation = expr["situation"] + style = expr["style"] + + # 训练映射关系: situation -> style + if learner.learn_mapping(situation, style): + success_count += 1 + else: + logger.warning(f"训练失败: {situation} -> {style}") + + logger.info( + f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " + f"当前风格总数={len(learner.get_all_styles())}, " + f"总样本数={learner.learning_stats['total_samples']}" + ) + + # 保存模型 + if learner.save(style_learner_manager.model_save_path): + logger.info(f"StyleLearner 模型保存成功: {chat_id}") + else: + logger.error(f"StyleLearner 模型保存失败: {chat_id}") + + except Exception as e: + logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True) + return learnt_expressions return None @@ -455,9 +520,17 @@ class ExpressionLearner: logger.error(f"学习{type_str}失败: {e}") return None + if not response or not response.strip(): + logger.warning(f"LLM返回空响应,无法学习{type_str}") + return None + logger.debug(f"学习{type_str}的response: {response}") expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + + if not expressions: + logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。") + logger.info(f"LLM完整响应:\n{response}") return expressions, chat_id @@ -465,31 +538,100 @@ class ExpressionLearner: def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 + 支持多种引号格式:"" 和 "" """ expressions: list[tuple[str, str, str]] = [] - for line in response.splitlines(): + failed_lines = [] + + for line_num, line in enumerate(response.splitlines(), 1): line = line.strip() if not line: continue + + # 替换中文引号为英文引号,便于统一处理 + line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"') + # 查找"当"和下一个引号 - idx_when = line.find('当"') + idx_when = line_normalized.find('当"') if idx_when == -1: - continue - idx_quote1 = idx_when + 1 - idx_quote2 = line.find('"', idx_quote1 + 1) - if idx_quote2 == -1: - continue - situation = line[idx_quote1 + 1 : idx_quote2] - # 查找"使用" - idx_use = line.find('使用"', idx_quote2) + # 尝试不带引号的格式: 当xxx时 + idx_when = line_normalized.find('当') + if idx_when == -1: + failed_lines.append((line_num, line, "找不到'当'关键字")) + continue + + # 提取"当"和"时"之间的内容 + idx_shi = line_normalized.find('时', idx_when) + if idx_shi == -1: + failed_lines.append((line_num, line, "找不到'时'关键字")) + continue + situation = line_normalized[idx_when + 1:idx_shi].strip('"\'""') + search_start = idx_shi + else: + idx_quote1 = idx_when + 1 + idx_quote2 = line_normalized.find('"', idx_quote1 + 1) + if idx_quote2 == -1: + failed_lines.append((line_num, line, "situation部分引号不匹配")) + continue + situation = line_normalized[idx_quote1 + 1 : idx_quote2] + search_start = idx_quote2 + + # 查找"使用"或"可以" + idx_use = line_normalized.find('使用"', search_start) if idx_use == -1: + idx_use = line_normalized.find('可以"', search_start) + if idx_use == -1: + # 尝试不带引号的格式 + idx_use = line_normalized.find('使用', search_start) + if idx_use == -1: + idx_use = line_normalized.find('可以', search_start) + if idx_use == -1: + failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字")) + continue + + # 提取剩余部分作为style + style = line_normalized[idx_use + 2:].strip('"\'"",。') + if not style: + failed_lines.append((line_num, line, "style部分为空")) + continue + else: + idx_quote3 = idx_use + 2 + idx_quote4 = line_normalized.find('"', idx_quote3 + 1) + if idx_quote4 == -1: + # 如果没有结束引号,取到行尾 + style = line_normalized[idx_quote3 + 1:].strip('"\'""') + else: + style = line_normalized[idx_quote3 + 1 : idx_quote4] + else: + idx_quote3 = idx_use + 2 + idx_quote4 = line_normalized.find('"', idx_quote3 + 1) + if idx_quote4 == -1: + # 如果没有结束引号,取到行尾 + style = line_normalized[idx_quote3 + 1:].strip('"\'""') + else: + style = line_normalized[idx_quote3 + 1 : idx_quote4] + + # 清理并验证 + situation = situation.strip() + style = style.strip() + + if not situation or not style: + failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'")) continue - idx_quote3 = idx_use + 2 - idx_quote4 = line.find('"', idx_quote3 + 1) - if idx_quote4 == -1: - continue - style = line[idx_quote3 + 1 : idx_quote4] + expressions.append((chat_id, situation, style)) + + # 记录解析失败的行 + if failed_lines: + logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:") + for line_num, line, reason in failed_lines[:5]: # 只显示前5个 + logger.warning(f" 行{line_num}: {reason}") + logger.debug(f" 原文: {line}") + + if not expressions: + logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}") + else: + logger.debug(f"成功解析 {len(expressions)} 个表达方式") return expressions @@ -522,12 +664,12 @@ class ExpressionLearnerManager: os.path.join(base_dir, "learnt_grammar"), ] - try: - for directory in directories_to_create: + for directory in directories_to_create: + try: os.makedirs(directory, exist_ok=True) - logger.debug(f"确保目录存在: {directory}") - except Exception as e: - logger.error(f"创建目录失败 {directory}: {e}") + logger.debug(f"确保目录存在: {directory}") + except Exception as e: + logger.error(f"创建目录失败 {directory}: {e}") @staticmethod async def _auto_migrate_json_to_db(): diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index eee737f3e..1dbf7e08e 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -15,6 +15,10 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +# 导入StyleLearner管理器和情境提取器 +from .situation_extractor import situation_extractor +from .style_learner import style_learner_manager + logger = get_logger("expression_selector") @@ -127,17 +131,18 @@ class ExpressionSelector: current_group = rule.group break - if not current_group: - return [chat_id] + # 🔥 始终包含当前 chat_id(确保至少能查到自己的数据) + related_chat_ids = [chat_id] - # 找出同一组的所有chat_id - related_chat_ids = [] - for rule in rules: - if rule.group == current_group and rule.chat_stream_id: - if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): - related_chat_ids.append(chat_id_candidate) + if current_group: + # 找出同一组的所有chat_id + for rule in rules: + if rule.group == current_group and rule.chat_stream_id: + if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): + if chat_id_candidate not in related_chat_ids: + related_chat_ids.append(chat_id_candidate) - return related_chat_ids if related_chat_ids else [chat_id] + return related_chat_ids async def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float @@ -236,6 +241,287 @@ class ExpressionSelector: ) await session.commit() + async def select_suitable_expressions( + self, + chat_id: str, + chat_history: list | str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """ + 统一的表达方式选择入口,根据配置自动选择模式 + + Args: + chat_id: 聊天ID + chat_history: 聊天历史(列表或字符串) + target_message: 目标消息 + max_num: 最多返回数量 + min_num: 最少返回数量 + + Returns: + 选中的表达方式列表 + """ + # 转换chat_history为字符串 + if isinstance(chat_history, list): + chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history]) + else: + chat_info = chat_history + + # 根据配置选择模式 + mode = global_config.expression.mode + logger.debug(f"[ExpressionSelector] 使用模式: {mode}") + + if mode == "exp_model": + return await self._select_expressions_model_only( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + else: # classic mode + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + async def _select_expressions_classic( + self, + chat_id: str, + chat_info: str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """经典模式:随机抽样 + LLM评估""" + logger.debug(f"[Classic模式] 使用LLM评估表达方式") + return await self.select_suitable_expressions_llm( + chat_id=chat_id, + chat_info=chat_info, + max_num=max_num, + min_num=min_num, + target_message=target_message + ) + + async def _select_expressions_model_only( + self, + chat_id: str, + chat_info: str, + target_message: str | None = None, + max_num: int = 10, + min_num: int = 5, + ) -> list[dict[str, Any]]: + """模型预测模式:先提取情境,再使用StyleLearner预测表达风格""" + logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") + + # 检查是否允许在此聊天流中使用表达 + if not self.can_use_expression_for_chat(chat_id): + logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") + return [] + + # 步骤1: 提取聊天情境 + situations = await situation_extractor.extract_situations( + chat_history=chat_info, + target_message=target_message, + max_situations=3 + ) + + if not situations: + logger.warning(f"无法提取聊天情境,回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}") + + # 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式 + learner = style_learner_manager.get_learner(chat_id) + + all_predicted_styles = {} + for i, situation in enumerate(situations, 1): + logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}") + best_style, scores = learner.predict_style(situation, top_k=max_num) + + if best_style and scores: + logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}") + # 合并分数(取最高分) + for style, score in scores.items(): + if style not in all_predicted_styles or score > all_predicted_styles[style]: + all_predicted_styles[style] = score + else: + logger.debug(f" 该情境未返回预测结果") + + if not all_predicted_styles: + logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + # 将分数字典转换为列表格式 [(style, score), ...] + predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True) + + logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}") + + # 步骤3: 根据预测的风格从数据库获取表达方式 + logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式") + expressions = await self.get_model_predicted_expressions( + chat_id=chat_id, + predicted_styles=predicted_styles, + max_num=max_num + ) + + if not expressions: + logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") + return await self._select_expressions_classic( + chat_id=chat_id, + chat_info=chat_info, + target_message=target_message, + max_num=max_num, + min_num=min_num + ) + + logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式") + return expressions + + async def get_model_predicted_expressions( + self, + chat_id: str, + predicted_styles: list[tuple[str, float]], + max_num: int = 10 + ) -> list[dict[str, Any]]: + """ + 根据StyleLearner预测的风格获取表达方式 + + Args: + chat_id: 聊天ID + predicted_styles: 预测的风格列表,格式: [(style, score), ...] + max_num: 最多返回数量 + + Returns: + 表达方式列表 + """ + if not predicted_styles: + return [] + + # 提取风格名称(前3个最佳匹配) + style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]] + logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}") + + # 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式) + related_chat_ids = self.get_related_chat_ids(chat_id) + logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}") + + async with get_db_session() as session: + # 🔍 先检查数据库中实际有哪些 chat_id 的数据 + db_chat_ids_result = await session.execute( + select(Expression.chat_id) + .where(Expression.type == "style") + .distinct() + ) + db_chat_ids = [cid for cid in db_chat_ids_result.scalars()] + logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}") + + # 获取所有相关 chat_id 的表达方式(用于模糊匹配) + all_expressions_result = await session.execute( + select(Expression) + .where(Expression.chat_id.in_(related_chat_ids)) + .where(Expression.type == "style") + ) + all_expressions = list(all_expressions_result.scalars()) + + logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}") + + # 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id + if not all_expressions: + logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询") + all_expressions_result = await session.execute( + select(Expression) + .where(Expression.type == "style") + ) + all_expressions = list(all_expressions_result.scalars()) + logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}") + + if not all_expressions: + logger.warning(f"数据库中完全没有任何表达方式,需要先学习") + return [] + + # 🔥 使用模糊匹配而不是精确匹配 + # 计算每个预测style与数据库style的相似度 + from difflib import SequenceMatcher + + matched_expressions = [] + for expr in all_expressions: + db_style = expr.style or "" + max_similarity = 0.0 + best_predicted = "" + + # 与每个预测的style计算相似度 + for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 + # 计算字符串相似度 + similarity = SequenceMatcher(None, predicted_style, db_style).ratio() + + # 也检查包含关系(如果一个是另一个的子串,给更高分) + if len(predicted_style) >= 2 and len(db_style) >= 2: + if predicted_style in db_style or db_style in predicted_style: + similarity = max(similarity, 0.7) + + if similarity > max_similarity: + max_similarity = similarity + best_predicted = predicted_style + + # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 + if max_similarity >= 0.3: # 30%相似度阈值 + matched_expressions.append((expr, max_similarity, expr.count, best_predicted)) + + if not matched_expressions: + # 收集数据库中的style样例用于调试 + all_styles = [e.style for e in all_expressions[:10]] + logger.warning( + f"数据库中没有找到匹配的表达方式(相似度阈值30%):\n" + f" 预测的style (前3个): {style_names}\n" + f" 数据库中存在的style样例: {all_styles}\n" + f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式" + ) + return [] + + # 按照相似度*count排序,选择最佳匹配 + matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True) + expressions_objs = [e[0] for e in matched_expressions[:max_num]] + + # 显示最佳匹配的详细信息 + top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]] + logger.info( + f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n" + f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n" + f" Top3匹配: {top_matches}" + ) + + # 转换为字典格式 + expressions = [] + for expr in expressions_objs: + expressions.append({ + "situation": expr.situation or "", + "style": expr.style or "", + "type": expr.type or "style", + "count": float(expr.count) if expr.count else 0.0, + "last_active_time": expr.last_active_time or 0.0 + }) + + logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式") + return expressions + async def select_suitable_expressions_llm( self, chat_id: str, diff --git a/src/chat/express/expressor_model/__init__.py b/src/chat/express/expressor_model/__init__.py new file mode 100644 index 000000000..a13656a85 --- /dev/null +++ b/src/chat/express/expressor_model/__init__.py @@ -0,0 +1,9 @@ +""" +表达模型包 +包含基于Online Naive Bayes的机器学习模型 +""" +from .model import ExpressorModel +from .online_nb import OnlineNaiveBayes +from .tokenizer import Tokenizer + +__all__ = ["ExpressorModel", "OnlineNaiveBayes", "Tokenizer"] diff --git a/src/chat/express/expressor_model/model.py b/src/chat/express/expressor_model/model.py new file mode 100644 index 000000000..8c18240a8 --- /dev/null +++ b/src/chat/express/expressor_model/model.py @@ -0,0 +1,216 @@ +""" +基于Online Naive Bayes的表达模型 +支持候选表达的动态添加和在线学习 +""" +import os +import pickle +from collections import Counter, defaultdict +from typing import Dict, Optional, Tuple + +from src.common.logger import get_logger + +from .online_nb import OnlineNaiveBayes +from .tokenizer import Tokenizer + +logger = get_logger("expressor.model") + + +class ExpressorModel: + """直接使用朴素贝叶斯精排(可在线学习)""" + + def __init__( + self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000, use_jieba: bool = True + ): + """ + Args: + alpha: 词频平滑参数 + beta: 类别先验平滑参数 + gamma: 衰减因子 + vocab_size: 词汇表大小 + use_jieba: 是否使用jieba分词 + """ + # 初始化分词器 + self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) + + # 初始化在线朴素贝叶斯模型 + self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) + + # 候选表达管理 + self._candidates: Dict[str, str] = {} # cid -> text (style) + self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) + + logger.info( + f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})" + ) + + def add_candidate(self, cid: str, text: str, situation: Optional[str] = None): + """ + 添加候选文本和对应的situation + + Args: + cid: 候选ID + text: 表达文本 (style) + situation: 情境文本 + """ + self._candidates[cid] = text + if situation is not None: + self._situations[cid] = situation + + # 确保在nb模型中初始化该候选的计数 + if cid not in self.nb.cls_counts: + self.nb.cls_counts[cid] = 0.0 + if cid not in self.nb.token_counts: + self.nb.token_counts[cid] = defaultdict(float) + + def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: + """ + 直接对所有候选进行朴素贝叶斯评分 + + Args: + text: 查询文本 + k: 返回前k个候选,如果为None则返回所有 + + Returns: + (最佳候选ID, 所有候选的分数字典) + """ + # 1. 分词 + toks = self.tokenizer.tokenize(text) + if not toks or not self._candidates: + return None, {} + + # 2. 计算词频 + tf = Counter(toks) + all_cids = list(self._candidates.keys()) + + # 3. 批量评分 + scores = self.nb.score_batch(tf, all_cids) + + if not scores: + return None, {} + + # 4. 根据k参数限制返回的候选数量 + if k is not None and k > 0: + sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + limited_scores = dict(sorted_scores[:k]) + best = sorted_scores[0][0] if sorted_scores else None + return best, limited_scores + else: + best = max(scores.items(), key=lambda x: x[1])[0] + return best, scores + + def update_positive(self, text: str, cid: str): + """ + 更新正反馈学习 + + Args: + text: 输入文本 + cid: 目标类别ID + """ + toks = self.tokenizer.tokenize(text) + if not toks: + return + + tf = Counter(toks) + self.nb.update_positive(tf, cid) + + def decay(self, factor: Optional[float] = None): + """ + 应用知识衰减 + + Args: + factor: 衰减因子,如果为None则使用模型配置的gamma + """ + self.nb.decay(factor) + + def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: + """ + 获取候选信息 + + Args: + cid: 候选ID + + Returns: + (style文本, situation文本) + """ + style = self._candidates.get(cid) + situation = self._situations.get(cid) + return style, situation + + def get_all_candidates(self) -> Dict[str, Tuple[str, str]]: + """ + 获取所有候选 + + Returns: + {cid: (style, situation)} + """ + result = {} + for cid in self._candidates.keys(): + style, situation = self.get_candidate_info(cid) + result[cid] = (style, situation) + return result + + def save(self, path: str): + """ + 保存模型到文件 + + Args: + path: 保存路径 + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + + data = { + "candidates": self._candidates, + "situations": self._situations, + "nb_cls_counts": dict(self.nb.cls_counts), + "nb_token_counts": {k: dict(v) for k, v in self.nb.token_counts.items()}, + "nb_alpha": self.nb.alpha, + "nb_beta": self.nb.beta, + "nb_gamma": self.nb.gamma, + "nb_V": self.nb.V, + } + + with open(path, "wb") as f: + pickle.dump(data, f) + + logger.info(f"模型已保存到 {path}") + + def load(self, path: str): + """ + 从文件加载模型 + + Args: + path: 加载路径 + """ + if not os.path.exists(path): + logger.warning(f"模型文件不存在: {path}") + return + + with open(path, "rb") as f: + data = pickle.load(f) + + self._candidates = data["candidates"] + self._situations = data["situations"] + + # 恢复nb模型的参数 + self.nb.alpha = data["nb_alpha"] + self.nb.beta = data["nb_beta"] + self.nb.gamma = data["nb_gamma"] + self.nb.V = data["nb_V"] + + # 恢复统计数据 + self.nb.cls_counts = defaultdict(float, data["nb_cls_counts"]) + self.nb.token_counts = defaultdict(lambda: defaultdict(float)) + for cid, tc in data["nb_token_counts"].items(): + self.nb.token_counts[cid] = defaultdict(float, tc) + + logger.info(f"模型已从 {path} 加载") + + def get_stats(self) -> Dict: + """获取模型统计信息""" + nb_stats = self.nb.get_stats() + return { + "n_candidates": len(self._candidates), + "n_classes": nb_stats["n_classes"], + "n_tokens": nb_stats["n_tokens"], + "total_counts": nb_stats["total_counts"], + } diff --git a/src/chat/express/expressor_model/online_nb.py b/src/chat/express/expressor_model/online_nb.py new file mode 100644 index 000000000..39bd0d1cd --- /dev/null +++ b/src/chat/express/expressor_model/online_nb.py @@ -0,0 +1,142 @@ +""" +在线朴素贝叶斯分类器 +支持增量学习和知识衰减 +""" +import math +from collections import Counter, defaultdict +from typing import Dict, List, Optional + +from src.common.logger import get_logger + +logger = get_logger("expressor.online_nb") + + +class OnlineNaiveBayes: + """在线朴素贝叶斯分类器""" + + def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): + """ + Args: + alpha: 词频平滑参数 + beta: 类别先验平滑参数 + gamma: 衰减因子 (0-1之间,1表示不衰减) + vocab_size: 词汇表大小 + """ + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.V = vocab_size + + # 类别统计 + self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count + self.token_counts: Dict[str, Dict[str, float]] = defaultdict( + lambda: defaultdict(float) + ) # cid -> term -> count + + # 缓存 + self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + + def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: + """ + 批量计算候选的贝叶斯分数 + + Args: + tf: 查询文本的词频Counter + cids: 候选类别ID列表 + + Returns: + 每个候选的分数字典 + """ + total_cls = sum(self.cls_counts.values()) + n_cls = max(1, len(self.cls_counts)) + denom_prior = math.log(total_cls + self.beta * n_cls) + + out: Dict[str, float] = {} + for cid in cids: + # 计算先验概率 log P(c) + prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior + s = prior + + # 计算似然概率 log P(w|c) + logZ = self._logZ_c(cid) + tc = self.token_counts[cid] + + for term, qtf in tf.items(): + num = tc.get(term, 0.0) + self.alpha + s += qtf * (math.log(num) - logZ) + + out[cid] = s + return out + + def update_positive(self, tf: Counter, cid: str): + """ + 正反馈更新 + + Args: + tf: 词频Counter + cid: 类别ID + """ + inc = 0.0 + tc = self.token_counts[cid] + + # 更新词频统计 + for term, c in tf.items(): + tc[term] += float(c) + inc += float(c) + + # 更新类别统计 + self.cls_counts[cid] += inc + self._invalidate(cid) + + def decay(self, factor: Optional[float] = None): + """ + 知识衰减(遗忘机制) + + Args: + factor: 衰减因子,如果为None则使用self.gamma + """ + g = self.gamma if factor is None else factor + if g >= 1.0: + return + + # 对所有统计进行衰减 + for cid in list(self.cls_counts.keys()): + self.cls_counts[cid] *= g + for term in list(self.token_counts[cid].keys()): + self.token_counts[cid][term] *= g + self._invalidate(cid) + + logger.debug(f"应用知识衰减,衰减因子: {g}") + + def _logZ_c(self, cid: str) -> float: + """ + 计算归一化因子logZ + + Args: + cid: 类别ID + + Returns: + log(Z_c) + """ + if cid not in self._logZ: + Z = self.cls_counts[cid] + self.V * self.alpha + self._logZ[cid] = math.log(max(Z, 1e-12)) + return self._logZ[cid] + + def _invalidate(self, cid: str): + """ + 使缓存失效 + + Args: + cid: 类别ID + """ + if cid in self._logZ: + del self._logZ[cid] + + def get_stats(self) -> Dict: + """获取统计信息""" + return { + "n_classes": len(self.cls_counts), + "n_tokens": sum(len(tc) for tc in self.token_counts.values()), + "total_counts": sum(self.cls_counts.values()), + } diff --git a/src/chat/express/expressor_model/tokenizer.py b/src/chat/express/expressor_model/tokenizer.py new file mode 100644 index 000000000..e25f780d4 --- /dev/null +++ b/src/chat/express/expressor_model/tokenizer.py @@ -0,0 +1,62 @@ +""" +文本分词器,支持中文Jieba分词 +""" +from typing import List + +from src.common.logger import get_logger + +logger = get_logger("expressor.tokenizer") + + +class Tokenizer: + """文本分词器,支持中文Jieba分词""" + + def __init__(self, stopwords: set = None, use_jieba: bool = True): + """ + Args: + stopwords: 停用词集合 + use_jieba: 是否使用jieba分词 + """ + self.stopwords = stopwords or set() + self.use_jieba = use_jieba + + if use_jieba: + try: + import jieba + + jieba.initialize() + logger.info("Jieba分词器初始化成功") + except ImportError: + logger.warning("Jieba未安装,将使用字符级分词") + self.use_jieba = False + + def tokenize(self, text: str) -> List[str]: + """ + 分词并返回token列表 + + Args: + text: 输入文本 + + Returns: + token列表 + """ + if not text: + return [] + + # 使用jieba分词 + if self.use_jieba: + try: + import jieba + + tokens = list(jieba.cut(text)) + except Exception as e: + logger.warning(f"Jieba分词失败,使用字符级分词: {e}") + tokens = list(text) + else: + # 简单按字符分词 + tokens = list(text) + + # 过滤停用词和空字符串 + tokens = [token.strip() for token in tokens if token.strip() and token not in self.stopwords] + + return tokens diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py new file mode 100644 index 000000000..8ebe0a8bd --- /dev/null +++ b/src/chat/express/situation_extractor.py @@ -0,0 +1,162 @@ +""" +情境提取器 +从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测 +""" +from typing import Optional + +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +logger = get_logger("situation_extractor") + + +def init_prompt(): + situation_extraction_prompt = """ +以下是正在进行的聊天内容: +{chat_history} + +你的名字是{bot_name}{target_message_info} + +请分析当前聊天的情境特征,提取出最能描述当前情境的1-3个关键场景描述。 + +场景描述应该: +1. 简洁明了(每个不超过20个字) +2. 聚焦情绪、话题、氛围 +3. 不涉及具体人名 +4. 类似于"表示惊讶"、"讨论游戏"、"表达赞同"这样的格式 + +请以纯文本格式输出,每行一个场景描述,不要有序号、引号或其他格式: + +例如: +表示惊讶和意外 +讨论技术问题 +表达友好的赞同 + +现在请提取当前聊天的情境: +""" + Prompt(situation_extraction_prompt, "situation_extraction_prompt") + + +class SituationExtractor: + """情境提取器,从聊天历史中提取当前情境""" + + def __init__(self): + self.llm_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, + request_type="expression.situation_extractor" + ) + + async def extract_situations( + self, + chat_history: list | str, + target_message: Optional[str] = None, + max_situations: int = 3 + ) -> list[str]: + """ + 从聊天历史中提取情境 + + Args: + chat_history: 聊天历史(列表或字符串) + target_message: 目标消息(可选) + max_situations: 最多提取的情境数量 + + Returns: + 情境描述列表 + """ + # 转换chat_history为字符串 + if isinstance(chat_history, list): + chat_info = "\n".join([ + f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" + for msg in chat_history + ]) + else: + chat_info = chat_history + + # 构建目标消息信息 + if target_message: + target_message_info = f",现在你想要回复消息:{target_message}" + else: + target_message_info = "" + + # 构建 prompt + try: + prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format( + bot_name=global_config.bot.nickname, + chat_history=chat_info, + target_message_info=target_message_info + ) + + # 调用 LLM + response, _ = await self.llm_model.generate_response_async( + prompt=prompt, + temperature=0.3 + ) + + if not response or not response.strip(): + logger.warning("LLM返回空响应,无法提取情境") + return [] + + # 解析响应 + situations = self._parse_situations(response, max_situations) + + if situations: + logger.debug(f"提取到 {len(situations)} 个情境: {situations}") + else: + logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}") + + return situations + + except Exception as e: + logger.error(f"提取情境失败: {e}") + return [] + + @staticmethod + def _parse_situations(response: str, max_situations: int) -> list[str]: + """ + 解析 LLM 返回的情境描述 + + Args: + response: LLM 响应 + max_situations: 最多返回的情境数量 + + Returns: + 情境描述列表 + """ + situations = [] + + for line in response.splitlines(): + line = line.strip() + if not line: + continue + + # 移除可能的序号、引号等 + line = line.lstrip('0123456789.、-*>))】] \t"\'""''') + line = line.rstrip('"\'""''') + line = line.strip() + + if not line: + continue + + # 过滤掉明显不是情境描述的内容 + if len(line) > 30: # 太长 + continue + if len(line) < 2: # 太短 + continue + if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']): + continue + + situations.append(line) + + if len(situations) >= max_situations: + break + + return situations + + +# 初始化 prompt +init_prompt() + +# 全局单例 +situation_extractor = SituationExtractor() diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py new file mode 100644 index 000000000..c254ef98c --- /dev/null +++ b/src/chat/express/style_learner.py @@ -0,0 +1,425 @@ +""" +风格学习引擎 +基于ExpressorModel实现的表达风格学习和预测系统 +支持多聊天室独立建模和在线学习 +""" +import os +import time +from typing import Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from .expressor_model import ExpressorModel + +logger = get_logger("expressor.style_learner") + + +class StyleLearner: + """单个聊天室的表达风格学习器""" + + def __init__(self, chat_id: str, model_config: Optional[Dict] = None): + """ + Args: + chat_id: 聊天室ID + model_config: 模型配置 + """ + self.chat_id = chat_id + self.model_config = model_config or { + "alpha": 0.5, + "beta": 0.5, + "gamma": 0.99, # 衰减因子,支持遗忘 + "vocab_size": 200000, + "use_jieba": True, + } + + # 初始化表达模型 + self.expressor = ExpressorModel(**self.model_config) + + # 动态风格管理 + self.max_styles = 2000 # 每个chat_id最多2000个风格 + self.style_to_id: Dict[str, str] = {} # style文本 -> style_id + self.id_to_style: Dict[str, str] = {} # style_id -> style文本 + self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 + self.next_style_id = 0 + + # 学习统计 + self.learning_stats = { + "total_samples": 0, + "style_counts": {}, + "last_update": time.time(), + } + + logger.info(f"StyleLearner初始化成功: chat_id={chat_id}") + + def add_style(self, style: str, situation: Optional[str] = None) -> bool: + """ + 动态添加一个新的风格 + + Args: + style: 风格文本 + situation: 情境文本 + + Returns: + 是否添加成功 + """ + try: + # 检查是否已存在 + if style in self.style_to_id: + return True + + # 检查是否超过最大限制 + if len(self.style_to_id) >= self.max_styles: + logger.warning(f"已达到最大风格数量限制 ({self.max_styles})") + return False + + # 生成新的style_id + style_id = f"style_{self.next_style_id}" + self.next_style_id += 1 + + # 添加到映射 + self.style_to_id[style] = style_id + self.id_to_style[style_id] = style + if situation: + self.id_to_situation[style_id] = situation + + # 添加到expressor模型 + self.expressor.add_candidate(style_id, style, situation) + + # 初始化统计 + self.learning_stats["style_counts"][style_id] = 0 + + logger.debug(f"添加风格成功: {style_id} -> {style}") + return True + + except Exception as e: + logger.error(f"添加风格失败: {e}") + return False + + def learn_mapping(self, up_content: str, style: str) -> bool: + """ + 学习一个up_content到style的映射 + + Args: + up_content: 前置内容 + style: 目标风格 + + Returns: + 是否学习成功 + """ + try: + # 如果style不存在,先添加它 + if style not in self.style_to_id: + if not self.add_style(style): + return False + + # 获取style_id + style_id = self.style_to_id[style] + + # 使用正反馈学习 + self.expressor.update_positive(up_content, style_id) + + # 更新统计 + self.learning_stats["total_samples"] += 1 + self.learning_stats["style_counts"][style_id] += 1 + self.learning_stats["last_update"] = time.time() + + logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}") + return True + + except Exception as e: + logger.error(f"学习映射失败: {e}") + return False + + def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 根据up_content预测最合适的style + + Args: + up_content: 前置内容 + top_k: 返回前k个候选 + + Returns: + (最佳style文本, 所有候选的分数字典) + """ + try: + # 先检查是否有训练数据 + if not self.style_to_id: + logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}") + return None, {} + + best_style_id, scores = self.expressor.predict(up_content, k=top_k) + + if best_style_id is None: + logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...") + return None, {} + + # 将style_id转换为style文本 + best_style = self.id_to_style.get(best_style_id) + + if best_style is None: + logger.warning( + f"style_id无法转换为style文本: style_id={best_style_id}, " + f"已知的id_to_style数量={len(self.id_to_style)}" + ) + return None, {} + + # 转换所有分数 + style_scores = {} + for sid, score in scores.items(): + style_text = self.id_to_style.get(sid) + if style_text: + style_scores[style_text] = score + else: + logger.warning(f"跳过无法转换的style_id: {sid}") + + logger.debug( + f"预测成功: up_content={up_content[:30]}..., " + f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}" + ) + + return best_style, style_scores + + except Exception as e: + logger.error(f"预测style失败: {e}", exc_info=True) + return None, {} + + def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: + """ + 获取style的完整信息 + + Args: + style: 风格文本 + + Returns: + (style_id, situation) + """ + style_id = self.style_to_id.get(style) + if not style_id: + return None, None + + situation = self.id_to_situation.get(style_id) + return style_id, situation + + def get_all_styles(self) -> List[str]: + """ + 获取所有风格列表 + + Returns: + 风格文本列表 + """ + return list(self.style_to_id.keys()) + + def apply_decay(self, factor: Optional[float] = None): + """ + 应用知识衰减 + + Args: + factor: 衰减因子 + """ + self.expressor.decay(factor) + logger.debug(f"应用知识衰减: chat_id={self.chat_id}") + + def save(self, base_path: str) -> bool: + """ + 保存学习器到文件 + + Args: + base_path: 基础保存路径 + + Returns: + 是否保存成功 + """ + try: + # 创建保存目录 + save_dir = os.path.join(base_path, self.chat_id) + os.makedirs(save_dir, exist_ok=True) + + # 保存expressor模型 + model_path = os.path.join(save_dir, "expressor_model.pkl") + self.expressor.save(model_path) + + # 保存映射关系和统计信息 + import pickle + + meta_path = os.path.join(save_dir, "meta.pkl") + meta_data = { + "style_to_id": self.style_to_id, + "id_to_style": self.id_to_style, + "id_to_situation": self.id_to_situation, + "next_style_id": self.next_style_id, + "learning_stats": self.learning_stats, + } + + with open(meta_path, "wb") as f: + pickle.dump(meta_data, f) + + logger.info(f"StyleLearner保存成功: {save_dir}") + return True + + except Exception as e: + logger.error(f"保存StyleLearner失败: {e}") + return False + + def load(self, base_path: str) -> bool: + """ + 从文件加载学习器 + + Args: + base_path: 基础加载路径 + + Returns: + 是否加载成功 + """ + try: + save_dir = os.path.join(base_path, self.chat_id) + + # 检查目录是否存在 + if not os.path.exists(save_dir): + logger.debug(f"StyleLearner保存目录不存在: {save_dir}") + return False + + # 加载expressor模型 + model_path = os.path.join(save_dir, "expressor_model.pkl") + if os.path.exists(model_path): + self.expressor.load(model_path) + + # 加载映射关系和统计信息 + import pickle + + meta_path = os.path.join(save_dir, "meta.pkl") + if os.path.exists(meta_path): + with open(meta_path, "rb") as f: + meta_data = pickle.load(f) + + self.style_to_id = meta_data["style_to_id"] + self.id_to_style = meta_data["id_to_style"] + self.id_to_situation = meta_data["id_to_situation"] + self.next_style_id = meta_data["next_style_id"] + self.learning_stats = meta_data["learning_stats"] + + logger.info(f"StyleLearner加载成功: {save_dir}") + return True + + except Exception as e: + logger.error(f"加载StyleLearner失败: {e}") + return False + + def get_stats(self) -> Dict: + """获取统计信息""" + model_stats = self.expressor.get_stats() + return { + "chat_id": self.chat_id, + "n_styles": len(self.style_to_id), + "total_samples": self.learning_stats["total_samples"], + "last_update": self.learning_stats["last_update"], + "model_stats": model_stats, + } + + +class StyleLearnerManager: + """多聊天室表达风格学习管理器""" + + def __init__(self, model_save_path: str = "data/expression/style_models"): + """ + Args: + model_save_path: 模型保存路径 + """ + self.learners: Dict[str, StyleLearner] = {} + self.model_save_path = model_save_path + + # 确保保存目录存在 + os.makedirs(model_save_path, exist_ok=True) + + logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") + + def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: + """ + 获取或创建指定chat_id的学习器 + + Args: + chat_id: 聊天室ID + model_config: 模型配置 + + Returns: + StyleLearner实例 + """ + if chat_id not in self.learners: + # 创建新的学习器 + learner = StyleLearner(chat_id, model_config) + + # 尝试加载已保存的模型 + learner.load(self.model_save_path) + + self.learners[chat_id] = learner + + return self.learners[chat_id] + + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: + """ + 学习一个映射关系 + + Args: + chat_id: 聊天室ID + up_content: 前置内容 + style: 目标风格 + + Returns: + 是否学习成功 + """ + learner = self.get_learner(chat_id) + return learner.learn_mapping(up_content, style) + + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 预测最合适的风格 + + Args: + chat_id: 聊天室ID + up_content: 前置内容 + top_k: 返回前k个候选 + + Returns: + (最佳style, 分数字典) + """ + learner = self.get_learner(chat_id) + return learner.predict_style(up_content, top_k) + + def save_all(self) -> bool: + """ + 保存所有学习器 + + Returns: + 是否全部保存成功 + """ + success = True + for chat_id, learner in self.learners.items(): + if not learner.save(self.model_save_path): + success = False + + logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}") + return success + + def apply_decay_all(self, factor: Optional[float] = None): + """ + 对所有学习器应用知识衰减 + + Args: + factor: 衰减因子 + """ + for learner in self.learners.values(): + learner.apply_decay(factor) + + logger.info(f"对所有StyleLearner应用知识衰减") + + def get_all_stats(self) -> Dict[str, Dict]: + """ + 获取所有学习器的统计信息 + + Returns: + {chat_id: stats} + """ + return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} + + +# 全局单例 +style_learner_manager = StyleLearnerManager() diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 5bf13081f..c111bf8b4 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -46,6 +46,9 @@ class StreamLoopManager: # 状态控制 self.is_running = False + # 每个流的上一次间隔值(用于日志去重) + self._last_intervals: dict[str, float] = {} + logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})") async def start(self) -> None: @@ -285,7 +288,11 @@ class StreamLoopManager: interval = await self._calculate_interval(stream_id, has_messages) # 6. sleep等待下次检查 - logger.info(f"流 {stream_id} 等待 {interval:.2f}s") + # 只在间隔发生变化时输出日志,避免刷屏 + last_interval = self._last_intervals.get(stream_id) + if last_interval is None or abs(interval - last_interval) > 0.01: + logger.info(f"流 {stream_id} 等待周期变化: {interval:.2f}s") + self._last_intervals[stream_id] = interval await asyncio.sleep(interval) except asyncio.CancelledError: @@ -316,6 +323,9 @@ class StreamLoopManager: except Exception as e: logger.debug(f"释放自适应流处理槽位失败: {e}") + # 清理间隔记录 + self._last_intervals.pop(stream_id, None) + logger.info(f"流循环结束: {stream_id}") async def _get_stream_context(self, stream_id: str) -> Any | None: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 854ca615a..e15dab72a 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -5,6 +5,7 @@ from typing import Any from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.timer_calculator import Timer +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config from src.person_info.person_info import get_person_info_manager @@ -142,7 +143,7 @@ class ChatterActionManager: self, action_name: str, chat_id: str, - target_message: dict | None = None, + target_message: dict | DatabaseMessages | None = None, reasoning: str = "", action_data: dict | None = None, thinking_id: str | None = None, @@ -262,9 +263,15 @@ class ChatterActionManager: from_plugin=False, ) if not success or not response_set: - logger.info( - f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败" - ) + # 安全地获取 processed_plain_text + if isinstance(target_message, DatabaseMessages): + msg_text = target_message.processed_plain_text or "未知消息" + elif target_message: + msg_text = target_message.get("processed_plain_text", "未知消息") + else: + msg_text = "未知消息" + + logger.info(f"对 {msg_text} 的回复生成失败") return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} except asyncio.CancelledError: logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消") @@ -322,8 +329,11 @@ class ChatterActionManager: # 获取目标消息ID target_message_id = None - if target_message and isinstance(target_message, dict): - target_message_id = target_message.get("message_id") + if target_message: + if isinstance(target_message, DatabaseMessages): + target_message_id = target_message.message_id + elif isinstance(target_message, dict): + target_message_id = target_message.get("message_id") elif action_data and isinstance(action_data, dict): target_message_id = action_data.get("target_message_id") @@ -488,14 +498,19 @@ class ChatterActionManager: person_info_manager = get_person_info_manager() # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.get("chat_info_platform") - if platform is None: - platform = getattr(chat_stream, "platform", "unknown") + if isinstance(action_message, DatabaseMessages): + platform = action_message.chat_info.platform + user_id = action_message.user_info.user_id + else: + platform = action_message.get("chat_info_platform") + if platform is None: + platform = getattr(chat_stream, "platform", "unknown") + user_id = action_message.get("user_id", "") # 获取用户信息并生成回复提示 person_id = person_info_manager.get_person_id( platform, - action_message.get("user_id", ""), + user_id, ) person_name = await person_info_manager.get_value(person_id, "person_name") action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" @@ -565,7 +580,14 @@ class ChatterActionManager: # 根据新消息数量决定是否需要引用回复 reply_text = "" - is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True + # 检查是否为主动思考消息 + if isinstance(message_data, DatabaseMessages): + # DatabaseMessages 对象没有 message_type 字段,默认为 False + is_proactive_thinking = False + elif message_data: + is_proactive_thinking = message_data.get("message_type") == "proactive_thinking" + else: + is_proactive_thinking = True logger.debug(f"[send_response] message_data: {message_data}") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 572f88ec1..ef94cf2e3 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality @@ -474,10 +475,13 @@ class DefaultReplyer: style_habits = [] grammar_habits = [] - # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions = await expression_selector.select_suitable_expressions_llm( - self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target + # 使用统一的表达方式选择入口(支持classic和exp_model模式) + selected_expressions = await expression_selector.select_suitable_expressions( + chat_id=self.chat_stream.stream_id, + chat_history=chat_history, + target_message=target, + max_num=8, + min_num=2 ) if selected_expressions: @@ -1208,7 +1212,7 @@ class DefaultReplyer: extra_info: str = "", available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, - reply_message: dict[str, Any] | None = None, + reply_message: dict[str, Any] | DatabaseMessages | None = None, ) -> str: """ 构建回复器上下文 @@ -1250,10 +1254,24 @@ class DefaultReplyer: if reply_message is None: logger.warning("reply_message 为 None,无法构建prompt") return "" - platform = reply_message.get("chat_info_platform") + + # 统一处理 DatabaseMessages 对象和字典 + if isinstance(reply_message, DatabaseMessages): + platform = reply_message.chat_info.platform + user_id = reply_message.user_info.user_id + user_nickname = reply_message.user_info.user_nickname + user_cardname = reply_message.user_info.user_cardname + processed_plain_text = reply_message.processed_plain_text + else: + platform = reply_message.get("chat_info_platform") + user_id = reply_message.get("user_id") + user_nickname = reply_message.get("user_nickname") + user_cardname = reply_message.get("user_cardname") + processed_plain_text = reply_message.get("processed_plain_text") + person_id = person_info_manager.get_person_id( platform, # type: ignore - reply_message.get("user_id"), # type: ignore + user_id, # type: ignore ) person_name = await person_info_manager.get_value(person_id, "person_name") @@ -1262,22 +1280,22 @@ class DefaultReplyer: # 尝试从reply_message获取用户名 await person_info_manager.first_knowing_some_one( platform, # type: ignore - reply_message.get("user_id"), # type: ignore - reply_message.get("user_nickname") or "", - reply_message.get("user_cardname") or "", + user_id, # type: ignore + user_nickname or "", + user_cardname or "", ) # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) current_user_id = await person_info_manager.get_value(person_id, "user_id") - current_platform = reply_message.get("chat_info_platform") + current_platform = platform if current_user_id == bot_user_id and current_platform == global_config.bot.platform: sender = f"{person_name}(你)" else: # 如果不是bot自己,直接使用person_name sender = person_name - target = reply_message.get("processed_plain_text") + target = processed_plain_text # 最终的空值检查,确保sender和target不为None if sender is None: @@ -1611,15 +1629,22 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - reply_message: dict[str, Any] | None = None, + reply_message: dict[str, Any] | DatabaseMessages | None = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) if reply_message: - sender = reply_message.get("sender") - target = reply_message.get("target") + if isinstance(reply_message, DatabaseMessages): + # 从 DatabaseMessages 对象获取 sender 和 target + # 注意: DatabaseMessages 没有直接的 sender/target 字段 + # 需要根据实际情况构造 + sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id + target = reply_message.processed_plain_text or "" + else: + sender = reply_message.get("sender") + target = reply_message.get("target") else: sender, target = self._parse_reply_target(reply_to) @@ -1891,42 +1916,64 @@ class DefaultReplyer: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" - # 使用统一评分API获取关系信息 + # 使用 RelationshipFetcher 获取完整关系信息(包含新字段) try: - from src.plugin_system.apis.scoring_api import scoring_api + from src.person_info.relationship_fetcher import relationship_fetcher_manager - # 获取用户信息以获取真实的user_id - user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"]) - user_id = user_info.get("user_id", "unknown") + # 获取 chat_id + chat_id = self.chat_stream.stream_id - # 从统一API获取关系数据 - relationship_data = await scoring_api.get_user_relationship_data(user_id) - if relationship_data: - relationship_text = relationship_data.get("relationship_text", "") - relationship_score = relationship_data.get("relationship_score", 0.3) + # 获取 RelationshipFetcher 实例 + relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) - # 构建丰富的关系信息描述 - if relationship_text: - # 转换关系分数为描述性文本 - if relationship_score >= 0.8: - relationship_level = "非常亲密的朋友" - elif relationship_score >= 0.6: - relationship_level = "好朋友" - elif relationship_score >= 0.4: - relationship_level = "普通朋友" - elif relationship_score >= 0.2: - relationship_level = "认识的人" - else: - relationship_level = "陌生人" + # 构建用户关系信息(包含别名、偏好关键词等新字段) + user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5) - return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}" - else: - return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。" + # 构建聊天流印象信息 + stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id) + + # 组合两部分信息 + if user_relation_info and stream_impression: + return "\n\n".join([user_relation_info, stream_impression]) + elif user_relation_info: + return user_relation_info + elif stream_impression: + return stream_impression else: return f"你完全不认识{sender},这是第一次互动。" except Exception as e: logger.error(f"获取关系信息失败: {e}") + # 降级到基本信息 + try: + from src.plugin_system.apis.scoring_api import scoring_api + + user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"]) + user_id = user_info.get("user_id", "unknown") + + relationship_data = await scoring_api.get_user_relationship_data(user_id) + if relationship_data: + relationship_text = relationship_data.get("relationship_text", "") + relationship_score = relationship_data.get("relationship_score", 0.3) + + if relationship_text: + if relationship_score >= 0.8: + relationship_level = "非常亲密的朋友" + elif relationship_score >= 0.6: + relationship_level = "好朋友" + elif relationship_score >= 0.4: + relationship_level = "普通朋友" + elif relationship_score >= 0.2: + relationship_level = "认识的人" + else: + relationship_level = "陌生人" + + return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}" + else: + return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。" + except Exception: + pass + return f"你与{sender}是普通朋友关系。" async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None): diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 09c0dad95..2e141e6ad 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -606,11 +606,11 @@ class Prompt: recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - # 使用LLM选择与当前情景匹配的表达习惯 + # 使用统一的表达方式选择入口(支持classic和exp_model模式) expression_selector = ExpressionSelector(self.parameters.chat_id) - selected_expressions = await expression_selector.select_suitable_expressions_llm( + selected_expressions = await expression_selector.select_suitable_expressions( chat_id=self.parameters.chat_id, - chat_info=chat_history, + chat_history=chat_history, target_message=self.parameters.target, ) @@ -1109,8 +1109,18 @@ class Prompt: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" - # 使用关系提取器构建关系信息 - return await relationship_fetcher.build_relation_info(person_id, points_num=5) + # 使用关系提取器构建用户关系信息和聊天流印象 + user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5) + stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id) + + # 组合两部分信息 + info_parts = [] + if user_relation_info: + info_parts.append(user_relation_info) + if stream_impression: + info_parts.append(stream_impression) + + return "\n\n".join(info_parts) if info_parts else "" def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]: """为超时或失败的异步构建任务提供一个安全的默认返回值. diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 4d8046e16..9f03aa43c 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -140,6 +140,11 @@ class ChatStreams(Base): consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) # 消息打断系统字段 interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) + # 聊天流印象字段 + stream_impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 对聊天流的主观印象描述 + stream_chat_style: Mapped[str | None] = mapped_column(Text, nullable=True) # 聊天流的总体风格 + stream_topic_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 话题关键词,逗号分隔 + stream_interest_score: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 对聊天流的兴趣程度(0-1) __table_args__ = ( Index("idx_chatstreams_stream_id", "stream_id"), @@ -877,7 +882,9 @@ class UserRelationships(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) + user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔 relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True) + preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔 relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1) last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time) created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 3f74fca82..3f7622115 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -187,6 +187,10 @@ class ExpressionRule(ValidatedConfigBase): class ExpressionConfig(ValidatedConfigBase): """表达配置类""" + mode: Literal["classic", "exp_model"] = Field( + default="classic", + description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测" + ) rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @staticmethod diff --git a/src/main.py b/src/main.py index 941814435..1400b3568 100644 --- a/src/main.py +++ b/src/main.py @@ -432,20 +432,6 @@ MoFox_Bot(第三方修改版) get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") - """ - # 初始化回复后关系追踪系统 - try: - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system - from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker - - relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system) - chatter_interest_scoring_system.relationship_tracker = relationship_tracker - logger.info("回复后关系追踪系统初始化成功") - except Exception as e: - logger.error(f"回复后关系追踪系统初始化失败: {e}") - relationship_tracker = None - """ - # 启动情绪管理器 await mood_manager.start() logger.info("情绪管理器初始化成功") diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 7b0dca370..eba734184 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -107,10 +107,13 @@ class PromptBuilder: style_habits = [] grammar_habits = [] - # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions = await expression_selector.select_suitable_expressions_llm( - chat_stream.stream_id, chat_history, max_num=12, min_num=5, target_message=target + # 使用统一的表达方式选择入口(支持classic和exp_model模式) + selected_expressions = await expression_selector.select_suitable_expressions( + chat_id=chat_stream.stream_id, + chat_history=chat_history, + target_message=target, + max_num=12, + min_num=5 ) if selected_expressions: @@ -163,13 +166,25 @@ class PromptBuilder: person_id = PersonInfoManager.get_person_id(person[0], person[1]) person_ids.append(person_id) - # 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为 - relation_info_list = await asyncio.gather( - *[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] - ) - if relation_info := "".join(relation_info_list): + # 构建用户关系信息和聊天流印象信息 + user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] + stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id) + + # 并行获取所有信息 + results = await asyncio.gather(*user_relation_tasks, stream_impression_task) + relation_info_list = results[:-1] # 用户关系信息 + stream_impression = results[-1] # 聊天流印象 + + # 组合用户关系信息和聊天流印象 + combined_info_parts = [] + if user_relation_info := "".join(relation_info_list): + combined_info_parts.append(user_relation_info) + if stream_impression: + combined_info_parts.append(stream_impression) + + if combined_info := "\n\n".join(combined_info_parts): relation_prompt = await global_prompt_manager.format_prompt( - "relation_prompt", relation_info=relation_info + "relation_prompt", relation_info=combined_info ) return relation_prompt diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 8783d5e7f..2f20ea5be 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -120,13 +120,15 @@ class RelationshipFetcher: know_since = await person_info_manager.get_value(person_id, "know_since") last_know = await person_info_manager.get_value(person_id, "last_know") - # 如果用户没有基本信息,返回默认描述 - if person_name == nickname_str and not short_impression and not full_impression: - return f"你完全不认识{person_name},这是你们第一次交流。" - # 获取用户特征点 current_points = await person_info_manager.get_value(person_id, "points") or [] forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] + + # 确保 points 是列表类型(可能从数据库返回字符串) + if not isinstance(current_points, list): + current_points = [] + if not isinstance(forgotten_points, list): + forgotten_points = [] # 按时间排序并选择最有代表性的特征点 all_points = current_points + forgotten_points @@ -177,28 +179,48 @@ class RelationshipFetcher: if points_text: relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}") - # 5. 从UserRelationships表获取额外关系信息 + # 5. 从UserRelationships表获取完整关系信息(新系统) try: from src.common.database.sqlalchemy_database_api import db_query from src.common.database.sqlalchemy_models import UserRelationships - # 查询用户关系数据 + # 查询用户关系数据(修复:添加 await) + user_id = str(await person_info_manager.get_value(person_id, "user_id")) relationships = await db_query( UserRelationships, - filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))], + filters={"user_id": user_id}, limit=1, ) if relationships: + # db_query 返回字典列表,使用字典访问方式 rel_data = relationships[0] - if rel_data.relationship_text: - relation_parts.append(f"关系记录:{rel_data.relationship_text}") - if rel_data.relationship_score: - score_desc = self._get_relationship_score_description(rel_data.relationship_score) - relation_parts.append(f"关系亲密程度:{score_desc}") + + # 5.1 用户别名 + if rel_data.get("user_aliases"): + aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()] + if aliases_list: + aliases_str = "、".join(aliases_list) + relation_parts.append(f"{person_name}的别名有:{aliases_str}") + + # 5.2 关系印象文本(主观认知) + if rel_data.get("relationship_text"): + relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}") + + # 5.3 用户偏好关键词 + if rel_data.get("preference_keywords"): + keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()] + if keywords_list: + keywords_str = "、".join(keywords_list) + relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}") + + # 5.4 关系亲密程度(好感分数) + if rel_data.get("relationship_score") is not None: + score_desc = self._get_relationship_score_description(rel_data["relationship_score"]) + relation_parts.append(f"你们的关系程度:{score_desc}({rel_data['relationship_score']:.2f})") except Exception as e: - logger.debug(f"查询UserRelationships表失败: {e}") + logger.error(f"查询UserRelationships表失败: {e}", exc_info=True) # 构建最终的关系信息字符串 if relation_parts: @@ -206,10 +228,90 @@ class RelationshipFetcher: [f"• {part}" for part in relation_parts] ) else: - relation_info = f"你对{person_name}了解不多,这是比较初步的交流。" + # 只有当所有数据源都没有信息时才返回默认文本 + relation_info = f"你完全不认识{person_name},这是你们第一次交流。" return relation_info + async def build_chat_stream_impression(self, stream_id: str) -> str: + """构建聊天流的印象信息 + + Args: + stream_id: 聊天流ID + + Returns: + str: 格式化后的聊天流印象字符串 + """ + try: + from src.common.database.sqlalchemy_database_api import db_query + from src.common.database.sqlalchemy_models import ChatStreams + + # 查询聊天流数据 + streams = await db_query( + ChatStreams, + filters={"stream_id": stream_id}, + limit=1, + ) + + if not streams: + return "" + + # db_query 返回字典列表,使用字典访问方式 + stream_data = streams[0] + impression_parts = [] + + # 1. 聊天环境基本信息 + if stream_data.get("group_name"): + impression_parts.append(f"这是一个名为「{stream_data['group_name']}」的群聊") + else: + impression_parts.append("这是一个私聊对话") + + # 2. 聊天流的主观印象 + if stream_data.get("stream_impression_text"): + impression_parts.append(f"你对这个聊天环境的印象:{stream_data['stream_impression_text']}") + + # 3. 聊天风格 + if stream_data.get("stream_chat_style"): + impression_parts.append(f"这里的聊天风格:{stream_data['stream_chat_style']}") + + # 4. 常见话题 + if stream_data.get("stream_topic_keywords"): + topics_list = [topic.strip() for topic in stream_data["stream_topic_keywords"].split(",") if topic.strip()] + if topics_list: + topics_str = "、".join(topics_list) + impression_parts.append(f"这里常讨论的话题:{topics_str}") + + # 5. 兴趣程度 + if stream_data.get("stream_interest_score") is not None: + interest_desc = self._get_interest_score_description(stream_data["stream_interest_score"]) + impression_parts.append(f"你对这个聊天环境的兴趣程度:{interest_desc}({stream_data['stream_interest_score']:.2f})") + + # 构建最终的印象信息字符串 + if impression_parts: + impression_info = "关于当前的聊天环境:\n" + "\n".join( + [f"• {part}" for part in impression_parts] + ) + return impression_info + else: + return "" + + except Exception as e: + logger.debug(f"查询ChatStreams表失败: {e}") + return "" + + def _get_interest_score_description(self, score: float) -> str: + """根据兴趣分数返回描述性文字""" + if score >= 0.8: + return "非常感兴趣,很喜欢这里的氛围" + elif score >= 0.6: + return "比较感兴趣,愿意积极参与" + elif score >= 0.4: + return "一般兴趣,会适度参与" + elif score >= 0.2: + return "兴趣不大,较少主动参与" + else: + return "不太感兴趣,很少参与" + def _get_attitude_description(self, attitude: int) -> str: """根据态度分数返回描述性文字""" if attitude >= 80: diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 3e52cf4c5..96f0e4b09 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -108,52 +108,79 @@ def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | """查找要回复的消息 Args: - message_dict: 消息字典 + message_dict: 消息字典或 DatabaseMessages 对象 Returns: Optional[MessageRecv]: 找到的消息,如果没找到则返回None """ + # 兼容 DatabaseMessages 对象和字典 + if isinstance(message_dict, dict): + user_platform = message_dict.get("user_platform", "") + user_id = message_dict.get("user_id", "") + user_nickname = message_dict.get("user_nickname", "") + user_cardname = message_dict.get("user_cardname", "") + chat_info_group_id = message_dict.get("chat_info_group_id") + chat_info_group_platform = message_dict.get("chat_info_group_platform", "") + chat_info_group_name = message_dict.get("chat_info_group_name", "") + chat_info_platform = message_dict.get("chat_info_platform", "") + message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") + time_val = message_dict.get("time") + additional_config = message_dict.get("additional_config") + processed_plain_text = message_dict.get("processed_plain_text") + else: + # DatabaseMessages 对象 + user_platform = getattr(message_dict, "user_platform", "") + user_id = getattr(message_dict, "user_id", "") + user_nickname = getattr(message_dict, "user_nickname", "") + user_cardname = getattr(message_dict, "user_cardname", "") + chat_info_group_id = getattr(message_dict, "chat_info_group_id", None) + chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "") + chat_info_group_name = getattr(message_dict, "chat_info_group_name", "") + chat_info_platform = getattr(message_dict, "chat_info_platform", "") + message_id = getattr(message_dict, "message_id", None) + time_val = getattr(message_dict, "time", None) + additional_config = getattr(message_dict, "additional_config", None) + processed_plain_text = getattr(message_dict, "processed_plain_text", "") + # 构建MessageRecv对象 user_info = { - "platform": message_dict.get("user_platform", ""), - "user_id": message_dict.get("user_id", ""), - "user_nickname": message_dict.get("user_nickname", ""), - "user_cardname": message_dict.get("user_cardname", ""), + "platform": user_platform, + "user_id": user_id, + "user_nickname": user_nickname, + "user_cardname": user_cardname, } group_info = {} - if message_dict.get("chat_info_group_id"): + if chat_info_group_id: group_info = { - "platform": message_dict.get("chat_info_group_platform", ""), - "group_id": message_dict.get("chat_info_group_id", ""), - "group_name": message_dict.get("chat_info_group_name", ""), + "platform": chat_info_group_platform, + "group_id": chat_info_group_id, + "group_name": chat_info_group_name, } format_info = {"content_format": "", "accept_format": ""} template_info = {"template_items": {}} message_info = { - "platform": message_dict.get("chat_info_platform", ""), - "message_id": message_dict.get("message_id") - or message_dict.get("chat_info_message_id") - or message_dict.get("id"), - "time": message_dict.get("time"), + "platform": chat_info_platform, + "message_id": message_id, + "time": time_val, "group_info": group_info, "user_info": user_info, - "additional_config": message_dict.get("additional_config"), + "additional_config": additional_config, "format_info": format_info, "template_info": template_info, } new_message_dict = { "message_info": message_info, - "raw_message": message_dict.get("processed_plain_text"), - "processed_plain_text": message_dict.get("processed_plain_text"), + "raw_message": processed_plain_text, + "processed_plain_text": processed_plain_text, } message_recv = MessageRecv(new_message_dict) - logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") + logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}") return message_recv diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index 01ce4c7dc..2eac60402 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -7,8 +7,16 @@ from src.plugin_system.base.component_types import ComponentType logger = get_logger("tool_api") -def get_tool_instance(tool_name: str) -> BaseTool | None: - """获取公开工具实例""" +def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | None: + """获取公开工具实例 + + Args: + tool_name: 工具名称 + chat_stream: 聊天流对象,用于提供上下文信息 + + Returns: + BaseTool: 工具实例,如果工具不存在则返回None + """ from src.plugin_system.core import component_registry # 获取插件配置 @@ -19,7 +27,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None: plugin_config = None tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore - return tool_class(plugin_config) if tool_class else None + return tool_class(plugin_config, chat_stream) if tool_class else None def get_llm_available_tool_definitions() -> list[dict[str, Any]]: diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 5790d2312..b5071e578 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING from src.chat.message_receive.chat_stream import ChatStream +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.plugin_system.apis import database_api, message_api, send_api from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType @@ -180,11 +181,18 @@ class BaseAction(ABC): if self.has_action_message: if self.action_name != "no_reply": - self.group_id = str(self.action_message.get("chat_info_group_id", None)) - self.group_name = self.action_message.get("chat_info_group_name", None) - - self.user_id = str(self.action_message.get("user_id", None)) - self.user_nickname = self.action_message.get("user_nickname", None) + # 统一处理 DatabaseMessages 对象和字典 + if isinstance(self.action_message, DatabaseMessages): + self.group_id = str(self.action_message.group_info.group_id if self.action_message.group_info else None) + self.group_name = self.action_message.group_info.group_name if self.action_message.group_info else None + self.user_id = str(self.action_message.user_info.user_id) + self.user_nickname = self.action_message.user_info.user_nickname + else: + self.group_id = str(self.action_message.get("chat_info_group_id", None)) + self.group_name = self.action_message.get("chat_info_group_name", None) + self.user_id = str(self.action_message.get("user_id", None)) + self.user_nickname = self.action_message.get("user_nickname", None) + if self.group_id: self.is_group = True self.target_id = self.group_id diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 5cd04b485..5ad4c6dbc 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -47,8 +47,9 @@ class BaseTool(ABC): sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = [] """子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用""" - def __init__(self, plugin_config: dict | None = None): + def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 + self.chat_stream = chat_stream # 存储聊天流信息,可用于获取上下文 @classmethod def get_tool_definition(cls) -> dict[str, Any]: diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 44d47eb9f..14e6fcd7c 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -226,7 +226,7 @@ class ToolExecutor: """执行单个工具调用,并处理缓存""" function_args = tool_call.args or {} - tool_instance = tool_instance or get_tool_instance(tool_call.func_name) + tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream) # 如果工具不存在或未启用缓存,则直接执行 if not tool_instance or not tool_instance.enable_cache: @@ -320,7 +320,7 @@ class ToolExecutor: parts = function_name.split("_", 1) if len(parts) == 2: base_tool_name, sub_tool_name = parts - base_tool_instance = get_tool_instance(base_tool_name) + base_tool_instance = get_tool_instance(base_tool_name, self.chat_stream) if base_tool_instance and base_tool_instance.is_two_step_tool: logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}") @@ -340,7 +340,7 @@ class ToolExecutor: } # 获取对应工具实例 - tool_instance = tool_instance or get_tool_instance(function_name) + tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream) if not tool_instance: logger.warning(f"未知工具名称: {function_name}") return None diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py index 524dfb80d..abf581203 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py @@ -209,13 +209,13 @@ class AffinityInterestCalculator(BaseInterestCalculator): relationship_value = self.user_relationships[user_id] return min(relationship_value, 1.0) - # 如果内存中没有,尝试从关系追踪器获取 + # 如果内存中没有,尝试从统一的评分API获取 try: - from .relationship_tracker import ChatterRelationshipTracker + from src.plugin_system.apis.scoring_api import scoring_api - global_tracker = ChatterRelationshipTracker() - if global_tracker: - relationship_score = await global_tracker.get_user_relationship_score(user_id) + relationship_data = await scoring_api.get_user_relationship_data(user_id) + if relationship_data: + relationship_score = relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) # 同时更新内存缓存 self.user_relationships[user_id] = relationship_score return relationship_score diff --git a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py new file mode 100644 index 000000000..06be93119 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py @@ -0,0 +1,363 @@ +""" +聊天流印象更新工具 + +通过LLM二步调用机制更新对聊天流(如QQ群)的整体印象,包括主观描述、聊天风格、话题关键词和兴趣分数 +""" + +import json +import time +from typing import Any + +from sqlalchemy import select + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import ChatStreams +from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.plugin_system import BaseTool, ToolParamType + +logger = get_logger("chat_stream_impression_tool") + + +class ChatStreamImpressionTool(BaseTool): + """聊天流印象更新工具 + + 使用二步调用机制: + 1. LLM决定是否调用工具并传入初步参数(stream_id会自动传入) + 2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容 + """ + + name = "update_chat_stream_impression" + description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。" + parameters = [ + ("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None), + ("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None), + ("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None), + ("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None), + ] + available_for_llm = True + history_ttl = 5 + + def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): + super().__init__(plugin_config, chat_stream) + + # 初始化用于二步调用的LLM + try: + self.impression_llm = LLMRequest( + model_set=model_config.model_task_config.relationship_tracker, + request_type="chat_stream_impression_update" + ) + except AttributeError: + # 降级处理 + available_models = [ + attr for attr in dir(model_config.model_task_config) + if not attr.startswith("_") and attr != "model_dump" + ] + if available_models: + fallback_model = available_models[0] + logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}") + self.impression_llm = LLMRequest( + model_set=getattr(model_config.model_task_config, fallback_model), + request_type="chat_stream_impression_update" + ) + else: + logger.error("无可用的模型配置") + self.impression_llm = None + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """执行聊天流印象更新 + + Args: + function_args: 工具参数,stream_id会由系统自动注入 + + Returns: + dict: 执行结果 + """ + try: + # stream_id应该由调用方(如工具执行器)自动注入 + # 如果没有注入,尝试从上下文获取 + stream_id = function_args.get("stream_id") + if not stream_id: + # 尝试从其他可能的来源获取 + logger.warning("stream_id未自动注入,尝试从其他来源获取") + # 这里可以添加从上下文获取的逻辑 + return { + "type": "error", + "id": "chat_stream_impression", + "content": "错误:无法获取当前聊天流ID" + } + + # 从LLM传入的参数 + new_impression = function_args.get("impression_description", "") + new_style = function_args.get("chat_style", "") + new_topics = function_args.get("topic_keywords", "") + new_score = function_args.get("interest_score") + + # 从数据库获取现有聊天流印象 + existing_impression = await self._get_stream_impression(stream_id) + + # 如果LLM没有传入任何有效参数,返回提示 + if not any([new_impression, new_style, new_topics, new_score is not None]): + return { + "type": "info", + "id": stream_id, + "content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)" + } + + # 调用LLM进行二步决策 + if self.impression_llm is None: + logger.error("LLM未正确初始化,无法执行二步调用") + return { + "type": "error", + "id": stream_id, + "content": "系统错误:LLM未正确初始化" + } + + final_impression = await self._llm_decide_final_impression( + stream_id=stream_id, + existing_impression=existing_impression, + new_impression=new_impression, + new_style=new_style, + new_topics=new_topics, + new_score=new_score + ) + + if not final_impression: + return { + "type": "error", + "id": stream_id, + "content": "LLM决策失败,无法更新聊天流印象" + } + + # 更新数据库 + await self._update_stream_impression_in_db(stream_id, final_impression) + + # 构建返回信息 + updates = [] + if final_impression.get("stream_impression_text"): + updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...") + if final_impression.get("stream_chat_style"): + updates.append(f"风格: {final_impression['stream_chat_style']}") + if final_impression.get("stream_topic_keywords"): + updates.append(f"话题: {final_impression['stream_topic_keywords']}") + if final_impression.get("stream_interest_score") is not None: + updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}") + + result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates) + logger.info(f"聊天流印象更新成功: {stream_id}") + + return { + "type": "chat_stream_impression_update", + "id": stream_id, + "content": result_text + } + + except Exception as e: + logger.error(f"聊天流印象更新失败: {e}", exc_info=True) + return { + "type": "error", + "id": function_args.get("stream_id", "unknown"), + "content": f"聊天流印象更新失败: {str(e)}" + } + + async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]: + """从数据库获取聊天流现有印象 + + Args: + stream_id: 聊天流ID + + Returns: + dict: 聊天流印象数据 + """ + try: + async with get_db_session() as session: + stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) + result = await session.execute(stmt) + stream = result.scalar_one_or_none() + + if stream: + return { + "stream_impression_text": stream.stream_impression_text or "", + "stream_chat_style": stream.stream_chat_style or "", + "stream_topic_keywords": stream.stream_topic_keywords or "", + "stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5, + "group_name": stream.group_name or "私聊", + } + else: + # 聊天流不存在,返回默认值 + return { + "stream_impression_text": "", + "stream_chat_style": "", + "stream_topic_keywords": "", + "stream_interest_score": 0.5, + "group_name": "未知", + } + except Exception as e: + logger.error(f"获取聊天流印象失败: {e}") + return { + "stream_impression_text": "", + "stream_chat_style": "", + "stream_topic_keywords": "", + "stream_interest_score": 0.5, + "group_name": "未知", + } + + async def _llm_decide_final_impression( + self, + stream_id: str, + existing_impression: dict[str, Any], + new_impression: str, + new_style: str, + new_topics: str, + new_score: float | None + ) -> dict[str, Any] | None: + """使用LLM决策最终的聊天流印象内容 + + Args: + stream_id: 聊天流ID + existing_impression: 现有印象数据 + new_impression: LLM传入的新印象 + new_style: LLM传入的新风格 + new_topics: LLM传入的新话题 + new_score: LLM传入的新分数 + + Returns: + dict: 最终决定的印象数据,如果失败返回None + """ + try: + # 获取bot人设 + from src.individuality.individuality import Individuality + individuality = Individuality() + bot_personality = await individuality.get_personality_block() + + prompt = f""" +你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} + +你正在更新对聊天流 {stream_id} 的整体印象。 + +【当前聊天流信息】 +- 聊天环境: {existing_impression.get('group_name', '未知')} +- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')} +- 聊天风格: {existing_impression.get('stream_chat_style', '未知')} +- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')} +- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f} + +【本次想要更新的内容】 +- 新的印象描述: {new_impression if new_impression else '不更新'} +- 新的聊天风格: {new_style if new_style else '不更新'} +- 新的话题关键词: {new_topics if new_topics else '不更新'} +- 新的兴趣分数: {new_score if new_score is not None else '不更新'} + +请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意: +1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字) +2. 聊天风格:如果提供了新风格,应该用简洁的词语概括,如"活跃轻松"、"严肃专业"、"幽默随性"等 +3. 话题关键词:如果提供了新话题,应该与现有话题合并(去重),保留最核心和频繁的话题 +4. 兴趣分数:如果提供了新分数,需要结合现有分数合理调整(0.0表示完全不感兴趣,1.0表示非常感兴趣) + +请以JSON格式返回最终决定: +{{ + "stream_impression_text": "最终的印象描述(100-200字),整体性的对这个聊天环境的认知", + "stream_chat_style": "最终的聊天风格,简洁概括", + "stream_topic_keywords": "最终的话题关键词,逗号分隔", + "stream_interest_score": 最终的兴趣分数(0.0-1.0), + "reasoning": "你的决策理由" +}} +""" + + # 调用LLM + llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt) + + if not llm_response: + logger.warning("LLM未返回有效响应") + return None + + # 清理并解析响应 + cleaned_response = self._clean_llm_json_response(llm_response) + response_data = json.loads(cleaned_response) + + # 提取最终决定的数据 + final_impression = { + "stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")), + "stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")), + "stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")), + "stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))), + } + + logger.info(f"LLM决策完成: {stream_id}") + logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") + + return final_impression + + except json.JSONDecodeError as e: + logger.error(f"LLM响应JSON解析失败: {e}") + logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") + return None + except Exception as e: + logger.error(f"LLM决策失败: {e}", exc_info=True) + return None + + async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]): + """更新数据库中的聊天流印象 + + Args: + stream_id: 聊天流ID + impression: 印象数据 + """ + try: + async with get_db_session() as session: + stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + # 更新现有记录 + existing.stream_impression_text = impression.get("stream_impression_text", "") + existing.stream_chat_style = impression.get("stream_chat_style", "") + existing.stream_topic_keywords = impression.get("stream_topic_keywords", "") + existing.stream_interest_score = impression.get("stream_interest_score", 0.5) + + await session.commit() + logger.info(f"聊天流印象已更新到数据库: {stream_id}") + else: + error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象" + logger.error(error_msg) + # 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录 + raise ValueError(error_msg) + + except Exception as e: + logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True) + raise + + def _clean_llm_json_response(self, response: str) -> str: + """清理LLM响应,移除可能的JSON格式标记 + + Args: + response: LLM原始响应 + + Returns: + str: 清理后的JSON字符串 + """ + try: + import re + + cleaned = response.strip() + + # 移除 ```json 或 ``` 等标记 + cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) + cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) + + # 尝试找到JSON对象的开始和结束 + json_start = cleaned.find("{") + json_end = cleaned.rfind("}") + + if json_start != -1 and json_end != -1 and json_end > json_start: + cleaned = cleaned[json_start:json_end + 1] + + cleaned = cleaned.strip() + + return cleaned + + except Exception as e: + logger.warning(f"清理LLM响应失败: {e}") + return response diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index 53c327561..3af389f9f 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -45,13 +45,6 @@ class ChatterPlanExecutor: "execution_times": [], } - # 用户关系追踪引用 - self.relationship_tracker = None - - def set_relationship_tracker(self, relationship_tracker): - """设置关系追踪器""" - self.relationship_tracker = relationship_tracker - async def execute(self, plan: Plan) -> dict[str, Any]: """ 遍历并执行Plan对象中`decided_actions`列表里的所有动作。 @@ -238,19 +231,11 @@ class ChatterPlanExecutor: except Exception as e: error_message = str(e) logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}") - # 记录用户关系追踪 - 使用后台异步执行,防止阻塞主流程 + + # 将机器人回复添加到已读消息中 if success and action_info.action_message: - logger.debug(f"准备执行关系追踪: success={success}, action_message存在={bool(action_info.action_message)}") - logger.debug(f"关系追踪器状态: {self.relationship_tracker is not None}") - - # 直接使用后台异步任务执行关系追踪,避免阻塞主回复流程 - import asyncio - asyncio.create_task(self._track_user_interaction(action_info, plan, reply_content)) - logger.debug("关系追踪已启动为后台异步任务") - else: - logger.debug(f"跳过关系追踪: success={success}, action_message存在={bool(action_info.action_message)}") - # 将机器人回复添加到已读消息中 await self._add_bot_reply_to_read_messages(action_info, plan, reply_content) + execution_time = time.time() - start_time self.execution_stats["execution_times"].append(execution_time) @@ -356,81 +341,6 @@ class ChatterPlanExecutor: "reasoning": action_info.reasoning, } - async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str): - """追踪用户交互 - 集成回复后关系追踪""" - try: - logger.debug("🔍 开始执行用户交互追踪") - - if not action_info.action_message: - logger.debug("❌ 跳过追踪:action_message为空") - return - - # 获取用户信息 - 处理DatabaseMessages对象 - if hasattr(action_info.action_message, "user_id"): - # DatabaseMessages对象情况 - user_id = action_info.action_message.user_id - user_name = action_info.action_message.user_nickname or user_id - # 使用processed_plain_text作为消息内容,如果没有则使用display_message - user_message = ( - action_info.action_message.processed_plain_text - or action_info.action_message.display_message - or "" - ) - logger.debug(f"📝 从DatabaseMessages获取用户信息: user_id={user_id}, user_name={user_name}") - else: - # 字典情况(向后兼容)- 适配扁平化消息字典结构 - # 首先尝试从扁平化结构直接获取用户信息 - user_id = action_info.action_message.get("user_id") - user_name = action_info.action_message.get("user_nickname") or user_id - - # 如果扁平化结构中没有用户信息,再尝试从嵌套的user_info获取 - if not user_id: - user_info = action_info.action_message.get("user_info", {}) - user_id = user_info.get("user_id") - user_name = user_info.get("user_nickname") or user_id - logger.debug(f"📝 从嵌套user_info获取用户信息: user_id={user_id}, user_name={user_name}") - else: - logger.debug(f"📝 从扁平化结构获取用户信息: user_id={user_id}, user_name={user_name}") - - # 获取消息内容,优先使用processed_plain_text - user_message = ( - action_info.action_message.get("processed_plain_text", "") - or action_info.action_message.get("display_message", "") - or action_info.action_message.get("content", "") - ) - - if not user_id: - logger.debug("❌ 跳过追踪:缺少用户ID") - return - - # 如果有设置关系追踪器,执行回复后关系追踪 - if self.relationship_tracker: - logger.debug(f"✅ 关系追踪器存在,开始为用户 {user_id} 执行追踪") - - # 记录基础交互信息(保持向后兼容) - self.relationship_tracker.add_interaction( - user_id=user_id, - user_name=user_name, - user_message=user_message, - bot_reply=reply_content, - reply_timestamp=time.time(), - ) - logger.debug(f"📊 已添加基础交互信息: {user_name}({user_id})") - - # 执行新的回复后关系追踪 - await self.relationship_tracker.track_reply_relationship( - user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time() - ) - logger.debug(f"🎯 已执行回复后关系追踪: {user_id}") - - else: - logger.debug("❌ 关系追踪器不存在,跳过追踪") - - except Exception as e: - logger.error(f"追踪用户交互时出错: {e}") - logger.debug(f"action_message类型: {type(action_info.action_message)}") - logger.debug(f"action_message内容: {action_info.action_message}") - async def _add_bot_reply_to_read_messages(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str): """将机器人回复添加到已读消息中""" try: @@ -491,7 +401,7 @@ class ChatterPlanExecutor: # 群组信息(如果是群聊) chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None, chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None, - chat_info_group_platform=chat_stream.group_info.group_platform if chat_stream.group_info else None, + chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None, # 动作信息 actions=["bot_reply"], diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index a24059c05..a8ae019a0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -51,16 +51,6 @@ class ChatterActionPlanner: self.generator = ChatterPlanGenerator(chat_id) self.executor = ChatterPlanExecutor(action_manager) - # 初始化关系追踪器 - if global_config.affinity_flow.enable_relationship_tracking: - from .relationship_tracker import ChatterRelationshipTracker - self.relationship_tracker = ChatterRelationshipTracker() - self.executor.set_relationship_tracker(self.relationship_tracker) - logger.info(f"关系追踪器已初始化 (chat_id: {chat_id})") - else: - self.relationship_tracker = None - logger.info(f"关系系统已禁用,跳过关系追踪器初始化 (chat_id: {chat_id})") - # 使用新的统一兴趣度管理系统 # 规划器统计 diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index 26b83a696..6a8ee7fdb 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -52,4 +52,20 @@ class AffinityChatterPlugin(BasePlugin): except Exception as e: logger.error(f"加载 AffinityInterestCalculator 时出错: {e}") + try: + # 延迟导入 UserProfileTool + from .user_profile_tool import UserProfileTool + + components.append((UserProfileTool.get_tool_info(), UserProfileTool)) + except Exception as e: + logger.error(f"加载 UserProfileTool 时出错: {e}") + + try: + # 延迟导入 ChatStreamImpressionTool + from .chat_stream_impression_tool import ChatStreamImpressionTool + + components.append((ChatStreamImpressionTool.get_tool_info(), ChatStreamImpressionTool)) + except Exception as e: + logger.error(f"加载 ChatStreamImpressionTool 时出错: {e}") + return components diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py deleted file mode 100644 index 5a0433028..000000000 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ /dev/null @@ -1,820 +0,0 @@ -""" -用户关系追踪器 -负责追踪用户交互历史,并通过LLM分析更新用户关系分 -支持数据库持久化存储和回复后自动关系更新 -""" - -import random -import time - -from sqlalchemy import desc, select - -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Messages, UserRelationships -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest - -logger = get_logger("chatter_relationship_tracker") - - -class ChatterRelationshipTracker: - """用户关系追踪器""" - - def __init__(self, interest_scoring_system=None): - self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data - self.max_tracking_users = 3 - self.update_interval_minutes = 30 - self.last_update_time = time.time() - self.relationship_history: list[dict] = [] - - # 兼容性:保留参数但不直接使用,转而使用统一API - self.interest_scoring_system = None # 废弃,不再使用 - - # 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float}) - self.user_relationship_cache: dict[str, dict] = {} - self.cache_expiry_hours = 1 # 缓存过期时间(小时) - - # 关系更新LLM - try: - self.relationship_llm = LLMRequest( - model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker" - ) - except AttributeError: - # 如果relationship_tracker配置不存在,尝试其他可用的模型配置 - available_models = [ - attr - for attr in dir(model_config.model_task_config) - if not attr.startswith("_") and attr != "model_dump" - ] - - if available_models: - # 使用第一个可用的模型配置 - fallback_model = available_models[0] - logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}") - self.relationship_llm = LLMRequest( - model_set=getattr(model_config.model_task_config, fallback_model), - request_type="relationship_tracker", - ) - else: - # 如果没有任何模型配置,创建一个简单的LLMRequest - logger.warning("No model configurations found, creating basic LLMRequest") - self.relationship_llm = LLMRequest( - model_set="gpt-3.5-turbo", # 默认模型 - request_type="relationship_tracker", - ) - - def set_interest_scoring_system(self, interest_scoring_system): - """设置兴趣度评分系统引用(已废弃,使用统一API)""" - # 不再需要设置,直接使用统一API - logger.info("set_interest_scoring_system 已废弃,现在使用统一评分API") - - def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float): - """添加用户交互记录""" - if len(self.tracking_users) >= self.max_tracking_users: - # 移除最旧的记录 - oldest_user = min( - self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0) - ) - del self.tracking_users[oldest_user] - - # 获取当前关系分 - 使用缓存数据 - current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值 - if user_id in self.user_relationship_cache: - current_relationship_score = self.user_relationship_cache[user_id].get("relationship_score", current_relationship_score) - - self.tracking_users[user_id] = { - "user_id": user_id, - "user_name": user_name, - "user_message": user_message, - "bot_reply": bot_reply, - "reply_timestamp": reply_timestamp, - "current_relationship_score": current_relationship_score, - } - - logger.debug(f"添加用户交互追踪: {user_id}") - - async def check_and_update_relationships(self) -> list[dict]: - """检查并更新用户关系""" - current_time = time.time() - if current_time - self.last_update_time < self.update_interval_minutes * 60: - return [] - - updates = [] - for user_id, interaction in list(self.tracking_users.items()): - if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟 - update = await self._update_user_relationship(interaction) - if update: - updates.append(update) - del self.tracking_users[user_id] - - self.last_update_time = current_time - return updates - - async def _update_user_relationship(self, interaction: dict) -> dict | None: - """更新单个用户的关系""" - try: - # 获取bot人设信息 - from src.individuality.individuality import Individuality - - individuality = Individuality() - bot_personality = await individuality.get_personality_block() - - prompt = f""" -你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} - -请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系: - -用户ID: {interaction["user_id"]} -用户名: {interaction["user_name"]} -用户消息: {interaction["user_message"]} -你的回复: {interaction["bot_reply"]} -当前关系分: {interaction["current_relationship_score"]} - -【重要】关系分数档次定义: -- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流 -- 0.2-0.4:普通网友 - 有基本互动但不熟悉 -- 0.4-0.6:熟悉网友 - 经常交流,有一定了解 -- 0.6-0.8:朋友 - 可以分享心情,互相关心 -- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间 - -【严格要求】: -1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数 -2. 关系提升需要足够的互动积累和时间验证 -3. 即使是朋友关系,单次互动加分通常不超过0.05-0.1 -4. 人物印象描述应该是泛化的、整体的理解,从你的视角对用户整体性格特质的描述: - - 描述用户的整体性格特点(如:温柔、幽默、理性、感性等) - - 用户给你的整体感觉和印象 - - 你们关系的整体状态和氛围 - - 避免描述具体事件或对话内容,而是基于这些事件形成的整体认知 - -根据你的人设性格,思考: -1. 从你的性格视角,这个用户给你什么样的整体印象? -2. 用户的性格特质和行为模式是否符合你的喜好? -3. 基于这次互动,你对用户的整体认知有什么变化? -4. 这个用户在你心中的整体形象是怎样的? - -请以JSON格式返回更新结果: -{{ - "new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑), - "reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑", - "interaction_summary": "基于你性格的用户整体印象描述,包含用户的整体性格特质、给你的整体感觉,避免具体事件描述" -}} -""" - - # 调用LLM进行分析 - 添加超时保护 - import asyncio - try: - llm_response, _ = await asyncio.wait_for( - self.relationship_llm.generate_response_async(prompt=prompt), - timeout=30.0 # 30秒超时 - ) - except asyncio.TimeoutError: - logger.warning(f"初次见面LLM调用超时: user_id={user_id}, 跳过此次追踪") - return - except Exception as e: - logger.error(f"初次见面LLM调用失败: user_id={user_id}, 错误: {e}") - return - - if llm_response: - import json - - try: - # 清理LLM响应,移除可能的格式标记 - cleaned_response = self._clean_llm_json_response(llm_response) - response_data = json.loads(cleaned_response) - new_score = max( - 0.0, - min( - 1.0, - float( - response_data.get( - "new_relationship_score", global_config.affinity_flow.base_relationship_score - ) - ), - ), - ) - - # 使用统一API更新关系分 - from src.plugin_system.apis.scoring_api import scoring_api - await scoring_api.update_user_relationship( - interaction["user_id"], new_score - ) - - return { - "user_id": interaction["user_id"], - "new_relationship_score": new_score, - "reasoning": response_data.get("reasoning", ""), - "interaction_summary": response_data.get("interaction_summary", ""), - } - - except json.JSONDecodeError as e: - logger.error(f"LLM响应JSON解析失败: {e}") - logger.debug(f"LLM原始响应: {llm_response}") - except Exception as e: - logger.error(f"处理关系更新数据失败: {e}") - - except Exception as e: - logger.error(f"更新用户关系时出错: {e}") - - return None - - def get_tracking_users(self) -> dict[str, dict]: - """获取正在追踪的用户""" - return self.tracking_users.copy() - - def get_user_interaction(self, user_id: str) -> dict | None: - """获取特定用户的交互记录""" - return self.tracking_users.get(user_id) - - def remove_user_tracking(self, user_id: str): - """移除用户追踪""" - if user_id in self.tracking_users: - del self.tracking_users[user_id] - logger.debug(f"移除用户追踪: {user_id}") - - def clear_all_tracking(self): - """清空所有追踪""" - self.tracking_users.clear() - logger.info("清空所有用户追踪") - - def get_relationship_history(self) -> list[dict]: - """获取关系历史记录""" - return self.relationship_history.copy() - - def add_to_history(self, relationship_update: dict): - """添加到关系历史""" - self.relationship_history.append({**relationship_update, "update_time": time.time()}) - - # 限制历史记录数量 - if len(self.relationship_history) > 100: - self.relationship_history = self.relationship_history[-100:] - - def get_tracker_stats(self) -> dict: - """获取追踪器统计""" - return { - "tracking_users": len(self.tracking_users), - "max_tracking_users": self.max_tracking_users, - "update_interval_minutes": self.update_interval_minutes, - "relationship_history": len(self.relationship_history), - "last_update_time": self.last_update_time, - } - - def update_config(self, max_tracking_users: int | None = None, update_interval_minutes: int | None = None): - """更新配置""" - if max_tracking_users is not None: - self.max_tracking_users = max_tracking_users - logger.info(f"更新最大追踪用户数: {max_tracking_users}") - - if update_interval_minutes is not None: - self.update_interval_minutes = update_interval_minutes - logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟") - - async def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""): - """强制更新用户关系分""" - if user_id in self.tracking_users: - current_score = self.tracking_users[user_id]["current_relationship_score"] - - # 使用统一API更新关系分 - from src.plugin_system.apis.scoring_api import scoring_api - await scoring_api.update_user_relationship(user_id, new_score) - - update_info = { - "user_id": user_id, - "new_relationship_score": new_score, - "reasoning": reasoning or "手动更新", - "interaction_summary": "手动更新关系分", - } - self.add_to_history(update_info) - logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}") - - def get_user_summary(self, user_id: str) -> dict: - """获取用户交互总结""" - if user_id not in self.tracking_users: - return {} - - interaction = self.tracking_users[user_id] - return { - "user_id": user_id, - "user_name": interaction["user_name"], - "current_relationship_score": interaction["current_relationship_score"], - "interaction_count": 1, # 简化版本,每次追踪只记录一次交互 - "last_interaction": interaction["reply_timestamp"], - "recent_message": interaction["user_message"][:100] + "..." - if len(interaction["user_message"]) > 100 - else interaction["user_message"], - } - - # ===== 数据库支持方法 ===== - - async def get_user_relationship_score(self, user_id: str) -> float: - """获取用户关系分""" - # 先检查缓存 - if user_id in self.user_relationship_cache: - cache_data = self.user_relationship_cache[user_id] - # 检查缓存是否过期 - cache_time = cache_data.get("last_tracked", 0) - if time.time() - cache_time < self.cache_expiry_hours * 3600: - return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) - - # 缓存过期或不存在,从数据库获取 - relationship_data = await self._get_user_relationship_from_db(user_id) - if relationship_data: - # 更新缓存 - self.user_relationship_cache[user_id] = { - "relationship_text": relationship_data.get("relationship_text", ""), - "relationship_score": relationship_data.get( - "relationship_score", global_config.affinity_flow.base_relationship_score - ), - "last_tracked": time.time(), - } - return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) - - # 数据库中也没有,返回默认值 - return global_config.affinity_flow.base_relationship_score - - async def _get_user_relationship_from_db(self, user_id: str) -> dict | None: - """从数据库获取用户关系数据""" - try: - async with get_db_session() as session: - # 查询用户关系表 - stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) - result = await session.execute(stmt) - relationship = result.scalar_one_or_none() - - if relationship: - return { - "relationship_text": relationship.relationship_text or "", - "relationship_score": float(relationship.relationship_score) - if relationship.relationship_score is not None - else 0.3, - "last_updated": relationship.last_updated, - } - except Exception as e: - logger.error(f"从数据库获取用户关系失败: {e}") - - return None - - async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float): - """更新数据库中的用户关系""" - try: - current_time = time.time() - - async with get_db_session() as session: - # 检查是否已存在关系记录 - stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) - result = await session.execute(stmt) - existing = result.scalar_one_or_none() - - if existing: - # 更新现有记录 - existing.relationship_text = relationship_text - existing.relationship_score = relationship_score - existing.last_updated = current_time - existing.user_name = existing.user_name or user_id # 更新用户名如果为空 - else: - # 插入新记录 - new_relationship = UserRelationships( - user_id=user_id, - user_name=user_id, - relationship_text=relationship_text, - relationship_score=relationship_score, - last_updated=current_time, - ) - session.add(new_relationship) - - await session.commit() - logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}") - - except Exception as e: - logger.error(f"更新数据库用户关系失败: {e}") - - # ===== 回复后关系追踪方法 ===== - - async def track_reply_relationship( - self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float - ): - """回复后关系追踪 - 主要入口点""" - try: - # 首先检查是否启用关系追踪 - if not global_config.affinity_flow.enable_relationship_tracking: - logger.debug(f"🚫 [RelationshipTracker] 关系追踪系统已禁用,跳过用户 {user_id}") - return - - # 概率筛选 - 减少API调用压力 - tracking_probability = global_config.affinity_flow.relationship_tracking_probability - if random.random() > tracking_probability: - logger.debug( - f"🎲 [RelationshipTracker] 概率筛选未通过 ({tracking_probability:.2f}),跳过用户 {user_id} 的关系追踪" - ) - return - - logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id} (概率通过: {tracking_probability:.2f})") - - # 检查上次追踪时间 - 使用配置的冷却时间 - last_tracked_time = await self._get_last_tracked_time(user_id) - cooldown_hours = global_config.affinity_flow.relationship_tracking_cooldown_hours - cooldown_seconds = cooldown_hours * 3600 - time_diff = reply_timestamp - last_tracked_time - - # 使用配置的最小间隔时间 - min_interval = global_config.affinity_flow.relationship_tracking_interval_min - required_interval = max(min_interval, cooldown_seconds) - - if time_diff < required_interval: - logger.debug( - f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足 {required_interval/60:.1f} 分钟 " - f"(实际: {time_diff/60:.1f} 分钟),跳过" - ) - return - - # 获取上次bot回复该用户的消息 - last_bot_reply = await self._get_last_bot_reply_to_user(user_id) - if not last_bot_reply: - logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑") - await self._handle_first_interaction(user_id, user_name, bot_reply_content) - return - - # 获取用户后续的反应消息 - user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time) - logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息") - - # 获取当前关系数据 - current_relationship = await self._get_user_relationship_from_db(user_id) - current_score = ( - current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) - if current_relationship - else global_config.affinity_flow.base_relationship_score - ) - current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户" - - # 使用LLM分析并更新关系 - logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系") - await self._analyze_and_update_relationship( - user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content - ) - - except Exception as e: - logger.error(f"回复后关系追踪失败: {e}") - logger.debug("错误详情:", exc_info=True) - - async def _get_last_tracked_time(self, user_id: str) -> float: - """获取上次追踪时间""" - # 先检查缓存 - if user_id in self.user_relationship_cache: - return self.user_relationship_cache[user_id].get("last_tracked", 0) - - # 从数据库获取 - relationship_data = await self._get_user_relationship_from_db(user_id) - if relationship_data: - return relationship_data.get("last_updated", 0) - - return 0 - - async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None: - """获取上次bot回复该用户的消息""" - try: - async with get_db_session() as session: - # 查询bot回复给该用户的最新消息 - stmt = ( - select(Messages) - .where(Messages.user_id == user_id) - .where(Messages.reply_to.isnot(None)) - .order_by(desc(Messages.time)) - .limit(1) - ) - - result = await session.execute(stmt) - message = result.scalar_one_or_none() - if message: - # 将SQLAlchemy模型转换为DatabaseMessages对象 - return self._sqlalchemy_to_database_messages(message) - - except Exception as e: - logger.error(f"获取上次回复消息失败: {e}") - - return None - - async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]: - """获取用户在bot回复后的反应消息""" - try: - async with get_db_session() as session: - # 查询用户在回复时间之后的5分钟内的消息 - end_time = reply_time + 5 * 60 # 5分钟 - - stmt = ( - select(Messages) - .where(Messages.user_id == user_id) - .where(Messages.time > reply_time) - .where(Messages.time <= end_time) - .order_by(Messages.time) - ) - - result = await session.execute(stmt) - messages = result.scalars().all() - if messages: - return [self._sqlalchemy_to_database_messages(message) for message in messages] - - except Exception as e: - logger.error(f"获取用户反应消息失败: {e}") - - return [] - - def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages: - """将SQLAlchemy消息模型转换为DatabaseMessages对象""" - try: - return DatabaseMessages( - message_id=sqlalchemy_message.message_id or "", - time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0, - chat_id=sqlalchemy_message.chat_id or "", - reply_to=sqlalchemy_message.reply_to, - processed_plain_text=sqlalchemy_message.processed_plain_text or "", - user_id=sqlalchemy_message.user_id or "", - user_nickname=sqlalchemy_message.user_nickname or "", - user_platform=sqlalchemy_message.user_platform or "", - ) - except Exception as e: - logger.error(f"SQLAlchemy消息转换失败: {e}") - # 返回一个基本的消息对象 - return DatabaseMessages( - message_id="", - time=0.0, - chat_id="", - processed_plain_text="", - user_id="", - user_nickname="", - user_platform="", - ) - - async def _analyze_and_update_relationship( - self, - user_id: str, - user_name: str, - last_bot_reply: DatabaseMessages, - user_reactions: list[DatabaseMessages], - current_text: str, - current_score: float, - current_reply: str, - ): - """使用LLM分析并更新用户关系""" - try: - # 构建分析提示 - user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions]) - - # 获取bot人设信息 - from src.individuality.individuality import Individuality - - individuality = Individuality() - bot_personality = await individuality.get_personality_block() - - prompt = f""" -你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} - -请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数: - -用户信息: -- 用户ID: {user_id} -- 用户名: {user_name} - -你上次的回复: {last_bot_reply.processed_plain_text} - -用户反应消息: -{user_reactions_text} - -你当前的回复: {current_reply} - -当前关系印象: {current_text} -当前关系分数: {current_score:.3f} - -【重要】关系分数档次定义: -- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流 -- 0.2-0.4:普通网友 - 有基本互动但不熟悉 -- 0.4-0.6:熟悉网友 - 经常交流,有一定了解 -- 0.6-0.8:朋友 - 可以分享心情,互相关心 -- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间 - -【严格要求】: -1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分 -2. 关系提升需要足够的互动积累和时间验证,单次互动加分通常不超过0.05-0.1 -3. 必须考虑当前关系档次,不能跳跃式提升(比如从0.3直接到0.7) -4. 人物印象描述应该是泛化的、整体的理解(100-200字),从你的视角对用户整体性格特质的描述: - - 描述用户的整体性格特点和行为模式(如:温柔体贴、幽默风趣、理性稳重等) - - 用户给你的整体感觉和印象氛围 - - 你们关系的整体状态和发展阶段 - - 基于所有互动形成的用户整体形象认知 - - 避免提及具体事件或对话内容,而是总结形成的整体印象 -5. 在撰写人物印象时,请根据已有信息自然地融入用户的性别。如果性别不确定,请使用中性描述。 - -性格视角深度分析: -1. 从你的性格视角,基于这次互动,你对用户的整体印象有什么新的认识? -2. 用户的整体性格特质和行为模式符合你的喜好吗? -3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么? -4. 基于你们的互动历史,用户在你心中的整体形象是怎样的? -5. 这个用户给你带来的整体感受和情绪体验是怎样的? - -请以JSON格式返回更新结果: -{{ - "relationship_text": "泛化的用户整体印象描述(100-200字),其中自然地体现用户的性别,包含用户的整体性格特质、给你的整体感觉和印象氛围,避免具体事件描述", - "relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑), - "analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性", - "interaction_quality": "high/medium/low" -}} -""" - - # 调用LLM进行分析 - 添加超时保护 - import asyncio - try: - llm_response, _ = await asyncio.wait_for( - self.relationship_llm.generate_response_async(prompt=prompt), - timeout=30.0 # 30秒超时 - ) - except asyncio.TimeoutError: - logger.warning(f"关系追踪LLM调用超时: user_id={user_id}, 跳过此次追踪") - return - except Exception as e: - logger.error(f"关系追踪LLM调用失败: user_id={user_id}, 错误: {e}") - return - - if llm_response: - import json - - try: - # 清理LLM响应,移除可能的格式标记 - cleaned_response = self._clean_llm_json_response(llm_response) - response_data = json.loads(cleaned_response) - - new_text = response_data.get("relationship_text", current_text) - new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score)))) - reasoning = response_data.get("analysis_reasoning", "") - quality = response_data.get("interaction_quality", "medium") - - # 更新数据库 - await self._update_user_relationship_in_db(user_id, new_text, new_score) - - # 更新缓存 - self.user_relationship_cache[user_id] = { - "relationship_text": new_text, - "relationship_score": new_score, - "last_tracked": time.time(), - } - - # 使用统一API更新关系分(内存缓存已通过数据库更新自动处理) - # 数据库更新后,缓存会在下次访问时自动同步 - - # 记录分析历史 - analysis_record = { - "user_id": user_id, - "timestamp": time.time(), - "old_score": current_score, - "new_score": new_score, - "old_text": current_text, - "new_text": new_text, - "reasoning": reasoning, - "quality": quality, - "user_reactions_count": len(user_reactions), - } - self.relationship_history.append(analysis_record) - - # 限制历史记录数量 - if len(self.relationship_history) > 100: - self.relationship_history = self.relationship_history[-100:] - - logger.info(f"✅ 关系分析完成: {user_id}") - logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'") - logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}") - logger.info(f" 🎯 质量: {quality}") - - except json.JSONDecodeError as e: - logger.error(f"LLM响应JSON解析失败: {e}") - logger.debug(f"LLM原始响应: {llm_response}") - else: - logger.warning("LLM未返回有效响应") - - except Exception as e: - logger.error(f"关系分析失败: {e}") - logger.debug("错误详情:", exc_info=True) - - async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str): - """处理与用户的初次交互""" - try: - # 初次交互也进行概率检查,但使用更高的通过率 - first_interaction_probability = min(1.0, global_config.affinity_flow.relationship_tracking_probability * 1.5) - if random.random() > first_interaction_probability: - logger.debug( - f"🎲 [RelationshipTracker] 初次交互概率筛选未通过 ({first_interaction_probability:.2f}),跳过用户 {user_id}" - ) - return - - logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互 (概率通过: {first_interaction_probability:.2f})") - - # 获取bot人设信息 - from src.individuality.individuality import Individuality - - individuality = Individuality() - bot_personality = await individuality.get_personality_block() - - prompt = f""" -你现在是:{bot_personality} - -你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象,建立初始关系档案。 - -用户信息: -- 用户ID: {user_id} -- 用户名: {user_name} - -你的首次回复: {bot_reply_content} - -【严格要求】: -1. 建立一个初始关系分数,通常在0.2-0.4之间(普通网友)。 -2. 初始关系印象描述要简洁地记录你对用户的整体初步看法(50-100字)。请在描述中自然地融入你对用户性别的初步判断(例如“他似乎是...”或“感觉她...”),如果完全无法判断,则使用中性描述。 - - 基于用户名和初次互动,用户给你的整体感觉 - - 你感受到的用户整体性格特质倾向 - - 你对与这个用户建立关系的整体期待和感觉 - - 避免描述具体的事件细节,而是整体的直觉印象 - -请以JSON格式返回结果: -{{ - "relationship_text": "简洁的用户整体初始印象描述(50-100字),其中自然地体现对用户性别的初步判断", - "relationship_score": 0.2~0.4的新分数, - "analysis_reasoning": "从你性格角度说明建立此初始印象的理由" -}} -""" - # 调用LLM进行分析 - llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt) - if not llm_response: - logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}") - return - - import json - - cleaned_response = self._clean_llm_json_response(llm_response) - response_data = json.loads(cleaned_response) - - new_text = response_data.get("relationship_text", "初次见面") - new_score = max( - 0.0, - min( - 1.0, - float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)), - ), - ) - - # 更新数据库和缓存 - await self._update_user_relationship_in_db(user_id, new_text, new_score) - self.user_relationship_cache[user_id] = { - "relationship_text": new_text, - "relationship_score": new_score, - "last_tracked": time.time(), - } - - logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}") - - except Exception as e: - logger.error(f"处理初次交互失败: {user_id}, 错误: {e}") - logger.debug("错误详情:", exc_info=True) - - def _clean_llm_json_response(self, response: str) -> str: - """ - 清理LLM响应,移除可能的JSON格式标记 - - Args: - response: LLM原始响应 - - Returns: - 清理后的JSON字符串 - """ - try: - import re - - # 移除常见的JSON格式标记 - cleaned = response.strip() - - # 移除 ```json 或 ``` 等标记 - cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) - cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - - # 移除可能的Markdown代码块标记 - cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE) - - # 尝试找到JSON对象的开始和结束 - json_start = cleaned.find("{") - json_end = cleaned.rfind("}") - - if json_start != -1 and json_end != -1 and json_end > json_start: - # 提取JSON部分 - cleaned = cleaned[json_start : json_end + 1] - - # 移除多余的空白字符 - cleaned = cleaned.strip() - - logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}") - if cleaned != response: - logger.debug(f"清理前: {response[:200]}...") - logger.debug(f"清理后: {cleaned[:200]}...") - - return cleaned - - except Exception as e: - logger.warning(f"清理LLM响应失败: {e}") - return response # 清理失败时返回原始响应 diff --git a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py new file mode 100644 index 000000000..b4fc68526 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py @@ -0,0 +1,370 @@ +""" +用户画像更新工具 + +通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数 +""" + +import orjson +import time +from typing import Any + +from sqlalchemy import select + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import UserRelationships +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.plugin_system import BaseTool, ToolParamType + +logger = get_logger("user_profile_tool") + + +class UserProfileTool(BaseTool): + """用户画像更新工具 + + 使用二步调用机制: + 1. LLM决定是否调用工具并传入初步参数 + 2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容 + """ + + name = "update_user_profile" + description = "当你通过聊天记录对某个用户产生了新的认识或印象时使用此工具,更新该用户的画像信息。包括:用户别名、你对TA的主观印象、TA的偏好兴趣、你对TA的好感程度。调用时机:当你发现用户透露了新的个人信息、展现了性格特点、表达了兴趣偏好,或者你们的互动让你对TA的看法发生变化时。" + parameters = [ + ("target_user_id", ToolParamType.STRING, "目标用户的ID(必须)", True, None), + ("user_aliases", ToolParamType.STRING, "该用户的昵称或别名,如果发现用户自称或被他人称呼的其他名字时填写,多个别名用逗号分隔(可选)", False, None), + ("impression_description", ToolParamType.STRING, "你对该用户的整体印象和性格感受,例如'这个用户很幽默开朗'、'TA对技术很有热情'等。当你通过对话了解到用户的性格、态度、行为特点时填写(可选)", False, None), + ("preference_keywords", ToolParamType.STRING, "该用户表现出的兴趣爱好或偏好,如'编程,游戏,动漫'。当用户谈论自己喜欢的事物时填写,多个关键词用逗号分隔(可选)", False, None), + ("affection_score", ToolParamType.FLOAT, "你对该用户的好感程度,0.0(陌生/不喜欢)到1.0(很喜欢/好友)。当你们的互动让你对TA的感觉发生变化时更新(可选)", False, None), + ] + available_for_llm = True + history_ttl = 5 + + def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): + super().__init__(plugin_config, chat_stream) + + # 初始化用于二步调用的LLM + try: + self.profile_llm = LLMRequest( + model_set=model_config.model_task_config.relationship_tracker, + request_type="user_profile_update" + ) + except AttributeError: + # 降级处理 + available_models = [ + attr for attr in dir(model_config.model_task_config) + if not attr.startswith("_") and attr != "model_dump" + ] + if available_models: + fallback_model = available_models[0] + logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}") + self.profile_llm = LLMRequest( + model_set=getattr(model_config.model_task_config, fallback_model), + request_type="user_profile_update" + ) + else: + logger.error("无可用的模型配置") + self.profile_llm = None + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """执行用户画像更新 + + Args: + function_args: 工具参数 + + Returns: + dict: 执行结果 + """ + try: + # 提取参数 + target_user_id = function_args.get("target_user_id") + if not target_user_id: + return { + "type": "error", + "id": "user_profile_update", + "content": "错误:必须提供目标用户ID" + } + + # 从LLM传入的参数 + new_aliases = function_args.get("user_aliases", "") + new_impression = function_args.get("impression_description", "") + new_keywords = function_args.get("preference_keywords", "") + new_score = function_args.get("affection_score") + + # 从数据库获取现有用户画像 + existing_profile = await self._get_user_profile(target_user_id) + + # 如果LLM没有传入任何有效参数,返回提示 + if not any([new_aliases, new_impression, new_keywords, new_score is not None]): + return { + "type": "info", + "id": target_user_id, + "content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)" + } + + # 调用LLM进行二步决策 + if self.profile_llm is None: + logger.error("LLM未正确初始化,无法执行二步调用") + return { + "type": "error", + "id": target_user_id, + "content": "系统错误:LLM未正确初始化" + } + + final_profile = await self._llm_decide_final_profile( + target_user_id=target_user_id, + existing_profile=existing_profile, + new_aliases=new_aliases, + new_impression=new_impression, + new_keywords=new_keywords, + new_score=new_score + ) + + if not final_profile: + return { + "type": "error", + "id": target_user_id, + "content": "LLM决策失败,无法更新用户画像" + } + + # 更新数据库 + await self._update_user_profile_in_db(target_user_id, final_profile) + + # 构建返回信息 + updates = [] + if final_profile.get("user_aliases"): + updates.append(f"别名: {final_profile['user_aliases']}") + if final_profile.get("relationship_text"): + updates.append(f"印象: {final_profile['relationship_text'][:50]}...") + if final_profile.get("preference_keywords"): + updates.append(f"偏好: {final_profile['preference_keywords']}") + if final_profile.get("relationship_score") is not None: + updates.append(f"好感分: {final_profile['relationship_score']:.2f}") + + result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates) + logger.info(f"用户画像更新成功: {target_user_id}") + + return { + "type": "user_profile_update", + "id": target_user_id, + "content": result_text + } + + except Exception as e: + logger.error(f"用户画像更新失败: {e}", exc_info=True) + return { + "type": "error", + "id": function_args.get("target_user_id", "unknown"), + "content": f"用户画像更新失败: {str(e)}" + } + + async def _get_user_profile(self, user_id: str) -> dict[str, Any]: + """从数据库获取用户现有画像 + + Args: + user_id: 用户ID + + Returns: + dict: 用户画像数据 + """ + try: + async with get_db_session() as session: + stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) + result = await session.execute(stmt) + profile = result.scalar_one_or_none() + + if profile: + return { + "user_name": profile.user_name or user_id, + "user_aliases": profile.user_aliases or "", + "relationship_text": profile.relationship_text or "", + "preference_keywords": profile.preference_keywords or "", + "relationship_score": float(profile.relationship_score) if profile.relationship_score is not None else global_config.affinity_flow.base_relationship_score, + } + else: + # 用户不存在,返回默认值 + return { + "user_name": user_id, + "user_aliases": "", + "relationship_text": "", + "preference_keywords": "", + "relationship_score": global_config.affinity_flow.base_relationship_score, + } + except Exception as e: + logger.error(f"获取用户画像失败: {e}") + return { + "user_name": user_id, + "user_aliases": "", + "relationship_text": "", + "preference_keywords": "", + "relationship_score": global_config.affinity_flow.base_relationship_score, + } + + async def _llm_decide_final_profile( + self, + target_user_id: str, + existing_profile: dict[str, Any], + new_aliases: str, + new_impression: str, + new_keywords: str, + new_score: float | None + ) -> dict[str, Any] | None: + """使用LLM决策最终的用户画像内容 + + Args: + target_user_id: 目标用户ID + existing_profile: 现有画像数据 + new_aliases: LLM传入的新别名 + new_impression: LLM传入的新印象 + new_keywords: LLM传入的新关键词 + new_score: LLM传入的新分数 + + Returns: + dict: 最终决定的画像数据,如果失败返回None + """ + try: + # 获取bot人设 + from src.individuality.individuality import Individuality + individuality = Individuality() + bot_personality = await individuality.get_personality_block() + + prompt = f""" +你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} + +你正在更新对用户 {target_user_id} 的画像认识。 + +【当前画像信息】 +- 用户名: {existing_profile.get('user_name', target_user_id)} +- 已知别名: {existing_profile.get('user_aliases', '无')} +- 当前印象: {existing_profile.get('relationship_text', '暂无印象')} +- 偏好关键词: {existing_profile.get('preference_keywords', '未知')} +- 当前好感分: {existing_profile.get('relationship_score', 0.3):.2f} + +【本次想要更新的内容】 +- 新增/更新别名: {new_aliases if new_aliases else '不更新'} +- 新的印象描述: {new_impression if new_impression else '不更新'} +- 新的偏好关键词: {new_keywords if new_keywords else '不更新'} +- 新的好感分数: {new_score if new_score is not None else '不更新'} + +请综合考虑现有信息和新信息,决定最终的用户画像内容。注意: +1. 别名:如果提供了新别名,应该与现有别名合并(去重),而不是替换 +2. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成更完整的认识(100-200字) +3. 偏好关键词:如果提供了新关键词,应该与现有关键词合并(去重),每个关键词简短 +4. 好感分数:如果提供了新分数,需要结合现有分数合理调整(变化不宜过大,遵循现实逻辑) + +请以JSON格式返回最终决定: +{{ + "user_aliases": "最终的别名列表,逗号分隔", + "relationship_text": "最终的印象描述(100-200字),整体性、泛化的理解", + "preference_keywords": "最终的偏好关键词,逗号分隔", + "relationship_score": 最终的好感分数(0.0-1.0), + "reasoning": "你的决策理由" +}} +""" + + # 调用LLM + llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt) + + if not llm_response: + logger.warning("LLM未返回有效响应") + return None + + # 清理并解析响应 + cleaned_response = self._clean_llm_json_response(llm_response) + response_data = orjson.loads(cleaned_response) + + # 提取最终决定的数据 + final_profile = { + "user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")), + "relationship_text": response_data.get("relationship_text", existing_profile.get("relationship_text", "")), + "preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")), + "relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))), + } + + logger.info(f"LLM决策完成: {target_user_id}") + logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") + + return final_profile + + except orjson.JSONDecodeError as e: + logger.error(f"LLM响应JSON解析失败: {e}") + logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") + return None + except Exception as e: + logger.error(f"LLM决策失败: {e}", exc_info=True) + return None + + async def _update_user_profile_in_db(self, user_id: str, profile: dict[str, Any]): + """更新数据库中的用户画像 + + Args: + user_id: 用户ID + profile: 画像数据 + """ + try: + current_time = time.time() + + async with get_db_session() as session: + stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + # 更新现有记录 + existing.user_aliases = profile.get("user_aliases", "") + existing.relationship_text = profile.get("relationship_text", "") + existing.preference_keywords = profile.get("preference_keywords", "") + existing.relationship_score = profile.get("relationship_score", global_config.affinity_flow.base_relationship_score) + existing.last_updated = current_time + else: + # 创建新记录 + new_profile = UserRelationships( + user_id=user_id, + user_name=user_id, + user_aliases=profile.get("user_aliases", ""), + relationship_text=profile.get("relationship_text", ""), + preference_keywords=profile.get("preference_keywords", ""), + relationship_score=profile.get("relationship_score", global_config.affinity_flow.base_relationship_score), + last_updated=current_time + ) + session.add(new_profile) + + await session.commit() + logger.info(f"用户画像已更新到数据库: {user_id}") + + except Exception as e: + logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True) + raise + + def _clean_llm_json_response(self, response: str) -> str: + """清理LLM响应,移除可能的JSON格式标记 + + Args: + response: LLM原始响应 + + Returns: + str: 清理后的JSON字符串 + """ + try: + import re + + cleaned = response.strip() + + # 移除 ```json 或 ``` 等标记 + cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) + cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) + + # 尝试找到JSON对象的开始和结束 + json_start = cleaned.find("{") + json_end = cleaned.rfind("}") + + if json_start != -1 and json_end != -1 and json_end > json_start: + cleaned = cleaned[json_start:json_end + 1] + + cleaned = cleaned.strip() + + return cleaned + + except Exception as e: + logger.warning(f"清理LLM响应失败: {e}") + return response diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 05005c173..8d75ca2fd 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -6,6 +6,7 @@ from typing import ClassVar from dateutil.parser import parse as parse_datetime from src.chat.message_receive.chat_stream import ChatStream +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask, async_task_manager from src.person_info.person_info import get_person_info_manager @@ -253,19 +254,19 @@ class SetEmojiLikeAction(BaseAction): message_id = None set_like = self.action_data.get("set", True) - if self.has_action_message and isinstance(self.action_message, dict): - message_id = self.action_message.get("message_id") - logger.info(f"获取到的消息ID: {message_id}") - else: + if self.has_action_message: + if isinstance(self.action_message, DatabaseMessages): + message_id = self.action_message.message_id + logger.info(f"获取到的消息ID: {message_id}") + elif isinstance(self.action_message, dict): + message_id = self.action_message.get("message_id") + logger.info(f"获取到的消息ID: {message_id}") + + if not message_id: logger.error("未提供有效的消息或消息ID") await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False) return False, "未提供消息ID" - if not message_id: - logger.error("消息ID为空") - await self.store_action_info(action_prompt_display="贴表情失败: 消息ID为空", action_done=False) - return False, "消息ID为空" - available_models = llm_api.get_available_models() if "utils_small" not in available_models: logger.error("未找到 'utils_small' 模型配置,无法选择表情") @@ -273,7 +274,12 @@ class SetEmojiLikeAction(BaseAction): model_to_use = available_models["utils_small"] - context_text = self.action_message.get("processed_plain_text", "") + # 统一处理 DatabaseMessages 和字典 + if isinstance(self.action_message, DatabaseMessages): + context_text = self.action_message.processed_plain_text or "" + else: + context_text = self.action_message.get("processed_plain_text", "") + if not context_text: logger.error("无法找到动作选择的原始消息文本") return False, "无法找到动作选择的原始消息文本" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 05ac7274e..e824467d2 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.5.1" +version = "7.5.2" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -92,6 +92,11 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息 [expression] # 表达学习配置 +# mode: 表达方式模式,可选: +# - "classic": 经典模式,随机抽样 + LLM选择 +# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达 +mode = "classic" + # rules是一个列表,每个元素都是一个学习规则 # chat_stream_id: 聊天流ID,格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置 # use_expression: 是否使用学到的表达 (true/false)