From a29be4809125d5a1b1b085500d88bf3dd1cc0730 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 31 Oct 2025 20:56:17 +0800 Subject: [PATCH] =?UTF-8?q?refactor(core):=20=E4=BC=98=E5=8C=96=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=8F=90=E7=A4=BA=E4=B8=8E=E4=BB=A3=E7=A0=81=E9=A3=8E?= =?UTF-8?q?=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面: 1. **类型提示现代化**: - 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。 - 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。 2. **代码风格统一**: - 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。 - 统一了部分日志输出的格式,增强了日志的可读性。 3. **导入语句优化**: - 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。 这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。 --- scripts/check_expression_database.py | 33 +- scripts/check_style_field.py | 23 +- scripts/debug_style_learner.py | 32 +- src/chat/energy_system/energy_manager.py | 5 +- src/chat/express/express_utils.py | 18 +- src/chat/express/expression_learner.py | 44 +-- src/chat/express/expression_selector.py | 90 ++--- src/chat/express/expressor_model/model.py | 17 +- src/chat/express/expressor_model/online_nb.py | 15 +- src/chat/express/expressor_model/tokenizer.py | 3 +- src/chat/express/situation_extractor.py | 43 ++- src/chat/express/style_learner.py | 39 ++- src/chat/memory_system/memory_system.py | 2 +- src/chat/message_manager/context_manager.py | 28 +- .../message_manager/distribution_manager.py | 40 +-- src/chat/message_manager/message_manager.py | 6 +- src/chat/message_receive/bot.py | 16 +- src/chat/message_receive/chat_stream.py | 23 +- src/chat/message_receive/message.py | 6 +- src/chat/message_receive/message_processor.py | 94 +++--- src/chat/message_receive/storage.py | 23 +- .../message_receive/uni_message_sender.py | 18 +- src/chat/planner_actions/action_manager.py | 2 +- src/chat/replyer/default_generator.py | 54 ++- src/chat/utils/prompt.py | 4 +- src/chat/utils/utils.py | 7 +- .../data_models/message_manager_data_model.py | 8 +- src/config/config.py | 2 +- src/config/official_configs.py | 26 +- src/main.py | 2 +- src/person_info/relationship_fetcher.py | 10 +- src/plugin_system/apis/send_api.py | 14 +- src/plugin_system/base/base_action.py | 2 +- src/plugin_system/base/base_command.py | 2 +- src/plugin_system/base/plus_command.py | 2 +- src/plugin_system/core/event_manager.py | 2 +- .../chat_stream_impression_tool.py | 207 +++++++----- .../affinity_flow_chatter/plan_executor.py | 15 +- .../affinity_flow_chatter/plan_filter.py | 30 +- .../built_in/affinity_flow_chatter/planner.py | 67 ++-- .../proactive_thinking_event.py | 70 ++-- .../proactive_thinking_executor.py | 315 +++++++++--------- .../proactive_thinking_scheduler.py | 268 +++++++-------- .../user_profile_tool.py | 70 ++-- .../built_in/social_toolkit_plugin/plugin.py | 4 +- .../built_in/web_search_tool/plugin.py | 2 +- src/schedule/unified_scheduler.py | 53 +-- 47 files changed, 923 insertions(+), 933 deletions(-) diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py index f600cc434..2341c2140 100644 --- a/scripts/check_expression_database.py +++ b/scripts/check_expression_database.py @@ -9,24 +9,25 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from sqlalchemy import select, func +from sqlalchemy import func, select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Expression async def check_database(): """检查表达方式数据库状态""" - + print("=" * 60) print("表达方式数据库诊断报告") print("=" * 60) - + async with get_db_session() as session: # 1. 统计总数 total_count = await session.execute(select(func.count()).select_from(Expression)) total = total_count.scalar() print(f"\n📊 总表达方式数量: {total}") - + if total == 0: print("\n⚠️ 数据库为空!") print("\n可能的原因:") @@ -38,7 +39,7 @@ async def check_database(): print("- 查看日志中是否有表达学习相关的错误") print("- 确认聊天流的 learn_expression 配置为 true") return - + # 2. 按 chat_id 统计 print("\n📝 按聊天流统计:") chat_counts = await session.execute( @@ -47,7 +48,7 @@ async def check_database(): ) for chat_id, count in chat_counts: print(f" - {chat_id}: {count} 个表达方式") - + # 3. 按 type 统计 print("\n📝 按类型统计:") type_counts = await session.execute( @@ -56,7 +57,7 @@ async def check_database(): ) for expr_type, count in type_counts: print(f" - {expr_type}: {count} 个") - + # 4. 检查 situation 和 style 字段是否有空值 print("\n🔍 字段完整性检查:") null_situation = await session.execute( @@ -69,30 +70,30 @@ async def check_database(): .select_from(Expression) .where(Expression.style == None) ) - + null_sit_count = null_situation.scalar() null_sty_count = null_style.scalar() - + print(f" - situation 为空: {null_sit_count} 个") print(f" - style 为空: {null_sty_count} 个") - + if null_sit_count > 0 or null_sty_count > 0: print(" ⚠️ 发现空值!这会导致匹配失败") - + # 5. 显示一些样例数据 print("\n📋 样例数据 (前10条):") samples = await session.execute( select(Expression) .limit(10) ) - + for i, expr in enumerate(samples.scalars(), 1): print(f"\n [{i}] Chat: {expr.chat_id}") print(f" Type: {expr.type}") print(f" Situation: {expr.situation}") print(f" Style: {expr.style}") print(f" Count: {expr.count}") - + # 6. 检查 style 字段的唯一值 print("\n📋 Style 字段样例 (前20个):") unique_styles = await session.execute( @@ -100,13 +101,13 @@ async def check_database(): .distinct() .limit(20) ) - + styles = [s for s in unique_styles.scalars()] for style in styles: print(f" - {style}") - + print(f"\n (共 {len(styles)} 个不同的 style)") - + print("\n" + "=" * 60) print("诊断完成") print("=" * 60) diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py index c8f5ef1fb..d28c8b240 100644 --- a/scripts/check_style_field.py +++ b/scripts/check_style_field.py @@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from sqlalchemy import select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Expression async def analyze_style_fields(): """分析 style 字段的内容""" - + print("=" * 60) print("Style 字段内容分析") print("=" * 60) - + async with get_db_session() as session: # 获取所有表达方式 result = await session.execute(select(Expression).limit(30)) expressions = result.scalars().all() - + print(f"\n总共检查 {len(expressions)} 条记录\n") - + # 按类型分类 style_examples = [] - + for expr in expressions: if expr.type == "style": style_examples.append({ @@ -37,7 +38,7 @@ async def analyze_style_fields(): "style": expr.style, "length": len(expr.style) if expr.style else 0 }) - + print("📋 Style 类型样例 (前15条):") print("="*60) for i, ex in enumerate(style_examples[:15], 1): @@ -45,17 +46,17 @@ async def analyze_style_fields(): print(f" Situation: {ex['situation']}") print(f" Style: {ex['style']}") print(f" 长度: {ex['length']} 字符") - + # 判断是具体表达还是风格描述 - if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']): + if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]): style_type = "✓ 风格描述" - elif ex['length'] <= 10: + elif ex["length"] <= 10: style_type = "? 可能是具体表达(较短)" else: style_type = "✗ 具体表达内容" - + print(f" 类型判断: {style_type}") - + print("\n" + "="*60) print("分析完成") print("="*60) diff --git a/scripts/debug_style_learner.py b/scripts/debug_style_learner.py index 970ba2532..1c0937ece 100644 --- a/scripts/debug_style_learner.py +++ b/scripts/debug_style_learner.py @@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner") def check_style_learner_status(chat_id: str): """检查指定 chat_id 的 StyleLearner 状态""" - + print("=" * 60) print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}") print("=" * 60) - + # 获取 learner learner = style_learner_manager.get_learner(chat_id) - + # 1. 基本信息 - print(f"\n📊 基本信息:") + print("\n📊 基本信息:") print(f" Chat ID: {learner.chat_id}") print(f" 风格数量: {len(learner.style_to_id)}") print(f" 下一个ID: {learner.next_style_id}") print(f" 最大风格数: {learner.max_styles}") - + # 2. 学习统计 - print(f"\n📈 学习统计:") + print("\n📈 学习统计:") print(f" 总样本数: {learner.learning_stats['total_samples']}") print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}") - + # 3. 风格列表(前20个) - print(f"\n📋 已学习的风格 (前20个):") + print("\n📋 已学习的风格 (前20个):") all_styles = learner.get_all_styles() if not all_styles: print(" ⚠️ 没有任何风格!模型尚未训练") @@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str): situation = learner.id_to_situation.get(style_id, "N/A") print(f" [{i}] {style}") print(f" (ID: {style_id}, Situation: {situation})") - + # 4. 测试预测 - print(f"\n🔮 测试预测功能:") + print("\n🔮 测试预测功能:") if not all_styles: print(" ⚠️ 无法测试,模型没有训练数据") else: @@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str): "讨论游戏", "表达赞同" ] - + for test_sit in test_situations: print(f"\n 测试输入: '{test_sit}'") best_style, scores = learner.predict_style(test_sit, top_k=3) - + if best_style: print(f" ✓ 最佳匹配: {best_style}") - print(f" Top 3:") + print(" Top 3:") for style, score in list(scores.items())[:3]: print(f" - {style}: {score:.4f}") else: - print(f" ✗ 预测失败") - + print(" ✗ 预测失败") + print("\n" + "=" * 60) print("诊断完成") print("=" * 60) @@ -82,7 +82,7 @@ if __name__ == "__main__": "52fb94af9f500a01e023ea780e43606e", # 有78个表达方式 "46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式 ] - + for chat_id in test_chat_ids: check_style_learner_status(chat_id) print("\n") diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 982cfccce..079147812 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator): # 从数据库获取聊天流兴趣分数 try: + from sqlalchemy import select + from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams - from sqlalchemy import select async with get_db_session() as session: stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if stream and stream.stream_interest_score is not None: interest_score = float(stream.stream_interest_score) logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}") diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py index bd7f41e2d..0d1baded1 100644 --- a/src/chat/express/express_utils.py +++ b/src/chat/express/express_utils.py @@ -5,14 +5,14 @@ import difflib import random import re -from typing import Any, Dict, List, Optional +from typing import Any from src.common.logger import get_logger logger = get_logger("express_utils") -def filter_message_content(content: Optional[str]) -> str: +def filter_message_content(content: str | None) -> str: """ 过滤消息内容,移除回复、@、图片等格式 @@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float: return difflib.SequenceMatcher(None, text1, text2).ratio() -def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]: +def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]: """ 加权随机抽样函数 @@ -108,7 +108,7 @@ def normalize_text(text: str) -> str: return text.strip() -def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: +def extract_keywords(text: str, max_keywords: int = 10) -> list[str]: """ 简单的关键词提取(基于词频) @@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]: return words[:max_keywords] -def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str: +def format_expression_pair(situation: str, style: str, index: int | None = None) -> str: """ 格式化表达方式对 @@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No return f'当"{situation}"时,使用"{style}"' -def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: +def parse_expression_pair(text: str) -> tuple[str, str] | None: """ 解析表达方式对文本 @@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]: return None -def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]: +def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]: """ 批量去重表达方式 @@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif def merge_expressions_from_multiple_chats( - expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100 -) -> List[Dict[str, Any]]: + expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100 +) -> list[dict[str, Any]]: """ 合并多个聊天室的表达方式 diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 75864be40..2cfe2ed8d 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -438,9 +438,9 @@ class ExpressionLearner: try: # 获取 StyleLearner 实例 learner = style_learner_manager.get_learner(chat_id) - + logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") - + # 为每个学习到的表达方式训练模型 # 使用 situation 作为输入,style 作为目标 # 这是最符合语义的方式:场景 -> 表达方式 @@ -448,25 +448,25 @@ class ExpressionLearner: for expr in expr_list: situation = expr["situation"] style = expr["style"] - + # 训练映射关系: situation -> style if learner.learn_mapping(situation, style): success_count += 1 else: logger.warning(f"训练失败: {situation} -> {style}") - + logger.info( f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " f"当前风格总数={len(learner.get_all_styles())}, " f"总样本数={learner.learning_stats['total_samples']}" ) - + # 保存模型 if learner.save(style_learner_manager.model_save_path): logger.info(f"StyleLearner 模型保存成功: {chat_id}") else: logger.error(f"StyleLearner 模型保存失败: {chat_id}") - + except Exception as e: logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True) @@ -527,7 +527,7 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的response: {response}") expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) - + if not expressions: logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。") logger.info(f"LLM完整响应:\n{response}") @@ -542,26 +542,26 @@ class ExpressionLearner: """ expressions: list[tuple[str, str, str]] = [] failed_lines = [] - + for line_num, line in enumerate(response.splitlines(), 1): line = line.strip() if not line: continue - + # 替换中文引号为英文引号,便于统一处理 line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"') - + # 查找"当"和下一个引号 idx_when = line_normalized.find('当"') if idx_when == -1: # 尝试不带引号的格式: 当xxx时 - idx_when = line_normalized.find('当') + idx_when = line_normalized.find("当") if idx_when == -1: failed_lines.append((line_num, line, "找不到'当'关键字")) continue - + # 提取"当"和"时"之间的内容 - idx_shi = line_normalized.find('时', idx_when) + idx_shi = line_normalized.find("时", idx_when) if idx_shi == -1: failed_lines.append((line_num, line, "找不到'时'关键字")) continue @@ -575,20 +575,20 @@ class ExpressionLearner: continue situation = line_normalized[idx_quote1 + 1 : idx_quote2] search_start = idx_quote2 - + # 查找"使用"或"可以" idx_use = line_normalized.find('使用"', search_start) if idx_use == -1: idx_use = line_normalized.find('可以"', search_start) if idx_use == -1: # 尝试不带引号的格式 - idx_use = line_normalized.find('使用', search_start) + idx_use = line_normalized.find("使用", search_start) if idx_use == -1: - idx_use = line_normalized.find('可以', search_start) + idx_use = line_normalized.find("可以", search_start) if idx_use == -1: failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字")) continue - + # 提取剩余部分作为style style = line_normalized[idx_use + 2:].strip('"\'"",。') if not style: @@ -610,24 +610,24 @@ class ExpressionLearner: style = line_normalized[idx_quote3 + 1:].strip('"\'""') else: style = line_normalized[idx_quote3 + 1 : idx_quote4] - + # 清理并验证 situation = situation.strip() style = style.strip() - + if not situation or not style: failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'")) continue - + expressions.append((chat_id, situation, style)) - + # 记录解析失败的行 if failed_lines: logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:") for line_num, line, reason in failed_lines[:5]: # 只显示前5个 logger.warning(f" 行{line_num}: {reason}") logger.debug(f" 原文: {line}") - + if not expressions: logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}") else: diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 1dbf7e08e..568cde3c3 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -267,11 +267,11 @@ class ExpressionSelector: chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history]) else: chat_info = chat_history - + # 根据配置选择模式 mode = global_config.expression.mode logger.debug(f"[ExpressionSelector] 使用模式: {mode}") - + if mode == "exp_model": return await self._select_expressions_model_only( chat_id=chat_id, @@ -288,7 +288,7 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + async def _select_expressions_classic( self, chat_id: str, @@ -298,7 +298,7 @@ class ExpressionSelector: min_num: int = 5, ) -> list[dict[str, Any]]: """经典模式:随机抽样 + LLM评估""" - logger.debug(f"[Classic模式] 使用LLM评估表达方式") + logger.debug("[Classic模式] 使用LLM评估表达方式") return await self.select_suitable_expressions_llm( chat_id=chat_id, chat_info=chat_info, @@ -306,7 +306,7 @@ class ExpressionSelector: min_num=min_num, target_message=target_message ) - + async def _select_expressions_model_only( self, chat_id: str, @@ -316,22 +316,22 @@ class ExpressionSelector: min_num: int = 5, ) -> list[dict[str, Any]]: """模型预测模式:先提取情境,再使用StyleLearner预测表达风格""" - logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") - + logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式") + # 检查是否允许在此聊天流中使用表达 if not self.can_use_expression_for_chat(chat_id): logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") return [] - + # 步骤1: 提取聊天情境 situations = await situation_extractor.extract_situations( chat_history=chat_info, target_message=target_message, max_situations=3 ) - + if not situations: - logger.warning(f"无法提取聊天情境,回退到经典模式") + logger.warning("无法提取聊天情境,回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -339,17 +339,17 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}") - + # 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式 learner = style_learner_manager.get_learner(chat_id) - + all_predicted_styles = {} for i, situation in enumerate(situations, 1): logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}") best_style, scores = learner.predict_style(situation, top_k=max_num) - + if best_style and scores: logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}") # 合并分数(取最高分) @@ -357,10 +357,10 @@ class ExpressionSelector: if style not in all_predicted_styles or score > all_predicted_styles[style]: all_predicted_styles[style] = score else: - logger.debug(f" 该情境未返回预测结果") - + logger.debug(" 该情境未返回预测结果") + if not all_predicted_styles: - logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") + logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -368,22 +368,22 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + # 将分数字典转换为列表格式 [(style, score), ...] predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True) - + logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}") - + # 步骤3: 根据预测的风格从数据库获取表达方式 - logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式") + logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式") expressions = await self.get_model_predicted_expressions( chat_id=chat_id, predicted_styles=predicted_styles, max_num=max_num ) - + if not expressions: - logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") + logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式") return await self._select_expressions_classic( chat_id=chat_id, chat_info=chat_info, @@ -391,10 +391,10 @@ class ExpressionSelector: max_num=max_num, min_num=min_num ) - + logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式") return expressions - + async def get_model_predicted_expressions( self, chat_id: str, @@ -414,15 +414,15 @@ class ExpressionSelector: """ if not predicted_styles: return [] - + # 提取风格名称(前3个最佳匹配) style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]] logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}") - + # 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式) related_chat_ids = self.get_related_chat_ids(chat_id) logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}") - + async with get_db_session() as session: # 🔍 先检查数据库中实际有哪些 chat_id 的数据 db_chat_ids_result = await session.execute( @@ -432,7 +432,7 @@ class ExpressionSelector: ) db_chat_ids = [cid for cid in db_chat_ids_result.scalars()] logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}") - + # 获取所有相关 chat_id 的表达方式(用于模糊匹配) all_expressions_result = await session.execute( select(Expression) @@ -440,51 +440,51 @@ class ExpressionSelector: .where(Expression.type == "style") ) all_expressions = list(all_expressions_result.scalars()) - + logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}") - + # 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id if not all_expressions: - logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询") + logger.info("相关chat_id没有数据,尝试从所有chat_id查询") all_expressions_result = await session.execute( select(Expression) .where(Expression.type == "style") ) all_expressions = list(all_expressions_result.scalars()) logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}") - + if not all_expressions: - logger.warning(f"数据库中完全没有任何表达方式,需要先学习") + logger.warning("数据库中完全没有任何表达方式,需要先学习") return [] - + # 🔥 使用模糊匹配而不是精确匹配 # 计算每个预测style与数据库style的相似度 from difflib import SequenceMatcher - + matched_expressions = [] for expr in all_expressions: db_style = expr.style or "" max_similarity = 0.0 best_predicted = "" - + # 与每个预测的style计算相似度 for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 # 计算字符串相似度 similarity = SequenceMatcher(None, predicted_style, db_style).ratio() - + # 也检查包含关系(如果一个是另一个的子串,给更高分) if len(predicted_style) >= 2 and len(db_style) >= 2: if predicted_style in db_style or db_style in predicted_style: similarity = max(similarity, 0.7) - + if similarity > max_similarity: max_similarity = similarity best_predicted = predicted_style - + # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 if max_similarity >= 0.3: # 30%相似度阈值 matched_expressions.append((expr, max_similarity, expr.count, best_predicted)) - + if not matched_expressions: # 收集数据库中的style样例用于调试 all_styles = [e.style for e in all_expressions[:10]] @@ -495,11 +495,11 @@ class ExpressionSelector: f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式" ) return [] - + # 按照相似度*count排序,选择最佳匹配 matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True) expressions_objs = [e[0] for e in matched_expressions[:max_num]] - + # 显示最佳匹配的详细信息 top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]] logger.info( @@ -507,7 +507,7 @@ class ExpressionSelector: f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n" f" Top3匹配: {top_matches}" ) - + # 转换为字典格式 expressions = [] for expr in expressions_objs: @@ -518,7 +518,7 @@ class ExpressionSelector: "count": float(expr.count) if expr.count else 0.0, "last_active_time": expr.last_active_time or 0.0 }) - + logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式") return expressions diff --git a/src/chat/express/expressor_model/model.py b/src/chat/express/expressor_model/model.py index 8c18240a8..c2b665878 100644 --- a/src/chat/express/expressor_model/model.py +++ b/src/chat/express/expressor_model/model.py @@ -5,7 +5,6 @@ import os import pickle from collections import Counter, defaultdict -from typing import Dict, Optional, Tuple from src.common.logger import get_logger @@ -36,14 +35,14 @@ class ExpressorModel: self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) # 候选表达管理 - self._candidates: Dict[str, str] = {} # cid -> text (style) - self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) + self._candidates: dict[str, str] = {} # cid -> text (style) + self._situations: dict[str, str] = {} # cid -> situation (不参与计算) logger.info( f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})" ) - def add_candidate(self, cid: str, text: str, situation: Optional[str] = None): + def add_candidate(self, cid: str, text: str, situation: str | None = None): """ 添加候选文本和对应的situation @@ -62,7 +61,7 @@ class ExpressorModel: if cid not in self.nb.token_counts: self.nb.token_counts[cid] = defaultdict(float) - def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: + def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]: """ 直接对所有候选进行朴素贝叶斯评分 @@ -113,7 +112,7 @@ class ExpressorModel: tf = Counter(toks) self.nb.update_positive(tf, cid) - def decay(self, factor: Optional[float] = None): + def decay(self, factor: float | None = None): """ 应用知识衰减 @@ -122,7 +121,7 @@ class ExpressorModel: """ self.nb.decay(factor) - def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: + def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]: """ 获取候选信息 @@ -136,7 +135,7 @@ class ExpressorModel: situation = self._situations.get(cid) return style, situation - def get_all_candidates(self) -> Dict[str, Tuple[str, str]]: + def get_all_candidates(self) -> dict[str, tuple[str, str]]: """ 获取所有候选 @@ -205,7 +204,7 @@ class ExpressorModel: logger.info(f"模型已从 {path} 加载") - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取模型统计信息""" nb_stats = self.nb.get_stats() return { diff --git a/src/chat/express/expressor_model/online_nb.py b/src/chat/express/expressor_model/online_nb.py index 39bd0d1cd..06230bdf7 100644 --- a/src/chat/express/expressor_model/online_nb.py +++ b/src/chat/express/expressor_model/online_nb.py @@ -4,7 +4,6 @@ """ import math from collections import Counter, defaultdict -from typing import Dict, List, Optional from src.common.logger import get_logger @@ -28,15 +27,15 @@ class OnlineNaiveBayes: self.V = vocab_size # 类别统计 - self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count - self.token_counts: Dict[str, Dict[str, float]] = defaultdict( + self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count + self.token_counts: dict[str, dict[str, float]] = defaultdict( lambda: defaultdict(float) ) # cid -> term -> count # 缓存 - self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + self._logZ: dict[str, float] = {} # cache log(∑counts + Vα) - def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: + def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]: """ 批量计算候选的贝叶斯分数 @@ -51,7 +50,7 @@ class OnlineNaiveBayes: n_cls = max(1, len(self.cls_counts)) denom_prior = math.log(total_cls + self.beta * n_cls) - out: Dict[str, float] = {} + out: dict[str, float] = {} for cid in cids: # 计算先验概率 log P(c) prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior @@ -88,7 +87,7 @@ class OnlineNaiveBayes: self.cls_counts[cid] += inc self._invalidate(cid) - def decay(self, factor: Optional[float] = None): + def decay(self, factor: float | None = None): """ 知识衰减(遗忘机制) @@ -133,7 +132,7 @@ class OnlineNaiveBayes: if cid in self._logZ: del self._logZ[cid] - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取统计信息""" return { "n_classes": len(self.cls_counts), diff --git a/src/chat/express/expressor_model/tokenizer.py b/src/chat/express/expressor_model/tokenizer.py index e25f780d4..b12cdc713 100644 --- a/src/chat/express/expressor_model/tokenizer.py +++ b/src/chat/express/expressor_model/tokenizer.py @@ -1,7 +1,6 @@ """ 文本分词器,支持中文Jieba分词 """ -from typing import List from src.common.logger import get_logger @@ -30,7 +29,7 @@ class Tokenizer: logger.warning("Jieba未安装,将使用字符级分词") self.use_jieba = False - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: """ 分词并返回token列表 diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py index 8ebe0a8bd..1393d5a1b 100644 --- a/src/chat/express/situation_extractor.py +++ b/src/chat/express/situation_extractor.py @@ -2,7 +2,6 @@ 情境提取器 从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测 """ -from typing import Optional from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger @@ -41,17 +40,17 @@ def init_prompt(): class SituationExtractor: """情境提取器,从聊天历史中提取当前情境""" - + def __init__(self): self.llm_model = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="expression.situation_extractor" ) - + async def extract_situations( self, chat_history: list | str, - target_message: Optional[str] = None, + target_message: str | None = None, max_situations: int = 3 ) -> list[str]: """ @@ -68,18 +67,18 @@ class SituationExtractor: # 转换chat_history为字符串 if isinstance(chat_history, list): chat_info = "\n".join([ - f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" + f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history ]) else: chat_info = chat_history - + # 构建目标消息信息 if target_message: target_message_info = f",现在你想要回复消息:{target_message}" else: target_message_info = "" - + # 构建 prompt try: prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format( @@ -87,31 +86,31 @@ class SituationExtractor: chat_history=chat_info, target_message_info=target_message_info ) - + # 调用 LLM response, _ = await self.llm_model.generate_response_async( prompt=prompt, temperature=0.3 ) - + if not response or not response.strip(): logger.warning("LLM返回空响应,无法提取情境") return [] - + # 解析响应 situations = self._parse_situations(response, max_situations) - + if situations: logger.debug(f"提取到 {len(situations)} 个情境: {situations}") else: logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}") - + return situations - + except Exception as e: logger.error(f"提取情境失败: {e}") return [] - + @staticmethod def _parse_situations(response: str, max_situations: int) -> list[str]: """ @@ -125,33 +124,33 @@ class SituationExtractor: 情境描述列表 """ situations = [] - + for line in response.splitlines(): line = line.strip() if not line: continue - + # 移除可能的序号、引号等 line = line.lstrip('0123456789.、-*>))】] \t"\'""''') line = line.rstrip('"\'""''') line = line.strip() - + if not line: continue - + # 过滤掉明显不是情境描述的内容 if len(line) > 30: # 太长 continue if len(line) < 2: # 太短 continue - if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']): + if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]): continue - + situations.append(line) - + if len(situations) >= max_situations: break - + return situations diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index c254ef98c..1ea54dd83 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -5,7 +5,6 @@ """ import os import time -from typing import Dict, List, Optional, Tuple from src.common.logger import get_logger @@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" - def __init__(self, chat_id: str, model_config: Optional[Dict] = None): + def __init__(self, chat_id: str, model_config: dict | None = None): """ Args: chat_id: 聊天室ID @@ -37,9 +36,9 @@ class StyleLearner: # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 - self.style_to_id: Dict[str, str] = {} # style文本 -> style_id - self.id_to_style: Dict[str, str] = {} # style_id -> style文本 - self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 + self.style_to_id: dict[str, str] = {} # style文本 -> style_id + self.id_to_style: dict[str, str] = {} # style_id -> style文本 + self.id_to_situation: dict[str, str] = {} # style_id -> situation文本 self.next_style_id = 0 # 学习统计 @@ -51,7 +50,7 @@ class StyleLearner: logger.info(f"StyleLearner初始化成功: chat_id={chat_id}") - def add_style(self, style: str, situation: Optional[str] = None) -> bool: + def add_style(self, style: str, situation: str | None = None) -> bool: """ 动态添加一个新的风格 @@ -130,7 +129,7 @@ class StyleLearner: logger.error(f"学习映射失败: {e}") return False - def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 根据up_content预测最合适的style @@ -146,7 +145,7 @@ class StyleLearner: if not self.style_to_id: logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}") return None, {} - + best_style_id, scores = self.expressor.predict(up_content, k=top_k) if best_style_id is None: @@ -155,7 +154,7 @@ class StyleLearner: # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) - + if best_style is None: logger.warning( f"style_id无法转换为style文本: style_id={best_style_id}, " @@ -171,7 +170,7 @@ class StyleLearner: style_scores[style_text] = score else: logger.warning(f"跳过无法转换的style_id: {sid}") - + logger.debug( f"预测成功: up_content={up_content[:30]}..., " f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}" @@ -183,7 +182,7 @@ class StyleLearner: logger.error(f"预测style失败: {e}", exc_info=True) return None, {} - def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: + def get_style_info(self, style: str) -> tuple[str | None, str | None]: """ 获取style的完整信息 @@ -200,7 +199,7 @@ class StyleLearner: situation = self.id_to_situation.get(style_id) return style_id, situation - def get_all_styles(self) -> List[str]: + def get_all_styles(self) -> list[str]: """ 获取所有风格列表 @@ -209,7 +208,7 @@ class StyleLearner: """ return list(self.style_to_id.keys()) - def apply_decay(self, factor: Optional[float] = None): + def apply_decay(self, factor: float | None = None): """ 应用知识衰减 @@ -304,7 +303,7 @@ class StyleLearner: logger.error(f"加载StyleLearner失败: {e}") return False - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取统计信息""" model_stats = self.expressor.get_stats() return { @@ -324,7 +323,7 @@ class StyleLearnerManager: Args: model_save_path: 模型保存路径 """ - self.learners: Dict[str, StyleLearner] = {} + self.learners: dict[str, StyleLearner] = {} self.model_save_path = model_save_path # 确保保存目录存在 @@ -332,7 +331,7 @@ class StyleLearnerManager: logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") - def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: + def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner: """ 获取或创建指定chat_id的学习器 @@ -369,7 +368,7 @@ class StyleLearnerManager: learner = self.get_learner(chat_id) return learner.learn_mapping(up_content, style) - def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 预测最合适的风格 @@ -399,7 +398,7 @@ class StyleLearnerManager: logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}") return success - def apply_decay_all(self, factor: Optional[float] = None): + def apply_decay_all(self, factor: float | None = None): """ 对所有学习器应用知识衰减 @@ -409,9 +408,9 @@ class StyleLearnerManager: for learner in self.learners.values(): learner.apply_decay(factor) - logger.info(f"对所有StyleLearner应用知识衰减") + logger.info("对所有StyleLearner应用知识衰减") - def get_all_stats(self) -> Dict[str, Dict]: + def get_all_stats(self) -> dict[str, dict]: """ 获取所有学习器的统计信息 diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 774fe5769..53ad47e84 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -503,7 +503,7 @@ class MemorySystem: existing_id = self._memory_fingerprints.get(fingerprint_key) if existing_id and existing_id not in new_memory_ids: candidate_ids.add(existing_id) - except Exception as exc: # noqa: PERF203 + except Exception as exc: logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc) # 基于主体索引的候选(使用统一存储) diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index aabf375fd..bd74925c7 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -35,12 +35,12 @@ class SingleStreamContextManager: self.last_access_time = time.time() self.access_count = 0 self.total_messages = 0 - + # 标记是否已初始化历史消息 self._history_initialized = False logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})") - + # 异步初始化历史消息(不阻塞构造函数) asyncio.create_task(self._initialize_history_from_db()) @@ -299,55 +299,55 @@ class SingleStreamContextManager: """更新访问统计""" self.last_access_time = time.time() self.access_count += 1 - + async def _initialize_history_from_db(self): """从数据库初始化历史消息到context中""" if self._history_initialized: logger.info(f"历史消息已初始化,跳过: {self.stream_id}") return - + # 立即设置标志,防止并发重复加载 logger.info(f"设置历史初始化标志: {self.stream_id}") self._history_initialized = True - + try: logger.info(f"开始从数据库加载历史消息: {self.stream_id}") - + from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat - + # 加载历史消息(限制数量为max_context_size的2倍,用于丰富上下文) db_messages = await get_raw_msg_before_timestamp_with_chat( chat_id=self.stream_id, timestamp=time.time(), limit=self.max_context_size * 2, ) - + if db_messages: # 将数据库消息转换为 DatabaseMessages 对象并添加到历史 for msg_dict in db_messages: try: # 使用 ** 解包字典作为关键字参数 db_msg = DatabaseMessages(**msg_dict) - + # 标记为已读 db_msg.is_read = True - + # 添加到历史消息 self.context.history_messages.append(db_msg) - + except Exception as e: logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}") continue - + logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}") else: logger.debug(f"没有历史消息需要加载: {self.stream_id}") - + except Exception as e: logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True) # 加载失败时重置标志,允许重试 self._history_initialized = False - + async def ensure_history_initialized(self): """确保历史消息已初始化(供外部调用)""" if not self._history_initialized: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index ec2de4c83..c3496b79b 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -69,10 +69,10 @@ class StreamLoopManager: try: # 获取所有活跃的流 from src.plugin_system.apis.chat_api import get_chat_manager - + chat_manager = get_chat_manager() all_streams = await chat_manager.get_all_streams() - + # 创建任务列表以便并发取消 cancel_tasks = [] for chat_stream in all_streams: @@ -119,10 +119,10 @@ class StreamLoopManager: # 创建流循环任务 try: loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") - + # 将任务记录到 StreamContext 中 context.stream_loop_task = loop_task - + # 更新统计信息 self.stats["active_streams"] += 1 self.stats["total_loops"] += 1 @@ -169,7 +169,7 @@ class StreamLoopManager: # 清空 StreamContext 中的任务记录 context.stream_loop_task = None - + logger.info(f"停止流循环: {stream_id}") return True @@ -200,13 +200,13 @@ class StreamLoopManager: if has_messages: if force_dispatch: logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count) - + # 3. 在处理前更新能量值(用于下次间隔计算) try: await self._update_stream_energy(stream_id, context) except Exception as e: logger.debug(f"更新流能量失败 {stream_id}: {e}") - + # 4. 激活chatter处理 success = await self._process_stream_messages(stream_id, context) @@ -371,7 +371,7 @@ class StreamLoopManager: # 清除 Chatter 处理标志 context.is_chatter_processing = False logger.debug(f"清除 Chatter 处理标志: {stream_id}") - + # 无论成功或失败,都要设置处理状态为未处理 self._set_stream_processing_status(stream_id, False) @@ -432,48 +432,48 @@ class StreamLoopManager: """ try: from src.chat.message_receive.chat_stream import get_chat_manager - + # 获取聊天流 chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if not chat_stream: logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") return - + # 从 context_manager 获取消息(包括未读和历史消息) # 合并未读消息和历史消息 all_messages = [] - + # 添加历史消息 history_messages = context.get_history_messages(limit=global_config.chat.max_context_size) all_messages.extend(history_messages) - + # 添加未读消息 unread_messages = context.get_unread_messages() all_messages.extend(unread_messages) - + # 按时间排序并限制数量 all_messages.sort(key=lambda m: m.time) messages = all_messages[-global_config.chat.max_context_size:] - + # 获取用户ID user_id = None if context.triggering_user_id: user_id = context.triggering_user_id - + # 使用能量管理器计算并缓存能量值 energy = await energy_manager.calculate_focus_energy( stream_id=stream_id, messages=messages, user_id=user_id ) - + # 同步更新到 ChatStream chat_stream._focus_energy = energy - + logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}") - + except Exception as e: logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False) @@ -670,7 +670,7 @@ class StreamLoopManager: # 使用 start_stream_loop 重新创建流循环任务 success = await self.start_stream_loop(stream_id, force=True) - + if success: logger.info(f"已创建强制分发流循环: {stream_id}") else: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index b3eb66ffc..a06e07be0 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -307,7 +307,7 @@ class MessageManager: # 检查上下文 context = chat_stream.context_manager.context - + # 只有当 Chatter 真正在处理时才检查打断 if not context.is_chatter_processing: logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查") @@ -315,7 +315,7 @@ class MessageManager: # 检查是否有 stream_loop_task 在运行 stream_loop_task = context.stream_loop_task - + if stream_loop_task and not stream_loop_task.done(): # 检查触发用户ID triggering_user_id = context.triggering_user_id @@ -387,7 +387,7 @@ class MessageManager: # 重新创建 stream_loop 任务 success = await stream_loop_manager.start_stream_loop(stream_id, force=True) - + if success: logger.info(f"✅ 成功重新创建流循环任务: {stream_id}") else: diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 4ff4626dd..710a1872d 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -10,7 +10,7 @@ from src.chat.antipromptinjector import initialize_anti_injector from src.chat.message_manager import message_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.prompt import create_prompt_async, global_prompt_manager +from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.utils import is_mentioned_bot_in_message from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger @@ -181,7 +181,7 @@ class ChatBot: # 创建PlusCommand实例 plus_command_instance = plus_command_class(message, plugin_config) - + # 为插件实例设置 chat_stream 运行时属性 setattr(plus_command_instance, "chat_stream", chat) @@ -257,7 +257,7 @@ class ChatBot: # 创建命令实例 command_instance: BaseCommand = command_class(message, plugin_config) command_instance.set_matched_groups(matched_groups) - + # 为插件实例设置 chat_stream 运行时属性 setattr(command_instance, "chat_stream", chat) @@ -340,7 +340,7 @@ class ChatBot: ) # print(message_data) # logger.debug(str(message_data)) - + # 先提取基础信息检查是否是自身消息上报 from maim_message import BaseMessageInfo temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) @@ -350,7 +350,7 @@ class ChatBot: # 直接使用消息字典更新,不再需要创建 MessageRecv await MessageStorage.update_message(message_data) return - + group_info = temp_message_info.group_info user_info = temp_message_info.user_info @@ -368,14 +368,14 @@ class ChatBot: stream_id=chat.stream_id, platform=chat.platform ) - + # 填充聊天流时间信息 message.chat_info.create_time = chat.create_time message.chat_info.last_active_time = chat.last_active_time - + # 注册消息到聊天管理器 get_chat_manager().register_message(message) - + # 检测是否提及机器人 message.is_mentioned, _ = is_mentioned_bot_in_message(message) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 8f6dd37e1..049d0fda1 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,8 +1,6 @@ import asyncio -import copy import hashlib import time -from typing import TYPE_CHECKING from maim_message import GroupInfo, UserInfo from rich.traceback import install @@ -10,13 +8,12 @@ from sqlalchemy import select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 -from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config # 新增导入 - install(extra_lines=3) @@ -134,7 +131,7 @@ class ChatStream: """ # 直接使用传入的 DatabaseMessages,设置到上下文中 self.context_manager.context.set_current_message(message) - + # 设置优先级信息(如果存在) priority_mode = getattr(message, "priority_mode", None) priority_info = getattr(message, "priority_info", None) @@ -156,7 +153,7 @@ class ChatStream: def _safe_get_actions(self, message: DatabaseMessages) -> list | None: """安全获取消息的actions字段""" import json - + try: actions = getattr(message, "actions", None) if actions is None: @@ -321,7 +318,7 @@ class ChatManager: def __init__(self): if not self._initialized: from src.common.data_models.database_data_model import DatabaseMessages - + self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message # try: @@ -360,15 +357,15 @@ class ChatManager: def register_message(self, message: DatabaseMessages): """注册消息到聊天流""" # 从 DatabaseMessages 提取平台和用户/群组信息 - from maim_message import UserInfo, GroupInfo - + from maim_message import GroupInfo, UserInfo + user_info = UserInfo( platform=message.user_info.platform, user_id=message.user_info.user_id, user_nickname=message.user_info.user_nickname, user_cardname=message.user_info.user_cardname or "" ) - + group_info = None if message.group_info: group_info = GroupInfo( @@ -376,7 +373,7 @@ class ChatManager: group_id=message.group_info.group_id, group_name=message.group_info.group_name ) - + stream_id = self._generate_stream_id( message.chat_info.platform, user_info, @@ -435,7 +432,7 @@ class ChatManager: stream.user_info = user_info if group_info: stream.group_info = group_info - + # 检查是否有最后一条消息(现在使用 DatabaseMessages) from src.common.data_models.database_data_model import DatabaseMessages if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): @@ -532,7 +529,7 @@ class ChatManager: async def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" from src.common.data_models.database_data_model import DatabaseMessages - + stream = self.streams.get(stream_id) if not stream: return None diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 4790bab98..68fc4f1bf 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,8 +1,7 @@ -import base64 import time from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Optional import urllib3 from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo @@ -11,7 +10,6 @@ from rich.traceback import install from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager -from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger @@ -269,7 +267,7 @@ class MessageSending(MessageProcessBase): if self.reply: # 从 DatabaseMessages 获取 message_id message_id = self.reply.message_id - + if message_id: self.reply_to_message_id = message_id self.message_segment = Seg( diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index b6c66f144..10e7213de 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -39,7 +39,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str # 解析基础信息 message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - + # 初始化处理状态 processing_state = { "is_emoji": False, @@ -53,10 +53,10 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str "priority_mode": "interest", "priority_info": None, } - + # 异步处理消息段,生成纯文本 processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info) - + # 解析 notice 信息 is_notify = False is_public_notice = False @@ -65,34 +65,34 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str is_notify = message_info.additional_config.get("is_notice", False) is_public_notice = message_info.additional_config.get("is_public_notice", False) notice_type = message_info.additional_config.get("notice_type") - + # 提取用户信息 user_info = message_info.user_info user_id = str(user_info.user_id) if user_info and user_info.user_id else "" user_nickname = (user_info.user_nickname or "") if user_info else "" user_cardname = user_info.user_cardname if user_info else None user_platform = (user_info.platform or "") if user_info else "" - + # 提取群组信息 group_info = message_info.group_info group_id = group_info.group_id if group_info else None group_name = group_info.group_name if group_info else None group_platform = group_info.platform if group_info else None - + # chat_id 应该直接使用 stream_id(与数据库存储格式一致) # stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的 chat_id = stream_id - + # 准备 additional_config additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type) - + # 提取 reply_to reply_to = _extract_reply_from_segment(message_segment) - + # 构造 DatabaseMessages message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() message_id = message_info.message_id or "" - + # 处理 is_mentioned is_mentioned = None mentioned_value = processing_state.get("is_mentioned") @@ -100,7 +100,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str is_mentioned = mentioned_value elif isinstance(mentioned_value, (int, float)): is_mentioned = mentioned_value != 0 - + db_message = DatabaseMessages( message_id=message_id, time=float(message_time), @@ -133,19 +133,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str chat_info_group_name=group_name, chat_info_group_platform=group_platform, ) - + # 设置优先级信息 if processing_state.get("priority_mode"): setattr(db_message, "priority_mode", processing_state["priority_mode"]) if processing_state.get("priority_info"): setattr(db_message, "priority_info", processing_state["priority_info"]) - + # 设置其他运行时属性 setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False))) setattr(db_message, "is_video", bool(processing_state.get("is_video", False))) setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False))) setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False))) - + return db_message @@ -190,7 +190,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM state["is_emoji"] = False state["is_video"] = False return segment.data - + elif segment.type == "at": state["is_picid"] = False state["is_emoji"] = False @@ -201,7 +201,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM nickname, qq_id = segment.data.split(":", 1) return f"@{nickname}" return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" - + elif segment.type == "image": # 如果是base64图片数据 if isinstance(segment.data, str): @@ -213,7 +213,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM _, processed_text = await image_manager.process_image(segment.data) return processed_text return "[发了一张图片,网卡了加载不出来]" - + elif segment.type == "emoji": state["has_emoji"] = True state["is_emoji"] = True @@ -223,13 +223,13 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM if isinstance(segment.data, str): return await get_image_manager().get_emoji_description(segment.data) return "[发了一个表情包,网卡了加载不出来]" - + elif segment.type == "voice": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = True state["is_video"] = False - + # 检查消息是否由机器人自己发送 if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account): logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。") @@ -240,12 +240,12 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM return f"[语音:{cached_text}]" else: logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") - + # 标准语音识别流程 if isinstance(segment.data, str): return await get_voice_text(segment.data) return "[发了一段语音,网卡了加载不出来]" - + elif segment.type == "mention_bot": state["is_picid"] = False state["is_emoji"] = False @@ -253,7 +253,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM state["is_video"] = False state["is_mentioned"] = float(segment.data) return "" - + elif segment.type == "priority_info": state["is_picid"] = False state["is_emoji"] = False @@ -263,26 +263,26 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM state["priority_mode"] = "priority" state["priority_info"] = segment.data return "" - + elif segment.type == "file": if isinstance(segment.data, dict): - file_name = segment.data.get('name', '未知文件') - file_size = segment.data.get('size', '未知大小') + file_name = segment.data.get("name", "未知文件") + file_size = segment.data.get("size", "未知大小") return f"[文件:{file_name} ({file_size}字节)]" return "[收到一个文件]" - + elif segment.type == "video": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = False state["is_video"] = True logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") - + # 检查视频分析功能是否可用 if not is_video_analysis_available(): logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") return "[视频]" - + if global_config.video_analysis.enable: logger.info("已启用视频识别,开始识别") if isinstance(segment.data, dict): @@ -290,23 +290,23 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM # 从Adapter接收的视频数据 video_base64 = segment.data.get("base64") filename = segment.data.get("filename", "video.mp4") - + logger.info(f"视频文件名: {filename}") logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") - + if video_base64: # 解码base64视频数据 video_bytes = base64.b64decode(video_base64) logger.info(f"解码后视频大小: {len(video_bytes)} 字节") - + # 使用video analyzer分析视频 video_analyzer = get_video_analyzer() result = await video_analyzer.analyze_video_from_bytes( video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt ) - + logger.info(f"视频分析结果: {result}") - + # 返回视频分析结果 summary = result.get("summary", "") if summary: @@ -329,7 +329,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM else: logger.warning(f"未知的消息段类型: {segment.type}") return f"[{segment.type} 消息]" - + except Exception as e: logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -349,9 +349,9 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i """ try: additional_config_data = {} - + # 首先获取adapter传递的additional_config - if hasattr(message_info, 'additional_config') and message_info.additional_config: + if hasattr(message_info, "additional_config") and message_info.additional_config: if isinstance(message_info.additional_config, dict): additional_config_data = message_info.additional_config.copy() elif isinstance(message_info.additional_config, str): @@ -360,28 +360,28 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i except Exception as e: logger.warning(f"无法解析 additional_config JSON: {e}") additional_config_data = {} - + # 添加notice相关标志 if is_notify: additional_config_data["is_notice"] = True additional_config_data["notice_type"] = notice_type or "unknown" additional_config_data["is_public_notice"] = bool(is_public_notice) - + # 添加format_info到additional_config中 - if hasattr(message_info, 'format_info') and message_info.format_info: + if hasattr(message_info, "format_info") and message_info.format_info: try: format_info_dict = message_info.format_info.to_dict() additional_config_data["format_info"] = format_info_dict logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}") except Exception as e: logger.warning(f"将 format_info 转换为字典失败: {e}") - + # 序列化为JSON字符串 if additional_config_data: return orjson.dumps(additional_config_data).decode("utf-8") except Exception as e: logger.error(f"准备 additional_config 失败: {e}") - + return None @@ -423,8 +423,8 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag Returns: BaseMessageInfo: 重建的消息信息对象 """ - from maim_message import UserInfo, GroupInfo - + from maim_message import GroupInfo, UserInfo + # 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo user_info = UserInfo( platform=db_message.user_info.platform, @@ -432,7 +432,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag user_nickname=db_message.user_info.user_nickname, user_cardname=db_message.user_info.user_cardname or "" ) - + # 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在) group_info = None if db_message.group_info: @@ -441,7 +441,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag group_id=db_message.group_info.group_id, group_name=db_message.group_info.group_name ) - + # 解析 additional_config(从 JSON 字符串到字典) additional_config = None if db_message.additional_config: @@ -450,7 +450,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag except Exception: # 如果解析失败,保持为字符串 pass - + # 创建 BaseMessageInfo message_info = BaseMessageInfo( platform=db_message.chat_info.platform, @@ -460,7 +460,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag group_info=group_info, additional_config=additional_config # type: ignore ) - + return message_info diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 9b2b54991..314472845 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -5,12 +5,11 @@ import traceback import orjson from sqlalchemy import desc, select, update +from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Images, Messages from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages - from .chat_stream import ChatStream from .message import MessageSending @@ -51,10 +50,10 @@ class MessageStorage: filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" - + display_message = message.display_message or message.processed_plain_text or "" filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) - + # 直接从 DatabaseMessages 获取所有字段 msg_id = message.message_id msg_time = message.time @@ -71,13 +70,13 @@ class MessageStorage: key_words = "" # DatabaseMessages 没有 key_words key_words_lite = "" memorized_times = 0 # DatabaseMessages 没有 memorized_times - + # 使用 DatabaseMessages 中的嵌套对象信息 user_platform = message.user_info.platform if message.user_info else "" user_id = message.user_info.user_id if message.user_info else "" user_nickname = message.user_info.user_nickname if message.user_info else "" user_cardname = message.user_info.user_cardname if message.user_info else None - + chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" chat_info_platform = message.chat_info.platform if message.chat_info else "" chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 @@ -89,7 +88,7 @@ class MessageStorage: chat_info_group_platform = message.group_info.group_platform if message.group_info else None chat_info_group_id = message.group_info.group_id if message.group_info else None chat_info_group_name = message.group_info.group_name if message.group_info else None - + else: # MessageSending 处理逻辑 processed_plain_text = message.processed_plain_text @@ -145,7 +144,7 @@ class MessageStorage: msg_time = float(message.message_info.time or time.time()) chat_id = chat_stream.stream_id memorized_times = message.memorized_times - + # 安全地获取 group_info, 如果为 None 则视为空字典 group_info_from_chat = chat_info_dict.get("group_info") or {} # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) @@ -153,12 +152,12 @@ class MessageStorage: # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None - + user_platform = user_info_dict.get("platform") user_id = user_info_dict.get("user_id") user_nickname = user_info_dict.get("user_nickname") user_cardname = user_info_dict.get("user_cardname") - + chat_info_stream_id = chat_info_dict.get("stream_id") chat_info_platform = chat_info_dict.get("platform") chat_info_create_time = float(chat_info_dict.get("create_time", 0.0)) @@ -222,11 +221,11 @@ class MessageStorage: # 从字典中提取信息 message_info = message_data.get("message_info", {}) mmc_message_id = message_info.get("message_id") - + message_segment = message_data.get("message_segment", {}) segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} - + qq_message_id = None logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 3a1204f23..20f927419 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool: await get_global_api().send_message(message) if show_log: logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'") - + # 触发 AFTER_SEND 事件 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType - + from src.plugin_system.core.event_manager import event_manager + if message.chat_stream: logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}") - + # 使用 asyncio.create_task 来异步触发事件,避免阻塞 async def trigger_event_async(): try: - logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件") + logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件") await event_manager.trigger_event( EventType.AFTER_SEND, permission_group="SYSTEM", stream_id=message.chat_stream.stream_id, message=message, ) - logger.info(f"[事件触发] AFTER_SEND 事件触发完成") + logger.info("[事件触发] AFTER_SEND 事件触发完成") except Exception as e: logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True) - + # 创建异步任务,不等待完成 asyncio.create_task(trigger_event_async()) - logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务") + logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务") except Exception as event_error: logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True) - + return True except Exception as e: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index e15dab72a..f52e40657 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -270,7 +270,7 @@ class ChatterActionManager: msg_text = target_message.get("processed_plain_text", "未知消息") else: msg_text = "未知消息" - + logger.info(f"对 {msg_text} 的回复生成失败") return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} except asyncio.CancelledError: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 6eddb354c..bbe05e718 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -32,8 +32,6 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest - - from src.mood.mood_manager import mood_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.apis import llm_api @@ -943,10 +941,10 @@ class DefaultReplyer: chat_stream = await chat_manager.get_stream(chat_id) if chat_stream: stream_context = chat_stream.context_manager - + # 确保历史消息已从数据库加载 await stream_context.ensure_history_initialized() - + # 直接使用内存中的已读和未读消息,无需再查询数据库 read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载) unread_messages = stream_context.get_unread_messages() # 未读消息 @@ -956,11 +954,11 @@ class DefaultReplyer: if read_messages: # 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages read_messages_dicts = [msg.flatten() for msg in read_messages] - + # 按时间排序并限制数量 sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0)) final_history = sorted_messages[-50:] # 限制最多50条 - + read_content = await build_readable_messages( final_history, replace_bot_name=True, @@ -1194,7 +1192,7 @@ class DefaultReplyer: if reply_message is None: logger.warning("reply_message 为 None,无法构建prompt") return "" - + # 统一处理 DatabaseMessages 对象和字典 if isinstance(reply_message, DatabaseMessages): platform = reply_message.chat_info.platform @@ -1208,7 +1206,7 @@ class DefaultReplyer: user_nickname = reply_message.get("user_nickname") user_cardname = reply_message.get("user_cardname") processed_plain_text = reply_message.get("processed_plain_text") - + person_id = person_info_manager.get_person_id( platform, # type: ignore user_id, # type: ignore @@ -1262,24 +1260,24 @@ class DefaultReplyer: # 从内存获取历史消息,避免重复查询数据库 from src.plugin_system.apis.chat_api import get_chat_manager - + chat_manager = get_chat_manager() chat_stream_obj = await chat_manager.get_stream(chat_id) - + if chat_stream_obj: # 确保历史消息已初始化 await chat_stream_obj.context_manager.ensure_history_initialized() - + # 获取所有消息(历史+未读) all_messages = ( chat_stream_obj.context_manager.context.history_messages + chat_stream_obj.context_manager.get_unread_messages() ) - + # 转换为字典格式 message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]] message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]] - + logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}") else: # 回退到数据库查询 @@ -1294,7 +1292,7 @@ class DefaultReplyer: timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - + chat_talking_prompt_short = await build_readable_messages( message_list_before_short, replace_bot_name=True, @@ -1634,24 +1632,24 @@ class DefaultReplyer: # 从内存获取历史消息,避免重复查询数据库 from src.plugin_system.apis.chat_api import get_chat_manager - + chat_manager = get_chat_manager() chat_stream_obj = await chat_manager.get_stream(chat_id) - + if chat_stream_obj: # 确保历史消息已初始化 await chat_stream_obj.context_manager.ensure_history_initialized() - + # 获取所有消息(历史+未读) all_messages = ( chat_stream_obj.context_manager.context.history_messages + chat_stream_obj.context_manager.get_unread_messages() ) - + # 转换为字典格式,限制数量 limit = min(int(global_config.chat.max_context_size * 0.33), 15) message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]] - + logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息") else: # 回退到数据库查询 @@ -1661,7 +1659,7 @@ class DefaultReplyer: timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - + chat_talking_prompt_half = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, @@ -1818,7 +1816,7 @@ class DefaultReplyer: # 循环移除,以处理模型可能生成的嵌套回复头/尾 # 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容 pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?:(?P.*)\](?:,?说:)?\s*$", re.DOTALL) - + temp_content = cleaned_content while True: match = pattern.match(temp_content) @@ -1830,7 +1828,7 @@ class DefaultReplyer: temp_content = new_content else: break # 没有匹配到,退出循环 - + # 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况 # 这可以作为处理复杂嵌套的最后一道防线 final_split = temp_content.rsplit("],说:", 1) @@ -1838,7 +1836,7 @@ class DefaultReplyer: final_content = final_split[1].strip() else: final_content = temp_content - + if final_content != content: logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'") content = final_content @@ -2077,24 +2075,24 @@ class DefaultReplyer: # 从内存获取聊天历史用于存储,避免重复查询数据库 from src.plugin_system.apis.chat_api import get_chat_manager - + chat_manager = get_chat_manager() chat_stream_obj = await chat_manager.get_stream(stream.stream_id) - + if chat_stream_obj: # 确保历史消息已初始化 await chat_stream_obj.context_manager.ensure_history_initialized() - + # 获取所有消息(历史+未读) all_messages = ( chat_stream_obj.context_manager.context.history_messages + chat_stream_obj.context_manager.get_unread_messages() ) - + # 转换为字典格式,限制数量 limit = int(global_config.chat.max_context_size * 0.33) message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]] - + logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息") else: # 回退到数据库查询 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 2e141e6ad..c10056bf2 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1112,14 +1112,14 @@ class Prompt: # 使用关系提取器构建用户关系信息和聊天流印象 user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5) stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id) - + # 组合两部分信息 info_parts = [] if user_relation_info: info_parts.append(user_relation_info) if stream_impression: info_parts.append(stream_impression) - + return "\n\n".join(info_parts) if info_parts else "" def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index bc0641c9d..f0d5e2529 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,6 +11,7 @@ import rjieba from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager + # MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages @@ -49,13 +50,13 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]: Returns: tuple[bool, float]: (是否提及, 提及概率) - """ + """ keywords = [global_config.bot.nickname] nicknames = global_config.bot.alias_names reply_probability = 0.0 is_at = False is_mentioned = False - + # 检查 is_mentioned 属性 mentioned_attr = getattr(message, "is_mentioned", None) if mentioned_attr is not None: @@ -63,7 +64,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]: return bool(mentioned_attr), float(mentioned_attr) except (ValueError, TypeError): pass - + # 检查 additional_config additional_config = None diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index eb29b3302..5eb7f0f7b 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -7,7 +7,7 @@ import asyncio import time from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode, ChatType @@ -64,7 +64,7 @@ class StreamContext(BaseDataModel): triggering_user_id: str | None = None # 触发当前聊天流的用户ID is_replying: bool = False # 是否正在生成回复 processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复 - decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史 + decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史 def add_action_to_message(self, message_id: str, action: str): """ @@ -260,7 +260,7 @@ class StreamContext(BaseDataModel): if requested_type not in accept_format: logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}") return False - logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") + logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") return True # 方法2: 检查content_format字段(向后兼容) @@ -279,7 +279,7 @@ class StreamContext(BaseDataModel): if requested_type not in content_format: logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") return False - logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") + logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") return True else: logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段") diff --git a/src/config/config.py b/src/config/config.py index efd57be69..b22674893 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -26,7 +26,6 @@ from src.config.official_configs import ( EmojiConfig, ExperimentalConfig, ExpressionConfig, - ReactionConfig, LPMMKnowledgeConfig, MaimMessageConfig, MemoryConfig, @@ -38,6 +37,7 @@ from src.config.official_configs import ( PersonalityConfig, PlanningSystemConfig, ProactiveThinkingConfig, + ReactionConfig, ResponsePostProcessConfig, ResponseSplitterConfig, ToolConfig, diff --git a/src/config/official_configs.py b/src/config/official_configs.py index cc9885b8c..24957cd30 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -188,7 +188,7 @@ class ExpressionConfig(ValidatedConfigBase): """表达配置类""" mode: Literal["classic", "exp_model"] = Field( - default="classic", + default="classic", description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测" ) rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @@ -761,35 +761,35 @@ class ProactiveThinkingConfig(ValidatedConfigBase): cold_start_cooldown: int = Field( default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)" ) - + # --- 新增:间隔配置 --- base_interval: int = Field(default=1800, ge=60, description="基础触发间隔(秒),默认30分钟") min_interval: int = Field(default=600, ge=60, description="最小触发间隔(秒),默认10分钟。兴趣分数高时会接近此值") max_interval: int = Field(default=7200, ge=60, description="最大触发间隔(秒),默认2小时。兴趣分数低时会接近此值") - + # --- 新增:动态调整配置 --- use_interest_score: bool = Field(default=True, description="是否根据兴趣分数动态调整间隔。关闭则使用固定base_interval") interest_score_factor: float = Field(default=2.0, ge=1.0, le=3.0, description="兴趣分数影响因子。公式: interval = base * (factor - score)") - + # --- 新增:黑白名单配置 --- whitelist_mode: bool = Field(default=False, description="是否启用白名单模式。启用后只对白名单中的聊天流生效") blacklist_mode: bool = Field(default=False, description="是否启用黑名单模式。启用后排除黑名单中的聊天流") - + whitelist_private: list[str] = Field( - default_factory=list, + default_factory=list, description='私聊白名单,格式: ["platform:user_id:private", "qq:12345:private"]' ) whitelist_group: list[str] = Field( - default_factory=list, + default_factory=list, description='群聊白名单,格式: ["platform:group_id:group", "qq:123456:group"]' ) - + blacklist_private: list[str] = Field( - default_factory=list, + default_factory=list, description='私聊黑名单,格式: ["platform:user_id:private", "qq:12345:private"]' ) blacklist_group: list[str] = Field( - default_factory=list, + default_factory=list, description='群聊黑名单,格式: ["platform:group_id:group", "qq:123456:group"]' ) @@ -802,17 +802,17 @@ class ProactiveThinkingConfig(ValidatedConfigBase): quiet_hours_start: str = Field(default="00:00", description='安静时段开始时间,格式: "HH:MM"') quiet_hours_end: str = Field(default="07:00", description='安静时段结束时间,格式: "HH:MM"') active_hours_multiplier: float = Field(default=0.7, ge=0.1, le=2.0, description="活跃时段间隔倍数,<1表示更频繁,>1表示更稀疏") - + # --- 新增:冷却与限制 --- reply_reset_enabled: bool = Field(default=True, description="bot回复后是否重置定时器(避免回复后立即又主动发言)") topic_throw_cooldown: int = Field(default=3600, ge=0, description="抛出话题后的冷却时间(秒),期间暂停主动思考") max_daily_proactive: int = Field(default=0, ge=0, description="每个聊天流每天最多主动发言次数,0表示不限制") - + # --- 新增:决策权重配置 --- do_nothing_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="do_nothing动作的基础权重") simple_bubble_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="simple_bubble动作的基础权重") throw_topic_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="throw_topic动作的基础权重") - + # --- 新增:调试与监控 --- enable_statistics: bool = Field(default=True, description="是否启用统计功能(记录触发次数、决策分布等)") log_decisions: bool = Field(default=False, description="是否记录每次决策的详细日志(用于调试)") diff --git a/src/main.py b/src/main.py index c23d887b3..c11180e43 100644 --- a/src/main.py +++ b/src/main.py @@ -429,7 +429,7 @@ MoFox_Bot(第三方修改版) await initialize_scheduler() except Exception as e: logger.error(f"统一调度器初始化失败: {e}") - + # 加载所有插件 plugin_manager.load_all_plugins() diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 2f20ea5be..c9776df64 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -123,7 +123,7 @@ class RelationshipFetcher: # 获取用户特征点 current_points = await person_info_manager.get_value(person_id, "points") or [] forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] - + # 确保 points 是列表类型(可能从数据库返回字符串) if not isinstance(current_points, list): current_points = [] @@ -195,25 +195,25 @@ class RelationshipFetcher: if relationships: # db_query 返回字典列表,使用字典访问方式 rel_data = relationships[0] - + # 5.1 用户别名 if rel_data.get("user_aliases"): aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()] if aliases_list: aliases_str = "、".join(aliases_list) relation_parts.append(f"{person_name}的别名有:{aliases_str}") - + # 5.2 关系印象文本(主观认知) if rel_data.get("relationship_text"): relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}") - + # 5.3 用户偏好关键词 if rel_data.get("preference_keywords"): keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()] if keywords_list: keywords_str = "、".join(keywords_list) relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}") - + # 5.4 关系亲密程度(好感分数) if rel_data.get("relationship_score") is not None: score_desc = self._get_relationship_score_description(rel_data["relationship_score"]) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 71786562d..429be54c8 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -55,7 +55,7 @@ async def file_to_stream( if not file_name: file_name = Path(file_path).name - + params = { "file": file_path, "name": file_name, @@ -68,7 +68,7 @@ async def file_to_stream( else: action = "upload_private_file" params["user_id"] = target_stream.user_info.user_id - + response = await adapter_command_to_stream( action=action, params=params, @@ -86,7 +86,7 @@ async def file_to_stream( import asyncio import time import traceback -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from maim_message import Seg, UserInfo @@ -117,11 +117,11 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None """ from src.common.data_models.database_data_model import DatabaseMessages - + # 如果已经是 DatabaseMessages,直接返回 if isinstance(message_dict, DatabaseMessages): return message_dict - + # 从字典提取信息 user_platform = message_dict.get("user_platform", "") user_id = message_dict.get("user_id", "") @@ -135,7 +135,7 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage time_val = message_dict.get("time", time.time()) additional_config = message_dict.get("additional_config") processed_plain_text = message_dict.get("processed_plain_text", "") - + # DatabaseMessages 使用扁平参数构造 db_message = DatabaseMessages( message_id=message_id or "temp_reply_id", @@ -151,7 +151,7 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage processed_plain_text=processed_plain_text, additional_config=additional_config ) - + logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}") return db_message diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index b5071e578..e102b55cc 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -192,7 +192,7 @@ class BaseAction(ABC): self.group_name = self.action_message.get("chat_info_group_name", None) self.user_id = str(self.action_message.get("user_id", None)) self.user_nickname = self.action_message.get("user_nickname", None) - + if self.group_id: self.is_group = True self.target_id = self.group_id diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 7076bbba6..df604cbc0 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -45,7 +45,7 @@ class BaseCommand(ABC): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.log_prefix = "[Command]" - + # chat_stream 会在运行时被 bot.py 设置 self.chat_stream: "ChatStream | None" = None diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index b53846fc2..525819763 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -64,7 +64,7 @@ class PlusCommand(ABC): self.message = message self.plugin_config = plugin_config or {} self.log_prefix = "[PlusCommand]" - + # chat_stream 会在运行时被 bot.py 设置 self.chat_stream: "ChatStream | None" = None diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index e54861b15..64468b958 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -40,7 +40,7 @@ class EventManager: self._events: dict[str, BaseEvent] = {} self._event_handlers: dict[str, type[BaseEventHandler]] = {} self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 - self._scheduler_callback: Optional[Any] = None # scheduler 回调函数 + self._scheduler_callback: Any | None = None # scheduler 回调函数 self._initialized = True logger.info("EventManager 单例初始化完成") diff --git a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py index a49f7f36e..d4649dbfa 100644 --- a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py @@ -5,7 +5,6 @@ """ import json -import time from typing import Any from sqlalchemy import select @@ -22,7 +21,7 @@ logger = get_logger("chat_stream_impression_tool") class ChatStreamImpressionTool(BaseTool): """聊天流印象更新工具 - + 使用二步调用机制: 1. LLM决定是否调用工具并传入初步参数(stream_id会自动传入) 2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容 @@ -31,27 +30,52 @@ class ChatStreamImpressionTool(BaseTool): name = "update_chat_stream_impression" description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。" parameters = [ - ("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None), - ("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None), - ("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None), - ("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None), + ( + "impression_description", + ToolParamType.STRING, + "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", + False, + None, + ), + ( + "chat_style", + ToolParamType.STRING, + "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", + False, + None, + ), + ( + "topic_keywords", + ToolParamType.STRING, + "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", + False, + None, + ), + ( + "interest_score", + ToolParamType.FLOAT, + "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", + False, + None, + ), ] available_for_llm = True history_ttl = 5 def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): super().__init__(plugin_config, chat_stream) - + # 初始化用于二步调用的LLM try: self.impression_llm = LLMRequest( model_set=model_config.model_task_config.relationship_tracker, - request_type="chat_stream_impression_update" + request_type="chat_stream_impression_update", ) except AttributeError: # 降级处理 available_models = [ - attr for attr in dir(model_config.model_task_config) + attr + for attr in dir(model_config.model_task_config) if not attr.startswith("_") and attr != "model_dump" ] if available_models: @@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool): logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}") self.impression_llm = LLMRequest( model_set=getattr(model_config.model_task_config, fallback_model), - request_type="chat_stream_impression_update" + request_type="chat_stream_impression_update", ) else: logger.error("无可用的模型配置") @@ -67,17 +91,17 @@ class ChatStreamImpressionTool(BaseTool): async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行聊天流印象更新 - + Args: function_args: 工具参数 - + Returns: dict: 执行结果 """ try: # 优先从 function_args 获取 stream_id stream_id = function_args.get("stream_id") - + # 如果没有,从 chat_stream 对象获取 if not stream_id and self.chat_stream: try: @@ -85,61 +109,49 @@ class ChatStreamImpressionTool(BaseTool): logger.debug(f"从 chat_stream 获取到 stream_id: {stream_id}") except AttributeError: logger.warning("chat_stream 对象没有 stream_id 属性") - + # 如果还是没有,返回错误 if not stream_id: logger.error("无法获取 stream_id:function_args 和 chat_stream 都没有提供") - return { - "type": "error", - "id": "chat_stream_impression", - "content": "错误:无法获取当前聊天流ID" - } - + return {"type": "error", "id": "chat_stream_impression", "content": "错误:无法获取当前聊天流ID"} + # 从LLM传入的参数 new_impression = function_args.get("impression_description", "") new_style = function_args.get("chat_style", "") new_topics = function_args.get("topic_keywords", "") new_score = function_args.get("interest_score") - + # 从数据库获取现有聊天流印象 existing_impression = await self._get_stream_impression(stream_id) - + # 如果LLM没有传入任何有效参数,返回提示 if not any([new_impression, new_style, new_topics, new_score is not None]): return { "type": "info", "id": stream_id, - "content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)" + "content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)", } - + # 调用LLM进行二步决策 if self.impression_llm is None: logger.error("LLM未正确初始化,无法执行二步调用") - return { - "type": "error", - "id": stream_id, - "content": "系统错误:LLM未正确初始化" - } - + return {"type": "error", "id": stream_id, "content": "系统错误:LLM未正确初始化"} + final_impression = await self._llm_decide_final_impression( stream_id=stream_id, existing_impression=existing_impression, new_impression=new_impression, new_style=new_style, new_topics=new_topics, - new_score=new_score + new_score=new_score, ) - + if not final_impression: - return { - "type": "error", - "id": stream_id, - "content": "LLM决策失败,无法更新聊天流印象" - } - + return {"type": "error", "id": stream_id, "content": "LLM决策失败,无法更新聊天流印象"} + # 更新数据库 await self._update_stream_impression_in_db(stream_id, final_impression) - + # 构建返回信息 updates = [] if final_impression.get("stream_impression_text"): @@ -150,30 +162,26 @@ class ChatStreamImpressionTool(BaseTool): updates.append(f"话题: {final_impression['stream_topic_keywords']}") if final_impression.get("stream_interest_score") is not None: updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}") - + result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates) logger.info(f"聊天流印象更新成功: {stream_id}") - - return { - "type": "chat_stream_impression_update", - "id": stream_id, - "content": result_text - } - + + return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text} + except Exception as e: logger.error(f"聊天流印象更新失败: {e}", exc_info=True) return { "type": "error", "id": function_args.get("stream_id", "unknown"), - "content": f"聊天流印象更新失败: {str(e)}" + "content": f"聊天流印象更新失败: {e!s}", } async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]: """从数据库获取聊天流现有印象 - + Args: stream_id: 聊天流ID - + Returns: dict: 聊天流印象数据 """ @@ -182,13 +190,15 @@ class ChatStreamImpressionTool(BaseTool): stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if stream: return { "stream_impression_text": stream.stream_impression_text or "", "stream_chat_style": stream.stream_chat_style or "", "stream_topic_keywords": stream.stream_topic_keywords or "", - "stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5, + "stream_interest_score": float(stream.stream_interest_score) + if stream.stream_interest_score is not None + else 0.5, "group_name": stream.group_name or "私聊", } else: @@ -217,10 +227,10 @@ class ChatStreamImpressionTool(BaseTool): new_impression: str, new_style: str, new_topics: str, - new_score: float | None + new_score: float | None, ) -> dict[str, Any] | None: """使用LLM决策最终的聊天流印象内容 - + Args: stream_id: 聊天流ID existing_impression: 现有印象数据 @@ -228,33 +238,34 @@ class ChatStreamImpressionTool(BaseTool): new_style: LLM传入的新风格 new_topics: LLM传入的新话题 new_score: LLM传入的新分数 - + Returns: dict: 最终决定的印象数据,如果失败返回None """ try: # 获取bot人设 from src.individuality.individuality import Individuality + individuality = Individuality() bot_personality = await individuality.get_personality_block() - + prompt = f""" 你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} 你正在更新对聊天流 {stream_id} 的整体印象。 【当前聊天流信息】 -- 聊天环境: {existing_impression.get('group_name', '未知')} -- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')} -- 聊天风格: {existing_impression.get('stream_chat_style', '未知')} -- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')} -- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f} +- 聊天环境: {existing_impression.get("group_name", "未知")} +- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")} +- 聊天风格: {existing_impression.get("stream_chat_style", "未知")} +- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")} +- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f} 【本次想要更新的内容】 -- 新的印象描述: {new_impression if new_impression else '不更新'} -- 新的聊天风格: {new_style if new_style else '不更新'} -- 新的话题关键词: {new_topics if new_topics else '不更新'} -- 新的兴趣分数: {new_score if new_score is not None else '不更新'} +- 新的印象描述: {new_impression if new_impression else "不更新"} +- 新的聊天风格: {new_style if new_style else "不更新"} +- 新的话题关键词: {new_topics if new_topics else "不更新"} +- 新的兴趣分数: {new_score if new_score is not None else "不更新"} 请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意: 1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字) @@ -271,31 +282,47 @@ class ChatStreamImpressionTool(BaseTool): "reasoning": "你的决策理由" }} """ - + # 调用LLM llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt) - + if not llm_response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析响应 cleaned_response = self._clean_llm_json_response(llm_response) response_data = json.loads(cleaned_response) - + # 提取最终决定的数据 final_impression = { - "stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")), - "stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")), - "stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")), - "stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))), + "stream_impression_text": response_data.get( + "stream_impression_text", existing_impression.get("stream_impression_text", "") + ), + "stream_chat_style": response_data.get( + "stream_chat_style", existing_impression.get("stream_chat_style", "") + ), + "stream_topic_keywords": response_data.get( + "stream_topic_keywords", existing_impression.get("stream_topic_keywords", "") + ), + "stream_interest_score": max( + 0.0, + min( + 1.0, + float( + response_data.get( + "stream_interest_score", existing_impression.get("stream_interest_score", 0.5) + ) + ), + ), + ), } - + logger.info(f"LLM决策完成: {stream_id}") logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") - + return final_impression - + except json.JSONDecodeError as e: logger.error(f"LLM响应JSON解析失败: {e}") logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") @@ -306,7 +333,7 @@ class ChatStreamImpressionTool(BaseTool): async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]): """更新数据库中的聊天流印象 - + Args: stream_id: 聊天流ID impression: 印象数据 @@ -316,14 +343,14 @@ class ChatStreamImpressionTool(BaseTool): stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 更新现有记录 existing.stream_impression_text = impression.get("stream_impression_text", "") existing.stream_chat_style = impression.get("stream_chat_style", "") existing.stream_topic_keywords = impression.get("stream_topic_keywords", "") existing.stream_interest_score = impression.get("stream_interest_score", 0.5) - + await session.commit() logger.info(f"聊天流印象已更新到数据库: {stream_id}") else: @@ -331,40 +358,40 @@ class ChatStreamImpressionTool(BaseTool): logger.error(error_msg) # 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录 raise ValueError(error_msg) - + except Exception as e: logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True) raise def _clean_llm_json_response(self, response: str) -> str: """清理LLM响应,移除可能的JSON格式标记 - + Args: response: LLM原始响应 - + Returns: str: 清理后的JSON字符串 """ try: import re - + cleaned = response.strip() - + # 移除 ```json 或 ``` 等标记 cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + # 尝试找到JSON对象的开始和结束 json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: - cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned[json_start : json_end + 1] + cleaned = cleaned.strip() - + return cleaned - + except Exception as e: logger.warning(f"清理LLM响应失败: {e}") return response diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index e150e7e62..4359b3f66 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -231,11 +231,11 @@ class ChatterPlanExecutor: except Exception as e: error_message = str(e) logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}") - + # 将机器人回复添加到已读消息中 if success and action_info.action_message: await self._add_bot_reply_to_read_messages(action_info, plan, reply_content) - + execution_time = time.time() - start_time self.execution_stats["execution_times"].append(execution_time) @@ -381,13 +381,11 @@ class ChatterPlanExecutor: is_picid=False, is_command=False, is_notify=False, - # 用户信息 user_id=bot_user_id, user_nickname=bot_nickname, user_cardname=bot_nickname, user_platform="qq", - # 聊天上下文信息 chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id, chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname, @@ -397,23 +395,22 @@ class ChatterPlanExecutor: chat_info_platform=chat_stream.platform, chat_info_create_time=chat_stream.create_time, chat_info_last_active_time=chat_stream.last_active_time, - # 群组信息(如果是群聊) chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None, chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None, - chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None, - + chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) + if chat_stream.group_info + else None, # 动作信息 actions=["bot_reply"], should_reply=False, - should_act=False + should_act=False, ) # 添加到chat_stream的已读消息中 chat_stream.context_manager.context.history_messages.append(bot_message) logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") - except Exception as e: logger.error(f"添加机器人回复到已读消息时出错: {e}") logger.debug(f"plan.chat_id: {plan.chat_id}") diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 3013afaa4..afe2241a2 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -60,7 +60,7 @@ class ChatterPlanFilter: prompt, used_message_id_list = await self._build_prompt(plan) plan.llm_prompt = prompt if global_config.debug.show_prompt: - logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡 + logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡 llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) @@ -104,24 +104,26 @@ class ChatterPlanFilter: # 预解析 action_type 来进行判断 thinking = item.get("thinking", "未提供思考过程") actions_obj = item.get("actions", {}) - + # 记录决策历史 - if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history: + if ( + hasattr(global_config.chat, "enable_decision_history") + and global_config.chat.enable_decision_history + ): action_types_to_log = [] actions_to_process_for_log = [] if isinstance(actions_obj, dict): actions_to_process_for_log.append(actions_obj) elif isinstance(actions_obj, list): actions_to_process_for_log.extend(actions_obj) - + for single_action in actions_to_process_for_log: if isinstance(single_action, dict): action_types_to_log.append(single_action.get("action_type", "no_action")) - + if thinking != "未提供思考过程" and action_types_to_log: await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log)) - # 处理actions字段可能是字典或列表的情况 if isinstance(actions_obj, dict): action_type = actions_obj.get("action_type", "no_action") @@ -579,15 +581,15 @@ class ChatterPlanFilter: ): reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}" action = "no_action" - #TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来) - #from src.common.data_models.database_data_model import DatabaseMessages + # TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来) + # from src.common.data_models.database_data_model import DatabaseMessages - #action_message_obj = None - #if target_message_obj: - #try: - #action_message_obj = DatabaseMessages(**target_message_obj) - #except Exception: - #logger.warning("无法将目标消息转换为DatabaseMessages对象") + # action_message_obj = None + # if target_message_obj: + # try: + # action_message_obj = DatabaseMessages(**target_message_obj) + # except Exception: + # logger.warning("无法将目标消息转换为DatabaseMessages对象") parsed_actions.append( ActionPlannerInfo( diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index a8ae019a0..8fc75b4ef 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager - from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import Plan from src.common.data_models.message_manager_data_model import StreamContext @@ -100,11 +99,11 @@ class ChatterActionPlanner: if context: context.chat_mode = ChatMode.FOCUS await self._sync_chat_mode_to_stream(context) - + # Normal模式下使用简化流程 if chat_mode == ChatMode.NORMAL: return await self._normal_mode_flow(context) - + # 在规划前,先进行动作修改 from src.chat.planner_actions.action_modifier import ActionModifier action_modifier = ActionModifier(self.action_manager, self.chat_id) @@ -184,12 +183,12 @@ class ChatterActionPlanner: for action in filtered_plan.decided_actions: if action.action_type in ["reply", "proactive_reply"] and action.action_message: # 提取目标消息ID - if hasattr(action.action_message, 'message_id'): + if hasattr(action.action_message, "message_id"): target_message_id = action.action_message.message_id elif isinstance(action.action_message, dict): - target_message_id = action.action_message.get('message_id') + target_message_id = action.action_message.get("message_id") break - + # 如果找到目标消息ID,检查是否已经在处理中 if target_message_id and context: if context.processing_message_id == target_message_id: @@ -215,7 +214,7 @@ class ChatterActionPlanner: # 6. 根据执行结果更新统计信息 self._update_stats_from_execution_result(execution_result) - + # 7. Focus模式下如果执行了reply动作,切换到Normal模式 if chat_mode == ChatMode.FOCUS and context: if filtered_plan.decided_actions: @@ -233,7 +232,7 @@ class ChatterActionPlanner: # 8. 清理处理标记 if context: context.processing_message_id = None - logger.debug(f"已清理处理标记,完成规划流程") + logger.debug("已清理处理标记,完成规划流程") # 9. 返回结果 return self._build_return_result(filtered_plan) @@ -262,7 +261,7 @@ class ChatterActionPlanner: return await self._enhanced_plan_flow(context) try: unread_messages = context.get_unread_messages() if context else [] - + if not unread_messages: logger.debug("Normal模式: 没有未读消息") from src.common.data_models.info_data_model import ActionPlannerInfo @@ -273,11 +272,11 @@ class ChatterActionPlanner: action_message=None, ) return [asdict(no_action)], None - + # 检查是否有消息达到reply阈值 should_reply = False target_message = None - + for message in unread_messages: message_should_reply = getattr(message, "should_reply", False) if message_should_reply: @@ -285,7 +284,7 @@ class ChatterActionPlanner: target_message = message logger.info(f"Normal模式: 消息 {message.message_id} 达到reply阈值") break - + if should_reply and target_message: # 检查是否正在处理相同的目标消息,防止重复回复 target_message_id = target_message.message_id @@ -302,26 +301,26 @@ class ChatterActionPlanner: action_message=None, ) return [asdict(no_action)], None - + # 记录当前正在处理的消息ID if context: context.processing_message_id = target_message_id logger.debug(f"Normal模式: 开始处理目标消息: {target_message_id}") - + # 达到reply阈值,直接进入回复流程 from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.plugin_system.base.component_types import ChatType - + # 构建目标消息字典 - 使用 flatten() 方法获取扁平化的字典 target_message_dict = target_message.flatten() - + reply_action = ActionPlannerInfo( action_type="reply", reasoning="Normal模式: 兴趣度达到阈值,直接回复", action_data={"target_message_id": target_message.message_id}, action_message=target_message, ) - + # Normal模式下直接构建最小化的Plan,跳过generator和action_modifier # 这样可以显著降低延迟 minimal_plan = Plan( @@ -330,25 +329,25 @@ class ChatterActionPlanner: mode=ChatMode.NORMAL, decided_actions=[reply_action], ) - + # 执行reply动作 execution_result = await self.executor.execute(minimal_plan) self._update_stats_from_execution_result(execution_result) - + logger.info("Normal模式: 执行reply动作完成") - + # 清理处理标记 if context: context.processing_message_id = None - logger.debug(f"Normal模式: 已清理处理标记") - + logger.debug("Normal模式: 已清理处理标记") + # 无论是否回复,都进行退出normal模式的判定 await self._check_exit_normal_mode(context) - + return [asdict(reply_action)], target_message_dict else: # 未达到reply阈值 - logger.debug(f"Normal模式: 未达到reply阈值") + logger.debug("Normal模式: 未达到reply阈值") from src.common.data_models.info_data_model import ActionPlannerInfo no_action = ActionPlannerInfo( action_type="no_action", @@ -356,12 +355,12 @@ class ChatterActionPlanner: action_data={}, action_message=None, ) - + # 无论是否回复,都进行退出normal模式的判定 await self._check_exit_normal_mode(context) - + return [asdict(no_action)], None - + except Exception as e: logger.error(f"Normal模式流程出错: {e}") self.planner_stats["failed_plans"] += 1 @@ -378,16 +377,16 @@ class ChatterActionPlanner: """ if not context: return - + try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(self.chat_id) if chat_manager else None - + if not chat_stream: return - + focus_energy = chat_stream.focus_energy # focus_energy越低,退出normal模式的概率越高 # 使用反比例函数: 退出概率 = 1 - focus_energy @@ -395,7 +394,7 @@ class ChatterActionPlanner: # 当focus_energy = 0.5时,退出概率 = 50% # 当focus_energy = 0.9时,退出概率 = 10% exit_probability = 1.0 - focus_energy - + import random if random.random() < exit_probability: logger.info(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 切换回focus模式") @@ -404,7 +403,7 @@ class ChatterActionPlanner: await self._sync_chat_mode_to_stream(context) else: logger.debug(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 保持normal模式") - + except Exception as e: logger.warning(f"检查退出Normal模式失败: {e}") @@ -412,7 +411,7 @@ class ChatterActionPlanner: """同步chat_mode到ChatStream""" try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() if chat_manager: chat_stream = await chat_manager.get_stream(context.stream_id) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py index 7b8ffdad1..b7f45b749 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py @@ -15,57 +15,57 @@ logger = get_logger("proactive_thinking_event") class ProactiveThinkingReplyHandler(BaseEventHandler): """Reply事件处理器 - + 当bot回复某个聊天流后: 1. 如果该聊天流的主动思考被暂停(因为抛出了话题),则恢复它 2. 无论是否暂停,都重置定时任务,重新开始计时 """ - + handler_name: str = "proactive_thinking_reply_handler" handler_description: str = "监听reply事件,重置主动思考定时任务" init_subscribe: list[EventType | str] = [EventType.AFTER_SEND] - + async def execute(self, kwargs: dict | None) -> HandlerResult: """处理reply事件 - + Args: kwargs: 事件参数,应包含 stream_id - + Returns: HandlerResult: 处理结果 """ logger.debug("[主动思考事件] ProactiveThinkingReplyHandler 开始执行") logger.debug(f"[主动思考事件] 接收到的参数: {kwargs}") - + if not kwargs: logger.debug("[主动思考事件] kwargs 为空,跳过处理") return HandlerResult(success=True, continue_process=True, message=None) - + stream_id = kwargs.get("stream_id") if not stream_id: - logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数") + logger.debug("[主动思考事件] Reply事件缺少stream_id参数") return HandlerResult(success=True, continue_process=True, message=None) - + logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件,stream_id={stream_id}") - + try: from src.config.config import global_config - + # 检查是否启用reply重置 if not global_config.proactive_thinking.reply_reset_enabled: - logger.debug(f"[主动思考事件] reply_reset_enabled 为 False,跳过重置") + logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置") return HandlerResult(success=True, continue_process=True, message=None) - + # 检查是否被暂停 was_paused = await proactive_thinking_scheduler.is_paused(stream_id) logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}") - + if was_paused: logger.debug(f"[主动思考事件] 检测到reply事件,聊天流 {stream_id} 之前因抛出话题而暂停,现在恢复") - + # 重置定时任务(这会自动清除暂停标记并创建新任务) success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id) - + if success: if was_paused: logger.info(f"✅ 聊天流 {stream_id} 主动思考已恢复并重置") @@ -73,82 +73,82 @@ class ProactiveThinkingReplyHandler(BaseEventHandler): logger.debug(f"✅ 聊天流 {stream_id} 主动思考任务已重置") else: logger.warning(f"❌ 重置聊天流 {stream_id} 主动思考任务失败") - + except Exception as e: logger.error(f"❌ 处理reply事件时出错: {e}", exc_info=True) - + # 总是继续处理其他handler return HandlerResult(success=True, continue_process=True, message=None) class ProactiveThinkingMessageHandler(BaseEventHandler): """消息事件处理器 - + 当收到消息时,如果该聊天流还没有主动思考任务,则创建一个 这样可以确保新的聊天流也能获得主动思考功能 """ - + handler_name: str = "proactive_thinking_message_handler" handler_description: str = "监听消息事件,为新聊天流创建主动思考任务" init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE] - + async def execute(self, kwargs: dict | None) -> HandlerResult: """处理消息事件 - + Args: kwargs: 事件参数,格式为 {"message": DatabaseMessages} - + Returns: HandlerResult: 处理结果 """ if not kwargs: return HandlerResult(success=True, continue_process=True, message=None) - + # 从 kwargs 中获取 DatabaseMessages 对象 message = kwargs.get("message") if not message or not hasattr(message, "chat_stream"): return HandlerResult(success=True, continue_process=True, message=None) - + # 从 chat_stream 获取 stream_id chat_stream = message.chat_stream if not chat_stream or not hasattr(chat_stream, "stream_id"): return HandlerResult(success=True, continue_process=True, message=None) - + stream_id = chat_stream.stream_id - + try: from src.config.config import global_config - + # 检查是否启用主动思考 if not global_config.proactive_thinking.enable: return HandlerResult(success=True, continue_process=True, message=None) - + # 检查该聊天流是否已经有任务 task_info = await proactive_thinking_scheduler.get_task_info(stream_id) if task_info: # 已经有任务,不需要创建 return HandlerResult(success=True, continue_process=True, message=None) - + # 从 message_info 获取平台和聊天ID信息 message_info = message.message_info platform = message_info.platform is_group = message_info.group_info is not None chat_id = message_info.group_info.group_id if is_group else message_info.user_info.user_id # type: ignore - + # 构造配置字符串 stream_config = f"{platform}:{chat_id}:{'group' if is_group else 'private'}" - + # 检查黑白名单 if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config): return HandlerResult(success=True, continue_process=True, message=None) - + # 创建主动思考任务 success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id) if success: logger.info(f"为新聊天流 {stream_id} 创建了主动思考任务") - + except Exception as e: logger.error(f"处理消息事件时出错: {e}", exc_info=True) - + # 总是继续处理其他handler return HandlerResult(success=True, continue_process=True, message=None) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py index 80de51f5f..26425d989 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py @@ -5,11 +5,10 @@ import json from datetime import datetime -from typing import Any, Literal, Optional +from typing import Any, Literal from sqlalchemy import select -from src.chat.express.expression_learner import expression_learner_manager from src.chat.express.expression_selector import expression_selector from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams @@ -17,42 +16,40 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import Individuality from src.llm_models.utils_model import LLMRequest -from src.plugin_system.apis import chat_api, message_api, send_api +from src.plugin_system.apis import message_api, send_api logger = get_logger("proactive_thinking_executor") class ProactiveThinkingPlanner: """主动思考规划器 - + 负责: 1. 搜集信息(聊天流印象、话题关键词、历史聊天记录) 2. 调用LLM决策:什么都不做/简单冒泡/抛出话题 3. 根据决策生成回复内容 """ - + def __init__(self): """初始化规划器""" try: self.decision_llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="proactive_thinking_decision" + model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision" ) self.reply_llm = LLMRequest( - model_set=model_config.model_task_config.replyer, - request_type="proactive_thinking_reply" + model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply" ) except Exception as e: logger.error(f"初始化LLM失败: {e}") self.decision_llm = None self.reply_llm = None - - async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def gather_context(self, stream_id: str) -> dict[str, Any] | None: """搜集聊天流的上下文信息 - + Args: stream_id: 聊天流ID - + Returns: dict: 包含所有上下文信息的字典,失败返回None """ @@ -62,27 +59,25 @@ class ProactiveThinkingPlanner: if not stream_data: logger.warning(f"无法获取聊天流 {stream_id} 的印象数据") return None - + # 2. 获取最近的聊天记录 recent_messages = await message_api.get_recent_messages( - chat_id=stream_id, - limit=20, - limit_mode="latest", - hours=24 + chat_id=stream_id, limit=20, limit_mode="latest", hours=24 ) - + recent_chat_history = "" if recent_messages: recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages) - + # 3. 获取bot人设 individuality = Individuality() bot_personality = await individuality.get_personality_block() - + # 4. 获取当前心情 current_mood = "感觉很平静" # 默认心情 try: from src.mood.mood_manager import mood_manager + mood_obj = mood_manager.get_mood_by_chat_id(stream_id) if mood_obj: await mood_obj._initialize() # 确保已初始化 @@ -90,19 +85,20 @@ class ProactiveThinkingPlanner: logger.debug(f"获取到聊天流 {stream_id} 的心情: {current_mood}") except Exception as e: logger.warning(f"获取心情失败,使用默认值: {e}") - + # 5. 获取上次决策 last_decision = None try: from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import ( proactive_thinking_scheduler, ) + last_decision = proactive_thinking_scheduler.get_last_decision(stream_id) if last_decision: logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}") except Exception as e: logger.warning(f"获取上次决策失败: {e}") - + # 6. 构建上下文 context = { "stream_id": stream_id, @@ -117,45 +113,45 @@ class ProactiveThinkingPlanner: "current_mood": current_mood, "last_decision": last_decision, } - + logger.debug(f"成功搜集聊天流 {stream_id} 的上下文信息") return context - + except Exception as e: logger.error(f"搜集上下文信息失败: {e}", exc_info=True) return None - - async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None: """从数据库获取聊天流印象数据""" try: async with get_db_session() as session: stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) result = await session.execute(stmt) stream = result.scalar_one_or_none() - + if not stream: return None - + return { "stream_name": stream.group_name or "私聊", "stream_impression_text": stream.stream_impression_text or "", "stream_chat_style": stream.stream_chat_style or "", "stream_topic_keywords": stream.stream_topic_keywords or "", - "stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5, + "stream_interest_score": float(stream.stream_interest_score) + if stream.stream_interest_score + else 0.5, } - + except Exception as e: logger.error(f"获取聊天流印象失败: {e}") return None - - async def make_decision( - self, context: dict[str, Any] - ) -> Optional[dict[str, Any]]: + + async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None: """使用LLM进行决策 - + Args: context: 上下文信息 - + Returns: dict: 决策结果,包含: - action: "do_nothing" | "simple_bubble" | "throw_topic" @@ -165,30 +161,28 @@ class ProactiveThinkingPlanner: if not self.decision_llm: logger.error("决策LLM未初始化") return None - + response = None try: decision_prompt = self._build_decision_prompt(context) - + if global_config.debug.show_prompt: logger.info(f"决策提示词:\n{decision_prompt}") - + response, _ = await self.decision_llm.generate_response_async(prompt=decision_prompt) - + if not response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析JSON响应 cleaned_response = self._clean_json_response(response) decision = json.loads(cleaned_response) - - logger.info( - f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}" - ) - + + logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}") + return decision - + except json.JSONDecodeError as e: logger.error(f"解析决策JSON失败: {e}") if response: @@ -197,18 +191,18 @@ class ProactiveThinkingPlanner: except Exception as e: logger.error(f"决策过程失败: {e}", exc_info=True) return None - + def _build_decision_prompt(self, context: dict[str, Any]) -> str: """构建决策提示词""" # 构建上次决策信息 last_decision_text = "" - if context.get('last_decision'): - last_dec = context['last_decision'] - last_action = last_dec.get('action', '未知') - last_reasoning = last_dec.get('reasoning', '无') - last_topic = last_dec.get('topic') - last_time = last_dec.get('timestamp', '未知') - + if context.get("last_decision"): + last_dec = context["last_decision"] + last_action = last_dec.get("action", "未知") + last_reasoning = last_dec.get("reasoning", "无") + last_topic = last_dec.get("topic") + last_time = last_dec.get("timestamp", "未知") + last_decision_text = f""" 【上次主动思考的决策】 - 时间: {last_time} @@ -216,24 +210,24 @@ class ProactiveThinkingPlanner: - 理由: {last_reasoning}""" if last_topic: last_decision_text += f"\n- 话题: {last_topic}" - - return f"""你是一个有着独特个性的AI助手。你的人设是: -{context['bot_personality']} -现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。 + return f"""你是一个有着独特个性的AI助手。你的人设是: +{context["bot_personality"]} + +现在是 {context["current_time"]},你正在考虑是否要主动在 "{context["stream_name"]}" 中说些什么。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境信息】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} -- 常见话题: {context['topic_keywords'] or '暂无'} -- 你的兴趣程度: {context['interest_score']:.2f}/1.0 +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} +- 常见话题: {context["topic_keywords"] or "暂无"} +- 你的兴趣程度: {context["interest_score"]:.2f}/1.0 {last_decision_text} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} 请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么: @@ -267,53 +261,50 @@ class ProactiveThinkingPlanner: 3. 只有在真的有话题想聊时才选择 throw_topic 4. 符合你的人设,不要太过热情或冷淡 """ - + async def generate_reply( - self, - context: dict[str, Any], - action: Literal["simple_bubble", "throw_topic"], - topic: Optional[str] = None - ) -> Optional[str]: + self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None + ) -> str | None: """生成回复内容 - + Args: context: 上下文信息 action: 动作类型 topic: (可选) 话题内容,当action=throw_topic时必须提供 - + Returns: str: 生成的回复文本,失败返回None """ if not self.reply_llm: logger.error("回复LLM未初始化") return None - + try: reply_prompt = await self._build_reply_prompt(context, action, topic) - + if global_config.debug.show_prompt: logger.info(f"回复提示词:\n{reply_prompt}") - + response, _ = await self.reply_llm.generate_response_async(prompt=reply_prompt) - + if not response: logger.warning("LLM未返回有效回复") return None - + logger.info(f"生成回复成功: {response[:50]}...") return response.strip() - + except Exception as e: logger.error(f"生成回复失败: {e}", exc_info=True) return None - + async def _get_expression_habits(self, stream_id: str, chat_history: str) -> str: """获取表达方式参考 - + Args: stream_id: 聊天流ID chat_history: 聊天历史 - + Returns: str: 格式化的表达方式参考文本 """ @@ -324,15 +315,15 @@ class ProactiveThinkingPlanner: chat_history=chat_history, target_message=None, # 主动思考没有target message max_num=6, # 主动思考时使用较少的表达方式 - min_num=2 + min_num=2, ) - + if not selected_expressions: return "" - + style_habits = [] grammar_habits = [] - + for expr in selected_expressions: if isinstance(expr, dict) and "situation" in expr and "style" in expr: expr_type = expr.get("type", "style") @@ -340,7 +331,7 @@ class ProactiveThinkingPlanner: grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") else: style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - + expression_block = "" if style_habits or grammar_habits: expression_block = "\n【表达方式参考】\n" @@ -349,41 +340,37 @@ class ProactiveThinkingPlanner: if grammar_habits: expression_block += "句法特点:\n" + "\n".join(grammar_habits) + "\n" expression_block += "注意:仅在情景合适时自然地使用这些表达,不要生硬套用。\n" - + return expression_block - + except Exception as e: logger.warning(f"获取表达方式失败: {e}") return "" - + async def _build_reply_prompt( - self, - context: dict[str, Any], - action: Literal["simple_bubble", "throw_topic"], - topic: Optional[str] + self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None ) -> str: """构建回复提示词""" # 获取表达方式参考 expression_habits = await self._get_expression_habits( - stream_id=context.get('stream_id', ''), - chat_history=context.get('recent_chat_history', '') + stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "") ) - + if action == "simple_bubble": return f"""你是一个有着独特个性的AI助手。你的人设是: -{context['bot_personality']} +{context["bot_personality"]} -现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。 +现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中简单冒个泡。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} {expression_habits} 请生成一条简短的消息,用于水群。要求: 1. 非常简短(5-15字) @@ -394,23 +381,23 @@ class ProactiveThinkingPlanner: 6. 如果有表达方式参考,在合适时自然使用 7. 合理参考历史记录 直接输出消息内容,不要解释:""" - + else: # throw_topic return f"""你是一个有着独特个性的AI助手。你的人设是: -{context['bot_personality']} +{context["bot_personality"]} -现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。 +现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中抛出一个话题。 【你当前的心情】 -{context.get('current_mood', '感觉很平静')} +{context.get("current_mood", "感觉很平静")} 【聊天环境】 -- 整体印象: {context['stream_impression']} -- 聊天风格: {context['chat_style']} -- 常见话题: {context['topic_keywords'] or '暂无'} +- 整体印象: {context["stream_impression"]} +- 聊天风格: {context["chat_style"]} +- 常见话题: {context["topic_keywords"] or "暂无"} 【最近的聊天记录】 -{context['recent_chat_history']} +{context["recent_chat_history"]} 【你想抛出的话题】 {topic} @@ -425,21 +412,21 @@ class ProactiveThinkingPlanner: 7. 如果有表达方式参考,在合适时自然使用 直接输出消息内容,不要解释:""" - + def _clean_json_response(self, response: str) -> str: """清理LLM响应中的JSON格式标记""" import re - + cleaned = response.strip() cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: - cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned[json_start : json_end + 1] + return cleaned.strip() @@ -452,7 +439,7 @@ _statistics: dict[str, dict[str, Any]] = {} def _update_statistics(stream_id: str, action: str): """更新统计数据 - + Args: stream_id: 聊天流ID action: 执行的动作 @@ -465,18 +452,18 @@ def _update_statistics(stream_id: str, action: str): "throw_topic_count": 0, "last_execution_time": None, } - + _statistics[stream_id]["total_executions"] += 1 _statistics[stream_id][f"{action}_count"] += 1 _statistics[stream_id]["last_execution_time"] = datetime.now().isoformat() -def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]: +def get_statistics(stream_id: str | None = None) -> dict[str, Any]: """获取统计数据 - + Args: stream_id: 聊天流ID,None表示获取所有统计 - + Returns: 统计数据字典 """ @@ -487,7 +474,7 @@ def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]: async def execute_proactive_thinking(stream_id: str): """执行主动思考(被调度器调用的回调函数) - + Args: stream_id: 聊天流ID """ @@ -495,125 +482,125 @@ async def execute_proactive_thinking(stream_id: str): from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import ( proactive_thinking_scheduler, ) - + config = global_config.proactive_thinking - + logger.debug(f"🤔 开始主动思考 {stream_id}") - + try: # 0. 前置检查 if proactive_thinking_scheduler._is_in_quiet_hours(): - logger.debug(f"安静时段,跳过") + logger.debug("安静时段,跳过") return - + if not proactive_thinking_scheduler._check_daily_limit(stream_id): - logger.debug(f"今日发言达上限") + logger.debug("今日发言达上限") return - + # 1. 搜集信息 - logger.debug(f"步骤1: 搜集上下文") + logger.debug("步骤1: 搜集上下文") context = await _planner.gather_context(stream_id) if not context: - logger.warning(f"无法搜集上下文,跳过") + logger.warning("无法搜集上下文,跳过") return # 检查兴趣分数阈值 - interest_score = context.get('interest_score', 0.5) + interest_score = context.get("interest_score", 0.5) if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score): - logger.debug(f"兴趣分数不在阈值范围内") + logger.debug("兴趣分数不在阈值范围内") return - + # 2. 进行决策 - logger.debug(f"步骤2: LLM决策") + logger.debug("步骤2: LLM决策") decision = await _planner.make_decision(context) if not decision: - logger.warning(f"决策失败,跳过") + logger.warning("决策失败,跳过") return - + action = decision.get("action", "do_nothing") reasoning = decision.get("reasoning", "无") - + # 记录决策日志 if config.log_decisions: logger.debug(f"决策: action={action}, reasoning={reasoning}") - + # 3. 根据决策执行相应动作 if action == "do_nothing": logger.debug(f"决策:什么都不做。理由:{reasoning}") proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None) return - + elif action == "simple_bubble": logger.info(f"💬 决策:冒个泡。理由:{reasoning}") - + proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None) - + # 生成简单的消息 - logger.debug(f"步骤3: 生成冒泡回复") + logger.debug("步骤3: 生成冒泡回复") reply = await _planner.generate_reply(context, "simple_bubble") if reply: await send_api.text_to_stream( stream_id=stream_id, text=reply, ) - logger.info(f"✅ 已发送冒泡消息") - + logger.info("✅ 已发送冒泡消息") + # 增加每日计数 proactive_thinking_scheduler._increment_daily_count(stream_id) - + # 更新统计 if config.enable_statistics: _update_statistics(stream_id, action) - + # 冒泡后暂停主动思考,等待用户回复 # 使用与 topic_throw 相同的冷却时间配置 if config.topic_throw_cooldown > 0: - logger.info(f"[主动思考] 步骤5:暂停任务") + logger.info("[主动思考] 步骤5:暂停任务") await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡") logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复") - logger.info(f"[主动思考] simple_bubble 执行完成") - + logger.info("[主动思考] simple_bubble 执行完成") + elif action == "throw_topic": topic = decision.get("topic", "") logger.info(f"[主动思考] 决策:抛出话题。理由:{reasoning},话题:{topic}") - + # 记录决策 proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, topic) - + if not topic: logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡") - logger.info(f"[主动思考] 步骤3:生成降级冒泡回复") + logger.info("[主动思考] 步骤3:生成降级冒泡回复") reply = await _planner.generate_reply(context, "simple_bubble") else: # 生成基于话题的消息 - logger.info(f"[主动思考] 步骤3:生成话题回复") + logger.info("[主动思考] 步骤3:生成话题回复") reply = await _planner.generate_reply(context, "throw_topic", topic) - + if reply: - logger.info(f"[主动思考] 步骤4:发送消息") + logger.info("[主动思考] 步骤4:发送消息") await send_api.text_to_stream( stream_id=stream_id, text=reply, ) logger.info(f"[主动思考] 已发送话题消息到 {stream_id}") - + # 增加每日计数 proactive_thinking_scheduler._increment_daily_count(stream_id) - + # 更新统计 if config.enable_statistics: _update_statistics(stream_id, action) - + # 抛出话题后暂停主动思考(如果配置了冷却时间) if config.topic_throw_cooldown > 0: - logger.info(f"[主动思考] 步骤5:暂停任务") + logger.info("[主动思考] 步骤5:暂停任务") await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题") logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复") - logger.info(f"[主动思考] throw_topic 执行完成") + logger.info("[主动思考] throw_topic 执行完成") logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成") - + except Exception as e: logger.error(f"[主动思考] 执行主动思考失败: {e}", exc_info=True) diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py index 33e90654d..47ed467cd 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_scheduler.py @@ -6,20 +6,17 @@ import asyncio from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams from src.common.logger import get_logger from src.schedule.unified_scheduler import TriggerType, unified_scheduler -from sqlalchemy import select logger = get_logger("proactive_thinking_scheduler") class ProactiveThinkingScheduler: """主动思考调度器 - + 负责为每个聊天流创建和管理主动思考任务。 特点: 1. 根据聊天流的兴趣分数动态计算触发间隔 @@ -32,27 +29,28 @@ class ProactiveThinkingScheduler: self._stream_schedules: dict[str, str] = {} # stream_id -> schedule_id self._paused_streams: set[str] = set() # 因抛出话题而暂停的聊天流 self._lock = asyncio.Lock() - + # 统计数据 self._statistics: dict[str, dict[str, Any]] = {} # stream_id -> 统计信息 self._daily_counts: dict[str, dict[str, int]] = {} # stream_id -> {date: count} - + # 历史决策记录:stream_id -> 上次决策信息 self._last_decisions: dict[str, dict[str, Any]] = {} - + # 从全局配置加载(延迟导入避免循环依赖) from src.config.config import global_config + self.config = global_config.proactive_thinking - + def _calculate_interval(self, focus_energy: float) -> int: """根据 focus_energy 计算触发间隔 - + Args: focus_energy: 聊天流的 focus_energy 值 (0.0-1.0) - + Returns: int: 触发间隔(秒) - + 公式: - focus_energy 越高,间隔越短(更频繁思考) - interval = base_interval * (factor - focus_energy) @@ -63,26 +61,26 @@ class ProactiveThinkingScheduler: # 如果不使用 focus_energy,直接返回基础间隔 if not self.config.use_interest_score: return self.config.base_interval - + # 确保值在有效范围内 focus_energy = max(0.0, min(1.0, focus_energy)) - + # 计算间隔:focus_energy 越高,系数越小,间隔越短 factor = self.config.interest_score_factor - focus_energy interval = int(self.config.base_interval * factor) - + # 限制在最小和最大间隔之间 interval = max(self.config.min_interval, min(self.config.max_interval, interval)) - - logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)") + + logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)") return interval - + def _check_whitelist_blacklist(self, stream_config: str) -> bool: """检查聊天流是否通过黑白名单验证 - + Args: stream_config: 聊天流配置字符串,格式: "platform:id:type" - + Returns: bool: True表示允许主动思考,False表示拒绝 """ @@ -91,148 +89,148 @@ class ProactiveThinkingScheduler: if len(parts) != 3: logger.warning(f"无效的stream_config格式: {stream_config}") return False - + is_private = parts[2] == "private" - + # 检查基础开关 if is_private and not self.config.enable_in_private: return False if not is_private and not self.config.enable_in_group: return False - + # 黑名单检查(优先级高) if self.config.blacklist_mode: blacklist = self.config.blacklist_private if is_private else self.config.blacklist_group if stream_config in blacklist: logger.debug(f"聊天流 {stream_config} 在黑名单中,拒绝主动思考") return False - + # 白名单检查 if self.config.whitelist_mode: whitelist = self.config.whitelist_private if is_private else self.config.whitelist_group if stream_config not in whitelist: logger.debug(f"聊天流 {stream_config} 不在白名单中,拒绝主动思考") return False - + return True - + def _check_interest_score_threshold(self, interest_score: float) -> bool: """检查兴趣分数是否在阈值范围内 - + Args: interest_score: 兴趣分数 - + Returns: bool: True表示在范围内 """ if interest_score < self.config.min_interest_score: logger.debug(f"兴趣分数 {interest_score:.2f} 低于最低阈值 {self.config.min_interest_score}") return False - + if interest_score > self.config.max_interest_score: logger.debug(f"兴趣分数 {interest_score:.2f} 高于最高阈值 {self.config.max_interest_score}") return False - + return True - + def _check_daily_limit(self, stream_id: str) -> bool: """检查今日主动发言次数是否超限 - + Args: stream_id: 聊天流ID - + Returns: bool: True表示未超限 """ if self.config.max_daily_proactive == 0: return True # 不限制 - + today = datetime.now().strftime("%Y-%m-%d") - + if stream_id not in self._daily_counts: self._daily_counts[stream_id] = {} - + # 清理过期日期的数据 for date in list(self._daily_counts[stream_id].keys()): if date != today: del self._daily_counts[stream_id][date] - + count = self._daily_counts[stream_id].get(today, 0) - + if count >= self.config.max_daily_proactive: logger.debug(f"聊天流 {stream_id} 今日主动发言次数已达上限 ({count}/{self.config.max_daily_proactive})") return False - + return True - + def _increment_daily_count(self, stream_id: str): """增加今日主动发言计数""" today = datetime.now().strftime("%Y-%m-%d") - + if stream_id not in self._daily_counts: self._daily_counts[stream_id] = {} - + self._daily_counts[stream_id][today] = self._daily_counts[stream_id].get(today, 0) + 1 - + def _is_in_quiet_hours(self) -> bool: """检查当前是否在安静时段 - + Returns: bool: True表示在安静时段 """ if not self.config.enable_time_strategy: return False - + now = datetime.now() current_time = now.strftime("%H:%M") - + start = self.config.quiet_hours_start end = self.config.quiet_hours_end - + # 处理跨日的情况(如23:00-07:00) if start <= end: return start <= current_time <= end else: return current_time >= start or current_time <= end - + async def _get_stream_focus_energy(self, stream_id: str) -> float: """获取聊天流的 focus_energy - + Args: stream_id: 聊天流ID - + Returns: float: focus_energy 值,默认0.5 """ try: # 从聊天管理器获取聊天流 from src.chat.message_receive.chat_stream import get_chat_manager - - logger.debug(f"[调度器] 获取聊天管理器") + + logger.debug("[调度器] 获取聊天管理器") chat_manager = get_chat_manager() logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}") chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream: # 计算并获取最新的 focus_energy - logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy") + logger.debug("[调度器] 找到聊天流,开始计算 focus_energy") focus_energy = await chat_stream.calculate_focus_energy() logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}") return focus_energy else: logger.warning(f"[调度器] ⚠️ 未找到聊天流 {stream_id},使用默认 focus_energy=0.5") return 0.5 - + except Exception as e: logger.error(f"[调度器] ❌ 获取聊天流 {stream_id} 的 focus_energy 失败: {e}", exc_info=True) return 0.5 - + async def schedule_proactive_thinking(self, stream_id: str) -> bool: """为聊天流创建或重置主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功创建/重置任务 """ @@ -243,25 +241,25 @@ class ProactiveThinkingScheduler: if stream_id in self._paused_streams: logger.debug(f"[调度器] 清除聊天流 {stream_id} 的暂停标记") self._paused_streams.discard(stream_id) - + # 如果已经有任务,先移除 if stream_id in self._stream_schedules: old_schedule_id = self._stream_schedules[stream_id] logger.debug(f"[调度器] 移除聊天流 {stream_id} 的旧任务") await unified_scheduler.remove_schedule(old_schedule_id) - + # 获取 focus_energy 并计算间隔 focus_energy = await self._get_stream_focus_energy(stream_id) logger.debug(f"[调度器] focus_energy={focus_energy:.3f}") - + interval_seconds = self._calculate_interval(focus_energy) - logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)") - + logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)") + # 导入回调函数(延迟导入避免循环依赖) from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import ( execute_proactive_thinking, ) - + # 创建新任务 schedule_id = await unified_scheduler.create_schedule( callback=execute_proactive_thinking, @@ -273,34 +271,34 @@ class ProactiveThinkingScheduler: task_name=f"ProactiveThinking-{stream_id}", callback_args=(stream_id,), ) - + self._stream_schedules[stream_id] = schedule_id - + # 计算下次触发时间 next_run_time = datetime.now() + timedelta(seconds=interval_seconds) - + logger.info( f"✅ 聊天流 {stream_id} 主动思考任务已创建 | " f"Focus: {focus_energy:.3f} | " - f"间隔: {interval_seconds/60:.1f}分钟 | " + f"间隔: {interval_seconds / 60:.1f}分钟 | " f"下次: {next_run_time.strftime('%H:%M:%S')}" ) return True - + except Exception as e: logger.error(f"❌ 创建主动思考任务失败 {stream_id}: {e}", exc_info=True) return False - + async def pause_proactive_thinking(self, stream_id: str, reason: str = "抛出话题") -> bool: """暂停聊天流的主动思考任务 - + 当选择"抛出话题"后,应该暂停该聊天流的主动思考, 直到bot至少执行过一次reply后才恢复。 - + Args: stream_id: 聊天流ID reason: 暂停原因 - + Returns: bool: 是否成功暂停 """ @@ -309,26 +307,26 @@ class ProactiveThinkingScheduler: if stream_id not in self._stream_schedules: logger.warning(f"尝试暂停不存在的任务: {stream_id}") return False - + schedule_id = self._stream_schedules[stream_id] success = await unified_scheduler.pause_schedule(schedule_id) - + if success: self._paused_streams.add(stream_id) logger.info(f"⏸️ 暂停主动思考 {stream_id},原因: {reason}") - + return success - - except Exception as e: + + except Exception: # 错误日志已在上面记录 return False - + async def resume_proactive_thinking(self, stream_id: str) -> bool: """恢复聊天流的主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功恢复 """ @@ -337,26 +335,26 @@ class ProactiveThinkingScheduler: if stream_id not in self._stream_schedules: logger.warning(f"尝试恢复不存在的任务: {stream_id}") return False - + schedule_id = self._stream_schedules[stream_id] success = await unified_scheduler.resume_schedule(schedule_id) - + if success: self._paused_streams.discard(stream_id) logger.info(f"▶️ 恢复主动思考 {stream_id}") - + return success - + except Exception as e: logger.error(f"❌ 恢复主动思考失败 {stream_id}: {e}", exc_info=True) return False - + async def cancel_proactive_thinking(self, stream_id: str) -> bool: """取消聊天流的主动思考任务 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否成功取消 """ @@ -364,55 +362,55 @@ class ProactiveThinkingScheduler: async with self._lock: if stream_id not in self._stream_schedules: return True # 已经不存在,视为成功 - + schedule_id = self._stream_schedules.pop(stream_id) self._paused_streams.discard(stream_id) - + success = await unified_scheduler.remove_schedule(schedule_id) logger.debug(f"⏹️ 取消主动思考 {stream_id}") - + return success - + except Exception as e: logger.error(f"❌ 取消主动思考失败 {stream_id}: {e}", exc_info=True) return False - + async def is_paused(self, stream_id: str) -> bool: """检查聊天流的主动思考是否被暂停 - + Args: stream_id: 聊天流ID - + Returns: bool: 是否暂停中 """ async with self._lock: return stream_id in self._paused_streams - - async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]: + + async def get_task_info(self, stream_id: str) -> dict[str, Any] | None: """获取聊天流的主动思考任务信息 - + Args: stream_id: 聊天流ID - + Returns: dict: 任务信息,如果不存在返回None """ async with self._lock: if stream_id not in self._stream_schedules: return None - + schedule_id = self._stream_schedules[stream_id] task_info = await unified_scheduler.get_task_info(schedule_id) - + if task_info: task_info["is_paused_for_topic"] = stream_id in self._paused_streams - + return task_info - + async def list_all_tasks(self) -> list[dict[str, Any]]: """列出所有主动思考任务 - + Returns: list: 任务信息列表 """ @@ -425,10 +423,10 @@ class ProactiveThinkingScheduler: task_info["is_paused_for_topic"] = stream_id in self._paused_streams tasks.append(task_info) return tasks - + def get_statistics(self) -> dict[str, Any]: """获取调度器统计信息 - + Returns: dict: 统计信息 """ @@ -437,51 +435,48 @@ class ProactiveThinkingScheduler: "paused_for_topic": len(self._paused_streams), "active_tasks": len(self._stream_schedules) - len(self._paused_streams), } - + async def log_next_trigger_times(self, max_streams: int = 10): """在日志中输出聊天流的下次触发时间 - + Args: max_streams: 最多显示多少个聊天流,0表示全部 """ logger.info("=" * 60) logger.info("主动思考任务状态") logger.info("=" * 60) - + tasks = await self.list_all_tasks() - + if not tasks: logger.info("当前没有活跃的主动思考任务") logger.info("=" * 60) return - + # 按下次触发时间排序 - tasks_sorted = sorted( - tasks, - key=lambda x: x.get("next_run_time", datetime.max) or datetime.max - ) - + tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max) + # 限制显示数量 if max_streams > 0: tasks_sorted = tasks_sorted[:max_streams] - + logger.info(f"共有 {len(self._stream_schedules)} 个任务,显示前 {len(tasks_sorted)} 个") logger.info("") - + for i, task in enumerate(tasks_sorted, 1): stream_id = task.get("stream_id", "Unknown") next_run = task.get("next_run_time") is_paused = task.get("is_paused_for_topic", False) - + # 获取聊天流名称(如果可能) stream_name = stream_id[:16] + "..." if len(stream_id) > 16 else stream_id - + if next_run: # 计算剩余时间 now = datetime.now() remaining = next_run - now remaining_seconds = int(remaining.total_seconds()) - + if remaining_seconds < 0: time_str = "已过期(待执行)" elif remaining_seconds < 60: @@ -492,28 +487,25 @@ class ProactiveThinkingScheduler: hours = remaining_seconds // 3600 minutes = (remaining_seconds % 3600) // 60 time_str = f"{hours}小时{minutes}分钟后" - + status = "⏸️ 暂停中" if is_paused else "✅ 活跃" - + logger.info( f"[{i:2d}] {status} | {stream_name}\n" f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})" ) else: - logger.info( - f"[{i:2d}] ⚠️ 未知 | {stream_name}\n" - f" 下次触发: 未设置" - ) - + logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置") + logger.info("") logger.info("=" * 60) - - def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]: + + def get_last_decision(self, stream_id: str) -> dict[str, Any] | None: """获取聊天流的上次主动思考决策 - + Args: stream_id: 聊天流ID - + Returns: dict: 上次决策信息,包含: - action: "do_nothing" | "simple_bubble" | "throw_topic" @@ -523,16 +515,10 @@ class ProactiveThinkingScheduler: None: 如果没有历史决策 """ return self._last_decisions.get(stream_id) - - def record_decision( - self, - stream_id: str, - action: str, - reasoning: str, - topic: Optional[str] = None - ) -> None: + + def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None: """记录聊天流的主动思考决策 - + Args: stream_id: 聊天流ID action: 决策动作 diff --git a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py index b4fc68526..00240b024 100644 --- a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py @@ -4,10 +4,10 @@ 通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数 """ -import orjson import time from typing import Any +import orjson from sqlalchemy import select from src.common.database.sqlalchemy_database_api import get_db_session @@ -42,7 +42,7 @@ class UserProfileTool(BaseTool): def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None): super().__init__(plugin_config, chat_stream) - + # 初始化用于二步调用的LLM try: self.profile_llm = LLMRequest( @@ -84,24 +84,24 @@ class UserProfileTool(BaseTool): "id": "user_profile_update", "content": "错误:必须提供目标用户ID" } - + # 从LLM传入的参数 new_aliases = function_args.get("user_aliases", "") new_impression = function_args.get("impression_description", "") new_keywords = function_args.get("preference_keywords", "") new_score = function_args.get("affection_score") - + # 从数据库获取现有用户画像 existing_profile = await self._get_user_profile(target_user_id) - + # 如果LLM没有传入任何有效参数,返回提示 if not any([new_aliases, new_impression, new_keywords, new_score is not None]): return { "type": "info", "id": target_user_id, - "content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)" + "content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)" } - + # 调用LLM进行二步决策 if self.profile_llm is None: logger.error("LLM未正确初始化,无法执行二步调用") @@ -110,7 +110,7 @@ class UserProfileTool(BaseTool): "id": target_user_id, "content": "系统错误:LLM未正确初始化" } - + final_profile = await self._llm_decide_final_profile( target_user_id=target_user_id, existing_profile=existing_profile, @@ -119,17 +119,17 @@ class UserProfileTool(BaseTool): new_keywords=new_keywords, new_score=new_score ) - + if not final_profile: return { "type": "error", "id": target_user_id, "content": "LLM决策失败,无法更新用户画像" } - + # 更新数据库 await self._update_user_profile_in_db(target_user_id, final_profile) - + # 构建返回信息 updates = [] if final_profile.get("user_aliases"): @@ -140,22 +140,22 @@ class UserProfileTool(BaseTool): updates.append(f"偏好: {final_profile['preference_keywords']}") if final_profile.get("relationship_score") is not None: updates.append(f"好感分: {final_profile['relationship_score']:.2f}") - + result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates) logger.info(f"用户画像更新成功: {target_user_id}") - + return { "type": "user_profile_update", "id": target_user_id, "content": result_text } - + except Exception as e: logger.error(f"用户画像更新失败: {e}", exc_info=True) return { "type": "error", "id": function_args.get("target_user_id", "unknown"), - "content": f"用户画像更新失败: {str(e)}" + "content": f"用户画像更新失败: {e!s}" } async def _get_user_profile(self, user_id: str) -> dict[str, Any]: @@ -172,7 +172,7 @@ class UserProfileTool(BaseTool): stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) profile = result.scalar_one_or_none() - + if profile: return { "user_name": profile.user_name or user_id, @@ -227,7 +227,7 @@ class UserProfileTool(BaseTool): from src.individuality.individuality import Individuality individuality = Individuality() bot_personality = await individuality.get_personality_block() - + prompt = f""" 你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} @@ -261,18 +261,18 @@ class UserProfileTool(BaseTool): "reasoning": "你的决策理由" }} """ - + # 调用LLM llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt) - + if not llm_response: logger.warning("LLM未返回有效响应") return None - + # 清理并解析响应 cleaned_response = self._clean_llm_json_response(llm_response) response_data = orjson.loads(cleaned_response) - + # 提取最终决定的数据 final_profile = { "user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")), @@ -280,12 +280,12 @@ class UserProfileTool(BaseTool): "preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")), "relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))), } - + logger.info(f"LLM决策完成: {target_user_id}") logger.debug(f"决策理由: {response_data.get('reasoning', '无')}") - + return final_profile - + except orjson.JSONDecodeError as e: logger.error(f"LLM响应JSON解析失败: {e}") logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}") @@ -303,12 +303,12 @@ class UserProfileTool(BaseTool): """ try: current_time = time.time() - + async with get_db_session() as session: stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 更新现有记录 existing.user_aliases = profile.get("user_aliases", "") @@ -328,10 +328,10 @@ class UserProfileTool(BaseTool): last_updated=current_time ) session.add(new_profile) - + await session.commit() logger.info(f"用户画像已更新到数据库: {user_id}") - + except Exception as e: logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True) raise @@ -347,24 +347,24 @@ class UserProfileTool(BaseTool): """ try: import re - + cleaned = response.strip() - + # 移除 ```json 或 ``` 等标记 cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) - + # 尝试找到JSON对象的开始和结束 json_start = cleaned.find("{") json_end = cleaned.rfind("}") - + if json_start != -1 and json_end != -1 and json_end > json_start: cleaned = cleaned[json_start:json_end + 1] - + cleaned = cleaned.strip() - + return cleaned - + except Exception as e: logger.warning(f"清理LLM响应失败: {e}") return response diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 8d75ca2fd..179d7997a 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -261,7 +261,7 @@ class SetEmojiLikeAction(BaseAction): elif isinstance(self.action_message, dict): message_id = self.action_message.get("message_id") logger.info(f"获取到的消息ID: {message_id}") - + if not message_id: logger.error("未提供有效的消息或消息ID") await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False) @@ -279,7 +279,7 @@ class SetEmojiLikeAction(BaseAction): context_text = self.action_message.processed_plain_text or "" else: context_text = self.action_message.get("processed_plain_text", "") - + if not context_text: logger.error("无法找到动作选择的原始消息文本") return False, "无法找到动作选择的原始消息文本" diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index 44e0082e0..a47a41ea1 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -5,7 +5,7 @@ Web Search Tool Plugin """ from src.common.logger import get_logger -from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin +from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin from src.plugin_system.apis import config_api from .tools.url_parser import URLParserTool diff --git a/src/schedule/unified_scheduler.py b/src/schedule/unified_scheduler.py index 0a2fc859b..aff48ee83 100644 --- a/src/schedule/unified_scheduler.py +++ b/src/schedule/unified_scheduler.py @@ -5,9 +5,10 @@ import asyncio import uuid -from datetime import datetime, timedelta +from collections.abc import Awaitable, Callable +from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable, Optional +from typing import Any from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType @@ -33,9 +34,9 @@ class ScheduleTask: trigger_type: TriggerType, trigger_config: dict[str, Any], is_recurring: bool = False, - task_name: Optional[str] = None, - callback_args: Optional[tuple] = None, - callback_kwargs: Optional[dict] = None, + task_name: str | None = None, + callback_args: tuple | None = None, + callback_kwargs: dict | None = None, ): self.schedule_id = schedule_id self.callback = callback @@ -46,7 +47,7 @@ class ScheduleTask: self.callback_args = callback_args or () self.callback_kwargs = callback_kwargs or {} self.created_at = datetime.now() - self.last_triggered_at: Optional[datetime] = None + self.last_triggered_at: datetime | None = None self.trigger_count = 0 self.is_active = True @@ -77,7 +78,7 @@ class UnifiedScheduler: def __init__(self): self._tasks: dict[str, ScheduleTask] = {} self._running = False - self._check_task: Optional[asyncio.Task] = None + self._check_task: asyncio.Task | None = None self._lock = asyncio.Lock() self._event_subscriptions: set[str] = set() # 追踪已订阅的事件 @@ -111,7 +112,7 @@ class UnifiedScheduler: for task in event_tasks: try: logger.debug(f"[调度器] 执行事件任务: {task.task_name}") - + # 执行回调,传入事件参数 if event_params: if asyncio.iscoroutinefunction(task.callback): @@ -127,7 +128,7 @@ class UnifiedScheduler: # 如果不是循环任务,标记为删除 if not task.is_recurring: tasks_to_remove.append(task.schedule_id) - + logger.debug(f"[调度器] 事件任务 {task.task_name} 执行完成") except Exception as e: @@ -204,11 +205,11 @@ class UnifiedScheduler: 注意:为了避免死锁,回调执行必须在锁外进行 """ current_time = datetime.now() - + # 第一阶段:在锁内快速收集需要触发的任务 async with self._lock: tasks_to_trigger = [] - + for schedule_id, task in list(self._tasks.items()): if not task.is_active: continue @@ -219,14 +220,14 @@ class UnifiedScheduler: tasks_to_trigger.append(task) except Exception as e: logger.error(f"检查任务 {task.task_name} 时发生错误: {e}", exc_info=True) - + # 第二阶段:在锁外执行回调(避免死锁) tasks_to_remove = [] - + for task in tasks_to_trigger: try: logger.debug(f"[调度器] 触发定时任务: {task.task_name}") - + # 执行回调 await self._execute_callback(task) @@ -339,9 +340,9 @@ class UnifiedScheduler: trigger_type: TriggerType, trigger_config: dict[str, Any], is_recurring: bool = False, - task_name: Optional[str] = None, - callback_args: Optional[tuple] = None, - callback_kwargs: Optional[dict] = None, + task_name: str | None = None, + callback_args: tuple | None = None, + callback_kwargs: dict | None = None, ) -> str: """创建调度任务(详细注释见文档)""" schedule_id = str(uuid.uuid4()) @@ -430,7 +431,7 @@ class UnifiedScheduler: logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)") return True - async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]: + async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None: """获取任务信息""" async with self._lock: task = self._tasks.get(schedule_id) @@ -449,7 +450,7 @@ class UnifiedScheduler: "trigger_config": task.trigger_config.copy(), } - async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]: + async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]: """列出所有任务或指定类型的任务""" async with self._lock: tasks = [] @@ -499,11 +500,11 @@ async def initialize_scheduler(): logger.info("正在启动统一调度器...") await unified_scheduler.start() logger.info("统一调度器启动成功") - + # 获取初始统计信息 stats = unified_scheduler.get_statistics() logger.info(f"调度器状态: {stats}") - + except Exception as e: logger.error(f"启动统一调度器失败: {e}", exc_info=True) raise @@ -516,20 +517,20 @@ async def shutdown_scheduler(): """ try: logger.info("正在关闭统一调度器...") - + # 显示最终统计 stats = unified_scheduler.get_statistics() logger.info(f"调度器最终统计: {stats}") - + # 列出剩余任务 remaining_tasks = await unified_scheduler.list_tasks() if remaining_tasks: logger.warning(f"检测到 {len(remaining_tasks)} 个未清理的任务:") for task in remaining_tasks: logger.warning(f" - {task['task_name']} (ID: {task['schedule_id'][:8]}...)") - + await unified_scheduler.stop() logger.info("统一调度器已关闭") - + except Exception as e: - logger.error(f"关闭统一调度器失败: {e}", exc_info=True) \ No newline at end of file + logger.error(f"关闭统一调度器失败: {e}", exc_info=True)