diff --git a/integration_test_relationship_tools.py b/integration_test_relationship_tools.py deleted file mode 100644 index a2ac3a7fa..000000000 --- a/integration_test_relationship_tools.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -关系追踪工具集成测试脚本 - -注意:此脚本需要在完整的应用环境中运行 -建议通过 bot.py 启动后在交互式环境中测试 -""" - -import asyncio - - -async def test_user_profile_tool(): - """测试用户画像工具""" - print("\n" + "=" * 80) - print("测试 UserProfileTool") - print("=" * 80) - - from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import UserRelationships - - tool = UserProfileTool() - print(f"✅ 工具名称: {tool.name}") - print(f" 工具描述: {tool.description}") - - # 执行工具 - test_user_id = "integration_test_user_001" - result = await tool.execute({ - "target_user_id": test_user_id, - "user_aliases": "测试小明,TestMing,小明君", - "impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。", - "preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说", - "affection_score": 0.85 - }) - - print(f"\n✅ 工具执行结果:") - print(f" 类型: {result.get('type')}") - print(f" 内容: {result.get('content')}") - - # 验证数据库 - db_data = await db_query( - UserRelationships, - filters={"user_id": test_user_id}, - limit=1 - ) - - if db_data: - data = db_data[0] - print(f"\n✅ 数据库验证:") - print(f" user_id: {data.get('user_id')}") - print(f" user_aliases: {data.get('user_aliases')}") - print(f" relationship_text: {data.get('relationship_text', '')[:80]}...") - print(f" preference_keywords: {data.get('preference_keywords')}") - print(f" relationship_score: {data.get('relationship_score')}") - return True - else: - print(f"\n❌ 数据库中未找到数据") - return False - - -async def test_chat_stream_impression_tool(): - """测试聊天流印象工具""" - print("\n" + "=" * 80) - print("测试 ChatStreamImpressionTool") - print("=" * 80) - - from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import ChatStreams, get_db_session - - # 准备测试数据:先创建一条 ChatStreams 记录 - test_stream_id = "integration_test_stream_001" - print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}") - - import time - current_time = time.time() - - async with get_db_session() as session: - new_stream = ChatStreams( - stream_id=test_stream_id, - create_time=current_time, - last_active_time=current_time, - platform="QQ", - user_platform="QQ", - user_id="test_user_123", - user_nickname="测试用户", - group_name="测试技术交流群", - group_platform="QQ", - group_id="test_group_456", - stream_impression_text="", # 初始为空 - stream_chat_style="", - stream_topic_keywords="", - stream_interest_score=0.5 - ) - session.add(new_stream) - await session.commit() - print(f"✅ 测试聊天流记录已创建") - - tool = ChatStreamImpressionTool() - print(f"✅ 工具名称: {tool.name}") - print(f" 工具描述: {tool.description}") - - # 执行工具 - result = await tool.execute({ - "stream_id": test_stream_id, - "impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。", - "chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享", - "topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目", - "interest_score": 0.90 - }) - - print(f"\n✅ 工具执行结果:") - print(f" 类型: {result.get('type')}") - print(f" 内容: {result.get('content')}") - - # 验证数据库 - db_data = await db_query( - ChatStreams, - filters={"stream_id": test_stream_id}, - limit=1 - ) - - if db_data: - data = db_data[0] - print(f"\n✅ 数据库验证:") - print(f" stream_id: {data.get('stream_id')}") - print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...") - print(f" stream_chat_style: {data.get('stream_chat_style')}") - print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}") - print(f" stream_interest_score: {data.get('stream_interest_score')}") - return True - else: - print(f"\n❌ 数据库中未找到数据") - return False - - -async def test_relationship_info_build(): - """测试关系信息构建""" - print("\n" + "=" * 80) - print("测试关系信息构建(提示词集成)") - print("=" * 80) - - from src.person_info.relationship_fetcher import relationship_fetcher_manager - - test_stream_id = "integration_test_stream_001" - test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试 - - fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id) - print(f"✅ RelationshipFetcher 已创建") - - # 测试聊天流印象构建 - print(f"\n🔍 构建聊天流印象...") - stream_info = await fetcher.build_chat_stream_impression(test_stream_id) - - if stream_info: - print(f"✅ 聊天流印象构建成功") - print(f"\n{'=' * 80}") - print(stream_info) - print(f"{'=' * 80}") - else: - print(f"⚠️ 聊天流印象为空(可能测试数据不存在)") - - return True - - -async def cleanup_test_data(): - """清理测试数据""" - print("\n" + "=" * 80) - print("清理测试数据") - print("=" * 80) - - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams - - try: - # 清理用户数据 - await db_query( - UserRelationships, - query_type="delete", - filters={"user_id": "integration_test_user_001"} - ) - print("✅ 用户测试数据已清理") - - # 清理聊天流数据 - await db_query( - ChatStreams, - query_type="delete", - filters={"stream_id": "integration_test_stream_001"} - ) - print("✅ 聊天流测试数据已清理") - - return True - except Exception as e: - print(f"⚠️ 清理失败: {e}") - return False - - -async def run_all_tests(): - """运行所有测试""" - print("\n" + "=" * 80) - print("关系追踪工具集成测试") - print("=" * 80) - - results = {} - - # 测试1 - try: - results["UserProfileTool"] = await test_user_profile_tool() - except Exception as e: - print(f"\n❌ UserProfileTool 测试失败: {e}") - import traceback - traceback.print_exc() - results["UserProfileTool"] = False - - # 测试2 - try: - results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool() - except Exception as e: - print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}") - import traceback - traceback.print_exc() - results["ChatStreamImpressionTool"] = False - - # 测试3 - try: - results["RelationshipFetcher"] = await test_relationship_info_build() - except Exception as e: - print(f"\n❌ RelationshipFetcher 测试失败: {e}") - import traceback - traceback.print_exc() - results["RelationshipFetcher"] = False - - # 清理 - try: - await cleanup_test_data() - except Exception as e: - print(f"\n⚠️ 清理测试数据失败: {e}") - - # 总结 - print("\n" + "=" * 80) - print("测试总结") - print("=" * 80) - - passed = sum(1 for r in results.values() if r) - total = len(results) - - for test_name, result in results.items(): - status = "✅ 通过" if result else "❌ 失败" - print(f"{status} - {test_name}") - - print(f"\n总计: {passed}/{total} 测试通过") - - if passed == total: - print("\n🎉 所有测试通过!") - else: - print(f"\n⚠️ {total - passed} 个测试失败") - - return passed == total - - -# 使用说明 -print(""" -============================================================================ -关系追踪工具集成测试脚本 -============================================================================ - -此脚本需要在完整的应用环境中运行。 - -使用方法1: 在 bot.py 中添加测试调用 ------------------------------------ -在 bot.py 的 main() 函数中添加: - - # 测试关系追踪工具 - from tests.integration_test_relationship_tools import run_all_tests - await run_all_tests() - -使用方法2: 在 Python REPL 中运行 ------------------------------------ -启动 bot.py 后,在 Python 调试控制台中执行: - - import asyncio - from tests.integration_test_relationship_tools import run_all_tests - asyncio.create_task(run_all_tests()) - -使用方法3: 直接在此文件底部运行 ------------------------------------ -取消注释下面的代码,然后确保已启动应用环境 -============================================================================ -""") - - -# 如果需要直接运行(需要应用环境已启动) -if __name__ == "__main__": - print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境") - print("建议在 bot.py 启动后的环境中运行\n") - - try: - asyncio.run(run_all_tests()) - except Exception as e: - print(f"\n❌ 测试失败: {e}") - print("\n建议:") - print("1. 确保已启动 bot.py") - print("2. 在 Python 调试控制台中运行测试") - print("3. 或在 bot.py 中添加测试调用") diff --git a/pyrightconfig.json b/pyrightconfig.json index 3cffac58c..adf9c8dcf 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -27,6 +27,6 @@ "venvPath": ".", "venv": ".venv", "executionEnvironments": [ - {"root": "src"} + {"root": "."} ] } diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py index f600cc434..2341c2140 100644 --- a/scripts/check_expression_database.py +++ b/scripts/check_expression_database.py @@ -9,24 +9,25 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from sqlalchemy import select, func +from sqlalchemy import func, select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Expression async def check_database(): """检查表达方式数据库状态""" - + print("=" * 60) print("表达方式数据库诊断报告") print("=" * 60) - + async with get_db_session() as session: # 1. 统计总数 total_count = await session.execute(select(func.count()).select_from(Expression)) total = total_count.scalar() print(f"\n📊 总表达方式数量: {total}") - + if total == 0: print("\n⚠️ 数据库为空!") print("\n可能的原因:") @@ -38,7 +39,7 @@ async def check_database(): print("- 查看日志中是否有表达学习相关的错误") print("- 确认聊天流的 learn_expression 配置为 true") return - + # 2. 按 chat_id 统计 print("\n📝 按聊天流统计:") chat_counts = await session.execute( @@ -47,7 +48,7 @@ async def check_database(): ) for chat_id, count in chat_counts: print(f" - {chat_id}: {count} 个表达方式") - + # 3. 按 type 统计 print("\n📝 按类型统计:") type_counts = await session.execute( @@ -56,7 +57,7 @@ async def check_database(): ) for expr_type, count in type_counts: print(f" - {expr_type}: {count} 个") - + # 4. 检查 situation 和 style 字段是否有空值 print("\n🔍 字段完整性检查:") null_situation = await session.execute( @@ -69,30 +70,30 @@ async def check_database(): .select_from(Expression) .where(Expression.style == None) ) - + null_sit_count = null_situation.scalar() null_sty_count = null_style.scalar() - + print(f" - situation 为空: {null_sit_count} 个") print(f" - style 为空: {null_sty_count} 个") - + if null_sit_count > 0 or null_sty_count > 0: print(" ⚠️ 发现空值!这会导致匹配失败") - + # 5. 显示一些样例数据 print("\n📋 样例数据 (前10条):") samples = await session.execute( select(Expression) .limit(10) ) - + for i, expr in enumerate(samples.scalars(), 1): print(f"\n [{i}] Chat: {expr.chat_id}") print(f" Type: {expr.type}") print(f" Situation: {expr.situation}") print(f" Style: {expr.style}") print(f" Count: {expr.count}") - + # 6. 检查 style 字段的唯一值 print("\n📋 Style 字段样例 (前20个):") unique_styles = await session.execute( @@ -100,13 +101,13 @@ async def check_database(): .distinct() .limit(20) ) - + styles = [s for s in unique_styles.scalars()] for style in styles: print(f" - {style}") - + print(f"\n (共 {len(styles)} 个不同的 style)") - + print("\n" + "=" * 60) print("诊断完成") print("=" * 60) diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py index c8f5ef1fb..d28c8b240 100644 --- a/scripts/check_style_field.py +++ b/scripts/check_style_field.py @@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from sqlalchemy import select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Expression async def analyze_style_fields(): """分析 style 字段的内容""" - + print("=" * 60) print("Style 字段内容分析") print("=" * 60) - + async with get_db_session() as session: # 获取所有表达方式 result = await session.execute(select(Expression).limit(30)) expressions = result.scalars().all() - + print(f"\n总共检查 {len(expressions)} 条记录\n") - + # 按类型分类 style_examples = [] - + for expr in expressions: if expr.type == "style": style_examples.append({ @@ -37,7 +38,7 @@ async def analyze_style_fields(): "style": expr.style, "length": len(expr.style) if expr.style else 0 }) - + print("📋 Style 类型样例 (前15条):") print("="*60) for i, ex in enumerate(style_examples[:15], 1): @@ -45,17 +46,17 @@ async def analyze_style_fields(): print(f" Situation: {ex['situation']}") print(f" Style: {ex['style']}") print(f" 长度: {ex['length']} 字符") - + # 判断是具体表达还是风格描述 - if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']): + if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]): style_type = "✓ 风格描述" - elif ex['length'] <= 10: + elif ex["length"] <= 10: style_type = "? 可能是具体表达(较短)" else: style_type = "✗ 具体表达内容" - + print(f" 类型判断: {style_type}") - + print("\n" + "="*60) print("分析完成") print("="*60) diff --git a/scripts/debug_style_learner.py b/scripts/debug_style_learner.py index 970ba2532..1c0937ece 100644 --- a/scripts/debug_style_learner.py +++ b/scripts/debug_style_learner.py @@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner") def check_style_learner_status(chat_id: str): """检查指定 chat_id 的 StyleLearner 状态""" - + print("=" * 60) print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}") print("=" * 60) - + # 获取 learner learner = style_learner_manager.get_learner(chat_id) - + # 1. 基本信息 - print(f"\n📊 基本信息:") + print("\n📊 基本信息:") print(f" Chat ID: {learner.chat_id}") print(f" 风格数量: {len(learner.style_to_id)}") print(f" 下一个ID: {learner.next_style_id}") print(f" 最大风格数: {learner.max_styles}") - + # 2. 学习统计 - print(f"\n📈 学习统计:") + print("\n📈 学习统计:") print(f" 总样本数: {learner.learning_stats['total_samples']}") print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}") - + # 3. 风格列表(前20个) - print(f"\n📋 已学习的风格 (前20个):") + print("\n📋 已学习的风格 (前20个):") all_styles = learner.get_all_styles() if not all_styles: print(" ⚠️ 没有任何风格!模型尚未训练") @@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str): situation = learner.id_to_situation.get(style_id, "N/A") print(f" [{i}] {style}") print(f" (ID: {style_id}, Situation: {situation})") - + # 4. 测试预测 - print(f"\n🔮 测试预测功能:") + print("\n🔮 测试预测功能:") if not all_styles: print(" ⚠️ 无法测试,模型没有训练数据") else: @@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str): "讨论游戏", "表达赞同" ] - + for test_sit in test_situations: print(f"\n 测试输入: '{test_sit}'") best_style, scores = learner.predict_style(test_sit, top_k=3) - + if best_style: print(f" ✓ 最佳匹配: {best_style}") - print(f" Top 3:") + print(" Top 3:") for style, score in list(scores.items())[:3]: print(f" - {style}: {score:.4f}") else: - print(f" ✗ 预测失败") - + print(" ✗ 预测失败") + print("\n" + "=" * 60) print("诊断完成") print("=" * 60) @@ -82,7 +82,7 @@ if __name__ == "__main__": "52fb94af9f500a01e023ea780e43606e", # 有78个表达方式 "46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式 ] - + for chat_id in test_chat_ids: check_style_learner_status(chat_id) print("\n") diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index 0e37efc0d..b13baff13 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -6,7 +6,7 @@ import re -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger logger = get_logger("anti_injector.message_processor") @@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor") class MessageProcessor: """消息内容处理器""" - def extract_text_content(self, message: MessageRecv) -> str: + def extract_text_content(self, message: DatabaseMessages) -> str: """提取消息中的文本内容,过滤掉引用的历史内容 Args: @@ -64,7 +64,7 @@ class MessageProcessor: return new_content @staticmethod - def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None: + def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None: """检查用户白名单 Args: @@ -74,8 +74,8 @@ class MessageProcessor: Returns: 如果在白名单中返回结果元组,否则返回None """ - user_id = message.message_info.user_info.user_id - platform = message.message_info.platform + user_id = message.user_info.user_id + platform = message.chat_info.platform # 检查用户白名单:格式为 [[platform, user_id], ...] for whitelist_entry in whitelist: diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 982cfccce..079147812 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator): # 从数据库获取聊天流兴趣分数 try: + from sqlalchemy import select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams - from sqlalchemy import select async with get_db_session() as session: stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if stream and stream.stream_interest_score is not None: interest_score = float(stream.stream_interest_score) logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}") diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py index bd7f41e2d..0d1baded1 100644 --- a/src/chat/express/express_utils.py +++ b/src/chat/express/express_utils.py @@ -5,14 +5,14 @@ import difflib import random import re -from typing import Any, Dict, List, Optional +from typing import Any from src.common.logger import get_logger logger = get_logger("express_utils") -def filter_message_content(content: Optional[str]) -> str: +def filter_message_content(content: str | None) -> str: """ 过滤消息内容,移除回复、@、图片等格式 @@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float: return difflib.SequenceMatcher(None, text1, text2).ratio() -def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]: +def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]: """ 加权随机抽样函数 @@ -108,7 +108,7 @@ def normalize_text(text: str) -> str: return text.strip() -def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: +def extract_keywords(text: str, max_keywords: int = 10) -> list[str]: """ 简单的关键词提取(基于词频) @@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: return words[:max_keywords] -def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str: +def format_expression_pair(situation: str, style: str, index: int | None = None) -> str: """ 格式化表达方式对 @@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No return f'当"{situation}"时,使用"{style}"' -def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: +def parse_expression_pair(text: str) -> tuple[str, str] | None: """ 解析表达方式对文本 @@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: return None -def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]: +def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]: """ 批量去重表达方式 @@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif def merge_expressions_from_multiple_chats( - expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100 -) -> List[Dict[str, Any]]: + expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100 +) -> list[dict[str, Any]]: """ 合并多个聊天室的表达方式 diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 75864be40..2cfe2ed8d 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -438,9 +438,9 @@ class ExpressionLearner: try: # 获取 StyleLearner 实例 learner = style_learner_manager.get_learner(chat_id) - + logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") - + # 为每个学习到的表达方式训练模型 # 使用 situation 作为输入,style 作为目标 # 这是最符合语义的方式:场景 -> 表达方式 @@ -448,25 +448,25 @@ class ExpressionLearner: for expr in expr_list: situation = expr["situation"] style = expr["style"] - + # 训练映射关系: situation -> style if learner.learn_mapping(situation, style): success_count += 1 else: logger.warning(f"训练失败: {situation} -> {style}") - + logger.info( f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " f"当前风格总数={len(learner.get_all_styles())}, " f"总样本数={learner.learning_stats['total_samples']}" ) - + # 保存模型 if learner.save(style_learner_manager.model_save_path): logger.info(f"StyleLearner 模型保存成功: {chat_id}") else: logger.error(f"StyleLearner 模型保存失败: {chat_id}") - + except Exception as e: logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True) @@ -527,7 +527,7 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的response: {response}") expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) - + if not expressions: logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。") logger.info(f"LLM完整响应:\n{response}") @@ -542,26 +542,26 @@ class ExpressionLearner: """ expressions: list[tuple[str, str, str]] = [] failed_lines = [] - + for line_num, line in enumerate(response.splitlines(), 1): line = line.strip() if not line: continue - + # 替换中文引号为英文引号,便于统一处理 line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"') - + # 查找"当"和下一个引号 idx_when = line_normalized.find('当"') if idx_when == -1: # 尝试不带引号的格式: 当xxx时 - idx_when = line_normalized.find('当') + idx_when = line_normalized.find("当") if idx_when == -1: failed_lines.append((line_num, line, "找不到'当'关键字")) continue - + # 提取"当"和"时"之间的内容 - idx_shi = line_normalized.find('时', idx_when) + idx_shi = line_normalized.find("时", idx_when) if idx_shi == -1: failed_lines.append((line_num, line, "找不到'时'关键字")) continue @@ -575,20 +575,20 @@ class ExpressionLearner: continue situation = line_normalized[idx_quote1 + 1 : idx_quote2] search_start = idx_quote2 - + # 查找"使用"或"可以" idx_use = line_normalized.find('使用"', search_start) if idx_use == -1: idx_use = line_normalized.find('可以"', search_start) if idx_use == -1: # 尝试不带引号的格式 - idx_use = line_normalized.find('使用', search_start) + idx_use = line_normalized.find("使用", search_start) if idx_use == -1: - idx_use = line_normalized.find('可以', search_start) + idx_use = line_normalized.find("可以", search_start) if idx_use == -1: failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字")) continue - + # 提取剩余部分作为style style = line_normalized[idx_use + 2:].strip('"\'"",。') if not style: @@ -610,24 +610,24 @@ class ExpressionLearner: style = line_normalized[idx_quote3 + 1:].strip('"\'""') else: style = line_normalized[idx_quote3 + 1 : idx_quote4] - + # 清理并验证 situation = situation.strip() style = style.strip() - + if not situation or not style: failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'")) continue - + expressions.append((chat_id, situation, style)) - + # 记录解析失败的行 if failed_lines: logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:") for line_num, line, reason in failed_lines[:5]: # 只显示前5个 logger.warning(f" 行{line_num}: {reason}") logger.debug(f" 原文: {line}") - + if not expressions: logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}") else: diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 1dbf7e08e..568cde3c3 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -267,11 +267,11 @@ class ExpressionSelector: chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history]) else: chat_info = chat_history - + # 根据配置选择模式 mode = global_config.expression.mode logger.debug(f"[ExpressionSelector] 使用模式: {mode}") - + if mode == "exp_model": return await self._select_expressions_model_only( chat_id=chat_id, @@ -288,7 +288,7 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + async def _select_expressions_classic( self, chat_id: str, @@ -298,7 +298,7 @@ class ExpressionSelector: min_num: int = 5, ) -> list[dict[str, Any]]: """经典模式:随机抽样 + LLM评估""" - logger.debug(f"[Classic模式] 使用LLM评估表达方式") + logger.debug("[Classic模式] 使用LLM评估表达方式") return await self.select_suitable_expressions_llm( chat_id=chat_id, chat_info=chat_info, @@ -306,7 +306,7 @@ class ExpressionSelector: min_num=min_num, target_message=target_message ) - + async def _select_expressions_model_only( self, chat_id: str, @@ -316,22 +316,22 @@ class ExpressionSelector: min_num: int = 5, ) -> list[dict[str, Any]]: """模型预测模式:先提取情境,再使用StyleLearner预测表达风格""" - logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") - + logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") + # 检查是否允许在此聊天流中使用表达 if not self.can_use_expression_for_chat(chat_id): logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") return [] - + # 步骤1: 提取聊天情境 situations = await situation_extractor.extract_situations( chat_history=chat_info, target_message=target_message, max_situations=3 ) - + if not situations: - logger.warning(f"无法提取聊天情境,回退到经典模式") + logger.warning("无法提取聊天情境,回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -339,17 +339,17 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}") - + # 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式 learner = style_learner_manager.get_learner(chat_id) - + all_predicted_styles = {} for i, situation in enumerate(situations, 1): logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}") best_style, scores = learner.predict_style(situation, top_k=max_num) - + if best_style and scores: logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}") # 合并分数(取最高分) @@ -357,10 +357,10 @@ class ExpressionSelector: if style not in all_predicted_styles or score > all_predicted_styles[style]: all_predicted_styles[style] = score else: - logger.debug(f" 该情境未返回预测结果") - + logger.debug(" 该情境未返回预测结果") + if not all_predicted_styles: - logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") + logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -368,22 +368,22 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + # 将分数字典转换为列表格式 [(style, score), ...] predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True) - + logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}") - + # 步骤3: 根据预测的风格从数据库获取表达方式 - logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式") + logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式") expressions = await self.get_model_predicted_expressions( chat_id=chat_id, predicted_styles=predicted_styles, max_num=max_num ) - + if not expressions: - logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") + logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -391,10 +391,10 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式") return expressions - + async def get_model_predicted_expressions( self, chat_id: str, @@ -414,15 +414,15 @@ class ExpressionSelector: """ if not predicted_styles: return [] - + # 提取风格名称(前3个最佳匹配) style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]] logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}") - + # 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式) related_chat_ids = self.get_related_chat_ids(chat_id) logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}") - + async with get_db_session() as session: # 🔍 先检查数据库中实际有哪些 chat_id 的数据 db_chat_ids_result = await session.execute( @@ -432,7 +432,7 @@ class ExpressionSelector: ) db_chat_ids = [cid for cid in db_chat_ids_result.scalars()] logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}") - + # 获取所有相关 chat_id 的表达方式(用于模糊匹配) all_expressions_result = await session.execute( select(Expression) @@ -440,51 +440,51 @@ class ExpressionSelector: .where(Expression.type == "style") ) all_expressions = list(all_expressions_result.scalars()) - + logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}") - + # 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id if not all_expressions: - logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询") + logger.info("相关chat_id没有数据,尝试从所有chat_id查询") all_expressions_result = await session.execute( select(Expression) .where(Expression.type == "style") ) all_expressions = list(all_expressions_result.scalars()) logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}") - + if not all_expressions: - logger.warning(f"数据库中完全没有任何表达方式,需要先学习") + logger.warning("数据库中完全没有任何表达方式,需要先学习") return [] - + # 🔥 使用模糊匹配而不是精确匹配 # 计算每个预测style与数据库style的相似度 from difflib import SequenceMatcher - + matched_expressions = [] for expr in all_expressions: db_style = expr.style or "" max_similarity = 0.0 best_predicted = "" - + # 与每个预测的style计算相似度 for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 # 计算字符串相似度 similarity = SequenceMatcher(None, predicted_style, db_style).ratio() - + # 也检查包含关系(如果一个是另一个的子串,给更高分) if len(predicted_style) >= 2 and len(db_style) >= 2: if predicted_style in db_style or db_style in predicted_style: similarity = max(similarity, 0.7) - + if similarity > max_similarity: max_similarity = similarity best_predicted = predicted_style - + # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 if max_similarity >= 0.3: # 30%相似度阈值 matched_expressions.append((expr, max_similarity, expr.count, best_predicted)) - + if not matched_expressions: # 收集数据库中的style样例用于调试 all_styles = [e.style for e in all_expressions[:10]] @@ -495,11 +495,11 @@ class ExpressionSelector: f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式" ) return [] - + # 按照相似度*count排序,选择最佳匹配 matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True) expressions_objs = [e[0] for e in matched_expressions[:max_num]] - + # 显示最佳匹配的详细信息 top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]] logger.info( @@ -507,7 +507,7 @@ class ExpressionSelector: f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n" f" Top3匹配: {top_matches}" ) - + # 转换为字典格式 expressions = [] for expr in expressions_objs: @@ -518,7 +518,7 @@ class ExpressionSelector: "count": float(expr.count) if expr.count else 0.0, "last_active_time": expr.last_active_time or 0.0 }) - + logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式") return expressions diff --git a/src/chat/express/expressor_model/model.py b/src/chat/express/expressor_model/model.py index 8c18240a8..c2b665878 100644 --- a/src/chat/express/expressor_model/model.py +++ b/src/chat/express/expressor_model/model.py @@ -5,7 +5,6 @@ import os import pickle from collections import Counter, defaultdict -from typing import Dict, Optional, Tuple from src.common.logger import get_logger @@ -36,14 +35,14 @@ class ExpressorModel: self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) # 候选表达管理 - self._candidates: Dict[str, str] = {} # cid -> text (style) - self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) + self._candidates: dict[str, str] = {} # cid -> text (style) + self._situations: dict[str, str] = {} # cid -> situation (不参与计算) logger.info( f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})" ) - def add_candidate(self, cid: str, text: str, situation: Optional[str] = None): + def add_candidate(self, cid: str, text: str, situation: str | None = None): """ 添加候选文本和对应的situation @@ -62,7 +61,7 @@ class ExpressorModel: if cid not in self.nb.token_counts: self.nb.token_counts[cid] = defaultdict(float) - def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: + def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]: """ 直接对所有候选进行朴素贝叶斯评分 @@ -113,7 +112,7 @@ class ExpressorModel: tf = Counter(toks) self.nb.update_positive(tf, cid) - def decay(self, factor: Optional[float] = None): + def decay(self, factor: float | None = None): """ 应用知识衰减 @@ -122,7 +121,7 @@ class ExpressorModel: """ self.nb.decay(factor) - def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: + def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]: """ 获取候选信息 @@ -136,7 +135,7 @@ class ExpressorModel: situation = self._situations.get(cid) return style, situation - def get_all_candidates(self) -> Dict[str, Tuple[str, str]]: + def get_all_candidates(self) -> dict[str, tuple[str, str]]: """ 获取所有候选 @@ -205,7 +204,7 @@ class ExpressorModel: logger.info(f"模型已从 {path} 加载") - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取模型统计信息""" nb_stats = self.nb.get_stats() return { diff --git a/src/chat/express/expressor_model/online_nb.py b/src/chat/express/expressor_model/online_nb.py index 39bd0d1cd..06230bdf7 100644 --- a/src/chat/express/expressor_model/online_nb.py +++ b/src/chat/express/expressor_model/online_nb.py @@ -4,7 +4,6 @@ """ import math from collections import Counter, defaultdict -from typing import Dict, List, Optional from src.common.logger import get_logger @@ -28,15 +27,15 @@ class OnlineNaiveBayes: self.V = vocab_size # 类别统计 - self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count - self.token_counts: Dict[str, Dict[str, float]] = defaultdict( + self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count + self.token_counts: dict[str, dict[str, float]] = defaultdict( lambda: defaultdict(float) ) # cid -> term -> count # 缓存 - self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + self._logZ: dict[str, float] = {} # cache log(∑counts + Vα) - def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: + def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]: """ 批量计算候选的贝叶斯分数 @@ -51,7 +50,7 @@ class OnlineNaiveBayes: n_cls = max(1, len(self.cls_counts)) denom_prior = math.log(total_cls + self.beta * n_cls) - out: Dict[str, float] = {} + out: dict[str, float] = {} for cid in cids: # 计算先验概率 log P(c) prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior @@ -88,7 +87,7 @@ class OnlineNaiveBayes: self.cls_counts[cid] += inc self._invalidate(cid) - def decay(self, factor: Optional[float] = None): + def decay(self, factor: float | None = None): """ 知识衰减(遗忘机制) @@ -133,7 +132,7 @@ class OnlineNaiveBayes: if cid in self._logZ: del self._logZ[cid] - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取统计信息""" return { "n_classes": len(self.cls_counts), diff --git a/src/chat/express/expressor_model/tokenizer.py b/src/chat/express/expressor_model/tokenizer.py index e25f780d4..b12cdc713 100644 --- a/src/chat/express/expressor_model/tokenizer.py +++ b/src/chat/express/expressor_model/tokenizer.py @@ -1,7 +1,6 @@ """ 文本分词器,支持中文Jieba分词 """ -from typing import List from src.common.logger import get_logger @@ -30,7 +29,7 @@ class Tokenizer: logger.warning("Jieba未安装,将使用字符级分词") self.use_jieba = False - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: """ 分词并返回token列表 diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py index 8ebe0a8bd..1393d5a1b 100644 --- a/src/chat/express/situation_extractor.py +++ b/src/chat/express/situation_extractor.py @@ -2,7 +2,6 @@ 情境提取器 从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测 """ -from typing import Optional from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger @@ -41,17 +40,17 @@ def init_prompt(): class SituationExtractor: """情境提取器,从聊天历史中提取当前情境""" - + def __init__(self): self.llm_model = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="expression.situation_extractor" ) - + async def extract_situations( self, chat_history: list | str, - target_message: Optional[str] = None, + target_message: str | None = None, max_situations: int = 3 ) -> list[str]: """ @@ -68,18 +67,18 @@ class SituationExtractor: # 转换chat_history为字符串 if isinstance(chat_history, list): chat_info = "\n".join([ - f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" + f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history ]) else: chat_info = chat_history - + # 构建目标消息信息 if target_message: target_message_info = f",现在你想要回复消息:{target_message}" else: target_message_info = "" - + # 构建 prompt try: prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format( @@ -87,31 +86,31 @@ class SituationExtractor: chat_history=chat_info, target_message_info=target_message_info ) - + # 调用 LLM response, _ = await self.llm_model.generate_response_async( prompt=prompt, temperature=0.3 ) - + if not response or not response.strip(): logger.warning("LLM返回空响应,无法提取情境") return [] - + # 解析响应 situations = self._parse_situations(response, max_situations) - + if situations: logger.debug(f"提取到 {len(situations)} 个情境: {situations}") else: logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}") - + return situations - + except Exception as e: logger.error(f"提取情境失败: {e}") return [] - + @staticmethod def _parse_situations(response: str, max_situations: int) -> list[str]: """ @@ -125,33 +124,33 @@ class SituationExtractor: 情境描述列表 """ situations = [] - + for line in response.splitlines(): line = line.strip() if not line: continue - + # 移除可能的序号、引号等 line = line.lstrip('0123456789.、-*>))】] \t"\'""''') line = line.rstrip('"\'""''') line = line.strip() - + if not line: continue - + # 过滤掉明显不是情境描述的内容 if len(line) > 30: # 太长 continue if len(line) < 2: # 太短 continue - if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']): + if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]): continue - + situations.append(line) - + if len(situations) >= max_situations: break - + return situations diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index c254ef98c..1ea54dd83 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -5,7 +5,6 @@ """ import os import time -from typing import Dict, List, Optional, Tuple from src.common.logger import get_logger @@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" - def __init__(self, chat_id: str, model_config: Optional[Dict] = None): + def __init__(self, chat_id: str, model_config: dict | None = None): """ Args: chat_id: 聊天室ID @@ -37,9 +36,9 @@ class StyleLearner: # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 - self.style_to_id: Dict[str, str] = {} # style文本 -> style_id - self.id_to_style: Dict[str, str] = {} # style_id -> style文本 - self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 + self.style_to_id: dict[str, str] = {} # style文本 -> style_id + self.id_to_style: dict[str, str] = {} # style_id -> style文本 + self.id_to_situation: dict[str, str] = {} # style_id -> situation文本 self.next_style_id = 0 # 学习统计 @@ -51,7 +50,7 @@ class StyleLearner: logger.info(f"StyleLearner初始化成功: chat_id={chat_id}") - def add_style(self, style: str, situation: Optional[str] = None) -> bool: + def add_style(self, style: str, situation: str | None = None) -> bool: """ 动态添加一个新的风格 @@ -130,7 +129,7 @@ class StyleLearner: logger.error(f"学习映射失败: {e}") return False - def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 根据up_content预测最合适的style @@ -146,7 +145,7 @@ class StyleLearner: if not self.style_to_id: logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}") return None, {} - + best_style_id, scores = self.expressor.predict(up_content, k=top_k) if best_style_id is None: @@ -155,7 +154,7 @@ class StyleLearner: # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) - + if best_style is None: logger.warning( f"style_id无法转换为style文本: style_id={best_style_id}, " @@ -171,7 +170,7 @@ class StyleLearner: style_scores[style_text] = score else: logger.warning(f"跳过无法转换的style_id: {sid}") - + logger.debug( f"预测成功: up_content={up_content[:30]}..., " f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}" @@ -183,7 +182,7 @@ class StyleLearner: logger.error(f"预测style失败: {e}", exc_info=True) return None, {} - def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: + def get_style_info(self, style: str) -> tuple[str | None, str | None]: """ 获取style的完整信息 @@ -200,7 +199,7 @@ class StyleLearner: situation = self.id_to_situation.get(style_id) return style_id, situation - def get_all_styles(self) -> List[str]: + def get_all_styles(self) -> list[str]: """ 获取所有风格列表 @@ -209,7 +208,7 @@ class StyleLearner: """ return list(self.style_to_id.keys()) - def apply_decay(self, factor: Optional[float] = None): + def apply_decay(self, factor: float | None = None): """ 应用知识衰减 @@ -304,7 +303,7 @@ class StyleLearner: logger.error(f"加载StyleLearner失败: {e}") return False - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取统计信息""" model_stats = self.expressor.get_stats() return { @@ -324,7 +323,7 @@ class StyleLearnerManager: Args: model_save_path: 模型保存路径 """ - self.learners: Dict[str, StyleLearner] = {} + self.learners: dict[str, StyleLearner] = {} self.model_save_path = model_save_path # 确保保存目录存在 @@ -332,7 +331,7 @@ class StyleLearnerManager: logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") - def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: + def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner: """ 获取或创建指定chat_id的学习器 @@ -369,7 +368,7 @@ class StyleLearnerManager: learner = self.get_learner(chat_id) return learner.learn_mapping(up_content, style) - def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 预测最合适的风格 @@ -399,7 +398,7 @@ class StyleLearnerManager: logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}") return success - def apply_decay_all(self, factor: Optional[float] = None): + def apply_decay_all(self, factor: float | None = None): """ 对所有学习器应用知识衰减 @@ -409,9 +408,9 @@ class StyleLearnerManager: for learner in self.learners.values(): learner.apply_decay(factor) - logger.info(f"对所有StyleLearner应用知识衰减") + logger.info("对所有StyleLearner应用知识衰减") - def get_all_stats(self) -> Dict[str, Dict]: + def get_all_stats(self) -> dict[str, dict]: """ 获取所有学习器的统计信息 diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index b2d9a93cd..3f29081c8 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -169,6 +169,7 @@ class BotInterestManager: 2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度 3. 生成15-25个不等的标签 4. 标签应该是具体的关键词,而不是抽象概念 +5. 每个标签的长度不超过4个字符 请以JSON格式返回,格式如下: {{ @@ -207,6 +208,11 @@ class BotInterestManager: tag_name = tag_data.get("name", f"标签_{i}") weight = tag_data.get("weight", 0.5) + # 检查标签长度,如果过长则截断 + if len(tag_name) > 10: + logger.warning(f"⚠️ 标签 '{tag_name}' 过长,将截断为10个字符") + tag_name = tag_name[:10] + tag = BotInterestTag(tag_name=tag_name, weight=weight) bot_interests.interest_tags.append(tag) @@ -355,6 +361,8 @@ class BotInterestManager: # 使用LLMRequest获取embedding logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'") + if not self.embedding_request: + raise RuntimeError("❌ Embedding客户端未初始化") embedding, model_name = await self.embedding_request.get_embedding(text) if embedding and len(embedding) > 0: @@ -504,7 +512,7 @@ class BotInterestManager: ) # 添加直接关键词匹配奖励 - keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags) + keyword_bonus = self._calculate_keyword_match_bonus(keywords or [], result.matched_tags) logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}") # 应用关键词奖励到匹配分数 @@ -616,17 +624,18 @@ class BotInterestManager: def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: """计算余弦相似度""" try: - vec1 = np.array(vec1) - vec2 = np.array(vec2) + np_vec1 = np.array(vec1) + np_vec2 = np.array(vec2) - dot_product = np.dot(vec1, vec2) - norm1 = np.linalg.norm(vec1) - norm2 = np.linalg.norm(vec2) + dot_product = np.dot(np_vec1, np_vec2) + norm1 = np.linalg.norm(np_vec1) + norm2 = np.linalg.norm(np_vec2) if norm1 == 0 or norm2 == 0: return 0.0 - return dot_product / (norm1 * norm2) + similarity = dot_product / (norm1 * norm2) + return float(similarity) except Exception as e: logger.error(f"计算余弦相似度失败: {e}") @@ -758,7 +767,7 @@ class BotInterestManager: if existing_record: # 更新现有记录 logger.info("🔄 更新现有的兴趣标签配置") - existing_record.interest_tags = json_data + existing_record.interest_tags = json_data.decode("utf-8") existing_record.personality_description = interests.personality_description existing_record.embedding_model = interests.embedding_model existing_record.version = interests.version @@ -772,7 +781,7 @@ class BotInterestManager: new_record = DBBotPersonalityInterests( personality_id=interests.personality_id, personality_description=interests.personality_description, - interest_tags=json_data, + interest_tags=json_data.decode("utf-8"), embedding_model=interests.embedding_model, version=interests.version, last_updated=interests.last_updated, diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index cc436a1c2..53ad47e84 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -503,7 +503,7 @@ class MemorySystem: existing_id = self._memory_fingerprints.get(fingerprint_key) if existing_id and existing_id not in new_memory_ids: candidate_ids.add(existing_id) - except Exception as exc: # noqa: PERF203 + except Exception as exc: logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc) # 基于主体索引的候选(使用统一存储) @@ -1739,10 +1739,8 @@ def get_memory_system() -> MemorySystem: if memory_system is None: logger.warning("Global memory_system is None. Creating new uninitialized instance. This might be a problem.") memory_system = MemorySystem() - logger.info(f"get_memory_system() called, returning instance with id: {id(memory_system)}") return memory_system - async def initialize_memory_system(llm_model: LLMRequest | None = None): """初始化全局记忆系统""" global memory_system diff --git a/src/chat/message_manager/adaptive_stream_manager.py b/src/chat/message_manager/adaptive_stream_manager.py deleted file mode 100644 index fa0a97de5..000000000 --- a/src/chat/message_manager/adaptive_stream_manager.py +++ /dev/null @@ -1,482 +0,0 @@ -""" -自适应流管理器 - 动态并发限制和异步流池管理 -根据系统负载和流优先级动态调整并发限制 -""" - -import asyncio -import time -from dataclasses import dataclass, field -from enum import Enum - -import psutil - -from src.common.logger import get_logger - -logger = get_logger("adaptive_stream_manager") - - -class StreamPriority(Enum): - """流优先级""" - - LOW = 1 - NORMAL = 2 - HIGH = 3 - CRITICAL = 4 - - -@dataclass -class SystemMetrics: - """系统指标""" - - cpu_usage: float = 0.0 - memory_usage: float = 0.0 - active_coroutines: int = 0 - event_loop_lag: float = 0.0 - timestamp: float = field(default_factory=time.time) - - -@dataclass -class StreamMetrics: - """流指标""" - - stream_id: str - priority: StreamPriority - message_rate: float = 0.0 # 消息速率(消息/分钟) - response_time: float = 0.0 # 平均响应时间 - last_activity: float = field(default_factory=time.time) - consecutive_failures: int = 0 - is_active: bool = True - - -class AdaptiveStreamManager: - """自适应流管理器""" - - def __init__( - self, - base_concurrent_limit: int = 50, - max_concurrent_limit: int = 200, - min_concurrent_limit: int = 10, - metrics_window: float = 60.0, # 指标窗口时间 - adjustment_interval: float = 30.0, # 调整间隔 - cpu_threshold_high: float = 0.8, # CPU高负载阈值 - cpu_threshold_low: float = 0.3, # CPU低负载阈值 - memory_threshold_high: float = 0.85, # 内存高负载阈值 - ): - self.base_concurrent_limit = base_concurrent_limit - self.max_concurrent_limit = max_concurrent_limit - self.min_concurrent_limit = min_concurrent_limit - self.metrics_window = metrics_window - self.adjustment_interval = adjustment_interval - self.cpu_threshold_high = cpu_threshold_high - self.cpu_threshold_low = cpu_threshold_low - self.memory_threshold_high = memory_threshold_high - - # 当前状态 - self.current_limit = base_concurrent_limit - self.active_streams: set[str] = set() - self.pending_streams: set[str] = set() - self.stream_metrics: dict[str, StreamMetrics] = {} - - # 异步信号量 - self.semaphore = asyncio.Semaphore(base_concurrent_limit) - self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量 - - # 系统监控 - self.system_metrics: list[SystemMetrics] = [] - self.last_adjustment_time = 0.0 - - # 统计信息 - self.stats = { - "total_requests": 0, - "accepted_requests": 0, - "rejected_requests": 0, - "priority_accepts": 0, - "limit_adjustments": 0, - "avg_concurrent_streams": 0, - "peak_concurrent_streams": 0, - } - - # 监控任务 - self.monitor_task: asyncio.Task | None = None - self.adjustment_task: asyncio.Task | None = None - self.is_running = False - - logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})") - - async def start(self): - """启动自适应管理器""" - if self.is_running: - logger.warning("自适应流管理器已经在运行") - return - - self.is_running = True - self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor") - self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment") - - async def stop(self): - """停止自适应管理器""" - if not self.is_running: - return - - self.is_running = False - - # 停止监控任务 - if self.monitor_task and not self.monitor_task.done(): - self.monitor_task.cancel() - try: - await asyncio.wait_for(self.monitor_task, timeout=10.0) - except asyncio.TimeoutError: - logger.warning("系统监控任务停止超时") - except Exception as e: - logger.error(f"停止系统监控任务时出错: {e}") - - if self.adjustment_task and not self.adjustment_task.done(): - self.adjustment_task.cancel() - try: - await asyncio.wait_for(self.adjustment_task, timeout=10.0) - except asyncio.TimeoutError: - logger.warning("限制调整任务停止超时") - except Exception as e: - logger.error(f"停止限制调整任务时出错: {e}") - - logger.info("自适应流管理器已停止") - - async def acquire_stream_slot( - self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False - ) -> bool: - """ - 获取流处理槽位 - - Args: - stream_id: 流ID - priority: 优先级 - force: 是否强制获取(突破限制) - - Returns: - bool: 是否成功获取槽位 - """ - # 检查管理器是否已启动 - if not self.is_running: - logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}") - return True - - self.stats["total_requests"] += 1 - current_time = time.time() - - # 更新流指标 - if stream_id not in self.stream_metrics: - self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority) - self.stream_metrics[stream_id].last_activity = current_time - - # 检查是否已经活跃 - if stream_id in self.active_streams: - logger.debug(f"流 {stream_id} 已经在活跃列表中") - return True - - # 优先级处理 - if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]: - return await self._acquire_priority_slot(stream_id, priority, force) - - # 检查是否需要强制分发(消息积压) - if not force and self._should_force_dispatch(stream_id): - force = True - logger.info(f"流 {stream_id} 消息积压严重,强制分发") - - # 尝试获取常规信号量 - try: - # 使用wait_for实现非阻塞获取 - acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001) - if acquired: - self.active_streams.add(stream_id) - self.stats["accepted_requests"] += 1 - logger.debug(f"流 {stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})") - return True - except asyncio.TimeoutError: - logger.debug(f"常规信号量已满: {stream_id}") - except Exception as e: - logger.warning(f"获取常规槽位时出错: {e}") - - # 如果强制分发,尝试突破限制 - if force: - return await self._force_acquire_slot(stream_id) - - # 无法获取槽位 - self.stats["rejected_requests"] += 1 - logger.debug(f"流 {stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}") - return False - - async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool: - """获取优先级槽位""" - try: - # 优先级信号量有少量槽位 - acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001) - if acquired: - self.active_streams.add(stream_id) - self.stats["priority_accepts"] += 1 - self.stats["accepted_requests"] += 1 - logger.debug(f"流 {stream_id} 获取优先级槽位成功 (优先级: {priority.name})") - return True - except asyncio.TimeoutError: - logger.debug(f"优先级信号量已满: {stream_id}") - except Exception as e: - logger.warning(f"获取优先级槽位时出错: {e}") - - # 如果优先级槽位也满了,检查是否强制 - if force or priority == StreamPriority.CRITICAL: - return await self._force_acquire_slot(stream_id) - - return False - - async def _force_acquire_slot(self, stream_id: str) -> bool: - """强制获取槽位(突破限制)""" - # 检查是否超过最大限制 - if len(self.active_streams) >= self.max_concurrent_limit: - logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发") - return False - - # 强制添加到活跃列表 - self.active_streams.add(stream_id) - self.stats["accepted_requests"] += 1 - logger.warning(f"流 {stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})") - return True - - def release_stream_slot(self, stream_id: str): - """释放流处理槽位""" - if stream_id in self.active_streams: - self.active_streams.remove(stream_id) - - # 释放相应的信号量 - metrics = self.stream_metrics.get(stream_id) - if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]: - self.priority_semaphore.release() - else: - self.semaphore.release() - - logger.debug(f"流 {stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})") - - def _should_force_dispatch(self, stream_id: str) -> bool: - """判断是否应该强制分发""" - # 这里可以实现基于消息积压的判断逻辑 - # 简化版本:基于流的历史活跃度和优先级 - metrics = self.stream_metrics.get(stream_id) - if not metrics: - return False - - # 如果是高优先级流,更容易强制分发 - if metrics.priority == StreamPriority.HIGH: - return True - - # 如果最近有活跃且响应时间较长,可能需要强制分发 - current_time = time.time() - if ( - current_time - metrics.last_activity < 300 # 5分钟内有活动 - and metrics.response_time > 5.0 - ): # 响应时间超过5秒 - return True - - return False - - async def _system_monitor_loop(self): - """系统监控循环""" - logger.info("系统监控循环启动") - - while self.is_running: - try: - await asyncio.sleep(5.0) # 每5秒监控一次 - await self._collect_system_metrics() - except asyncio.CancelledError: - logger.info("系统监控循环被取消") - break - except Exception as e: - logger.error(f"系统监控出错: {e}") - - logger.info("系统监控循环结束") - - async def _collect_system_metrics(self): - """收集系统指标""" - try: - # CPU使用率 - cpu_usage = psutil.cpu_percent(interval=None) / 100.0 - - # 内存使用率 - memory = psutil.virtual_memory() - memory_usage = memory.percent / 100.0 - - # 活跃协程数量 - try: - active_coroutines = len(asyncio.all_tasks()) - except: - active_coroutines = 0 - - # 事件循环延迟 - event_loop_lag = 0.0 - try: - asyncio.get_running_loop() - start_time = time.time() - await asyncio.sleep(0) - event_loop_lag = time.time() - start_time - except: - pass - - metrics = SystemMetrics( - cpu_usage=cpu_usage, - memory_usage=memory_usage, - active_coroutines=active_coroutines, - event_loop_lag=event_loop_lag, - timestamp=time.time(), - ) - - self.system_metrics.append(metrics) - - # 保持指标窗口大小 - cutoff_time = time.time() - self.metrics_window - self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time] - - # 更新统计信息 - self.stats["avg_concurrent_streams"] = ( - self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1 - ) - self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams)) - - except Exception as e: - logger.error(f"收集系统指标失败: {e}") - - async def _adjustment_loop(self): - """限制调整循环""" - logger.info("限制调整循环启动") - - while self.is_running: - try: - await asyncio.sleep(self.adjustment_interval) - await self._adjust_concurrent_limit() - except asyncio.CancelledError: - logger.info("限制调整循环被取消") - break - except Exception as e: - logger.error(f"限制调整出错: {e}") - - logger.info("限制调整循环结束") - - async def _adjust_concurrent_limit(self): - """调整并发限制""" - if not self.system_metrics: - return - - current_time = time.time() - if current_time - self.last_adjustment_time < self.adjustment_interval: - return - - # 计算平均系统指标 - recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics - if not recent_metrics: - return - - avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics) - avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics) - avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics) - - # 调整策略 - old_limit = self.current_limit - adjustment_factor = 1.0 - - # CPU负载调整 - if avg_cpu > self.cpu_threshold_high: - adjustment_factor *= 0.8 # 减少20% - elif avg_cpu < self.cpu_threshold_low: - adjustment_factor *= 1.2 # 增加20% - - # 内存负载调整 - if avg_memory > self.memory_threshold_high: - adjustment_factor *= 0.7 # 减少30% - - # 协程数量调整 - if avg_coroutines > 1000: - adjustment_factor *= 0.9 # 减少10% - - # 应用调整 - new_limit = int(self.current_limit * adjustment_factor) - new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit)) - - # 检查是否需要调整信号量 - if new_limit != self.current_limit: - await self._adjust_semaphore(self.current_limit, new_limit) - self.current_limit = new_limit - self.stats["limit_adjustments"] += 1 - self.last_adjustment_time = current_time - - logger.info( - f"并发限制调整: {old_limit} -> {new_limit} " - f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})" - ) - - async def _adjust_semaphore(self, old_limit: int, new_limit: int): - """调整信号量大小""" - if new_limit > old_limit: - # 增加信号量槽位 - for _ in range(new_limit - old_limit): - self.semaphore.release() - elif new_limit < old_limit: - # 减少信号量槽位(通过等待槽位被释放) - reduction = old_limit - new_limit - for _ in range(reduction): - try: - await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001) - except: - # 如果无法立即获取,说明当前使用量接近限制 - break - - def update_stream_metrics(self, stream_id: str, **kwargs): - """更新流指标""" - if stream_id not in self.stream_metrics: - return - - metrics = self.stream_metrics[stream_id] - for key, value in kwargs.items(): - if hasattr(metrics, key): - setattr(metrics, key, value) - - def get_stats(self) -> dict: - """获取统计信息""" - stats = self.stats.copy() - stats.update( - { - "current_limit": self.current_limit, - "active_streams": len(self.active_streams), - "pending_streams": len(self.pending_streams), - "is_running": self.is_running, - "system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0, - "system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0, - } - ) - - # 计算接受率 - if stats["total_requests"] > 0: - stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"] - else: - stats["acceptance_rate"] = 0 - - return stats - - -# 全局自适应管理器实例 -_adaptive_manager: AdaptiveStreamManager | None = None - - -def get_adaptive_stream_manager() -> AdaptiveStreamManager: - """获取自适应流管理器实例""" - global _adaptive_manager - if _adaptive_manager is None: - _adaptive_manager = AdaptiveStreamManager() - return _adaptive_manager - - -async def init_adaptive_stream_manager(): - """初始化自适应流管理器""" - manager = get_adaptive_stream_manager() - await manager.start() - - -async def shutdown_adaptive_stream_manager(): - """关闭自适应流管理器""" - manager = get_adaptive_stream_manager() - await manager.stop() diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 41bf47781..bd74925c7 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -29,7 +29,6 @@ class SingleStreamContextManager: # 配置参数 self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) - self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 # 元数据 self.created_time = time.time() @@ -37,7 +36,13 @@ class SingleStreamContextManager: self.access_count = 0 self.total_messages = 0 - logger.debug(f"单流上下文管理器初始化: {stream_id}") + # 标记是否已初始化历史消息 + self._history_initialized = False + + logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})") + + # 异步初始化历史消息(不阻塞构造函数) + asyncio.create_task(self._initialize_history_from_db()) def get_context(self) -> StreamContext: """获取流上下文""" @@ -93,27 +98,24 @@ class SingleStreamContextManager: return True else: logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}") - - except ImportError: - logger.debug("MessageManager不可用,使用直接添加模式") except Exception as e: logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}") - # 回退方案:直接添加到未读消息 - message.is_read = False - self.context.unread_messages.append(message) + # 回退方案:直接添加到未读消息 + message.is_read = False + self.context.unread_messages.append(message) - # 自动检测和更新chat type - self._detect_chat_type(message) + # 自动检测和更新chat type + self._detect_chat_type(message) - # 在上下文管理器中计算兴趣值 - await self._calculate_message_interest(message) - self.total_messages += 1 - self.last_access_time = time.time() - # 启动流的循环任务(如果还未启动) - asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id)) - logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") - return True + # 在上下文管理器中计算兴趣值 + await self._calculate_message_interest(message) + self.total_messages += 1 + self.last_access_time = time.time() + # 启动流的循环任务(如果还未启动) + asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id)) + logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") + return True except Exception as e: logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False @@ -298,6 +300,59 @@ class SingleStreamContextManager: self.last_access_time = time.time() self.access_count += 1 + async def _initialize_history_from_db(self): + """从数据库初始化历史消息到context中""" + if self._history_initialized: + logger.info(f"历史消息已初始化,跳过: {self.stream_id}") + return + + # 立即设置标志,防止并发重复加载 + logger.info(f"设置历史初始化标志: {self.stream_id}") + self._history_initialized = True + + try: + logger.info(f"开始从数据库加载历史消息: {self.stream_id}") + + from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + + # 加载历史消息(限制数量为max_context_size的2倍,用于丰富上下文) + db_messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=self.stream_id, + timestamp=time.time(), + limit=self.max_context_size * 2, + ) + + if db_messages: + # 将数据库消息转换为 DatabaseMessages 对象并添加到历史 + for msg_dict in db_messages: + try: + # 使用 ** 解包字典作为关键字参数 + db_msg = DatabaseMessages(**msg_dict) + + # 标记为已读 + db_msg.is_read = True + + # 添加到历史消息 + self.context.history_messages.append(db_msg) + + except Exception as e: + logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}") + continue + + logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}") + else: + logger.debug(f"没有历史消息需要加载: {self.stream_id}") + + except Exception as e: + logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True) + # 加载失败时重置标志,允许重试 + self._history_initialized = False + + async def ensure_history_initialized(self): + """确保历史消息已初始化(供外部调用)""" + if not self._history_initialized: + await self._initialize_history_from_db() + async def _calculate_message_interest(self, message: DatabaseMessages) -> float: """ 在上下文管理器中计算消息的兴趣度 diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index c111bf8b4..c3496b79b 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -9,7 +9,6 @@ from typing import Any from src.chat.chatter_manager import ChatterManager from src.chat.energy_system import energy_manager -from src.chat.message_manager.adaptive_stream_manager import StreamPriority from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger from src.config.config import global_config @@ -70,10 +69,10 @@ class StreamLoopManager: try: # 获取所有活跃的流 from src.plugin_system.apis.chat_api import get_chat_manager - + chat_manager = get_chat_manager() all_streams = await chat_manager.get_all_streams() - + # 创建任务列表以便并发取消 cancel_tasks = [] for chat_stream in all_streams: @@ -117,38 +116,13 @@ class StreamLoopManager: logger.debug(f"流 {stream_id} 循环已在运行") return True - # 使用自适应流管理器获取槽位 - try: - from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager - - adaptive_manager = get_adaptive_stream_manager() - - if adaptive_manager.is_running: - # 确定流优先级 - priority = self._determine_stream_priority(stream_id) - - # 获取处理槽位 - slot_acquired = await adaptive_manager.acquire_stream_slot( - stream_id=stream_id, priority=priority, force=force - ) - - if slot_acquired: - logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})") - else: - logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案") - else: - logger.debug("自适应管理器未运行") - - except Exception as e: - logger.debug(f"自适应管理器获取槽位失败: {e}") - # 创建流循环任务 try: loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") - + # 将任务记录到 StreamContext 中 context.stream_loop_task = loop_task - + # 更新统计信息 self.stats["active_streams"] += 1 self.stats["total_loops"] += 1 @@ -158,35 +132,8 @@ class StreamLoopManager: except Exception as e: logger.error(f"启动流循环任务失败 {stream_id}: {e}") - # 释放槽位 - from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager - - adaptive_manager = get_adaptive_stream_manager() - adaptive_manager.release_stream_slot(stream_id) - return False - def _determine_stream_priority(self, stream_id: str) -> "StreamPriority": - """确定流优先级""" - try: - from src.chat.message_manager.adaptive_stream_manager import StreamPriority - - # 这里可以基于流的历史数据、用户身份等确定优先级 - # 简化版本:基于流ID的哈希值分配优先级 - hash_value = hash(stream_id) % 10 - - if hash_value >= 8: # 20% 高优先级 - return StreamPriority.HIGH - elif hash_value >= 5: # 30% 中等优先级 - return StreamPriority.NORMAL - else: # 50% 低优先级 - return StreamPriority.LOW - - except Exception: - from src.chat.message_manager.adaptive_stream_manager import StreamPriority - - return StreamPriority.NORMAL - async def stop_stream_loop(self, stream_id: str) -> bool: """停止指定流的循环任务 @@ -222,7 +169,7 @@ class StreamLoopManager: # 清空 StreamContext 中的任务记录 context.stream_loop_task = None - + logger.info(f"停止流循环: {stream_id}") return True @@ -248,31 +195,18 @@ class StreamLoopManager: unread_count = self._get_unread_count(context) force_dispatch = self._needs_force_dispatch_for_context(context, unread_count) - # 3. 更新自适应管理器指标 - try: - from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager - - adaptive_manager = get_adaptive_stream_manager() - adaptive_manager.update_stream_metrics( - stream_id, - message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算 - last_activity=time.time(), - ) - except Exception as e: - logger.debug(f"更新流指标失败: {e}") - has_messages = force_dispatch or await self._has_messages_to_process(context) if has_messages: if force_dispatch: logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count) - + # 3. 在处理前更新能量值(用于下次间隔计算) try: await self._update_stream_energy(stream_id, context) except Exception as e: logger.debug(f"更新流能量失败 {stream_id}: {e}") - + # 4. 激活chatter处理 success = await self._process_stream_messages(stream_id, context) @@ -313,16 +247,6 @@ class StreamLoopManager: except Exception as e: logger.debug(f"清理 StreamContext 任务记录失败: {e}") - # 释放自适应管理器的槽位 - try: - from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager - - adaptive_manager = get_adaptive_stream_manager() - adaptive_manager.release_stream_slot(stream_id) - logger.debug(f"释放自适应流处理槽位: {stream_id}") - except Exception as e: - logger.debug(f"释放自适应流处理槽位失败: {e}") - # 清理间隔记录 self._last_intervals.pop(stream_id, None) @@ -447,7 +371,7 @@ class StreamLoopManager: # 清除 Chatter 处理标志 context.is_chatter_processing = False logger.debug(f"清除 Chatter 处理标志: {stream_id}") - + # 无论成功或失败,都要设置处理状态为未处理 self._set_stream_processing_status(stream_id, False) @@ -508,48 +432,48 @@ class StreamLoopManager: """ try: from src.chat.message_receive.chat_stream import get_chat_manager - + # 获取聊天流 chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if not chat_stream: logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") return - + # 从 context_manager 获取消息(包括未读和历史消息) # 合并未读消息和历史消息 all_messages = [] - + # 添加历史消息 history_messages = context.get_history_messages(limit=global_config.chat.max_context_size) all_messages.extend(history_messages) - + # 添加未读消息 unread_messages = context.get_unread_messages() all_messages.extend(unread_messages) - + # 按时间排序并限制数量 all_messages.sort(key=lambda m: m.time) messages = all_messages[-global_config.chat.max_context_size:] - + # 获取用户ID user_id = None if context.triggering_user_id: user_id = context.triggering_user_id - + # 使用能量管理器计算并缓存能量值 energy = await energy_manager.calculate_focus_energy( stream_id=stream_id, messages=messages, user_id=user_id ) - + # 同步更新到 ChatStream chat_stream._focus_energy = energy - + logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}") - + except Exception as e: logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False) @@ -746,7 +670,7 @@ class StreamLoopManager: # 使用 start_stream_loop 重新创建流循环任务 success = await self.start_stream_loop(stream_id, force=True) - + if success: logger.info(f"已创建强制分发流循环: {stream_id}") else: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 49c169640..a06e07be0 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -71,29 +71,9 @@ class MessageManager: except Exception as e: logger.error(f"启动批量数据库写入器失败: {e}") - # 启动流缓存管理器 - try: - from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager - - await init_stream_cache_manager() - except Exception as e: - logger.error(f"启动流缓存管理器失败: {e}") - # 启动消息缓存系统(内置) logger.info("📦 消息缓存系统已启动") - # 启动自适应流管理器 - try: - from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager - - await init_adaptive_stream_manager() - logger.info("🎯 自适应流管理器已启动") - except Exception as e: - logger.error(f"启动自适应流管理器失败: {e}") - - # 启动睡眠和唤醒管理器 - # 睡眠系统的定时任务启动移至 main.py - # 启动流循环管理器并设置chatter_manager await stream_loop_manager.start() stream_loop_manager.set_chatter_manager(self.chatter_manager) @@ -116,30 +96,11 @@ class MessageManager: except Exception as e: logger.error(f"停止批量数据库写入器失败: {e}") - # 停止流缓存管理器 - try: - from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager - - await shutdown_stream_cache_manager() - logger.info("🗄️ 流缓存管理器已停止") - except Exception as e: - logger.error(f"停止流缓存管理器失败: {e}") - # 停止消息缓存系统(内置) self.message_caches.clear() self.stream_processing_status.clear() logger.info("📦 消息缓存系统已停止") - # 停止自适应流管理器 - try: - from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager - - await shutdown_adaptive_stream_manager() - logger.info("🎯 自适应流管理器已停止") - except Exception as e: - logger.error(f"停止自适应流管理器失败: {e}") - - # 停止流循环管理器 await stream_loop_manager.stop() @@ -152,7 +113,7 @@ class MessageManager: # 检查是否为notice消息 if self._is_notice_message(message): # Notice消息处理 - 添加到全局管理器 - logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}") + logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}") await self._handle_notice_message(stream_id, message) # 根据配置决定是否继续处理(触发聊天流程) @@ -206,39 +167,6 @@ class MessageManager: except Exception as e: logger.error(f"更新消息 {message_id} 时发生错误: {e}") - async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int: - """批量更新消息信息,降低更新频率""" - if not updates: - return 0 - - try: - chat_manager = get_chat_manager() - chat_stream = await chat_manager.get_stream(stream_id) - if not chat_stream: - logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在") - return 0 - - updated_count = 0 - for item in updates: - message_id = item.get("message_id") - if not message_id: - continue - - payload = {key: value for key, value in item.items() if key != "message_id" and value is not None} - - if not payload: - continue - - success = await chat_stream.context_manager.update_message(message_id, payload) - if success: - updated_count += 1 - - if updated_count: - logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})") - return updated_count - except Exception as e: - logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}") - return 0 async def add_action(self, stream_id: str, message_id: str, action: str): """添加动作到消息""" @@ -266,7 +194,7 @@ class MessageManager: logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.stream_context + context = chat_stream.context_manager.context context.is_active = False # 取消处理任务 @@ -288,7 +216,7 @@ class MessageManager: logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.stream_context + context = chat_stream.context_manager.context context.is_active = True logger.info(f"激活聊天流: {stream_id}") @@ -304,7 +232,7 @@ class MessageManager: if not chat_stream: return None - context = chat_stream.stream_context + context = chat_stream.context_manager.context unread_count = len(chat_stream.context_manager.get_unread_messages()) return StreamStats( @@ -379,7 +307,7 @@ class MessageManager: # 检查上下文 context = chat_stream.context_manager.context - + # 只有当 Chatter 真正在处理时才检查打断 if not context.is_chatter_processing: logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查") @@ -387,7 +315,7 @@ class MessageManager: # 检查是否有 stream_loop_task 在运行 stream_loop_task = context.stream_loop_task - + if stream_loop_task and not stream_loop_task.done(): # 检查触发用户ID triggering_user_id = context.triggering_user_id @@ -447,7 +375,7 @@ class MessageManager: await asyncio.sleep(0.1) # 获取当前的stream context - context = chat_stream.stream_context + context = chat_stream.context_manager.context # 确保有未读消息需要处理 unread_messages = context.get_unread_messages() @@ -459,7 +387,7 @@ class MessageManager: # 重新创建 stream_loop 任务 success = await stream_loop_manager.start_stream_loop(stream_id, force=True) - + if success: logger.info(f"✅ 成功重新创建流循环任务: {stream_id}") else: diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py deleted file mode 100644 index ea85c3855..000000000 --- a/src/chat/message_manager/stream_cache_manager.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -流缓存管理器 - 使用优化版聊天流和智能缓存策略 -提供分层缓存和自动清理功能 -""" - -import asyncio -import time -from collections import OrderedDict -from dataclasses import dataclass - -from maim_message import GroupInfo, UserInfo - -from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream -from src.common.logger import get_logger - -logger = get_logger("stream_cache_manager") - - -@dataclass -class StreamCacheStats: - """缓存统计信息""" - - hot_cache_size: int = 0 - warm_storage_size: int = 0 - cold_storage_size: int = 0 - total_memory_usage: int = 0 # 估算的内存使用(字节) - cache_hits: int = 0 - cache_misses: int = 0 - evictions: int = 0 - last_cleanup_time: float = 0 - - -class TieredStreamCache: - """分层流缓存管理器""" - - def __init__( - self, - max_hot_size: int = 100, - max_warm_size: int = 500, - max_cold_size: int = 2000, - cleanup_interval: float = 300.0, # 5分钟清理一次 - hot_timeout: float = 1800.0, # 30分钟未访问降级到warm - warm_timeout: float = 7200.0, # 2小时未访问降级到cold - cold_timeout: float = 86400.0, # 24小时未访问删除 - ): - self.max_hot_size = max_hot_size - self.max_warm_size = max_warm_size - self.max_cold_size = max_cold_size - self.cleanup_interval = cleanup_interval - self.hot_timeout = hot_timeout - self.warm_timeout = warm_timeout - self.cold_timeout = cold_timeout - - # 三层缓存存储 - self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU) - self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) - self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) - - # 统计信息 - self.stats = StreamCacheStats() - - # 清理任务 - self.cleanup_task: asyncio.Task | None = None - self.is_running = False - - logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})") - - async def start(self): - """启动缓存管理器""" - if self.is_running: - logger.warning("缓存管理器已经在运行") - return - - self.is_running = True - self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup") - - async def stop(self): - """停止缓存管理器""" - if not self.is_running: - return - - self.is_running = False - - if self.cleanup_task and not self.cleanup_task.done(): - self.cleanup_task.cancel() - try: - await asyncio.wait_for(self.cleanup_task, timeout=10.0) - except asyncio.TimeoutError: - logger.warning("缓存清理任务停止超时") - except Exception as e: - logger.error(f"停止缓存清理任务时出错: {e}") - - logger.info("分层流缓存管理器已停止") - - async def get_or_create_stream( - self, - stream_id: str, - platform: str, - user_info: UserInfo, - group_info: GroupInfo | None = None, - data: dict | None = None, - ) -> OptimizedChatStream: - """获取或创建流 - 优化版本""" - current_time = time.time() - - # 1. 检查热缓存 - if stream_id in self.hot_cache: - stream = self.hot_cache[stream_id] - # 移动到末尾(LRU更新) - self.hot_cache.move_to_end(stream_id) - self.stats.cache_hits += 1 - logger.debug(f"热缓存命中: {stream_id}") - return stream.create_snapshot() - - # 2. 检查温存储 - if stream_id in self.warm_storage: - stream, last_access = self.warm_storage[stream_id] - self.warm_storage[stream_id] = (stream, current_time) - self.stats.cache_hits += 1 - logger.debug(f"温缓存命中: {stream_id}") - # 提升到热缓存 - await self._promote_to_hot(stream_id, stream) - return stream.create_snapshot() - - # 3. 检查冷存储 - if stream_id in self.cold_storage: - stream, last_access = self.cold_storage[stream_id] - self.cold_storage[stream_id] = (stream, current_time) - self.stats.cache_hits += 1 - logger.debug(f"冷缓存命中: {stream_id}") - # 提升到温缓存 - await self._promote_to_warm(stream_id, stream) - return stream.create_snapshot() - - # 4. 缓存未命中,创建新流 - self.stats.cache_misses += 1 - stream = create_optimized_chat_stream( - stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data - ) - logger.debug(f"缓存未命中,创建新流: {stream_id}") - - # 添加到热缓存 - await self._add_to_hot(stream_id, stream) - - return stream - - async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream): - """添加到热缓存""" - # 检查是否需要驱逐 - if len(self.hot_cache) >= self.max_hot_size: - await self._evict_from_hot() - - self.hot_cache[stream_id] = stream - self.stats.hot_cache_size = len(self.hot_cache) - - async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream): - """提升到热缓存""" - # 从温存储中移除 - if stream_id in self.warm_storage: - del self.warm_storage[stream_id] - self.stats.warm_storage_size = len(self.warm_storage) - - # 添加到热缓存 - await self._add_to_hot(stream_id, stream) - logger.debug(f"流 {stream_id} 提升到热缓存") - - async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream): - """提升到温缓存""" - # 从冷存储中移除 - if stream_id in self.cold_storage: - del self.cold_storage[stream_id] - self.stats.cold_storage_size = len(self.cold_storage) - - # 添加到温存储 - if len(self.warm_storage) >= self.max_warm_size: - await self._evict_from_warm() - - current_time = time.time() - self.warm_storage[stream_id] = (stream, current_time) - self.stats.warm_storage_size = len(self.warm_storage) - logger.debug(f"流 {stream_id} 提升到温缓存") - - async def _evict_from_hot(self): - """从热缓存驱逐最久未使用的流""" - if not self.hot_cache: - return - - # LRU驱逐 - stream_id, stream = self.hot_cache.popitem(last=False) - self.stats.evictions += 1 - logger.debug(f"从热缓存驱逐: {stream_id}") - - # 移动到温存储 - if len(self.warm_storage) < self.max_warm_size: - current_time = time.time() - self.warm_storage[stream_id] = (stream, current_time) - self.stats.warm_storage_size = len(self.warm_storage) - else: - # 温存储也满了,直接删除 - logger.debug(f"温存储已满,删除流: {stream_id}") - - self.stats.hot_cache_size = len(self.hot_cache) - - async def _evict_from_warm(self): - """从温存储驱逐最久未使用的流""" - if not self.warm_storage: - return - - # 找到最久未访问的流 - oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1]) - stream, last_access = self.warm_storage.pop(oldest_stream_id) - self.stats.evictions += 1 - logger.debug(f"从温存储驱逐: {oldest_stream_id}") - - # 移动到冷存储 - if len(self.cold_storage) < self.max_cold_size: - current_time = time.time() - self.cold_storage[oldest_stream_id] = (stream, current_time) - self.stats.cold_storage_size = len(self.cold_storage) - else: - # 冷存储也满了,直接删除 - logger.debug(f"冷存储已满,删除流: {oldest_stream_id}") - - self.stats.warm_storage_size = len(self.warm_storage) - - async def _cleanup_loop(self): - """清理循环""" - logger.info("流缓存清理循环启动") - - while self.is_running: - try: - await asyncio.sleep(self.cleanup_interval) - await self._perform_cleanup() - except asyncio.CancelledError: - logger.info("流缓存清理循环被取消") - break - except Exception as e: - logger.error(f"流缓存清理出错: {e}") - - logger.info("流缓存清理循环结束") - - async def _perform_cleanup(self): - """执行清理操作""" - current_time = time.time() - cleanup_stats = { - "hot_to_warm": 0, - "warm_to_cold": 0, - "cold_removed": 0, - } - - # 1. 检查热缓存超时 - hot_to_demote = [] - for stream_id, stream in self.hot_cache.items(): - # 获取最后访问时间(简化:使用创建时间作为近似) - last_access = getattr(stream, "last_active_time", stream.create_time) - if current_time - last_access > self.hot_timeout: - hot_to_demote.append(stream_id) - - for stream_id in hot_to_demote: - stream = self.hot_cache.pop(stream_id) - current_time_local = time.time() - self.warm_storage[stream_id] = (stream, current_time_local) - cleanup_stats["hot_to_warm"] += 1 - - # 2. 检查温存储超时 - warm_to_demote = [] - for stream_id, (stream, last_access) in self.warm_storage.items(): - if current_time - last_access > self.warm_timeout: - warm_to_demote.append(stream_id) - - for stream_id in warm_to_demote: - stream, last_access = self.warm_storage.pop(stream_id) - self.cold_storage[stream_id] = (stream, last_access) - cleanup_stats["warm_to_cold"] += 1 - - # 3. 检查冷存储超时 - cold_to_remove = [] - for stream_id, (stream, last_access) in self.cold_storage.items(): - if current_time - last_access > self.cold_timeout: - cold_to_remove.append(stream_id) - - for stream_id in cold_to_remove: - self.cold_storage.pop(stream_id) - cleanup_stats["cold_removed"] += 1 - - # 更新统计信息 - self.stats.hot_cache_size = len(self.hot_cache) - self.stats.warm_storage_size = len(self.warm_storage) - self.stats.cold_storage_size = len(self.cold_storage) - self.stats.last_cleanup_time = current_time - - # 估算内存使用(粗略估计) - self.stats.total_memory_usage = ( - len(self.hot_cache) * 1024 # 每个热流约1KB - + len(self.warm_storage) * 512 # 每个温流约512B - + len(self.cold_storage) * 256 # 每个冷流约256B - ) - - if sum(cleanup_stats.values()) > 0: - logger.info( - f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, " - f"{cleanup_stats['warm_to_cold']}温→冷, " - f"{cleanup_stats['cold_removed']}冷删除" - ) - - def get_stats(self) -> StreamCacheStats: - """获取缓存统计信息""" - # 计算命中率 - total_requests = self.stats.cache_hits + self.stats.cache_misses - hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0 - - stats_copy = StreamCacheStats( - hot_cache_size=self.stats.hot_cache_size, - warm_storage_size=self.stats.warm_storage_size, - cold_storage_size=self.stats.cold_storage_size, - total_memory_usage=self.stats.total_memory_usage, - cache_hits=self.stats.cache_hits, - cache_misses=self.stats.cache_misses, - evictions=self.stats.evictions, - last_cleanup_time=self.stats.last_cleanup_time, - ) - - # 添加命中率信息 - stats_copy.hit_rate = hit_rate - - return stats_copy - - def clear_cache(self): - """清空所有缓存""" - self.hot_cache.clear() - self.warm_storage.clear() - self.cold_storage.clear() - - self.stats.hot_cache_size = 0 - self.stats.warm_storage_size = 0 - self.stats.cold_storage_size = 0 - self.stats.total_memory_usage = 0 - - logger.info("所有缓存已清空") - - async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None: - """获取流的快照(不修改缓存状态)""" - if stream_id in self.hot_cache: - return self.hot_cache[stream_id].create_snapshot() - elif stream_id in self.warm_storage: - return self.warm_storage[stream_id][0].create_snapshot() - elif stream_id in self.cold_storage: - return self.cold_storage[stream_id][0].create_snapshot() - return None - - def get_cached_stream_ids(self) -> set[str]: - """获取所有缓存的流ID""" - return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys()) - - -# 全局缓存管理器实例 -_cache_manager: TieredStreamCache | None = None - - -def get_stream_cache_manager() -> TieredStreamCache: - """获取流缓存管理器实例""" - global _cache_manager - if _cache_manager is None: - _cache_manager = TieredStreamCache() - return _cache_manager - - -async def init_stream_cache_manager(): - """初始化流缓存管理器""" - manager = get_stream_cache_manager() - await manager.start() - - -async def shutdown_stream_cache_manager(): - """关闭流缓存管理器""" - manager = get_stream_cache_manager() - await manager.stop() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 544dec94f..710a1872d 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -9,13 +9,12 @@ from maim_message import UserInfo from src.chat.antipromptinjector import initialize_anti_injector from src.chat.message_manager import message_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.prompt import create_prompt_async, global_prompt_manager +from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.utils import is_mentioned_bot_in_message +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mood.mood_manager import mood_manager # 导入情绪管理器 from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.core import component_registry, event_manager, global_announcement_manager @@ -73,9 +72,6 @@ class ChatBot: self.bot = None # bot 实例引用 self._started = False self.mood_manager = mood_manager # 获取情绪管理器单例 - # 亲和力流消息处理器 - 直接使用全局afc_manager - - self.s4u_message_processor = S4UMessageProcessor() # 初始化反注入系统 self._initialize_anti_injector() @@ -109,10 +105,10 @@ class ChatBot: self._started = True - async def _process_plus_commands(self, message: MessageRecv): + async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream): """独立处理PlusCommand系统""" try: - text = message.processed_plain_text + text = message.processed_plain_text or "" # 获取配置的命令前缀 from src.config.config import global_config @@ -170,10 +166,10 @@ class ChatBot: # 检查命令是否被禁用 if ( - message.chat_stream - and message.chat_stream.stream_id + chat + and chat.stream_id and plus_command_name - in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) + in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) ): logger.info("用户禁用的PlusCommand,跳过处理") return False, None, True @@ -186,10 +182,13 @@ class ChatBot: # 创建PlusCommand实例 plus_command_instance = plus_command_class(message, plugin_config) + # 为插件实例设置 chat_stream 运行时属性 + setattr(plus_command_instance, "chat_stream", chat) + try: # 检查聊天类型限制 if not plus_command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = chat.group_info is not None logger.info( f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -229,11 +228,11 @@ class ChatBot: logger.error(f"处理PlusCommand时出错: {e}") return False, None, True # 出错时继续处理消息 - async def _process_commands_with_new_system(self, message: MessageRecv): + async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream): # sourcery skip: use-named-expression """使用新插件系统处理命令""" try: - text = message.processed_plain_text + text = message.processed_plain_text or "" # 使用新的组件注册中心查找命令 command_result = component_registry.find_command_by_text(text) @@ -242,10 +241,10 @@ class ChatBot: plugin_name = command_info.plugin_name command_name = command_info.name if ( - message.chat_stream - and message.chat_stream.stream_id + chat + and chat.stream_id and command_name - in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) + in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) ): logger.info("用户禁用的命令,跳过处理") return False, None, True @@ -259,10 +258,13 @@ class ChatBot: command_instance: BaseCommand = command_class(message, plugin_config) command_instance.set_matched_groups(matched_groups) + # 为插件实例设置 chat_stream 运行时属性 + setattr(command_instance, "chat_stream", chat) + try: # 检查聊天类型限制 if not command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = chat.group_info is not None logger.info( f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -299,92 +301,6 @@ class ChatBot: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - async def handle_notice_message(self, message: MessageRecv): - """处理notice消息 - - notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点: - 1. 默认不触发聊天流程,只记录 - 2. 可通过配置开启触发聊天流程 - 3. 会在提示词中展示 - """ - # 检查是否是notice消息 - if message.is_notify: - logger.info(f"收到notice消息: {message.notice_type}") - - # 根据配置决定是否触发聊天流程 - if not global_config.notice.enable_notice_trigger_chat: - logger.debug("notice消息不触发聊天流程(配置已关闭)") - return True # 返回True表示已处理,不继续后续流程 - else: - logger.debug("notice消息触发聊天流程(配置已开启)") - return False # 返回False表示继续处理,触发聊天流程 - - # 兼容旧的notice判断方式 - if message.message_info.message_id == "notice": - message.is_notify = True - logger.info("旧格式notice消息") - - # 同样根据配置决定 - if not global_config.notice.enable_notice_trigger_chat: - return True - else: - return False - - # 处理适配器响应消息 - if hasattr(message, "message_segment") and message.message_segment: - if message.message_segment.type == "adapter_response": - await self.handle_adapter_response(message) - return True - elif message.message_segment.type == "adapter_command": - # 适配器命令消息不需要进一步处理 - logger.debug("收到适配器命令消息,跳过后续处理") - return True - - return False - - async def handle_adapter_response(self, message: MessageRecv): - """处理适配器命令响应""" - try: - from src.plugin_system.apis.send_api import put_adapter_response - - seg_data = message.message_segment.data - if isinstance(seg_data, dict): - request_id = seg_data.get("request_id") - response_data = seg_data.get("response") - else: - request_id = None - response_data = None - - if request_id and response_data: - logger.debug(f"收到适配器响应: request_id={request_id}") - put_adapter_response(request_id, response_data) - else: - logger.warning("适配器响应消息格式不正确") - - except Exception as e: - logger.error(f"处理适配器响应时出错: {e}") - - async def do_s4u(self, message_data: dict[str, Any]): - message = MessageRecvS4U(message_data) - group_info = message.message_info.group_info - user_info = message.message_info.user_info - - get_chat_manager().register_message(message) - chat = await get_chat_manager().get_or_create_stream( - platform=message.message_info.platform, # type: ignore - user_info=user_info, # type: ignore - group_info=group_info, - ) - - message.update_chat_stream(chat) - - # 处理消息内容 - await message.process() - - await self.s4u_message_processor.process_message(message) - - return - async def message_process(self, message_data: dict[str, Any]) -> None: """处理转化后的统一格式消息""" try: @@ -406,9 +322,6 @@ class ChatBot: await self._ensure_started() # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError - if not isinstance(message_data, dict): - logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过") - return message_info = message_data.get("message_info") if not isinstance(message_info, dict): logger.debug( @@ -417,12 +330,6 @@ class ChatBot: ) return - platform = message_info.get("platform") - - if platform == "amaidesu_default": - await self.do_s4u(message_data) - return - if message_info.get("group_info") is not None: message_info["group_info"]["group_id"] = str( message_info["group_info"]["group_id"] @@ -433,156 +340,71 @@ class ChatBot: ) # print(message_data) # logger.debug(str(message_data)) - message = MessageRecv(message_data) - group_info = message.message_info.group_info - user_info = message.message_info.user_info - if message.message_info.additional_config: - sent_message = message.message_info.additional_config.get("echo", False) + # 先提取基础信息检查是否是自身消息上报 + from maim_message import BaseMessageInfo + temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) + if temp_message_info.additional_config: + sent_message = temp_message_info.additional_config.get("echo", False) if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题 - await MessageStorage.update_message(message) + # 直接使用消息字典更新,不再需要创建 MessageRecv + await MessageStorage.update_message(message_data) return - get_chat_manager().register_message(message) + group_info = temp_message_info.group_info + user_info = temp_message_info.user_info + # 获取或创建聊天流 chat = await get_chat_manager().get_or_create_stream( - platform=message.message_info.platform, # type: ignore + platform=temp_message_info.platform, # type: ignore user_info=user_info, # type: ignore group_info=group_info, ) - message.update_chat_stream(chat) + # 使用新的消息处理器直接生成 DatabaseMessages + from src.chat.message_receive.message_processor import process_message_from_dict + message = await process_message_from_dict( + message_dict=message_data, + stream_id=chat.stream_id, + platform=chat.platform + ) - # 处理消息内容,生成纯文本 - await message.process() + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + # 注册消息到聊天管理器 + get_chat_manager().register_message(message) + + # 检测是否提及机器人 message.is_mentioned, _ = is_mentioned_bot_in_message(message) # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 chat_name = chat.group_info.group_name if chat.group_info else "私聊" - if message.message_info.user_info: - logger.info( - f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m" - ) + user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" + logger.info( + f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m" + ) # 在此添加硬编码过滤,防止回复图片处理失败的消息 failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] - if any(keyword in message.processed_plain_text for keyword in failure_keywords): - logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。") - return - - # 处理notice消息 - notice_handled = await self.handle_notice_message(message) - if notice_handled: - # notice消息已处理,需要先添加到message_manager再存储 - try: - import time - - from src.common.data_models.database_data_model import DatabaseMessages - - message_info = message.message_info - msg_user_info = getattr(message_info, "user_info", None) - stream_user_info = getattr(message.chat_stream, "user_info", None) - group_info = getattr(message.chat_stream, "group_info", None) - - message_id = message_info.message_id or "" - message_time = message_info.time if message_info.time is not None else time.time() - - user_id = "" - user_nickname = "" - user_cardname = None - user_platform = "" - if msg_user_info: - user_id = str(getattr(msg_user_info, "user_id", "") or "") - user_nickname = getattr(msg_user_info, "user_nickname", "") or "" - user_cardname = getattr(msg_user_info, "user_cardname", None) - user_platform = getattr(msg_user_info, "platform", "") or "" - elif stream_user_info: - user_id = str(getattr(stream_user_info, "user_id", "") or "") - user_nickname = getattr(stream_user_info, "user_nickname", "") or "" - user_cardname = getattr(stream_user_info, "user_cardname", None) - user_platform = getattr(stream_user_info, "platform", "") or "" - - chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") - chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" - chat_user_cardname = getattr(stream_user_info, "user_cardname", None) - chat_user_platform = getattr(stream_user_info, "platform", "") or "" - - group_id = getattr(group_info, "group_id", None) - group_name = getattr(group_info, "group_name", None) - group_platform = getattr(group_info, "platform", None) - - # 构建additional_config,确保包含is_notice标志 - import json - additional_config_dict = { - "is_notice": True, - "notice_type": message.notice_type or "unknown", - "is_public_notice": bool(message.is_public_notice), - } - - # 如果message_info有additional_config,合并进来 - if hasattr(message_info, "additional_config") and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_dict.update(message_info.additional_config) - elif isinstance(message_info.additional_config, str): - try: - existing_config = json.loads(message_info.additional_config) - additional_config_dict.update(existing_config) - except Exception: - pass - - additional_config_json = json.dumps(additional_config_dict) - - # 创建数据库消息对象 - db_message = DatabaseMessages( - message_id=message_id, - time=float(message_time), - chat_id=message.chat_stream.stream_id, - processed_plain_text=message.processed_plain_text, - display_message=message.processed_plain_text, - is_notify=bool(message.is_notify), - is_public_notice=bool(message.is_public_notice), - notice_type=message.notice_type, - additional_config=additional_config_json, - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - user_platform=user_platform, - chat_info_stream_id=message.chat_stream.stream_id, - chat_info_platform=message.chat_stream.platform, - chat_info_create_time=float(message.chat_stream.create_time), - chat_info_last_active_time=float(message.chat_stream.last_active_time), - chat_info_user_id=chat_user_id, - chat_info_user_nickname=chat_user_nickname, - chat_info_user_cardname=chat_user_cardname, - chat_info_user_platform=chat_user_platform, - chat_info_group_id=group_id, - chat_info_group_name=group_name, - chat_info_group_platform=group_platform, - ) - - # 添加到message_manager(这会将notice添加到全局notice管理器) - await message_manager.add_message(message.chat_stream.stream_id, db_message) - logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}") - - except Exception as e: - logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True) - - # 存储后直接返回 - await MessageStorage.store_message(message, chat) - logger.debug("notice消息已存储,跳过后续处理") + processed_text = message.processed_plain_text or "" + if any(keyword in processed_text for keyword in failure_keywords): + logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") return # 过滤检查 + # DatabaseMessages 使用 display_message 作为原始消息表示 + raw_text = message.display_message or message.processed_plain_text or "" if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - message.raw_message, # type: ignore + raw_text, chat, user_info, # type: ignore ): return # 命令处理 - 首先尝试PlusCommand独立处理 - is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message) + is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) # 如果是PlusCommand且不需要继续处理,则直接返回 if is_plus_command and not plus_continue_process: @@ -592,7 +414,7 @@ class ChatBot: # 如果不是PlusCommand,尝试传统的BaseCommand处理 if not is_plus_command: - is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message) + is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat) # 如果是命令且不需要继续处理,则直接返回 if is_command and not continue_process: @@ -604,138 +426,14 @@ class ChatBot: if result and not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - # TODO:暂不可用 + # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info # 确认从接口发来的message是否有自定义的prompt模板信息 - if message.message_info.template_info and not message.message_info.template_info.template_default: - template_group_name: str | None = message.message_info.template_info.template_name # type: ignore - template_items = message.message_info.template_info.template_items - async with global_prompt_manager.async_message_scope(template_group_name): - if isinstance(template_items, dict): - for k in template_items.keys(): - await create_prompt_async(template_items[k], k) - logger.debug(f"注册{template_items[k]},{k}") - else: - template_group_name = None + # 这个功能需要在 adapter 层通过 additional_config 传递 + template_group_name = None async def preprocess(): - import time - - from src.common.data_models.database_data_model import DatabaseMessages - - message_info = message.message_info - msg_user_info = getattr(message_info, "user_info", None) - stream_user_info = getattr(message.chat_stream, "user_info", None) - group_info = getattr(message.chat_stream, "group_info", None) - - message_id = message_info.message_id or "" - message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() - is_mentioned = None - if isinstance(message.is_mentioned, bool): - is_mentioned = message.is_mentioned - elif isinstance(message.is_mentioned, int | float): - is_mentioned = message.is_mentioned != 0 - - user_id = "" - user_nickname = "" - user_cardname = None - user_platform = "" - if msg_user_info: - user_id = str(getattr(msg_user_info, "user_id", "") or "") - user_nickname = getattr(msg_user_info, "user_nickname", "") or "" - user_cardname = getattr(msg_user_info, "user_cardname", None) - user_platform = getattr(msg_user_info, "platform", "") or "" - elif stream_user_info: - user_id = str(getattr(stream_user_info, "user_id", "") or "") - user_nickname = getattr(stream_user_info, "user_nickname", "") or "" - user_cardname = getattr(stream_user_info, "user_cardname", None) - user_platform = getattr(stream_user_info, "platform", "") or "" - - chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") - chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" - chat_user_cardname = getattr(stream_user_info, "user_cardname", None) - chat_user_platform = getattr(stream_user_info, "platform", "") or "" - - group_id = getattr(group_info, "group_id", None) - group_name = getattr(group_info, "group_name", None) - group_platform = getattr(group_info, "platform", None) - - # 准备 additional_config,将 format_info 嵌入其中 - additional_config_str = None - try: - import orjson - - additional_config_data = {} - - # 首先获取adapter传递的additional_config - if hasattr(message_info, 'additional_config') and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_data = message_info.additional_config.copy() - elif isinstance(message_info.additional_config, str): - try: - additional_config_data = orjson.loads(message_info.additional_config) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} - - # 然后添加format_info到additional_config中 - if hasattr(message_info, 'format_info') and message_info.format_info: - try: - format_info_dict = message_info.format_info.to_dict() - additional_config_data["format_info"] = format_info_dict - logger.debug(f"[bot.py] 嵌入 format_info 到 additional_config: {format_info_dict}") - except Exception as e: - logger.warning(f"将 format_info 转换为字典失败: {e}") - else: - logger.warning(f"[bot.py] [问题] 消息缺少 format_info: message_id={message_id}") - - # 序列化为JSON字符串 - if additional_config_data: - additional_config_str = orjson.dumps(additional_config_data).decode("utf-8") - except Exception as e: - logger.error(f"准备 additional_config 失败: {e}") - - # 创建数据库消息对象 - db_message = DatabaseMessages( - message_id=message_id, - time=float(message_time), - chat_id=message.chat_stream.stream_id, - processed_plain_text=message.processed_plain_text, - display_message=message.processed_plain_text, - is_mentioned=is_mentioned, - is_at=bool(message.is_at) if message.is_at is not None else None, - is_emoji=bool(message.is_emoji), - is_picid=bool(message.is_picid), - is_command=bool(message.is_command), - is_notify=bool(message.is_notify), - is_public_notice=bool(message.is_public_notice), - notice_type=message.notice_type, - additional_config=additional_config_str, - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - user_platform=user_platform, - chat_info_stream_id=message.chat_stream.stream_id, - chat_info_platform=message.chat_stream.platform, - chat_info_create_time=float(message.chat_stream.create_time), - chat_info_last_active_time=float(message.chat_stream.last_active_time), - chat_info_user_id=chat_user_id, - chat_info_user_nickname=chat_user_nickname, - chat_info_user_cardname=chat_user_cardname, - chat_info_user_platform=chat_user_platform, - chat_info_group_id=group_id, - chat_info_group_name=group_name, - chat_info_group_platform=group_platform, - ) - - # 兼容历史逻辑:显式设置群聊相关属性,便于后续逻辑通过 hasattr 判断 - if group_info: - setattr(db_message, "chat_info_group_id", group_id) - setattr(db_message, "chat_info_group_name", group_name) - setattr(db_message, "chat_info_group_platform", group_platform) - else: - setattr(db_message, "chat_info_group_id", None) - setattr(db_message, "chat_info_group_name", None) - setattr(db_message, "chat_info_group_platform", None) + # message 已经是 DatabaseMessages,直接使用 + group_info = chat.group_info # 先交给消息管理器处理,计算兴趣度等衍生数据 try: @@ -752,31 +450,15 @@ class ChatBot: should_process_in_manager = False if should_process_in_manager: - await message_manager.add_message(message.chat_stream.stream_id, db_message) - logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") + await message_manager.add_message(chat.stream_id, message) + logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") except Exception as e: logger.error(f"消息添加到消息管理器失败: {e}") - # 将兴趣度结果同步回原始消息,便于后续流程使用 - message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0)) - setattr( - message, - "should_reply", - getattr(db_message, "should_reply", getattr(message, "should_reply", False)), - ) - setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False))) - # 存储消息到数据库,只进行一次写入 try: - await MessageStorage.store_message(message, message.chat_stream) - logger.debug( - "消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)", - message.message_info.message_id, - getattr(message, "interest_value", -1.0), - getattr(message, "should_reply", None), - getattr(message, "should_act", None), - ) + await MessageStorage.store_message(message, chat) except Exception as e: logger.error(f"存储消息到数据库失败: {e}") traceback.print_exc() @@ -785,13 +467,13 @@ class ChatBot: try: if global_config.mood.enable_mood: # 获取兴趣度用于情绪更新 - interest_rate = getattr(message, "interest_value", 0.0) + interest_rate = message.interest_value if interest_rate is None: interest_rate = 0.0 logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") # 获取当前聊天的情绪对象并更新情绪状态 - chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id) + chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) await chat_mood.update_mood_by_message(message, interest_rate) logger.debug("情绪状态更新完成") except Exception as e: diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c22d755fb..049d0fda1 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,8 +1,6 @@ import asyncio -import copy import hashlib import time -from typing import TYPE_CHECKING from maim_message import GroupInfo, UserInfo from rich.traceback import install @@ -10,16 +8,12 @@ from sqlalchemy import select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 from src.common.logger import get_logger from src.config.config import global_config # 新增导入 -# 避免循环导入,使用TYPE_CHECKING进行类型提示 -if TYPE_CHECKING: - from .message import MessageRecv - - install(extra_lines=3) @@ -33,7 +27,7 @@ class ChatStream: self, stream_id: str, platform: str, - user_info: UserInfo, + user_info: UserInfo | None = None, group_info: GroupInfo | None = None, data: dict | None = None, ): @@ -46,20 +40,18 @@ class ChatStream: self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False - # 使用StreamContext替代ChatMessageContext + # 创建单流上下文管理器(包含StreamContext) + from src.chat.message_manager.context_manager import SingleStreamContextManager from src.common.data_models.message_manager_data_model import StreamContext from src.plugin_system.base.component_types import ChatMode, ChatType - # 创建StreamContext - self.stream_context: StreamContext = StreamContext( - stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL - ) - - # 创建单流上下文管理器 - from src.chat.message_manager.context_manager import SingleStreamContextManager - self.context_manager: SingleStreamContextManager = SingleStreamContextManager( - stream_id=stream_id, context=self.stream_context + stream_id=stream_id, + context=StreamContext( + stream_id=stream_id, + chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), ) # 基础参数 @@ -67,37 +59,6 @@ class ChatStream: self._focus_energy = 0.5 # 内部存储的focus_energy值 self.no_reply_consecutive = 0 - def __deepcopy__(self, memo): - """自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象""" - import copy - - # 创建新的实例 - new_stream = ChatStream( - stream_id=self.stream_id, - platform=self.platform, - user_info=copy.deepcopy(self.user_info, memo), - group_info=copy.deepcopy(self.group_info, memo), - ) - - # 复制基本属性 - new_stream.create_time = self.create_time - new_stream.last_active_time = self.last_active_time - new_stream.sleep_pressure = self.sleep_pressure - new_stream.saved = self.saved - new_stream.base_interest_energy = self.base_interest_energy - new_stream._focus_energy = self._focus_energy - new_stream.no_reply_consecutive = self.no_reply_consecutive - - # 复制 stream_context,但跳过 processing_task - new_stream.stream_context = copy.deepcopy(self.stream_context, memo) - if hasattr(new_stream.stream_context, "processing_task"): - new_stream.stream_context.processing_task = None - - # 复制 context_manager - new_stream.context_manager = copy.deepcopy(self.context_manager, memo) - - return new_stream - def to_dict(self) -> dict: """转换为字典格式""" return { @@ -111,11 +72,11 @@ class ChatStream: "focus_energy": self.focus_energy, # 基础兴趣度 "base_interest_energy": self.base_interest_energy, - # stream_context基本信息 - "stream_context_chat_type": self.stream_context.chat_type.value, - "stream_context_chat_mode": self.stream_context.chat_mode.value, + # stream_context基本信息(通过context_manager访问) + "stream_context_chat_type": self.context_manager.context.chat_type.value, + "stream_context_chat_mode": self.context_manager.context.chat_mode.value, # 统计信息 - "interruption_count": self.stream_context.interruption_count, + "interruption_count": self.context_manager.context.interruption_count, } @classmethod @@ -132,27 +93,19 @@ class ChatStream: data=data, ) - # 恢复stream_context信息 + # 恢复stream_context信息(通过context_manager访问) if "stream_context_chat_type" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) + instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"]) # 恢复interruption_count信息 if "interruption_count" in data: - instance.stream_context.interruption_count = data["interruption_count"] - - # 确保 context_manager 已初始化 - if not hasattr(instance, "context_manager"): - from src.chat.message_manager.context_manager import SingleStreamContextManager - - instance.context_manager = SingleStreamContextManager( - stream_id=instance.stream_id, context=instance.stream_context - ) + instance.context_manager.context.interruption_count = data["interruption_count"] return instance @@ -160,159 +113,47 @@ class ChatStream: """获取原始的、未哈希的聊天流ID字符串""" if self.group_info: return f"{self.platform}:{self.group_info.group_id}:group" - else: + elif self.user_info: return f"{self.platform}:{self.user_info.user_id}:private" + else: + return f"{self.platform}:unknown:private" def update_active_time(self): """更新最后活跃时间""" self.last_active_time = time.time() self.saved = False - async def set_context(self, message: "MessageRecv"): - """设置聊天消息上下文""" - # 将MessageRecv转换为DatabaseMessages并设置到stream_context - import json - - from src.common.data_models.database_data_model import DatabaseMessages - - # 安全获取message_info中的数据 - message_info = getattr(message, "message_info", {}) - user_info = getattr(message_info, "user_info", {}) - group_info = getattr(message_info, "group_info", {}) - - # 提取reply_to信息(从message_segment中查找reply类型的段) - reply_to = None - if hasattr(message, "message_segment") and message.message_segment: - reply_to = self._extract_reply_from_segment(message.message_segment) - - # 完整的数据转移逻辑 - db_message = DatabaseMessages( - # 基础消息信息 - message_id=getattr(message, "message_id", ""), - time=getattr(message, "time", time.time()), - chat_id=self._generate_chat_id(message_info), - reply_to=reply_to, - # 兴趣度相关 - interest_value=getattr(message, "interest_value", 0.0), - # 关键词 - key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False) - if getattr(message, "key_words", None) - else None, - key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False) - if getattr(message, "key_words_lite", None) - else None, - # 消息状态标记 - is_mentioned=getattr(message, "is_mentioned", None), - is_at=getattr(message, "is_at", False), - is_emoji=getattr(message, "is_emoji", False), - is_picid=getattr(message, "is_picid", False), - is_voice=getattr(message, "is_voice", False), - is_video=getattr(message, "is_video", False), - is_command=getattr(message, "is_command", False), - is_notify=getattr(message, "is_notify", False), - is_public_notice=getattr(message, "is_public_notice", False), - notice_type=getattr(message, "notice_type", None), - # 消息内容 - processed_plain_text=getattr(message, "processed_plain_text", ""), - display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text - # 优先级信息 - priority_mode=getattr(message, "priority_mode", None), - priority_info=json.dumps(getattr(message, "priority_info", None)) - if getattr(message, "priority_info", None) - else None, - # 额外配置 - 需要将 format_info 嵌入到 additional_config 中 - additional_config=self._prepare_additional_config(message_info), - # 用户信息 - user_id=str(getattr(user_info, "user_id", "")), - user_nickname=getattr(user_info, "user_nickname", ""), - user_cardname=getattr(user_info, "user_cardname", None), - user_platform=getattr(user_info, "platform", ""), - # 群组信息 - chat_info_group_id=getattr(group_info, "group_id", None), - chat_info_group_name=getattr(group_info, "group_name", None), - chat_info_group_platform=getattr(group_info, "platform", None), - # 聊天流信息 - chat_info_user_id=str(getattr(user_info, "user_id", "")), - chat_info_user_nickname=getattr(user_info, "user_nickname", ""), - chat_info_user_cardname=getattr(user_info, "user_cardname", None), - chat_info_user_platform=getattr(user_info, "platform", ""), - chat_info_stream_id=self.stream_id, - chat_info_platform=self.platform, - chat_info_create_time=self.create_time, - chat_info_last_active_time=self.last_active_time, - # 新增兴趣度系统字段 - 添加安全处理 - actions=self._safe_get_actions(message), - should_reply=getattr(message, "should_reply", False), - should_act=getattr(message, "should_act", False), - ) - - self.stream_context.set_current_message(db_message) - self.stream_context.priority_mode = getattr(message, "priority_mode", None) - self.stream_context.priority_info = getattr(message, "priority_info", None) - - # 调试日志:记录数据转移情况 - logger.debug( - f"消息数据转移完成 - message_id: {db_message.message_id}, " - f"chat_id: {db_message.chat_id}, " - f"is_mentioned: {db_message.is_mentioned}, " - f"is_emoji: {db_message.is_emoji}, " - f"is_picid: {db_message.is_picid}, " - f"interest_value: {db_message.interest_value}" - ) - - def _prepare_additional_config(self, message_info) -> str | None: - """ - 准备 additional_config,将 format_info 嵌入其中 - - 这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config - 包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型 - + async def set_context(self, message: DatabaseMessages): + """设置聊天消息上下文 + Args: - message_info: BaseMessageInfo 对象 - - Returns: - str | None: JSON 字符串格式的 additional_config,如果为空则返回 None + message: DatabaseMessages 对象,直接使用不需要转换 """ - import orjson + # 直接使用传入的 DatabaseMessages,设置到上下文中 + self.context_manager.context.set_current_message(message) - # 首先获取adapter传递的additional_config - additional_config_data = {} - if hasattr(message_info, 'additional_config') and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_data = message_info.additional_config.copy() - elif isinstance(message_info.additional_config, str): - # 如果是字符串,尝试解析 - try: - additional_config_data = orjson.loads(message_info.additional_config) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} - - # 然后添加format_info到additional_config中 - if hasattr(message_info, 'format_info') and message_info.format_info: - try: - format_info_dict = message_info.format_info.to_dict() - additional_config_data["format_info"] = format_info_dict - logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}") - except Exception as e: - logger.warning(f"将 format_info 转换为字典失败: {e}") - else: - logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}") - logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型") - - # 序列化为JSON字符串 - if additional_config_data: - try: - return orjson.dumps(additional_config_data).decode("utf-8") - except Exception as e: - logger.error(f"序列化 additional_config 失败: {e}") - return None - return None + # 设置优先级信息(如果存在) + priority_mode = getattr(message, "priority_mode", None) + priority_info = getattr(message, "priority_info", None) + if priority_mode: + self.context_manager.context.priority_mode = priority_mode + if priority_info: + self.context_manager.context.priority_info = priority_info - def _safe_get_actions(self, message: "MessageRecv") -> list | None: + # 调试日志 + logger.debug( + f"消息上下文已设置 - message_id: {message.message_id}, " + f"chat_id: {message.chat_id}, " + f"is_mentioned: {message.is_mentioned}, " + f"is_emoji: {message.is_emoji}, " + f"is_picid: {message.is_picid}, " + f"interest_value: {message.interest_value}" + ) + + def _safe_get_actions(self, message: DatabaseMessages) -> list | None: """安全获取消息的actions字段""" import json - + try: actions = getattr(message, "actions", None) if actions is None: @@ -380,23 +221,6 @@ class ChatStream: if hasattr(db_message, "should_act"): db_message.should_act = False - def _extract_reply_from_segment(self, segment) -> str | None: - """从消息段中提取reply_to信息""" - try: - if hasattr(segment, "type") and segment.type == "seglist": - # 递归搜索seglist中的reply段 - if hasattr(segment, "data") and segment.data: - for seg in segment.data: - reply_id = self._extract_reply_from_segment(seg) - if reply_id: - return reply_id - elif hasattr(segment, "type") and segment.type == "reply": - # 找到reply段,返回message_id - return str(segment.data) if segment.data else None - except Exception as e: - logger.warning(f"提取reply_to信息失败: {e}") - return None - def _generate_chat_id(self, message_info) -> str: """生成chat_id,基于群组或用户信息""" try: @@ -493,8 +317,10 @@ class ChatManager: def __init__(self): if not self._initialized: + from src.common.data_models.database_data_model import DatabaseMessages + self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream - self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message + self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message # try: # async with get_db_session() as session: # db.connect(reuse_if_open=True) @@ -528,12 +354,30 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {e!s}") - def register_message(self, message: "MessageRecv"): + def register_message(self, message: DatabaseMessages): """注册消息到聊天流""" + # 从 DatabaseMessages 提取平台和用户/群组信息 + from maim_message import GroupInfo, UserInfo + + user_info = UserInfo( + platform=message.user_info.platform, + user_id=message.user_info.user_id, + user_nickname=message.user_info.user_nickname, + user_cardname=message.user_info.user_cardname or "" + ) + + group_info = None + if message.group_info: + group_info = GroupInfo( + platform=message.group_info.group_platform or "", + group_id=message.group_info.group_id, + group_name=message.group_info.group_name + ) + stream_id = self._generate_stream_id( - message.message_info.platform, # type: ignore - message.message_info.user_info, - message.message_info.group_info, + message.chat_info.platform, + user_info, + group_info, ) self.last_messages[stream_id] = message # logger.debug(f"注册消息到聊天流: {stream_id}") @@ -578,49 +422,23 @@ class ChatManager: try: stream_id = self._generate_stream_id(platform, user_info, group_info) - # 优先使用缓存管理器(优化版本) - try: - from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager - - cache_manager = get_stream_cache_manager() - - if cache_manager.is_running: - optimized_stream = await cache_manager.get_or_create_stream( - stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info - ) - - # 设置消息上下文 - from .message import MessageRecv - - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): - optimized_stream.set_context(self.last_messages[stream_id]) - - # 转换为原始ChatStream以保持兼容性 - original_stream = self._convert_to_original_stream(optimized_stream) - - return original_stream - - except Exception as e: - logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}") - - # 回退到原始方法 # 检查内存中是否存在 if stream_id in self.streams: stream = self.streams[stream_id] # 更新用户信息和群组信息 stream.update_active_time() - stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 if user_info.platform and user_info.user_id: stream.user_info = user_info if group_info: stream.group_info = group_info - from .message import MessageRecv # 延迟导入,避免循环引用 - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + # 检查是否有最后一条消息(现在使用 DatabaseMessages) + from src.common.data_models.database_data_model import DatabaseMessages + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") + logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息") return stream # 检查数据库中是否存在 @@ -678,20 +496,30 @@ class ChatManager: logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e - stream = copy.deepcopy(stream) - from .message import MessageRecv # 延迟导入,避免循环引用 + from src.common.data_models.database_data_model import DatabaseMessages - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") + logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") # 确保 ChatStream 有自己的 context_manager - if not hasattr(stream, "context_manager"): - # 创建新的单流上下文管理器 + if not hasattr(stream, "context_manager") or stream.context_manager is None: from src.chat.message_manager.context_manager import SingleStreamContextManager + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType - stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context) + logger.info(f"为 stream {stream_id} 创建新的 context_manager") + stream.context_manager = SingleStreamContextManager( + stream_id=stream_id, + context=StreamContext( + stream_id=stream_id, + chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), + ) + else: + logger.info(f"stream {stream_id} 已有 context_manager,跳过创建") # 保存到内存和数据库 self.streams[stream_id] = stream @@ -700,10 +528,12 @@ class ChatManager: async def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" + from src.common.data_models.database_data_model import DatabaseMessages + stream = self.streams.get(stream_id) if not stream: return None - if stream_id in self.last_messages: + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) return stream @@ -919,12 +749,22 @@ class ChatManager: # await stream.set_context(self.last_messages[stream.stream_id]) # 确保 ChatStream 有自己的 context_manager - if not hasattr(stream, "context_manager"): + if not hasattr(stream, "context_manager") or stream.context_manager is None: from src.chat.message_manager.context_manager import SingleStreamContextManager + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType + logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager") stream.context_manager = SingleStreamContextManager( - stream_id=stream.stream_id, context=stream.stream_context + stream_id=stream.stream_id, + context=StreamContext( + stream_id=stream.stream_id, + chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), ) + else: + logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager") except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) @@ -932,46 +772,6 @@ class ChatManager: chat_manager = None -def _convert_to_original_stream(self, optimized_stream) -> "ChatStream": - """将OptimizedChatStream转换为原始ChatStream以保持兼容性""" - try: - # 创建原始ChatStream实例 - original_stream = ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info(), - ) - - # 复制状态 - original_stream.create_time = optimized_stream.create_time - original_stream.last_active_time = optimized_stream.last_active_time - original_stream.sleep_pressure = optimized_stream.sleep_pressure - original_stream.base_interest_energy = optimized_stream.base_interest_energy - original_stream._focus_energy = optimized_stream._focus_energy - original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive - original_stream.saved = optimized_stream.saved - - # 复制上下文信息(如果存在) - if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context: - original_stream.stream_context = optimized_stream._stream_context - - if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager: - original_stream.context_manager = optimized_stream._context_manager - - return original_stream - - except Exception as e: - logger.error(f"转换OptimizedChatStream失败: {e}") - # 如果转换失败,创建一个新的原始流 - return ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info(), - ) - - def get_chat_manager(): global chat_manager if chat_manager is None: diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 98b12d694..68fc4f1bf 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,8 +1,7 @@ -import base64 import time from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import urllib3 from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo @@ -11,8 +10,8 @@ from rich.traceback import install from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager -from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config @@ -43,7 +42,7 @@ class Message(MessageBase, metaclass=ABCMeta): user_info: UserInfo, message_segment: Seg | None = None, timestamp: float | None = None, - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, processed_plain_text: str = "", ): # 使用传入的时间戳或当前时间 @@ -95,418 +94,12 @@ class Message(MessageBase, metaclass=ABCMeta): @dataclass -class MessageRecv(Message): - """接收消息类,用于处理从MessageCQ序列化的消息""" - def __init__(self, message_dict: dict[str, Any]): - """从MessageCQ的字典初始化 - - Args: - message_dict: MessageCQ序列化后的字典 - """ - # Manually initialize attributes from MessageBase and Message - self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - self.raw_message = message_dict.get("raw_message") - - self.chat_stream = None - self.reply = None - self.processed_plain_text = message_dict.get("processed_plain_text", "") - self.memorized_times = 0 - - # MessageRecv specific attributes - self.is_emoji = False - self.has_emoji = False - self.is_picid = False - self.has_picid = False - self.is_voice = False - self.is_video = False - self.is_mentioned = None - self.is_notify = False # 是否为notice消息 - self.is_public_notice = False # 是否为公共notice - self.notice_type = None # notice类型 - self.is_at = False - self.is_command = False - - self.priority_mode = "interest" - self.priority_info = None - self.interest_value: float = 0.0 - - self.key_words = [] - self.key_words_lite = [] - - # 解析additional_config中的notice信息 - if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict): - self.is_notify = self.message_info.additional_config.get("is_notice", False) - self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False) - self.notice_type = self.message_info.additional_config.get("notice_type") - - def update_chat_stream(self, chat_stream: "ChatStream"): - self.chat_stream = chat_stream - - async def process(self) -> None: - """处理消息内容,生成纯文本和详细文本 - - 这个方法必须在创建实例后显式调用,因为它包含异步操作。 - """ - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - async def _process_single_segment(self, segment: Seg) -> str: - """处理单个消息段 - - Args: - segment: 消息段 - - Returns: - str: 处理后的文本 - """ - try: - if segment.type == "text": - self.is_picid = False - self.is_emoji = False - self.is_video = False - return segment.data # type: ignore - elif segment.type == "at": - self.is_picid = False - self.is_emoji = False - self.is_video = False - # 处理at消息,格式为"昵称:QQ号" - if isinstance(segment.data, str) and ":" in segment.data: - nickname, qq_id = segment.data.split(":", 1) - return f"@{nickname}" - return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" - elif segment.type == "image": - # 如果是base64图片数据 - if isinstance(segment.data, str): - self.has_picid = True - self.is_picid = True - self.is_emoji = False - self.is_video = False - image_manager = get_image_manager() - # print(f"segment.data: {segment.data}") - _, processed_text = await image_manager.process_image(segment.data) - return processed_text - return "[发了一张图片,网卡了加载不出来]" - elif segment.type == "emoji": - self.has_emoji = True - self.is_emoji = True - self.is_picid = False - self.is_voice = False - self.is_video = False - if isinstance(segment.data, str): - return await get_image_manager().get_emoji_description(segment.data) - return "[发了一个表情包,网卡了加载不出来]" - elif segment.type == "voice": - self.is_picid = False - self.is_emoji = False - self.is_voice = True - self.is_video = False - - # 检查消息是否由机器人自己发送 - if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): - logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。") - if isinstance(segment.data, str): - cached_text = consume_self_voice_text(segment.data) - if cached_text: - logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") - return f"[语音:{cached_text}]" - else: - logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") - - # 标准语音识别流程 (也作为缓存未命中的后备方案) - if isinstance(segment.data, str): - return await get_voice_text(segment.data) - return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "mention_bot": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - self.is_video = False - self.is_mentioned = float(segment.data) # type: ignore - return "" - elif segment.type == "priority_info": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - # 处理优先级信息 - self.priority_mode = "priority" - self.priority_info = segment.data - """ - { - 'message_type': 'vip', # vip or normal - 'message_priority': 1.0, # 优先级,大为优先,float - } - """ - return "" - elif segment.type == "file": - if isinstance(segment.data, dict): - file_name = segment.data.get('name', '未知文件') - file_size = segment.data.get('size', '未知大小') - return f"[文件:{file_name} ({file_size}字节)]" - return "[收到一个文件]" - elif segment.type == "video": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - self.is_video = True - logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") - - # 检查视频分析功能是否可用 - if not is_video_analysis_available(): - logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") - return "[视频]" - - if global_config.video_analysis.enable: - logger.info("已启用视频识别,开始识别") - if isinstance(segment.data, dict): - try: - # 从Adapter接收的视频数据 - video_base64 = segment.data.get("base64") - filename = segment.data.get("filename", "video.mp4") - - logger.info(f"视频文件名: {filename}") - logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") - - if video_base64: - # 解码base64视频数据 - video_bytes = base64.b64decode(video_base64) - logger.info(f"解码后视频大小: {len(video_bytes)} 字节") - - # 使用video analyzer分析视频 - video_analyzer = get_video_analyzer() - result = await video_analyzer.analyze_video_from_bytes( - video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt - ) - - logger.info(f"视频分析结果: {result}") - - # 返回视频分析结果 - summary = result.get("summary", "") - if summary: - return f"[视频内容] {summary}" - else: - return "[已收到视频,但分析失败]" - else: - logger.warning("视频消息中没有base64数据") - return "[收到视频消息,但数据异常]" - except Exception as e: - logger.error(f"视频处理失败: {e!s}") - import traceback - - logger.error(f"错误详情: {traceback.format_exc()}") - return "[收到视频,但处理时出现错误]" - else: - logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") - return "[发了一个视频,但格式不支持]" - else: - return "" - else: - logger.warning(f"未知的消息段类型: {segment.type}") - return f"[{segment.type} 消息]" - except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" - - -@dataclass -class MessageRecvS4U(MessageRecv): - def __init__(self, message_dict: dict[str, Any]): - super().__init__(message_dict) - self.is_gift = False - self.is_fake_gift = False - self.is_superchat = False - self.gift_info = None - self.gift_name = None - self.gift_count: int | None = None - self.superchat_info = None - self.superchat_price = None - self.superchat_message_text = None - self.is_screen = False - self.is_internal = False - self.voice_done = None - - self.chat_info = None - - async def process(self) -> None: - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - async def _process_single_segment(self, segment: Seg) -> str: - """处理单个消息段 - - Args: - segment: 消息段 - - Returns: - str: 处理后的文本 - """ - try: - if segment.type == "text": - self.is_voice = False - self.is_picid = False - self.is_emoji = False - return segment.data # type: ignore - elif segment.type == "image": - self.is_voice = False - # 如果是base64图片数据 - if isinstance(segment.data, str): - self.has_picid = True - self.is_picid = True - self.is_emoji = False - image_manager = get_image_manager() - # print(f"segment.data: {segment.data}") - _, processed_text = await image_manager.process_image(segment.data) - return processed_text - return "[发了一张图片,网卡了加载不出来]" - elif segment.type == "emoji": - self.has_emoji = True - self.is_emoji = True - self.is_picid = False - if isinstance(segment.data, str): - return await get_image_manager().get_emoji_description(segment.data) - return "[发了一个表情包,网卡了加载不出来]" - elif segment.type == "voice": - self.has_picid = False - self.is_picid = False - self.is_emoji = False - self.is_voice = True - - # 检查消息是否由机器人自己发送 - # 检查消息是否由机器人自己发送 - if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): - logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。") - if isinstance(segment.data, str): - cached_text = consume_self_voice_text(segment.data) - if cached_text: - logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") - return f"[语音:{cached_text}]" - else: - logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") - - # 标准语音识别流程 (也作为缓存未命中的后备方案) - if isinstance(segment.data, str): - return await get_voice_text(segment.data) - return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "mention_bot": - self.is_voice = False - self.is_picid = False - self.is_emoji = False - self.is_mentioned = float(segment.data) # type: ignore - return "" - elif segment.type == "priority_info": - self.is_voice = False - self.is_picid = False - self.is_emoji = False - if isinstance(segment.data, dict): - # 处理优先级信息 - self.priority_mode = "priority" - self.priority_info = segment.data - """ - { - 'message_type': 'vip', # vip or normal - 'message_priority': 1.0, # 优先级,大为优先,float - } - """ - return "" - elif segment.type == "gift": - self.is_voice = False - self.is_gift = True - # 解析gift_info,格式为"名称:数量" - name, count = segment.data.split(":", 1) # type: ignore - self.gift_info = segment.data - self.gift_name = name.strip() - self.gift_count = int(count.strip()) - return "" - elif segment.type == "voice_done": - msg_id = segment.data - logger.info(f"voice_done: {msg_id}") - self.voice_done = msg_id - return "" - elif segment.type == "superchat": - self.is_superchat = True - self.superchat_info = segment.data - price, message_text = segment.data.split(":", 1) # type: ignore - self.superchat_price = price.strip() - self.superchat_message_text = message_text.strip() - - self.processed_plain_text = str(self.superchat_message_text) - self.processed_plain_text += ( - f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" - ) - - return self.processed_plain_text - elif segment.type == "screen": - self.is_screen = True - self.screen_info = segment.data - return "屏幕信息" - elif segment.type == "file": - if isinstance(segment.data, dict): - file_name = segment.data.get('name', '未知文件') - file_size = segment.data.get('size', '未知大小') - return f"[文件:{file_name} ({file_size}字节)]" - return "[收到一个文件]" - elif segment.type == "video": - self.is_voice = False - self.is_picid = False - self.is_emoji = False - - logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") - - # 检查视频分析功能是否可用 - if not is_video_analysis_available(): - logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") - return "[视频]" - - if global_config.video_analysis.enable: - logger.info("已启用视频识别,开始识别") - if isinstance(segment.data, dict): - try: - # 从Adapter接收的视频数据 - video_base64 = segment.data.get("base64") - filename = segment.data.get("filename", "video.mp4") - - logger.info(f"视频文件名: {filename}") - logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") - - if video_base64: - # 解码base64视频数据 - video_bytes = base64.b64decode(video_base64) - logger.info(f"解码后视频大小: {len(video_bytes)} 字节") - - # 使用video analyzer分析视频 - video_analyzer = get_video_analyzer() - result = await video_analyzer.analyze_video_from_bytes( - video_bytes, filename - ) - - logger.info(f"视频分析结果: {result}") - - # 返回视频分析结果 - summary = result.get("summary", "") - if summary: - return f"[视频内容] {summary}" - else: - return "[已收到视频,但分析失败]" - else: - logger.warning("视频消息中没有base64数据") - return "[收到视频消息,但数据异常]" - except Exception as e: - logger.error(f"视频处理失败: {e!s}") - import traceback - - logger.error(f"错误详情: {traceback.format_exc()}") - return "[收到视频,但处理时出现错误]" - else: - logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") - return "[发了一个视频,但格式不支持]" - else: - return "" - else: - logger.warning(f"未知的消息段类型: {segment.type}") - return f"[{segment.type} 消息]" - except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" +# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages +# 如需从消息字典创建 DatabaseMessages,请使用: +# from src.chat.message_receive.message_processor import process_message_from_dict +# +# 迁移完成日期: 2025-10-31 @dataclass @@ -519,7 +112,7 @@ class MessageProcessBase(Message): chat_stream: "ChatStream", bot_user_info: UserInfo, message_segment: Seg | None = None, - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, thinking_start_time: float = 0, timestamp: float | None = None, ): @@ -565,7 +158,7 @@ class MessageProcessBase(Message): return "[表情,网卡了加载不出来]" elif seg.type == "voice": # 检查消息是否由机器人自己发送 - # 检查消息是否由机器人自己发送 + # self.message_info 来自 MessageBase,指当前消息的信息 if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。") if isinstance(seg.data, str): @@ -587,10 +180,24 @@ class MessageProcessBase(Message): return f"@{nickname}" return f"@{seg.data}" if isinstance(seg.data, str) else "@未知用户" elif seg.type == "reply": - if self.reply and hasattr(self.reply, "processed_plain_text"): - # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") - # print(f"reply: {self.reply}") - return f"[回复<{self.reply.message_info.user_info.user_nickname}({self.reply.message_info.user_info.user_id})> 的消息:{self.reply.processed_plain_text}]" # type: ignore + # 处理回复消息段 + if self.reply: + # 检查 reply 对象是否有必要的属性 + if hasattr(self.reply, "processed_plain_text") and self.reply.processed_plain_text: + # DatabaseMessages 使用 user_info 而不是 message_info.user_info + user_nickname = self.reply.user_info.user_nickname if self.reply.user_info else "未知用户" + user_id = self.reply.user_info.user_id if self.reply.user_info else "" + return f"[回复<{user_nickname}({user_id})> 的消息:{self.reply.processed_plain_text}]" + else: + # reply 对象存在但没有 processed_plain_text,返回简化的回复标识 + logger.debug(f"reply 消息段没有 processed_plain_text 属性,message_id: {getattr(self.reply, 'message_id', 'unknown')}") + return "[回复消息]" + else: + # 没有 reply 对象,但有 reply 消息段(可能是机器人自己发送的消息) + # 这种情况下 seg.data 应该包含被回复消息的 message_id + if isinstance(seg.data, str): + logger.debug(f"处理 reply 消息段,但 self.reply 为 None,reply_to message_id: {seg.data}") + return f"[回复消息 {seg.data}]" return None else: return f"[{seg.type}:{seg.data!s}]" @@ -620,7 +227,7 @@ class MessageSending(MessageProcessBase): sender_info: UserInfo | None, # 用来记录发送者信息 message_segment: Seg, display_message: str = "", - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, is_head: bool = False, is_emoji: bool = False, thinking_start_time: float = 0, @@ -639,7 +246,11 @@ class MessageSending(MessageProcessBase): # 发送状态特有属性 self.sender_info = sender_info - self.reply_to_message_id = reply.message_info.message_id if reply else None + # 从 DatabaseMessages 获取 message_id + if reply: + self.reply_to_message_id = reply.message_id + else: + self.reply_to_message_id = None self.is_head = is_head self.is_emoji = is_emoji self.apply_set_reply_logic = apply_set_reply_logic @@ -654,14 +265,18 @@ class MessageSending(MessageProcessBase): def build_reply(self): """设置回复消息""" if self.reply: - self.reply_to_message_id = self.reply.message_info.message_id - self.message_segment = Seg( - type="seglist", - data=[ - Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore - self.message_segment, - ], - ) + # 从 DatabaseMessages 获取 message_id + message_id = self.reply.message_id + + if message_id: + self.reply_to_message_id = message_id + self.message_segment = Seg( + type="seglist", + data=[ + Seg(type="reply", data=message_id), # type: ignore + self.message_segment, + ], + ) async def process(self) -> None: """处理消息内容,生成纯文本和详细文本""" @@ -679,103 +294,5 @@ class MessageSending(MessageProcessBase): return self.message_info.group_info is None or self.message_info.group_info.group_id is None -@dataclass -class MessageSet: - """消息集合类,可以存储多个发送消息""" - - def __init__(self, chat_stream: "ChatStream", message_id: str): - self.chat_stream = chat_stream - self.message_id = message_id - self.messages: list[MessageSending] = [] - self.time = round(time.time(), 3) # 保留3位小数 - - def add_message(self, message: MessageSending) -> None: - """添加消息到集合""" - if not isinstance(message, MessageSending): - raise TypeError("MessageSet只能添加MessageSending类型的消息") - self.messages.append(message) - self.messages.sort(key=lambda x: x.message_info.time) # type: ignore - - def get_message_by_index(self, index: int) -> MessageSending | None: - """通过索引获取消息""" - return self.messages[index] if 0 <= index < len(self.messages) else None - - def get_message_by_time(self, target_time: float) -> MessageSending | None: - """获取最接近指定时间的消息""" - if not self.messages: - return None - - left, right = 0, len(self.messages) - 1 - while left < right: - mid = (left + right) // 2 - if self.messages[mid].message_info.time < target_time: # type: ignore - left = mid + 1 - else: - right = mid - - return self.messages[left] - - def clear_messages(self) -> None: - """清空所有消息""" - self.messages.clear() - - def remove_message(self, message: MessageSending) -> bool: - """移除指定消息""" - if message in self.messages: - self.messages.remove(message) - return True - return False - - def __str__(self) -> str: - return f"MessageSet(id={self.message_id}, count={len(self.messages)})" - - def __len__(self) -> int: - return len(self.messages) - - -def message_recv_from_dict(message_dict: dict) -> MessageRecv: - return MessageRecv(message_dict) - - -def message_from_db_dict(db_dict: dict) -> MessageRecv: - """从数据库字典创建MessageRecv实例""" - # 转换扁平的数据库字典为嵌套结构 - message_info_dict = { - "platform": db_dict.get("chat_info_platform"), - "message_id": db_dict.get("message_id"), - "time": db_dict.get("time"), - "group_info": { - "platform": db_dict.get("chat_info_group_platform"), - "group_id": db_dict.get("chat_info_group_id"), - "group_name": db_dict.get("chat_info_group_name"), - }, - "user_info": { - "platform": db_dict.get("user_platform"), - "user_id": db_dict.get("user_id"), - "user_nickname": db_dict.get("user_nickname"), - "user_cardname": db_dict.get("user_cardname"), - }, - } - - processed_text = db_dict.get("processed_plain_text", "") - - # 构建 MessageRecv 需要的字典 - recv_dict = { - "message_info": message_info_dict, - "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 - "raw_message": None, # 数据库中未存储原始消息 - "processed_plain_text": processed_text, - } - - # 创建 MessageRecv 实例 - msg = MessageRecv(recv_dict) - - # 从数据库字典中填充其他可选字段 - msg.interest_value = db_dict.get("interest_value", 0.0) - msg.is_mentioned = db_dict.get("is_mentioned") - msg.priority_mode = db_dict.get("priority_mode", "interest") - msg.priority_info = db_dict.get("priority_info") - msg.is_emoji = db_dict.get("is_emoji", False) - msg.is_picid = db_dict.get("is_picid", False) - - return msg +# message_recv_from_dict 和 message_from_db_dict 函数已被移除 +# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py new file mode 100644 index 000000000..10e7213de --- /dev/null +++ b/src/chat/message_receive/message_processor.py @@ -0,0 +1,489 @@ +"""消息处理工具模块 +将原 MessageRecv 的消息处理逻辑提取为独立函数, +直接从适配器消息字典生成 DatabaseMessages +""" +import base64 +import time +from typing import Any + +import orjson +from maim_message import BaseMessageInfo, Seg + +from src.chat.utils.self_voice_cache import consume_self_voice_text +from src.chat.utils.utils_image import get_image_manager +from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available +from src.chat.utils.utils_voice import get_voice_text +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("message_processor") + + +async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages: + """从适配器消息字典处理并生成 DatabaseMessages + + 这个函数整合了原 MessageRecv 的所有处理逻辑: + 1. 解析 message_segment 并异步处理内容(图片、语音、视频等) + 2. 提取所有消息元数据 + 3. 直接构造 DatabaseMessages 对象 + + Args: + message_dict: MessageCQ序列化后的字典 + stream_id: 聊天流ID + platform: 平台标识 + + Returns: + DatabaseMessages: 处理完成的数据库消息对象 + """ + # 解析基础信息 + message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) + message_segment = Seg.from_dict(message_dict.get("message_segment", {})) + + # 初始化处理状态 + processing_state = { + "is_emoji": False, + "has_emoji": False, + "is_picid": False, + "has_picid": False, + "is_voice": False, + "is_video": False, + "is_mentioned": None, + "is_at": False, + "priority_mode": "interest", + "priority_info": None, + } + + # 异步处理消息段,生成纯文本 + processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info) + + # 解析 notice 信息 + is_notify = False + is_public_notice = False + notice_type = None + if message_info.additional_config and isinstance(message_info.additional_config, dict): + is_notify = message_info.additional_config.get("is_notice", False) + is_public_notice = message_info.additional_config.get("is_public_notice", False) + notice_type = message_info.additional_config.get("notice_type") + + # 提取用户信息 + user_info = message_info.user_info + user_id = str(user_info.user_id) if user_info and user_info.user_id else "" + user_nickname = (user_info.user_nickname or "") if user_info else "" + user_cardname = user_info.user_cardname if user_info else None + user_platform = (user_info.platform or "") if user_info else "" + + # 提取群组信息 + group_info = message_info.group_info + group_id = group_info.group_id if group_info else None + group_name = group_info.group_name if group_info else None + group_platform = group_info.platform if group_info else None + + # chat_id 应该直接使用 stream_id(与数据库存储格式一致) + # stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的 + chat_id = stream_id + + # 准备 additional_config + additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type) + + # 提取 reply_to + reply_to = _extract_reply_from_segment(message_segment) + + # 构造 DatabaseMessages + message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() + message_id = message_info.message_id or "" + + # 处理 is_mentioned + is_mentioned = None + mentioned_value = processing_state.get("is_mentioned") + if isinstance(mentioned_value, bool): + is_mentioned = mentioned_value + elif isinstance(mentioned_value, (int, float)): + is_mentioned = mentioned_value != 0 + + db_message = DatabaseMessages( + message_id=message_id, + time=float(message_time), + chat_id=chat_id, + reply_to=reply_to, + processed_plain_text=processed_plain_text, + display_message=processed_plain_text, + is_mentioned=is_mentioned, + is_at=bool(processing_state.get("is_at", False)), + is_emoji=bool(processing_state.get("is_emoji", False)), + is_picid=bool(processing_state.get("is_picid", False)), + is_command=False, # 将在后续处理中设置 + is_notify=bool(is_notify), + is_public_notice=bool(is_public_notice), + notice_type=notice_type, + additional_config=additional_config_str, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + user_platform=user_platform, + chat_info_stream_id=stream_id, + chat_info_platform=platform, + chat_info_create_time=0.0, # 将由 ChatStream 填充 + chat_info_last_active_time=0.0, # 将由 ChatStream 填充 + chat_info_user_id=user_id, + chat_info_user_nickname=user_nickname, + chat_info_user_cardname=user_cardname, + chat_info_user_platform=user_platform, + chat_info_group_id=group_id, + chat_info_group_name=group_name, + chat_info_group_platform=group_platform, + ) + + # 设置优先级信息 + if processing_state.get("priority_mode"): + setattr(db_message, "priority_mode", processing_state["priority_mode"]) + if processing_state.get("priority_info"): + setattr(db_message, "priority_info", processing_state["priority_info"]) + + # 设置其他运行时属性 + setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False))) + setattr(db_message, "is_video", bool(processing_state.get("is_video", False))) + setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False))) + setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False))) + + return db_message + + +async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: + """递归处理消息段,转换为文字描述 + + Args: + segment: 要处理的消息段 + state: 处理状态字典(用于记录消息类型标记) + message_info: 消息基础信息(用于某些处理逻辑) + + Returns: + str: 处理后的文本 + """ + if segment.type == "seglist": + # 处理消息段列表 + segments_text = [] + for seg in segment.data: + processed = await _process_message_segments(seg, state, message_info) + if processed: + segments_text.append(processed) + return " ".join(segments_text) + else: + # 处理单个消息段 + return await _process_single_segment(segment, state, message_info) + + +async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: + """处理单个消息段 + + Args: + segment: 消息段 + state: 处理状态字典 + message_info: 消息基础信息 + + Returns: + str: 处理后的文本 + """ + try: + if segment.type == "text": + state["is_picid"] = False + state["is_emoji"] = False + state["is_video"] = False + return segment.data + + elif segment.type == "at": + state["is_picid"] = False + state["is_emoji"] = False + state["is_video"] = False + state["is_at"] = True + # 处理at消息,格式为"昵称:QQ号" + if isinstance(segment.data, str) and ":" in segment.data: + nickname, qq_id = segment.data.split(":", 1) + return f"@{nickname}" + return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" + + elif segment.type == "image": + # 如果是base64图片数据 + if isinstance(segment.data, str): + state["has_picid"] = True + state["is_picid"] = True + state["is_emoji"] = False + state["is_video"] = False + image_manager = get_image_manager() + _, processed_text = await image_manager.process_image(segment.data) + return processed_text + return "[发了一张图片,网卡了加载不出来]" + + elif segment.type == "emoji": + state["has_emoji"] = True + state["is_emoji"] = True + state["is_picid"] = False + state["is_voice"] = False + state["is_video"] = False + if isinstance(segment.data, str): + return await get_image_manager().get_emoji_description(segment.data) + return "[发了一个表情包,网卡了加载不出来]" + + elif segment.type == "voice": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = True + state["is_video"] = False + + # 检查消息是否由机器人自己发送 + if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account): + logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。") + if isinstance(segment.data, str): + cached_text = consume_self_voice_text(segment.data) + if cached_text: + logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") + return f"[语音:{cached_text}]" + else: + logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") + + # 标准语音识别流程 + if isinstance(segment.data, str): + return await get_voice_text(segment.data) + return "[发了一段语音,网卡了加载不出来]" + + elif segment.type == "mention_bot": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + state["is_video"] = False + state["is_mentioned"] = float(segment.data) + return "" + + elif segment.type == "priority_info": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + if isinstance(segment.data, dict): + # 处理优先级信息 + state["priority_mode"] = "priority" + state["priority_info"] = segment.data + return "" + + elif segment.type == "file": + if isinstance(segment.data, dict): + file_name = segment.data.get("name", "未知文件") + file_size = segment.data.get("size", "未知大小") + return f"[文件:{file_name} ({file_size}字节)]" + return "[收到一个文件]" + + elif segment.type == "video": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + state["is_video"] = True + logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") + + # 检查视频分析功能是否可用 + if not is_video_analysis_available(): + logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") + return "[视频]" + + if global_config.video_analysis.enable: + logger.info("已启用视频识别,开始识别") + if isinstance(segment.data, dict): + try: + # 从Adapter接收的视频数据 + video_base64 = segment.data.get("base64") + filename = segment.data.get("filename", "video.mp4") + + logger.info(f"视频文件名: {filename}") + logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") + + if video_base64: + # 解码base64视频数据 + video_bytes = base64.b64decode(video_base64) + logger.info(f"解码后视频大小: {len(video_bytes)} 字节") + + # 使用video analyzer分析视频 + video_analyzer = get_video_analyzer() + result = await video_analyzer.analyze_video_from_bytes( + video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt + ) + + logger.info(f"视频分析结果: {result}") + + # 返回视频分析结果 + summary = result.get("summary", "") + if summary: + return f"[视频内容] {summary}" + else: + return "[已收到视频,但分析失败]" + else: + logger.warning("视频消息中没有base64数据") + return "[收到视频消息,但数据异常]" + except Exception as e: + logger.error(f"视频处理失败: {e!s}") + import traceback + logger.error(f"错误详情: {traceback.format_exc()}") + return "[收到视频,但处理时出现错误]" + else: + logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") + return "[发了一个视频,但格式不支持]" + else: + return "" + else: + logger.warning(f"未知的消息段类型: {segment.type}") + return f"[{segment.type} 消息]" + + except Exception as e: + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") + return f"[处理失败的{segment.type}消息]" + + +def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None: + """准备 additional_config,包含 format_info 和 notice 信息 + + Args: + message_info: 消息基础信息 + is_notify: 是否为notice消息 + is_public_notice: 是否为公共notice + notice_type: notice类型 + + Returns: + str | None: JSON 字符串格式的 additional_config,如果为空则返回 None + """ + try: + additional_config_data = {} + + # 首先获取adapter传递的additional_config + if hasattr(message_info, "additional_config") and message_info.additional_config: + if isinstance(message_info.additional_config, dict): + additional_config_data = message_info.additional_config.copy() + elif isinstance(message_info.additional_config, str): + try: + additional_config_data = orjson.loads(message_info.additional_config) + except Exception as e: + logger.warning(f"无法解析 additional_config JSON: {e}") + additional_config_data = {} + + # 添加notice相关标志 + if is_notify: + additional_config_data["is_notice"] = True + additional_config_data["notice_type"] = notice_type or "unknown" + additional_config_data["is_public_notice"] = bool(is_public_notice) + + # 添加format_info到additional_config中 + if hasattr(message_info, "format_info") and message_info.format_info: + try: + format_info_dict = message_info.format_info.to_dict() + additional_config_data["format_info"] = format_info_dict + logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}") + except Exception as e: + logger.warning(f"将 format_info 转换为字典失败: {e}") + + # 序列化为JSON字符串 + if additional_config_data: + return orjson.dumps(additional_config_data).decode("utf-8") + except Exception as e: + logger.error(f"准备 additional_config 失败: {e}") + + return None + + +def _extract_reply_from_segment(segment: Seg) -> str | None: + """从消息段中提取reply_to信息 + + Args: + segment: 消息段 + + Returns: + str | None: 回复的消息ID,如果没有则返回None + """ + try: + if hasattr(segment, "type") and segment.type == "seglist": + # 递归搜索seglist中的reply段 + if hasattr(segment, "data") and segment.data: + for seg in segment.data: + reply_id = _extract_reply_from_segment(seg) + if reply_id: + return reply_id + elif hasattr(segment, "type") and segment.type == "reply": + # 找到reply段,返回message_id + return str(segment.data) if segment.data else None + except Exception as e: + logger.warning(f"提取reply_to信息失败: {e}") + return None + + +# ============================================================================= +# DatabaseMessages 扩展工具函数 +# ============================================================================= + +def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo: + """从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码) + + Args: + db_message: DatabaseMessages 对象 + + Returns: + BaseMessageInfo: 重建的消息信息对象 + """ + from maim_message import GroupInfo, UserInfo + + # 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo + user_info = UserInfo( + platform=db_message.user_info.platform, + user_id=db_message.user_info.user_id, + user_nickname=db_message.user_info.user_nickname, + user_cardname=db_message.user_info.user_cardname or "" + ) + + # 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在) + group_info = None + if db_message.group_info: + group_info = GroupInfo( + platform=db_message.group_info.group_platform or "", + group_id=db_message.group_info.group_id, + group_name=db_message.group_info.group_name + ) + + # 解析 additional_config(从 JSON 字符串到字典) + additional_config = None + if db_message.additional_config: + try: + additional_config = orjson.loads(db_message.additional_config) + except Exception: + # 如果解析失败,保持为字符串 + pass + + # 创建 BaseMessageInfo + message_info = BaseMessageInfo( + platform=db_message.chat_info.platform, + message_id=db_message.message_id, + time=db_message.time, + user_info=user_info, + group_info=group_info, + additional_config=additional_config # type: ignore + ) + + return message_info + + +def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None: + """安全地为 DatabaseMessages 设置运行时属性 + + Args: + db_message: DatabaseMessages 对象 + attr_name: 属性名 + value: 属性值 + """ + setattr(db_message, attr_name, value) + + +def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any: + """安全地获取 DatabaseMessages 的运行时属性 + + Args: + db_message: DatabaseMessages 对象 + attr_name: 属性名 + default: 默认值 + + Returns: + 属性值或默认值 + """ + return getattr(db_message, attr_name, default) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index edf9bb9c8..314472845 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -5,12 +5,13 @@ import traceback import orjson from sqlalchemy import desc, select, update +from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Images, Messages from src.common.logger import get_logger from .chat_stream import ChatStream -from .message import MessageRecv, MessageSending +from .message import MessageSending logger = get_logger("message_storage") @@ -34,97 +35,166 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: + async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 过滤敏感信息的正则模式 pattern = r".*?|.*?|.*?" - processed_plain_text = message.processed_plain_text - - if processed_plain_text: - processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) - # 增加对None的防御性处理 - safe_processed_plain_text = processed_plain_text or "" - filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) - else: - filtered_processed_plain_text = "" - - if isinstance(message, MessageSending): - display_message = message.display_message - if display_message: - filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + # 如果是 DatabaseMessages,直接使用它的字段 + if isinstance(message, DatabaseMessages): + processed_plain_text = message.processed_plain_text + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) else: - # 如果没有设置display_message,使用processed_plain_text作为显示消息 - filtered_display_message = ( - re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) - ) - interest_value = 0 - is_mentioned = False - reply_to = message.reply_to - priority_mode = "" - priority_info = {} - is_emoji = False - is_picid = False - is_notify = False - is_command = False - key_words = "" - key_words_lite = "" - else: - filtered_display_message = "" - interest_value = message.interest_value + filtered_processed_plain_text = "" + + display_message = message.display_message or message.processed_plain_text or "" + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + + # 直接从 DatabaseMessages 获取所有字段 + msg_id = message.message_id + msg_time = message.time + chat_id = message.chat_id + reply_to = "" # DatabaseMessages 没有 reply_to 字段 is_mentioned = message.is_mentioned - reply_to = "" - priority_mode = message.priority_mode - priority_info = message.priority_info - is_emoji = message.is_emoji - is_picid = message.is_picid - is_notify = message.is_notify - is_command = message.is_command - # 序列化关键词列表为JSON字符串 - key_words = MessageStorage._serialize_keywords(message.key_words) - key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) + interest_value = message.interest_value or 0.0 + priority_mode = "" # DatabaseMessages 没有 priority_mode + priority_info_json = None # DatabaseMessages 没有 priority_info + is_emoji = message.is_emoji or False + is_picid = message.is_picid or False + is_notify = message.is_notify or False + is_command = message.is_command or False + key_words = "" # DatabaseMessages 没有 key_words + key_words_lite = "" + memorized_times = 0 # DatabaseMessages 没有 memorized_times - chat_info_dict = chat_stream.to_dict() - user_info_dict = message.message_info.user_info.to_dict() # type: ignore + # 使用 DatabaseMessages 中的嵌套对象信息 + user_platform = message.user_info.platform if message.user_info else "" + user_id = message.user_info.user_id if message.user_info else "" + user_nickname = message.user_info.user_nickname if message.user_info else "" + user_cardname = message.user_info.user_cardname if message.user_info else None - # message_id 现在是 TextField,直接使用字符串值 - msg_id = message.message_info.message_id + chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" + chat_info_platform = message.chat_info.platform if message.chat_info else "" + chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 + chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 + chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" + chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" + chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" + chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None + chat_info_group_platform = message.group_info.group_platform if message.group_info else None + chat_info_group_id = message.group_info.group_id if message.group_info else None + chat_info_group_name = message.group_info.group_name if message.group_info else None - # 安全地获取 group_info, 如果为 None 则视为空字典 - group_info_from_chat = chat_info_dict.get("group_info") or {} - # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) - user_info_from_chat = chat_info_dict.get("user_info") or {} + else: + # MessageSending 处理逻辑 + processed_plain_text = message.processed_plain_text - # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 - priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + # 增加对None的防御性处理 + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) + else: + filtered_processed_plain_text = "" + + if isinstance(message, MessageSending): + display_message = message.display_message + if display_message: + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + else: + # 如果没有设置display_message,使用processed_plain_text作为显示消息 + filtered_display_message = ( + re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) + ) + interest_value = 0 + is_mentioned = False + reply_to = message.reply_to + priority_mode = "" + priority_info = {} + is_emoji = False + is_picid = False + is_notify = False + is_command = False + key_words = "" + key_words_lite = "" + else: + filtered_display_message = "" + interest_value = message.interest_value + is_mentioned = message.is_mentioned + reply_to = "" + priority_mode = message.priority_mode + priority_info = message.priority_info + is_emoji = message.is_emoji + is_picid = message.is_picid + is_notify = message.is_notify + is_command = message.is_command + # 序列化关键词列表为JSON字符串 + key_words = MessageStorage._serialize_keywords(message.key_words) + key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) + + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() # type: ignore + + # message_id 现在是 TextField,直接使用字符串值 + msg_id = message.message_info.message_id + msg_time = float(message.message_info.time or time.time()) + chat_id = chat_stream.stream_id + memorized_times = message.memorized_times + + # 安全地获取 group_info, 如果为 None 则视为空字典 + group_info_from_chat = chat_info_dict.get("group_info") or {} + # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) + user_info_from_chat = chat_info_dict.get("user_info") or {} + + # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 + priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + + user_platform = user_info_dict.get("platform") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + + chat_info_stream_id = chat_info_dict.get("stream_id") + chat_info_platform = chat_info_dict.get("platform") + chat_info_create_time = float(chat_info_dict.get("create_time", 0.0)) + chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0)) + chat_info_user_platform = user_info_from_chat.get("platform") + chat_info_user_id = user_info_from_chat.get("user_id") + chat_info_user_nickname = user_info_from_chat.get("user_nickname") + chat_info_user_cardname = user_info_from_chat.get("user_cardname") + chat_info_group_platform = group_info_from_chat.get("platform") + chat_info_group_id = group_info_from_chat.get("group_id") + chat_info_group_name = group_info_from_chat.get("group_name") # 获取数据库会话 - new_message = Messages( message_id=msg_id, - time=float(message.message_info.time or time.time()), - chat_id=chat_stream.stream_id, + time=msg_time, + chat_id=chat_id, reply_to=reply_to, is_mentioned=is_mentioned, - chat_info_stream_id=chat_info_dict.get("stream_id"), - chat_info_platform=chat_info_dict.get("platform"), - chat_info_user_platform=user_info_from_chat.get("platform"), - chat_info_user_id=user_info_from_chat.get("user_id"), - chat_info_user_nickname=user_info_from_chat.get("user_nickname"), - chat_info_user_cardname=user_info_from_chat.get("user_cardname"), - chat_info_group_platform=group_info_from_chat.get("platform"), - chat_info_group_id=group_info_from_chat.get("group_id"), - chat_info_group_name=group_info_from_chat.get("group_name"), - chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), - chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), - user_platform=user_info_dict.get("platform"), - user_id=user_info_dict.get("user_id"), - user_nickname=user_info_dict.get("user_nickname"), - user_cardname=user_info_dict.get("user_cardname"), + chat_info_stream_id=chat_info_stream_id, + chat_info_platform=chat_info_platform, + chat_info_user_platform=chat_info_user_platform, + chat_info_user_id=chat_info_user_id, + chat_info_user_nickname=chat_info_user_nickname, + chat_info_user_cardname=chat_info_user_cardname, + chat_info_group_platform=chat_info_group_platform, + chat_info_group_id=chat_info_group_id, + chat_info_group_name=chat_info_group_name, + chat_info_create_time=chat_info_create_time, + chat_info_last_active_time=chat_info_last_active_time, + user_platform=user_platform, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, - memorized_times=message.memorized_times, + memorized_times=memorized_times, interest_value=interest_value, priority_mode=priority_mode, priority_info=priority_info_json, @@ -145,36 +215,43 @@ class MessageStorage: traceback.print_exc() @staticmethod - async def update_message(message): - """更新消息ID""" + async def update_message(message_data: dict): + """更新消息ID(从消息字典)""" try: - mmc_message_id = message.message_info.message_id + # 从字典中提取信息 + message_info = message_data.get("message_info", {}) + mmc_message_id = message_info.get("message_id") + + message_segment = message_data.get("message_segment", {}) + segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None + segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} + qq_message_id = None - logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}") + logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") # 根据消息段类型提取message_id - if message.message_segment.type == "notify": - qq_message_id = message.message_segment.data.get("id") - elif message.message_segment.type == "text": - qq_message_id = message.message_segment.data.get("id") - elif message.message_segment.type == "reply": - qq_message_id = message.message_segment.data.get("id") + if segment_type == "notify": + qq_message_id = segment_data.get("id") + elif segment_type == "text": + qq_message_id = segment_data.get("id") + elif segment_type == "reply": + qq_message_id = segment_data.get("id") if qq_message_id: logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") - elif message.message_segment.type == "adapter_response": + elif segment_type == "adapter_response": logger.debug("适配器响应消息,不需要更新ID") return - elif message.message_segment.type == "adapter_command": + elif segment_type == "adapter_command": logger.debug("适配器命令消息,不需要更新ID") return else: - logger.debug(f"未知的消息段类型: {message.message_segment.type},跳过ID更新") + logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新") return if not qq_message_id: - logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id,跳过更新") - logger.debug(f"消息段数据: {message.message_segment.data}") + logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新") + logger.debug(f"消息段数据: {segment_data}") return # 使用上下文管理器确保session正确管理 diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 3a1204f23..20f927419 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool: await get_global_api().send_message(message) if show_log: logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'") - + # 触发 AFTER_SEND 事件 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType - + from src.plugin_system.core.event_manager import event_manager + if message.chat_stream: logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}") - + # 使用 asyncio.create_task 来异步触发事件,避免阻塞 async def trigger_event_async(): try: - logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件") + logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件") await event_manager.trigger_event( EventType.AFTER_SEND, permission_group="SYSTEM", stream_id=message.chat_stream.stream_id, message=message, ) - logger.info(f"[事件触发] AFTER_SEND 事件触发完成") + logger.info("[事件触发] AFTER_SEND 事件触发完成") except Exception as e: logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True) - + # 创建异步任务,不等待完成 asyncio.create_task(trigger_event_async()) - logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务") + logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务") except Exception as event_error: logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True) - + return True except Exception as e: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index e15dab72a..f52e40657 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -270,7 +270,7 @@ class ChatterActionManager: msg_text = target_message.get("processed_plain_text", "未知消息") else: msg_text = "未知消息" - + logger.info(f"对 {msg_text} 的回复生成失败") return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} except asyncio.CancelledError: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 35a17d675..7ea2b4785 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -137,7 +137,7 @@ class ActionModifier: logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") # === 第二阶段:检查动作的关联类型 === - chat_context = self.chat_stream.stream_context + chat_context = self.chat_stream.context_manager.context current_actions_s2 = self.action_manager.get_using_actions() type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index ef94cf2e3..bbe05e718 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -13,7 +13,7 @@ from typing import Any from src.chat.express.expression_selector import expression_selector from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo +from src.chat.message_receive.message import MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.utils.chat_message_builder import ( build_readable_messages, @@ -32,10 +32,6 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest -from src.mais4u.mai_think import mai_thinking_manager - -# 旧记忆系统已被移除 -# 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.apis import llm_api @@ -945,40 +941,24 @@ class DefaultReplyer: chat_stream = await chat_manager.get_stream(chat_id) if chat_stream: stream_context = chat_stream.context_manager - # 使用真正的已读和未读消息 - read_messages = stream_context.context.history_messages # 已读消息 + + # 确保历史消息已从数据库加载 + await stream_context.ensure_history_initialized() + + # 直接使用内存中的已读和未读消息,无需再查询数据库 + read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载) unread_messages = stream_context.get_unread_messages() # 未读消息 # 构建已读历史消息 prompt read_history_prompt = "" - # 总是从数据库加载历史记录,并与会话历史合并 - logger.info("正在从数据库加载上下文并与会话历史合并...") - db_messages_raw = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=time.time(), - limit=global_config.chat.max_context_size, - ) + if read_messages: + # 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages + read_messages_dicts = [msg.flatten() for msg in read_messages] - # 合并和去重 - combined_messages = {} - # 首先添加数据库消息 - for msg in db_messages_raw: - if msg.get("message_id"): - combined_messages[msg["message_id"]] = msg - - # 然后用会话消息覆盖/添加,以确保它们是最新的 - for msg_obj in read_messages: - msg_dict = msg_obj.flatten() - if msg_dict.get("message_id"): - combined_messages[msg_dict["message_id"]] = msg_dict - - # 按时间排序 - sorted_messages = sorted(combined_messages.values(), key=lambda x: x.get("time", 0)) + # 按时间排序并限制数量 + sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0)) + final_history = sorted_messages[-50:] # 限制最多50条 - read_history_prompt = "" - if sorted_messages: - # 限制最终用于prompt的历史消息数量 - final_history = sorted_messages[-50:] read_content = await build_readable_messages( final_history, replace_bot_name=True, @@ -986,8 +966,10 @@ class DefaultReplyer: truncate=True, ) read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}" + logger.debug(f"使用内存中的 {len(final_history)} 条历史消息构建prompt") else: read_history_prompt = "暂无已读历史消息" + logger.debug("内存中没有历史消息") # 构建未读历史消息 prompt unread_history_prompt = "" @@ -1161,50 +1143,6 @@ class DefaultReplyer: return interest_scores - def build_mai_think_context( - self, - chat_id: str, - memory_block: str, - relation_info: str, - time_block: str, - chat_target_1: str, - chat_target_2: str, - mood_prompt: str, - identity_block: str, - sender: str, - target: str, - chat_info: str, - ) -> Any: - """构建 mai_think 上下文信息 - - Args: - chat_id: 聊天ID - memory_block: 记忆块内容 - relation_info: 关系信息 - time_block: 时间块内容 - chat_target_1: 聊天目标1 - chat_target_2: 聊天目标2 - mood_prompt: 情绪提示 - identity_block: 身份块内容 - sender: 发送者名称 - target: 目标消息内容 - chat_info: 聊天信息 - - Returns: - Any: mai_think 实例 - """ - mai_think = mai_thinking_manager.get_mai_think(chat_id) - mai_think.memory_block = memory_block - mai_think.relation_info_block = relation_info - mai_think.time_block = time_block - mai_think.chat_target = chat_target_1 - mai_think.chat_target_2 = chat_target_2 - mai_think.chat_info = chat_info - mai_think.mood_state = mood_prompt - mai_think.identity = identity_block - mai_think.sender = sender - mai_think.target = target - return mai_think async def build_prompt_reply_context( self, @@ -1254,7 +1192,7 @@ class DefaultReplyer: if reply_message is None: logger.warning("reply_message 为 None,无法构建prompt") return "" - + # 统一处理 DatabaseMessages 对象和字典 if isinstance(reply_message, DatabaseMessages): platform = reply_message.chat_info.platform @@ -1268,7 +1206,7 @@ class DefaultReplyer: user_nickname = reply_message.get("user_nickname") user_cardname = reply_message.get("user_cardname") processed_plain_text = reply_message.get("processed_plain_text") - + person_id = person_info_manager.get_person_id( platform, # type: ignore user_id, # type: ignore @@ -1320,17 +1258,41 @@ class DefaultReplyer: action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=time.time(), - limit=global_config.chat.max_context_size * 2, - ) + # 从内存获取历史消息,避免重复查询数据库 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_manager = get_chat_manager() + chat_stream_obj = await chat_manager.get_stream(chat_id) + + if chat_stream_obj: + # 确保历史消息已初始化 + await chat_stream_obj.context_manager.ensure_history_initialized() + + # 获取所有消息(历史+未读) + all_messages = ( + chat_stream_obj.context_manager.context.history_messages + + chat_stream_obj.context_manager.get_unread_messages() + ) + + # 转换为字典格式 + message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]] + message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]] + + logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}") + else: + # 回退到数据库查询 + logger.warning(f"无法获取chat_stream,回退到数据库查询: {chat_id}") + message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat( + chat_id=chat_id, + timestamp=time.time(), + limit=global_config.chat.max_context_size * 2, + ) + message_list_before_short = await get_raw_msg_before_timestamp_with_chat( + chat_id=chat_id, + timestamp=time.time(), + limit=int(global_config.chat.max_context_size * 0.33), + ) - message_list_before_short = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.33), - ) chat_talking_prompt_short = await build_readable_messages( message_list_before_short, replace_bot_name=True, @@ -1668,11 +1630,36 @@ class DefaultReplyer: else: mood_prompt = "" - message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=time.time(), - limit=min(int(global_config.chat.max_context_size * 0.33), 15), - ) + # 从内存获取历史消息,避免重复查询数据库 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_manager = get_chat_manager() + chat_stream_obj = await chat_manager.get_stream(chat_id) + + if chat_stream_obj: + # 确保历史消息已初始化 + await chat_stream_obj.context_manager.ensure_history_initialized() + + # 获取所有消息(历史+未读) + all_messages = ( + chat_stream_obj.context_manager.context.history_messages + + chat_stream_obj.context_manager.get_unread_messages() + ) + + # 转换为字典格式,限制数量 + limit = min(int(global_config.chat.max_context_size * 0.33), 15) + message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]] + + logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息") + else: + # 回退到数据库查询 + logger.warning(f"无法获取chat_stream,回退到数据库查询: {chat_id}") + message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( + chat_id=chat_id, + timestamp=time.time(), + limit=min(int(global_config.chat.max_context_size * 0.33), 15), + ) + chat_talking_prompt_half = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, @@ -1779,7 +1766,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: MessageRecv | None = None, + anchor_message: DatabaseMessages | None = None, ) -> MessageSending: """构建单个发送消息""" @@ -1789,8 +1776,11 @@ class DefaultReplyer: platform=self.chat_stream.platform, ) - # await anchor_message.process() - sender_info = anchor_message.message_info.user_info if anchor_message else None + # 从 DatabaseMessages 获取 sender_info + if anchor_message: + sender_info = anchor_message.user_info + else: + sender_info = None return MessageSending( message_id=message_id, # 使用片段的唯一ID @@ -1826,7 +1816,7 @@ class DefaultReplyer: # 循环移除,以处理模型可能生成的嵌套回复头/尾 # 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容 pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?:(?P.*)\](?:,?说:)?\s*$", re.DOTALL) - + temp_content = cleaned_content while True: match = pattern.match(temp_content) @@ -1838,7 +1828,7 @@ class DefaultReplyer: temp_content = new_content else: break # 没有匹配到,退出循环 - + # 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况 # 这可以作为处理复杂嵌套的最后一道防线 final_split = temp_content.rsplit("],说:", 1) @@ -1846,7 +1836,7 @@ class DefaultReplyer: final_content = final_split[1].strip() else: final_content = temp_content - + if final_content != content: logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'") content = final_content @@ -2083,12 +2073,35 @@ class DefaultReplyer: memory_context = {key: value for key, value in memory_context.items() if value} - # 构建聊天历史用于存储 - message_list_before_short = await get_raw_msg_before_timestamp_with_chat( - chat_id=stream.stream_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.33), - ) + # 从内存获取聊天历史用于存储,避免重复查询数据库 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_manager = get_chat_manager() + chat_stream_obj = await chat_manager.get_stream(stream.stream_id) + + if chat_stream_obj: + # 确保历史消息已初始化 + await chat_stream_obj.context_manager.ensure_history_initialized() + + # 获取所有消息(历史+未读) + all_messages = ( + chat_stream_obj.context_manager.context.history_messages + + chat_stream_obj.context_manager.get_unread_messages() + ) + + # 转换为字典格式,限制数量 + limit = int(global_config.chat.max_context_size * 0.33) + message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]] + + logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息") + else: + # 回退到数据库查询 + logger.warning(f"记忆存储:无法获取chat_stream,回退到数据库查询: {stream.stream_id}") + message_list_before_short = await get_raw_msg_before_timestamp_with_chat( + chat_id=stream.stream_id, + timestamp=time.time(), + limit=int(global_config.chat.max_context_size * 0.33), + ) chat_history = await build_readable_messages( message_list_before_short, replace_bot_name=True, diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 2e141e6ad..c10056bf2 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1112,14 +1112,14 @@ class Prompt: # 使用关系提取器构建用户关系信息和聊天流印象 user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5) stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id) - + # 组合两部分信息 info_parts = [] if user_relation_info: info_parts.append(user_relation_info) if stream_impression: info_parts.append(stream_impression) - + return "\n\n".join(info_parts) if info_parts else "" def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 496e50673..f0d5e2529 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,8 @@ import rjieba from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.message_receive.message import MessageRecv + +# MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages from src.config.config import global_config, model_config @@ -41,34 +42,58 @@ def db_message_to_str(message_dict: dict) -> str: return result -def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: - """检查消息是否提到了机器人""" +def is_mentioned_bot_in_message(message) -> tuple[bool, float]: + """检查消息是否提到了机器人 + + Args: + message: DatabaseMessages 消息对象 + + Returns: + tuple[bool, float]: (是否提及, 提及概率) + """ keywords = [global_config.bot.nickname] nicknames = global_config.bot.alias_names reply_probability = 0.0 is_at = False is_mentioned = False - if message.is_mentioned is not None: - return bool(message.is_mentioned), message.is_mentioned - if ( - message.message_info.additional_config is not None - and message.message_info.additional_config.get("is_mentioned") is not None - ): + + # 检查 is_mentioned 属性 + mentioned_attr = getattr(message, "is_mentioned", None) + if mentioned_attr is not None: try: - reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore + return bool(mentioned_attr), float(mentioned_attr) + except (ValueError, TypeError): + pass + + # 检查 additional_config + additional_config = None + + # DatabaseMessages: additional_config 是 JSON 字符串 + if message.additional_config: + try: + import orjson + additional_config = orjson.loads(message.additional_config) + except Exception: + pass + + if additional_config and additional_config.get("is_mentioned") is not None: + try: + reply_probability = float(additional_config.get("is_mentioned")) # type: ignore is_mentioned = True return is_mentioned, reply_probability except Exception as e: logger.warning(str(e)) logger.warning( - f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" + f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}" ) - if global_config.bot.nickname in message.processed_plain_text: + # 检查消息文本内容 + processed_text = message.processed_plain_text or "" + if global_config.bot.nickname in processed_text: is_mentioned = True for alias_name in global_config.bot.alias_names: - if alias_name in message.processed_plain_text: + if alias_name in processed_text: is_mentioned = True # 判断是否被@ @@ -110,7 +135,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: logger.debug("被提及,回复概率设置为100%") return is_mentioned, reply_probability - async def get_embedding(text, request_type="embedding") -> list[float] | None: """获取文本的embedding向量""" # 每次都创建新的LLMRequest实例以避免事件循环冲突 diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index eb29b3302..5eb7f0f7b 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -7,7 +7,7 @@ import asyncio import time from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode, ChatType @@ -64,7 +64,7 @@ class StreamContext(BaseDataModel): triggering_user_id: str | None = None # 触发当前聊天流的用户ID is_replying: bool = False # 是否正在生成回复 processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复 - decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史 + decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史 def add_action_to_message(self, message_id: str, action: str): """ @@ -260,7 +260,7 @@ class StreamContext(BaseDataModel): if requested_type not in accept_format: logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}") return False - logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") + logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") return True # 方法2: 检查content_format字段(向后兼容) @@ -279,7 +279,7 @@ class StreamContext(BaseDataModel): if requested_type not in content_format: logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") return False - logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") + logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") return True else: logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段") diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 2ab7ba13e..fad348bf9 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -9,15 +9,18 @@ from src.common.logger import get_logger logger = get_logger("db_migration") -async def check_and_migrate_database(): +async def check_and_migrate_database(existing_engine=None): """ 异步检查数据库结构并自动迁移。 - 自动创建不存在的表。 - 自动为现有表添加缺失的列。 - 自动为现有表创建缺失的索引。 + + Args: + existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。 """ logger.info("正在检查数据库结构并执行自动迁移...") - engine = await get_engine() + engine = existing_engine if existing_engine is not None else await get_engine() async with engine.connect() as connection: # 在同步上下文中运行inspector操作 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 9f03aa43c..287f0fc29 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) # 迁移 - try: - from src.common.database.db_migration import check_and_migrate_database - await check_and_migrate_database(existing_engine=_engine) - except TypeError: - from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate - await _legacy_migrate() + from src.common.database.db_migration import check_and_migrate_database + await check_and_migrate_database(existing_engine=_engine) if config.database_type == "sqlite": await enable_sqlite_wal_mode(_engine) diff --git a/src/config/config.py b/src/config/config.py index efd57be69..b22674893 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -26,7 +26,6 @@ from src.config.official_configs import ( EmojiConfig, ExperimentalConfig, ExpressionConfig, - ReactionConfig, LPMMKnowledgeConfig, MaimMessageConfig, MemoryConfig, @@ -38,6 +37,7 @@ from src.config.official_configs import ( PersonalityConfig, PlanningSystemConfig, ProactiveThinkingConfig, + ReactionConfig, ResponsePostProcessConfig, ResponseSplitterConfig, ToolConfig, diff --git a/src/config/official_configs.py b/src/config/official_configs.py index cc9885b8c..24957cd30 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -188,7 +188,7 @@ class ExpressionConfig(ValidatedConfigBase): """表达配置类""" mode: Literal["classic", "exp_model"] = Field( - default="classic", + default="classic", description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测" ) rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @@ -761,35 +761,35 @@ class ProactiveThinkingConfig(ValidatedConfigBase): cold_start_cooldown: int = Field( default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)" ) - + # --- 新增:间隔配置 --- base_interval: int = Field(default=1800, ge=60, description="基础触发间隔(秒),默认30分钟") min_interval: int = Field(default=600, ge=60, description="最小触发间隔(秒),默认10分钟。兴趣分数高时会接近此值") max_interval: int = Field(default=7200, ge=60, description="最大触发间隔(秒),默认2小时。兴趣分数低时会接近此值") - + # --- 新增:动态调整配置 --- use_interest_score: bool = Field(default=True, description="是否根据兴趣分数动态调整间隔。关闭则使用固定base_interval") interest_score_factor: float = Field(default=2.0, ge=1.0, le=3.0, description="兴趣分数影响因子。公式: interval = base * (factor - score)") - + # --- 新增:黑白名单配置 --- whitelist_mode: bool = Field(default=False, description="是否启用白名单模式。启用后只对白名单中的聊天流生效") blacklist_mode: bool = Field(default=False, description="是否启用黑名单模式。启用后排除黑名单中的聊天流") - + whitelist_private: list[str] = Field( - default_factory=list, + default_factory=list, description='私聊白名单,格式: ["platform:user_id:private", "qq:12345:private"]' ) whitelist_group: list[str] = Field( - default_factory=list, + default_factory=list, description='群聊白名单,格式: ["platform:group_id:group", "qq:123456:group"]' ) - + blacklist_private: list[str] = Field( - default_factory=list, + default_factory=list, description='私聊黑名单,格式: ["platform:user_id:private", "qq:12345:private"]' ) blacklist_group: list[str] = Field( - default_factory=list, + default_factory=list, description='群聊黑名单,格式: ["platform:group_id:group", "qq:123456:group"]' ) @@ -802,17 +802,17 @@ class ProactiveThinkingConfig(ValidatedConfigBase): quiet_hours_start: str = Field(default="00:00", description='安静时段开始时间,格式: "HH:MM"') quiet_hours_end: str = Field(default="07:00", description='安静时段结束时间,格式: "HH:MM"') active_hours_multiplier: float = Field(default=0.7, ge=0.1, le=2.0, description="活跃时段间隔倍数,<1表示更频繁,>1表示更稀疏") - + # --- 新增:冷却与限制 --- reply_reset_enabled: bool = Field(default=True, description="bot回复后是否重置定时器(避免回复后立即又主动发言)") topic_throw_cooldown: int = Field(default=3600, ge=0, description="抛出话题后的冷却时间(秒),期间暂停主动思考") max_daily_proactive: int = Field(default=0, ge=0, description="每个聊天流每天最多主动发言次数,0表示不限制") - + # --- 新增:决策权重配置 --- do_nothing_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="do_nothing动作的基础权重") simple_bubble_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="simple_bubble动作的基础权重") throw_topic_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="throw_topic动作的基础权重") - + # --- 新增:调试与监控 --- enable_statistics: bool = Field(default=True, description="是否启用统计功能(记录触发次数、决策分布等)") log_decisions: bool = Field(default=False, description="是否记录每次决策的详细日志(用于调试)") diff --git a/src/main.py b/src/main.py index c23d887b3..c11180e43 100644 --- a/src/main.py +++ b/src/main.py @@ -429,7 +429,7 @@ MoFox_Bot(第三方修改版) await initialize_scheduler() except Exception as e: logger.error(f"统一调度器初始化失败: {e}") - + # 加载所有插件 plugin_manager.load_all_plugins() diff --git a/src/mais4u/config/old/s4u_config_20250715_141713.toml b/src/mais4u/config/old/s4u_config_20250715_141713.toml deleted file mode 100644 index 538fcd88a..000000000 --- a/src/mais4u/config/old/s4u_config_20250715_141713.toml +++ /dev/null @@ -1,36 +0,0 @@ -[inner] -version = "1.0.0" - -#----以下是S4U聊天系统配置文件---- -# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块 -# 支持优先级队列、消息中断、VIP用户等高级功能 -# -# 如果你想要修改配置文件,请在修改后将version的值进行变更 -# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类 -# -# 版本格式:主版本号.次版本号.修订号 -#----S4U配置说明结束---- - -[s4u] -# 消息管理配置 -message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃 -recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除 - -# 优先级系统配置 -at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数 -vip_queue_priority = true # 是否启用VIP队列优先级系统 -enable_message_interruption = true # 是否允许高优先级消息中断当前回复 - -# 打字效果配置 -typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度 -enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟 - -# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效) -chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟 -min_typing_delay = 0.2 # 最小打字延迟(秒) -max_typing_delay = 2.0 # 最大打字延迟(秒) - -# 系统功能开关 -enable_old_message_cleanup = true # 是否自动清理过旧的普通消息 -enable_loading_indicator = true # 是否显示加载提示 - diff --git a/src/mais4u/config/s4u_config.toml b/src/mais4u/config/s4u_config.toml deleted file mode 100644 index 26fdef449..000000000 --- a/src/mais4u/config/s4u_config.toml +++ /dev/null @@ -1,132 +0,0 @@ -[inner] -version = "1.1.0" - -#----以下是S4U聊天系统配置文件---- -# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块 -# 支持优先级队列、消息中断、VIP用户等高级功能 -# -# 如果你想要修改配置文件,请在修改后将version的值进行变更 -# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类 -# -# 版本格式:主版本号.次版本号.修订号 -#----S4U配置说明结束---- - -[s4u] -# 消息管理配置 -message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃 -recent_message_keep_count = 8 # 保留最近N条消息,超出范围的普通消息将被移除 - -# 优先级系统配置 -at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数 -vip_queue_priority = true # 是否启用VIP队列优先级系统 -enable_message_interruption = true # 是否允许高优先级消息中断当前回复 - -# 打字效果配置 -typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度 -enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟 - -# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效) -chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟 -min_typing_delay = 0.2 # 最小打字延迟(秒) -max_typing_delay = 2.0 # 最大打字延迟(秒) - -# 系统功能开关 -enable_old_message_cleanup = true # 是否自动清理过旧的普通消息 -enable_loading_indicator = true # 是否显示加载提示 - -enable_streaming_output = false # 是否启用流式输出,false时全部生成后一次性发送 - -max_context_message_length = 30 -max_core_message_length = 20 - -# 模型配置 -[models] -# 主要对话模型配置 -[models.chat] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false - -# 规划模型配置 -[models.motion] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false - -# 情感分析模型配置 -[models.emotion] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 记忆模型配置 -[models.memory] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 工具使用模型配置 -[models.tool_use] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 嵌入模型配置 -[models.embedding] -name = "text-embedding-v1" -provider = "OPENAI" -dimension = 1024 - -# 视觉语言模型配置 -[models.vlm] -name = "qwen-vl-plus" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 知识库模型配置 -[models.knowledge] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 实体提取模型配置 -[models.entity_extract] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 问答模型配置 -[models.qa] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 - -# 兼容性配置(已废弃,请使用models.motion) -[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 -# 强烈建议使用免费的小模型 -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false # 是否启用思考 \ No newline at end of file diff --git a/src/mais4u/config/s4u_config_template.toml b/src/mais4u/config/s4u_config_template.toml deleted file mode 100644 index 40adb1f63..000000000 --- a/src/mais4u/config/s4u_config_template.toml +++ /dev/null @@ -1,67 +0,0 @@ -[inner] -version = "1.1.0" - -#----以下是S4U聊天系统配置文件---- -# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块 -# 支持优先级队列、消息中断、VIP用户等高级功能 -# -# 如果你想要修改配置文件,请在修改后将version的值进行变更 -# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类 -# -# 版本格式:主版本号.次版本号.修订号 -#----S4U配置说明结束---- - -[s4u] -# 消息管理配置 -message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃 -recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除 - -# 优先级系统配置 -at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数 -vip_queue_priority = true # 是否启用VIP队列优先级系统 -enable_message_interruption = true # 是否允许高优先级消息中断当前回复 - -# 打字效果配置 -typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度 -enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟 - -# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效) -chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟 -min_typing_delay = 0.2 # 最小打字延迟(秒) -max_typing_delay = 2.0 # 最大打字延迟(秒) - -# 系统功能开关 -enable_old_message_cleanup = true # 是否自动清理过旧的普通消息 - -enable_streaming_output = true # 是否启用流式输出,false时全部生成后一次性发送 - -max_context_message_length = 20 -max_core_message_length = 30 - -# 模型配置 -[models] -# 主要对话模型配置 -[models.chat] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false - -# 规划模型配置 -[models.motion] -name = "qwen3-32b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false - -# 情感分析模型配置 -[models.emotion] -name = "qwen3-8b" -provider = "BAILIAN" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 diff --git a/src/mais4u/constant_s4u.py b/src/mais4u/constant_s4u.py deleted file mode 100644 index eda7aa375..000000000 --- a/src/mais4u/constant_s4u.py +++ /dev/null @@ -1 +0,0 @@ -ENABLE_S4U = False diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py deleted file mode 100644 index 9fbb5767d..000000000 --- a/src/mais4u/mai_think.py +++ /dev/null @@ -1,178 +0,0 @@ -import time - -from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.message_receive.message import MessageRecvS4U -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.logger import get_logger -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest -from src.mais4u.mais4u_chat.internal_manager import internal_manager -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor - -logger = get_logger(__name__) - - -def init_prompt(): - Prompt( - """ -你之前的内心想法是:{mind} - -{memory_block} -{relation_info_block} - -{chat_target} -{time_block} -{chat_info} -{identity} - -你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state} ---------------------- -在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复 -你刚刚选择回复的内容是:{reponse} -现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容 -请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出想法:""", - "after_response_think_prompt", - ) - - -class MaiThinking: - def __init__(self, chat_id): - self.chat_id = chat_id - # 这些将在异步初始化中设置 - self.chat_stream = None # type: ignore - self.platform = None - self.is_group = False - self._initialized = False - - self.s4u_message_processor = S4UMessageProcessor() - - self.mind = "" - - self.memory_block = "" - self.relation_info_block = "" - self.time_block = "" - self.chat_target = "" - self.chat_target_2 = "" - self.chat_info = "" - self.mood_state = "" - self.identity = "" - self.sender = "" - self.target = "" - - self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking") - - async def _initialize(self): - """异步初始化方法""" - if not self._initialized: - self.chat_stream = await get_chat_manager().get_stream(self.chat_id) - if self.chat_stream: - self.platform = self.chat_stream.platform - self.is_group = bool(self.chat_stream.group_info) - self._initialized = True - - async def do_think_before_response(self): - pass - - async def do_think_after_response(self, reponse: str): - prompt = await global_prompt_manager.format_prompt( - "after_response_think_prompt", - mind=self.mind, - reponse=reponse, - memory_block=self.memory_block, - relation_info_block=self.relation_info_block, - time_block=self.time_block, - chat_target=self.chat_target, - chat_target_2=self.chat_target_2, - chat_info=self.chat_info, - mood_state=self.mood_state, - identity=self.identity, - sender=self.sender, - target=self.target, - ) - - result, _ = await self.thinking_model.generate_response_async(prompt) - self.mind = result - - logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}") - # logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}") - logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}") - - msg_recv = await self.build_internal_message_recv(self.mind) - await self.s4u_message_processor.process_message(msg_recv) - internal_manager.set_internal_state(self.mind) - - async def do_think_when_receive_message(self): - pass - - async def build_internal_message_recv(self, message_text: str): - # 初始化 - await self._initialize() - - msg_id = f"internal_{time.time()}" - - message_dict = { - "message_info": { - "message_id": msg_id, - "time": time.time(), - "user_info": { - "user_id": "internal", # 内部用户ID - "user_nickname": "内心", # 内部昵称 - "platform": self.platform, # 平台标记为 internal - # 其他 user_info 字段按需补充 - }, - "platform": self.platform, # 平台 - # 其他 message_info 字段按需补充 - }, - "message_segment": { - "type": "text", # 消息类型 - "data": message_text, # 消息内容 - # 其他 segment 字段按需补充 - }, - "raw_message": message_text, # 原始消息内容 - "processed_plain_text": message_text, # 处理后的纯文本 - # 下面这些字段可选,根据 MessageRecv 需要 - "is_emoji": False, - "has_emoji": False, - "is_picid": False, - "has_picid": False, - "is_voice": False, - "is_mentioned": False, - "is_command": False, - "is_internal": True, - "priority_mode": "interest", - "priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级 - "interest_value": 1.0, - } - - if self.is_group: - message_dict["message_info"]["group_info"] = { - "platform": self.platform, - "group_id": self.chat_stream.group_info.group_id, - "group_name": self.chat_stream.group_info.group_name, - } - - msg_recv = MessageRecvS4U(message_dict) - msg_recv.chat_info = self.chat_info - msg_recv.chat_stream = self.chat_stream - msg_recv.is_internal = True - - return msg_recv - - -class MaiThinkingManager: - def __init__(self): - self.mai_think_list = [] - - def get_mai_think(self, chat_id): - for mai_think in self.mai_think_list: - if mai_think.chat_id == chat_id: - return mai_think - mai_think = MaiThinking(chat_id) - self.mai_think_list.append(mai_think) - return mai_think - - -mai_thinking_manager = MaiThinkingManager() - - -init_prompt() diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py deleted file mode 100644 index 423eeaf16..000000000 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ /dev/null @@ -1,306 +0,0 @@ -import time - -import orjson -from json_repair import repair_json - -from src.chat.message_receive.message import MessageRecv -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest -from src.mais4u.s4u_config import s4u_config -from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.plugin_system.apis import send_api - -logger = get_logger("action") - -HEAD_CODE = { - "看向上方": "(0,0.5,0)", - "看向下方": "(0,-0.5,0)", - "看向左边": "(-1,0,0)", - "看向右边": "(1,0,0)", - "随意朝向": "random", - "看向摄像机": "camera", - "注视对方": "(0,0,0)", - "看向正前方": "(0,0,0)", -} - -BODY_CODE = { - "双手背后向前弯腰": "010_0070", - "歪头双手合十": "010_0100", - "标准文静站立": "010_0101", - "双手交叠腹部站立": "010_0150", - "帅气的姿势": "010_0190", - "另一个帅气的姿势": "010_0191", - "手掌朝前可爱": "010_0210", - "平静,双手后放": "平静,双手后放", - "思考": "思考", - "优雅,左手放在腰上": "优雅,左手放在腰上", - "一般": "一般", - "可爱,双手前放": "可爱,双手前放", -} - - -def init_prompt(): - Prompt( - """ -{chat_talking_prompt} -以上是群里正在进行的聊天记录 - -{indentify_block} -你现在的动作状态是: -- 身体动作:{body_action} - -现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。 -身体动作可选: -{all_actions} - -请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在: -{{ - "body_action": "..." -}} -""", - "change_action_prompt", - ) - Prompt( - """ -{chat_talking_prompt} -以上是群里最近的聊天记录 - -{indentify_block} -你之前的动作状态是 -- 身体动作:{body_action} - -身体动作可选: -{all_actions} - -距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。 -请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在: -{{ - "body_action": "..." -}} -""", - "regress_action_prompt", - ) - - -class ChatAction: - def __init__(self, chat_id: str): - self.chat_id: str = chat_id - self.body_action: str = "一般" - self.head_action: str = "注视摄像机" - - self.regression_count: int = 0 - # 新增:body_action冷却池,key为动作名,value为剩余冷却次数 - self.body_action_cooldown: dict[str, int] = {} - - print(s4u_config.models.motion) - print(model_config.model_task_config.emotion) - - self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") - - self.last_change_time: float = 0 - - async def send_action_update(self): - """发送动作更新到前端""" - - body_code = BODY_CODE.get(self.body_action, "") - await send_api.custom_to_stream( - message_type="body_action", - content=body_code, - stream_id=self.chat_id, - storage_message=False, - show_log=True, - ) - - async def update_action_by_message(self, message: MessageRecv): - self.regression_count = 0 - - message_time: float = message.message_info.time # type: ignore - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_change_time, - timestamp_end=message_time, - limit=15, - limit_mode="last", - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - - bot_name = global_config.bot.nickname - if global_config.bot.alias_names: - bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" - else: - bot_nickname = "" - - prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - - try: - # 冷却池处理:过滤掉冷却中的动作 - self._update_body_action_cooldown() - available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] - all_actions = "\n".join(available_actions) - - prompt = await global_prompt_manager.format_prompt( - "change_action_prompt", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - body_action=self.body_action, - all_actions=all_actions, - ) - - logger.info(f"prompt: {prompt}") - response, (reasoning_content, _, _) = await self.action_model.generate_response_async( - prompt=prompt, temperature=0.7 - ) - logger.info(f"response: {response}") - logger.info(f"reasoning_content: {reasoning_content}") - - if action_data := orjson.loads(repair_json(response)): - # 记录原动作,切换后进入冷却 - prev_body_action = self.body_action - new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action and prev_body_action: - self.body_action_cooldown[prev_body_action] = 3 - self.body_action = new_body_action - self.head_action = action_data.get("head_action", self.head_action) - # 发送动作更新 - await self.send_action_update() - - self.last_change_time = message_time - except Exception as e: - logger.error(f"update_action_by_message error: {e}") - - async def regress_action(self): - message_time = time.time() - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_change_time, - timestamp_end=message_time, - limit=10, - limit_mode="last", - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - - bot_name = global_config.bot.nickname - if global_config.bot.alias_names: - bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" - else: - bot_nickname = "" - - prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - try: - # 冷却池处理:过滤掉冷却中的动作 - self._update_body_action_cooldown() - available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] - all_actions = "\n".join(available_actions) - - prompt = await global_prompt_manager.format_prompt( - "regress_action_prompt", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - body_action=self.body_action, - all_actions=all_actions, - ) - - logger.info(f"prompt: {prompt}") - response, (reasoning_content, _, _) = await self.action_model.generate_response_async( - prompt=prompt, temperature=0.7 - ) - logger.info(f"response: {response}") - logger.info(f"reasoning_content: {reasoning_content}") - - if action_data := orjson.loads(repair_json(response)): - prev_body_action = self.body_action - new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action and prev_body_action: - self.body_action_cooldown[prev_body_action] = 6 - self.body_action = new_body_action - # 发送动作更新 - await self.send_action_update() - - self.regression_count += 1 - self.last_change_time = message_time - except Exception as e: - logger.error(f"regress_action error: {e}") - - # 新增:冷却池维护方法 - def _update_body_action_cooldown(self): - remove_keys = [] - for k in self.body_action_cooldown: - self.body_action_cooldown[k] -= 1 - if self.body_action_cooldown[k] <= 0: - remove_keys.append(k) - for k in remove_keys: - del self.body_action_cooldown[k] - - -class ActionRegressionTask(AsyncTask): - def __init__(self, action_manager: "ActionManager"): - super().__init__(task_name="ActionRegressionTask", run_interval=3) - self.action_manager = action_manager - - async def run(self): - logger.debug("Running action regression task...") - now = time.time() - for action_state in self.action_manager.action_state_list: - if action_state.last_change_time == 0: - continue - - if now - action_state.last_change_time > 10: - if action_state.regression_count >= 3: - continue - - logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1} 次") - await action_state.regress_action() - - -class ActionManager: - def __init__(self): - self.action_state_list: list[ChatAction] = [] - """当前动作状态""" - self.task_started: bool = False - - async def start(self): - """启动动作回归后台任务""" - if self.task_started: - return - - logger.info("启动动作回归任务...") - task = ActionRegressionTask(self) - await async_task_manager.add_task(task) - self.task_started = True - logger.info("动作回归任务已启动") - - def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction: - for action_state in self.action_state_list: - if action_state.chat_id == chat_id: - return action_state - - new_action_state = ChatAction(chat_id) - self.action_state_list.append(new_action_state) - return new_action_state - - -init_prompt() - -action_manager = ActionManager() -"""全局动作管理器""" diff --git a/src/mais4u/mais4u_chat/context_web_manager.py b/src/mais4u/mais4u_chat/context_web_manager.py deleted file mode 100644 index e615b88b0..000000000 --- a/src/mais4u/mais4u_chat/context_web_manager.py +++ /dev/null @@ -1,692 +0,0 @@ -import asyncio -from collections import deque -from datetime import datetime - -import aiohttp_cors -import orjson -from aiohttp import WSMsgType, web - -from src.chat.message_receive.message import MessageRecv -from src.common.logger import get_logger - -logger = get_logger("context_web") - - -class ContextMessage: - """上下文消息类""" - - def __init__(self, message: MessageRecv): - self.user_name = message.message_info.user_info.user_nickname - self.user_id = message.message_info.user_info.user_id - self.content = message.processed_plain_text - self.timestamp = datetime.now() - self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊" - - # 识别消息类型 - self.is_gift = getattr(message, "is_gift", False) - self.is_superchat = getattr(message, "is_superchat", False) - - # 添加礼物和SC相关信息 - if self.is_gift: - self.gift_name = getattr(message, "gift_name", "") - self.gift_count = getattr(message, "gift_count", "1") - self.content = f"送出了 {self.gift_name} x{self.gift_count}" - elif self.is_superchat: - self.superchat_price = getattr(message, "superchat_price", "0") - self.superchat_message = getattr(message, "superchat_message_text", "") - if self.superchat_message: - self.content = f"[¥{self.superchat_price}] {self.superchat_message}" - else: - self.content = f"[¥{self.superchat_price}] {self.content}" - - def to_dict(self): - return { - "user_name": self.user_name, - "user_id": self.user_id, - "content": self.content, - "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), - "group_name": self.group_name, - "is_gift": self.is_gift, - "is_superchat": self.is_superchat, - } - - -class ContextWebManager: - """上下文网页管理器""" - - def __init__(self, max_messages: int = 10, port: int = 8765): - self.max_messages = max_messages - self.port = port - self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage - self.websockets: list[web.WebSocketResponse] = [] - self.app = None - self.runner = None - self.site = None - self._server_starting = False # 添加启动标志防止并发 - - async def start_server(self): - """启动web服务器""" - if self.site is not None: - logger.debug("Web服务器已经启动,跳过重复启动") - return - - if self._server_starting: - logger.debug("Web服务器正在启动中,等待启动完成...") - # 等待启动完成 - while self._server_starting and self.site is None: - await asyncio.sleep(0.1) - return - - self._server_starting = True - - try: - self.app = web.Application() - - # 设置CORS - cors = aiohttp_cors.setup( - self.app, - defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*" - ) - }, - ) - - # 添加路由 - self.app.router.add_get("/", self.index_handler) - self.app.router.add_get("/ws", self.websocket_handler) - self.app.router.add_get("/api/contexts", self.get_contexts_handler) - self.app.router.add_get("/debug", self.debug_handler) - - # 为所有路由添加CORS - for route in list(self.app.router.routes()): - cors.add(route) - - self.runner = web.AppRunner(self.app) - await self.runner.setup() - - self.site = web.TCPSite(self.runner, "localhost", self.port) - await self.site.start() - - logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}") - - except Exception as e: - logger.error(f"❌ 启动Web服务器失败: {e}") - # 清理部分启动的资源 - if self.runner: - await self.runner.cleanup() - self.app = None - self.runner = None - self.site = None - raise - finally: - self._server_starting = False - - async def stop_server(self): - """停止web服务器""" - if self.site: - await self.site.stop() - if self.runner: - await self.runner.cleanup() - self.app = None - self.runner = None - self.site = None - self._server_starting = False - - async def index_handler(self, request): - """主页处理器""" - html_content = ( - """ - - - - - 聊天上下文 - - - -
- 🔧 调试 -
-
暂无消息
-
-
- - - - - """ - ) - return web.Response(text=html_content, content_type="text/html") - - async def websocket_handler(self, request): - """WebSocket处理器""" - ws = web.WebSocketResponse() - await ws.prepare(request) - - self.websockets.append(ws) - logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}") - - # 发送初始数据 - await self.send_contexts_to_websocket(ws) - - async for msg in ws: - if msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket错误: {ws.exception()}") - break - - # 清理断开的连接 - if ws in self.websockets: - self.websockets.remove(ws) - logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}") - - return ws - - async def get_contexts_handler(self, request): - """获取上下文API""" - all_context_msgs = [] - for contexts in self.contexts.values(): - all_context_msgs.extend(list(contexts)) - - # 按时间排序,最新的在最后 - all_context_msgs.sort(key=lambda x: x.timestamp) - - # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] - - logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息") - return web.json_response({"contexts": contexts_data}) - - async def debug_handler(self, request): - """调试信息处理器""" - debug_info = { - "server_status": "running", - "websocket_connections": len(self.websockets), - "total_chats": len(self.contexts), - "total_messages": sum(len(contexts) for contexts in self.contexts.values()), - } - - # 构建聊天详情HTML - chats_html = "" - for chat_id, contexts in self.contexts.items(): - messages_html = "" - for msg in contexts: - timestamp = msg.timestamp.strftime("%H:%M:%S") - content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content - messages_html += f'
[{timestamp}] {msg.user_name}: {content}
' - - chats_html += f""" -
-

聊天 {chat_id} ({len(contexts)} 条消息)

- {messages_html} -
- """ - - html_content = f""" - - - - - 调试信息 - - - -

上下文网页管理器调试信息

- -
-

服务器状态

-

状态: {debug_info["server_status"]}

-

WebSocket连接数: {debug_info["websocket_connections"]}

-

聊天总数: {debug_info["total_chats"]}

-

消息总数: {debug_info["total_messages"]}

-
- -
-

聊天详情

- {chats_html} -
- -
-

操作

- - - -
- - - - - """ - - return web.Response(text=html_content, content_type="text/html") - - async def add_message(self, chat_id: str, message: MessageRecv): - """添加新消息到上下文""" - if chat_id not in self.contexts: - self.contexts[chat_id] = deque(maxlen=self.max_messages) - logger.debug(f"为聊天 {chat_id} 创建新的上下文队列") - - context_msg = ContextMessage(message) - self.contexts[chat_id].append(context_msg) - - # 统计当前总消息数 - total_messages = sum(len(contexts) for contexts in self.contexts.values()) - - logger.info( - f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}" - ) - - # 调试:打印当前所有消息 - logger.info("📝 当前上下文中的所有消息:") - for cid, contexts in self.contexts.items(): - logger.info(f" 聊天 {cid}: {len(contexts)} 条消息") - for i, msg in enumerate(contexts): - logger.info( - f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..." - ) - - # 广播更新给所有WebSocket连接 - await self.broadcast_contexts() - - async def send_contexts_to_websocket(self, ws: web.WebSocketResponse): - """向单个WebSocket发送上下文数据""" - all_context_msgs = [] - for contexts in self.contexts.values(): - all_context_msgs.extend(list(contexts)) - - # 按时间排序,最新的在最后 - all_context_msgs.sort(key=lambda x: x.timestamp) - - # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] - - data = {"contexts": contexts_data} - await ws.send_str(orjson.dumps(data).decode("utf-8")) - - async def broadcast_contexts(self): - """向所有WebSocket连接广播上下文更新""" - if not self.websockets: - logger.debug("没有WebSocket连接,跳过广播") - return - - all_context_msgs = [] - for contexts in self.contexts.values(): - all_context_msgs.extend(list(contexts)) - - # 按时间排序,最新的在最后 - all_context_msgs.sort(key=lambda x: x.timestamp) - - # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] - - data = {"contexts": contexts_data} - message = orjson.dumps(data).decode("utf-8") - - logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接") - - # 创建WebSocket列表的副本,避免在遍历时修改 - websockets_copy = self.websockets.copy() - removed_count = 0 - - for ws in websockets_copy: - if ws.closed: - if ws in self.websockets: - self.websockets.remove(ws) - removed_count += 1 - else: - try: - await ws.send_str(message) - logger.debug("消息发送成功") - except Exception as e: - logger.error(f"发送WebSocket消息失败: {e}") - if ws in self.websockets: - self.websockets.remove(ws) - removed_count += 1 - - if removed_count > 0: - logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接") - - -# 全局实例 -_context_web_manager: ContextWebManager | None = None - - -def get_context_web_manager() -> ContextWebManager: - """获取上下文网页管理器实例""" - global _context_web_manager - if _context_web_manager is None: - _context_web_manager = ContextWebManager() - return _context_web_manager - - -async def init_context_web_manager(): - """初始化上下文网页管理器""" - manager = get_context_web_manager() - await manager.start_server() - return manager diff --git a/src/mais4u/mais4u_chat/gift_manager.py b/src/mais4u/mais4u_chat/gift_manager.py deleted file mode 100644 index 976476225..000000000 --- a/src/mais4u/mais4u_chat/gift_manager.py +++ /dev/null @@ -1,147 +0,0 @@ -import asyncio -from collections.abc import Callable -from dataclasses import dataclass - -from src.chat.message_receive.message import MessageRecvS4U -from src.common.logger import get_logger - -logger = get_logger("gift_manager") - - -@dataclass -class PendingGift: - """等待中的礼物消息""" - - message: MessageRecvS4U - total_count: int - timer_task: asyncio.Task - callback: Callable[[MessageRecvS4U], None] - - -class GiftManager: - """礼物管理器,提供防抖功能""" - - def __init__(self): - """初始化礼物管理器""" - self.pending_gifts: dict[tuple[str, str], PendingGift] = {} - self.debounce_timeout = 5.0 # 3秒防抖时间 - - async def handle_gift( - self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None - ) -> bool: - """处理礼物消息,返回是否应该立即处理 - - Args: - message: 礼物消息 - callback: 防抖完成后的回调函数 - - Returns: - bool: False表示消息被暂存等待防抖,True表示应该立即处理 - """ - if not message.is_gift: - return True - - # 构建礼物的唯一键:(发送人ID, 礼物名称) - gift_key = (message.message_info.user_info.user_id, message.gift_name) - - # 如果已经有相同的礼物在等待中,则合并 - if gift_key in self.pending_gifts: - await self._merge_gift(gift_key, message) - return False - - # 创建新的等待礼物 - await self._create_pending_gift(gift_key, message, callback) - return False - - async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None: - """合并礼物消息""" - pending_gift = self.pending_gifts[gift_key] - - # 取消之前的定时器 - if not pending_gift.timer_task.cancelled(): - pending_gift.timer_task.cancel() - - # 累加礼物数量 - try: - new_count = int(new_message.gift_count) - pending_gift.total_count += new_count - - # 更新消息为最新的(保留最新的消息,但累加数量) - pending_gift.message = new_message - pending_gift.message.gift_count = str(pending_gift.total_count) - pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}" - - except ValueError: - logger.warning(f"无法解析礼物数量: {new_message.gift_count}") - # 如果无法解析数量,保持原有数量不变 - - # 重新创建定时器 - pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key)) - - logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") - - async def _create_pending_gift( - self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None - ) -> None: - """创建新的等待礼物""" - try: - initial_count = int(message.gift_count) - except ValueError: - initial_count = 1 - logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1") - - # 创建定时器任务 - timer_task = asyncio.create_task(self._gift_timeout(gift_key)) - - # 创建等待礼物对象 - pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback) - - self.pending_gifts[gift_key] = pending_gift - - logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}") - - async def _gift_timeout(self, gift_key: tuple[str, str]) -> None: - """礼物防抖超时处理""" - try: - # 等待防抖时间 - await asyncio.sleep(self.debounce_timeout) - - # 获取等待中的礼物 - if gift_key not in self.pending_gifts: - return - - pending_gift = self.pending_gifts.pop(gift_key) - - logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}") - - message = pending_gift.message - message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}" - - # 执行回调 - if pending_gift.callback: - try: - pending_gift.callback(message) - except Exception as e: - logger.error(f"礼物回调执行失败: {e}", exc_info=True) - - except asyncio.CancelledError: - # 定时器被取消,不需要处理 - pass - except Exception as e: - logger.error(f"礼物防抖处理异常: {e}", exc_info=True) - - def get_pending_count(self) -> int: - """获取当前等待中的礼物数量""" - return len(self.pending_gifts) - - async def flush_all(self) -> None: - """立即处理所有等待中的礼物""" - for gift_key in list(self.pending_gifts.keys()): - pending_gift = self.pending_gifts.get(gift_key) - if pending_gift and not pending_gift.timer_task.cancelled(): - pending_gift.timer_task.cancel() - await self._gift_timeout(gift_key) - - -# 创建全局礼物管理器实例 -gift_manager = GiftManager() diff --git a/src/mais4u/mais4u_chat/internal_manager.py b/src/mais4u/mais4u_chat/internal_manager.py deleted file mode 100644 index 3e4a518d4..000000000 --- a/src/mais4u/mais4u_chat/internal_manager.py +++ /dev/null @@ -1,15 +0,0 @@ -class InternalManager: - def __init__(self): - self.now_internal_state = "" - - def set_internal_state(self, internal_state: str): - self.now_internal_state = internal_state - - def get_internal_state(self): - return self.now_internal_state - - def get_internal_state_str(self): - return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}" - - -internal_manager = InternalManager() diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py deleted file mode 100644 index 919e7e60c..000000000 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ /dev/null @@ -1,611 +0,0 @@ -import asyncio -import random -import time -import traceback - -import orjson -from maim_message import Seg, UserInfo - -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending -from src.chat.message_receive.storage import MessageStorage -from src.common.logger import get_logger -from src.common.message.api import get_global_api -from src.config.config import global_config -from src.mais4u.constant_s4u import ENABLE_S4U -from src.mais4u.s4u_config import s4u_config -from src.person_info.person_info import PersonInfoManager -from src.person_info.relationship_builder_manager import relationship_builder_manager - -from .s4u_mood_manager import mood_manager -from .s4u_stream_generator import S4UStreamGenerator -from .s4u_watching_manager import watching_manager -from .super_chat_manager import get_super_chat_manager -from .yes_or_no import yes_or_no_head - -logger = get_logger("S4U_chat") - - -class MessageSenderContainer: - """一个简单的容器,用于按顺序发送消息并模拟打字效果。""" - - def __init__(self, chat_stream: ChatStream, original_message: MessageRecv): - self.chat_stream = chat_stream - self.original_message = original_message - self.queue = asyncio.Queue() - self.storage = MessageStorage() - self._task: asyncio.Task | None = None - self._paused_event = asyncio.Event() - self._paused_event.set() # 默认设置为非暂停状态 - - self.msg_id = "" - - self.last_msg_id = "" - - self.voice_done = "" - - async def add_message(self, chunk: str): - """向队列中添加一个消息块。""" - await self.queue.put(chunk) - - async def close(self): - """表示没有更多消息了,关闭队列。""" - await self.queue.put(None) # Sentinel - - def pause(self): - """暂停发送。""" - self._paused_event.clear() - - def resume(self): - """恢复发送。""" - self._paused_event.set() - - @staticmethod - def _calculate_typing_delay(text: str) -> float: - """根据文本长度计算模拟打字延迟。""" - chars_per_second = s4u_config.chars_per_second - min_delay = s4u_config.min_typing_delay - max_delay = s4u_config.max_typing_delay - - delay = len(text) / chars_per_second - return max(min_delay, min(delay, max_delay)) - - async def _send_worker(self): - """从队列中取出消息并发送。""" - while True: - try: - # This structure ensures that task_done() is called for every item retrieved, - # even if the worker is cancelled while processing the item. - chunk = await self.queue.get() - except asyncio.CancelledError: - break - - try: - if chunk is None: - break - - # Check for pause signal *after* getting an item. - await self._paused_event.wait() - - # 根据配置选择延迟模式 - if s4u_config.enable_dynamic_typing_delay: - delay = self._calculate_typing_delay(chunk) - else: - delay = s4u_config.typing_delay - await asyncio.sleep(delay) - - message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}") - bot_message = MessageSending( - message_id=self.msg_id, - chat_stream=self.chat_stream, - bot_user_info=UserInfo( - user_id=global_config.bot.qq_account, - user_nickname=global_config.bot.nickname, - platform=self.original_message.message_info.platform, - ), - sender_info=self.original_message.message_info.user_info, - message_segment=message_segment, - reply=self.original_message, - is_emoji=False, - apply_set_reply_logic=True, - reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}", - ) - - await bot_message.process() - - await get_global_api().send_message(bot_message) - logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'") - - message_segment = Seg(type="text", data=chunk) - bot_message = MessageSending( - message_id=self.msg_id, - chat_stream=self.chat_stream, - bot_user_info=UserInfo( - user_id=global_config.bot.qq_account, - user_nickname=global_config.bot.nickname, - platform=self.original_message.message_info.platform, - ), - sender_info=self.original_message.message_info.user_info, - message_segment=message_segment, - reply=self.original_message, - is_emoji=False, - apply_set_reply_logic=True, - reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}", - ) - await bot_message.process() - - await self.storage.store_message(bot_message, self.chat_stream) - - except Exception as e: - logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True) - - finally: - # CRUCIAL: Always call task_done() for any item that was successfully retrieved. - self.queue.task_done() - - def start(self): - """启动发送任务。""" - if self._task is None: - self._task = asyncio.create_task(self._send_worker()) - - async def join(self): - """等待所有消息发送完毕。""" - if self._task: - await self._task - - @property - def task(self): - return self._task - - -class S4UChatManager: - def __init__(self): - self.s4u_chats: dict[str, "S4UChat"] = {} - - async def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat": - if chat_stream.stream_id not in self.s4u_chats: - stream_name = await get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id - logger.info(f"Creating new S4UChat for stream: {stream_name}") - self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream) - return self.s4u_chats[chat_stream.stream_id] - - -if not ENABLE_S4U: - s4u_chat_manager = None -else: - s4u_chat_manager = S4UChatManager() - - -def get_s4u_chat_manager() -> S4UChatManager: - return s4u_chat_manager - - -class S4UChat: - def __init__(self, chat_stream: ChatStream): - """初始化 S4UChat 实例。""" - - self.last_msg_id = self.msg_id - self.chat_stream = chat_stream - self.stream_id = chat_stream.stream_id - self.stream_name = self.stream_id # 初始化时使用stream_id,稍后异步更新 - self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) - - # 两个消息队列 - self._vip_queue = asyncio.PriorityQueue() - self._normal_queue = asyncio.PriorityQueue() - - self._entry_counter = 0 # 保证FIFO的全局计数器 - self._new_message_event = asyncio.Event() # 用于唤醒处理器 - - self._processing_task = asyncio.create_task(self._message_processor()) - self._current_generation_task: asyncio.Task | None = None - # 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象) - self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None - - self._is_replying = False - self.gpt = S4UStreamGenerator() - self.gpt.chat_stream = self.chat_stream - self.interest_dict: dict[str, float] = {} # 用户兴趣分 - - self.internal_message: list[MessageRecvS4U] = [] - - self.msg_id = "" - self.voice_done = "" - - logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.") - self._stream_name_initialized = False - - async def _initialize_stream_name(self): - """异步初始化stream_name""" - if not self._stream_name_initialized: - self.stream_name = await get_chat_manager().get_stream_name(self.stream_id) or self.stream_id - self._stream_name_initialized = True - - @staticmethod - def _get_priority_info(message: MessageRecv) -> dict: - """安全地从消息中提取和解析 priority_info""" - priority_info_raw = message.priority_info - priority_info = {} - if isinstance(priority_info_raw, str): - try: - priority_info = orjson.loads(priority_info_raw) - except orjson.JSONDecodeError: - logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}") - elif isinstance(priority_info_raw, dict): - priority_info = priority_info_raw - return priority_info - - @staticmethod - def _is_vip(priority_info: dict) -> bool: - """检查消息是否来自VIP用户。""" - return priority_info.get("message_type") == "vip" - - def _get_interest_score(self, user_id: str) -> float: - """获取用户的兴趣分,默认为1.0""" - return self.interest_dict.get(user_id, 1.0) - - def go_processing(self): - if self.voice_done == self.last_msg_id: - return True - return False - - def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float: - """ - 为消息计算基础优先级分数。分数越高,优先级越高。 - """ - score = 0.0 - - # 加上消息自带的优先级 - score += priority_info.get("message_priority", 0.0) - - # 加上用户的固有兴趣分 - score += self._get_interest_score(message.message_info.user_info.user_id) - return score - - def decay_interest_score(self): - for person_id, score in self.interest_dict.items(): - if score > 0: - self.interest_dict[person_id] = score * 0.95 - else: - self.interest_dict[person_id] = 0 - - async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None: - # 初始化stream_name - await self._initialize_stream_name() - - self.decay_interest_score() - - """根据VIP状态和中断逻辑将消息放入相应队列。""" - user_id = message.message_info.user_info.user_id - platform = message.message_info.platform - person_id = PersonInfoManager.get_person_id(platform, user_id) - - try: - is_gift = message.is_gift - is_superchat = message.is_superchat - # print(is_gift) - # print(is_superchat) - if is_gift: - await self.relationship_builder.build_relation(immediate_build=person_id) - # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0 - current_score = self.interest_dict.get(person_id, 1.0) - self.interest_dict[person_id] = current_score + 0.1 * message.gift_count - elif is_superchat: - await self.relationship_builder.build_relation(immediate_build=person_id) - # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0 - current_score = self.interest_dict.get(person_id, 1.0) - self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price) - - # 添加SuperChat到管理器 - super_chat_manager = get_super_chat_manager() - await super_chat_manager.add_superchat(message) - else: - await self.relationship_builder.build_relation(20) - except Exception: - traceback.print_exc() - - logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}") - - priority_info = self._get_priority_info(message) - is_vip = self._is_vip(priority_info) - new_priority_score = self._calculate_base_priority_score(message, priority_info) - - should_interrupt = False - if ( - s4u_config.enable_message_interruption - and self._current_generation_task - and not self._current_generation_task.done() - ): - if self._current_message_being_replied: - current_queue, current_priority, _, current_msg = self._current_message_being_replied - - # 规则:VIP从不被打断 - if current_queue == "vip": - pass # Do nothing - - # 规则:普通消息可以被打断 - elif current_queue == "normal": - # VIP消息可以打断普通消息 - if is_vip: - should_interrupt = True - logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.") - # 普通消息的内部打断逻辑 - else: - new_sender_id = message.message_info.user_info.user_id - current_sender_id = current_msg.message_info.user_info.user_id - # 新消息优先级更高 - if new_priority_score > current_priority: - should_interrupt = True - logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.") - # 同用户,新消息的优先级不能更低 - elif new_sender_id == current_sender_id and new_priority_score >= current_priority: - should_interrupt = True - logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.") - - if should_interrupt: - if self.gpt.partial_response: - logger.warning( - f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'" - ) - self._current_generation_task.cancel() - - # asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数 - # 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前 - item = (-new_priority_score, self._entry_counter, time.time(), message) - - if is_vip and s4u_config.vip_queue_priority: - await self._vip_queue.put(item) - logger.info(f"[{self.stream_name}] VIP message added to queue.") - else: - await self._normal_queue.put(item) - - self._entry_counter += 1 - self._new_message_event.set() # 唤醒处理器 - - def _cleanup_old_normal_messages(self): - """清理普通队列中不在最近N条消息范围内的消息""" - if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty(): - return - - # 计算阈值:保留最近 recent_message_keep_count 条消息 - cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count) - - # 临时存储需要保留的消息 - temp_messages = [] - removed_count = 0 - - # 取出所有普通队列中的消息 - while not self._normal_queue.empty(): - try: - item = self._normal_queue.get_nowait() - neg_priority, entry_count, timestamp, message = item - - # 如果消息在最近N条消息范围内,保留它 - logger.info( - f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}" - ) - - if entry_count >= cutoff_counter: - temp_messages.append(item) - else: - removed_count += 1 - self._normal_queue.task_done() # 标记被移除的任务为完成 - - except asyncio.QueueEmpty: - break - - # 将保留的消息重新放入队列 - for item in temp_messages: - self._normal_queue.put_nowait(item) - - if removed_count > 0: - logger.info( - f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除" - ) - logger.info( - f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range." - ) - - async def _message_processor(self): - """调度器:优先处理VIP队列,然后处理普通队列。""" - while True: - try: - # 等待有新消息的信号,避免空转 - await self._new_message_event.wait() - self._new_message_event.clear() - - # 清理普通队列中的过旧消息 - self._cleanup_old_normal_messages() - - # 优先处理VIP队列 - if not self._vip_queue.empty(): - neg_priority, entry_count, _, message = self._vip_queue.get_nowait() - priority = -neg_priority - queue_name = "vip" - # 其次处理普通队列 - elif not self._normal_queue.empty(): - neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait() - priority = -neg_priority - # 检查普通消息是否超时 - if time.time() - timestamp > s4u_config.message_timeout_seconds: - logger.info( - f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..." - ) - self._normal_queue.task_done() - continue # 处理下一条 - queue_name = "normal" - else: - if self.internal_message: - message = self.internal_message[-1] - self.internal_message = [] - - priority = 0 - neg_priority = 0 - entry_count = 0 - queue_name = "internal" - - logger.info( - f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..." - ) - else: - continue # 没有消息了,回去等事件 - - self._current_message_being_replied = (queue_name, priority, entry_count, message) - self._current_generation_task = asyncio.create_task(self._generate_and_send(message)) - - try: - await self._current_generation_task - except asyncio.CancelledError: - logger.info( - f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded." - ) - # 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。 - # 旧的重新入队逻辑会导致所有中断的消息最终都被回复。 - - except Exception as e: - logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True) - finally: - self._current_generation_task = None - self._current_message_being_replied = None - # 标记任务完成 - if queue_name == "vip": - self._vip_queue.task_done() - elif queue_name == "internal": - # 如果使用 internal_message 生成回复,则不从 normal 队列中移除 - pass - else: - self._normal_queue.task_done() - - # 检查是否还有任务,有则立即再次触发事件 - if not self._vip_queue.empty() or not self._normal_queue.empty(): - self._new_message_event.set() - - except asyncio.CancelledError: - logger.info(f"[{self.stream_name}] Message processor is shutting down.") - break - except Exception as e: - logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True) - await asyncio.sleep(1) - - def get_processing_message_id(self): - self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}" - - async def _generate_and_send(self, message: MessageRecv): - """为单个消息生成文本回复。整个过程可以被中断。""" - self._is_replying = True - total_chars_sent = 0 # 跟踪发送的总字符数 - - self.get_processing_message_id() - - # 视线管理:开始生成回复时切换视线状态 - chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) - - if message.is_internal: - await chat_watching.on_internal_message_start() - else: - await chat_watching.on_reply_start() - - sender_container = MessageSenderContainer(self.chat_stream, message) - sender_container.start() - - async def generate_and_send_inner(): - nonlocal total_chars_sent - logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'") - - if s4u_config.enable_streaming_output: - logger.info("[S4U] 开始流式输出") - # 流式输出,边生成边发送 - gen = self.gpt.generate_response(message, "") - async for chunk in gen: - sender_container.msg_id = self.msg_id - await sender_container.add_message(chunk) - total_chars_sent += len(chunk) - else: - logger.info("[S4U] 开始一次性输出") - # 一次性输出,先收集所有chunk - all_chunks = [] - gen = self.gpt.generate_response(message, "") - async for chunk in gen: - all_chunks.append(chunk) - total_chars_sent += len(chunk) - # 一次性发送 - sender_container.msg_id = self.msg_id - await sender_container.add_message("".join(all_chunks)) - - try: - try: - await asyncio.wait_for(generate_and_send_inner(), timeout=10) - except asyncio.TimeoutError: - logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。") - sender_container.msg_id = self.msg_id - await sender_container.add_message("麦麦不知道哦") - total_chars_sent = len("麦麦不知道哦") - - mood = mood_manager.get_mood_by_chat_id(self.stream_id) - await yes_or_no_head( - text=total_chars_sent, - emotion=mood.mood_state, - chat_history=message.processed_plain_text, - chat_id=self.stream_id, - ) - - # 等待所有文本消息发送完成 - await sender_container.close() - await sender_container.join() - - await chat_watching.on_thinking_finished() - - start_time = time.time() - logged = False - while not self.go_processing(): - if time.time() - start_time > 60: - logger.warning(f"[{self.stream_name}] 等待消息发送超时(60秒),强制跳出循环。") - break - if not logged: - logger.info(f"[{self.stream_name}] 等待消息发送完成...") - logged = True - await asyncio.sleep(0.2) - - logger.info(f"[{self.stream_name}] 所有文本块处理完毕。") - - except asyncio.CancelledError: - logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。") - raise # 将取消异常向上传播 - except Exception as e: - traceback.print_exc() - logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True) - # 回复生成实时展示:清空内容(出错时) - finally: - self._is_replying = False - - # 视线管理:回复结束时切换视线状态 - chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) - await chat_watching.on_reply_finished() - - # 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的) - sender_container.resume() - if not sender_container.task.done(): - await sender_container.close() - await sender_container.join() - logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") - - async def shutdown(self): - """平滑关闭处理任务。""" - logger.info(f"正在关闭 S4UChat: {self.stream_name}") - - # 取消正在运行的任务 - if self._current_generation_task and not self._current_generation_task.done(): - self._current_generation_task.cancel() - - if self._processing_task and not self._processing_task.done(): - self._processing_task.cancel() - - # 等待任务响应取消 - try: - await self._processing_task - except asyncio.CancelledError: - logger.info(f"处理任务已成功取消: {self.stream_name}") - - @property - def new_message_event(self): - return self._new_message_event diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py deleted file mode 100644 index 2031f7c56..000000000 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ /dev/null @@ -1,458 +0,0 @@ -import asyncio -import time - -import orjson - -from src.chat.message_receive.message import MessageRecv -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest -from src.mais4u.constant_s4u import ENABLE_S4U -from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.plugin_system.apis import send_api - -""" -情绪管理系统使用说明: - -1. 情绪数值系统: - - 情绪包含四个维度:joy(喜), anger(怒), sorrow(哀), fear(惧) - - 每个维度的取值范围为1-10 - - 当情绪发生变化时,会自动发送到ws端处理 - -2. 情绪更新机制: - - 接收到新消息时会更新情绪状态 - - 定期进行情绪回归(冷静下来) - - 每次情绪变化都会发送到ws端,格式为: - type: "emotion" - data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1} - -3. ws端处理: - - 本地只负责情绪计算和发送情绪数值 - - 表情渲染和动作由ws端根据情绪数值处理 -""" - -logger = get_logger("mood") - - -def init_prompt(): - Prompt( - """ -{chat_talking_prompt} -以上是直播间里正在进行的对话 - -{indentify_block} -你刚刚的情绪状态是:{mood_state} - -现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容 -请只输出情绪状态,不要输出其他内容: -""", - "change_mood_prompt_vtb", - ) - Prompt( - """ -{chat_talking_prompt} -以上是直播间里最近的对话 - -{indentify_block} -你之前的情绪状态是:{mood_state} - -距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态 -请只输出情绪状态,不要输出其他内容: -""", - "regress_mood_prompt_vtb", - ) - Prompt( - """ -{chat_talking_prompt} -以上是直播间里正在进行的对话 - -{indentify_block} -你刚刚的情绪状态是:{mood_state} -具体来说,从1-10分,你的情绪状态是: -喜(Joy): {joy} -怒(Anger): {anger} -哀(Sorrow): {sorrow} -惧(Fear): {fear} - -现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。 -请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。 -键值请使用英文: "joy", "anger", "sorrow", "fear". -例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}} -不要输出任何其他内容,只输出JSON。 -""", - "change_mood_numerical_prompt", - ) - Prompt( - """ -{chat_talking_prompt} -以上是直播间里最近的对话 - -{indentify_block} -你之前的情绪状态是:{mood_state} -具体来说,从1-10分,你的情绪状态是: -喜(Joy): {joy} -怒(Anger): {anger} -哀(Sorrow): {sorrow} -惧(Fear): {fear} - -距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。 -请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。 -键值请使用英文: "joy", "anger", "sorrow", "fear". -例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}} -不要输出任何其他内容,只输出JSON。 -""", - "regress_mood_numerical_prompt", - ) - - -class ChatMood: - def __init__(self, chat_id: str): - self.chat_id: str = chat_id - self.mood_state: str = "感觉很平静" - self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1} - - self.regression_count: int = 0 - - self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text") - self.mood_model_numerical = LLMRequest( - model_set=model_config.model_task_config.emotion, request_type="mood_numerical" - ) - - self.last_change_time: float = 0 - - # 发送初始情绪状态到ws端 - asyncio.create_task(self.send_emotion_update(self.mood_values)) - - @staticmethod - def _parse_numerical_mood(response: str) -> dict[str, int] | None: - try: - # The LLM might output markdown with json inside - if "```json" in response: - response = response.split("```json")[1].split("```")[0] - elif "```" in response: - response = response.split("```")[1].split("```")[0] - - data = orjson.loads(response) - - # Validate - required_keys = {"joy", "anger", "sorrow", "fear"} - if not required_keys.issubset(data.keys()): - logger.warning(f"Numerical mood response missing keys: {response}") - return None - - for key in required_keys: - value = data[key] - if not isinstance(value, int) or not (1 <= value <= 10): - logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}") - return None - - return {key: data[key] for key in required_keys} - - except orjson.JSONDecodeError: - logger.warning(f"Failed to parse numerical mood JSON: {response}") - return None - except Exception as e: - logger.error(f"Error parsing numerical mood: {e}, response: {response}") - return None - - async def update_mood_by_message(self, message: MessageRecv): - self.regression_count = 0 - - message_time: float = message.message_info.time # type: ignore - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_change_time, - timestamp_end=message_time, - limit=10, - limit_mode="last", - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - - bot_name = global_config.bot.nickname - if global_config.bot.alias_names: - bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" - else: - bot_nickname = "" - - prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - - async def _update_text_mood(): - prompt = await global_prompt_manager.format_prompt( - "change_mood_prompt_vtb", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - mood_state=self.mood_state, - ) - logger.debug(f"text mood prompt: {prompt}") - response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( - prompt=prompt, temperature=0.7 - ) - logger.info(f"text mood response: {response}") - logger.debug(f"text mood reasoning_content: {reasoning_content}") - return response - - async def _update_numerical_mood(): - prompt = await global_prompt_manager.format_prompt( - "change_mood_numerical_prompt", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - mood_state=self.mood_state, - joy=self.mood_values["joy"], - anger=self.mood_values["anger"], - sorrow=self.mood_values["sorrow"], - fear=self.mood_values["fear"], - ) - logger.debug(f"numerical mood prompt: {prompt}") - response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( - prompt=prompt, temperature=0.4 - ) - logger.info(f"numerical mood response: {response}") - logger.debug(f"numerical mood reasoning_content: {reasoning_content}") - return self._parse_numerical_mood(response) - - results = await asyncio.gather(_update_text_mood(), _update_numerical_mood()) - text_mood_response, numerical_mood_response = results - - if text_mood_response: - self.mood_state = text_mood_response - - if numerical_mood_response: - _old_mood_values = self.mood_values.copy() - self.mood_values = numerical_mood_response - - # 发送情绪更新到ws端 - await self.send_emotion_update(self.mood_values) - - logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}") - - self.last_change_time = message_time - - async def regress_mood(self): - message_time = time.time() - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_change_time, - timestamp_end=message_time, - limit=5, - limit_mode="last", - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - - bot_name = global_config.bot.nickname - if global_config.bot.alias_names: - bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" - else: - bot_nickname = "" - - prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - - async def _regress_text_mood(): - prompt = await global_prompt_manager.format_prompt( - "regress_mood_prompt_vtb", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - mood_state=self.mood_state, - ) - logger.debug(f"text regress prompt: {prompt}") - response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( - prompt=prompt, temperature=0.7 - ) - logger.info(f"text regress response: {response}") - logger.debug(f"text regress reasoning_content: {reasoning_content}") - return response - - async def _regress_numerical_mood(): - prompt = await global_prompt_manager.format_prompt( - "regress_mood_numerical_prompt", - chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, - mood_state=self.mood_state, - joy=self.mood_values["joy"], - anger=self.mood_values["anger"], - sorrow=self.mood_values["sorrow"], - fear=self.mood_values["fear"], - ) - logger.debug(f"numerical regress prompt: {prompt}") - response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( - prompt=prompt, - temperature=0.4, - ) - logger.info(f"numerical regress response: {response}") - logger.debug(f"numerical regress reasoning_content: {reasoning_content}") - return self._parse_numerical_mood(response) - - results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood()) - text_mood_response, numerical_mood_response = results - - if text_mood_response: - self.mood_state = text_mood_response - - if numerical_mood_response: - _old_mood_values = self.mood_values.copy() - self.mood_values = numerical_mood_response - - # 发送情绪更新到ws端 - await self.send_emotion_update(self.mood_values) - - logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}") - - self.regression_count += 1 - - async def send_emotion_update(self, mood_values: dict[str, int]): - """发送情绪更新到ws端""" - emotion_data = { - "joy": mood_values.get("joy", 5), - "anger": mood_values.get("anger", 1), - "sorrow": mood_values.get("sorrow", 1), - "fear": mood_values.get("fear", 1), - } - - await send_api.custom_to_stream( - message_type="emotion", - content=emotion_data, - stream_id=self.chat_id, - storage_message=False, - show_log=True, - ) - - logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}") - - -class MoodRegressionTask(AsyncTask): - def __init__(self, mood_manager: "MoodManager"): - super().__init__(task_name="MoodRegressionTask", run_interval=30) - self.mood_manager = mood_manager - self.run_count = 0 - - async def run(self): - self.run_count += 1 - logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态") - - now = time.time() - regression_executed = 0 - - for mood in self.mood_manager.mood_list: - chat_info = f"chat {mood.chat_id}" - - if mood.last_change_time == 0: - logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归") - continue - - time_since_last_change = now - mood.last_change_time - - # 检查是否有极端情绪需要快速回归 - high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8} - has_extreme_emotion = len(high_emotions) > 0 - - # 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s - should_regress = False - regress_reason = "" - - if time_since_last_change > 120: - should_regress = True - regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)" - elif has_extreme_emotion and time_since_last_change > 30: - should_regress = True - high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()]) - regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)" - - if should_regress: - if mood.regression_count >= 3: - logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归") - continue - - logger.info( - f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)" - ) - await mood.regress_mood() - regression_executed += 1 - else: - if has_extreme_emotion: - remaining_time = 5 - time_since_last_change - high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()]) - logger.debug( - f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒" - ) - else: - remaining_time = 120 - time_since_last_change - logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒") - - if regression_executed > 0: - logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归") - else: - logger.debug("[回归任务] 本次没有符合回归条件的聊天") - - -class MoodManager: - def __init__(self): - self.mood_list: list[ChatMood] = [] - """当前情绪状态""" - self.task_started: bool = False - - async def start(self): - """启动情绪回归后台任务""" - if self.task_started: - return - - logger.info("启动情绪管理任务...") - - # 启动情绪回归任务 - regression_task = MoodRegressionTask(self) - await async_task_manager.add_task(regression_task) - - self.task_started = True - logger.info("情绪管理任务已启动(情绪回归)") - - def get_mood_by_chat_id(self, chat_id: str) -> ChatMood: - for mood in self.mood_list: - if mood.chat_id == chat_id: - return mood - - new_mood = ChatMood(chat_id) - self.mood_list.append(new_mood) - return new_mood - - def reset_mood_by_chat_id(self, chat_id: str): - for mood in self.mood_list: - if mood.chat_id == chat_id: - mood.mood_state = "感觉很平静" - mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1} - mood.regression_count = 0 - # 发送重置后的情绪状态到ws端 - asyncio.create_task(mood.send_emotion_update(mood.mood_values)) - return - - # 如果没有找到现有的mood,创建新的 - new_mood = ChatMood(chat_id) - self.mood_list.append(new_mood) - # 发送初始情绪状态到ws端 - asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values)) - - -if ENABLE_S4U: - init_prompt() - mood_manager = MoodManager() -else: - mood_manager = None - -"""全局情绪管理器""" diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py deleted file mode 100644 index 2560f4e1a..000000000 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ /dev/null @@ -1,282 +0,0 @@ -import asyncio -import math - -from maim_message.message_base import GroupInfo - -from src.chat.message_receive.chat_stream import get_chat_manager - -# 旧的Hippocampus系统已被移除,现在使用增强记忆系统 -# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager -from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import is_mentioned_bot_in_message -from src.common.logger import get_logger -from src.config.config import global_config -from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager -from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager -from src.mais4u.mais4u_chat.gift_manager import gift_manager -from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager -from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager -from src.mais4u.mais4u_chat.screen_manager import screen_manager - -from .s4u_chat import get_s4u_chat_manager - -# from ..message_receive.message_buffer import message_buffer - -logger = get_logger("chat") - - -async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]: - """计算消息的兴趣度 - - Args: - message: 待处理的消息对象 - - Returns: - Tuple[float, bool]: (兴趣度, 是否被提及) - """ - is_mentioned, _ = is_mentioned_bot_in_message(message) - interested_rate = 0.0 - - if global_config.memory.enable_memory: - with Timer("记忆激活"): - # 使用新的统一记忆系统计算兴趣度 - try: - from src.chat.memory_system import get_memory_system - - memory_system = get_memory_system() - enhanced_memories = await memory_system.retrieve_relevant_memories( - query_text=message.processed_plain_text, - user_id=str(message.user_info.user_id), - scope_id=message.chat_id, - limit=5, - ) - - # 基于检索结果计算兴趣度 - if enhanced_memories: - # 有相关记忆,兴趣度基于相似度计算 - max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories) - interested_rate = min(max_score, 1.0) # 限制在0-1之间 - else: - # 没有相关记忆,给予基础兴趣度 - interested_rate = 0.1 - - logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}") - - except Exception as e: - logger.warning(f"增强记忆系统兴趣度计算失败: {e}") - interested_rate = 0.1 # 默认基础兴趣度 - - text_len = len(message.processed_plain_text) - # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 - # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) - - if text_len == 0: - base_interest = 0.01 # 空消息最低兴趣度 - elif text_len <= 5: - # 1-5字符:线性增长 0.01 -> 0.03 - base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4 - elif text_len <= 10: - # 6-10字符:线性增长 0.03 -> 0.06 - base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5 - elif text_len <= 20: - # 11-20字符:线性增长 0.06 -> 0.12 - base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10 - elif text_len <= 30: - # 21-30字符:线性增长 0.12 -> 0.18 - base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10 - elif text_len <= 50: - # 31-50字符:线性增长 0.18 -> 0.22 - base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20 - elif text_len <= 100: - # 51-100字符:线性增长 0.22 -> 0.26 - base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50 - else: - # 100+字符:对数增长 0.26 -> 0.3,增长率递减 - base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 - - # 确保在范围内 - base_interest = min(max(base_interest, 0.01), 0.3) - - interested_rate += base_interest - - if is_mentioned: - interest_increase_on_mention = 1 - interested_rate += interest_increase_on_mention - - return interested_rate, is_mentioned - - -class S4UMessageProcessor: - """心流处理器,负责处理接收到的消息并计算兴趣度""" - - def __init__(self): - """初始化心流处理器,创建消息存储实例""" - self.storage = MessageStorage() - - async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None: - """处理接收到的原始消息数据 - - 主要流程: - 1. 消息解析与初始化 - 2. 消息缓冲处理 - 3. 过滤检查 - 4. 兴趣度计算 - 5. 关系处理 - - Args: - message_data: 原始消息字符串 - """ - - # 1. 消息解析与初始化 - groupinfo = message.message_info.group_info - userinfo = message.message_info.user_info - message_info = message.message_info - - chat = await get_chat_manager().get_or_create_stream( - platform=message_info.platform, - user_info=userinfo, - group_info=groupinfo, - ) - - if await self.handle_internal_message(message): - return - - if await self.hadle_if_voice_done(message): - return - - # 处理礼物消息,如果消息被暂存则停止当前处理流程 - if not skip_gift_debounce and not await self.handle_if_gift(message): - return - await self.check_if_fake_gift(message) - - # 处理屏幕消息 - if await self.handle_screen_message(message): - return - - await self.storage.store_message(message, chat) - - s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat) - - await s4u_chat.add_message(message) - - _interested_rate, _ = await _calculate_interest(message) - - await mood_manager.start() - - # 一系列llm驱动的前处理 - chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) - asyncio.create_task(chat_mood.update_mood_by_message(message)) - chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id) - asyncio.create_task(chat_action.update_action_by_message(message)) - # 视线管理:收到消息时切换视线状态 - chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id) - await chat_watching.on_message_received() - - # 上下文网页管理:启动独立task处理消息上下文 - asyncio.create_task(self._handle_context_web_update(chat.stream_id, message)) - - # 日志记录 - if message.is_gift: - logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}") - else: - logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") - - @staticmethod - async def handle_internal_message(message: MessageRecvS4U): - if message.is_internal: - group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心") - - chat = await get_chat_manager().get_or_create_stream( - platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info - ) - s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat) - message.message_info.group_info = s4u_chat.chat_stream.group_info - message.message_info.platform = s4u_chat.chat_stream.platform - - s4u_chat.internal_message.append(message) - s4u_chat.new_message_event.set() - - logger.info( - f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}" - ) - - return True - return False - - @staticmethod - async def handle_screen_message(message: MessageRecvS4U): - if message.is_screen: - screen_manager.set_screen(message.screen_info) - return True - return False - - @staticmethod - async def hadle_if_voice_done(message: MessageRecvS4U): - if message.voice_done: - s4u_chat = await get_s4u_chat_manager().get_or_create_chat(message.chat_stream) - s4u_chat.voice_done = message.voice_done - return True - return False - - @staticmethod - async def check_if_fake_gift(message: MessageRecvS4U) -> bool: - """检查消息是否为假礼物""" - if message.is_gift: - return False - - gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"] - if any(keyword in message.processed_plain_text for keyword in gift_keywords): - message.is_fake_gift = True - return True - - return False - - async def handle_if_gift(self, message: MessageRecvS4U) -> bool: - """处理礼物消息 - - Returns: - bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理 - """ - if message.is_gift: - # 定义防抖完成后的回调函数 - def gift_callback(merged_message: MessageRecvS4U): - """礼物防抖完成后的回调""" - # 创建异步任务来处理合并后的礼物消息,跳过防抖处理 - asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True)) - - # 交给礼物管理器处理,并传入回调函数 - # 对于礼物消息,handle_gift 总是返回 False(消息被暂存) - await gift_manager.handle_gift(message, gift_callback) - return False # 消息被暂存,不继续处理 - - return True # 非礼物消息,继续正常处理 - - @staticmethod - async def _handle_context_web_update(chat_id: str, message: MessageRecv): - """处理上下文网页更新的独立task - - Args: - chat_id: 聊天ID - message: 消息对象 - """ - try: - logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}") - - context_manager = get_context_web_manager() - - # 只在服务器未启动时启动(避免重复启动) - if context_manager.site is None: - logger.info("🚀 首次启动上下文网页服务器...") - await context_manager.start_server() - - # 添加消息到上下文并更新网页 - await asyncio.sleep(1.5) - - await context_manager.add_message(chat_id, message) - - logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}") - - except Exception as e: - logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True) diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py deleted file mode 100644 index eba734184..000000000 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ /dev/null @@ -1,443 +0,0 @@ -import asyncio - -# 旧的Hippocampus系统已被移除,现在使用增强记忆系统 -# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager -import random -import time -from datetime import datetime - -from src.chat.express.expression_selector import expression_selector -from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.message import MessageRecvS4U -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.utils.utils import get_recent_group_speaker -from src.common.logger import get_logger -from src.config.config import global_config -from src.mais4u.mais4u_chat.internal_manager import internal_manager -from src.mais4u.mais4u_chat.screen_manager import screen_manager -from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager -from src.mais4u.s4u_config import s4u_config -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.person_info.relationship_fetcher import relationship_fetcher_manager - -from .s4u_mood_manager import mood_manager - -logger = get_logger("prompt") - - -def init_prompt(): - Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") - Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt") - Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt") - - Prompt( - """ -你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播 -虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色 -你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复 -你可以看见用户发送的弹幕,礼物和superchat -{screen_info} -{internal_state} - -{relation_info_block} -{memory_block} -{expression_habits_block} - -你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。 - -{sc_info} - -{background_dialogue_prompt} --------------------------------- -{time_block} -这是你和{sender_name}的对话,你们正在交流中: -{core_dialogue_prompt} - -对方最新发送的内容:{message_txt} -{gift_info} -回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。 -表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state} -不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。 -你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。 -你的发言: -""", - "s4u_prompt", # New template for private CHAT chat - ) - - Prompt( - """ -你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播 -虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色 -你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复 -你可以看见用户发送的弹幕,礼物和superchat -你可以看见面前的屏幕,目前屏幕的内容是: -{screen_info} - -{memory_block} -{expression_habits_block} - -{sc_info} - -{time_block} -{chat_info_danmu} --------------------------------- -以上是你和弹幕的对话,与此同时,你在与QQ群友聊天,聊天记录如下: -{chat_info_qq} --------------------------------- -你刚刚回复了QQ群,你内心的想法是:{mind} -请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁 -{gift_info} -回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。 -表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state} -不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。 -你的发言: -""", - "s4u_prompt_internal", # New template for private CHAT chat - ) - - -class PromptBuilder: - def __init__(self): - self.prompt_built = "" - self.activate_messages = "" - - @staticmethod - async def build_expression_habits(chat_stream: ChatStream, chat_history, target): - style_habits = [] - grammar_habits = [] - - # 使用统一的表达方式选择入口(支持classic和exp_model模式) - selected_expressions = await expression_selector.select_suitable_expressions( - chat_id=chat_stream.stream_id, - chat_history=chat_history, - target_message=target, - max_num=12, - min_num=5 - ) - - if selected_expressions: - logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式") - for expr in selected_expressions: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - expr_type = expr.get("type", "style") - if expr_type == "grammar": - grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - else: - style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - else: - logger.debug("没有从处理器获得表达方式,将使用空的表达方式") - # 不再在replyer中进行随机选择,全部交给处理器处理 - - style_habits_str = "\n".join(style_habits) - grammar_habits_str = "\n".join(grammar_habits) - - # 动态构建expression habits块 - expression_habits_block = "" - if style_habits_str.strip(): - expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n" - if grammar_habits_str.strip(): - expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n" - - return expression_habits_block - - @staticmethod - async def build_relation_info(chat_stream) -> str: - is_group_chat = bool(chat_stream.group_info) - who_chat_in_group = [] - if is_group_chat: - who_chat_in_group = get_recent_group_speaker( - chat_stream.stream_id, - (chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None, - limit=global_config.chat.max_context_size, - ) - elif chat_stream.user_info: - who_chat_in_group.append( - (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) - ) - - relation_prompt = "" - if global_config.affinity_flow.enable_relationship_tracking and who_chat_in_group: - relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id) - - # 将 (platform, user_id, nickname) 转换为 person_id - person_ids = [] - for person in who_chat_in_group: - person_id = PersonInfoManager.get_person_id(person[0], person[1]) - person_ids.append(person_id) - - # 构建用户关系信息和聊天流印象信息 - user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] - stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id) - - # 并行获取所有信息 - results = await asyncio.gather(*user_relation_tasks, stream_impression_task) - relation_info_list = results[:-1] # 用户关系信息 - stream_impression = results[-1] # 聊天流印象 - - # 组合用户关系信息和聊天流印象 - combined_info_parts = [] - if user_relation_info := "".join(relation_info_list): - combined_info_parts.append(user_relation_info) - if stream_impression: - combined_info_parts.append(stream_impression) - - if combined_info := "\n\n".join(combined_info_parts): - relation_prompt = await global_prompt_manager.format_prompt( - "relation_prompt", relation_info=combined_info - ) - return relation_prompt - - @staticmethod - async def build_memory_block(text: str) -> str: - # 使用新的统一记忆系统检索记忆 - try: - from src.chat.memory_system import get_memory_system - - memory_system = get_memory_system() - enhanced_memories = await memory_system.retrieve_relevant_memories( - query_text=text, - user_id="system", # 系统查询 - scope_id="system", - limit=5, - ) - - related_memory_info = "" - if enhanced_memories: - for memory_chunk in enhanced_memories: - related_memory_info += memory_chunk.display or memory_chunk.text_content or "" - return await global_prompt_manager.format_prompt( - "memory_prompt", memory_info=related_memory_info.strip() - ) - return "" - - except Exception as e: - logger.warning(f"增强记忆系统检索失败: {e}") - return "" - - @staticmethod - async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U): - message_list_before_now = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_stream.stream_id, - timestamp=time.time(), - limit=300, - ) - - talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}" - - core_dialogue_list = [] - background_dialogue_list = [] - bot_id = str(global_config.bot.qq_account) - target_user_id = str(message.chat_stream.user_info.user_id) - - for msg_dict in message_list_before_now: - try: - msg_user_id = str(msg_dict.get("user_id")) - if msg_user_id == bot_id: - if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): - core_dialogue_list.append(msg_dict) - elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"): - background_dialogue_list.append(msg_dict) - # else: - # background_dialogue_list.append(msg_dict) - elif msg_user_id == target_user_id: - core_dialogue_list.append(msg_dict) - else: - background_dialogue_list.append(msg_dict) - except Exception as e: - logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") - - background_dialogue_prompt = "" - if background_dialogue_list: - context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :] - background_dialogue_prompt_str = await build_readable_messages( - context_msgs, - timestamp_mode="normal_no_YMD", - show_pic=False, - ) - background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}" - - core_msg_str = "" - if core_dialogue_list: - core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :] - - first_msg = core_dialogue_list[0] - start_speaking_user_id = first_msg.get("user_id") - if start_speaking_user_id == bot_id: - last_speaking_user_id = bot_id - msg_seg_str = "你的发言:\n" - else: - start_speaking_user_id = target_user_id - last_speaking_user_id = start_speaking_user_id - msg_seg_str = "对方的发言:\n" - - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" - - all_msg_seg_list = [] - for msg in core_dialogue_list[1:]: - speaker = msg.get("user_id") - if speaker == last_speaking_user_id: - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" - else: - msg_seg_str = f"{msg_seg_str}\n" - all_msg_seg_list.append(msg_seg_str) - - if speaker == bot_id: - msg_seg_str = "你的发言:\n" - else: - msg_seg_str = "对方的发言:\n" - - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" - last_speaking_user_id = speaker - - all_msg_seg_list.append(msg_seg_str) - for msg in all_msg_seg_list: - core_msg_str += msg - - all_dialogue_prompt = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_stream.stream_id, - timestamp=time.time(), - limit=20, - ) - all_dialogue_prompt_str = await build_readable_messages( - all_dialogue_prompt, - timestamp_mode="normal_no_YMD", - show_pic=False, - ) - - return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str - - @staticmethod - def build_gift_info(message: MessageRecvS4U): - if message.is_gift: - return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户" - else: - if message.is_fake_gift: - return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)" - - return "" - - @staticmethod - def build_sc_info(message: MessageRecvS4U): - super_chat_manager = get_super_chat_manager() - return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id) - - async def build_prompt_normal( - self, - message: MessageRecvS4U, - message_txt: str, - ) -> str: - chat_stream = message.chat_stream - - person_id = PersonInfoManager.get_person_id( - message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id - ) - person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") - - if message.chat_stream.user_info.user_nickname: - if person_name: - sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" - else: - sender_name = f"[{message.chat_stream.user_info.user_nickname}]" - else: - sender_name = f"用户({message.chat_stream.user_info.user_id})" - - relation_info_block, memory_block, expression_habits_block = await asyncio.gather( - self.build_relation_info(chat_stream), - self.build_memory_block(message_txt), - self.build_expression_habits(chat_stream, message_txt, sender_name), - ) - - core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = await self.build_chat_history_prompts( - chat_stream, message - ) - - gift_info = self.build_gift_info(message) - - sc_info = self.build_sc_info(message) - - screen_info = screen_manager.get_screen_str() - - internal_state = internal_manager.get_internal_state_str() - - time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - - mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id) - - template_name = "s4u_prompt" - - if not message.is_internal: - prompt = await global_prompt_manager.format_prompt( - template_name, - time_block=time_block, - expression_habits_block=expression_habits_block, - relation_info_block=relation_info_block, - memory_block=memory_block, - screen_info=screen_info, - internal_state=internal_state, - gift_info=gift_info, - sc_info=sc_info, - sender_name=sender_name, - core_dialogue_prompt=core_dialogue_prompt, - background_dialogue_prompt=background_dialogue_prompt, - message_txt=message_txt, - mood_state=mood.mood_state, - ) - else: - prompt = await global_prompt_manager.format_prompt( - "s4u_prompt_internal", - time_block=time_block, - expression_habits_block=expression_habits_block, - relation_info_block=relation_info_block, - memory_block=memory_block, - screen_info=screen_info, - gift_info=gift_info, - sc_info=sc_info, - chat_info_danmu=all_dialogue_prompt, - chat_info_qq=message.chat_info, - mind=message.processed_plain_text, - mood_state=mood.mood_state, - ) - - # print(prompt) - - return prompt - - -def weighted_sample_no_replacement(items, weights, k) -> list: - """ - 加权且不放回地随机抽取k个元素。 - - 参数: - items: 待抽取的元素列表 - weights: 每个元素对应的权重(与items等长,且为正数) - k: 需要抽取的元素个数 - 返回: - selected: 按权重加权且不重复抽取的k个元素组成的列表 - - 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 - - 实现思路: - 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 - 这样保证了: - 1. count越大被选中概率越高 - 2. 不会重复选中同一个元素 - """ - selected = [] - pool = list(zip(items, weights, strict=False)) - for _ in range(min(k, len(pool))): - total = sum(w for _, w in pool) - r = random.uniform(0, total) - upto = 0 - for idx, (item, weight) in enumerate(pool): - upto += weight - if upto >= r: - selected.append(item) - pool.pop(idx) - break - return selected - - -init_prompt() -prompt_builder = PromptBuilder() diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py deleted file mode 100644 index 3f2ac4a80..000000000 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ /dev/null @@ -1,168 +0,0 @@ -import asyncio -import re -from collections.abc import AsyncGenerator - -from src.chat.message_receive.message import MessageRecvS4U -from src.common.logger import get_logger -from src.config.config import model_config -from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder -from src.mais4u.openai_client import AsyncOpenAIClient - -logger = get_logger("s4u_stream_generator") - - -class S4UStreamGenerator: - def __init__(self): - replyer_config = model_config.model_task_config.replyer - model_to_use = replyer_config.model_list[0] - model_info = model_config.get_model_info(model_to_use) - if not model_info: - logger.error(f"模型 {model_to_use} 在配置中未找到") - raise ValueError(f"模型 {model_to_use} 在配置中未找到") - provider_name = model_info.api_provider - provider_info = model_config.get_provider(provider_name) - if not provider_info: - logger.error("`replyer` 找不到对应的Provider") - raise ValueError("`replyer` 找不到对应的Provider") - - api_key = provider_info.api_key - base_url = provider_info.base_url - - if not api_key: - logger.error(f"{provider_name}没有配置API KEY") - raise ValueError(f"{provider_name}没有配置API KEY") - - self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) - self.model_1_name = model_to_use - self.replyer_config = replyer_config - - self.current_model_name = "unknown model" - self.partial_response = "" - - # 正则表达式用于按句子切分,同时处理各种标点和边缘情况 - # 匹配常见的句子结束符,但会忽略引号内和数字中的标点 - self.sentence_split_pattern = re.compile( - r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容 - r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符 - re.UNICODE | re.DOTALL, - ) - - self.chat_stream = None - - @staticmethod - async def build_last_internal_message(message: MessageRecvS4U, previous_reply_context: str = ""): - # person_id = PersonInfoManager.get_person_id( - # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id - # ) - # person_info_manager = get_person_info_manager() - # person_name = await person_info_manager.get_value(person_id, "person_name") - - # if message.chat_stream.user_info.user_nickname: - # if person_name: - # sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" - # else: - # sender_name = f"[{message.chat_stream.user_info.user_nickname}]" - # else: - # sender_name = f"用户({message.chat_stream.user_info.user_id})" - - # 构建prompt - if previous_reply_context: - message_txt = f""" - 你正在回复用户的消息,但中途被打断了。这是已有的对话上下文: - [你已经对上一条消息说的话]: {previous_reply_context} - --- - [这是用户发来的新消息, 你需要结合上下文,对此进行回复]: - {message.processed_plain_text} - """ - return True, message_txt - else: - message_txt = message.processed_plain_text - return False, message_txt - - async def generate_response( - self, message: MessageRecvS4U, previous_reply_context: str = "" - ) -> AsyncGenerator[str, None]: - """根据当前模型类型选择对应的生成函数""" - # 从global_config中获取模型概率值并选择模型 - self.partial_response = "" - message_txt = message.processed_plain_text - if not message.is_internal: - interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context) - if interupted: - message_txt = message_txt_added - - message.chat_stream = self.chat_stream - prompt = await prompt_builder.build_prompt_normal( - message=message, - message_txt=message_txt, - ) - - logger.info( - f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}" - ) - - current_client = self.client_1 - self.current_model_name = self.model_1_name - - extra_kwargs = {} - if self.replyer_config.get("enable_thinking") is not None: - extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking") - if self.replyer_config.get("thinking_budget") is not None: - extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget") - - async for chunk in self._generate_response_with_model( - prompt, current_client, self.current_model_name, **extra_kwargs - ): - yield chunk - - async def _generate_response_with_model( - self, - prompt: str, - client: AsyncOpenAIClient, - model_name: str, - **kwargs, - ) -> AsyncGenerator[str, None]: - buffer = "" - delimiters = ",。!?,.!?\n\r" # For final trimming - punctuation_buffer = "" - - async for content in client.get_stream_content( - messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs - ): - buffer += content - - # 使用正则表达式匹配句子 - last_match_end = 0 - for match in self.sentence_split_pattern.finditer(buffer): - sentence = match.group(0).strip() - if sentence: - # 如果句子看起来完整(即不只是等待更多内容),则发送 - if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)): - # 检查是否只是一个标点符号 - if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]: - punctuation_buffer += sentence - else: - # 发送之前累积的标点和当前句子 - to_yield = punctuation_buffer + sentence - if to_yield.endswith((",", ",")): - to_yield = to_yield.rstrip(",,") - - self.partial_response += to_yield - yield to_yield - punctuation_buffer = "" # 清空标点符号缓冲区 - await asyncio.sleep(0) # 允许其他任务运行 - - last_match_end = match.end(0) - - # 从缓冲区移除已发送的部分 - if last_match_end > 0: - buffer = buffer[last_match_end:] - - # 发送缓冲区中剩余的任何内容 - to_yield = (punctuation_buffer + buffer).strip() - if to_yield: - if to_yield.endswith((",", ",")): - to_yield = to_yield.rstrip(",,") - if to_yield: - self.partial_response += to_yield - yield to_yield diff --git a/src/mais4u/mais4u_chat/s4u_watching_manager.py b/src/mais4u/mais4u_chat/s4u_watching_manager.py deleted file mode 100644 index 90c01545b..000000000 --- a/src/mais4u/mais4u_chat/s4u_watching_manager.py +++ /dev/null @@ -1,106 +0,0 @@ -from src.common.logger import get_logger -from src.plugin_system.apis import send_api - -""" -视线管理系统使用说明: - -1. 视线状态: - - wandering: 随意看 - - danmu: 看弹幕 - - lens: 看镜头 - -2. 状态切换逻辑: - - 收到消息时 → 切换为看弹幕,立即发送更新 - - 开始生成回复时 → 切换为看镜头或随意,立即发送更新 - - 生成完毕后 → 看弹幕1秒,然后回到看镜头直到有新消息,状态变化时立即发送更新 - -3. 使用方法: - # 获取视线管理器 - watching = watching_manager.get_watching_by_chat_id(chat_id) - - # 收到消息时调用 - await watching.on_message_received() - - # 开始生成回复时调用 - await watching.on_reply_start() - - # 生成回复完毕时调用 - await watching.on_reply_finished() - -4. 自动更新系统: - - 状态变化时立即发送type为"watching",data为状态值的websocket消息 - - 使用定时器自动处理状态转换(如看弹幕时间结束后自动切换到看镜头) - - 无需定期检查,所有状态变化都是事件驱动的 -""" - -logger = get_logger("watching") - -HEAD_CODE = { - "看向上方": "(0,0.5,0)", - "看向下方": "(0,-0.5,0)", - "看向左边": "(-1,0,0)", - "看向右边": "(1,0,0)", - "随意朝向": "random", - "看向摄像机": "camera", - "注视对方": "(0,0,0)", - "看向正前方": "(0,0,0)", -} - - -class ChatWatching: - def __init__(self, chat_id: str): - self.chat_id: str = chat_id - - async def on_reply_start(self): - """开始生成回复时调用""" - await send_api.custom_to_stream( - message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False - ) - - async def on_reply_finished(self): - """生成回复完毕时调用""" - await send_api.custom_to_stream( - message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False - ) - - async def on_thinking_finished(self): - """思考完毕时调用""" - await send_api.custom_to_stream( - message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False - ) - - async def on_message_received(self): - """收到消息时调用""" - await send_api.custom_to_stream( - message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False - ) - - async def on_internal_message_start(self): - """收到消息时调用""" - await send_api.custom_to_stream( - message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False - ) - - -class WatchingManager: - def __init__(self): - self.watching_list: list[ChatWatching] = [] - """当前视线状态列表""" - self.task_started: bool = False - - def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching: - """获取或创建聊天对应的视线管理器""" - for watching in self.watching_list: - if watching.chat_id == chat_id: - return watching - - new_watching = ChatWatching(chat_id) - self.watching_list.append(new_watching) - logger.info(f"为chat {chat_id}创建新的视线管理器") - - return new_watching - - -# 全局视线管理器实例 -watching_manager = WatchingManager() -"""全局视线管理器""" diff --git a/src/mais4u/mais4u_chat/screen_manager.py b/src/mais4u/mais4u_chat/screen_manager.py deleted file mode 100644 index 60a7f914d..000000000 --- a/src/mais4u/mais4u_chat/screen_manager.py +++ /dev/null @@ -1,15 +0,0 @@ -class ScreenManager: - def __init__(self): - self.now_screen = "" - - def set_screen(self, screen_str: str): - self.now_screen = screen_str - - def get_screen(self): - return self.now_screen - - def get_screen_str(self): - return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}" - - -screen_manager = ScreenManager() diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py deleted file mode 100644 index df6245746..000000000 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import time -from dataclasses import dataclass - -from src.chat.message_receive.message import MessageRecvS4U -from src.common.logger import get_logger - -# 全局SuperChat管理器实例 -from src.mais4u.constant_s4u import ENABLE_S4U - -logger = get_logger("super_chat_manager") - - -@dataclass -class SuperChatRecord: - """SuperChat记录数据类""" - - user_id: str - user_nickname: str - platform: str - chat_id: str - price: float - message_text: str - timestamp: float - expire_time: float - group_name: str | None = None - - def is_expired(self) -> bool: - """检查SuperChat是否已过期""" - return time.time() > self.expire_time - - def remaining_time(self) -> float: - """获取剩余时间(秒)""" - return max(0, self.expire_time - time.time()) - - def to_dict(self) -> dict: - """转换为字典格式""" - return { - "user_id": self.user_id, - "user_nickname": self.user_nickname, - "platform": self.platform, - "chat_id": self.chat_id, - "price": self.price, - "message_text": self.message_text, - "timestamp": self.timestamp, - "expire_time": self.expire_time, - "group_name": self.group_name, - "remaining_time": self.remaining_time(), - } - - -class SuperChatManager: - """SuperChat管理器,负责管理和跟踪SuperChat消息""" - - def __init__(self): - self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表 - self._cleanup_task: asyncio.Task | None = None - self._is_initialized = False - logger.info("SuperChat管理器已初始化") - - def _ensure_cleanup_task_started(self): - """确保清理任务已启动(延迟启动)""" - if self._cleanup_task is None or self._cleanup_task.done(): - try: - loop = asyncio.get_running_loop() - self._cleanup_task = loop.create_task(self._cleanup_expired_superchats()) - self._is_initialized = True - logger.info("SuperChat清理任务已启动") - except RuntimeError: - # 没有运行的事件循环,稍后再启动 - logger.debug("当前没有运行的事件循环,将在需要时启动清理任务") - - def _start_cleanup_task(self): - """启动清理任务(已弃用,保留向后兼容)""" - self._ensure_cleanup_task_started() - - async def _cleanup_expired_superchats(self): - """定期清理过期的SuperChat""" - while True: - try: - total_removed = 0 - - for chat_id in list(self.super_chats.keys()): - original_count = len(self.super_chats[chat_id]) - # 移除过期的SuperChat - self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] - - removed_count = original_count - len(self.super_chats[chat_id]) - total_removed += removed_count - - if removed_count > 0: - logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat") - - # 如果列表为空,删除该聊天的记录 - if not self.super_chats[chat_id]: - del self.super_chats[chat_id] - - if total_removed > 0: - logger.info(f"总共清理了 {total_removed} 个过期的SuperChat") - - # 每30秒检查一次 - await asyncio.sleep(30) - - except Exception as e: - logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True) - await asyncio.sleep(60) # 出错时等待更长时间 - - @staticmethod - def _calculate_expire_time(price: float) -> float: - """根据SuperChat金额计算过期时间""" - current_time = time.time() - - # 根据金额阶梯设置不同的存活时间 - if price >= 500: - # 500元以上:保持4小时 - duration = 4 * 3600 - elif price >= 200: - # 200-499元:保持2小时 - duration = 2 * 3600 - elif price >= 100: - # 100-199元:保持1小时 - duration = 1 * 3600 - elif price >= 50: - # 50-99元:保持30分钟 - duration = 30 * 60 - elif price >= 20: - # 20-49元:保持15分钟 - duration = 15 * 60 - elif price >= 10: - # 10-19元:保持10分钟 - duration = 10 * 60 - else: - # 10元以下:保持5分钟 - duration = 5 * 60 - - return current_time + duration - - async def add_superchat(self, message: MessageRecvS4U) -> None: - """添加新的SuperChat记录""" - # 确保清理任务已启动 - self._ensure_cleanup_task_started() - - if not message.is_superchat or not message.superchat_price: - logger.warning("尝试添加非SuperChat消息到SuperChat管理器") - return - - try: - price = float(message.superchat_price) - except (ValueError, TypeError): - logger.error(f"无效的SuperChat价格: {message.superchat_price}") - return - - user_info = message.message_info.user_info - group_info = message.message_info.group_info - chat_id = getattr(message, "chat_stream", None) - if chat_id: - chat_id = chat_id.stream_id - else: - # 生成chat_id的备用方法 - chat_id = f"{message.message_info.platform}_{user_info.user_id}" - if group_info: - chat_id = f"{message.message_info.platform}_{group_info.group_id}" - - expire_time = self._calculate_expire_time(price) - - record = SuperChatRecord( - user_id=user_info.user_id, - user_nickname=user_info.user_nickname, - platform=message.message_info.platform, - chat_id=chat_id, - price=price, - message_text=message.superchat_message_text or "", - timestamp=message.message_info.time, - expire_time=expire_time, - group_name=group_info.group_name if group_info else None, - ) - - # 添加到对应聊天的SuperChat列表 - if chat_id not in self.super_chats: - self.super_chats[chat_id] = [] - - self.super_chats[chat_id].append(record) - - # 按价格降序排序(价格高的在前) - self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True) - - logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}") - - def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]: - """获取指定聊天的所有有效SuperChat""" - # 确保清理任务已启动 - self._ensure_cleanup_task_started() - - if chat_id not in self.super_chats: - return [] - - # 过滤掉过期的SuperChat - valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] - return valid_superchats - - def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]: - """获取所有有效的SuperChat""" - # 确保清理任务已启动 - self._ensure_cleanup_task_started() - - result = {} - for chat_id, superchats in self.super_chats.items(): - valid_superchats = [sc for sc in superchats if not sc.is_expired()] - if valid_superchats: - result[chat_id] = valid_superchats - return result - - def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: - """构建SuperChat显示字符串""" - superchats = self.get_superchats_by_chat(chat_id) - - if not superchats: - return "" - - # 限制显示数量 - display_superchats = superchats[:max_count] - - lines = ["📢 当前有效超级弹幕:"] - for i, sc in enumerate(display_superchats, 1): - remaining_minutes = int(sc.remaining_time() / 60) - remaining_seconds = int(sc.remaining_time() % 60) - - time_display = ( - f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒" - ) - - line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" - if len(line) > 100: # 限制单行长度 - line = f"{line[:97]}..." - line += f" (剩余{time_display})" - lines.append(line) - - if len(superchats) > max_count: - lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") - - return "\n".join(lines) - - def build_superchat_summary_string(self, chat_id: str) -> str: - """构建SuperChat摘要字符串""" - superchats = self.get_superchats_by_chat(chat_id) - - if not superchats: - return "当前没有有效的超级弹幕" - lines = [] - for sc in superchats: - single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}" - if len(single_sc_str) > 100: - single_sc_str = f"{single_sc_str[:97]}..." - single_sc_str += f" (剩余{int(sc.remaining_time())}秒)" - lines.append(single_sc_str) - - total_amount = sum(sc.price for sc in superchats) - count = len(superchats) - highest_amount = max(sc.price for sc in superchats) - - final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元" - if lines: - final_str += "\n" + "\n".join(lines) - return final_str - - def get_superchat_statistics(self, chat_id: str) -> dict: - """获取SuperChat统计信息""" - superchats = self.get_superchats_by_chat(chat_id) - - if not superchats: - return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0} - - amounts = [sc.price for sc in superchats] - - return { - "count": len(superchats), - "total_amount": sum(amounts), - "average_amount": sum(amounts) / len(amounts), - "highest_amount": max(amounts), - "lowest_amount": min(amounts), - } - - async def shutdown(self): # sourcery skip: use-contextlib-suppress - """关闭管理器,清理资源""" - if self._cleanup_task and not self._cleanup_task.done(): - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - logger.info("SuperChat管理器已关闭") - - -# sourcery skip: assign-if-exp -if ENABLE_S4U: - super_chat_manager = SuperChatManager() -else: - super_chat_manager = None - - -def get_super_chat_manager() -> SuperChatManager: - """获取全局SuperChat管理器实例""" - - return super_chat_manager diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py deleted file mode 100644 index 51fba0416..000000000 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ /dev/null @@ -1,46 +0,0 @@ -from src.common.logger import get_logger -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest -from src.plugin_system.apis import send_api - -logger = get_logger(__name__) - -head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"] - - -async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""): - prompt = f""" -{chat_history} -以上是对方的发言: - -对这个发言,你的心情是:{emotion} -对上面的发言,你的回复是:{text} -请判断时是否要伴随回复做头部动作,你可以选择: - -不做额外动作 -点头一次 -点头两次 -摇头 -歪脑袋 -低头望向一边 - -请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。""" - model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") - - try: - # logger.info(f"prompt: {prompt}") - response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7) - logger.info(f"response: {response}") - - head_action = response if response in head_actions_list else "不做额外动作" - await send_api.custom_to_stream( - message_type="head_action", - content=head_action, - stream_id=chat_id, - storage_message=False, - show_log=True, - ) - - except Exception as e: - logger.error(f"yes_or_no_head error: {e}") - return "不做额外动作" diff --git a/src/mais4u/openai_client.py b/src/mais4u/openai_client.py deleted file mode 100644 index 6f5e0484e..000000000 --- a/src/mais4u/openai_client.py +++ /dev/null @@ -1,287 +0,0 @@ -from collections.abc import AsyncGenerator -from dataclasses import dataclass - -from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionChunk - - -@dataclass -class ChatMessage: - """聊天消息数据类""" - - role: str - content: str - - def to_dict(self) -> dict[str, str]: - return {"role": self.role, "content": self.content} - - -class AsyncOpenAIClient: - """异步OpenAI客户端,支持流式传输""" - - def __init__(self, api_key: str, base_url: str | None = None): - """ - 初始化客户端 - - Args: - api_key: OpenAI API密钥 - base_url: 可选的API基础URL,用于自定义端点 - """ - self.client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - timeout=10.0, # 设置60秒的全局超时 - ) - - async def chat_completion( - self, - messages: list[ChatMessage | dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 0.7, - max_tokens: int | None = None, - **kwargs, - ) -> ChatCompletion: - """ - 非流式聊天完成 - - Args: - messages: 消息列表 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - **kwargs: 其他参数 - - Returns: - 完整的聊天回复 - """ - # 转换消息格式 - formatted_messages = [] - for msg in messages: - if isinstance(msg, ChatMessage): - formatted_messages.append(msg.to_dict()) - else: - formatted_messages.append(msg) - - extra_body = {} - if kwargs.get("enable_thinking") is not None: - extra_body["enable_thinking"] = kwargs.pop("enable_thinking") - if kwargs.get("thinking_budget") is not None: - extra_body["thinking_budget"] = kwargs.pop("thinking_budget") - - response = await self.client.chat.completions.create( - model=model, - messages=formatted_messages, - temperature=temperature, - max_tokens=max_tokens, - stream=False, - extra_body=extra_body if extra_body else None, - **kwargs, - ) - - return response - - async def chat_completion_stream( - self, - messages: list[ChatMessage | dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 0.7, - max_tokens: int | None = None, - **kwargs, - ) -> AsyncGenerator[ChatCompletionChunk, None]: - """ - 流式聊天完成 - - Args: - messages: 消息列表 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - **kwargs: 其他参数 - - Yields: - ChatCompletionChunk: 流式响应块 - """ - # 转换消息格式 - formatted_messages = [] - for msg in messages: - if isinstance(msg, ChatMessage): - formatted_messages.append(msg.to_dict()) - else: - formatted_messages.append(msg) - - extra_body = {} - if kwargs.get("enable_thinking") is not None: - extra_body["enable_thinking"] = kwargs.pop("enable_thinking") - if kwargs.get("thinking_budget") is not None: - extra_body["thinking_budget"] = kwargs.pop("thinking_budget") - - stream = await self.client.chat.completions.create( - model=model, - messages=formatted_messages, - temperature=temperature, - max_tokens=max_tokens, - stream=True, - extra_body=extra_body if extra_body else None, - **kwargs, - ) - - async for chunk in stream: - yield chunk - - async def get_stream_content( - self, - messages: list[ChatMessage | dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 0.7, - max_tokens: int | None = None, - **kwargs, - ) -> AsyncGenerator[str, None]: - """ - 获取流式内容(只返回文本内容) - - Args: - messages: 消息列表 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - **kwargs: 其他参数 - - Yields: - str: 文本内容片段 - """ - async for chunk in self.chat_completion_stream( - messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs - ): - if chunk.choices and chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content - - async def collect_stream_response( - self, - messages: list[ChatMessage | dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 0.7, - max_tokens: int | None = None, - **kwargs, - ) -> str: - """ - 收集完整的流式响应 - - Args: - messages: 消息列表 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - **kwargs: 其他参数 - - Returns: - str: 完整的响应文本 - """ - full_response = "" - async for content in self.get_stream_content( - messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs - ): - full_response += content - - return full_response - - async def close(self): - """关闭客户端""" - await self.client.close() - - async def __aenter__(self): - """异步上下文管理器入口""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器退出""" - await self.close() - - -class ConversationManager: - """对话管理器,用于管理对话历史""" - - def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None): - """ - 初始化对话管理器 - - Args: - client: OpenAI客户端实例 - system_prompt: 系统提示词 - """ - self.client = client - self.messages: list[ChatMessage] = [] - - if system_prompt: - self.messages.append(ChatMessage(role="system", content=system_prompt)) - - def add_user_message(self, content: str): - """添加用户消息""" - self.messages.append(ChatMessage(role="user", content=content)) - - def add_assistant_message(self, content: str): - """添加助手消息""" - self.messages.append(ChatMessage(role="assistant", content=content)) - - async def send_message_stream( - self, content: str, model: str = "gpt-3.5-turbo", **kwargs - ) -> AsyncGenerator[str, None]: - """ - 发送消息并获取流式响应 - - Args: - content: 用户消息内容 - model: 模型名称 - **kwargs: 其他参数 - - Yields: - str: 响应内容片段 - """ - self.add_user_message(content) - - response_content = "" - async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs): - response_content += chunk - yield chunk - - self.add_assistant_message(response_content) - - async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str: - """ - 发送消息并获取完整响应 - - Args: - content: 用户消息内容 - model: 模型名称 - **kwargs: 其他参数 - - Returns: - str: 完整响应 - """ - self.add_user_message(content) - - response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs) - - response_content = response.choices[0].message.content - self.add_assistant_message(response_content) - - return response_content - - def clear_history(self, keep_system: bool = True): - """ - 清除对话历史 - - Args: - keep_system: 是否保留系统消息 - """ - if keep_system and self.messages and self.messages[0].role == "system": - self.messages = [self.messages[0]] - else: - self.messages = [] - - def get_message_count(self) -> int: - """获取消息数量""" - return len(self.messages) - - def get_conversation_history(self) -> list[dict[str, str]]: - """获取对话历史""" - return [msg.to_dict() for msg in self.messages] diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py deleted file mode 100644 index ce3abe47e..000000000 --- a/src/mais4u/s4u_config.py +++ /dev/null @@ -1,373 +0,0 @@ -import os -import shutil -from dataclasses import MISSING, dataclass, field, fields -from datetime import datetime -from typing import Any, Literal, TypeVar, get_args, get_origin - -import tomlkit -from tomlkit import TOMLDocument -from tomlkit.items import Table -from typing_extensions import Self - -from src.common.logger import get_logger -from src.mais4u.constant_s4u import ENABLE_S4U - -logger = get_logger("s4u_config") - - -# 新增:兼容dict和tomlkit Table -def is_dict_like(obj): - return isinstance(obj, dict | Table) - - -# 新增:递归将Table转为dict -def table_to_dict(obj): - if isinstance(obj, Table): - return {k: table_to_dict(v) for k, v in obj.items()} - elif isinstance(obj, dict): - return {k: table_to_dict(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [table_to_dict(i) for i in obj] - else: - return obj - - -# 获取mais4u模块目录 -MAIS4U_ROOT = os.path.dirname(__file__) -CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config") -TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml") -CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml") - -# S4U配置版本 -S4U_VERSION = "1.1.0" - -T = TypeVar("T", bound="S4UConfigBase") - - -@dataclass -class S4UConfigBase: - """S4U配置类的基类""" - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - """从字典加载配置字段""" - data = table_to_dict(data) # 递归转dict,兼容tomlkit Table - if not is_dict_like(data): - raise TypeError(f"Expected a dictionary, got {type(data).__name__}") - - init_args: dict[str, Any] = {} - - for f in fields(cls): - field_name = f.name - - if field_name.startswith("_"): - # 跳过以 _ 开头的字段 - continue - - if field_name not in data: - if f.default is not MISSING or f.default_factory is not MISSING: - # 跳过未提供且有默认值/默认构造方法的字段 - continue - else: - raise ValueError(f"Missing required field: '{field_name}'") - - value = data[field_name] - field_type = f.type - - try: - init_args[field_name] = cls._convert_field(value, field_type) # type: ignore - except TypeError as e: - raise TypeError(f"Field '{field_name}' has a type error: {e}") from e - except Exception as e: - raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e - - return cls() - - @classmethod - def _convert_field(cls, value: Any, field_type: type[Any]) -> Any: - """转换字段值为指定类型""" - # 如果是嵌套的 dataclass,递归调用 from_dict 方法 - if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase): - if not is_dict_like(value): - raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") - return field_type.from_dict(value) - - # 处理泛型集合类型(list, set, tuple) - field_origin_type = get_origin(field_type) - field_type_args = get_args(field_type) - - if field_origin_type in {list, set, tuple}: - if not isinstance(value, list): - raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") - - if field_origin_type is list: - if ( - field_type_args - and isinstance(field_type_args[0], type) - and issubclass(field_type_args[0], S4UConfigBase) - ): - return [field_type_args[0].from_dict(item) for item in value] - return [cls._convert_field(item, field_type_args[0]) for item in value] - elif field_origin_type is set: - return {cls._convert_field(item, field_type_args[0]) for item in value} - elif field_origin_type is tuple: - if len(value) != len(field_type_args): - raise TypeError( - f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}" - ) - return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False)) - - if field_origin_type is dict: - if not is_dict_like(value): - raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") - - if len(field_type_args) != 2: - raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") - key_type, value_type = field_type_args - - return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} - - # 处理基础类型,例如 int, str 等 - if field_origin_type is type(None) and value is None: # 处理Optional类型 - return None - - # 处理Literal类型 - if field_origin_type is Literal or get_origin(field_type) is Literal: - allowed_values = get_args(field_type) - if value in allowed_values: - return value - else: - raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") - - if field_type is Any or isinstance(value, field_type): - return value - - # 其他类型,尝试直接转换 - try: - return field_type(value) - except (ValueError, TypeError) as e: - raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e - - -@dataclass -class S4UModelConfig(S4UConfigBase): - """S4U模型配置类""" - - # 主要对话模型配置 - chat: dict[str, Any] = field(default_factory=lambda: {}) - """主要对话模型配置""" - - # 规划模型配置(原model_motion) - motion: dict[str, Any] = field(default_factory=lambda: {}) - """规划模型配置""" - - # 情感分析模型配置 - emotion: dict[str, Any] = field(default_factory=lambda: {}) - """情感分析模型配置""" - - # 记忆模型配置 - memory: dict[str, Any] = field(default_factory=lambda: {}) - """记忆模型配置""" - - # 工具使用模型配置 - tool_use: dict[str, Any] = field(default_factory=lambda: {}) - """工具使用模型配置""" - - # 嵌入模型配置 - embedding: dict[str, Any] = field(default_factory=lambda: {}) - """嵌入模型配置""" - - # 视觉语言模型配置 - vlm: dict[str, Any] = field(default_factory=lambda: {}) - """视觉语言模型配置""" - - # 知识库模型配置 - knowledge: dict[str, Any] = field(default_factory=lambda: {}) - """知识库模型配置""" - - # 实体提取模型配置 - entity_extract: dict[str, Any] = field(default_factory=lambda: {}) - """实体提取模型配置""" - - # 问答模型配置 - qa: dict[str, Any] = field(default_factory=lambda: {}) - """问答模型配置""" - - -@dataclass -class S4UConfig(S4UConfigBase): - """S4U聊天系统配置类""" - - message_timeout_seconds: int = 120 - """普通消息存活时间(秒),超过此时间的消息将被丢弃""" - - at_bot_priority_bonus: float = 100.0 - """@机器人时的优先级加成分数""" - - recent_message_keep_count: int = 6 - """保留最近N条消息,超出范围的普通消息将被移除""" - - typing_delay: float = 0.1 - """打字延迟时间(秒),模拟真实打字速度""" - - chars_per_second: float = 15.0 - """每秒字符数,用于计算动态打字延迟""" - - min_typing_delay: float = 0.2 - """最小打字延迟(秒)""" - - max_typing_delay: float = 2.0 - """最大打字延迟(秒)""" - - enable_dynamic_typing_delay: bool = False - """是否启用基于文本长度的动态打字延迟""" - - vip_queue_priority: bool = True - """是否启用VIP队列优先级系统""" - - enable_message_interruption: bool = True - """是否允许高优先级消息中断当前回复""" - - enable_old_message_cleanup: bool = True - """是否自动清理过旧的普通消息""" - - enable_streaming_output: bool = True - """是否启用流式输出,false时全部生成后一次性发送""" - - max_context_message_length: int = 20 - """上下文消息最大长度""" - - max_core_message_length: int = 30 - """核心消息最大长度""" - - # 模型配置 - models: S4UModelConfig = field(default_factory=S4UModelConfig) - """S4U模型配置""" - - # 兼容性字段,保持向后兼容 - - -@dataclass -class S4UGlobalConfig(S4UConfigBase): - """S4U总配置类""" - - s4u: S4UConfig - S4U_VERSION: str = S4U_VERSION - - -def update_s4u_config(): - """更新S4U配置文件""" - # 创建配置目录(如果不存在) - os.makedirs(CONFIG_DIR, exist_ok=True) - - # 检查模板文件是否存在 - if not os.path.exists(TEMPLATE_PATH): - logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}") - logger.error("请确保模板文件存在后重新运行") - raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}") - - # 检查配置文件是否存在 - if not os.path.exists(CONFIG_PATH): - logger.info("S4U配置文件不存在,从模板创建新配置") - shutil.copy2(TEMPLATE_PATH, CONFIG_PATH) - logger.info(f"已创建S4U配置文件: {CONFIG_PATH}") - return - - # 读取旧配置文件和模板文件 - with open(CONFIG_PATH, encoding="utf-8") as f: - old_config = tomlkit.load(f) - with open(TEMPLATE_PATH, encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") # type: ignore - new_version = new_config["inner"].get("version") # type: ignore - if old_version and new_version and old_version == new_version: - logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新") - return - else: - logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") - else: - logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新") - - # 创建备份目录 - old_config_dir = os.path.join(CONFIG_DIR, "old") - os.makedirs(old_config_dir, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml") - - # 移动旧配置文件到old目录 - shutil.move(CONFIG_PATH, old_backup_path) - logger.info(f"已备份旧S4U配置文件到: {old_backup_path}") - - # 复制模板文件到配置目录 - shutil.copy2(TEMPLATE_PATH, CONFIG_PATH) - logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}") - - def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, dict | Table): - update_dict(target_value, value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - # 将旧配置的值更新到新配置中 - logger.info("开始合并S4U新旧配置...") - update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式) - with open(CONFIG_PATH, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(new_config)) - - logger.info("S4U配置文件更新完成") - - -def load_s4u_config(config_path: str) -> S4UGlobalConfig: - """ - 加载S4U配置文件 - :param config_path: 配置文件路径 - :return: S4UGlobalConfig对象 - """ - # 读取配置文件 - with open(config_path, encoding="utf-8") as f: - config_data = tomlkit.load(f) - - # 创建S4UGlobalConfig对象 - try: - return S4UGlobalConfig.from_dict(config_data) - except Exception as e: - logger.critical("S4U配置文件解析失败") - raise e - - -if not ENABLE_S4U: - s4u_config = None - s4u_config_main = None -else: - # 初始化S4U配置 - logger.info(f"S4U当前版本: {S4U_VERSION}") - update_s4u_config() - - logger.info("正在加载S4U配置文件...") - s4u_config_main = load_s4u_config(config_path=CONFIG_PATH) - logger.info("S4U配置文件加载完成!") - - s4u_config: S4UConfig = s4u_config_main.s4u diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 14f1dfef5..ef52b93a1 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,7 +2,6 @@ import math import random import time -from src.chat.message_receive.message import MessageRecv from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.data_models.database_data_model import DatabaseMessages @@ -98,7 +97,7 @@ class ChatMood: if not hasattr(self, "last_change_time"): self.last_change_time = 0 - async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float): + async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float): # 确保异步初始化已完成 await self._initialize() @@ -109,11 +108,8 @@ class ChatMood: self.regression_count = 0 - # 处理不同类型的消息对象 - if isinstance(message, MessageRecv): - message_time = message.message_info.time - else: # DatabaseMessages - message_time = message.time + # 使用 DatabaseMessages 的时间字段 + message_time = message.time # 防止负时间差 during_last_time = max(0, message_time - self.last_change_time) diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 2f20ea5be..c9776df64 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -123,7 +123,7 @@ class RelationshipFetcher: # 获取用户特征点 current_points = await person_info_manager.get_value(person_id, "points") or [] forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] - + # 确保 points 是列表类型(可能从数据库返回字符串) if not isinstance(current_points, list): current_points = [] @@ -195,25 +195,25 @@ class RelationshipFetcher: if relationships: # db_query 返回字典列表,使用字典访问方式 rel_data = relationships[0] - + # 5.1 用户别名 if rel_data.get("user_aliases"): aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()] if aliases_list: aliases_str = "、".join(aliases_list) relation_parts.append(f"{person_name}的别名有:{aliases_str}") - + # 5.2 关系印象文本(主观认知) if rel_data.get("relationship_text"): relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}") - + # 5.3 用户偏好关键词 if rel_data.get("preference_keywords"): keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()] if keywords_list: keywords_str = "、".join(keywords_list) relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}") - + # 5.4 关系亲密程度(好感分数) if rel_data.get("relationship_score") is not None: score_desc = self._get_relationship_score_description(rel_data["relationship_score"]) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 96f0e4b09..429be54c8 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -55,7 +55,7 @@ async def file_to_stream( if not file_name: file_name = Path(file_path).name - + params = { "file": file_path, "name": file_name, @@ -68,7 +68,7 @@ async def file_to_stream( else: action = "upload_private_file" params["user_id"] = target_stream.user_info.user_id - + response = await adapter_command_to_stream( action=action, params=params, @@ -86,13 +86,16 @@ async def file_to_stream( import asyncio import time import traceback -from typing import Any +from typing import TYPE_CHECKING, Any from maim_message import Seg, UserInfo +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + # 导入依赖 from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.message import MessageRecv, MessageSending +from src.chat.message_receive.message import MessageSending from src.chat.message_receive.uni_message_sender import HeartFCSender from src.common.logger import get_logger from src.config.config import global_config @@ -104,84 +107,53 @@ logger = get_logger("send_api") _adapter_response_pool: dict[str, asyncio.Future] = {} -def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None: - """查找要回复的消息 +def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None": + """从消息字典构建 DatabaseMessages 对象 Args: message_dict: 消息字典或 DatabaseMessages 对象 Returns: - Optional[MessageRecv]: 找到的消息,如果没找到则返回None + Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None """ - # 兼容 DatabaseMessages 对象和字典 - if isinstance(message_dict, dict): - user_platform = message_dict.get("user_platform", "") - user_id = message_dict.get("user_id", "") - user_nickname = message_dict.get("user_nickname", "") - user_cardname = message_dict.get("user_cardname", "") - chat_info_group_id = message_dict.get("chat_info_group_id") - chat_info_group_platform = message_dict.get("chat_info_group_platform", "") - chat_info_group_name = message_dict.get("chat_info_group_name", "") - chat_info_platform = message_dict.get("chat_info_platform", "") - message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") - time_val = message_dict.get("time") - additional_config = message_dict.get("additional_config") - processed_plain_text = message_dict.get("processed_plain_text") - else: - # DatabaseMessages 对象 - user_platform = getattr(message_dict, "user_platform", "") - user_id = getattr(message_dict, "user_id", "") - user_nickname = getattr(message_dict, "user_nickname", "") - user_cardname = getattr(message_dict, "user_cardname", "") - chat_info_group_id = getattr(message_dict, "chat_info_group_id", None) - chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "") - chat_info_group_name = getattr(message_dict, "chat_info_group_name", "") - chat_info_platform = getattr(message_dict, "chat_info_platform", "") - message_id = getattr(message_dict, "message_id", None) - time_val = getattr(message_dict, "time", None) - additional_config = getattr(message_dict, "additional_config", None) - processed_plain_text = getattr(message_dict, "processed_plain_text", "") - - # 构建MessageRecv对象 - user_info = { - "platform": user_platform, - "user_id": user_id, - "user_nickname": user_nickname, - "user_cardname": user_cardname, - } + from src.common.data_models.database_data_model import DatabaseMessages - group_info = {} - if chat_info_group_id: - group_info = { - "platform": chat_info_group_platform, - "group_id": chat_info_group_id, - "group_name": chat_info_group_name, - } + # 如果已经是 DatabaseMessages,直接返回 + if isinstance(message_dict, DatabaseMessages): + return message_dict - format_info = {"content_format": "", "accept_format": ""} - template_info = {"template_items": {}} + # 从字典提取信息 + user_platform = message_dict.get("user_platform", "") + user_id = message_dict.get("user_id", "") + user_nickname = message_dict.get("user_nickname", "") + user_cardname = message_dict.get("user_cardname", "") + chat_info_group_id = message_dict.get("chat_info_group_id") + chat_info_group_platform = message_dict.get("chat_info_group_platform", "") + chat_info_group_name = message_dict.get("chat_info_group_name", "") + chat_info_platform = message_dict.get("chat_info_platform", "") + message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") + time_val = message_dict.get("time", time.time()) + additional_config = message_dict.get("additional_config") + processed_plain_text = message_dict.get("processed_plain_text", "") - message_info = { - "platform": chat_info_platform, - "message_id": message_id, - "time": time_val, - "group_info": group_info, - "user_info": user_info, - "additional_config": additional_config, - "format_info": format_info, - "template_info": template_info, - } + # DatabaseMessages 使用扁平参数构造 + db_message = DatabaseMessages( + message_id=message_id or "temp_reply_id", + time=time_val, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + user_platform=user_platform, + chat_info_group_id=chat_info_group_id, + chat_info_group_name=chat_info_group_name, + chat_info_group_platform=chat_info_group_platform, + chat_info_platform=chat_info_platform, + processed_plain_text=processed_plain_text, + additional_config=additional_config + ) - new_message_dict = { - "message_info": message_info, - "raw_message": processed_plain_text, - "processed_plain_text": processed_plain_text, - } - - message_recv = MessageRecv(new_message_dict) - - logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}") - return message_recv + logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}") + return db_message def put_adapter_response(request_id: str, response_data: dict) -> None: @@ -285,17 +257,17 @@ async def _send_to_target( "message_id": "temp_reply_id", # 临时ID "time": time.time() } - anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict) + anchor_message = message_dict_to_db_message(message_dict=temp_message_dict) else: anchor_message = None reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None elif reply_to_message: - anchor_message = message_dict_to_message_recv(message_dict=reply_to_message) + anchor_message = message_dict_to_db_message(message_dict=reply_to_message) if anchor_message: - anchor_message.update_chat_stream(target_stream) + # DatabaseMessages 不需要 update_chat_stream,它是纯数据对象 reply_to_platform_id = ( - f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" + f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}" ) else: reply_to_platform_id = None diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index b5071e578..e102b55cc 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -192,7 +192,7 @@ class BaseAction(ABC): self.group_name = self.action_message.get("chat_info_group_name", None) self.user_id = str(self.action_message.get("user_id", None)) self.user_nickname = self.action_message.get("user_nickname", None) - + if self.group_id: self.is_group = True self.target_id = self.group_id diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 9cb41ed04..df604cbc0 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,10 +1,14 @@ from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.plugin_system.apis import send_api from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("base_command") @@ -29,11 +33,11 @@ class BaseCommand(ABC): chat_type_allow: ChatType = ChatType.ALL """允许的聊天类型,默认为所有类型""" - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): """初始化Command组件 Args: - message: 接收到的消息对象 + message: 接收到的消息对象(DatabaseMessages) plugin_config: 插件配置字典 """ self.message = message @@ -42,6 +46,9 @@ class BaseCommand(ABC): self.log_prefix = "[Command]" + # chat_stream 会在运行时被 bot.py 设置 + self.chat_stream: "ChatStream | None" = None + # 从类属性获取chat_type_allow设置 self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL) @@ -49,7 +56,7 @@ class BaseCommand(ABC): # 验证聊天类型限制 if not self._validate_chat_type(): - is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message + is_group = message.group_info is not None logger.warning( f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: " f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" @@ -72,8 +79,8 @@ class BaseCommand(ABC): if self.chat_type_allow == ChatType.ALL: return True - # 检查是否为群聊消息 - is_group = self.message.message_info.group_info + # 检查是否为群聊消息(DatabaseMessages使用group_info来判断) + is_group = self.message.group_info is not None if self.chat_type_allow == ChatType.GROUP and is_group: return True @@ -137,12 +144,11 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) + return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to) async def send_type( self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" @@ -160,15 +166,14 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False return await send_api.custom_to_stream( message_type=message_type, content=content, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, display_message=display_message, typing=typing, reply_to=reply_to, @@ -190,8 +195,7 @@ class BaseCommand(ABC): """ try: # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False @@ -200,7 +204,7 @@ class BaseCommand(ABC): success = await send_api.command_to_stream( command=command_data, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, storage_message=storage_message, display_message=display_message, ) @@ -225,12 +229,11 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) + return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id) async def send_image(self, image_base64: str) -> bool: """发送图片 @@ -241,12 +244,11 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id) + return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id) @classmethod def get_command_info(cls) -> "CommandInfo": diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index e442d76c1..525819763 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -5,8 +5,9 @@ import re from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import send_api @@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.command_args import CommandArgs from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("plus_command") @@ -50,23 +54,26 @@ class PlusCommand(ABC): intercept_message: bool = False """是否拦截消息,不进行后续处理""" - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): """初始化命令组件 Args: - message: 接收到的消息对象 + message: 接收到的消息对象(DatabaseMessages) plugin_config: 插件配置字典 """ self.message = message self.plugin_config = plugin_config or {} self.log_prefix = "[PlusCommand]" + # chat_stream 会在运行时被 bot.py 设置 + self.chat_stream: "ChatStream | None" = None + # 解析命令参数 self._parse_command() # 验证聊天类型限制 if not self._validate_chat_type(): - is_group = self.message.message_info.group_info.group_id + is_group = message.group_info is not None logger.warning( f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: " f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" @@ -124,8 +131,8 @@ class PlusCommand(ABC): if self.chat_type_allow == ChatType.ALL: return True - # 检查是否为群聊消息 - is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info + # 检查是否为群聊消息(DatabaseMessages使用group_info判断) + is_group = self.message.group_info is not None if self.chat_type_allow == ChatType.GROUP and is_group: return True @@ -152,7 +159,7 @@ class PlusCommand(ABC): def _is_exact_command_call(self) -> bool: """检查是否是精确的命令调用(无参数)""" - if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text: + if not self.message.processed_plain_text: return False plain_text = self.message.processed_plain_text.strip() @@ -218,12 +225,11 @@ class PlusCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) + return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to) async def send_type( self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" @@ -241,15 +247,14 @@ class PlusCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False return await send_api.custom_to_stream( message_type=message_type, content=content, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, display_message=display_message, typing=typing, reply_to=reply_to, @@ -264,12 +269,11 @@ class PlusCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) + return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id) async def send_image(self, image_base64: str) -> bool: """发送图片 @@ -280,12 +284,11 @@ class PlusCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id) + return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id) @classmethod def get_plus_command_info(cls) -> "PlusCommandInfo": @@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand): 将PlusCommand适配到现有的插件系统,继承BaseCommand """ - def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None): """初始化适配器 Args: plus_command_class: PlusCommand子类 - message: 消息对象 + message: 消息对象(DatabaseMessages) plugin_config: 插件配置 """ # 先设置必要的类属性 @@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class): command_pattern = plus_command_class._generate_command_pattern() chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL) - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): super().__init__(message, plugin_config) self.plus_command = plus_command_class(message, plugin_config) self.priority = getattr(plus_command_class, "priority", 0) diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index e54861b15..64468b958 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -40,7 +40,7 @@ class EventManager: self._events: dict[str, BaseEvent] = {} self._event_handlers: dict[str, type[BaseEventHandler]] = {} self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 - self._scheduler_callback: Optional[Any] = None # scheduler 回调函数 + self._scheduler_callback: Any | None = None # scheduler 回调函数 self._initialized = True logger.info("EventManager 单例初始化完成") diff --git a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py index a49f7f36e..87f1abfce 100644 --- a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py @@ -5,7 +5,6 @@ """ import json -import time from typing import Any from sqlalchemy import select @@ -22,7 +21,7 @@ logger = get_logger("chat_stream_impression_tool") class ChatStreamImpressionTool(BaseTool): """聊天流印象更新工具 - + 使用二步调用机制: 1. LLM决定是否调用工具并传入初步参数(stream_id会自动传入) 2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容 @@ -31,27 +30,52 @@ class ChatStreamImpressionTool(BaseTool): name = "update_chat_stream_impression" description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。" parameters = [ - ("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None), - ("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None), - ("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None), - ("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None), + ( + "impression_description", + ToolParamType.STRING, + "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", + False, + None, + ), + ( + "chat_style", + ToolParamType.STRING, + "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", + False, + None, + ), + ( + "topic_keywords", + ToolParamType.STRING, + "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", + False, + None, + ), + ( + "interest_score", + ToolParamType.FLOAT, + "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", + False, + None, + ), ] available_for_llm = True history_ttl = 5 def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): super().__init__(plugin_config, chat_stream) - + # 初始化用于二步调用的LLM try: self.impression_llm = LLMRequest( model_set=model_config.model_task_config.relationship_tracker, - request_type="chat_stream_impression_update" + request_type="chat_stream_impression_update", ) except AttributeError: # 降级处理 available_models = [ - attr for attr in dir(model_config.model_task_config) + attr + for attr in dir(model_config.model_task_config) if not attr.startswith("_") and attr != "model_dump" ] if available_models: @@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool): logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}") self.impression_llm = LLMRequest( model_set=getattr(model_config.model_task_config, fallback_model), - request_type="chat_stream_impression_update" + request_type="chat_stream_impression_update", ) else: logger.error("无可用的模型配置") @@ -67,17 +91,17 @@ class ChatStreamImpressionTool(BaseTool): async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行聊天流印象更新 - + Args: function_args: 工具参数 - + Returns: dict: 执行结果 """ try: # 优先从 function_args 获取 stream_id stream_id = function_args.get("stream_id") - + # 如果没有,从 chat_stream 对象获取 if not stream_id and self.chat_stream: try: @@ -85,61 +109,49 @@ class ChatStreamImpressionTool(BaseTool): logger.debug(f"从 chat_stream 获取到 stream_id: {stream_id}") except AttributeError: logger.warning("chat_stream 对象没有 stream_id 属性") - + # 如果还是没有,返回错误 if not stream_id: logger.error("无法获取 stream_id:function_args 和 chat_stream 都没有提供") - return { - "type": "error", - "id": "chat_stream_impression", - "content": "错误:无法获取当前聊天流ID" - } - + return {"type": "error", "id": "chat_stream_impression", "content": "错误:无法获取当前聊天流ID"} + # 从LLM传入的参数 new_impression = function_args.get("impression_description", "") new_style = function_args.get("chat_style", "") new_topics = function_args.get("topic_keywords", "") new_score = function_args.get("interest_score") - + # 从数据库获取现有聊天流印象 existing_impression = await self._get_stream_impression(stream_id) - + # 如果LLM没有传入任何有效参数,返回提示 if not any([new_impression, new_style, new_topics, new_score is not None]): return { "type": "info", "id": stream_id, - "content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)" + "content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)", } - + # 调用LLM进行二步决策 if self.impression_llm is None: logger.error("LLM未正确初始化,无法执行二步调用") - return { - "type": "error", - "id": stream_id, - "content": "系统错误:LLM未正确初始化" - } - + return {"type": "error", "id": stream_id, "content": "系统错误:LLM未正确初始化"} + final_impression = await self._llm_decide_final_impression( stream_id=stream_id, existing_impression=existing_impression, new_impression=new_impression, new_style=new_style, new_topics=new_topics, - new_score=new_score + new_score=new_score, ) - + if not final_impression: - return { - "type": "error", - "id": stream_id, - "content": "LLM决策失败,无法更新聊天流印象" - } - + return {"type": "error", "id": stream_id, "content": "LLM决策失败,无法更新聊天流印象"} + # 更新数据库 await self._update_stream_impression_in_db(stream_id, final_impression) - + # 构建返回信息 updates = [] if final_impression.get("stream_impression_text"): @@ -150,30 +162,26 @@ class ChatStreamImpressionTool(BaseTool): updates.append(f"话题: {final_impression['stream_topic_keywords']}") if final_impression.get("stream_interest_score") is not None: updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}") - + result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates) logger.info(f"聊天流印象更新成功: {stream_id}") - - return { - "type": "chat_stream_impression_update", - "id": stream_id, - "content": result_text - } - + + return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text} + except Exception as e: logger.error(f"聊天流印象更新失败: {e}", exc_info=True) return { "type": "error", "id": function_args.get("stream_id", "unknown"), - "content": f"聊天流印象更新失败: {str(e)}" + "content": f"聊天流印象更新失败: {e!s}", } async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]: """从数据库获取聊天流现有印象 - + Args: stream_id: 聊天流ID - + Returns: dict: 聊天流印象数据 """ @@ -182,13 +190,15 @@ class ChatStreamImpressionTool(BaseTool): stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if stream: return { "stream_impression_text": stream.stream_impression_text or "", "stream_chat_style": stream.stream_chat_style or "", "stream_topic_keywords": stream.stream_topic_keywords or "", - "stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5, + "stream_interest_score": float(stream.stream_interest_score) + if stream.stream_interest_score is not None + else 0.5, "group_name": stream.group_name or "私聊", } else: @@ -217,10 +227,10 @@ class ChatStreamImpressionTool(BaseTool): new_impression: str, new_style: str, new_topics: str, - new_score: float | None + new_score: float | None, ) -> dict[str, Any] | None: """使用LLM决策最终的聊天流印象内容 - + Args: stream_id: 聊天流ID existing_impression: 现有印象数据 @@ -228,33 +238,34 @@ class ChatStreamImpressionTool(BaseTool): new_style: LLM传入的新风格 new_topics: LLM传入的新话题 new_score: LLM传入的新分数 - + Returns: dict: 最终决定的印象数据,如果失败返回None """ try: # 获取bot人设 from src.individuality.individuality import Individuality + individuality = Individuality() bot_personality = await individuality.get_personality_block() - + prompt = f""" 你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} 你正在更新对聊天流 {stream_id} 的整体印象。 【当前聊天流信息】 -- 聊天环境: {existing_impression.get('group_name', '未知')} -- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')} -- 聊天风格: {existing_impression.get('stream_chat_style', '未知')} -- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')} -- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f} +- 聊天环境: {existing_impression.get("group_name", "未知")} +- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")} +- 聊天风格: {existing_impression.get("stream_chat_style", "未知")} +- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")} +- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f} 【本次想要更新的内容】 -- 新的印象描述: {new_impression if new_impression else '不更新'} -- 新的聊天风格: {new_style if new_style else '不更新'} -- 新的话题关键词: {new_topics if new_topics else '不更新'} -- 新的兴趣分数: {new_score if new_score is not None else '不更新'} +- 新的印象描述: {new_impression if new_impression else "不更新"} +- 新的聊天风格: {new_style if new_style else "不更新"} +- 新的话题关键词: {new_topics if new_topics else "不更新"} +- 新的兴趣分数: {new_score if new_score is not None else "不更新"} 请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意: 1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字) @@ -271,31 +282,50 @@ class ChatStreamImpressionTool(BaseTool): "reasoning": "你的决策理由" }} """ - + # 调用LLM + if not self.impression_llm: + logger.info("未初始化impression_llm") + return None llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt) - + if not llm_response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析响应 cleaned_response = self._clean_llm_json_response(llm_response) response_data = json.loads(cleaned_response) - + # 提取最终决定的数据 final_impression = { - "stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")), - "stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")), - "stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")), - "stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))), + "stream_impression_text": response_data.get( + "stream_impression_text", existing_impression.get("stream_impression_text", "") + ), + "stream_chat_style": response_data.get( + "stream_chat_style", existing_impression.get("stream_chat_style", "") + ), + "stream_topic_keywords": response_data.get( + "stream_topic_keywords", existing_impression.get("stream_topic_keywords", "") + ), + "stream_interest_score": max( + 0.0, + min( + 1.0, + float( + response_data.get( + "stream_interest_score", existing_impression.get("stream_interest_score", 0.5) + ) + ), + ), + ), } - + logger.info(f"LLM决策完成: {stream_id}") logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") - + return final_impression - + except json.JSONDecodeError as e: logger.error(f"LLM响应JSON解析失败: {e}") logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") @@ -306,7 +336,7 @@ class ChatStreamImpressionTool(BaseTool): async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]): """更新数据库中的聊天流印象 - + Args: stream_id: 聊天流ID impression: 印象数据 @@ -316,14 +346,14 @@ class ChatStreamImpressionTool(BaseTool): stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 更新现有记录 existing.stream_impression_text = impression.get("stream_impression_text", "") existing.stream_chat_style = impression.get("stream_chat_style", "") existing.stream_topic_keywords = impression.get("stream_topic_keywords", "") existing.stream_interest_score = impression.get("stream_interest_score", 0.5) - + await session.commit() logger.info(f"聊天流印象已更新到数据库: {stream_id}") else: @@ -331,40 +361,40 @@ class ChatStreamImpressionTool(BaseTool): logger.error(error_msg) # 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录 raise ValueError(error_msg) - + except Exception as e: logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True) raise def _clean_llm_json_response(self, response: str) -> str: """清理LLM响应,移除可能的JSON格式标记 - + Args: response: LLM原始响应 - + Returns: str: 清理后的JSON字符串 """ try: import re - + cleaned = response.strip() - + # 移除 ```json 或 ``` 等标记 cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + # 尝试找到JSON对象的开始和结束 json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: - cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned[json_start : json_end + 1] + cleaned = cleaned.strip() - + return cleaned - + except Exception as e: logger.warning(f"清理LLM响应失败: {e}") return response diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index 3af389f9f..4359b3f66 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -231,11 +231,11 @@ class ChatterPlanExecutor: except Exception as e: error_message = str(e) logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}") - + # 将机器人回复添加到已读消息中 if success and action_info.action_message: await self._add_bot_reply_to_read_messages(action_info, plan, reply_content) - + execution_time = time.time() - start_time self.execution_stats["execution_times"].append(execution_time) @@ -381,13 +381,11 @@ class ChatterPlanExecutor: is_picid=False, is_command=False, is_notify=False, - # 用户信息 user_id=bot_user_id, user_nickname=bot_nickname, user_cardname=bot_nickname, user_platform="qq", - # 聊天上下文信息 chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id, chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname, @@ -397,24 +395,21 @@ class ChatterPlanExecutor: chat_info_platform=chat_stream.platform, chat_info_create_time=chat_stream.create_time, chat_info_last_active_time=chat_stream.last_active_time, - # 群组信息(如果是群聊) chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None, chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None, - chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None, - + chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) + if chat_stream.group_info + else None, # 动作信息 actions=["bot_reply"], should_reply=False, - should_act=False + should_act=False, ) # 添加到chat_stream的已读消息中 - if hasattr(chat_stream, "stream_context") and chat_stream.stream_context: - chat_stream.stream_context.history_messages.append(bot_message) - logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") - else: - logger.warning("chat_stream没有stream_context,无法添加已读消息") + chat_stream.context_manager.context.history_messages.append(bot_message) + logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") except Exception as e: logger.error(f"添加机器人回复到已读消息时出错: {e}") diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 3013afaa4..afe2241a2 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -60,7 +60,7 @@ class ChatterPlanFilter: prompt, used_message_id_list = await self._build_prompt(plan) plan.llm_prompt = prompt if global_config.debug.show_prompt: - logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡 + logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡 llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) @@ -104,24 +104,26 @@ class ChatterPlanFilter: # 预解析 action_type 来进行判断 thinking = item.get("thinking", "未提供思考过程") actions_obj = item.get("actions", {}) - + # 记录决策历史 - if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history: + if ( + hasattr(global_config.chat, "enable_decision_history") + and global_config.chat.enable_decision_history + ): action_types_to_log = [] actions_to_process_for_log = [] if isinstance(actions_obj, dict): actions_to_process_for_log.append(actions_obj) elif isinstance(actions_obj, list): actions_to_process_for_log.extend(actions_obj) - + for single_action in actions_to_process_for_log: if isinstance(single_action, dict): action_types_to_log.append(single_action.get("action_type", "no_action")) - + if thinking != "未提供思考过程" and action_types_to_log: await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log)) - # 处理actions字段可能是字典或列表的情况 if isinstance(actions_obj, dict): action_type = actions_obj.get("action_type", "no_action") @@ -579,15 +581,15 @@ class ChatterPlanFilter: ): reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}" action = "no_action" - #TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来) - #from src.common.data_models.database_data_model import DatabaseMessages + # TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来) + # from src.common.data_models.database_data_model import DatabaseMessages - #action_message_obj = None - #if target_message_obj: - #try: - #action_message_obj = DatabaseMessages(**target_message_obj) - #except Exception: - #logger.warning("无法将目标消息转换为DatabaseMessages对象") + # action_message_obj = None + # if target_message_obj: + # try: + # action_message_obj = DatabaseMessages(**target_message_obj) + # except Exception: + # logger.warning("无法将目标消息转换为DatabaseMessages对象") parsed_actions.append( ActionPlannerInfo( diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index a8ae019a0..8fc75b4ef 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager - from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import Plan from src.common.data_models.message_manager_data_model import StreamContext @@ -100,11 +99,11 @@ class ChatterActionPlanner: if context: context.chat_mode = ChatMode.FOCUS await self._sync_chat_mode_to_stream(context) - + # Normal模式下使用简化流程 if chat_mode == ChatMode.NORMAL: return await self._normal_mode_flow(context) - + # 在规划前,先进行动作修改 from src.chat.planner_actions.action_modifier import ActionModifier action_modifier = ActionModifier(self.action_manager, self.chat_id) @@ -184,12 +183,12 @@ class ChatterActionPlanner: for action in filtered_plan.decided_actions: if action.action_type in ["reply", "proactive_reply"] and action.action_message: # 提取目标消息ID - if hasattr(action.action_message, 'message_id'): + if hasattr(action.action_message, "message_id"): target_message_id = action.action_message.message_id elif isinstance(action.action_message, dict): - target_message_id = action.action_message.get('message_id') + target_message_id = action.action_message.get("message_id") break - + # 如果找到目标消息ID,检查是否已经在处理中 if target_message_id and context: if context.processing_message_id == target_message_id: @@ -215,7 +214,7 @@ class ChatterActionPlanner: # 6. 根据执行结果更新统计信息 self._update_stats_from_execution_result(execution_result) - + # 7. Focus模式下如果执行了reply动作,切换到Normal模式 if chat_mode == ChatMode.FOCUS and context: if filtered_plan.decided_actions: @@ -233,7 +232,7 @@ class ChatterActionPlanner: # 8. 清理处理标记 if context: context.processing_message_id = None - logger.debug(f"已清理处理标记,完成规划流程") + logger.debug("已清理处理标记,完成规划流程") # 9. 返回结果 return self._build_return_result(filtered_plan) @@ -262,7 +261,7 @@ class ChatterActionPlanner: return await self._enhanced_plan_flow(context) try: unread_messages = context.get_unread_messages() if context else [] - + if not unread_messages: logger.debug("Normal模式: 没有未读消息") from src.common.data_models.info_data_model import ActionPlannerInfo @@ -273,11 +272,11 @@ class ChatterActionPlanner: action_message=None, ) return [asdict(no_action)], None - + # 检查是否有消息达到reply阈值 should_reply = False target_message = None - + for message in unread_messages: message_should_reply = getattr(message, "should_reply", False) if message_should_reply: @@ -285,7 +284,7 @@ class ChatterActionPlanner: target_message = message logger.info(f"Normal模式: 消息 {message.message_id} 达到reply阈值") break - + if should_reply and target_message: # 检查是否正在处理相同的目标消息,防止重复回复 target_message_id = target_message.message_id @@ -302,26 +301,26 @@ class ChatterActionPlanner: action_message=None, ) return [asdict(no_action)], None - + # 记录当前正在处理的消息ID if context: context.processing_message_id = target_message_id logger.debug(f"Normal模式: 开始处理目标消息: {target_message_id}") - + # 达到reply阈值,直接进入回复流程 from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.plugin_system.base.component_types import ChatType - + # 构建目标消息字典 - 使用 flatten() 方法获取扁平化的字典 target_message_dict = target_message.flatten() - + reply_action = ActionPlannerInfo( action_type="reply", reasoning="Normal模式: 兴趣度达到阈值,直接回复", action_data={"target_message_id": target_message.message_id}, action_message=target_message, ) - + # Normal模式下直接构建最小化的Plan,跳过generator和action_modifier # 这样可以显著降低延迟 minimal_plan = Plan( @@ -330,25 +329,25 @@ class ChatterActionPlanner: mode=ChatMode.NORMAL, decided_actions=[reply_action], ) - + # 执行reply动作 execution_result = await self.executor.execute(minimal_plan) self._update_stats_from_execution_result(execution_result) - + logger.info("Normal模式: 执行reply动作完成") - + # 清理处理标记 if context: context.processing_message_id = None - logger.debug(f"Normal模式: 已清理处理标记") - + logger.debug("Normal模式: 已清理处理标记") + # 无论是否回复,都进行退出normal模式的判定 await self._check_exit_normal_mode(context) - + return [asdict(reply_action)], target_message_dict else: # 未达到reply阈值 - logger.debug(f"Normal模式: 未达到reply阈值") + logger.debug("Normal模式: 未达到reply阈值") from src.common.data_models.info_data_model import ActionPlannerInfo no_action = ActionPlannerInfo( action_type="no_action", @@ -356,12 +355,12 @@ class ChatterActionPlanner: action_data={}, action_message=None, ) - + # 无论是否回复,都进行退出normal模式的判定 await self._check_exit_normal_mode(context) - + return [asdict(no_action)], None - + except Exception as e: logger.error(f"Normal模式流程出错: {e}") self.planner_stats["failed_plans"] += 1 @@ -378,16 +377,16 @@ class ChatterActionPlanner: """ if not context: return - + try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(self.chat_id) if chat_manager else None - + if not chat_stream: return - + focus_energy = chat_stream.focus_energy # focus_energy越低,退出normal模式的概率越高 # 使用反比例函数: 退出概率 = 1 - focus_energy @@ -395,7 +394,7 @@ class ChatterActionPlanner: # 当focus_energy = 0.5时,退出概率 = 50% # 当focus_energy = 0.9时,退出概率 = 10% exit_probability = 1.0 - focus_energy - + import random if random.random() < exit_probability: logger.info(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 切换回focus模式") @@ -404,7 +403,7 @@ class ChatterActionPlanner: await self._sync_chat_mode_to_stream(context) else: logger.debug(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 保持normal模式") - + except Exception as e: logger.warning(f"检查退出Normal模式失败: {e}") @@ -412,7 +411,7 @@ class ChatterActionPlanner: """同步chat_mode到ChatStream""" try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() if chat_manager: chat_stream = await chat_manager.get_stream(context.stream_id) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py index 2a719da83..b7f45b749 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py @@ -15,57 +15,57 @@ logger = get_logger("proactive_thinking_event") class ProactiveThinkingReplyHandler(BaseEventHandler): """Reply事件处理器 - + 当bot回复某个聊天流后: 1. 如果该聊天流的主动思考被暂停(因为抛出了话题),则恢复它 2. 无论是否暂停,都重置定时任务,重新开始计时 """ - + handler_name: str = "proactive_thinking_reply_handler" handler_description: str = "监听reply事件,重置主动思考定时任务" init_subscribe: list[EventType | str] = [EventType.AFTER_SEND] - + async def execute(self, kwargs: dict | None) -> HandlerResult: """处理reply事件 - + Args: kwargs: 事件参数,应包含 stream_id - + Returns: HandlerResult: 处理结果 """ logger.debug("[主动思考事件] ProactiveThinkingReplyHandler 开始执行") logger.debug(f"[主动思考事件] 接收到的参数: {kwargs}") - + if not kwargs: logger.debug("[主动思考事件] kwargs 为空,跳过处理") return HandlerResult(success=True, continue_process=True, message=None) - + stream_id = kwargs.get("stream_id") if not stream_id: - logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数") + logger.debug("[主动思考事件] Reply事件缺少stream_id参数") return HandlerResult(success=True, continue_process=True, message=None) - + logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件,stream_id={stream_id}") - + try: from src.config.config import global_config - + # 检查是否启用reply重置 if not global_config.proactive_thinking.reply_reset_enabled: - logger.debug(f"[主动思考事件] reply_reset_enabled 为 False,跳过重置") + logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置") return HandlerResult(success=True, continue_process=True, message=None) - + # 检查是否被暂停 was_paused = await proactive_thinking_scheduler.is_paused(stream_id) logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}") - + if was_paused: logger.debug(f"[主动思考事件] 检测到reply事件,聊天流 {stream_id} 之前因抛出话题而暂停,现在恢复") - + # 重置定时任务(这会自动清除暂停标记并创建新任务) success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id) - + if success: if was_paused: logger.info(f"✅ 聊天流 {stream_id} 主动思考已恢复并重置") @@ -73,82 +73,82 @@ class ProactiveThinkingReplyHandler(BaseEventHandler): logger.debug(f"✅ 聊天流 {stream_id} 主动思考任务已重置") else: logger.warning(f"❌ 重置聊天流 {stream_id} 主动思考任务失败") - + except Exception as e: logger.error(f"❌ 处理reply事件时出错: {e}", exc_info=True) - + # 总是继续处理其他handler return HandlerResult(success=True, continue_process=True, message=None) class ProactiveThinkingMessageHandler(BaseEventHandler): """消息事件处理器 - + 当收到消息时,如果该聊天流还没有主动思考任务,则创建一个 这样可以确保新的聊天流也能获得主动思考功能 """ - + handler_name: str = "proactive_thinking_message_handler" handler_description: str = "监听消息事件,为新聊天流创建主动思考任务" init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE] - + async def execute(self, kwargs: dict | None) -> HandlerResult: """处理消息事件 - + Args: - kwargs: 事件参数,格式为 {"message": MessageRecv} - + kwargs: 事件参数,格式为 {"message": DatabaseMessages} + Returns: HandlerResult: 处理结果 """ if not kwargs: return HandlerResult(success=True, continue_process=True, message=None) - - # 从 kwargs 中获取 MessageRecv 对象 + + # 从 kwargs 中获取 DatabaseMessages 对象 message = kwargs.get("message") if not message or not hasattr(message, "chat_stream"): return HandlerResult(success=True, continue_process=True, message=None) - + # 从 chat_stream 获取 stream_id chat_stream = message.chat_stream if not chat_stream or not hasattr(chat_stream, "stream_id"): return HandlerResult(success=True, continue_process=True, message=None) - + stream_id = chat_stream.stream_id - + try: from src.config.config import global_config - + # 检查是否启用主动思考 if not global_config.proactive_thinking.enable: return HandlerResult(success=True, continue_process=True, message=None) - + # 检查该聊天流是否已经有任务 task_info = await proactive_thinking_scheduler.get_task_info(stream_id) if task_info: # 已经有任务,不需要创建 return HandlerResult(success=True, continue_process=True, message=None) - + # 从 message_info 获取平台和聊天ID信息 message_info = message.message_info platform = message_info.platform is_group = message_info.group_info is not None chat_id = message_info.group_info.group_id if is_group else message_info.user_info.user_id # type: ignore - + # 构造配置字符串 stream_config = f"{platform}:{chat_id}:{'group' if is_group else 'private'}" - + # 检查黑白名单 if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config): return HandlerResult(success=True, continue_process=True, message=None) - + # 创建主动思考任务 success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id) if success: logger.info(f"为新聊天流 {stream_id} 创建了主动思考任务") - + except Exception as e: logger.error(f"处理消息事件时出错: {e}", exc_info=True) - + # 总是继续处理其他handler return HandlerResult(success=True, continue_process=True, message=None) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py index 80de51f5f..cea22211c 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py @@ -5,11 +5,10 @@ import json from datetime import datetime -from typing import Any, Literal, Optional +from typing import Any, Literal from sqlalchemy import select -from src.chat.express.expression_learner import expression_learner_manager from src.chat.express.expression_selector import expression_selector from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams @@ -17,42 +16,40 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import Individuality from src.llm_models.utils_model import LLMRequest -from src.plugin_system.apis import chat_api, message_api, send_api +from src.plugin_system.apis import message_api, send_api logger = get_logger("proactive_thinking_executor") class ProactiveThinkingPlanner: """主动思考规划器 - + 负责: 1. 搜集信息(聊天流印象、话题关键词、历史聊天记录) 2. 调用LLM决策:什么都不做/简单冒泡/抛出话题 3. 根据决策生成回复内容 """ - + def __init__(self): """初始化规划器""" try: self.decision_llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="proactive_thinking_decision" + model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision" ) self.reply_llm = LLMRequest( - model_set=model_config.model_task_config.replyer, - request_type="proactive_thinking_reply" + model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply" ) except Exception as e: logger.error(f"初始化LLM失败: {e}") self.decision_llm = None self.reply_llm = None - - async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def gather_context(self, stream_id: str) -> dict[str, Any] | None: """搜集聊天流的上下文信息 - + Args: stream_id: 聊天流ID - + Returns: dict: 包含所有上下文信息的字典,失败返回None """ @@ -62,27 +59,28 @@ class ProactiveThinkingPlanner: if not stream_data: logger.warning(f"无法获取聊天流 {stream_id} 的印象数据") return None - + # 2. 获取最近的聊天记录 recent_messages = await message_api.get_recent_messages( chat_id=stream_id, - limit=20, + limit=40, limit_mode="latest", hours=24 ) - + recent_chat_history = "" if recent_messages: recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages) - + # 3. 获取bot人设 individuality = Individuality() bot_personality = await individuality.get_personality_block() - + # 4. 获取当前心情 current_mood = "感觉很平静" # 默认心情 try: from src.mood.mood_manager import mood_manager + mood_obj = mood_manager.get_mood_by_chat_id(stream_id) if mood_obj: await mood_obj._initialize() # 确保已初始化 @@ -90,19 +88,20 @@ class ProactiveThinkingPlanner: logger.debug(f"获取到聊天流 {stream_id} 的心情: {current_mood}") except Exception as e: logger.warning(f"获取心情失败,使用默认值: {e}") - + # 5. 获取上次决策 last_decision = None try: from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import ( proactive_thinking_scheduler, ) + last_decision = proactive_thinking_scheduler.get_last_decision(stream_id) if last_decision: logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}") except Exception as e: logger.warning(f"获取上次决策失败: {e}") - + # 6. 构建上下文 context = { "stream_id": stream_id, @@ -117,45 +116,45 @@ class ProactiveThinkingPlanner: "current_mood": current_mood, "last_decision": last_decision, } - + logger.debug(f"成功搜集聊天流 {stream_id} 的上下文信息") return context - + except Exception as e: logger.error(f"搜集上下文信息失败: {e}", exc_info=True) return None - - async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None: """从数据库获取聊天流印象数据""" try: async with get_db_session() as session: stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if not stream: return None - + return { "stream_name": stream.group_name or "私聊", "stream_impression_text": stream.stream_impression_text or "", "stream_chat_style": stream.stream_chat_style or "", "stream_topic_keywords": stream.stream_topic_keywords or "", - "stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5, + "stream_interest_score": float(stream.stream_interest_score) + if stream.stream_interest_score + else 0.5, } - + except Exception as e: logger.error(f"获取聊天流印象失败: {e}") return None - - async def make_decision( - self, context: dict[str, Any] - ) -> Optional[dict[str, Any]]: + + async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None: """使用LLM进行决策 - + Args: context: 上下文信息 - + Returns: dict: 决策结果,包含: - action: "do_nothing" | "simple_bubble" | "throw_topic" @@ -165,30 +164,28 @@ class ProactiveThinkingPlanner: if not self.decision_llm: logger.error("决策LLM未初始化") return None - + response = None try: decision_prompt = self._build_decision_prompt(context) - + if global_config.debug.show_prompt: logger.info(f"决策提示词:\n{decision_prompt}") - + response, _ = await self.decision_llm.generate_response_async(prompt=decision_prompt) - + if not response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析JSON响应 cleaned_response = self._clean_json_response(response) decision = json.loads(cleaned_response) - - logger.info( - f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}" - ) - + + logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}") + return decision - + except json.JSONDecodeError as e: logger.error(f"解析决策JSON失败: {e}") if response: @@ -197,18 +194,18 @@ class ProactiveThinkingPlanner: except Exception as e: logger.error(f"决策过程失败: {e}", exc_info=True) return None - + def _build_decision_prompt(self, context: dict[str, Any]) -> str: """构建决策提示词""" # 构建上次决策信息 last_decision_text = "" - if context.get('last_decision'): - last_dec = context['last_decision'] - last_action = last_dec.get('action', '未知') - last_reasoning = last_dec.get('reasoning', '无') - last_topic = last_dec.get('topic') - last_time = last_dec.get('timestamp', '未知') - + if context.get("last_decision"): + last_dec = context["last_decision"] + last_action = last_dec.get("action", "未知") + last_reasoning = last_dec.get("reasoning", "无") + last_topic = last_dec.get("topic") + last_time = last_dec.get("timestamp", "未知") + last_decision_text = f""" 【上次主动思考的决策】 - 时间: {last_time} @@ -217,103 +214,100 @@ class ProactiveThinkingPlanner: if last_topic: last_decision_text += f"\n- 话题: {last_topic}" - return f"""你是一个有着独特个性的AI助手。你的人设是: + return f"""你的人设是: {context['bot_personality']} -现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。 +现在是 {context['current_time']},你正在考虑是否要在与 "{context['stream_name']}" 的对话中主动说些什么。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境信息】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} -- 常见话题: {context['topic_keywords'] or '暂无'} -- 你的兴趣程度: {context['interest_score']:.2f}/1.0 +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} +- 常见话题: {context["topic_keywords"] or "暂无"} +- 你的兴趣程度: {context["interest_score"]:.2f}/1.0 {last_decision_text} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} -请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么: +请根据以上信息,决定你现在应该做什么: **选项1:什么都不做 (do_nothing)** -- 适用场景:现在可能是休息时间、工作时间,或者气氛不适合说话 -- 也可能是:最近聊天很活跃不需要你主动、没什么特别想说的、此时说话会显得突兀 -- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默 +- 适用场景:气氛不适合说话、最近对话很活跃、没什么特别想说的、或者此时说话会显得突兀。 +- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默。 **选项2:简单冒个泡 (simple_bubble)** -- 适用场景:群里有点冷清,你想引起注意或活跃气氛 -- 方式:简单问个好、发个表情、说句无关紧要的话,没有深意,就是刷个存在感 -- 心情影响:心情好时可能更活跃;心情不好时也可能需要倾诉或找人陪伴 +- 适用场景:对话有些冷清,你想缓和气氛或开启新的互动。 +- 方式:说一句轻松随意的话,旨在建立或维持连接。 +- 心情影响:心情会影响你冒泡的方式和内容。 -**选项3:抛出一个话题 (throw_topic)** -- 适用场景:历史消息中有未讨论完的话题、你有自己的想法、或者想深入聊某个主题 -- 方式:明确提出一个话题,希望得到回应和讨论 -- 心情影响:心情会影响你想聊的话题类型和语气 +**选项3:发起一次有目的的互动 (throw_topic)** +- 适用场景:你想延续对话、表达关心、或深入讨论某个具体话题。 +- **【互动类型1:延续约定或提醒】(最高优先级)**:检查最近的聊天记录,是否存在可以延续的互动。例如,如果昨晚的最后一条消息是“晚安”,现在是早上,一个“早安”的回应是绝佳的选择。如果之前提到过某个约定(如“待会聊”),现在可以主动跟进。 +- **【互动类型2:展现真诚的关心】(次高优先级)**:如果不存在可延续的约定,请仔细阅读聊天记录,寻找对方提及的个人状况(如天气、出行、身体、情绪、工作学习等),并主动表达关心。 +- **【互动类型3:开启新话题】**:当以上两点都不适用时,可以考虑开启一个你感兴趣的新话题。 +- 心情影响:心情会影响你想发起互动的方式和内容。 请以JSON格式回复你的决策: {{ "action": "do_nothing" | "simple_bubble" | "throw_topic", - "reasoning": "你的决策理由,说明为什么选择这个行动(要结合你的心情和上次决策考虑)", - "topic": "(仅当action=throw_topic时填写)你想抛出的具体话题" + "reasoning": "你的决策理由(请结合你的心情、聊天环境和对话历史进行分析)", + "topic": "(仅当action=throw_topic时填写)你的互动意图(如:回应晚安并说早安、关心对方的考试情况、讨论新游戏)" }} 注意: -1. 如果最近聊天很活跃(不到1小时),倾向于选择 do_nothing -2. 如果你对这个环境兴趣不高(<0.4),倾向于选择 do_nothing 或 simple_bubble -3. 考虑你的心情:心情会影响你的行动倾向和表达方式 -4. 参考上次决策:避免重复相同的话题,也可以根据上次效果调整策略 -3. 只有在真的有话题想聊时才选择 throw_topic -4. 符合你的人设,不要太过热情或冷淡 +1. 兴趣度较低(<0.4)时或者最近聊天很活跃(不到1小时),倾向于 `do_nothing` 或 `simple_bubble`。 +2. 你的心情会影响你的行动倾向和表达方式。 +3. 参考上次决策,避免重复,并可根据上次的互动效果调整策略。 +4. 只有在真的有感而发时才选择 `throw_topic`。 +5. 保持你的人设,确保行为一致性。 """ - + async def generate_reply( - self, - context: dict[str, Any], - action: Literal["simple_bubble", "throw_topic"], - topic: Optional[str] = None - ) -> Optional[str]: + self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None + ) -> str | None: """生成回复内容 - + Args: context: 上下文信息 action: 动作类型 topic: (可选) 话题内容,当action=throw_topic时必须提供 - + Returns: str: 生成的回复文本,失败返回None """ if not self.reply_llm: logger.error("回复LLM未初始化") return None - + try: reply_prompt = await self._build_reply_prompt(context, action, topic) - + if global_config.debug.show_prompt: logger.info(f"回复提示词:\n{reply_prompt}") - + response, _ = await self.reply_llm.generate_response_async(prompt=reply_prompt) - + if not response: logger.warning("LLM未返回有效回复") return None - + logger.info(f"生成回复成功: {response[:50]}...") return response.strip() - + except Exception as e: logger.error(f"生成回复失败: {e}", exc_info=True) return None - + async def _get_expression_habits(self, stream_id: str, chat_history: str) -> str: """获取表达方式参考 - + Args: stream_id: 聊天流ID chat_history: 聊天历史 - + Returns: str: 格式化的表达方式参考文本 """ @@ -324,15 +318,15 @@ class ProactiveThinkingPlanner: chat_history=chat_history, target_message=None, # 主动思考没有target message max_num=6, # 主动思考时使用较少的表达方式 - min_num=2 + min_num=2, ) - + if not selected_expressions: return "" - + style_habits = [] grammar_habits = [] - + for expr in selected_expressions: if isinstance(expr, dict) and "situation" in expr and "style" in expr: expr_type = expr.get("type", "style") @@ -340,7 +334,7 @@ class ProactiveThinkingPlanner: grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") else: style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - + expression_block = "" if style_habits or grammar_habits: expression_block = "\n【表达方式参考】\n" @@ -349,97 +343,98 @@ class ProactiveThinkingPlanner: if grammar_habits: expression_block += "句法特点:\n" + "\n".join(grammar_habits) + "\n" expression_block += "注意:仅在情景合适时自然地使用这些表达,不要生硬套用。\n" - + return expression_block - + except Exception as e: logger.warning(f"获取表达方式失败: {e}") return "" - + async def _build_reply_prompt( - self, - context: dict[str, Any], - action: Literal["simple_bubble", "throw_topic"], - topic: Optional[str] + self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None ) -> str: """构建回复提示词""" # 获取表达方式参考 expression_habits = await self._get_expression_habits( - stream_id=context.get('stream_id', ''), - chat_history=context.get('recent_chat_history', '') + stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "") ) - + if action == "simple_bubble": - return f"""你是一个有着独特个性的AI助手。你的人设是: + return f"""你的人设是: {context['bot_personality']} -现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。 +距离上次对话已经有一段时间了,你决定主动说些什么,轻松地开启新的互动。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} {expression_habits} -请生成一条简短的消息,用于水群。要求: -1. 非常简短(5-15字) -2. 轻松随意,不要有明确的话题或问题 -3. 可以是:问候、表达心情、随口一句话 -4. 符合你的人设和当前聊天风格 -5. **你的心情应该影响消息的内容和语气**(比如心情好时可能更活泼,心情不好时可能更低落) -6. 如果有表达方式参考,在合适时自然使用 -7. 合理参考历史记录 +请生成一条简短的消息,用于水群。 +【要求】 +1. 风格简短随意(5-20字) +2. 不要提出明确的话题或问题,可以是问候、表达心情或一句随口的话。 +3. 符合你的人设和当前聊天风格。 +4. **你的心情应该影响消息的内容和语气**。 +5. 如果有表达方式参考,在合适时自然使用。 +6. 合理参考历史记录。 直接输出消息内容,不要解释:""" - + else: # throw_topic - return f"""你是一个有着独特个性的AI助手。你的人设是: + return f"""你的人设是: {context['bot_personality']} -现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。 +现在是 {context['current_time']},你决定在与 "{context['stream_name']}" 的对话中主动发起一次互动。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} -- 常见话题: {context['topic_keywords'] or '暂无'} +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} +- 常见话题: {context["topic_keywords"] or "暂无"} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} -【你想抛出的话题】 +【你的互动意图】 {topic} {expression_habits} -请根据这个话题生成一条消息,要求: -1. 明确提出话题,引导讨论 -2. 长度适中(20-50字) -3. 自然地引入话题,不要生硬 -4. 可以结合最近的聊天记录 -5. 符合你的人设和当前聊天风格 -6. **你的心情应该影响话题的选择和表达方式**(比如心情好时可能更积极,心情不好时可能需要倾诉或寻求安慰) -7. 如果有表达方式参考,在合适时自然使用 +【构思指南】 +请根据你的互动意图,生成一条有温度的消息。 +- 如果意图是**延续约定**(如回应“晚安”),请直接生成对应的问候。 +- 如果意图是**表达关心**(如跟进对方提到的事),请生成自然、真诚的关心话语。 +- 如果意图是**开启新话题**,请自然地引入话题。 + +请根据这个意图,生成一条消息,要求: +1. 自然地引入话题或表达关心。 +2. 长度适中(20-50字)。 +3. 可以结合最近的聊天记录,使对话更连贯。 +4. 符合你的人设和当前聊天风格。 +5. **你的心情会影响你的表达方式**。 +6. 如果有表达方式参考,在合适时自然使用。 直接输出消息内容,不要解释:""" - + def _clean_json_response(self, response: str) -> str: """清理LLM响应中的JSON格式标记""" import re - + cleaned = response.strip() cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: - cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned[json_start : json_end + 1] + return cleaned.strip() @@ -452,7 +447,7 @@ _statistics: dict[str, dict[str, Any]] = {} def _update_statistics(stream_id: str, action: str): """更新统计数据 - + Args: stream_id: 聊天流ID action: 执行的动作 @@ -465,18 +460,18 @@ def _update_statistics(stream_id: str, action: str): "throw_topic_count": 0, "last_execution_time": None, } - + _statistics[stream_id]["total_executions"] += 1 _statistics[stream_id][f"{action}_count"] += 1 _statistics[stream_id]["last_execution_time"] = datetime.now().isoformat() -def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]: +def get_statistics(stream_id: str | None = None) -> dict[str, Any]: """获取统计数据 - + Args: stream_id: 聊天流ID,None表示获取所有统计 - + Returns: 统计数据字典 """ @@ -487,7 +482,7 @@ def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]: async def execute_proactive_thinking(stream_id: str): """执行主动思考(被调度器调用的回调函数) - + Args: stream_id: 聊天流ID """ @@ -495,125 +490,125 @@ async def execute_proactive_thinking(stream_id: str): from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import ( proactive_thinking_scheduler, ) - + config = global_config.proactive_thinking - + logger.debug(f"🤔 开始主动思考 {stream_id}") - + try: # 0. 前置检查 if proactive_thinking_scheduler._is_in_quiet_hours(): - logger.debug(f"安静时段,跳过") + logger.debug("安静时段,跳过") return - + if not proactive_thinking_scheduler._check_daily_limit(stream_id): - logger.debug(f"今日发言达上限") + logger.debug("今日发言达上限") return - + # 1. 搜集信息 - logger.debug(f"步骤1: 搜集上下文") + logger.debug("步骤1: 搜集上下文") context = await _planner.gather_context(stream_id) if not context: - logger.warning(f"无法搜集上下文,跳过") + logger.warning("无法搜集上下文,跳过") return # 检查兴趣分数阈值 - interest_score = context.get('interest_score', 0.5) + interest_score = context.get("interest_score", 0.5) if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score): - logger.debug(f"兴趣分数不在阈值范围内") + logger.debug("兴趣分数不在阈值范围内") return - + # 2. 进行决策 - logger.debug(f"步骤2: LLM决策") + logger.debug("步骤2: LLM决策") decision = await _planner.make_decision(context) if not decision: - logger.warning(f"决策失败,跳过") + logger.warning("决策失败,跳过") return - + action = decision.get("action", "do_nothing") reasoning = decision.get("reasoning", "无") - + # 记录决策日志 if config.log_decisions: logger.debug(f"决策: action={action}, reasoning={reasoning}") - + # 3. 根据决策执行相应动作 if action == "do_nothing": logger.debug(f"决策:什么都不做。理由:{reasoning}") proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None) return - + elif action == "simple_bubble": logger.info(f"💬 决策:冒个泡。理由:{reasoning}") - + proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None) - + # 生成简单的消息 - logger.debug(f"步骤3: 生成冒泡回复") + logger.debug("步骤3: 生成冒泡回复") reply = await _planner.generate_reply(context, "simple_bubble") if reply: await send_api.text_to_stream( stream_id=stream_id, text=reply, ) - logger.info(f"✅ 已发送冒泡消息") - + logger.info("✅ 已发送冒泡消息") + # 增加每日计数 proactive_thinking_scheduler._increment_daily_count(stream_id) - + # 更新统计 if config.enable_statistics: _update_statistics(stream_id, action) - + # 冒泡后暂停主动思考,等待用户回复 # 使用与 topic_throw 相同的冷却时间配置 if config.topic_throw_cooldown > 0: - logger.info(f"[主动思考] 步骤5:暂停任务") + logger.info("[主动思考] 步骤5:暂停任务") await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡") logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复") - logger.info(f"[主动思考] simple_bubble 执行完成") - + logger.info("[主动思考] simple_bubble 执行完成") + elif action == "throw_topic": topic = decision.get("topic", "") logger.info(f"[主动思考] 决策:抛出话题。理由:{reasoning},话题:{topic}") - + # 记录决策 proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, topic) - + if not topic: logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡") - logger.info(f"[主动思考] 步骤3:生成降级冒泡回复") + logger.info("[主动思考] 步骤3:生成降级冒泡回复") reply = await _planner.generate_reply(context, "simple_bubble") else: # 生成基于话题的消息 - logger.info(f"[主动思考] 步骤3:生成话题回复") + logger.info("[主动思考] 步骤3:生成话题回复") reply = await _planner.generate_reply(context, "throw_topic", topic) - + if reply: - logger.info(f"[主动思考] 步骤4:发送消息") + logger.info("[主动思考] 步骤4:发送消息") await send_api.text_to_stream( stream_id=stream_id, text=reply, ) logger.info(f"[主动思考] 已发送话题消息到 {stream_id}") - + # 增加每日计数 proactive_thinking_scheduler._increment_daily_count(stream_id) - + # 更新统计 if config.enable_statistics: _update_statistics(stream_id, action) - + # 抛出话题后暂停主动思考(如果配置了冷却时间) if config.topic_throw_cooldown > 0: - logger.info(f"[主动思考] 步骤5:暂停任务") + logger.info("[主动思考] 步骤5:暂停任务") await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题") logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复") - logger.info(f"[主动思考] throw_topic 执行完成") + logger.info("[主动思考] throw_topic 执行完成") logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成") - + except Exception as e: logger.error(f"[主动思考] 执行主动思考失败: {e}", exc_info=True) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py index 33e90654d..47ed467cd 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py @@ -6,20 +6,17 @@ import asyncio from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams from src.common.logger import get_logger from src.schedule.unified_scheduler import TriggerType, unified_scheduler -from sqlalchemy import select logger = get_logger("proactive_thinking_scheduler") class ProactiveThinkingScheduler: """主动思考调度器 - + 负责为每个聊天流创建和管理主动思考任务。 特点: 1. 根据聊天流的兴趣分数动态计算触发间隔 @@ -32,27 +29,28 @@ class ProactiveThinkingScheduler: self._stream_schedules: dict[str, str] = {} # stream_id -> schedule_id self._paused_streams: set[str] = set() # 因抛出话题而暂停的聊天流 self._lock = asyncio.Lock() - + # 统计数据 self._statistics: dict[str, dict[str, Any]] = {} # stream_id -> 统计信息 self._daily_counts: dict[str, dict[str, int]] = {} # stream_id -> {date: count} - + # 历史决策记录:stream_id -> 上次决策信息 self._last_decisions: dict[str, dict[str, Any]] = {} - + # 从全局配置加载(延迟导入避免循环依赖) from src.config.config import global_config + self.config = global_config.proactive_thinking - + def _calculate_interval(self, focus_energy: float) -> int: """根据 focus_energy 计算触发间隔 - + Args: focus_energy: 聊天流的 focus_energy 值 (0.0-1.0) - + Returns: int: 触发间隔(秒) - + 公式: - focus_energy 越高,间隔越短(更频繁思考) - interval = base_interval * (factor - focus_energy) @@ -63,26 +61,26 @@ class ProactiveThinkingScheduler: # 如果不使用 focus_energy,直接返回基础间隔 if not self.config.use_interest_score: return self.config.base_interval - + # 确保值在有效范围内 focus_energy = max(0.0, min(1.0, focus_energy)) - + # 计算间隔:focus_energy 越高,系数越小,间隔越短 factor = self.config.interest_score_factor - focus_energy interval = int(self.config.base_interval * factor) - + # 限制在最小和最大间隔之间 interval = max(self.config.min_interval, min(self.config.max_interval, interval)) - - logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)") + + logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)") return interval - + def _check_whitelist_blacklist(self, stream_config: str) -> bool: """检查聊天流是否通过黑白名单验证 - + Args: stream_config: 聊天流配置字符串,格式: "platform:id:type" - + Returns: bool: True表示允许主动思考,False表示拒绝 """ @@ -91,148 +89,148 @@ class ProactiveThinkingScheduler: if len(parts) != 3: logger.warning(f"无效的stream_config格式: {stream_config}") return False - + is_private = parts[2] == "private" - + # 检查基础开关 if is_private and not self.config.enable_in_private: return False if not is_private and not self.config.enable_in_group: return False - + # 黑名单检查(优先级高) if self.config.blacklist_mode: blacklist = self.config.blacklist_private if is_private else self.config.blacklist_group if stream_config in blacklist: logger.debug(f"聊天流 {stream_config} 在黑名单中,拒绝主动思考") return False - + # 白名单检查 if self.config.whitelist_mode: whitelist = self.config.whitelist_private if is_private else self.config.whitelist_group if stream_config not in whitelist: logger.debug(f"聊天流 {stream_config} 不在白名单中,拒绝主动思考") return False - + return True - + def _check_interest_score_threshold(self, interest_score: float) -> bool: """检查兴趣分数是否在阈值范围内 - + Args: interest_score: 兴趣分数 - + Returns: bool: True表示在范围内 """ if interest_score < self.config.min_interest_score: logger.debug(f"兴趣分数 {interest_score:.2f} 低于最低阈值 {self.config.min_interest_score}") return False - + if interest_score > self.config.max_interest_score: logger.debug(f"兴趣分数 {interest_score:.2f} 高于最高阈值 {self.config.max_interest_score}") return False - + return True - + def _check_daily_limit(self, stream_id: str) -> bool: """检查今日主动发言次数是否超限 - + Args: stream_id: 聊天流ID - + Returns: bool: True表示未超限 """ if self.config.max_daily_proactive == 0: return True # 不限制 - + today = datetime.now().strftime("%Y-%m-%d") - + if stream_id not in self._daily_counts: self._daily_counts[stream_id] = {} - + # 清理过期日期的数据 for date in list(self._daily_counts[stream_id].keys()): if date != today: del self._daily_counts[stream_id][date] - + count = self._daily_counts[stream_id].get(today, 0) - + if count >= self.config.max_daily_proactive: logger.debug(f"聊天流 {stream_id} 今日主动发言次数已达上限 ({count}/{self.config.max_daily_proactive})") return False - + return True - + def _increment_daily_count(self, stream_id: str): """增加今日主动发言计数""" today = datetime.now().strftime("%Y-%m-%d") - + if stream_id not in self._daily_counts: self._daily_counts[stream_id] = {} - + self._daily_counts[stream_id][today] = self._daily_counts[stream_id].get(today, 0) + 1 - + def _is_in_quiet_hours(self) -> bool: """检查当前是否在安静时段 - + Returns: bool: True表示在安静时段 """ if not self.config.enable_time_strategy: return False - + now = datetime.now() current_time = now.strftime("%H:%M") - + start = self.config.quiet_hours_start end = self.config.quiet_hours_end - + # 处理跨日的情况(如23:00-07:00) if start <= end: return start <= current_time <= end else: return current_time >= start or current_time <= end - + async def _get_stream_focus_energy(self, stream_id: str) -> float: """获取聊天流的 focus_energy - + Args: stream_id: 聊天流ID - + Returns: float: focus_energy 值,默认0.5 """ try: # 从聊天管理器获取聊天流 from src.chat.message_receive.chat_stream import get_chat_manager - - logger.debug(f"[调度器] 获取聊天管理器") + + logger.debug("[调度器] 获取聊天管理器") chat_manager = get_chat_manager() logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}") chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream: # 计算并获取最新的 focus_energy - logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy") + logger.debug("[调度器] 找到聊天流,开始计算 focus_energy") focus_energy = await chat_stream.calculate_focus_energy() logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}") return focus_energy else: logger.warning(f"[调度器] ⚠️ 未找到聊天流 {stream_id},使用默认 focus_energy=0.5") return 0.5 - + except Exception as e: logger.error(f"[调度器] ❌ 获取聊天流 {stream_id} 的 focus_energy 失败: {e}", exc_info=True) return 0.5 - + async def schedule_proactive_thinking(self, stream_id: str) -> bool: """为聊天流创建或重置主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功创建/重置任务 """ @@ -243,25 +241,25 @@ class ProactiveThinkingScheduler: if stream_id in self._paused_streams: logger.debug(f"[调度器] 清除聊天流 {stream_id} 的暂停标记") self._paused_streams.discard(stream_id) - + # 如果已经有任务,先移除 if stream_id in self._stream_schedules: old_schedule_id = self._stream_schedules[stream_id] logger.debug(f"[调度器] 移除聊天流 {stream_id} 的旧任务") await unified_scheduler.remove_schedule(old_schedule_id) - + # 获取 focus_energy 并计算间隔 focus_energy = await self._get_stream_focus_energy(stream_id) logger.debug(f"[调度器] focus_energy={focus_energy:.3f}") - + interval_seconds = self._calculate_interval(focus_energy) - logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)") - + logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)") + # 导入回调函数(延迟导入避免循环依赖) from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import ( execute_proactive_thinking, ) - + # 创建新任务 schedule_id = await unified_scheduler.create_schedule( callback=execute_proactive_thinking, @@ -273,34 +271,34 @@ class ProactiveThinkingScheduler: task_name=f"ProactiveThinking-{stream_id}", callback_args=(stream_id,), ) - + self._stream_schedules[stream_id] = schedule_id - + # 计算下次触发时间 next_run_time = datetime.now() + timedelta(seconds=interval_seconds) - + logger.info( f"✅ 聊天流 {stream_id} 主动思考任务已创建 | " f"Focus: {focus_energy:.3f} | " - f"间隔: {interval_seconds/60:.1f}分钟 | " + f"间隔: {interval_seconds / 60:.1f}分钟 | " f"下次: {next_run_time.strftime('%H:%M:%S')}" ) return True - + except Exception as e: logger.error(f"❌ 创建主动思考任务失败 {stream_id}: {e}", exc_info=True) return False - + async def pause_proactive_thinking(self, stream_id: str, reason: str = "抛出话题") -> bool: """暂停聊天流的主动思考任务 - + 当选择"抛出话题"后,应该暂停该聊天流的主动思考, 直到bot至少执行过一次reply后才恢复。 - + Args: stream_id: 聊天流ID reason: 暂停原因 - + Returns: bool: 是否成功暂停 """ @@ -309,26 +307,26 @@ class ProactiveThinkingScheduler: if stream_id not in self._stream_schedules: logger.warning(f"尝试暂停不存在的任务: {stream_id}") return False - + schedule_id = self._stream_schedules[stream_id] success = await unified_scheduler.pause_schedule(schedule_id) - + if success: self._paused_streams.add(stream_id) logger.info(f"⏸️ 暂停主动思考 {stream_id},原因: {reason}") - + return success - - except Exception as e: + + except Exception: # 错误日志已在上面记录 return False - + async def resume_proactive_thinking(self, stream_id: str) -> bool: """恢复聊天流的主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功恢复 """ @@ -337,26 +335,26 @@ class ProactiveThinkingScheduler: if stream_id not in self._stream_schedules: logger.warning(f"尝试恢复不存在的任务: {stream_id}") return False - + schedule_id = self._stream_schedules[stream_id] success = await unified_scheduler.resume_schedule(schedule_id) - + if success: self._paused_streams.discard(stream_id) logger.info(f"▶️ 恢复主动思考 {stream_id}") - + return success - + except Exception as e: logger.error(f"❌ 恢复主动思考失败 {stream_id}: {e}", exc_info=True) return False - + async def cancel_proactive_thinking(self, stream_id: str) -> bool: """取消聊天流的主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功取消 """ @@ -364,55 +362,55 @@ class ProactiveThinkingScheduler: async with self._lock: if stream_id not in self._stream_schedules: return True # 已经不存在,视为成功 - + schedule_id = self._stream_schedules.pop(stream_id) self._paused_streams.discard(stream_id) - + success = await unified_scheduler.remove_schedule(schedule_id) logger.debug(f"⏹️ 取消主动思考 {stream_id}") - + return success - + except Exception as e: logger.error(f"❌ 取消主动思考失败 {stream_id}: {e}", exc_info=True) return False - + async def is_paused(self, stream_id: str) -> bool: """检查聊天流的主动思考是否被暂停 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否暂停中 """ async with self._lock: return stream_id in self._paused_streams - - async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def get_task_info(self, stream_id: str) -> dict[str, Any] | None: """获取聊天流的主动思考任务信息 - + Args: stream_id: 聊天流ID - + Returns: dict: 任务信息,如果不存在返回None """ async with self._lock: if stream_id not in self._stream_schedules: return None - + schedule_id = self._stream_schedules[stream_id] task_info = await unified_scheduler.get_task_info(schedule_id) - + if task_info: task_info["is_paused_for_topic"] = stream_id in self._paused_streams - + return task_info - + async def list_all_tasks(self) -> list[dict[str, Any]]: """列出所有主动思考任务 - + Returns: list: 任务信息列表 """ @@ -425,10 +423,10 @@ class ProactiveThinkingScheduler: task_info["is_paused_for_topic"] = stream_id in self._paused_streams tasks.append(task_info) return tasks - + def get_statistics(self) -> dict[str, Any]: """获取调度器统计信息 - + Returns: dict: 统计信息 """ @@ -437,51 +435,48 @@ class ProactiveThinkingScheduler: "paused_for_topic": len(self._paused_streams), "active_tasks": len(self._stream_schedules) - len(self._paused_streams), } - + async def log_next_trigger_times(self, max_streams: int = 10): """在日志中输出聊天流的下次触发时间 - + Args: max_streams: 最多显示多少个聊天流,0表示全部 """ logger.info("=" * 60) logger.info("主动思考任务状态") logger.info("=" * 60) - + tasks = await self.list_all_tasks() - + if not tasks: logger.info("当前没有活跃的主动思考任务") logger.info("=" * 60) return - + # 按下次触发时间排序 - tasks_sorted = sorted( - tasks, - key=lambda x: x.get("next_run_time", datetime.max) or datetime.max - ) - + tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max) + # 限制显示数量 if max_streams > 0: tasks_sorted = tasks_sorted[:max_streams] - + logger.info(f"共有 {len(self._stream_schedules)} 个任务,显示前 {len(tasks_sorted)} 个") logger.info("") - + for i, task in enumerate(tasks_sorted, 1): stream_id = task.get("stream_id", "Unknown") next_run = task.get("next_run_time") is_paused = task.get("is_paused_for_topic", False) - + # 获取聊天流名称(如果可能) stream_name = stream_id[:16] + "..." if len(stream_id) > 16 else stream_id - + if next_run: # 计算剩余时间 now = datetime.now() remaining = next_run - now remaining_seconds = int(remaining.total_seconds()) - + if remaining_seconds < 0: time_str = "已过期(待执行)" elif remaining_seconds < 60: @@ -492,28 +487,25 @@ class ProactiveThinkingScheduler: hours = remaining_seconds // 3600 minutes = (remaining_seconds % 3600) // 60 time_str = f"{hours}小时{minutes}分钟后" - + status = "⏸️ 暂停中" if is_paused else "✅ 活跃" - + logger.info( f"[{i:2d}] {status} | {stream_name}\n" f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})" ) else: - logger.info( - f"[{i:2d}] ⚠️ 未知 | {stream_name}\n" - f" 下次触发: 未设置" - ) - + logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置") + logger.info("") logger.info("=" * 60) - - def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]: + + def get_last_decision(self, stream_id: str) -> dict[str, Any] | None: """获取聊天流的上次主动思考决策 - + Args: stream_id: 聊天流ID - + Returns: dict: 上次决策信息,包含: - action: "do_nothing" | "simple_bubble" | "throw_topic" @@ -523,16 +515,10 @@ class ProactiveThinkingScheduler: None: 如果没有历史决策 """ return self._last_decisions.get(stream_id) - - def record_decision( - self, - stream_id: str, - action: str, - reasoning: str, - topic: Optional[str] = None - ) -> None: + + def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None: """记录聊天流的主动思考决策 - + Args: stream_id: 聊天流ID action: 决策动作 diff --git a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py index b4fc68526..00240b024 100644 --- a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py @@ -4,10 +4,10 @@ 通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数 """ -import orjson import time from typing import Any +import orjson from sqlalchemy import select from src.common.database.sqlalchemy_database_api import get_db_session @@ -42,7 +42,7 @@ class UserProfileTool(BaseTool): def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): super().__init__(plugin_config, chat_stream) - + # 初始化用于二步调用的LLM try: self.profile_llm = LLMRequest( @@ -84,24 +84,24 @@ class UserProfileTool(BaseTool): "id": "user_profile_update", "content": "错误:必须提供目标用户ID" } - + # 从LLM传入的参数 new_aliases = function_args.get("user_aliases", "") new_impression = function_args.get("impression_description", "") new_keywords = function_args.get("preference_keywords", "") new_score = function_args.get("affection_score") - + # 从数据库获取现有用户画像 existing_profile = await self._get_user_profile(target_user_id) - + # 如果LLM没有传入任何有效参数,返回提示 if not any([new_aliases, new_impression, new_keywords, new_score is not None]): return { "type": "info", "id": target_user_id, - "content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)" + "content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)" } - + # 调用LLM进行二步决策 if self.profile_llm is None: logger.error("LLM未正确初始化,无法执行二步调用") @@ -110,7 +110,7 @@ class UserProfileTool(BaseTool): "id": target_user_id, "content": "系统错误:LLM未正确初始化" } - + final_profile = await self._llm_decide_final_profile( target_user_id=target_user_id, existing_profile=existing_profile, @@ -119,17 +119,17 @@ class UserProfileTool(BaseTool): new_keywords=new_keywords, new_score=new_score ) - + if not final_profile: return { "type": "error", "id": target_user_id, "content": "LLM决策失败,无法更新用户画像" } - + # 更新数据库 await self._update_user_profile_in_db(target_user_id, final_profile) - + # 构建返回信息 updates = [] if final_profile.get("user_aliases"): @@ -140,22 +140,22 @@ class UserProfileTool(BaseTool): updates.append(f"偏好: {final_profile['preference_keywords']}") if final_profile.get("relationship_score") is not None: updates.append(f"好感分: {final_profile['relationship_score']:.2f}") - + result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates) logger.info(f"用户画像更新成功: {target_user_id}") - + return { "type": "user_profile_update", "id": target_user_id, "content": result_text } - + except Exception as e: logger.error(f"用户画像更新失败: {e}", exc_info=True) return { "type": "error", "id": function_args.get("target_user_id", "unknown"), - "content": f"用户画像更新失败: {str(e)}" + "content": f"用户画像更新失败: {e!s}" } async def _get_user_profile(self, user_id: str) -> dict[str, Any]: @@ -172,7 +172,7 @@ class UserProfileTool(BaseTool): stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) profile = result.scalar_one_or_none() - + if profile: return { "user_name": profile.user_name or user_id, @@ -227,7 +227,7 @@ class UserProfileTool(BaseTool): from src.individuality.individuality import Individuality individuality = Individuality() bot_personality = await individuality.get_personality_block() - + prompt = f""" 你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} @@ -261,18 +261,18 @@ class UserProfileTool(BaseTool): "reasoning": "你的决策理由" }} """ - + # 调用LLM llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt) - + if not llm_response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析响应 cleaned_response = self._clean_llm_json_response(llm_response) response_data = orjson.loads(cleaned_response) - + # 提取最终决定的数据 final_profile = { "user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")), @@ -280,12 +280,12 @@ class UserProfileTool(BaseTool): "preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")), "relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))), } - + logger.info(f"LLM决策完成: {target_user_id}") logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") - + return final_profile - + except orjson.JSONDecodeError as e: logger.error(f"LLM响应JSON解析失败: {e}") logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") @@ -303,12 +303,12 @@ class UserProfileTool(BaseTool): """ try: current_time = time.time() - + async with get_db_session() as session: stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 更新现有记录 existing.user_aliases = profile.get("user_aliases", "") @@ -328,10 +328,10 @@ class UserProfileTool(BaseTool): last_updated=current_time ) session.add(new_profile) - + await session.commit() logger.info(f"用户画像已更新到数据库: {user_id}") - + except Exception as e: logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True) raise @@ -347,24 +347,24 @@ class UserProfileTool(BaseTool): """ try: import re - + cleaned = response.strip() - + # 移除 ```json 或 ``` 等标记 cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + # 尝试找到JSON对象的开始和结束 json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned.strip() - + return cleaned - + except Exception as e: logger.warning(f"清理LLM响应失败: {e}") return response diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 8d75ca2fd..179d7997a 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -261,7 +261,7 @@ class SetEmojiLikeAction(BaseAction): elif isinstance(self.action_message, dict): message_id = self.action_message.get("message_id") logger.info(f"获取到的消息ID: {message_id}") - + if not message_id: logger.error("未提供有效的消息或消息ID") await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False) @@ -279,7 +279,7 @@ class SetEmojiLikeAction(BaseAction): context_text = self.action_message.processed_plain_text or "" else: context_text = self.action_message.get("processed_plain_text", "") - + if not context_text: logger.error("无法找到动作选择的原始消息文本") return False, "无法找到动作选择的原始消息文本" diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index 44e0082e0..a47a41ea1 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -5,7 +5,7 @@ Web Search Tool Plugin """ from src.common.logger import get_logger -from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin +from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin from src.plugin_system.apis import config_api from .tools.url_parser import URLParserTool diff --git a/src/schedule/unified_scheduler.py b/src/schedule/unified_scheduler.py index 0a2fc859b..aff48ee83 100644 --- a/src/schedule/unified_scheduler.py +++ b/src/schedule/unified_scheduler.py @@ -5,9 +5,10 @@ import asyncio import uuid -from datetime import datetime, timedelta +from collections.abc import Awaitable, Callable +from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable, Optional +from typing import Any from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType @@ -33,9 +34,9 @@ class ScheduleTask: trigger_type: TriggerType, trigger_config: dict[str, Any], is_recurring: bool = False, - task_name: Optional[str] = None, - callback_args: Optional[tuple] = None, - callback_kwargs: Optional[dict] = None, + task_name: str | None = None, + callback_args: tuple | None = None, + callback_kwargs: dict | None = None, ): self.schedule_id = schedule_id self.callback = callback @@ -46,7 +47,7 @@ class ScheduleTask: self.callback_args = callback_args or () self.callback_kwargs = callback_kwargs or {} self.created_at = datetime.now() - self.last_triggered_at: Optional[datetime] = None + self.last_triggered_at: datetime | None = None self.trigger_count = 0 self.is_active = True @@ -77,7 +78,7 @@ class UnifiedScheduler: def __init__(self): self._tasks: dict[str, ScheduleTask] = {} self._running = False - self._check_task: Optional[asyncio.Task] = None + self._check_task: asyncio.Task | None = None self._lock = asyncio.Lock() self._event_subscriptions: set[str] = set() # 追踪已订阅的事件 @@ -111,7 +112,7 @@ class UnifiedScheduler: for task in event_tasks: try: logger.debug(f"[调度器] 执行事件任务: {task.task_name}") - + # 执行回调,传入事件参数 if event_params: if asyncio.iscoroutinefunction(task.callback): @@ -127,7 +128,7 @@ class UnifiedScheduler: # 如果不是循环任务,标记为删除 if not task.is_recurring: tasks_to_remove.append(task.schedule_id) - + logger.debug(f"[调度器] 事件任务 {task.task_name} 执行完成") except Exception as e: @@ -204,11 +205,11 @@ class UnifiedScheduler: 注意:为了避免死锁,回调执行必须在锁外进行 """ current_time = datetime.now() - + # 第一阶段:在锁内快速收集需要触发的任务 async with self._lock: tasks_to_trigger = [] - + for schedule_id, task in list(self._tasks.items()): if not task.is_active: continue @@ -219,14 +220,14 @@ class UnifiedScheduler: tasks_to_trigger.append(task) except Exception as e: logger.error(f"检查任务 {task.task_name} 时发生错误: {e}", exc_info=True) - + # 第二阶段:在锁外执行回调(避免死锁) tasks_to_remove = [] - + for task in tasks_to_trigger: try: logger.debug(f"[调度器] 触发定时任务: {task.task_name}") - + # 执行回调 await self._execute_callback(task) @@ -339,9 +340,9 @@ class UnifiedScheduler: trigger_type: TriggerType, trigger_config: dict[str, Any], is_recurring: bool = False, - task_name: Optional[str] = None, - callback_args: Optional[tuple] = None, - callback_kwargs: Optional[dict] = None, + task_name: str | None = None, + callback_args: tuple | None = None, + callback_kwargs: dict | None = None, ) -> str: """创建调度任务(详细注释见文档)""" schedule_id = str(uuid.uuid4()) @@ -430,7 +431,7 @@ class UnifiedScheduler: logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)") return True - async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]: + async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None: """获取任务信息""" async with self._lock: task = self._tasks.get(schedule_id) @@ -449,7 +450,7 @@ class UnifiedScheduler: "trigger_config": task.trigger_config.copy(), } - async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]: + async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]: """列出所有任务或指定类型的任务""" async with self._lock: tasks = [] @@ -499,11 +500,11 @@ async def initialize_scheduler(): logger.info("正在启动统一调度器...") await unified_scheduler.start() logger.info("统一调度器启动成功") - + # 获取初始统计信息 stats = unified_scheduler.get_statistics() logger.info(f"调度器状态: {stats}") - + except Exception as e: logger.error(f"启动统一调度器失败: {e}", exc_info=True) raise @@ -516,20 +517,20 @@ async def shutdown_scheduler(): """ try: logger.info("正在关闭统一调度器...") - + # 显示最终统计 stats = unified_scheduler.get_statistics() logger.info(f"调度器最终统计: {stats}") - + # 列出剩余任务 remaining_tasks = await unified_scheduler.list_tasks() if remaining_tasks: logger.warning(f"检测到 {len(remaining_tasks)} 个未清理的任务:") for task in remaining_tasks: logger.warning(f" - {task['task_name']} (ID: {task['schedule_id'][:8]}...)") - + await unified_scheduler.stop() logger.info("统一调度器已关闭") - + except Exception as e: - logger.error(f"关闭统一调度器失败: {e}", exc_info=True) \ No newline at end of file + logger.error(f"关闭统一调度器失败: {e}", exc_info=True)