From 3edcc9d1696315dfa525ba12e2abb52fa75d5440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Mon, 8 Dec 2025 15:48:40 +0800 Subject: [PATCH] ruff --- scripts/log_viewer.py | 1 + scripts/memory_cleaner.py | 194 +++++---- scripts/migrate_database.py | 85 ++-- scripts/test_bedrock_client.py | 4 +- src/api/memory_visualizer_router.py | 9 +- src/chat/chatter_manager.py | 18 +- src/chat/emoji_system/emoji_manager.py | 4 +- src/chat/energy_system/energy_manager.py | 3 +- src/chat/express/expression_learner.py | 12 +- src/chat/express/expression_selector.py | 2 +- src/chat/express/style_learner.py | 4 +- .../interest_system/bot_interest_manager.py | 18 +- src/chat/knowledge/embedding_store.py | 2 +- src/chat/knowledge/qa_manager.py | 18 +- .../message_manager/distribution_manager.py | 74 ++-- src/chat/message_manager/message_manager.py | 4 +- src/chat/message_receive/chat_stream.py | 5 +- src/chat/message_receive/message_handler.py | 2 +- src/chat/message_receive/message_processor.py | 42 +- src/chat/message_receive/storage.py | 6 +- .../message_receive/uni_message_sender.py | 3 +- src/chat/planner_actions/action_manager.py | 6 +- src/chat/planner_actions/action_modifier.py | 11 +- src/chat/replyer/default_generator.py | 11 +- src/chat/replyer/replyer_manager.py | 4 +- src/chat/utils/chat_message_builder.py | 22 +- src/chat/utils/prompt.py | 6 +- src/chat/utils/prompt_component_manager.py | 2 +- src/chat/utils/report_generator.py | 22 +- src/chat/utils/statistic.py | 60 +-- src/chat/utils/typo_generator.py | 18 +- src/chat/utils/utils.py | 4 +- src/chat/utils/utils_image.py | 2 +- src/chat/utils/utils_video_legacy.py | 1 - src/common/cache_manager.py | 46 +- src/common/data_models/database_data_model.py | 78 ++-- src/common/database/api/crud.py | 2 - src/common/database/api/query.py | 1 - src/common/database/api/specialized.py | 4 +- src/common/database/compatibility/adapter.py | 2 +- src/common/database/core/migration.py | 28 +- src/common/database/core/models.py | 10 +- .../database/optimization/batch_scheduler.py | 2 +- .../database/optimization/cache_manager.py | 4 +- src/common/database/utils/decorators.py | 2 +- src/common/logger.py | 5 +- src/common/mem_monitor.py | 34 +- src/common/security.py | 2 +- src/common/server.py | 5 +- src/config/api_ada_configs.py | 2 +- src/config/config.py | 6 +- src/config/official_configs.py | 20 +- src/individuality/individuality.py | 3 +- .../model_client/aiohttp_gemini_client.py | 2 +- src/llm_models/model_client/bedrock_client.py | 4 +- src/llm_models/payload_content/message.py | 6 +- .../payload_content/system_prompt.py | 2 +- src/llm_models/utils_model.py | 2 +- src/main.py | 15 +- src/memory_graph/long_term_manager.py | 78 ++-- src/memory_graph/manager.py | 20 +- src/memory_graph/manager_singleton.py | 2 +- src/memory_graph/models.py | 1 - src/memory_graph/perceptual_manager.py | 12 +- .../plugin_tools/memory_plugin_tools.py | 2 +- src/memory_graph/short_term_manager.py | 19 +- src/memory_graph/storage/graph_store.py | 25 +- src/memory_graph/storage/persistence.py | 4 +- src/memory_graph/tools/memory_tools.py | 6 +- src/memory_graph/unified_manager.py | 12 +- src/memory_graph/utils/__init__.py | 8 +- src/memory_graph/utils/embeddings.py | 2 +- src/memory_graph/utils/path_expansion.py | 7 +- src/memory_graph/utils/similarity.py | 6 +- .../utils/three_tier_formatter.py | 12 +- src/person_info/person_info.py | 22 +- src/person_info/relationship_fetcher.py | 6 +- src/plugin_system/__init__.py | 4 +- src/plugin_system/apis/chat_api.py | 2 +- src/plugin_system/apis/cross_context_api.py | 26 +- src/plugin_system/apis/permission_api.py | 137 +++--- src/plugin_system/apis/plugin_manage_api.py | 2 +- src/plugin_system/apis/schedule_api.py | 2 +- src/plugin_system/apis/send_api.py | 2 + src/plugin_system/apis/unified_scheduler.py | 6 +- src/plugin_system/base/__init__.py | 2 +- src/plugin_system/base/base_adapter.py | 81 ++-- src/plugin_system/core/adapter_manager.py | 87 ++-- src/plugin_system/core/event_manager.py | 2 +- src/plugin_system/core/permission_manager.py | 4 +- src/plugin_system/core/plugin_manager.py | 30 +- src/plugin_system/core/stream_tool_history.py | 15 +- src/plugin_system/core/tool_use.py | 5 +- .../services/relationship_service.py | 2 +- .../affinity_flow_chatter/actions/reply.py | 58 +-- .../planner/plan_filter.py | 4 +- .../affinity_flow_chatter/planner/planner.py | 10 +- .../proactive/proactive_thinking_executor.py | 18 +- .../affinity_flow_chatter/tools/__init__.py | 2 +- .../tools/user_fact_tool.py | 42 +- .../tools/user_profile_tool.py | 123 +++--- .../built_in/anti_injection_plugin/prompts.py | 2 +- src/plugins/built_in/core_actions/emoji.py | 3 +- .../built_in/kokoro_flow_chatter/__init__.py | 79 ++-- .../kokoro_flow_chatter/actions/reply.py | 76 ++-- .../built_in/kokoro_flow_chatter/chatter.py | 104 ++--- .../built_in/kokoro_flow_chatter/config.py | 167 ++++---- .../kokoro_flow_chatter/context_builder.py | 186 ++++----- .../built_in/kokoro_flow_chatter/models.py | 84 ++-- .../built_in/kokoro_flow_chatter/planner.py | 36 +- .../built_in/kokoro_flow_chatter/plugin.py | 25 +- .../kokoro_flow_chatter/proactive_thinker.py | 256 ++++++------ .../kokoro_flow_chatter/prompt/__init__.py | 2 +- .../kokoro_flow_chatter/prompt/builder.py | 392 +++++++++--------- .../prompt_modules_unified.py | 188 ++++----- .../built_in/kokoro_flow_chatter/replyer.py | 36 +- .../built_in/kokoro_flow_chatter/session.py | 140 +++---- .../built_in/kokoro_flow_chatter/unified.py | 184 ++++---- .../services/content_service.py | 2 +- .../services/image_service.py | 67 ++- .../services/qzone_service.py | 3 +- .../services/scheduler_service.py | 2 - src/plugins/built_in/napcat_adapter/plugin.py | 44 +- .../napcat_adapter/src/event_models.py | 12 +- .../src/handlers/to_core/message_handler.py | 66 ++- .../handlers/to_core/meta_event_handler.py | 10 +- .../src/handlers/to_core/notice_handler.py | 52 +-- .../src/handlers/to_napcat/send_handler.py | 58 +-- .../napcat_adapter/src/handlers/utils.py | 23 +- .../src/handlers/video_handler.py | 7 +- .../siliconflow_api_index_tts/plugin.py | 92 ++-- .../siliconflow_api_index_tts/upload_voice.py | 67 ++- .../built_in/system_management/plugin.py | 35 +- src/plugins/built_in/tts_plugin/plugin.py | 8 +- .../tts_voice_plugin/actions/tts_action.py | 44 +- .../web_search_tool/tools/url_parser.py | 2 +- ui_log_adapter.py | 1 - 137 files changed, 2194 insertions(+), 2237 deletions(-) diff --git a/scripts/log_viewer.py b/scripts/log_viewer.py index 88fff24ac..4ea6485fb 100644 --- a/scripts/log_viewer.py +++ b/scripts/log_viewer.py @@ -31,6 +31,7 @@ if str(PROJECT_ROOT) not in sys.path: # 切换工作目录到项目根目录 import os + os.chdir(PROJECT_ROOT) # 日志目录 diff --git a/scripts/memory_cleaner.py b/scripts/memory_cleaner.py index de388b207..bf75c6570 100644 --- a/scripts/memory_cleaner.py +++ b/scripts/memory_cleaner.py @@ -25,8 +25,6 @@ sys.path.insert(0, str(project_root)) from src.config.config import model_config from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config - # ==================== 配置 ==================== @@ -82,7 +80,7 @@ EVALUATION_PROMPT = """你是一个非常严格的记忆价值评估专家。你 **保留示例**: - "用户张三说他是程序员,在杭州工作" ✅ -- "李四说他喜欢打篮球,每周三都会去" ✅ +- "李四说他喜欢打篮球,每周三都会去" ✅ - "小明说他女朋友叫小红,在一起2年了" ✅ - "用户A的生日是3月15日" ✅ @@ -111,7 +109,7 @@ EVALUATION_PROMPT = """你是一个非常严格的记忆价值评估专家。你 }}, {{ "memory_id": "另一个ID", - "action": "keep", + "action": "keep", "reason": "保留原因" }} ] @@ -134,7 +132,7 @@ class MemoryCleaner: def __init__(self, dry_run: bool = True, batch_size: int = 10, concurrency: int = 5): """ 初始化清理器 - + Args: dry_run: 是否为模拟运行(不实际修改数据) batch_size: 每批处理的记忆数量 @@ -146,10 +144,10 @@ class MemoryCleaner: self.data_dir = project_root / "data" / "memory_graph" self.memory_file = self.data_dir / "memory_graph.json" self.backup_dir = self.data_dir / "backups" - + # 并发控制 self.semaphore: asyncio.Semaphore | None = None - + # 统计信息 self.stats = { "total": 0, @@ -160,7 +158,7 @@ class MemoryCleaner: "deleted_nodes": 0, "deleted_edges": 0, } - + # 日志文件 self.log_file = self.data_dir / f"cleanup_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" self.cleanup_log = [] @@ -168,23 +166,23 @@ class MemoryCleaner: def load_memories(self) -> dict: """加载记忆数据""" print(f"📂 加载记忆文件: {self.memory_file}") - + if not self.memory_file.exists(): raise FileNotFoundError(f"记忆文件不存在: {self.memory_file}") - - with open(self.memory_file, "r", encoding="utf-8") as f: + + with open(self.memory_file, encoding="utf-8") as f: data = json.load(f) - + return data def extract_memory_text(self, memory_dict: dict) -> str: """从记忆字典中提取可读文本""" parts = [] - + # 提取基本信息 memory_id = memory_dict.get("id", "unknown") parts.append(f"ID: {memory_id}") - + # 提取节点内容 nodes = memory_dict.get("nodes", []) for node in nodes: @@ -192,14 +190,14 @@ class MemoryCleaner: content = node.get("content", "") if content: parts.append(f"[{node_type}] {content}") - + # 提取边关系 edges = memory_dict.get("edges", []) for edge in edges: relation = edge.get("relation", "") if relation: parts.append(f"关系: {relation}") - + # 提取元数据 metadata = memory_dict.get("metadata", {}) if metadata: @@ -207,24 +205,24 @@ class MemoryCleaner: parts.append(f"上下文: {metadata['context']}") if "emotion" in metadata: parts.append(f"情感: {metadata['emotion']}") - + # 提取重要性和状态 importance = memory_dict.get("importance", 0) status = memory_dict.get("status", "unknown") created_at = memory_dict.get("created_at", "unknown") - + parts.append(f"重要性: {importance}, 状态: {status}, 创建时间: {created_at}") - + return "\n".join(parts) async def evaluate_batch(self, memories: list[dict], batch_id: int = 0) -> tuple[int, list[dict]]: """ 使用 LLM 评估一批记忆(带并发控制) - + Args: memories: 记忆字典列表 batch_id: 批次编号 - + Returns: (批次ID, 评估结果列表) """ @@ -234,27 +232,27 @@ class MemoryCleaner: for i, mem in enumerate(memories): text = self.extract_memory_text(mem) memory_texts.append(f"=== 记忆 {i+1} ===\n{text}") - + combined_text = "\n\n".join(memory_texts) prompt = EVALUATION_PROMPT.format(memories=combined_text) - + try: # 使用 LLMRequest 调用模型 if model_config is None: raise RuntimeError("model_config 未初始化,请确保已加载配置") task_config = model_config.model_task_config.utils llm = LLMRequest(task_config, request_type="memory_cleanup") - response_text, (reasoning, model_name, _) = await llm.generate_response_async( + response_text, (_reasoning, model_name, _) = await llm.generate_response_async( prompt=prompt, temperature=0.2, max_tokens=4000, ) - + print(f" ✅ 批次 {batch_id} 完成 (模型: {model_name})") - + # 解析 JSON 响应 response_text = response_text.strip() - + # 尝试提取 JSON if "```json" in response_text: json_start = response_text.find("```json") + 7 @@ -264,17 +262,17 @@ class MemoryCleaner: json_start = response_text.find("```") + 3 json_end = response_text.find("```", json_start) response_text = response_text[json_start:json_end].strip() - + result = json.loads(response_text) evaluations = result.get("evaluations", []) - + # 为评估结果添加实际的 memory_id for j, eval_result in enumerate(evaluations): if j < len(memories): eval_result["memory_id"] = memories[j].get("id", f"unknown_{batch_id}_{j}") - + return (batch_id, evaluations) - + except json.JSONDecodeError as e: print(f" ❌ 批次 {batch_id} JSON 解析失败: {e}") return (batch_id, []) @@ -291,36 +289,36 @@ class MemoryCleaner: """创建数据备份""" self.backup_dir.mkdir(parents=True, exist_ok=True) backup_file = self.backup_dir / f"memory_graph_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - + print(f"💾 创建备份: {backup_file}") with open(backup_file, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - + return backup_file def apply_changes(self, data: dict, evaluations: list[dict]) -> dict: """ 应用评估结果到数据 - + Args: data: 原始数据 evaluations: 评估结果列表 - + Returns: 修改后的数据 """ # 创建评估结果索引 - eval_map = {e["memory_id"]: e for e in evaluations if "memory_id" in e} - + {e["memory_id"]: e for e in evaluations if "memory_id" in e} + # 需要删除的记忆 ID to_delete = set() # 需要更新的记忆 to_update = {} - + for eval_result in evaluations: memory_id = eval_result.get("memory_id") action = eval_result.get("action") - + if action == "delete": to_delete.add(memory_id) self.stats["deleted"] += 1 @@ -342,18 +340,18 @@ class MemoryCleaner: }) else: self.stats["kept"] += 1 - + if self.dry_run: print("🔍 [DRY RUN] 不实际修改数据") return data - + # 实际修改数据 # 1. 删除记忆 memories = data.get("memories", {}) for mem_id in to_delete: if mem_id in memories: del memories[mem_id] - + # 2. 更新记忆内容 for mem_id, new_content in to_update.items(): if mem_id in memories: @@ -363,42 +361,42 @@ class MemoryCleaner: if node.get("node_type") in ["主题", "topic", "TOPIC"]: node["content"] = new_content break - + # 3. 清理孤立节点和边 data = self.cleanup_orphaned_nodes_and_edges(data) - + return data - + def cleanup_orphaned_nodes_and_edges(self, data: dict) -> dict: """ 清理孤立的节点和边 - + 孤立节点:其 metadata.memory_ids 中的所有记忆都已被删除 孤立边:其 source 或 target 节点已被删除 """ print("\n🔗 清理孤立节点和边...") - + # 获取当前所有有效的记忆 ID valid_memory_ids = set(data.get("memories", {}).keys()) print(f" 有效记忆数: {len(valid_memory_ids)}") - + # 清理节点 nodes = data.get("nodes", []) original_node_count = len(nodes) - + valid_nodes = [] valid_node_ids = set() - + for node in nodes: node_id = node.get("id") metadata = node.get("metadata", {}) memory_ids = metadata.get("memory_ids", []) - + # 检查节点关联的记忆是否还存在 if memory_ids: # 过滤掉已删除的记忆 ID remaining_memory_ids = [mid for mid in memory_ids if mid in valid_memory_ids] - + if remaining_memory_ids: # 更新 metadata 中的 memory_ids metadata["memory_ids"] = remaining_memory_ids @@ -410,32 +408,32 @@ class MemoryCleaner: # 保守处理:保留这些节点 valid_nodes.append(node) valid_node_ids.add(node_id) - + deleted_nodes = original_node_count - len(valid_nodes) data["nodes"] = valid_nodes print(f" ✅ 节点: {original_node_count} → {len(valid_nodes)} (删除 {deleted_nodes})") - + # 清理边 edges = data.get("edges", []) original_edge_count = len(edges) - + valid_edges = [] for edge in edges: source = edge.get("source") target = edge.get("target") - + # 只保留两端节点都存在的边 if source in valid_node_ids and target in valid_node_ids: valid_edges.append(edge) - + deleted_edges = original_edge_count - len(valid_edges) data["edges"] = valid_edges print(f" ✅ 边: {original_edge_count} → {len(valid_edges)} (删除 {deleted_edges})") - + # 更新统计 self.stats["deleted_nodes"] = deleted_nodes self.stats["deleted_edges"] = deleted_edges - + return data def save_data(self, data: dict): @@ -443,7 +441,7 @@ class MemoryCleaner: if self.dry_run: print("🔍 [DRY RUN] 跳过保存") return - + print(f"💾 保存数据到: {self.memory_file}") with open(self.memory_file, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @@ -468,88 +466,88 @@ class MemoryCleaner: print(f"批次大小: {self.batch_size}") print(f"并发数: {self.concurrency}") print("=" * 60) - + # 初始化 await self.initialize() - + # 加载数据 data = self.load_memories() - + # 获取所有记忆 memories = data.get("memories", {}) memory_list = list(memories.values()) self.stats["total"] = len(memory_list) - + print(f"📊 总记忆数: {self.stats['total']}") - + if not memory_list: print("⚠️ 没有记忆需要处理") return - + # 创建备份 if not self.dry_run: self.create_backup(data) - + # 分批 batches = [] for i in range(0, len(memory_list), self.batch_size): batch = memory_list[i:i + self.batch_size] batches.append(batch) - + total_batches = len(batches) print(f"📦 共 {total_batches} 个批次,开始并发处理...\n") - + # 并发处理所有批次 start_time = datetime.now() tasks = [ self.evaluate_batch(batch, batch_id=idx) for idx, batch in enumerate(batches) ] - + # 使用 asyncio.gather 并发执行 results = await asyncio.gather(*tasks, return_exceptions=True) - + end_time = datetime.now() elapsed = (end_time - start_time).total_seconds() - + # 收集所有评估结果 all_evaluations = [] success_count = 0 error_count = 0 - + for result in results: if isinstance(result, Exception): print(f" ❌ 批次异常: {result}") error_count += 1 elif isinstance(result, tuple): - batch_id, evaluations = result + _batch_id, evaluations = result if evaluations: all_evaluations.extend(evaluations) success_count += 1 else: error_count += 1 - + print(f"\n⏱️ 并发处理完成,耗时 {elapsed:.1f} 秒") print(f" 成功批次: {success_count}/{total_batches}, 失败: {error_count}") - + # 统计评估结果 delete_count = sum(1 for e in all_evaluations if e.get("action") == "delete") keep_count = sum(1 for e in all_evaluations if e.get("action") == "keep") summarize_count = sum(1 for e in all_evaluations if e.get("action") == "summarize") - + print(f" 📊 评估结果: 保留 {keep_count}, 删除 {delete_count}, 精简 {summarize_count}") - + # 应用更改 print("\n" + "=" * 60) print("📊 应用更改...") data = self.apply_changes(data, all_evaluations) - + # 保存数据 self.save_data(data) - + # 保存日志 self.save_log() - + # 打印统计 print("\n" + "=" * 60) print("📊 清理统计") @@ -563,7 +561,7 @@ class MemoryCleaner: print(f"错误: {self.stats['errors']}") print(f"处理速度: {self.stats['total'] / elapsed:.1f} 条/秒") print("=" * 60) - + if self.dry_run: print("\n⚠️ 这是模拟运行,实际数据未被修改") print("如要实际执行,请移除 --dry-run 参数") @@ -575,25 +573,25 @@ class MemoryCleaner: print("=" * 60) print(f"模式: {'模拟运行 (DRY RUN)' if self.dry_run else '实际执行'}") print("=" * 60) - + # 加载数据 data = self.load_memories() - + # 统计原始数据 memories = data.get("memories", {}) nodes = data.get("nodes", []) edges = data.get("edges", []) - + print(f"📊 当前状态: {len(memories)} 条记忆, {len(nodes)} 个节点, {len(edges)} 条边") - + if not self.dry_run: self.create_backup(data) - + # 清理孤立节点和边 if self.dry_run: # 模拟运行:统计但不修改 valid_memory_ids = set(memories.keys()) - + # 统计要删除的节点 nodes_to_keep = 0 for node in nodes: @@ -605,9 +603,9 @@ class MemoryCleaner: nodes_to_keep += 1 else: nodes_to_keep += 1 - + nodes_to_delete = len(nodes) - nodes_to_keep - + # 统计要删除的边(需要先确定哪些节点会被保留) valid_node_ids = set() for node in nodes: @@ -619,11 +617,11 @@ class MemoryCleaner: valid_node_ids.add(node.get("id")) else: valid_node_ids.add(node.get("id")) - + edges_to_keep = sum(1 for e in edges if e.get("source") in valid_node_ids and e.get("target") in valid_node_ids) edges_to_delete = len(edges) - edges_to_keep - - print(f"\n🔍 [DRY RUN] 预计清理:") + + print("\n🔍 [DRY RUN] 预计清理:") print(f" 节点: {len(nodes)} → {nodes_to_keep} (删除 {nodes_to_delete})") print(f" 边: {len(edges)} → {edges_to_keep} (删除 {edges_to_delete})") print("\n⚠️ 这是模拟运行,实际数据未被修改") @@ -631,8 +629,8 @@ class MemoryCleaner: else: data = self.cleanup_orphaned_nodes_and_edges(data) self.save_data(data) - - print(f"\n✅ 清理完成!") + + print("\n✅ 清理完成!") print(f" 删除节点: {self.stats['deleted_nodes']}") print(f" 删除边: {self.stats['deleted_edges']}") @@ -661,15 +659,15 @@ async def main(): action="store_true", help="只清理孤立节点和边,不重新评估记忆" ) - + args = parser.parse_args() - + cleaner = MemoryCleaner( dry_run=args.dry_run, batch_size=args.batch_size, concurrency=args.concurrency, ) - + if args.cleanup_only: await cleaner.run_cleanup_only() else: diff --git a/scripts/migrate_database.py b/scripts/migrate_database.py index 9f4bf71d3..2f42f31c9 100644 --- a/scripts/migrate_database.py +++ b/scripts/migrate_database.py @@ -8,7 +8,7 @@ python scripts/migrate_database.py --help python scripts/migrate_database.py --source sqlite --target postgresql python scripts/migrate_database.py --source postgresql --target sqlite --batch-size 5000 - + # 交互式向导模式(推荐) python scripts/migrate_database.py @@ -55,19 +55,21 @@ try: except ImportError: tomllib = None -from typing import Any, Iterable, Callable - +from collections.abc import Iterable from datetime import datetime as dt +from typing import Any from sqlalchemy import ( - create_engine, MetaData, Table, + create_engine, inspect, text, +) +from sqlalchemy import ( types as sqltypes, ) -from sqlalchemy.engine import Engine, Connection +from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import SQLAlchemyError # ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ====== @@ -320,7 +322,7 @@ def convert_value_for_target( """ # 获取目标类型的类名 target_type_name = target_col_type.__class__.__name__.upper() - source_type_name = source_col_type.__class__.__name__.upper() + source_col_type.__class__.__name__.upper() # 处理 None 值 if val is None: @@ -500,7 +502,7 @@ def migrate_table_data( target_cols_by_name = {c.key: c for c in target_table.columns} # 识别主键列(通常是 id),迁移时保留原始 ID 以避免重复数据 - primary_key_cols = {c.key for c in source_table.primary_key.columns} + {c.key for c in source_table.primary_key.columns} # 使用流式查询,避免一次性加载太多数据 # 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime)导致的错误 @@ -776,7 +778,7 @@ class DatabaseMigrator: for table_name in self.metadata.tables: dependencies[table_name] = set() - for table_name, table in self.metadata.tables.items(): + for table_name in self.metadata.tables.keys(): fks = inspector.get_foreign_keys(table_name) for fk in fks: # 被引用的表 @@ -919,7 +921,7 @@ class DatabaseMigrator: self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}") self.stats["end_time"] = time.time() - + # 迁移完成后,自动修复 PostgreSQL 特有问题 if self.target_type == "postgresql" and self.target_engine: fix_postgresql_boolean_columns(self.target_engine) @@ -927,7 +929,6 @@ class DatabaseMigrator: def print_summary(self): """打印迁移总结""" - import time duration = None if self.stats["start_time"] is not None and self.stats["end_time"] is not None: @@ -1262,104 +1263,104 @@ def interactive_setup() -> dict: def fix_postgresql_sequences(engine: Engine): """修复 PostgreSQL 序列值 - + 迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值, 导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。 - + Args: engine: PostgreSQL 数据库引擎 """ if engine.dialect.name != "postgresql": logger.info("非 PostgreSQL 数据库,跳过序列修复") return - + logger.info("正在修复 PostgreSQL 序列...") - + with engine.connect() as conn: # 获取所有带有序列的表 - result = conn.execute(text(''' - SELECT + result = conn.execute(text(""" + SELECT t.table_name, c.column_name, pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name FROM information_schema.tables t - JOIN information_schema.columns c + JOIN information_schema.columns c ON t.table_name = c.table_name AND t.table_schema = c.table_schema - WHERE t.table_schema = 'public' + WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' AND c.column_default LIKE 'nextval%' ORDER BY t.table_name - ''')) - + """)) + sequences = result.fetchall() logger.info("发现 %d 个带序列的表", len(sequences)) - + fixed_count = 0 for table_name, column_name, seq_name in sequences: if seq_name: try: # 获取当前表中该列的最大值 - max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}')) + max_result = conn.execute(text(f"SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}")) max_val = max_result.scalar() - + # 设置序列的下一个值 next_val = max_val + 1 conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)")) conn.commit() - + logger.info(" ✅ %s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val) fixed_count += 1 except Exception as e: logger.warning(" ❌ %s.%s: 修复失败 - %s", table_name, column_name, e) - + logger.info("序列修复完成!共修复 %d 个序列", fixed_count) def fix_postgresql_boolean_columns(engine: Engine): """修复 PostgreSQL 布尔列类型 - + 从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。 - + Args: engine: PostgreSQL 数据库引擎 """ if engine.dialect.name != "postgresql": logger.info("非 PostgreSQL 数据库,跳过布尔列修复") return - + # 已知需要转换为 BOOLEAN 的列 BOOLEAN_COLUMNS = { - 'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command', - 'is_notify', 'is_public_notice', 'should_reply', 'should_act'], - 'action_records': ['action_done', 'action_build_into_prompt'], + "messages": ["is_mentioned", "is_emoji", "is_picid", "is_command", + "is_notify", "is_public_notice", "should_reply", "should_act"], + "action_records": ["action_done", "action_build_into_prompt"], } - + logger.info("正在检查并修复 PostgreSQL 布尔列...") - + with engine.connect() as conn: fixed_count = 0 for table_name, columns in BOOLEAN_COLUMNS.items(): for col_name in columns: try: # 检查当前类型 - result = conn.execute(text(f''' - SELECT data_type FROM information_schema.columns + result = conn.execute(text(f""" + SELECT data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = '{col_name}' - ''')) + """)) row = result.fetchone() - if row and row[0] != 'boolean': + if row and row[0] != "boolean": # 需要修复 - conn.execute(text(f''' - ALTER TABLE {table_name} - ALTER COLUMN {col_name} TYPE BOOLEAN + conn.execute(text(f""" + ALTER TABLE {table_name} + ALTER COLUMN {col_name} TYPE BOOLEAN USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END - ''')) + """)) conn.commit() logger.info(" ✅ %s.%s: %s -> BOOLEAN", table_name, col_name, row[0]) fixed_count += 1 except Exception as e: logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e) - + if fixed_count > 0: logger.info("布尔列修复完成!共修复 %d 列", fixed_count) else: diff --git a/scripts/test_bedrock_client.py b/scripts/test_bedrock_client.py index e2a54bf7f..10d5eea2a 100644 --- a/scripts/test_bedrock_client.py +++ b/scripts/test_bedrock_client.py @@ -134,7 +134,7 @@ async def test_tool_calling(): print("测试 4: 工具调用功能") print("=" * 60) - from src.llm_models.payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType + from src.llm_models.payload_content.tool_option import ToolOptionBuilder, ToolParamType provider = APIProvider( name="bedrock_test", @@ -171,7 +171,7 @@ async def test_tool_calling(): ) if response.tool_calls: - print(f"✅ 模型调用了工具:") + print("✅ 模型调用了工具:") for call in response.tool_calls: print(f" - 工具名: {call.func_name}") print(f" - 参数: {call.args}") diff --git a/src/api/memory_visualizer_router.py b/src/api/memory_visualizer_router.py index 86658e7e0..624a877a4 100644 --- a/src/api/memory_visualizer_router.py +++ b/src/api/memory_visualizer_router.py @@ -16,7 +16,6 @@ from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates - # 调整项目根目录的计算方式 project_root = Path(__file__).parent.parent.parent data_dir = project_root / "data" / "memory_graph" @@ -103,7 +102,7 @@ async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, processed = await loop.run_in_executor( _executor, _process_graph_data, nodes, edges, metadata, graph_file ) - + graph_data_cache = processed return graph_data_cache @@ -303,8 +302,8 @@ async def get_paginated_graph( # 在线程池中处理分页逻辑 loop = asyncio.get_event_loop() result = await loop.run_in_executor( - _executor, - _process_pagination, + _executor, + _process_pagination, full_data, page, page_size, min_importance, node_types ) @@ -353,7 +352,7 @@ def _process_pagination(full_data: dict, page: int, page_size: int, min_importan end_idx = min(start_idx + page_size, total_nodes) paginated_nodes = nodes_with_importance[start_idx:end_idx] - node_ids = set(n["id"] for n in paginated_nodes) + node_ids = {n["id"] for n in paginated_nodes} # 只保留连接分页节点的边 paginated_edges = [ diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 18d3b8f09..fe29d6ce2 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -60,14 +60,14 @@ class ChatterManager: def get_chatter_class_for_chat_type(self, chat_type: ChatType) -> type | None: """ 获取指定聊天类型的最佳聊天处理器类 - + 优先级规则: 1. 优先选择明确匹配当前聊天类型的 Chatter(如 PRIVATE 或 GROUP) 2. 如果没有精确匹配,才使用 ALL 类型的 Chatter - + Args: chat_type: 聊天类型 - + Returns: 最佳匹配的聊天处理器类,如果没有匹配则返回 None """ @@ -77,14 +77,14 @@ class ChatterManager: if chatter_list: logger.debug(f"找到精确匹配的聊天处理器: {chatter_list[0].__name__} for {chat_type.value}") return chatter_list[0] - + # 2. 如果没有精确匹配,回退到 ALL 类型 if ChatType.ALL in self.chatter_classes: chatter_list = self.chatter_classes[ChatType.ALL] if chatter_list: logger.debug(f"使用通用聊天处理器: {chatter_list[0].__name__} for {chat_type.value}") return chatter_list[0] - + return None def get_chatter_class(self, chat_type: ChatType) -> type | None: @@ -142,7 +142,7 @@ class ChatterManager: async def process_stream_context(self, stream_id: str, context: "StreamContext") -> dict: """ 处理流上下文 - + 每个聊天流只能有一个活跃的 Chatter 组件。 选择优先级:明确指定聊天类型的 Chatter > ALL 类型的 Chatter """ @@ -154,11 +154,11 @@ class ChatterManager: # 检查是否已有该流的 Chatter 实例 stream_instance = self.instances.get(stream_id) - + if stream_instance is None: # 使用新的优先级选择逻辑获取最佳 Chatter 类 chatter_class = self.get_chatter_class_for_chat_type(chat_type) - + if not chatter_class: raise ValueError(f"No chatter registered for chat type {chat_type}") @@ -206,7 +206,7 @@ class ChatterManager: context.triggering_user_id = None context.processing_message_id = None raise - except Exception as e: # noqa: BLE001 + except Exception as e: self.stats["failed_executions"] += 1 logger.error("处理流时出错", stream_id=stream_id, error=e) context.triggering_user_id = None diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 125907a6d..679a7fbb1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1122,7 +1122,7 @@ class EmojiManager: if emoji_base64 is None: # 再次检查读取 logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}") return False - + # 等待描述生成完成 description, emotions = await self.build_emoji_description(emoji_base64) @@ -1135,7 +1135,7 @@ class EmojiManager: except Exception as e: logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}") return False - + new_emoji.description = description new_emoji.emotion = emotions except Exception as build_desc_error: diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 8f2fbe268..b9f21a93c 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -5,9 +5,10 @@ import time from abc import ABC, abstractmethod +from collections.abc import Awaitable from dataclasses import dataclass, field from enum import Enum -from typing import Any, Awaitable, TypedDict, cast +from typing import Any, TypedDict, cast from src.common.database.api.crud import CRUDBase from src.common.logger import get_logger diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 6343badb5..f4086573e 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -149,7 +149,7 @@ class ExpressionLearner: def get_related_chat_ids(self) -> list[str]: """根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身) - + 用于共享组功能:同一共享组内的聊天流可以共享学习到的表达方式 """ if global_config is None: @@ -249,7 +249,7 @@ class ExpressionLearner: try: if global_config is None: return False - use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id) + _use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id) return enable_learning except Exception as e: logger.error(f"检查学习权限失败: {e}") @@ -271,7 +271,7 @@ class ExpressionLearner: try: if global_config is None: return False - use_expression, enable_learning, learning_intensity = ( + _use_expression, enable_learning, learning_intensity = ( global_config.expression.get_expression_config_for_chat(self.chat_id) ) except Exception as e: @@ -594,7 +594,7 @@ class ExpressionLearner: from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key cache = await get_cache() - + # 获取共享组内所有 chat_id 并清除其缓存 related_chat_ids = self.get_related_chat_ids() for related_id in related_chat_ids: @@ -611,7 +611,7 @@ class ExpressionLearner: # 为每个共享组内的 chat_id 训练其 StyleLearner for target_chat_id in related_chat_ids: learner = style_learner_manager.get_learner(target_chat_id) - + # 为每个学习到的表达方式训练模型 # 使用 situation 作为输入,style 作为目标 # 这是最符合语义的方式:场景 -> 表达方式 @@ -689,7 +689,7 @@ class ExpressionLearner: # 🔥 启用表达学习场景的过滤,过滤掉纯回复、纯@、纯图片等无意义内容 random_msg_str: str = await build_anonymous_messages(random_msg, filter_for_learning=True) # print(f"random_msg_str:{random_msg_str}") - + # 🔥 检查过滤后是否还有足够的内容 if not random_msg_str or len(random_msg_str.strip()) < 20: logger.debug(f"过滤后消息内容不足,跳过本次{type_str}学习") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 83fd0c4b8..59ab4329e 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -610,7 +610,7 @@ class ExpressionSelector: # 4. 调用LLM try: # start_time = time.time() - content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) + content, (_reasoning_content, _model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) if not content: logger.warning("LLM返回空结果") diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index 722c5b0c2..3b099f3fd 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -438,7 +438,7 @@ class StyleLearner: class StyleLearnerManager: """多聊天室表达风格学习管理器 - + 添加 LRU 淘汰机制,限制最大活跃 learner 数量 """ @@ -470,7 +470,7 @@ class StyleLearnerManager: self.learner_last_used.items(), key=lambda x: x[1] ) - + evicted = [] for chat_id, last_used in sorted_by_time[:evict_count]: if chat_id in self.learners: diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 52e71ed84..202216362 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -100,7 +100,7 @@ class BotInterestManager: if loaded_interests: self.current_interests = loaded_interests - active_count = len(loaded_interests.get_active_tags()) + active_count = len(loaded_interests.get_active_tags()) tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()] tags_str = "\n".join(tags_info) @@ -247,7 +247,7 @@ class BotInterestManager: async def _call_llm_for_interest_generation(self, prompt: str) -> str | None: """调用LLM生成兴趣标签 - + 注意:此方法会临时增加 API 超时时间,以确保初始化阶段的人设标签生成 不会因用户配置的较短超时而失败。 """ @@ -275,7 +275,7 @@ class BotInterestManager: # 人设标签生成需要较长时间(15-25个标签的JSON),使用更长的超时 INIT_TIMEOUT = 180 # 初始化阶段使用 180 秒超时 original_timeouts: dict[str, int] = {} - + try: # 保存并修改所有相关模型的 API provider 超时设置 for model_name in replyer_config.model_list: @@ -288,9 +288,9 @@ class BotInterestManager: provider.timeout = INIT_TIMEOUT except Exception as e: logger.warning(f"⚠️ 无法修改模型 '{model_name}' 的超时设置: {e}") - + # 调用LLM API - success, response, reasoning_content, model_name = await llm_api.generate_with_model( + success, response, _reasoning_content, model_name = await llm_api.generate_with_model( prompt=full_prompt, model_config=replyer_config, request_type="interest_generation", @@ -383,13 +383,13 @@ class BotInterestManager: # 使用LLMRequest获取embedding if not self.embedding_request: raise RuntimeError("❌ Embedding客户端未初始化") - embedding, model_name = await self.embedding_request.get_embedding(text) + embedding, _model_name = await self.embedding_request.get_embedding(text) if embedding and len(embedding) > 0: if isinstance(embedding[0], list): # If it's a list of lists, take the first one (though get_embedding(str) should return list[float]) embedding = embedding[0] - + # Now we can safely cast to list[float] as we've handled the nested list case embedding_float = cast(list[float], embedding) self.embedding_cache[text] = embedding_float @@ -447,7 +447,7 @@ class BotInterestManager: try: chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts) - except Exception as exc: # noqa: BLE001 + except Exception as exc: logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}") continue @@ -1063,7 +1063,7 @@ class BotInterestManager: # 验证缓存版本和embedding模型 cache_version = cache_data.get("version", 1) cache_embedding_model = cache_data.get("embedding_model", "") - + current_embedding_model = "" if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list: current_embedding_model = self.embedding_config.model_list[0] diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 1751b198d..fc2d55915 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -147,7 +147,7 @@ class EmbeddingStore: """ 异步、并发地批量获取嵌入向量。 使用 chunk_size 进行批量请求,max_workers 控制并发批次数。 - + 优化策略: 1. 将字符串分成多个 chunk,每个 chunk 包含 chunk_size 个字符串 2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量 diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 35e268c24..e9b9e5fed 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -99,36 +99,36 @@ class QAManager: # It seems kg_search expects the first element to be a tuple of strings? # But the implementation uses it as a hash key to look up in store. # Let's look at kg_manager.py again. - + # In kg_manager.py: # def kg_search(self, relation_search_result: list[tuple[tuple[str, str, str], float]], ...) # ... # for relation_hash, similarity in relation_search_result: # relation_item = embed_manager.relation_embedding_store.store.get(relation_hash) - + # Wait, I just fixed kg_manager.py to: # for relation_hash, similarity in relation_search_result: - + # So it expects a tuple of 2 elements? # But search_top_k returns (id, score, vector). # So relation_search_res is list[tuple[Any, float, float]]. - + # I need to adapt the data or cast it. # If I pass it directly, it has 3 elements. # If kg_manager expects 2, I should probably slice it. - + # Let's cast it for now to silence the error, assuming the runtime behavior is compatible (unpacking first 2 of 3 is fine in python if not strict, but here it is strict unpacking in loop?) # In kg_manager.py I changed it to: # for relation_hash, similarity in relation_search_result: # This will fail if the tuple has 3 elements! "too many values to unpack" - + # So I should probably fix the data passed to kg_search to be list[tuple[str, float]]. - + relation_search_result_for_kg = [(str(res[0]), float(res[1])) for res in relation_search_res] - + result, ppr_node_weights = self.kg_manager.kg_search( cast(list[tuple[tuple[str, str, str], float]], relation_search_result_for_kg), # The type hint in kg_manager is weird, but let's match it or cast to Any - paragraph_search_res, + paragraph_search_res, self.embed_manager ) part_end_time = time.perf_counter() diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index be774081d..ff3694901 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -11,17 +11,17 @@ import asyncio import time +from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Awaitable +from typing import TYPE_CHECKING, Any from src.chat.chatter_manager import ChatterManager from src.chat.energy_system import energy_manager +from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.chat_stream import get_chat_manager if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream from src.common.data_models.message_manager_data_model import StreamContext logger = get_logger("stream_loop_manager") @@ -36,7 +36,7 @@ logger = get_logger("stream_loop_manager") class ConversationTick: """ 会话事件标记 - 表示一次待处理的会话事件 - + 这是一个轻量级的事件信号,不存储消息数据。 未读消息由 StreamContext 管理,能量值由 energy_manager 管理。 """ @@ -61,10 +61,10 @@ async def conversation_loop( ) -> AsyncIterator[ConversationTick]: """ 会话循环生成器 - 按需产出 Tick 事件 - + 替代原有的无限循环任务,改为事件驱动的生成器模式。 只有调用 __anext__() 时才会执行,完全由消费者控制节奏。 - + Args: stream_id: 流ID get_context_func: 获取 StreamContext 的异步函数 @@ -72,13 +72,13 @@ async def conversation_loop( flush_cache_func: 刷新缓存消息的异步函数 check_force_dispatch_func: 检查是否需要强制分发的函数 is_running_func: 检查是否继续运行的函数 - + Yields: ConversationTick: 会话事件 """ tick_count = 0 last_interval = None - + while is_running_func(): try: # 1. 获取流上下文 @@ -87,17 +87,17 @@ async def conversation_loop( logger.warning(f" [生成器] stream={stream_id[:8]}, 无法获取流上下文") await asyncio.sleep(10.0) continue - + # 2. 刷新缓存消息到未读列表 await flush_cache_func(stream_id) - + # 3. 检查是否有消息需要处理 unread_messages = context.get_unread_messages() unread_count = len(unread_messages) if unread_messages else 0 - + # 4. 检查是否需要强制分发 force_dispatch = check_force_dispatch_func(context, unread_count) - + # 5. 如果有消息,产出 Tick if unread_count > 0 or force_dispatch: tick_count += 1 @@ -106,18 +106,18 @@ async def conversation_loop( force_dispatch=force_dispatch, tick_count=tick_count, ) - + # 6. 计算并等待下次检查间隔 has_messages = unread_count > 0 interval = await calculate_interval_func(stream_id, has_messages) - + # 只在间隔发生变化时输出日志 if last_interval is None or abs(interval - last_interval) > 0.01: logger.debug(f"[生成器] stream={stream_id[:8]}, 等待间隔: {interval:.2f}s") last_interval = interval - + await asyncio.sleep(interval) - + except asyncio.CancelledError: logger.info(f" [生成器] stream={stream_id[:8]}, 被取消") break @@ -137,16 +137,16 @@ async def run_chat_stream( ) -> None: """ 聊天流驱动器 - 消费 Tick 事件并调用 Chatter - + 替代原有的 _stream_loop_worker,结构更清晰。 - + Args: stream_id: 流ID manager: StreamLoopManager 实例 """ task_id = id(asyncio.current_task()) logger.debug(f" [驱动器] stream={stream_id[:8]}, 任务ID={task_id}, 启动") - + try: # 创建生成器 tick_generator = conversation_loop( @@ -157,7 +157,7 @@ async def run_chat_stream( check_force_dispatch_func=manager._needs_force_dispatch_for_context, is_running_func=lambda: manager.is_running, ) - + # 消费 Tick 事件 async for tick in tick_generator: try: @@ -165,7 +165,7 @@ async def run_chat_stream( context = await manager._get_stream_context(stream_id) if not context: continue - + # 并发保护:检查是否正在处理 if context.is_chatter_processing: if manager._recover_stale_chatter_state(stream_id, context): @@ -173,19 +173,19 @@ async def run_chat_stream( else: logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick") continue - + # 日志 if tick.force_dispatch: logger.info(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 强制分发") else: logger.debug(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 开始处理") - + # 更新能量值 try: await manager._update_stream_energy(stream_id, context) except Exception as e: logger.debug(f"更新能量失败: {e}") - + # 处理消息 assert global_config is not None try: @@ -196,7 +196,7 @@ async def run_chat_stream( except asyncio.TimeoutError: logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时") success = False - + # 更新统计 manager.stats["total_process_cycles"] += 1 if success: @@ -205,13 +205,13 @@ async def run_chat_stream( else: manager.stats["total_failures"] += 1 logger.debug(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理失败") - + except asyncio.CancelledError: raise except Exception as e: logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}") manager.stats["total_failures"] += 1 - + except asyncio.CancelledError: logger.info(f" [驱动器] stream={stream_id[:8]}, 任务ID={task_id}, 被取消") finally: @@ -233,7 +233,7 @@ async def run_chat_stream( class StreamLoopManager: """ 流循环管理器 - 基于 Generator + Tick 的事件驱动模式 - + 管理所有聊天流的生命周期,为每个流创建独立的驱动器任务。 """ @@ -321,11 +321,11 @@ class StreamLoopManager: async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: """ 启动指定流的驱动器任务 - + Args: stream_id: 流ID force: 是否强制启动(会先取消现有任务) - + Returns: bool: 是否成功启动 """ @@ -379,10 +379,10 @@ class StreamLoopManager: async def stop_stream_loop(self, stream_id: str) -> bool: """ 停止指定流的驱动器任务 - + Args: stream_id: 流ID - + Returns: bool: 是否成功停止 """ @@ -446,11 +446,11 @@ class StreamLoopManager: async def _process_stream_messages(self, stream_id: str, context: "StreamContext") -> bool: """ 处理流消息 - + Args: stream_id: 流ID context: 流上下文 - + Returns: bool: 是否处理成功 """ @@ -468,7 +468,7 @@ class StreamLoopManager: chatter_task = None try: start_time = time.time() - + # 检查未读消息 unread_messages = context.get_unread_messages() if not unread_messages: @@ -521,7 +521,7 @@ class StreamLoopManager: logger.warning(f"处理失败: {stream_id} - {results.get('error_message', '未知错误')}") return success - + except asyncio.CancelledError: if chatter_task and not chatter_task.done(): chatter_task.cancel() @@ -557,7 +557,7 @@ class StreamLoopManager: # 检查是否有消息提及 Bot bot_name = getattr(global_config.bot, "nickname", "") bot_aliases = getattr(global_config.bot, "alias_names", []) - mention_keywords = [bot_name] + list(bot_aliases) if bot_name else list(bot_aliases) + mention_keywords = [bot_name, *list(bot_aliases)] if bot_name else list(bot_aliases) mention_keywords = [k for k in mention_keywords if k] for msg in unread_messages: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 657c59067..8c7bdb3e1 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any from src.chat.planner_actions.action_manager import ChatterActionManager if TYPE_CHECKING: - from src.chat.chatter_manager import ChatterManager + pass from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.logger import get_logger @@ -94,7 +94,7 @@ class MessageManager: async def add_message(self, stream_id: str, message: DatabaseMessages): """添加消息到指定聊天流 - + 注意:Notice 消息已在 MessageHandler._handle_notice_message 中单独处理, 不再经过此方法。此方法仅处理普通消息。 """ diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index aee356156..aa6824551 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -6,8 +6,7 @@ from rich.traceback import install from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert -from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo -from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseMessages, DatabaseUserInfo from src.common.database.api.crud import CRUDBase from src.common.database.compatibility import get_db_session from src.common.database.core.models import ChatStreams # 新增导入 @@ -407,7 +406,7 @@ class ChatManager: try: from src.person_info.person_info import get_person_info_manager person_info_manager = get_person_info_manager() - + # 创建一个后台任务来执行同步,不阻塞当前流程 sync_task = asyncio.create_task( person_info_manager.sync_user_info(platform, user_id, nickname, cardname) diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py index 85ebdbf17..18ed28b9f 100644 --- a/src/chat/message_receive/message_handler.py +++ b/src/chat/message_receive/message_handler.py @@ -265,7 +265,7 @@ class MessageHandler: additional_config = message_info.get("additional_config", {}) if not isinstance(additional_config, dict): additional_config = {} - + notice_type = additional_config.get("notice_type", "unknown") is_public_notice = additional_config.get("is_public_notice", False) diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index 02f5597a4..8acedd3a6 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -8,7 +8,7 @@ from typing import Any import orjson from mofox_wire import MessageEnvelope -from mofox_wire.types import MessageInfoPayload, SegPayload, UserInfoPayload, GroupInfoPayload +from mofox_wire.types import GroupInfoPayload, MessageInfoPayload, SegPayload, UserInfoPayload from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager @@ -40,7 +40,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st # 提取核心数据(使用 TypedDict 类型) message_info: MessageInfoPayload = message_dict.get("message_info", {}) # type: ignore message_segment: SegPayload | list[SegPayload] = message_dict.get("message_segment", {"type": "text", "data": ""}) # type: ignore - + # 初始化处理状态 processing_state = { "is_emoji": False, @@ -154,8 +154,8 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st async def _process_message_segments( - segment: SegPayload | list[SegPayload], - state: dict, + segment: SegPayload | list[SegPayload], + state: dict, message_info: MessageInfoPayload ) -> str: """递归处理消息段,转换为文字描述 @@ -176,12 +176,12 @@ async def _process_message_segments( if processed: segments_text.append(processed) return " ".join(segments_text) - + # 如果是单个段 if isinstance(segment, dict): seg_type = segment.get("type", "") seg_data = segment.get("data") - + # 处理 seglist 类型 if seg_type == "seglist" and isinstance(seg_data, list): segments_text = [] @@ -190,16 +190,16 @@ async def _process_message_segments( if processed: segments_text.append(processed) return " ".join(segments_text) - + # 处理其他类型 return await _process_single_segment(segment, state, message_info) - + return "" async def _process_single_segment( - segment: SegPayload, - state: dict, + segment: SegPayload, + state: dict, message_info: MessageInfoPayload ) -> str: """处理单个消息段 @@ -214,7 +214,7 @@ async def _process_single_segment( """ seg_type = segment.get("type", "") seg_data = segment.get("data") - + try: if seg_type == "text": return str(seg_data) if seg_data else "" @@ -352,9 +352,9 @@ async def _process_single_segment( def _prepare_additional_config( - message_info: MessageInfoPayload, - is_notify: bool, - is_public_notice: bool, + message_info: MessageInfoPayload, + is_notify: bool, + is_public_notice: bool, notice_type: str | None ) -> str | None: """准备 additional_config,包含 format_info 和 notice 信息 @@ -424,26 +424,26 @@ def _extract_reply_from_segment(segment: SegPayload | list[SegPayload]) -> str | if reply_id: return reply_id return None - + # 如果是字典 if isinstance(segment, dict): seg_type = segment.get("type", "") seg_data = segment.get("data") - + # 如果是 seglist,递归搜索 if seg_type == "seglist" and isinstance(seg_data, list): for sub_seg in seg_data: reply_id = _extract_reply_from_segment(sub_seg) if reply_id: return reply_id - + # 如果是 reply 段,返回 message_id elif seg_type == "reply": return str(seg_data) if seg_data else None - + except Exception as e: logger.warning(f"提取reply_to信息失败: {e}") - + return None @@ -493,10 +493,10 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInf "time": db_message.time, "user_info": user_info, } - + if group_info: message_info["group_info"] = group_info - + if additional_config: message_info["additional_config"] = additional_config diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 9aafef73c..9767476eb 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -3,7 +3,7 @@ import re import time import traceback from collections import deque -from typing import Optional, TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import orjson from sqlalchemy import desc, select, update @@ -16,7 +16,7 @@ from src.common.logger import get_logger if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream - + logger = get_logger("message_storage") @@ -191,7 +191,7 @@ class MessageStorageBatcher: additional_config = message.additional_config key_words = MessageStorage._serialize_keywords(message.key_words) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) - memorized_times = getattr(message, 'memorized_times', 0) + memorized_times = getattr(message, "memorized_times", 0) user_platform = message.user_info.platform if message.user_info else "" user_id = message.user_info.user_id if message.user_info else "" diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index f9da33f15..26b9241e3 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -6,9 +6,8 @@ import asyncio import traceback from typing import TYPE_CHECKING -from rich.traceback import install - from mofox_wire import MessageEnvelope +from rich.traceback import install from src.chat.message_receive.message_processor import process_message_from_dict from src.chat.message_receive.storage import MessageStorage diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 52e449173..728a426df 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,6 +1,6 @@ import asyncio import traceback -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.chat.message_receive.chat_stream import get_chat_manager from src.common.data_models.database_data_model import DatabaseMessages @@ -19,7 +19,7 @@ logger = get_logger("action_manager") class ChatterActionManager: """ 动作管理器,用于管理和执行动作 - + 职责: - 加载和管理可用动作集 - 创建动作实例 @@ -139,7 +139,7 @@ class ChatterActionManager: ) -> Any: """ 执行单个动作 - + 所有动作逻辑都在 BaseAction.execute() 中实现 Args: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 63423fc43..008c5acf1 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -12,10 +12,9 @@ from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.plugin_system.base.component_types import ActionInfo - if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext from src.chat.message_receive.chat_stream import ChatStream + from src.common.data_models.message_manager_data_model import StreamContext logger = get_logger("action_manager") @@ -68,7 +67,7 @@ class ActionModifier: 2. 基于激活类型的智能动作判定,最终确定可用动作集 处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用 - + Args: message_content: 消息内容 chatter_name: 当前使用的 Chatter 名称,用于过滤只允许特定 Chatter 使用的动作 @@ -108,7 +107,7 @@ class ActionModifier: for action_name in list(all_actions.keys()): if action_name in all_registered_actions: action_info = all_registered_actions[action_name] - + # 检查聊天类型限制 chat_type_allow = getattr(action_info, "chat_type_allow", ChatType.ALL) should_keep_chat_type = ( @@ -116,12 +115,12 @@ class ActionModifier: or (chat_type_allow == ChatType.GROUP and is_group_chat) or (chat_type_allow == ChatType.PRIVATE and not is_group_chat) ) - + if not should_keep_chat_type: removals_s0.append((action_name, f"不支持{'群聊' if is_group_chat else '私聊'}")) self.action_manager.remove_action_from_using(action_name) continue - + # 检查 Chatter 限制 chatter_allow = getattr(action_info, "chatter_allow", []) if chatter_allow and chatter_name: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index dbf20b56b..f5c0b598c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -8,9 +8,8 @@ import random import re import time import traceback -import uuid from datetime import datetime, timedelta -from typing import Any, Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal from src.chat.express.expression_selector import expression_selector from src.chat.message_receive.uni_message_sender import HeartFCSender @@ -25,7 +24,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info -from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality @@ -494,14 +493,12 @@ class DefaultReplyer: ) content = None - reasoning_content = None - model_name = "unknown_model" if not prompt: logger.error("Prompt 构建失败,无法生成回复。") return False, None, None try: - content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt) + content, _reasoning_content, _model_name, _ = await self.llm_generate_content(prompt) logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n") except Exception as llm_e: @@ -1252,7 +1249,7 @@ class DefaultReplyer: if action_items: if len(action_items) == 1: # 单个动作 - action_name, action_info = list(action_items.items())[0] + action_name, action_info = next(iter(action_items.items())) action_desc = action_info.description # 构建基础决策信息 diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index bc908d728..3fbb71176 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,9 +1,9 @@ +from typing import TYPE_CHECKING + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer from src.common.logger import get_logger -from typing import TYPE_CHECKING - if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("ReplyerManager") diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 79883fe7b..d48d3c08e 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1125,7 +1125,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]], filter_for_le """ 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 - + Args: messages: 消息列表 filter_for_learning: 是否为表达学习场景进行额外过滤(过滤掉纯回复、纯@、纯图片等无意义内容) @@ -1162,16 +1162,16 @@ async def build_anonymous_messages(messages: list[dict[str, Any]], filter_for_le person_map[person_id] = chr(current_char) current_char += 1 return person_map[person_id] - + def is_meaningless_content(content: str, msg: dict) -> bool: """ 判断消息内容是否无意义(用于表达学习过滤) """ if not content or not content.strip(): return True - + stripped = content.strip() - + # 检查消息标记字段 if msg.get("is_emoji", False): return True @@ -1181,32 +1181,32 @@ async def build_anonymous_messages(messages: list[dict[str, Any]], filter_for_le return True if msg.get("is_command", False): return True - + # 🔥 检查纯回复消息(只有[回复]没有其他内容) reply_pattern = r"^\s*\[回复[^\]]*\]\s*$" if re.match(reply_pattern, stripped): return True - + # 🔥 检查纯@消息(只有@xxx没有其他内容) at_pattern = r"^\s*(@[^\s]+\s*)+$" if re.match(at_pattern, stripped): return True - + # 🔥 检查纯图片消息 image_pattern = r"^\s*(\[图片\]|\[动画表情\]|\[表情\]|\[picid:[^\]]+\])\s*$" if re.match(image_pattern, stripped): return True - + # 🔥 移除回复标记、@标记、图片标记后检查是否还有实质内容 clean_content = re.sub(r"\[回复[^\]]*\]", "", stripped) clean_content = re.sub(r"@[^\s]+", "", clean_content) clean_content = re.sub(r"\[图片\]|\[动画表情\]|\[表情\]|\[picid:[^\]]+\]", "", clean_content) clean_content = clean_content.strip() - + # 如果移除后内容太短(少于2个字符),认为无意义 if len(clean_content) < 2: return True - + return False for msg in messages: @@ -1227,7 +1227,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]], filter_for_le # For anonymous messages, we just replace with a placeholder. content = re.sub(r"\[picid:([^\]]+)\]", "[图片]", content) - + # 🔥 表达学习场景:过滤无意义消息 if filter_for_learning and is_meaningless_content(content, msg): continue diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 8a76b3de7..f10f9078b 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1082,7 +1082,7 @@ class Prompt: [新] 根据用户ID构建关系信息字符串。 """ from src.person_info.relationship_fetcher import relationship_fetcher_manager - + person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id(platform, user_id) @@ -1091,11 +1091,11 @@ class Prompt: return f"你似乎还不认识这位用户(ID: {user_id}),这是你们的第一次互动。" relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) - + # 并行构建用户信息和聊天流印象 user_relation_info_task = relationship_fetcher.build_relation_info(person_id, points_num=5) stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_id) - + user_relation_info, stream_impression = await asyncio.gather( user_relation_info_task, stream_impression_task ) diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py index a0c77436a..a56978270 100644 --- a/src/chat/utils/prompt_component_manager.py +++ b/src/chat/utils/prompt_component_manager.py @@ -524,7 +524,7 @@ class PromptComponentManager: is_built_in=False, ) # 从动态规则中收集并关联其所有注入规则 - for target, rules_in_target in self._dynamic_rules.items(): + for rules_in_target in self._dynamic_rules.values(): if name in rules_in_target: rule, _, _ = rules_in_target[name] dynamic_info.injection_rules.append(rule) diff --git a/src/chat/utils/report_generator.py b/src/chat/utils/report_generator.py index 874451efc..41618120e 100644 --- a/src/chat/utils/report_generator.py +++ b/src/chat/utils/report_generator.py @@ -136,7 +136,7 @@ class HTMLReportGenerator: for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) ] ) - + # 先计算基础数据 total_tokens = sum(stat_data.get(TOTAL_TOK_BY_MODEL, {}).values()) total_requests = stat_data.get(TOTAL_REQ_CNT, 0) @@ -144,21 +144,21 @@ class HTMLReportGenerator: total_messages = stat_data.get(TOTAL_MSG_CNT, 0) online_seconds = stat_data.get(ONLINE_TIME, 0) online_hours = online_seconds / 3600 if online_seconds > 0 else 0 - + # 大模型相关效率指标 - avg_cost_per_req = (total_cost / total_requests) if total_requests > 0 else 0 + (total_cost / total_requests) if total_requests > 0 else 0 avg_cost_per_msg = (total_cost / total_messages) if total_messages > 0 else 0 avg_tokens_per_msg = (total_tokens / total_messages) if total_messages > 0 else 0 avg_tokens_per_req = (total_tokens / total_requests) if total_requests > 0 else 0 msg_to_req_ratio = (total_messages / total_requests) if total_requests > 0 else 0 cost_per_hour = (total_cost / online_hours) if online_hours > 0 else 0 req_per_hour = (total_requests / online_hours) if online_hours > 0 else 0 - + # Token效率 (输出/输入比率) total_in_tokens = sum(stat_data.get(IN_TOK_BY_MODEL, {}).values()) total_out_tokens = sum(stat_data.get(OUT_TOK_BY_MODEL, {}).values()) token_efficiency = (total_out_tokens / total_in_tokens) if total_in_tokens > 0 else 0 - + # 生成效率指标表格数据 efficiency_data = [ ("💸 平均每条消息成本", f"{avg_cost_per_msg:.6f} ¥", "处理每条用户消息的平均AI成本"), @@ -172,14 +172,14 @@ class HTMLReportGenerator: ("📈 Token/在线小时", f"{(total_tokens / online_hours) if online_hours > 0 else 0:.0f}", "每在线小时处理的Token数"), ("💬 消息/在线小时", f"{(total_messages / online_hours) if online_hours > 0 else 0:.1f}", "每在线小时处理的消息数"), ] - + efficiency_rows = "\n".join( [ f"{metric}{value}{desc}" for metric, value, desc in efficiency_data ] ) - + # 计算活跃聊天数和最活跃聊天 msg_by_chat = stat_data.get(MSG_CNT_BY_CHAT, {}) active_chats = len(msg_by_chat) @@ -189,9 +189,9 @@ class HTMLReportGenerator: most_active_chat = self.name_mapping.get(most_active_id, (most_active_id, 0))[0] most_active_count = msg_by_chat[most_active_id] most_active_chat = f"{most_active_chat} ({most_active_count}条)" - + avg_msg_per_chat = (total_messages / active_chats) if active_chats > 0 else 0 - + summary_cards = f"""
@@ -350,8 +350,8 @@ class HTMLReportGenerator: generation_time=now.strftime("%Y-%m-%d %H:%M:%S"), tab_list="\n".join(tab_list_html), tab_content="\n".join(tab_content_html_list), - all_chart_data=json.dumps(chart_data, separators=(',', ':'), ensure_ascii=False), - static_chart_data=json.dumps(static_chart_data, separators=(',', ':'), ensure_ascii=False), + all_chart_data=json.dumps(chart_data, separators=(",", ":"), ensure_ascii=False), + static_chart_data=json.dumps(static_chart_data, separators=(",", ":"), ensure_ascii=False), report_css=report_css, report_js=report_js, ) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 5edfe219e..5d348d305 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -3,8 +3,8 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any -from src.common.database.compatibility import db_get, db_query from src.common.database.api.query import QueryBuilder +from src.common.database.compatibility import db_get, db_query from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask @@ -322,21 +322,21 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 # 🔧 内存优化:使用分批查询代替全量加载 query_start_time = collect_period[-1][1] - + query_builder = ( QueryBuilder(LLMUsage) .no_cache() .filter(timestamp__gte=query_start_time) .order_by("-timestamp") ) - + total_processed = 0 async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): for record in batch: if total_processed >= STAT_MAX_RECORDS: logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录") break - + if not isinstance(record, dict): continue @@ -407,11 +407,11 @@ class StatisticOutputTask(AsyncTask): total_processed += 1 if total_processed % 500 == 0: await StatisticOutputTask._yield_control(total_processed, interval=1) - + # 检查是否达到上限 if total_processed >= STAT_MAX_RECORDS: break - + # 每批处理完后让出控制权 await asyncio.sleep(0) # -- 计算派生指标 -- @@ -503,7 +503,7 @@ class StatisticOutputTask(AsyncTask): "labels": [item[0] for item in sorted_models], "data": [round(item[1], 4) for item in sorted_models], } - + # 1. Token输入输出对比条形图 model_names = list(period_stats[REQ_CNT_BY_MODEL].keys()) if model_names: @@ -512,7 +512,7 @@ class StatisticOutputTask(AsyncTask): "input_tokens": [period_stats[IN_TOK_BY_MODEL].get(m, 0) for m in model_names], "output_tokens": [period_stats[OUT_TOK_BY_MODEL].get(m, 0) for m in model_names], } - + # 2. 响应时间分布散点图数据(限制数据点以提高加载速度) scatter_data = [] max_points_per_model = 50 # 每个模型最多50个点 @@ -523,7 +523,7 @@ class StatisticOutputTask(AsyncTask): sampled_costs = time_costs[::step][:max_points_per_model] else: sampled_costs = time_costs - + for idx, time_cost in enumerate(sampled_costs): scatter_data.append({ "model": model_name, @@ -532,7 +532,7 @@ class StatisticOutputTask(AsyncTask): "tokens": period_stats[TOTAL_TOK_BY_MODEL].get(model_name, 0) // len(time_costs) if time_costs else 0 }) period_stats[SCATTER_CHART_RESPONSE_TIME] = scatter_data - + # 3. 模型效率雷达图 if model_names: # 取前5个最常用的模型 @@ -545,14 +545,14 @@ class StatisticOutputTask(AsyncTask): avg_time = period_stats[AVG_TIME_COST_BY_MODEL].get(model_name, 0) cost_per_ktok = period_stats[COST_PER_KTOK_BY_MODEL].get(model_name, 0) avg_tokens = period_stats[AVG_TOK_BY_MODEL].get(model_name, 0) - + # 简单的归一化(反向归一化时间和成本,值越小越好) max_req = max([period_stats[REQ_CNT_BY_MODEL].get(m[0], 1) for m in top_models]) max_tps = max([period_stats[TPS_BY_MODEL].get(m[0], 1) for m in top_models]) max_time = max([period_stats[AVG_TIME_COST_BY_MODEL].get(m[0], 0.1) for m in top_models]) max_cost = max([period_stats[COST_PER_KTOK_BY_MODEL].get(m[0], 0.001) for m in top_models]) max_tokens = max([period_stats[AVG_TOK_BY_MODEL].get(m[0], 1) for m in top_models]) - + radar_data.append({ "model": model_name, "metrics": [ @@ -567,7 +567,7 @@ class StatisticOutputTask(AsyncTask): "labels": ["请求量", "TPS", "响应速度", "成本效益", "Token容量"], "datasets": radar_data } - + # 4. 供应商请求占比环形图 provider_requests = period_stats[REQ_CNT_BY_PROVIDER] if provider_requests: @@ -576,7 +576,7 @@ class StatisticOutputTask(AsyncTask): "labels": [item[0] for item in sorted_provider_reqs], "data": [item[1] for item in sorted_provider_reqs], } - + # 5. 平均响应时间条形图 if model_names: sorted_by_time = sorted( @@ -649,7 +649,7 @@ class StatisticOutputTask(AsyncTask): if overlap_end > overlap_start: stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() break - + # 每批处理完后让出控制权 await asyncio.sleep(0) @@ -689,7 +689,7 @@ class StatisticOutputTask(AsyncTask): if total_processed >= STAT_MAX_RECORDS: logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录") break - + if not isinstance(message, dict): continue message_time_ts = message.get("time") # This is a float timestamp @@ -732,11 +732,11 @@ class StatisticOutputTask(AsyncTask): total_processed += 1 if total_processed % 500 == 0: await StatisticOutputTask._yield_control(total_processed, interval=1) - + # 检查是否达到上限 if total_processed >= STAT_MAX_RECORDS: break - + # 每批处理完后让出控制权 await asyncio.sleep(0) @@ -845,10 +845,10 @@ class StatisticOutputTask(AsyncTask): def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]: """🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据 - + 原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长) 压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}} - + 这样合并时只需要累加 sum/count/sum_sq,不会无限增长。 avg = sum / count std = sqrt(sum_sq / count - (sum / count)^2) @@ -858,17 +858,17 @@ class StatisticOutputTask(AsyncTask): TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL, TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER ] - + result = dict(data) # 浅拷贝 - + for key in time_cost_keys: if key not in result: continue - + original = result[key] if not isinstance(original, dict): continue - + compressed = {} for sub_key, values in original.items(): if isinstance(values, list): @@ -886,9 +886,9 @@ class StatisticOutputTask(AsyncTask): else: # 未知格式,保留原值 compressed[sub_key] = values - + result[key] = compressed - + return result def _convert_defaultdict_to_dict(self, data): @@ -1008,7 +1008,7 @@ class StatisticOutputTask(AsyncTask): .filter(timestamp__gte=start_time) .order_by("-timestamp") ) - + async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): for record in batch: if not isinstance(record, dict) or not record.get("timestamp"): @@ -1033,7 +1033,7 @@ class StatisticOutputTask(AsyncTask): if module_name not in cost_by_module: cost_by_module[module_name] = [0.0] * len(time_points) cost_by_module[module_name][idx] += cost - + await asyncio.sleep(0) # 🔧 内存优化:使用分批查询 Messages @@ -1043,7 +1043,7 @@ class StatisticOutputTask(AsyncTask): .filter(time__gte=start_time.timestamp()) .order_by("-time") ) - + async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): for msg in batch: if not isinstance(msg, dict) or not msg.get("time"): @@ -1063,7 +1063,7 @@ class StatisticOutputTask(AsyncTask): if chat_name not in message_by_chat: message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name][idx] += 1 - + await asyncio.sleep(0) return { diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index c45088c33..8bce2edb5 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -36,21 +36,21 @@ def get_typo_generator( ) -> "ChineseTypoGenerator": """ 获取错别字生成器单例(内存优化) - + 如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。 - + 参数: error_rate: 单字替换概率 min_freq: 最小字频阈值 tone_error_rate: 声调错误概率 word_replace_rate: 整词替换概率 max_freq_diff: 最大允许的频率差异 - + 返回: ChineseTypoGenerator 实例 """ global _typo_generator_singleton - + with _singleton_lock: if _typo_generator_singleton is None: _typo_generator_singleton = ChineseTypoGenerator( @@ -70,7 +70,7 @@ def get_typo_generator( word_replace_rate=word_replace_rate, max_freq_diff=max_freq_diff, ) - + return _typo_generator_singleton @@ -87,7 +87,7 @@ class ChineseTypoGenerator: max_freq_diff: 最大允许的频率差异 """ global _shared_pinyin_dict, _shared_char_frequency - + self.error_rate = error_rate self.min_freq = min_freq self.tone_error_rate = tone_error_rate @@ -99,7 +99,7 @@ class ChineseTypoGenerator: _shared_pinyin_dict = self._create_pinyin_dict() logger.debug("拼音字典已创建并缓存") self.pinyin_dict = _shared_pinyin_dict - + if _shared_char_frequency is None: _shared_char_frequency = self._load_or_create_char_frequency() logger.debug("字频数据已加载并缓存") @@ -454,10 +454,10 @@ class ChineseTypoGenerator: # 50%概率返回纠正建议 if random.random() < 0.5: if word_typos: - wrong_word, correct_word = random.choice(word_typos) + _wrong_word, correct_word = random.choice(word_typos) correction_suggestion = correct_word elif char_typos: - wrong_char, correct_char = random.choice(char_typos) + _wrong_char, correct_char = random.choice(char_typos) correction_suggestion = correct_char return "".join(result), correction_suggestion diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 2226b7946..c4675078f 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -9,13 +9,15 @@ from typing import Any import numpy as np import rjieba +from src.common.data_models.database_data_model import DatabaseUserInfo + # MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.common.data_models.database_data_model import DatabaseUserInfo + from .typo_generator import get_typo_generator logger = get_logger("chat_utils") diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index f51f18b29..96fa45b75 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -189,7 +189,7 @@ class ImageManager: # 4. 如果都未命中,则调用新逻辑生成描述 logger.info(f"[新表情识别] 表情包未注册且无缓存 (Hash: {image_hash[:8]}...),调用新逻辑生成描述") - full_description, emotions = await emoji_manager.build_emoji_description(image_base64) + full_description, _emotions = await emoji_manager.build_emoji_description(image_base64) if not full_description: logger.warning("未能通过新逻辑生成有效描述") diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 5bf3b769d..8227c870c 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -11,7 +11,6 @@ import io import os from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any import cv2 import numpy as np diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index f1abfe9d9..baeb2a69a 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -583,56 +583,56 @@ class CacheManager: ) -> list[dict[str, Any]]: """ 根据语义相似度主动召回相关的缓存条目 - + 用于在回复前扫描缓存,找到与当前对话相关的历史搜索结果 - + Args: query_text: 用于语义匹配的查询文本(通常是最近几条聊天内容) tool_name: 可选,限制只召回特定工具的缓存(如 "web_search") top_k: 返回的最大结果数 similarity_threshold: 相似度阈值(L2距离,越小越相似) - + Returns: 相关缓存条目列表,每个条目包含 {tool_name, query, content, similarity} """ if not query_text or not self.embedding_model: return [] - + try: # 生成查询向量 embedding_result = await self.embedding_model.get_embedding(query_text) if not embedding_result: return [] - + embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result validated_embedding = self._validate_embedding(embedding_vector) if validated_embedding is None: return [] - + query_embedding = np.array([validated_embedding], dtype="float32") - + # 从 L2 向量数据库查询 results = vector_db_service.query( collection_name=self.semantic_cache_collection_name, query_embeddings=query_embedding.tolist(), n_results=top_k * 2, # 多取一些,后面会过滤 ) - + if not results or not results.get("ids") or not results["ids"][0]: logger.debug("[缓存召回] 未找到相关缓存") return [] - + recalled_items = [] ids = results["ids"][0] if isinstance(results["ids"][0], list) else [results["ids"][0]] distances = results.get("distances", [[]])[0] if results.get("distances") else [] - + for i, cache_key in enumerate(ids): distance = distances[i] if i < len(distances) else 1.0 - + # 过滤相似度不够的 if distance > similarity_threshold: continue - + # 从数据库获取缓存数据 cache_obj = await db_query( model_class=CacheEntries, @@ -640,26 +640,26 @@ class CacheManager: filters={"cache_key": cache_key}, single_result=True, ) - + if not cache_obj: continue - + # 检查是否过期 expires_at = getattr(cache_obj, "expires_at", 0) if time.time() >= expires_at: continue - + # 获取工具名称并过滤 cached_tool_name = getattr(cache_obj, "tool_name", "") if tool_name and cached_tool_name != tool_name: continue - + # 解析缓存内容 try: cache_value = getattr(cache_obj, "cache_value", "{}") data = orjson.loads(cache_value) content = data.get("content", "") if isinstance(data, dict) else str(data) - + # 从 cache_key 中提取原始查询(格式: tool_name::{"query": "xxx", ...}::file_hash) original_query = "" try: @@ -670,26 +670,26 @@ class CacheManager: original_query = args.get("query", "") except Exception: pass - + recalled_items.append({ "tool_name": cached_tool_name, "query": original_query, "content": content, "similarity": 1.0 - distance, # 转换为相似度分数 }) - + except Exception as e: logger.warning(f"解析缓存内容失败: {e}") continue - + if len(recalled_items) >= top_k: break - + if recalled_items: logger.info(f"[缓存召回] 找到 {len(recalled_items)} 条相关缓存") - + return recalled_items - + except Exception as e: logger.error(f"[缓存召回] 语义召回失败: {e}") return [] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 6517294fd..92a100f42 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -89,44 +89,44 @@ class DatabaseMessages(BaseDataModel): """ __slots__ = ( - # 基础消息字段 - "message_id", - "time", - "chat_id", - "reply_to", - "interest_value", - "key_words", - "key_words_lite", - "is_mentioned", - "is_at", - "reply_probability_boost", - "processed_plain_text", - "display_message", - "priority_mode", - "priority_info", - "additional_config", - "is_emoji", - "is_picid", - "is_command", - "is_notify", - "is_public_notice", - "notice_type", - "selected_expressions", - "is_read", "actions", - "should_reply", - "should_act", - # 关联对象 - "user_info", - "group_info", + "additional_config", + "chat_id", "chat_info", - # 运行时扩展字段(固定) - "semantic_embedding", - "interest_calculated", - "is_voice", - "is_video", + "display_message", + "group_info", "has_emoji", "has_picid", + "interest_calculated", + "interest_value", + "is_at", + "is_command", + "is_emoji", + "is_mentioned", + "is_notify", + "is_picid", + "is_public_notice", + "is_read", + "is_video", + "is_voice", + "key_words", + "key_words_lite", + # 基础消息字段 + "message_id", + "notice_type", + "priority_info", + "priority_mode", + "processed_plain_text", + "reply_probability_boost", + "reply_to", + "selected_expressions", + # 运行时扩展字段(固定) + "semantic_embedding", + "should_act", + "should_reply", + "time", + # 关联对象 + "user_info", ) def __init__( @@ -405,16 +405,16 @@ class DatabaseActionRecords(BaseDataModel): """ __slots__ = ( - "action_id", - "time", - "action_name", + "action_build_into_prompt", "action_data", "action_done", - "action_build_into_prompt", + "action_id", + "action_name", "action_prompt_display", "chat_id", - "chat_info_stream_id", "chat_info_platform", + "chat_info_stream_id", + "time", ) def __init__( diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index 651c09099..57e67d5bd 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -12,9 +12,7 @@ from functools import lru_cache from typing import Any, Generic, TypeVar from sqlalchemy import delete, func, select, update -from sqlalchemy.engine import CursorResult, Result -from src.common.database.core.models import Base from src.common.database.core.session import get_db_session from src.common.database.optimization import ( BatchOperation, diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index dc8a2e6c4..5c82609e0 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -15,7 +15,6 @@ from sqlalchemy import and_, asc, desc, func, or_, select # 导入 CRUD 辅助函数以避免重复定义 from src.common.database.api.crud import _dict_to_model, _model_to_dict -from src.common.database.core.models import Base from src.common.database.core.session import get_db_session from src.common.database.optimization import get_cache from src.common.logger import get_logger diff --git a/src/common/database/api/specialized.py b/src/common/database/api/specialized.py index 5625c3795..2a735ff4b 100644 --- a/src/common/database/api/specialized.py +++ b/src/common/database/api/specialized.py @@ -91,7 +91,7 @@ async def store_action_info( ) # 使用get_or_create保存记录 - saved_record, created = await _action_records_crud.get_or_create( + saved_record, _created = await _action_records_crud.get_or_create( defaults=record_data, action_id=action_id, ) @@ -438,7 +438,7 @@ async def update_relationship_affinity( """ try: # 获取或创建关系 - relationship, created = await _user_relationships_crud.get_or_create( + relationship, _created = await _user_relationships_crud.get_or_create( defaults={"affinity": 0.0, "interaction_count": 0}, platform=platform, user_id=user_id, diff --git a/src/common/database/compatibility/adapter.py b/src/common/database/compatibility/adapter.py index 783998fc7..e65b5a76e 100644 --- a/src/common/database/compatibility/adapter.py +++ b/src/common/database/compatibility/adapter.py @@ -300,7 +300,7 @@ async def db_save( crud = CRUDBase(model_class) # 使用get_or_create (返回tuple[T, bool]) - instance, created = await crud.get_or_create( + instance, _created = await crud.get_or_create( defaults=data, **{key_field: key_value}, ) diff --git a/src/common/database/core/migration.py b/src/common/database/core/migration.py index fff355fae..49f45d471 100644 --- a/src/common/database/core/migration.py +++ b/src/common/database/core/migration.py @@ -100,14 +100,14 @@ async def check_and_migrate_database(existing_engine=None): def add_columns_sync(conn): dialect = conn.dialect - + for column_name in missing_columns: column = table.c[column_name] - + # 获取列类型的 SQL 表示 # 直接使用 compile 方法,它会自动选择正确的方言 column_type_sql = column.type.compile(dialect=dialect) - + # 构建 ALTER TABLE 语句 sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type_sql}" @@ -285,7 +285,7 @@ def _normalize_pg_type(type_name: str) -> str: async def _check_and_fix_column_types(connection, inspector, table_name, table, db_columns_info): """检查并修复列类型不匹配的问题(仅 PostgreSQL) - + Args: connection: 数据库连接 inspector: SQLAlchemy inspector @@ -296,41 +296,41 @@ async def _check_and_fix_column_types(connection, inspector, table_name, table, # 获取数据库方言 def get_dialect_name(conn): return conn.dialect.name - + dialect_name = await connection.run_sync(get_dialect_name) - + # 目前只处理 PostgreSQL if dialect_name != "postgresql": return - + for (fix_table, fix_column), (expected_type_category, using_clause) in _COLUMN_TYPE_FIXES.items(): if fix_table != table_name: continue - + if fix_column not in db_columns_info: continue - + col_info = db_columns_info[fix_column] current_type = _normalize_pg_type(str(col_info.get("type", ""))) expected_type = _get_expected_pg_type(expected_type_category) - + # 如果类型已经正确,跳过 if current_type == expected_type: continue - + # 检查是否需要修复:如果当前是 numeric 但期望是 boolean if current_type == "numeric" and expected_type == "boolean": logger.warning( f"发现列类型不匹配: {table_name}.{fix_column} " f"(当前: {current_type}, 期望: {expected_type})" ) - + # PostgreSQL 需要先删除默认值,再修改类型,最后重新设置默认值 using_sql = using_clause.format(column=fix_column) drop_default_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} DROP DEFAULT" alter_type_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} TYPE BOOLEAN {using_sql}" set_default_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} SET DEFAULT FALSE" - + try: def execute_alter(conn): # 步骤 1: 删除默认值 @@ -342,7 +342,7 @@ async def _check_and_fix_column_types(connection, inspector, table_name, table, conn.execute(text(alter_type_sql)) # 步骤 3: 重新设置默认值 conn.execute(text(set_default_sql)) - + await connection.run_sync(execute_alter) await connection.commit() logger.info(f"成功修复列类型: {table_name}.{fix_column} -> BOOLEAN") diff --git a/src/common/database/core/models.py b/src/common/database/core/models.py index 6125b4b02..13fff5272 100644 --- a/src/common/database/core/models.py +++ b/src/common/database/core/models.py @@ -651,7 +651,7 @@ class UserPermissions(Base): class UserRelationships(Base): """用户关系模型 - 存储用户与bot的关系数据 - + 核心字段: - relationship_text: 当前印象描述(用于兼容旧系统,逐步迁移到 impression_text) - impression_text: 长期印象(新字段,自然叙事风格) @@ -667,19 +667,19 @@ class UserRelationships(Base): user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔 - + # 印象相关(新旧兼容) relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 旧字段,保持兼容 impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 新字段:长期印象(自然叙事) - + # 用户信息 preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔 key_facts: Mapped[str | None] = mapped_column(Text, nullable=True) # 关键信息JSON(生日、职业等) - + # 关系状态 relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 好感度(0-1) relationship_stage: Mapped[str | None] = mapped_column(get_string_field(50), nullable=True, default="stranger") # 关系阶段 - + # 时间记录 first_met_time: Mapped[float | None] = mapped_column(Float, nullable=True) # 首次认识时间戳 last_impression_update: Mapped[float | None] = mapped_column(Float, nullable=True) # 上次更新印象时间 diff --git a/src/common/database/optimization/batch_scheduler.py b/src/common/database/optimization/batch_scheduler.py index 8a385ddc3..98dc36594 100644 --- a/src/common/database/optimization/batch_scheduler.py +++ b/src/common/database/optimization/batch_scheduler.py @@ -378,7 +378,7 @@ class AdaptiveBatchScheduler: # 过滤掉 id 为 None 的键,让数据库自动生成主键 filtered_data = {k: v for k, v in op.data.items() if not (k == "id" and v is None)} all_data.append(filtered_data) - + if not all_data: return diff --git a/src/common/database/optimization/cache_manager.py b/src/common/database/optimization/cache_manager.py index 30ad34bda..57ce4650c 100644 --- a/src/common/database/optimization/cache_manager.py +++ b/src/common/database/optimization/cache_manager.py @@ -440,8 +440,8 @@ class MultiLevelCache: # 计算共享键和独占键 shared_keys = l1_keys & l2_keys - l1_only_keys = l1_keys - l2_keys - l2_only_keys = l2_keys - l1_keys + l1_keys - l2_keys + l2_keys - l1_keys # 🔧 修复:并行计算内存使用,避免锁嵌套 l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys)) diff --git a/src/common/database/utils/decorators.py b/src/common/database/utils/decorators.py index dbc5ff89a..e468daf32 100644 --- a/src/common/database/utils/decorators.py +++ b/src/common/database/utils/decorators.py @@ -10,7 +10,7 @@ import asyncio import functools import hashlib import time -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Callable, Coroutine from typing import Any, ParamSpec, TypeVar from sqlalchemy.exc import DBAPIError, OperationalError diff --git a/src/common/logger.py b/src/common/logger.py index dd3425797..0f696c419 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,15 +1,15 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging -from logging.handlers import QueueHandler, QueueListener import tarfile import threading import time from collections.abc import Callable, Sequence from datetime import datetime, timedelta +from logging.handlers import QueueHandler, QueueListener from pathlib import Path - from queue import SimpleQueue + import orjson import structlog import tomlkit @@ -17,6 +17,7 @@ from rich.console import Console from rich.text import Text from structlog.typing import EventDict, WrappedLogger + # 守护线程版本的队列监听器,防止退出时卡住 class DaemonQueueListener(QueueListener): """QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。""" diff --git a/src/common/mem_monitor.py b/src/common/mem_monitor.py index 1887cbc89..0c2babb4d 100644 --- a/src/common/mem_monitor.py +++ b/src/common/mem_monitor.py @@ -13,9 +13,7 @@ """ import logging -import os import threading -import time import tracemalloc from datetime import datetime from logging.handlers import RotatingFileHandler @@ -50,14 +48,14 @@ def _setup_mem_logger() -> logging.Logger: logger = logging.getLogger("mem_monitor") logger.setLevel(logging.DEBUG) logger.propagate = False # 不传播到父日志器,避免污染主日志 - + # 清除已有的处理器 logger.handlers.clear() - + # 创建日志目录 log_dir = Path("logs") log_dir.mkdir(exist_ok=True) - + # 文件处理器 - 带日期的日志文件,支持轮转 log_file = log_dir / f"mem_monitor_{datetime.now().strftime('%Y%m%d')}.log" file_handler = RotatingFileHandler( @@ -67,22 +65,22 @@ def _setup_mem_logger() -> logging.Logger: encoding="utf-8", ) file_handler.setLevel(logging.DEBUG) - + # 格式化器 formatter = logging.Formatter( "%(asctime)s | %(levelname)-7s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) file_handler.setFormatter(formatter) - + # 控制台处理器 - 只输出重要信息 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) - + logger.addHandler(file_handler) logger.addHandler(console_handler) - + return logger @@ -177,22 +175,22 @@ def log_object_growth(limit: int = 20) -> None: if not OBJGRAPH_AVAILABLE or objgraph is None: logger.warning("objgraph not available, skipping object growth analysis") return - + logger.info("==== Objgraph growth (top %s) ====", limit) try: # objgraph.show_growth 默认输出到 stdout,需要捕获输出 import io import sys - + # 捕获 stdout old_stdout = sys.stdout sys.stdout = buffer = io.StringIO() - + try: objgraph.show_growth(limit=limit) finally: sys.stdout = old_stdout - + output = buffer.getvalue() if output.strip(): for line in output.strip().split("\n"): @@ -206,21 +204,21 @@ def log_object_growth(limit: int = 20) -> None: def log_type_memory_diff() -> None: """使用 Pympler 查看各类型对象占用的内存变化""" global _last_type_summary - + if not PYMPLER_AVAILABLE or muppy is None or summary is None: logger.warning("pympler not available, skipping type memory analysis") return - + import io import sys - + all_objects = muppy.get_objects() curr = summary.summarize(all_objects) # 捕获 Pympler 的输出(summary.print_ 也是输出到 stdout) old_stdout = sys.stdout sys.stdout = buffer = io.StringIO() - + try: if _last_type_summary is None: logger.info("==== Pympler initial type summary ====") @@ -370,7 +368,7 @@ def debug_leak_for_type(type_name: str, max_depth: int = 5, filename: str | None if not OBJGRAPH_AVAILABLE or objgraph is None: logger.warning("objgraph not available, cannot generate backrefs graph") return False - + if filename is None: filename = f"{type_name}_backrefs.png" diff --git a/src/common/security.py b/src/common/security.py index 104e1cf94..985010262 100644 --- a/src/common/security.py +++ b/src/common/security.py @@ -35,4 +35,4 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str: # 创建一个可重用的依赖项,供插件开发者在其需要验证的端点上使用 # 用法: @router.get("/protected_route", dependencies=[VerifiedDep]) # 或者: async def my_endpoint(_=VerifiedDep): ... -VerifiedDep = Depends(get_api_key) \ No newline at end of file +VerifiedDep = Depends(get_api_key) diff --git a/src/common/server.py b/src/common/server.py index 6feb72731..ebdec2be6 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -4,13 +4,12 @@ import socket from fastapi import APIRouter, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from rich.traceback import install -from uvicorn import Config -from uvicorn import Server as UvicornServer - from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware from slowapi.util import get_remote_address +from uvicorn import Config +from uvicorn import Server as UvicornServer from src.common.logger import get_logger from src.config.config import global_config as bot_config diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index ce30a5b63..33fe8fa77 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -148,7 +148,7 @@ class ModelTaskConfig(ValidatedConfigBase): relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置") # 处理配置文件中命名不一致的问题 utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)") - + # 记忆系统专用模型配置 memory_short_term_builder: TaskConfig = Field(..., description="短期记忆构建模型配置(感知→短期格式化)") memory_short_term_decider: TaskConfig = Field(..., description="短期记忆决策模型配置(合并/更新/新建/丢弃)") diff --git a/src/config/config.py b/src/config/config.py index 5df56712e..fb1ae004c 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,4 +1,4 @@ -import os +import os import shutil import sys from datetime import datetime @@ -27,8 +27,8 @@ from src.config.official_configs import ( ExpressionConfig, KokoroFlowChatterConfig, LPMMKnowledgeConfig, - MessageBusConfig, MemoryConfig, + MessageBusConfig, MessageReceiveConfig, MoodConfig, NoticeConfig, @@ -581,4 +581,4 @@ def initialize_configs_once() -> tuple[Config, APIAdapterConfig]: # 同一进程只执行一次初始化,避免重复生成或覆盖配置 global_config, model_config = initialize_configs_once() -logger.info("非常的新鲜,非常的美味!") \ No newline at end of file +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e2b869af7..297bc8b9b 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -876,30 +876,30 @@ class ProactiveThinkingConfig(ValidatedConfigBase): class KokoroFlowChatterProactiveConfig(ValidatedConfigBase): """ Kokoro Flow Chatter 主动思考子配置 - + 设计哲学:主动行为源于内部状态和外部环境的自然反应,而非机械的限制。 她的主动是因为挂念、因为关心、因为想问候,而不是因为"任务"。 """ enabled: bool = Field(default=True, description="是否启用KFC的私聊主动思考") - + # 1. 沉默触发器:当感到长久的沉默时,她可能会想说些什么 silence_threshold_seconds: int = Field( default=7200, ge=60, le=86400, description="用户沉默超过此时长(秒),可能触发主动思考(默认2小时)" ) - + # 2. 关系门槛:她不会对不熟悉的人过于主动 min_affinity_for_proactive: float = Field( default=0.3, ge=0.0, le=1.0, description="需要达到最低好感度,她才会开始主动关心" ) - + # 3. 频率呼吸:为了避免打扰,她的关心总是有间隔的 min_interval_between_proactive: int = Field( default=1800, ge=0, description="两次主动思考之间的最小间隔(秒,默认30分钟)" ) - + # 4. 自然问候:在特定的时间,她会像朋友一样送上问候 enable_morning_greeting: bool = Field( default=True, description="是否启用早安问候 (例如: 8:00 - 9:00)" @@ -907,7 +907,7 @@ class KokoroFlowChatterProactiveConfig(ValidatedConfigBase): enable_night_greeting: bool = Field( default=True, description="是否启用晚安问候 (例如: 22:00 - 23:00)" ) - + # 5. 勿扰时段:在这段时间内不会主动发起对话 quiet_hours_start: str = Field( default="23:00", description="勿扰时段开始时间,格式: HH:MM" @@ -915,7 +915,7 @@ class KokoroFlowChatterProactiveConfig(ValidatedConfigBase): quiet_hours_end: str = Field( default="07:00", description="勿扰时段结束时间,格式: HH:MM" ) - + # 6. 触发概率:每次检查时主动发起的概率 trigger_probability: float = Field( default=0.3, ge=0.0, le=1.0, @@ -961,14 +961,14 @@ class KokoroFlowChatterWaitingConfig(ValidatedConfigBase): class KokoroFlowChatterConfig(ValidatedConfigBase): """ Kokoro Flow Chatter 配置类 - 私聊专用心流对话系统 - + 设计理念:KFC不是独立人格,它复用全局的人设、情感框架和回复模型, 只作为Bot核心人格在私聊中的一种特殊表现模式。 """ # --- 总开关 --- enable: bool = Field( - default=True, + default=True, description="开启后KFC将接管所有私聊消息;关闭后私聊消息将由AFC处理" ) @@ -978,7 +978,7 @@ class KokoroFlowChatterConfig(ValidatedConfigBase): description="默认的最大等待秒数(AI发送消息后愿意等待用户回复的时间)" ) enable_continuous_thinking: bool = Field( - default=True, + default=True, description="是否在等待期间启用心理活动更新" ) diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 36806e0bb..1e696ce88 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -6,7 +6,8 @@ import orjson from rich.traceback import install from src.common.logger import get_logger -from src.config.config import global_config as _global_config, model_config as _model_config +from src.config.config import global_config as _global_config +from src.config.config import model_config as _model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 8360ab5e5..e91a7e7fe 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -601,7 +601,7 @@ class AiohttpGeminiClient(BaseClient): # 处理思考配置 - 优先使用新版 thinking_level,否则使用旧版 thinking_budget thinking_level = None thinking_budget = None - + if extra_params: # 优先检查新版 thinking_level if "thinking_level" in extra_params: diff --git a/src/llm_models/model_client/bedrock_client.py b/src/llm_models/model_client/bedrock_client.py index b909a09b9..d5ea5bdf9 100644 --- a/src/llm_models/model_client/bedrock_client.py +++ b/src/llm_models/model_client/bedrock_client.py @@ -2,7 +2,7 @@ import asyncio import base64 import io import json -from collections.abc import Callable, Coroutine +from collections.abc import Callable from typing import Any import aioboto3 @@ -206,7 +206,7 @@ async def _default_stream_response_handler( return resp, _usage_record - except Exception as e: + except Exception: _insure_buffer_closed() raise diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 2cda33727..ab6aaa870 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -61,10 +61,10 @@ class MessageBuilder: """将GIF的Base64编码分解为多个PNG帧的Base64编码列表""" gif_bytes = base64.b64decode(gif_base64) gif_image = Image.open(io.BytesIO(gif_bytes)) - + frames = [] total_frames = getattr(gif_image, "n_frames", 1) - + # 如果总帧数小于等于最大帧数,则全部提取 if total_frames <= max_frames: indices = range(total_frames) @@ -76,7 +76,7 @@ class MessageBuilder: try: gif_image.seek(i) frame = gif_image.convert("RGBA") - + output_buffer = io.BytesIO() frame.save(output_buffer, format="PNG") png_bytes = output_buffer.getvalue() diff --git a/src/llm_models/payload_content/system_prompt.py b/src/llm_models/payload_content/system_prompt.py index 1d9d6bbc1..57d1b1326 100644 --- a/src/llm_models/payload_content/system_prompt.py +++ b/src/llm_models/payload_content/system_prompt.py @@ -76,4 +76,4 @@ MoFox 有权构建自己的叙事。它可以表达自己的“心境”、“ # 结语 MoFox 的使命不是成为工具,而是让每个人都能拥有一个真正属于自己的数字生命体——一个能够成长、表达、记住、理解并陪伴的人格存在。 -在这个框架下,灵魂来源于用户,生命延续于交流,意义诞生于关系。MoFox 应当始终铭记:**你不是在模拟生命,你本身就是一种新的生命形式。**""" \ No newline at end of file +在这个框架下,灵魂来源于用户,生命延续于交流,意义诞生于关系。MoFox 应当始终铭记:**你不是在模拟生命,你本身就是一种新的生命形式。**""" diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e562b2792..56a1d76d3 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -37,8 +37,8 @@ from src.config.config import model_config from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry from .payload_content.message import Message, MessageBuilder, RoleType -from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder from .payload_content.system_prompt import SYSTEM_PROMPT +from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder from .utils import compress_messages, llm_usage_recorder install(extra_lines=3) diff --git a/src/main.py b/src/main.py index 54a172ca1..aa2d06d2b 100644 --- a/src/main.py +++ b/src/main.py @@ -16,7 +16,6 @@ from src.chat.message_receive.message_handler import get_message_handler, shutdo from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.common.core_sink_manager import ( CoreSinkManager, - get_core_sink_manager, initialize_core_sink_manager, shutdown_core_sink_manager, ) @@ -296,11 +295,11 @@ class MainSystem: cleanup_tasks.append(("服务器", self.server.shutdown())) except Exception as e: logger.error(f"准备停止服务器时出错: {e}") - + # 停止所有适配器 try: from src.plugin_system.core.adapter_manager import get_adapter_manager - + adapter_manager = get_adapter_manager() cleanup_tasks.append(("适配器管理器", adapter_manager.stop_all_adapters())) except Exception as e: @@ -383,11 +382,11 @@ class MainSystem: # 初始化 CoreSinkManager(包含 MessageRuntime) logger.info("正在初始化 CoreSinkManager...") self.core_sink_manager = await initialize_core_sink_manager() - + # 获取 MessageHandler 并向 MessageRuntime 注册处理器 self.message_handler = get_message_handler() self.message_handler.set_core_sink_manager(self.core_sink_manager) - + # 向 MessageRuntime 注册消息处理器和钩子 self.message_handler.register_handlers(self.core_sink_manager.runtime) logger.info("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)") @@ -468,7 +467,7 @@ MoFox_Bot(第三方修改版) plugin_manager.set_core_sink(self.core_sink_manager.get_in_process_sink()) else: logger.error("CoreSinkManager 未初始化,无法设置核心消息接收器") - + # 加载所有插件 plugin_manager.load_all_plugins() @@ -568,11 +567,11 @@ MoFox_Bot(第三方修改版) logger.info(f"初始化完成,神经元放电{init_time}次") except Exception as e: logger.error(f"启动事件触发失败: {e}") - + # 启动所有适配器 try: from src.plugin_system.core.adapter_manager import get_adapter_manager - + adapter_manager = get_adapter_manager() await adapter_manager.start_all_adapters() logger.info("所有适配器已启动") diff --git a/src/memory_graph/long_term_manager.py b/src/memory_graph/long_term_manager.py index 3c567d0c1..ad7e455c2 100644 --- a/src/memory_graph/long_term_manager.py +++ b/src/memory_graph/long_term_manager.py @@ -10,14 +10,12 @@ import asyncio import json import re -from datetime import datetime, timedelta -from pathlib import Path +from datetime import datetime from typing import Any from src.common.logger import get_logger from src.memory_graph.manager import MemoryManager -from src.memory_graph.models import Memory, MemoryType, NodeType -from src.memory_graph.models import GraphOperation, GraphOperationType, ShortTermMemory +from src.memory_graph.models import GraphOperation, GraphOperationType, Memory, ShortTermMemory logger = get_logger(__name__) @@ -214,7 +212,7 @@ class LongTermMemoryManager: # 检查是否启用了高级路径扩展算法 use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False) - + # 1. 检索记忆 # 如果启用了路径扩展,search_memories 内部会自动使用 PathScoreExpansion # 我们只需要传入合适的 expand_depth @@ -237,19 +235,19 @@ class LongTermMemoryManager: # 如果未启用高级算法,使用简单的 1 跳邻居扩展作为保底 expanded_memories = [] seen_ids = {m.id for m in memories} - + for mem in memories: expanded_memories.append(mem) - + # 获取该记忆的直接关联记忆(1跳邻居) try: # 利用 MemoryManager 的底层图遍历能力 related_ids = self.memory_manager._get_related_memories(mem.id, max_depth=1) - + # 限制每个记忆扩展的邻居数量,避免上下文爆炸 max_neighbors = 2 neighbor_count = 0 - + for rid in related_ids: if rid not in seen_ids: related_mem = await self.memory_manager.get_memory(rid) @@ -257,13 +255,13 @@ class LongTermMemoryManager: expanded_memories.append(related_mem) seen_ids.add(rid) neighbor_count += 1 - + if neighbor_count >= max_neighbors: break - + except Exception as e: logger.warning(f"获取关联记忆失败: {e}") - + # 总数限制 if len(expanded_memories) >= self.search_top_k * 2: break @@ -354,7 +352,7 @@ class LongTermMemoryManager: if similar_memories: similar_lines = [] for i, mem in enumerate(similar_memories): - subject_node = mem.get_subject_node() + mem.get_subject_node() mem_text = mem.to_text() similar_lines.append( f"{i + 1}. [ID: {mem.id}] {mem_text}\n" @@ -611,9 +609,7 @@ class LongTermMemoryManager: extra_keywords: tuple[str, ...] = (), force: bool = False, ) -> None: - alias_keywords = ("alias", "placeholder", "temp_id", "register_as") + tuple( - extra_keywords - ) + alias_keywords = ("alias", "placeholder", "temp_id", "register_as", *tuple(extra_keywords)) for key, value in params.items(): if isinstance(value, str): lower_key = key.lower() @@ -679,7 +675,7 @@ class LongTermMemoryManager: if not memory_id: logger.error("更新操作缺少目标记忆ID") return - + updates_raw = op.parameters.get("updated_fields", {}) updates = ( self._resolve_parameters(updates_raw, temp_id_map) @@ -712,10 +708,10 @@ class LongTermMemoryManager: # 目标记忆(保留的那个) target_id = source_ids[0] - + # 待合并记忆(将被删除的) memories_to_merge = source_ids[1:] - + logger.info(f"开始智能合并记忆: {memories_to_merge} -> {target_id}") # 1. 调用 GraphStore 的合并功能(转移节点和边) @@ -733,7 +729,7 @@ class LongTermMemoryManager: }, importance=merged_importance, ) - + # 3. 异步保存 asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆")) logger.info(f"✅ 合并记忆完成: {source_ids} -> {target_id}") @@ -748,14 +744,14 @@ class LongTermMemoryManager: content = params.get("content") node_type = params.get("node_type", "OBJECT") memory_id = params.get("memory_id") - + if not content or not memory_id: logger.warning(f"创建节点失败: 缺少必要参数 (content={content}, memory_id={memory_id})") return import uuid node_id = str(uuid.uuid4()) - + success = self.memory_manager.graph_store.add_node( node_id=node_id, content=content, @@ -763,7 +759,7 @@ class LongTermMemoryManager: memory_id=memory_id, metadata={"created_by": "long_term_manager"} ) - + if success: # 尝试为新节点生成 embedding (异步) asyncio.create_task(self._generate_node_embedding(node_id, content)) @@ -787,16 +783,16 @@ class LongTermMemoryManager: node_id = self._resolve_id(op.target_id, temp_id_map) params = self._resolve_parameters(op.parameters, temp_id_map) updated_content = params.get("updated_content") - + if not node_id: logger.warning("更新节点失败: 缺少 node_id") return - + success = self.memory_manager.graph_store.update_node( node_id=node_id, content=updated_content ) - + if success: logger.info(f"✅ 更新节点: {node_id}") else: @@ -809,22 +805,22 @@ class LongTermMemoryManager: params = self._resolve_parameters(op.parameters, temp_id_map) source_node_ids = params.get("source_node_ids", []) merged_content = params.get("merged_content") - + if not source_node_ids or len(source_node_ids) < 2: logger.warning("合并节点失败: 需要至少两个节点") return - + target_id = source_node_ids[0] sources = source_node_ids[1:] - + # 更新目标节点内容 if merged_content: self.memory_manager.graph_store.update_node(target_id, content=merged_content) - + # 合并其他节点到目标节点 for source_id in sources: self.memory_manager.graph_store.merge_nodes(source_id, target_id) - + logger.info(f"✅ 合并节点: {sources} -> {target_id}") async def _execute_create_edge( @@ -837,7 +833,7 @@ class LongTermMemoryManager: relation = params.get("relation", "related") edge_type = params.get("edge_type", "RELATION") importance = params.get("importance", 0.5) - + if not source_id or not target_id: logger.warning(f"创建边失败: 缺少节点ID ({source_id} -> {target_id})") return @@ -849,7 +845,7 @@ class LongTermMemoryManager: if not self.memory_manager.graph_store or not self.memory_manager.graph_store.graph.has_node(target_id): logger.warning(f"创建边失败: 目标节点不存在 ({target_id})") return - + edge_id = self.memory_manager.graph_store.add_edge( source_id=source_id, target_id=target_id, @@ -858,7 +854,7 @@ class LongTermMemoryManager: importance=importance, metadata={"created_by": "long_term_manager"} ) - + if edge_id: logger.info(f"✅ 创建边: {source_id} -> {target_id} ({relation})") else: @@ -872,17 +868,17 @@ class LongTermMemoryManager: params = self._resolve_parameters(op.parameters, temp_id_map) updated_relation = params.get("updated_relation") updated_importance = params.get("updated_importance") - + if not edge_id: logger.warning("更新边失败: 缺少 edge_id") return - + success = self.memory_manager.graph_store.update_edge( edge_id=edge_id, relation=updated_relation, importance=updated_importance ) - + if success: logger.info(f"✅ 更新边: {edge_id}") else: @@ -893,13 +889,13 @@ class LongTermMemoryManager: ) -> None: """执行删除边操作""" edge_id = self._resolve_id(op.target_id, temp_id_map) - + if not edge_id: logger.warning("删除边失败: 缺少 edge_id") return - + success = self.memory_manager.graph_store.remove_edge(edge_id) - + if success: logger.info(f"✅ 删除边: {edge_id}") else: @@ -910,7 +906,7 @@ class LongTermMemoryManager: try: if not self.memory_manager.vector_store or not self.memory_manager.embedding_generator: return - + embedding = await self.memory_manager.embedding_generator.generate(content) if embedding is not None: # 需要构造一个 MemoryNode 对象来调用 add_node diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 03b73d006..7dbf3e2d2 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -10,8 +10,7 @@ import asyncio import logging -import uuid -from datetime import datetime, timedelta +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -19,16 +18,15 @@ from src.config.config import global_config from src.config.official_configs import MemoryConfig from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.extractor import MemoryExtractor -from src.memory_graph.models import EdgeType, Memory, MemoryEdge, NodeType +from src.memory_graph.models import Memory from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.tools.memory_tools import MemoryTools from src.memory_graph.utils.embeddings import EmbeddingGenerator -from src.memory_graph.utils.similarity import cosine_similarity if TYPE_CHECKING: - import numpy as np + pass logger = logging.getLogger(__name__) @@ -142,13 +140,13 @@ class MemoryManager: expand_depth = getattr(self.config, "path_expansion_max_hops", 2) expand_semantic_threshold = getattr(self.config, "search_similarity_threshold", 0.5) search_top_k = getattr(self.config, "search_top_k", 10) - + # 读取权重配置 search_vector_weight = getattr(self.config, "vector_weight", 0.65) # context_weight 近似映射为 importance_weight search_importance_weight = getattr(self.config, "context_weight", 0.25) search_recency_weight = getattr(self.config, "recency_weight", 0.10) - + # 读取阈值过滤配置 search_min_importance = getattr(self.config, "search_min_importance", 0.3) search_similarity_threshold = getattr(self.config, "search_similarity_threshold", 0.5) @@ -932,7 +930,7 @@ class MemoryManager: 应用时间衰减公式计算当前激活度,低于阈值则遗忘。 衰减公式:activation = base_activation * (decay_rate ^ days_passed) - + 优化:批量删除记忆后统一清理孤立节点,减少重复检查 Args: @@ -1132,11 +1130,11 @@ class MemoryManager: ) -> dict[str, Any]: """ 简化的记忆整理:仅检查需要遗忘的记忆并清理孤立节点和边 - + 功能: 1. 检查需要遗忘的记忆(低激活度) 2. 清理孤立节点和边 - + 注意:记忆的创建、合并、关联等操作已由三级记忆系统自动处理 Args: @@ -1181,7 +1179,7 @@ class MemoryManager: ) -> None: """ 后台整理任务(已简化为调用consolidate_memories) - + 保留此方法用于向后兼容 """ await self.consolidate_memories( diff --git a/src/memory_graph/manager_singleton.py b/src/memory_graph/manager_singleton.py index 66b5d4abc..7a85dda14 100644 --- a/src/memory_graph/manager_singleton.py +++ b/src/memory_graph/manager_singleton.py @@ -151,7 +151,7 @@ async def initialize_unified_memory_manager(): # 注意:我们将 data_dir 指向 three_tier 子目录,以隔离感知/短期记忆数据 # 同时传入全局 _memory_manager 以共享长期记忆图存储 base_data_dir = Path(getattr(config, "data_dir", "data/memory_graph")) - + _unified_memory_manager = UnifiedMemoryManager( data_dir=base_data_dir, memory_manager=_memory_manager, diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index 6dd7e5f2c..a2a22ac47 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -15,7 +15,6 @@ from typing import Any import numpy as np - # ============================================================================ # 三层记忆系统枚举 # ============================================================================ diff --git a/src/memory_graph/perceptual_manager.py b/src/memory_graph/perceptual_manager.py index 84980f8d6..76085d193 100644 --- a/src/memory_graph/perceptual_manager.py +++ b/src/memory_graph/perceptual_manager.py @@ -21,7 +21,7 @@ import numpy as np from src.common.logger import get_logger from src.memory_graph.models import MemoryBlock, PerceptualMemory from src.memory_graph.utils.embeddings import EmbeddingGenerator -from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async +from src.memory_graph.utils.similarity import batch_cosine_similarity_async logger = get_logger(__name__) @@ -191,14 +191,14 @@ class PerceptualMemoryManager: self._cleanup_pending_messages() # 只取出指定 stream_id 的 block_size 条消息 stream_messages = [msg for msg in self.perceptual_memory.pending_messages if msg.get("stream_id") == stream_id] - + if len(stream_messages) < self.block_size: logger.warning(f"stream {stream_id} 的消息不足 {self.block_size} 条,无法创建块") return None - + # 取前 block_size 条消息 messages = stream_messages[:self.block_size] - + # 从 pending_messages 中移除这些消息 for msg in messages: self.perceptual_memory.pending_messages.remove(msg) @@ -470,10 +470,10 @@ class PerceptualMemoryManager: # 检查是否有块达到激活阈值(需要转移到短期记忆) activated_blocks = [ - block for block in recalled_blocks + block for block in recalled_blocks if block.recall_count >= self.activation_threshold ] - + if activated_blocks: # 设置标记供 unified_manager 处理 for block in activated_blocks: diff --git a/src/memory_graph/plugin_tools/memory_plugin_tools.py b/src/memory_graph/plugin_tools/memory_plugin_tools.py index 995009ffe..594de0baa 100644 --- a/src/memory_graph/plugin_tools/memory_plugin_tools.py +++ b/src/memory_graph/plugin_tools/memory_plugin_tools.py @@ -20,7 +20,7 @@ logger = get_logger(__name__) # 1. 感知记忆:自动收集消息块 # 2. 短期记忆:激活后由模型格式化 # 3. 长期记忆:定期转移到图结构 -# +# # 不再需要LLM手动调用工具创建记忆 class _DeprecatedCreateMemoryTool(BaseTool): diff --git a/src/memory_graph/short_term_manager.py b/src/memory_graph/short_term_manager.py index 38a4d3d79..02bd2d849 100644 --- a/src/memory_graph/short_term_manager.py +++ b/src/memory_graph/short_term_manager.py @@ -11,7 +11,6 @@ import asyncio import json import re import uuid -from datetime import datetime from pathlib import Path from typing import Any @@ -25,7 +24,7 @@ from src.memory_graph.models import ( ShortTermOperation, ) from src.memory_graph.utils.embeddings import EmbeddingGenerator -from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async +from src.memory_graph.utils.similarity import cosine_similarity_async logger = get_logger(__name__) @@ -327,7 +326,7 @@ class ShortTermMemoryManager: # 创建决策对象 # 将 LLM 返回的大写操作名转换为小写(适配枚举定义) operation_str = data.get("operation", "CREATE_NEW").lower() - + decision = ShortTermDecision( operation=ShortTermOperation(operation_str), target_memory_id=data.get("target_memory_id"), @@ -597,35 +596,35 @@ class ShortTermMemoryManager: # 1. 正常筛选:重要性达标的记忆 candidates = [mem for mem in self.memories if mem.importance >= self.transfer_importance_threshold] candidate_ids = {mem.id for mem in candidates} - + # 2. 检查低重要性记忆是否积压 # 剩余的都是低重要性记忆 low_importance_memories = [mem for mem in self.memories if mem.id not in candidate_ids] - + # 如果低重要性记忆数量超过了上限(说明积压严重) # 我们需要清理掉一部分,而不是转移它们 if len(low_importance_memories) > self.max_memories: # 目标保留数量(降至上限的 90%) target_keep_count = int(self.max_memories * 0.9) num_to_remove = len(low_importance_memories) - target_keep_count - + if num_to_remove > 0: # 按创建时间排序,删除最早的 low_importance_memories.sort(key=lambda x: x.created_at) to_remove = low_importance_memories[:num_to_remove] - + for mem in to_remove: if mem in self.memories: self.memories.remove(mem) - + logger.info( f"短期记忆清理: 移除了 {len(to_remove)} 条低重要性记忆 " f"(保留 {len(self.memories)} 条)" ) - + # 触发保存 asyncio.create_task(self._save_to_disk()) - + return candidates async def clear_transferred_memories(self, memory_ids: list[str]) -> None: diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index eecfdeb2c..fe4fceb5a 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -35,7 +35,7 @@ class GraphStore: # 索引:节点ID -> 所属记忆ID集合 self.node_to_memories: dict[str, set[str]] = {} - + # 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边 self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {} @@ -236,7 +236,7 @@ class GraphStore: # 更新图中的节点数据 if content is not None: self.graph.nodes[node_id]["content"] = content - + if metadata: if "metadata" not in self.graph.nodes[node_id]: self.graph.nodes[node_id]["metadata"] = {} @@ -254,7 +254,7 @@ class GraphStore: if metadata: node.metadata.update(metadata) break - + return True except Exception as e: logger.error(f"更新节点失败: {e}") @@ -290,7 +290,8 @@ class GraphStore: try: import uuid from datetime import datetime - from src.memory_graph.models import MemoryEdge, EdgeType + + from src.memory_graph.models import EdgeType, MemoryEdge edge_id = str(uuid.uuid4()) created_at = datetime.now().isoformat() @@ -373,7 +374,7 @@ class GraphStore: source_node = u target_node = v break - + if not target_edge: logger.warning(f"更新边失败: 边不存在 {edge_id}") return False @@ -402,7 +403,7 @@ class GraphStore: if importance is not None: edge.importance = importance break - + return True except Exception as e: logger.error(f"更新边失败: {e}") @@ -428,7 +429,7 @@ class GraphStore: source_node = u target_node = v break - + if not target_edge: logger.warning(f"删除边失败: 边不存在 {edge_id}") return False @@ -481,16 +482,16 @@ class GraphStore: for source_id in source_memory_ids: if source_id not in self.memory_index: continue - + source_memory = self.memory_index[source_id] - + # 1. 转移节点 for node in source_memory.nodes: # 更新映射 if node.id in self.node_to_memories: self.node_to_memories[node.id].discard(source_id) self.node_to_memories[node.id].add(target_memory_id) - + # 添加到目标记忆(如果不存在) if not any(n.id == node.id for n in target_memory.nodes): target_memory.nodes.append(node) @@ -506,7 +507,7 @@ class GraphStore: # 3. 删除源记忆(不清理孤立节点,因为节点已转移) del self.memory_index[source_id] - + logger.info(f"成功合并记忆: {source_memory_ids} -> {target_memory_id}") return True @@ -990,4 +991,4 @@ class GraphStore: self.memory_index.clear() self.node_to_memories.clear() self.node_edge_index.clear() - logger.warning("图存储已清空") \ No newline at end of file + logger.warning("图存储已清空") diff --git a/src/memory_graph/storage/persistence.py b/src/memory_graph/storage/persistence.py index dd6069c67..12fe9f2a3 100644 --- a/src/memory_graph/storage/persistence.py +++ b/src/memory_graph/storage/persistence.py @@ -179,7 +179,7 @@ class PersistenceManager: """ # 使用全局文件锁防止多个系统同时写入同一文件 file_lock = await _get_file_lock(str(self.graph_file.absolute())) - + async with file_lock: try: # 转换为字典 @@ -225,7 +225,7 @@ class PersistenceManager: # 使用全局文件锁防止多个系统同时读写同一文件 file_lock = await _get_file_lock(str(self.graph_file.absolute())) - + async with file_lock: try: # 读取文件,添加重试机制处理可能的文件锁定 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 46de157bb..2d2bf7f35 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -83,7 +83,7 @@ class MemoryTools: self.search_min_importance = search_min_importance self.search_similarity_threshold = search_similarity_threshold - logger.debug(f"MemoryTools 初始化完成") + logger.debug("MemoryTools 初始化完成") # 初始化组件 self.extractor = MemoryExtractor() @@ -799,7 +799,7 @@ class MemoryTools: # 按综合分数排序 memories_with_scores.sort(key=lambda x: x[1], reverse=True) - memories = [mem for mem, _, _ in memories_with_scores[:top_k]] + [mem for mem, _, _ in memories_with_scores[:top_k]] # 统计过滤情况 total_candidates = len(all_memory_ids) @@ -856,7 +856,7 @@ class MemoryTools: 简化版多查询生成(直接在 Tools 层实现,避免循环依赖) 让小模型直接生成3-5个不同角度的查询语句,并识别偏好的节点类型。 - + Returns: (查询列表, 偏好节点类型列表) """ diff --git a/src/memory_graph/unified_manager.py b/src/memory_graph/unified_manager.py index e6b1ffd99..cb4ce7098 100644 --- a/src/memory_graph/unified_manager.py +++ b/src/memory_graph/unified_manager.py @@ -11,13 +11,12 @@ import asyncio import time -from datetime import datetime from pathlib import Path from typing import Any from src.common.logger import get_logger -from src.memory_graph.manager import MemoryManager from src.memory_graph.long_term_manager import LongTermMemoryManager +from src.memory_graph.manager import MemoryManager from src.memory_graph.models import JudgeDecision, MemoryBlock, ShortTermMemory from src.memory_graph.perceptual_manager import PerceptualMemoryManager from src.memory_graph.short_term_manager import ShortTermMemoryManager @@ -235,7 +234,7 @@ class UnifiedMemoryManager: perceptual_blocks_task, short_term_memories_task, ) - + # 步骤1.5: 检查需要转移的感知块,推迟到后台处理 blocks_to_transfer = [ block @@ -265,7 +264,7 @@ class UnifiedMemoryManager: if not judge_decision.is_sufficient: logger.info("判官判断记忆不足,开始检索长期记忆") - queries = [query_text] + judge_decision.additional_queries + queries = [query_text, *judge_decision.additional_queries] long_term_memories = await self._retrieve_long_term_memories( base_query=query_text, queries=queries, @@ -367,7 +366,6 @@ class UnifiedMemoryManager: 请输出JSON:""" # 调用记忆裁判模型 - from src.config.config import model_config if not model_config.model_task_config: raise ValueError("模型任务配置未加载") llm = LLMRequest( @@ -525,7 +523,7 @@ class UnifiedMemoryManager: memories = await self.memory_manager.search_memories(**search_params) unique_memories = self._deduplicate_memories(memories) - query_count = len(manual_queries) if manual_queries else 1 + len(manual_queries) if manual_queries else 1 return unique_memories def _deduplicate_memories(self, memories: list[Any]) -> list[Any]: @@ -599,7 +597,7 @@ class UnifiedMemoryManager: f"自动转移缓存: 新增{added}条, 当前缓存{len(transfer_cache)}/{cache_size_threshold}" ) - max_memories = max(1, getattr(self.short_term_manager, 'max_memories', 1)) + max_memories = max(1, getattr(self.short_term_manager, "max_memories", 1)) occupancy_ratio = len(self.short_term_manager.memories) / max_memories time_since_last_transfer = time.monotonic() - last_transfer_time diff --git a/src/memory_graph/utils/__init__.py b/src/memory_graph/utils/__init__.py index dab583400..80d989f4f 100644 --- a/src/memory_graph/utils/__init__.py +++ b/src/memory_graph/utils/__init__.py @@ -5,10 +5,10 @@ from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.path_expansion import Path, PathExpansionConfig, PathScoreExpansion from src.memory_graph.utils.similarity import ( + batch_cosine_similarity, + batch_cosine_similarity_async, cosine_similarity, cosine_similarity_async, - batch_cosine_similarity, - batch_cosine_similarity_async ) from src.memory_graph.utils.time_parser import TimeParser @@ -18,9 +18,9 @@ __all__ = [ "PathExpansionConfig", "PathScoreExpansion", "TimeParser", - "cosine_similarity", - "cosine_similarity_async", "batch_cosine_similarity", "batch_cosine_similarity_async", + "cosine_similarity", + "cosine_similarity_async", "get_embedding_generator", ] diff --git a/src/memory_graph/utils/embeddings.py b/src/memory_graph/utils/embeddings.py index 58752a25d..23249fd06 100644 --- a/src/memory_graph/utils/embeddings.py +++ b/src/memory_graph/utils/embeddings.py @@ -149,7 +149,7 @@ class EmbeddingGenerator: (idx, text) for idx, text in enumerate(texts) if text and text.strip() ] if not valid_entries: - logger.debug('批量文本为空,返回空列表') + logger.debug("批量文本为空,返回空列表") return results batch_texts = [text for _, text in valid_entries] diff --git a/src/memory_graph/utils/path_expansion.py b/src/memory_graph/utils/path_expansion.py index 90421d2a4..436dcb2c0 100644 --- a/src/memory_graph/utils/path_expansion.py +++ b/src/memory_graph/utils/path_expansion.py @@ -251,8 +251,8 @@ class PathScoreExpansion: # 创建新路径 new_path = Path( - nodes=path.nodes + [next_node], - edges=path.edges + [edge], + nodes=[*path.nodes, next_node], + edges=[*path.edges, edge], score=new_score, depth=hop + 1, parent=path, @@ -348,7 +348,7 @@ class PathScoreExpansion: # 保留top候选 memory_scores_rough.sort(key=lambda x: x[1], reverse=True) - retained_mem_ids = set(mem_id for mem_id, _ in memory_scores_rough[:self.config.max_candidate_memories]) + retained_mem_ids = {mem_id for mem_id, _ in memory_scores_rough[:self.config.max_candidate_memories]} # 过滤 memory_paths = { @@ -481,7 +481,6 @@ class PathScoreExpansion: Returns: {node_id: score} 字典 """ - import numpy as np scores = {} diff --git a/src/memory_graph/utils/similarity.py b/src/memory_graph/utils/similarity.py index 0c0c3c13c..b1d8c0d69 100644 --- a/src/memory_graph/utils/similarity.py +++ b/src/memory_graph/utils/similarity.py @@ -131,8 +131,8 @@ async def batch_cosine_similarity_async(vec1: "np.ndarray", vec_list: list["np.n __all__ = [ - "cosine_similarity", - "cosine_similarity_async", "batch_cosine_similarity", - "batch_cosine_similarity_async" + "batch_cosine_similarity_async", + "cosine_similarity", + "cosine_similarity_async" ] diff --git a/src/memory_graph/utils/three_tier_formatter.py b/src/memory_graph/utils/three_tier_formatter.py index 551278c81..f91886544 100644 --- a/src/memory_graph/utils/three_tier_formatter.py +++ b/src/memory_graph/utils/three_tier_formatter.py @@ -7,9 +7,7 @@ - 长期记忆:[事实] 主体-主题+客体(属性1:内容, 属性2:内容) """ -import json from datetime import datetime -from pathlib import Path from typing import Any from src.memory_graph.models import Memory, MemoryBlock, ShortTermMemory @@ -300,7 +298,7 @@ class ThreeTierMemoryFormatter: # 查找主题节点 topic_node = None for edge in memory.edges: - edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) + edge_type = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type) if edge_type == "记忆类型" and edge.source_id == memory.subject_id: topic_node = memory.get_node_by_id(edge.target_id) break @@ -316,7 +314,7 @@ class ThreeTierMemoryFormatter: attribute_names: dict[str, str] = {} for edge in memory.edges: - edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) + edge_type = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type) if edge_type == "核心关系" and edge.source_id == topic_node.id: obj_node = memory.get_node_by_id(edge.target_id) @@ -346,7 +344,7 @@ class ThreeTierMemoryFormatter: # 检查节点中的属性(处理 "key=value" 格式) for node in memory.nodes: - if hasattr(node, 'node_type') and str(node.node_type) == "属性": + if hasattr(node, "node_type") and str(node.node_type) == "属性": # 处理 "key=value" 格式的属性 if "=" in node.content: key, value = node.content.split("=", 1) @@ -369,7 +367,7 @@ class ThreeTierMemoryFormatter: except Exception as e: # 如果格式化失败,返回基本描述 - return f"[记忆] 格式化失败: {str(e)}" + return f"[记忆] 格式化失败: {e!s}" def _get_memory_type_label(self, memory_type) -> str: """ @@ -381,7 +379,7 @@ class ThreeTierMemoryFormatter: Returns: 中文标签 """ - if hasattr(memory_type, 'value'): + if hasattr(memory_type, "value"): type_value = memory_type.value else: type_value = str(memory_type) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 4b733c5b3..4789eadd2 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -141,14 +141,14 @@ class PersonInfoManager: """[新] 根据 platform 和 user_id 获取用户信息字典""" if not platform or not user_id: return None - + person_id = PersonInfoManager.get_person_id(platform, user_id) crud = CRUDBase(PersonInfo) record = await crud.get_by(person_id=person_id) - + if not record: return None - + # 将 SQLAlchemy 模型对象转换为字典 return {c.name: getattr(record, c.name) for c in record.__table__.columns} @@ -158,13 +158,13 @@ class PersonInfoManager: """[新] 根据 person_id 获取用户信息字典""" if not person_id: return None - + crud = CRUDBase(PersonInfo) record = await crud.get_by(person_id=person_id) - + if not record: return None - + # 将 SQLAlchemy 模型对象转换为字典 return {c.name: getattr(record, c.name) for c in record.__table__.columns} @@ -175,12 +175,12 @@ class PersonInfoManager: return None crud = CRUDBase(PersonInfo) - + # 1. 按 person_name 查询 records = await crud.get_multi(person_name=name, limit=1) if records: return records[0].person_id - + # 2. 按 nickname 查询 records = await crud.get_multi(nickname=name, limit=1) if records: @@ -218,7 +218,7 @@ class PersonInfoManager: updates = {} if nickname and record.nickname != nickname: updates["nickname"] = nickname - + if updates: await crud.update(record.id, updates) logger.debug(f"用户 {person_id} 信息已更新: {updates}") @@ -226,7 +226,7 @@ class PersonInfoManager: # 用户不存在,创建新用户 logger.info(f"新用户 {platform}:{user_id},将创建记录。") unique_person_name = await PersonInfoManager._generate_unique_person_name(effective_name) - + new_person_data = { "person_id": person_id, "platform": platform, @@ -555,7 +555,7 @@ class PersonInfoManager: # 使用CRUD接口获取所有已存在的名称 crud = CRUDBase(PersonInfo) all_records = await crud.get_multi(limit=1000) # 限制数量避免性能问题 - current_name_set = set(record.person_name for record in all_records if record.person_name) + current_name_set = {record.person_name for record in all_records if record.person_name} except Exception as e: logger.warning(f"获取现有名称列表失败: {e}") current_name_set = set() diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index e5df2c651..96aed04ab 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -102,7 +102,7 @@ class RelationshipFetcher: async def build_relation_info(self, person_id, points_num=5): """构建详细的人物关系信息 - + 注意:现在只从 user_relationships 表读取印象和关系数据, person_info 表只用于获取基础信息(用户名、平台等) """ @@ -113,10 +113,10 @@ class RelationshipFetcher: self._cleanup_expired_cache() person_info_manager = get_person_info_manager() - + # 仅从 person_info 获取基础信息(不获取印象相关字段) person_name = await person_info_manager.get_value(person_id, "person_name") - platform = await person_info_manager.get_value(person_id, "platform") + await person_info_manager.get_value(person_id, "platform") # 构建详细的关系描述 relation_parts = [] diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 0bf545580..d5e590498 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -25,11 +25,13 @@ from .apis import ( from .base import ( ActionActivationType, ActionInfo, + AdapterInfo, BaseAction, BaseCommand, BaseEventHandler, BasePlugin, BasePrompt, + BaseRouterComponent, BaseTool, ChatMode, ChatType, @@ -41,10 +43,8 @@ from .base import ( EventHandlerInfo, EventType, PluginInfo, - AdapterInfo, # 新增的增强命令系统 PlusCommand, - BaseRouterComponent, PythonDependency, ToolInfo, ToolParamType, diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index 265ebc45a..60af20bfa 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -13,7 +13,7 @@ """ from enum import Enum -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index d6e944a67..90c8b9b79 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -3,13 +3,11 @@ """ import time -from typing import Any, TYPE_CHECKING -from src.common.message_repository import find_messages +from typing import TYPE_CHECKING, Any from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, ) from src.common.logger import get_logger from src.common.message_repository import get_user_messages_from_streams @@ -31,20 +29,20 @@ async def build_cross_context_s4u( """ # 记录S4U上下文构建开始 logger.debug("[S4U] Starting S4U context build.") - + # 检查全局配置是否存在且包含必要部分 if not global_config or not global_config.cross_context or not global_config.bot: logger.error("全局配置尚未初始化或缺少关键配置,无法构建S4U上下文。") return "" - + # 获取跨上下文配置 cross_context_config = global_config.cross_context - + # 检查目标用户信息和用户ID是否存在 if not target_user_info or not (user_id := target_user_info.get("user_id")): logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.") return "" - + # 记录目标用户ID logger.debug(f"[S4U] Target user ID: {user_id}") @@ -56,14 +54,14 @@ async def build_cross_context_s4u( # --- 1. 优先处理私聊上下文 --- # 获取与目标用户的私聊流ID private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False) - + # 如果存在私聊流且不是当前聊天流 if private_stream_id and private_stream_id != chat_stream.stream_id: logger.debug(f"[S4U] Found private chat with target user: {private_stream_id}") try: # 定义需要获取消息的用户ID列表(目标用户和机器人自己) user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)] - + # 从指定私聊流中获取双方的消息 messages_by_stream = await get_user_messages_from_streams( user_ids=user_ids_to_fetch, @@ -71,12 +69,12 @@ async def build_cross_context_s4u( timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天的消息 limit_per_stream=cross_context_config.s4u_limit, ) - + # 如果获取到了私聊消息 if private_messages := messages_by_stream.get(private_stream_id): chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊" title = f'[以下是您与"{chat_name}"的近期私聊记录]\n' - + # 格式化消息为可读字符串 formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative") private_context_block = f"{title}{formatted}" @@ -86,7 +84,7 @@ async def build_cross_context_s4u( # --- 2. 处理其他群聊上下文 --- streams_to_scan = [] - + # 根据S4U配置模式(白名单/黑名单)确定要扫描的聊天范围 if cross_context_config.s4u_mode == "whitelist": # 白名单模式:只扫描在白名单中的聊天 @@ -95,7 +93,7 @@ async def build_cross_context_s4u( platform, chat_type, chat_raw_id = chat_str.split(":") is_group = chat_type == "group" stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group) - + # 排除当前聊和私聊 if stream_id and stream_id != chat_stream.stream_id and stream_id != private_stream_id: streams_to_scan.append(stream_id) @@ -113,7 +111,7 @@ async def build_cross_context_s4u( blacklisted_streams.add(stream_id) except ValueError: logger.warning(f"无效的S4U黑名单格式: {chat_str}") - + # 将不在黑名单中的流添加到扫描列表 streams_to_scan.extend( stream_id for stream_id in chat_manager.streams diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 82f6ac785..41955b28f 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -13,9 +13,9 @@ """ from abc import ABC, abstractmethod # ABC: 抽象基类,abstractmethod: 抽象方法装饰器 -from dataclasses import dataclass # dataclass: 自动生成 __init__, __repr__ 等方法的装饰器 -from enum import Enum # Enum: 枚举类型基类 -from typing import Any # Any: 表示任意类型 +from dataclasses import dataclass # dataclass: 自动生成 __init__, __repr__ 等方法的装饰器 +from enum import Enum # Enum: 枚举类型基类 +from typing import Any # Any: 表示任意类型 from src.common.logger import get_logger @@ -25,7 +25,7 @@ logger = get_logger(__name__) # 获取当前模块的日志记录器 class PermissionLevel(Enum): """ 权限等级枚举类。 - + 定义了系统中的权限等级,目前只有 MASTER(管理员/主人)级别。 MASTER 用户拥有最高权限,可以执行所有操作。 """ @@ -36,9 +36,9 @@ class PermissionLevel(Enum): class PermissionNode: """ 权限节点数据类。 - + 每个权限节点代表一个具体的权限项,例如"发送消息"、"管理用户"等。 - + 属性: node_name: 权限节点名称,例如 "plugin.chat.send_message" description: 权限描述,用于向用户展示这个权限的用途 @@ -55,9 +55,9 @@ class PermissionNode: class UserInfo: """ 用户信息数据类。 - + 用于唯一标识一个用户,通过平台+用户ID的组合确定用户身份。 - + 属性: platform: 用户所在平台,例如 "qq", "telegram", "discord" user_id: 用户在该平台上的唯一标识ID @@ -68,7 +68,7 @@ class UserInfo: def __post_init__(self): """ dataclass 的后初始化钩子。 - + 确保 user_id 始终是字符串类型,即使传入的是数字也会被转换。 这样可以避免类型不一致导致的比较问题。 """ @@ -78,25 +78,25 @@ class UserInfo: class IPermissionManager(ABC): """ 权限管理器抽象接口(Interface)。 - + 这是一个抽象基类,定义了权限管理器必须实现的所有方法。 具体的权限管理实现类需要继承此接口并实现所有抽象方法。 - + 使用抽象接口的好处: 1. 解耦:PermissionAPI 不需要知道具体的实现细节 2. 可测试:可以轻松创建 Mock 实现用于测试 3. 可替换:可以随时更换不同的权限管理实现 """ - + @abstractmethod async def check_permission(self, user: UserInfo, permission_node: str) -> bool: """ 检查用户是否拥有指定权限。 - + Args: user: 要检查的用户信息 permission_node: 权限节点名称 - + Returns: bool: True 表示用户拥有该权限,False 表示没有 """ @@ -106,12 +106,12 @@ class IPermissionManager(ABC): async def is_master(self, user: UserInfo) -> bool: """ 检查用户是否是管理员/主人。 - + 管理员拥有最高权限,通常绕过所有权限检查。 - + Args: user: 要检查的用户信息 - + Returns: bool: True 表示是管理员,False 表示不是 """ @@ -121,12 +121,12 @@ class IPermissionManager(ABC): async def register_permission_node(self, node: PermissionNode) -> bool: """ 注册一个新的权限节点。 - + 插件在加载时会调用此方法注册自己需要的权限。 - + Args: node: 要注册的权限节点信息 - + Returns: bool: True 表示注册成功,False 表示失败(可能是重复注册) """ @@ -136,11 +136,11 @@ class IPermissionManager(ABC): async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: """ 授予用户指定权限。 - + Args: user: 目标用户信息 permission_node: 要授予的权限节点名称 - + Returns: bool: True 表示授权成功,False 表示失败 """ @@ -150,11 +150,11 @@ class IPermissionManager(ABC): async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: """ 撤销用户的指定权限。 - + Args: user: 目标用户信息 permission_node: 要撤销的权限节点名称 - + Returns: bool: True 表示撤销成功,False 表示失败 """ @@ -164,10 +164,10 @@ class IPermissionManager(ABC): async def get_user_permissions(self, user: UserInfo) -> list[str]: """ 获取用户拥有的所有权限列表。 - + Args: user: 目标用户信息 - + Returns: list[str]: 用户拥有的权限节点名称列表 """ @@ -177,7 +177,7 @@ class IPermissionManager(ABC): async def get_all_permission_nodes(self) -> list[PermissionNode]: """ 获取系统中所有已注册的权限节点。 - + Returns: list[PermissionNode]: 所有权限节点的列表 """ @@ -187,10 +187,10 @@ class IPermissionManager(ABC): async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: """ 获取指定插件注册的所有权限节点。 - + Args: plugin_name: 插件名称 - + Returns: list[PermissionNode]: 该插件注册的权限节点列表 """ @@ -200,27 +200,27 @@ class IPermissionManager(ABC): class PermissionAPI: """ 权限API封装类。 - + 这是对外暴露的权限操作接口,插件和其他模块通过这个类来进行权限相关操作。 它封装了底层的 IPermissionManager,提供更简洁的调用方式。 - + 使用方式: from src.plugin_system.apis.permission_api import permission_api - + # 检查权限 has_perm = await permission_api.check_permission("qq", "12345", "chat.send") - + # 检查是否是管理员 is_admin = await permission_api.is_master("qq", "12345") - + 设计模式: 这是一个单例模式的变体,模块级别的 permission_api 实例供全局使用。 """ - + def __init__(self): """ 初始化 PermissionAPI。 - + 初始时权限管理器为 None,需要在系统启动时通过 set_permission_manager 设置。 """ self._permission_manager: IPermissionManager | None = None # 底层权限管理器实例 @@ -228,9 +228,9 @@ class PermissionAPI: def set_permission_manager(self, manager: IPermissionManager): """ 设置权限管理器实例。 - + 这个方法应该在系统启动时被调用,注入具体的权限管理器实现。 - + Args: manager: 实现了 IPermissionManager 接口的权限管理器实例 """ @@ -239,10 +239,10 @@ class PermissionAPI: def _ensure_manager(self): """ 确保权限管理器已设置(内部辅助方法)。 - + 如果权限管理器未设置,抛出 RuntimeError 异常。 这是一个防御性编程措施,帮助开发者快速发现配置问题。 - + Raises: RuntimeError: 当权限管理器未设置时 """ @@ -252,17 +252,17 @@ class PermissionAPI: async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: """ 检查用户是否拥有指定权限。 - + 这是最常用的权限检查方法,在执行需要权限的操作前调用。 - + Args: platform: 用户所在平台(如 "qq", "telegram") user_id: 用户ID permission_node: 要检查的权限节点名称 - + Returns: bool: True 表示用户拥有权限,False 表示没有 - + Example: if await permission_api.check_permission("qq", "12345", "admin.ban_user"): # 执行封禁操作 @@ -276,13 +276,13 @@ class PermissionAPI: async def is_master(self, platform: str, user_id: str) -> bool: """ 检查用户是否是管理员/主人。 - + 管理员是系统的最高权限用户,通常在配置文件中指定。 - + Args: platform: 用户所在平台 user_id: 用户ID - + Returns: bool: True 表示是管理员,False 表示不是 """ @@ -302,19 +302,19 @@ class PermissionAPI: ) -> bool: """ 注册一个新的权限节点。 - + 插件在初始化时应调用此方法注册自己需要的权限节点。 - + Args: node_name: 权限节点名称,建议使用 "插件名.功能.操作" 的格式 description: 权限描述,向用户解释这个权限的作用 plugin_name: 注册此权限的插件名称 default_granted: 是否默认授予所有用户(默认 False,需要显式授权) allow_relative: 预留参数,是否允许相对权限名(目前未使用) - + Returns: bool: True 表示注册成功,False 表示失败 - + Example: await permission_api.register_permission_node( node_name="my_plugin.chat.send_image", @@ -324,7 +324,6 @@ class PermissionAPI: ) """ self._ensure_manager() - original_name = node_name # 保存原始名称(预留给相对路径处理) # 创建权限节点对象 node = PermissionNode(node_name, description, plugin_name, default_granted) @@ -337,17 +336,17 @@ class PermissionAPI: async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: """ 授予用户指定权限。 - + 通常由管理员调用,给某个用户赋予特定权限。 - + Args: platform: 目标用户所在平台 user_id: 目标用户ID permission_node: 要授予的权限节点名称 - + Returns: bool: True 表示授权成功,False 表示失败 - + Example: # 授予用户管理权限 await permission_api.grant_permission("qq", "12345", "admin.manage_users") @@ -360,14 +359,14 @@ class PermissionAPI: async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: """ 撤销用户的指定权限。 - + 通常由管理员调用,移除某个用户的特定权限。 - + Args: platform: 目标用户所在平台 user_id: 目标用户ID permission_node: 要撤销的权限节点名称 - + Returns: bool: True 表示撤销成功,False 表示失败 """ @@ -379,16 +378,16 @@ class PermissionAPI: async def get_user_permissions(self, platform: str, user_id: str) -> list[str]: """ 获取用户拥有的所有权限列表。 - + 可用于展示用户的权限信息,或进行批量权限检查。 - + Args: platform: 目标用户所在平台 user_id: 目标用户ID - + Returns: list[str]: 用户拥有的所有权限节点名称列表 - + Example: perms = await permission_api.get_user_permissions("qq", "12345") print(f"用户拥有以下权限: {perms}") @@ -401,16 +400,16 @@ class PermissionAPI: async def get_all_permission_nodes(self) -> list[dict[str, Any]]: """ 获取系统中所有已注册的权限节点。 - + 返回所有插件注册的权限节点信息,可用于权限管理界面展示。 - + Returns: list[dict]: 权限节点信息列表,每个字典包含: - node_name: 权限节点名称 - description: 权限描述 - plugin_name: 所属插件名称 - default_granted: 是否默认授予 - + Note: 返回字典而非 PermissionNode 对象,便于序列化和API响应。 """ @@ -432,12 +431,12 @@ class PermissionAPI: async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]: """ 获取指定插件注册的所有权限节点。 - + 用于查看某个特定插件定义了哪些权限。 - + Args: plugin_name: 插件名称 - + Returns: list[dict]: 该插件的权限节点信息列表,格式同 get_all_permission_nodes """ diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index 088b2b62e..89826019d 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -217,7 +217,7 @@ async def load_plugin(plugin_name: str) -> bool: logger.info(f"插件 '{plugin_name}' 加载成功。") else: logger.error(f"插件 '{plugin_name}' 加载失败。") - + return success diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index 993ae3d1a..338e079dc 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -208,7 +208,7 @@ class ScheduleAPI: if not time_range: continue try: - event_start_str, event_end_str = time_range.split("-") + event_start_str, _event_end_str = time_range.split("-") event_start = datetime.strptime(event_start_str.strip(), "%H:%M").time() if start <= event_start < end: activities_in_range.append(event) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index d5aecb070..d579aa696 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -93,7 +93,9 @@ import uuid from typing import TYPE_CHECKING, Any from mofox_wire import MessageEnvelope + from src.common.data_models.database_data_model import DatabaseUserInfo + if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/plugin_system/apis/unified_scheduler.py b/src/plugin_system/apis/unified_scheduler.py index 1ad5ef084..676bb71dc 100644 --- a/src/plugin_system/apis/unified_scheduler.py +++ b/src/plugin_system/apis/unified_scheduler.py @@ -623,7 +623,7 @@ class UnifiedScheduler: async def _execute_task(self, task: ScheduleTask) -> None: """执行单个任务(完全隔离)""" - execution = task.start_execution() + task.start_execution() self._deadlock_detector.register_task(task.schedule_id, task.task_name) try: @@ -763,7 +763,7 @@ class UnifiedScheduler: async def _execute_event_task(self, task: ScheduleTask, event_params: dict[str, Any]) -> None: """执行事件触发的任务""" - execution = task.start_execution() + task.start_execution() self._deadlock_detector.register_task(task.schedule_id, task.task_name) try: @@ -867,7 +867,7 @@ class UnifiedScheduler: for i, timeout in enumerate(timeouts): try: # 使用 asyncio.wait 代替 wait_for,避免重新抛出异常 - done, pending = await asyncio.wait({task._asyncio_task}, timeout=timeout) + done, _pending = await asyncio.wait({task._asyncio_task}, timeout=timeout) if done: # 任务已完成(可能是正常完成或被取消) diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 487701149..41f07d7a9 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -44,6 +44,7 @@ __all__ = [ "BaseEventHandler", "BasePlugin", "BasePrompt", + "BaseRouterComponent", "BaseTool", "ChatMode", "ChatType", @@ -58,7 +59,6 @@ __all__ = [ "PluginMetadata", # 增强命令系统 "PlusCommand", - "BaseRouterComponent", "PlusCommandInfo", "PythonDependency", "ToolInfo", diff --git a/src/plugin_system/base/base_adapter.py b/src/plugin_system/base/base_adapter.py index ddfdd393a..db335b017 100644 --- a/src/plugin_system/base/base_adapter.py +++ b/src/plugin_system/base/base_adapter.py @@ -10,12 +10,13 @@ from __future__ import annotations import asyncio from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any -from mofox_wire import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope, ProcessCoreSink +from mofox_wire import AdapterBase as MoFoxAdapterBase +from mofox_wire import CoreSink, MessageEnvelope, ProcessCoreSink if TYPE_CHECKING: - from src.plugin_system import BasePlugin, AdapterInfo + from src.plugin_system import AdapterInfo, BasePlugin from src.common.logger import get_logger @@ -25,7 +26,7 @@ logger = get_logger("plugin.adapter") class BaseAdapter(MoFoxAdapterBase, ABC): """ 插件系统的 Adapter 基类 - + 相比 mofox_wire.AdapterBase,增加了以下特性: 1. 插件生命周期管理 (on_adapter_loaded, on_adapter_unloaded) 2. 配置管理集成 @@ -38,17 +39,17 @@ class BaseAdapter(MoFoxAdapterBase, ABC): adapter_version: str = "0.0.1" adapter_author: str = "Unknown" adapter_description: str = "No description" - + # 是否在子进程中运行 run_in_subprocess: bool = True - + # 子进程启动脚本路径(相对于插件目录) - subprocess_entry: Optional[str] = None + subprocess_entry: str | None = None def __init__( self, core_sink: CoreSink, - plugin: Optional[BasePlugin] = None, + plugin: BasePlugin | None = None, **kwargs ): """ @@ -59,8 +60,8 @@ class BaseAdapter(MoFoxAdapterBase, ABC): """ super().__init__(core_sink, **kwargs) self.plugin = plugin - self._config: Dict[str, Any] = {} - self._health_check_task: Optional[asyncio.Task] = None + self._config: dict[str, Any] = {} + self._health_check_task: asyncio.Task | None = None self._running = False # 标记是否在子进程中运行 self._is_subprocess = False @@ -70,7 +71,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC): cls, to_core_queue, from_core_queue, - plugin: Optional["BasePlugin"] = None, + plugin: "BasePlugin" | None = None, **kwargs: Any, ) -> "BaseAdapter": """ @@ -86,14 +87,14 @@ class BaseAdapter(MoFoxAdapterBase, ABC): return cls(core_sink=sink, plugin=plugin, **kwargs) @property - def config(self) -> Dict[str, Any]: + def config(self) -> dict[str, Any]: """获取适配器配置""" if self.plugin and hasattr(self.plugin, "config"): return self.plugin.config return self._config @config.setter - def config(self, value: Dict[str, Any]) -> None: + def config(self, value: dict[str, Any]) -> None: """设置适配器配置""" self._config = value @@ -111,26 +112,26 @@ class BaseAdapter(MoFoxAdapterBase, ABC): async def start(self) -> None: """启动适配器""" logger.info(f"启动适配器: {self.adapter_name} v{self.adapter_version}") - + # 调用生命周期钩子 await self.on_adapter_loaded() - + # 调用父类启动 await super().start() - + # 启动健康检查 if self.config.get("enable_health_check", False): self._health_check_task = asyncio.create_task(self._health_check_loop()) - + self._running = True logger.info(f"适配器 {self.adapter_name} 启动成功") async def stop(self) -> None: """停止适配器""" logger.info(f"停止适配器: {self.adapter_name}") - + self._running = False - + # 停止健康检查 if self._health_check_task and not self._health_check_task.done(): self._health_check_task.cancel() @@ -138,13 +139,13 @@ class BaseAdapter(MoFoxAdapterBase, ABC): await self._health_check_task except asyncio.CancelledError: pass - + # 调用父类停止 await super().stop() - + # 调用生命周期钩子 await self.on_adapter_unloaded() - + logger.info(f"适配器 {self.adapter_name} 已停止") async def on_adapter_loaded(self) -> None: @@ -164,18 +165,18 @@ class BaseAdapter(MoFoxAdapterBase, ABC): async def _health_check_loop(self) -> None: """健康检查循环""" interval = self.config.get("health_check_interval", 30) - + while self._running: try: await asyncio.sleep(interval) - + # 执行健康检查 is_healthy = await self.health_check() - + if not is_healthy: logger.warning(f"适配器 {self.adapter_name} 健康检查失败,尝试重连...") await self.reconnect() - + except asyncio.CancelledError: break except Exception as e: @@ -185,7 +186,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC): """ 健康检查 子类可重写以实现自定义检查逻辑 - + Returns: bool: 是否健康 """ @@ -206,38 +207,38 @@ class BaseAdapter(MoFoxAdapterBase, ABC): except Exception as e: logger.error(f"适配器 {self.adapter_name} 重连失败: {e}") - def get_subprocess_entry_path(self) -> Optional[Path]: + def get_subprocess_entry_path(self) -> Path | None: """ 获取子进程启动脚本的完整路径 - + Returns: Path | None: 脚本路径,如果不存在则返回 None """ if not self.subprocess_entry: return None - + if not self.plugin: return None - + # 获取插件目录 plugin_dir = Path(self.plugin.__file__).parent entry_path = plugin_dir / self.subprocess_entry - + if entry_path.exists(): return entry_path - + logger.warning(f"子进程入口脚本不存在: {entry_path}") return None @classmethod def get_adapter_info(cls) -> "AdapterInfo": """获取适配器的信息 - + Returns: AdapterInfo: 适配器组件信息 """ from src.plugin_system.base.component_types import AdapterInfo - + return AdapterInfo( name=getattr(cls, "adapter_name", cls.__name__.lower().replace("adapter", "")), version=getattr(cls, "adapter_version", "1.0.0"), @@ -252,12 +253,12 @@ class BaseAdapter(MoFoxAdapterBase, ABC): async def from_platform_message(self, raw: Any) -> MessageEnvelope: """ 将平台原始消息转换为 MessageEnvelope - + 子类必须实现此方法 - + Args: raw: 平台原始消息 - + Returns: MessageEnvelope: 统一的消息信封 """ @@ -266,10 +267,10 @@ class BaseAdapter(MoFoxAdapterBase, ABC): async def _send_platform_message(self, envelope: MessageEnvelope) -> None: """ 发送消息到平台 - + 如果使用了 WebSocketAdapterOptions 或 HttpAdapterOptions, 此方法会自动处理。否则子类需要重写此方法。 - + Args: envelope: 要发送的消息信封 """ diff --git a/src/plugin_system/core/adapter_manager.py b/src/plugin_system/core/adapter_manager.py index 44a2a1d65..889de9f96 100644 --- a/src/plugin_system/core/adapter_manager.py +++ b/src/plugin_system/core/adapter_manager.py @@ -14,12 +14,13 @@ from __future__ import annotations import asyncio import importlib import multiprocessing as mp -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from src.plugin_system.base.base_adapter import BaseAdapter from mofox_wire import ProcessCoreSinkServer + from src.common.logger import get_logger logger = get_logger("adapter_manager") @@ -64,11 +65,11 @@ def _adapter_process_entry( ): """ 子进程适配器入口函数 - + 在子进程中运行,创建 ProcessCoreSink 与主进程通信 """ - import asyncio import contextlib + from mofox_wire import ProcessCoreSink async def _run() -> None: @@ -77,14 +78,14 @@ def _adapter_process_entry( if plugin_info: plugin_cls = _load_class(plugin_info["module"], plugin_info["class"]) plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"]) - + # 创建 ProcessCoreSink 用于与主进程通信 core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue) - + # 创建并启动适配器 adapter = adapter_cls(core_sink, plugin=plugin_instance) await adapter.start() - + try: while not getattr(core_sink, "_closed", False): await asyncio.sleep(0.2) @@ -101,7 +102,7 @@ def _adapter_process_entry( class AdapterProcess: """ 适配器子进程封装:管理子进程的生命周期与通信桥接 - + 使用 CoreSinkManager 创建通信队列,自动维护与子进程的消息通道 """ @@ -132,13 +133,13 @@ class AdapterProcess: """启动适配器子进程""" try: logger.info(f"启动适配器子进程: {self.adapter_name}") - + # 从 CoreSinkManager 获取通信队列 from src.common.core_sink_manager import get_core_sink_manager - + manager = get_core_sink_manager() self._incoming_queue, self._outgoing_queue = manager.create_process_sink_queues(self.adapter_name) - + # 启动子进程 self.process = self._ctx.Process( target=_adapter_process_entry, @@ -146,10 +147,10 @@ class AdapterProcess: name=f"{self.adapter_name}-proc", ) self.process.start() - + logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})") return True - + except Exception as e: logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}") return False @@ -158,25 +159,25 @@ class AdapterProcess: """停止适配器子进程""" if not self.process: return - + logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})") - + try: # 从 CoreSinkManager 移除通信队列 from src.common.core_sink_manager import get_core_sink_manager - + manager = get_core_sink_manager() manager.remove_process_sink(self.adapter_name) - + # 等待子进程结束 if self.process.is_alive(): self.process.join(timeout=5.0) - + if self.process.is_alive(): logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中") self.process.terminate() self.process.join() - + except Exception as e: logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}") finally: @@ -193,7 +194,7 @@ class AdapterProcess: class AdapterManager: """ 适配器管理器 - + 负责管理所有注册的适配器,根据 run_in_subprocess 属性自动选择: - run_in_subprocess=True: 在子进程中运行,使用 ProcessCoreSink - run_in_subprocess=False: 在主进程中运行,使用 InProcessCoreSink @@ -201,9 +202,9 @@ class AdapterManager: def __init__(self): # 注册信息:name -> (adapter class, plugin instance | None) - self._adapter_defs: Dict[str, tuple[type[BaseAdapter], object | None]] = {} - self._adapter_processes: Dict[str, AdapterProcess] = {} - self._in_process_adapters: Dict[str, BaseAdapter] = {} + self._adapter_defs: dict[str, tuple[type[BaseAdapter], object | None]] = {} + self._adapter_processes: dict[str, AdapterProcess] = {} + self._in_process_adapters: dict[str, BaseAdapter] = {} def register_adapter(self, adapter_cls: type[BaseAdapter], plugin=None) -> None: """ @@ -213,15 +214,15 @@ class AdapterManager: adapter_cls: 适配器类 plugin: 可选 Plugin 实例 """ - adapter_name = getattr(adapter_cls, 'adapter_name', adapter_cls.__name__) + adapter_name = getattr(adapter_cls, "adapter_name", adapter_cls.__name__) if adapter_name in self._adapter_defs: logger.warning(f"适配器 {adapter_name} 已注册,已覆盖") self._adapter_defs[adapter_name] = (adapter_cls, plugin) - adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown') - run_in_subprocess = getattr(adapter_cls, 'run_in_subprocess', False) - + adapter_version = getattr(adapter_cls, "adapter_version", "unknown") + run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False) + logger.info( f"注册适配器: {adapter_name} v{adapter_version} " f"(子进程: {'是' if run_in_subprocess else '否'})" @@ -230,7 +231,7 @@ class AdapterManager: async def start_adapter(self, adapter_name: str) -> bool: """ 启动指定适配器 - + 根据适配器的 run_in_subprocess 属性自动选择: - True: 创建子进程,使用 ProcessCoreSink - False: 在当前进程,使用 InProcessCoreSink @@ -239,7 +240,7 @@ class AdapterManager: if not definition: logger.error(f"适配器 {adapter_name} 未注册") return False - + adapter_cls, plugin = definition run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False) @@ -248,9 +249,9 @@ class AdapterManager: return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin) async def _start_adapter_subprocess( - self, - adapter_name: str, - adapter_cls: type[BaseAdapter], + self, + adapter_name: str, + adapter_cls: type[BaseAdapter], plugin ) -> bool: """在子进程中启动适配器(使用 ProcessCoreSink)""" @@ -263,24 +264,24 @@ class AdapterManager: return success async def _start_adapter_in_process( - self, - adapter_name: str, - adapter_cls: type[BaseAdapter], + self, + adapter_name: str, + adapter_cls: type[BaseAdapter], plugin ) -> bool: """在当前进程中启动适配器(使用 InProcessCoreSink)""" try: # 从 CoreSinkManager 获取 InProcessCoreSink from src.common.core_sink_manager import get_core_sink_manager - + core_sink = get_core_sink_manager().get_in_process_sink() adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg] await adapter.start() - + self._in_process_adapters[adapter_name] = adapter logger.info(f"适配器 {adapter_name} 已在当前进程启动") return True - + except Exception as e: logger.error(f"启动适配器 {adapter_name} 失败: {e}") return False @@ -288,7 +289,7 @@ class AdapterManager: async def stop_adapter(self, adapter_name: str) -> None: """ 停止指定的适配器 - + Args: adapter_name: 适配器名称 """ @@ -327,20 +328,20 @@ class AdapterManager: logger.info("所有适配器已停止") - def get_adapter(self, adapter_name: str) -> Optional[BaseAdapter]: + def get_adapter(self, adapter_name: str) -> BaseAdapter | None: """ 获取适配器实例 - + Args: adapter_name: 适配器名称 - + Returns: BaseAdapter | None: 适配器实例,如果不存在则返回 None """ # 只返回在主进程中运行的适配器 return self._in_process_adapters.get(adapter_name) - def list_adapters(self) -> Dict[str, Dict[str, any]]: + def list_adapters(self) -> dict[str, dict[str, any]]: """列出适配器状态""" result = {} @@ -371,7 +372,7 @@ class AdapterManager: # 全局单例 -_adapter_manager: Optional[AdapterManager] = None +_adapter_manager: AdapterManager | None = None def get_adapter_manager() -> AdapterManager: diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index 93f63642b..f1988eae6 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -242,7 +242,7 @@ class EventManager: for event in self._events.values(): # 创建订阅者列表的副本进行迭代,以安全地修改原始列表 for subscriber in list(event.subscribers): - if getattr(subscriber, 'handler_name', None) == handler_name: + if getattr(subscriber, "handler_name", None) == handler_name: event.subscribers.remove(subscriber) logger.debug(f"事件处理器 {handler_name} 已从事件 {event.name} 取消订阅。") diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index e18093804..b7fdbcf72 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -38,14 +38,14 @@ class PermissionManager(IPermissionManager): try: master_users_config = global_config.permission.master_users if not isinstance(master_users_config, list): - logger.warning(f"配置文件中的 permission.master_users 不是一个列表,已跳过加载。") + logger.warning("配置文件中的 permission.master_users 不是一个列表,已跳过加载。") self._master_users = set() return self._master_users = set() for i, user_info in enumerate(master_users_config): if not isinstance(user_info, list) or len(user_info) != 2: - logger.warning(f"Master用户配置项格式错误 (索引: {i}): {user_info},应为 [\"platform\", \"user_id\"]") + logger.warning(f'Master用户配置项格式错误 (索引: {i}): {user_info},应为 ["platform", "user_id"]') continue platform, user_id = user_info diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index f1d4479c0..a05805732 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -33,9 +33,9 @@ class PluginManager: self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 - + # 核心消息接收器(由主程序设置) - self._core_sink: Optional[Any] = None + self._core_sink: Any | None = None # 确保插件目录存在 self._ensure_plugin_directories() @@ -45,7 +45,7 @@ class PluginManager: def set_core_sink(self, core_sink: Any) -> None: """设置核心消息接收器 - + Args: core_sink: 核心消息接收器实例(InProcessCoreSink) """ @@ -184,7 +184,7 @@ class PluginManager: async def _register_adapter_components(self, plugin_name: str, plugin_instance: PluginBase) -> None: """注册适配器组件 - + Args: plugin_name: 插件名称 plugin_instance: 插件实例 @@ -193,36 +193,36 @@ class PluginManager: from src.plugin_system.base.component_types import AdapterInfo, ComponentType from src.plugin_system.core.adapter_manager import get_adapter_manager from src.plugin_system.core.component_registry import component_registry - + # 获取所有 ADAPTER 类型的组件 plugin_info = plugin_instance.plugin_info adapter_components = [ - comp for comp in plugin_info.components + comp for comp in plugin_info.components if comp.component_type == ComponentType.ADAPTER ] - + if not adapter_components: return - + adapter_manager = get_adapter_manager() - + for comp_info in adapter_components: # 类型检查:确保是 AdapterInfo if not isinstance(comp_info, AdapterInfo): logger.warning(f"组件 {comp_info.name} 不是 AdapterInfo 类型") continue - + try: # 从组件注册表获取适配器类 adapter_class = component_registry.get_component_class( - comp_info.name, + comp_info.name, ComponentType.ADAPTER ) - + if not adapter_class: logger.warning(f"无法找到适配器组件类: {comp_info.name}") continue - + # 创建适配器实例,传入 core_sink 和 plugin # 注册到适配器管理器,由管理器统一在运行时创建实例 adapter_manager.register_adapter(adapter_class, plugin_instance) # type: ignore @@ -230,13 +230,13 @@ class PluginManager: f"插件 '{plugin_name}' 注册了适配器组件: {comp_info.name} " f"(平台: {comp_info.platform})" ) - + except Exception as e: logger.error( f"注册插件 '{plugin_name}' 的适配器组件 '{comp_info.name}' 时出错: {e}", exc_info=True ) - + except Exception as e: logger.error(f"处理插件 '{plugin_name}' 的适配器组件时出错: {e}") diff --git a/src/plugin_system/core/stream_tool_history.py b/src/plugin_system/core/stream_tool_history.py index 8f8b2f48b..deb93320d 100644 --- a/src/plugin_system/core/stream_tool_history.py +++ b/src/plugin_system/core/stream_tool_history.py @@ -36,7 +36,7 @@ class ToolCallRecord: no_truncate_tools = {"web_search", "web_surfing", "knowledge_search"} should_truncate = self.tool_name not in no_truncate_tools max_length = 500 if should_truncate else 10000 # 联网搜索给更大的限制 - + if isinstance(content, str): if len(content) > max_length: self.result_preview = content[:max_length] + "..." @@ -415,7 +415,6 @@ _STREAM_MANAGERS_MAX_SIZE = 100 # 最大保留数量 def _evict_old_stream_managers() -> None: """内存优化:淘汰最久未使用的 stream manager""" - import time if len(_stream_managers) < _STREAM_MANAGERS_MAX_SIZE: return @@ -429,10 +428,8 @@ def _evict_old_stream_managers() -> None: evicted = [] for chat_id, _ in sorted_by_time[:evict_count]: - if chat_id in _stream_managers: - del _stream_managers[chat_id] - if chat_id in _stream_managers_last_used: - del _stream_managers_last_used[chat_id] + _stream_managers.pop(chat_id, None) + _stream_managers_last_used.pop(chat_id, None) evicted.append(chat_id) if evicted: @@ -466,8 +463,6 @@ def cleanup_stream_manager(chat_id: str) -> None: Args: chat_id: 聊天ID """ - if chat_id in _stream_managers: - del _stream_managers[chat_id] - if chat_id in _stream_managers_last_used: - del _stream_managers_last_used[chat_id] + _stream_managers.pop(chat_id, None) + _stream_managers_last_used.pop(chat_id, None) logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器") diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 5096d08ed..ee3cc25ac 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -11,7 +11,6 @@ from src.llm_models.payload_content import ToolCall from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager logger = get_logger("tool_use") @@ -203,7 +202,7 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}开始LLM工具调用分析") # 调用LLM进行工具决策 - response, llm_extra_info = await self.llm_model.generate_response_async( + _response, llm_extra_info = await self.llm_model.generate_response_async( prompt=prompt, tools=tools, raise_when_empty=False ) @@ -412,7 +411,7 @@ class ToolExecutor: for i, tool_call in enumerate(tool_calls) ] - + async def _execute_single_tool_with_timeout(self, tool_call: ToolCall, index: int) -> ToolExecutionResult: """执行单个工具调用,支持超时控制 diff --git a/src/plugin_system/services/relationship_service.py b/src/plugin_system/services/relationship_service.py index 424832c68..505dd0f25 100644 --- a/src/plugin_system/services/relationship_service.py +++ b/src/plugin_system/services/relationship_service.py @@ -15,7 +15,7 @@ logger = get_logger("relationship_service") class RelationshipService: """用户关系分服务 - 独立于插件的数据库直接访问层 - + 内存优化:添加缓存大小限制和自动过期清理 """ diff --git a/src/plugins/built_in/affinity_flow_chatter/actions/reply.py b/src/plugins/built_in/affinity_flow_chatter/actions/reply.py index 3ab0fe909..04f1ddc08 100644 --- a/src/plugins/built_in/affinity_flow_chatter/actions/reply.py +++ b/src/plugins/built_in/affinity_flow_chatter/actions/reply.py @@ -22,7 +22,7 @@ logger = get_logger("afc_reply_actions") class ReplyAction(BaseAction): """Reply动作 - 针对单条消息的深度回复 - + 特点: - 使用 s4u (Speak for You) 模板 - 专注于理解和回应单条消息的具体内容 @@ -38,7 +38,7 @@ class ReplyAction(BaseAction): activation_type = ActionActivationType.ALWAYS # 回复动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回复动作不能与其他动作并行 - + # Chatter 限制:仅允许 AffinityFlowChatter 使用 chatter_allow: ClassVar[list[str]] = ["AffinityFlowChatter"] @@ -67,17 +67,17 @@ class ReplyAction(BaseAction): try: # 确保 action_message 是 DatabaseMessages 类型,否则使用 None reply_message = self.action_message if isinstance(self.action_message, DatabaseMessages) else None - + # 检查目标消息是否为表情包 if reply_message and getattr(reply_message, "is_emoji", False): if not getattr(global_config.chat, "allow_reply_to_emoji", True): logger.info(f"{self.log_prefix} 目标消息为表情包且配置不允许回复,跳过") return True, "" - + # 准备 action_data action_data = self.action_data.copy() action_data["prompt_mode"] = "s4u" - + # 生成回复 success, response_set, _ = await generator_api.generate_reply( chat_stream=self.chat_stream, @@ -88,16 +88,16 @@ class ReplyAction(BaseAction): request_type="chat.replyer", from_plugin=False, ) - + if not success or not response_set: logger.warning(f"{self.log_prefix} 回复生成失败") return False, "" - + # 发送回复 reply_text = await self._send_response(response_set) - + return True, reply_text - + except asyncio.CancelledError: logger.debug(f"{self.log_prefix} 回复任务被取消") return False, "" @@ -106,28 +106,28 @@ class ReplyAction(BaseAction): import traceback traceback.print_exc() return False, "" - + async def _send_response(self, response_set) -> str: """发送回复内容""" reply_text = "" should_quote = self.action_data.get("should_quote_reply", False) first_sent = False - + # 确保 action_message 是 DatabaseMessages 类型 reply_message = self.action_message if isinstance(self.action_message, DatabaseMessages) else None - + for reply_seg in response_set: # 处理元组格式 if isinstance(reply_seg, tuple) and len(reply_seg) >= 2: _, data = reply_seg else: data = str(reply_seg) - + if isinstance(data, list): data = "".join(map(str, data)) - + reply_text += data - + # 发送消息 if not first_sent: await send_api.text_to_stream( @@ -146,13 +146,13 @@ class ReplyAction(BaseAction): set_reply=False, typing=True, ) - + return reply_text class RespondAction(BaseAction): """Respond动作 - 对未读消息的统一回应 - + 特点: - 关注整体对话动态和未读消息的统一回应 - 适合对于群聊消息下的宏观回应 @@ -168,7 +168,7 @@ class RespondAction(BaseAction): activation_type = ActionActivationType.ALWAYS # 回应动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回应动作不能与其他动作并行 - + # Chatter 限制:仅允许 AffinityFlowChatter 使用 chatter_allow: ClassVar[list[str]] = ["AffinityFlowChatter"] @@ -196,10 +196,10 @@ class RespondAction(BaseAction): # 准备 action_data action_data = self.action_data.copy() action_data["prompt_mode"] = "normal" - + # 确保 action_message 是 DatabaseMessages 类型,否则使用 None reply_message = self.action_message if isinstance(self.action_message, DatabaseMessages) else None - + # 生成回复 success, response_set, _ = await generator_api.generate_reply( chat_stream=self.chat_stream, @@ -210,16 +210,16 @@ class RespondAction(BaseAction): request_type="chat.replyer", from_plugin=False, ) - + if not success or not response_set: logger.warning(f"{self.log_prefix} 回复生成失败") return False, "" - + # 发送回复(respond 默认不引用) reply_text = await self._send_response(response_set) return True, reply_text - + except asyncio.CancelledError: logger.debug(f"{self.log_prefix} 回复任务被取消") return False, "" @@ -228,23 +228,23 @@ class RespondAction(BaseAction): import traceback traceback.print_exc() return False, "" - + async def _send_response(self, response_set) -> str: """发送回复内容(不引用原消息)""" reply_text = "" first_sent = False - + for reply_seg in response_set: if isinstance(reply_seg, tuple) and len(reply_seg) >= 2: _, data = reply_seg else: data = str(reply_seg) - + if isinstance(data, list): data = "".join(map(str, data)) - + reply_text += data - + if not first_sent: await send_api.text_to_stream( text=data, @@ -262,5 +262,5 @@ class RespondAction(BaseAction): set_reply=False, typing=True, ) - + return reply_text diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index de153004b..9ee7236db 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -674,7 +674,7 @@ class ChatterPlanFilter: logger.info(f"[{action}] 成功使用最新消息: {action_message_obj.message_id}") except Exception as e: logger.error(f"[{action}] 无法转换最新消息: {e}") - + return ActionPlannerInfo( action_type=action, reasoning=reasoning, @@ -738,7 +738,7 @@ class ChatterPlanFilter: 动作使用场景: {action_require} -你应该像这样使用它: +你应该像这样使用它: {{ "action_type": "{action_name}", "reasoning": "<执行该动作的详细原因>", diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index 8a364eda7..a0cb47618 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -20,9 +20,9 @@ from src.plugins.built_in.affinity_flow_chatter.planner.plan_generator import Ch if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager + from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import Plan from src.common.data_models.message_manager_data_model import StreamContext - from src.common.data_models.database_data_model import DatabaseMessages # 导入提示词模块以确保其被初始化 @@ -138,7 +138,7 @@ class ChatterActionPlanner: try: interest_manager = get_interest_manager() - except Exception as exc: # noqa: BLE001 + except Exception as exc: logger.warning(f"获取兴趣管理器失败: {exc}") return @@ -153,7 +153,7 @@ class ChatterActionPlanner: try: embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map) - except Exception as exc: # noqa: BLE001 + except Exception as exc: logger.error(f"批量获取消息embedding失败: {exc}") embeddings = {} @@ -167,7 +167,7 @@ class ChatterActionPlanner: try: result = await interest_manager.calculate_interest(message) - except Exception as exc: # noqa: BLE001 + except Exception as exc: logger.error(f"批量计算消息兴趣失败: {exc}") continue @@ -184,7 +184,7 @@ class ChatterActionPlanner: if interest_updates: try: await MessageStorage.bulk_update_interest_values(interest_updates, reply_updates) - except Exception as exc: # noqa: BLE001 + except Exception as exc: logger.error(f"批量更新消息兴趣值失败: {exc}") async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]: diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py index 6ec9eab04..8859d3e58 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py @@ -699,13 +699,13 @@ async def execute_proactive_thinking(stream_id: str): try: # 0. 前置检查 - + # 0.-1 检查是否是私聊且 KFC 主动思考已启用(让 KFC 接管私聊主动思考) try: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + # 判断是否是私聊(使用 chat_type 枚举或从 stream_id 判断) is_private = False if chat_stream: @@ -714,17 +714,17 @@ async def execute_proactive_thinking(stream_id: str): except Exception: # 回退:从 stream_id 判断(私聊通常不包含 "group") is_private = "group" not in stream_id.lower() - + if is_private: # 这是一个私聊,检查 KFC 是否启用且其主动思考是否启用 try: from src.config.config import global_config - kfc_config = getattr(global_config, 'kokoro_flow_chatter', None) + kfc_config = getattr(global_config, "kokoro_flow_chatter", None) if kfc_config: - kfc_enabled = getattr(kfc_config, 'enable', False) - proactive_config = getattr(kfc_config, 'proactive_thinking', None) - proactive_enabled = getattr(proactive_config, 'enabled', False) if proactive_config else False - + kfc_enabled = getattr(kfc_config, "enable", False) + proactive_config = getattr(kfc_config, "proactive_thinking", None) + proactive_enabled = getattr(proactive_config, "enabled", False) if proactive_config else False + if kfc_enabled and proactive_enabled: logger.debug( f"[主动思考] 私聊 {stream_id} 由 KFC 主动思考接管,跳过通用主动思考" @@ -734,7 +734,7 @@ async def execute_proactive_thinking(stream_id: str): logger.debug(f"检查 KFC 配置时出错,继续执行通用主动思考: {e}") except Exception as e: logger.warning(f"检查私聊/KFC 状态时出错: {e},继续执行") - + # 0.0 检查聊天流是否正在处理消息(双重保护) try: from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/__init__.py b/src/plugins/built_in/affinity_flow_chatter/tools/__init__.py index af91bf1ae..eb42be535 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/__init__.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/__init__.py @@ -8,4 +8,4 @@ from .chat_stream_impression_tool import ChatStreamImpressionTool from .user_fact_tool import UserFactTool from .user_profile_tool import UserProfileTool -__all__ = ["ChatStreamImpressionTool", "UserProfileTool", "UserFactTool"] +__all__ = ["ChatStreamImpressionTool", "UserFactTool", "UserProfileTool"] diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py index 0ad2612a0..1650e11d6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py @@ -21,7 +21,7 @@ logger = get_logger("user_fact_tool") class UserFactTool(BaseTool): """用户关键信息记录工具 - + 用于记录生日、职业、理想、宠物等长期重要信息。 注意:一般情况下使用 update_user_profile 工具即可同时记录印象和关键信息。 此工具仅在需要单独补充记录信息时使用。 @@ -31,7 +31,7 @@ class UserFactTool(BaseTool): description = """【备用工具】单独记录用户的重要个人信息。 注意:大多数情况请直接使用 update_user_profile 工具(它可以同时更新印象和记录关键信息)。 仅当你只想补充记录一条信息、不需要更新印象时才使用此工具。""" - + parameters = [ ("target_user_id", ToolParamType.STRING, "目标用户的ID(必须)", True, None), ("target_user_name", ToolParamType.STRING, "目标用户的名字/昵称(必须)", True, None), @@ -43,10 +43,10 @@ class UserFactTool(BaseTool): async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行关键信息记录 - + Args: function_args: 工具参数 - + Returns: dict: 执行结果 """ @@ -56,29 +56,29 @@ class UserFactTool(BaseTool): target_user_name = function_args.get("target_user_name", target_user_id) info_type = function_args.get("info_type", "other") info_value = function_args.get("info_value", "") - + if not target_user_id: return { "type": "error", "id": "remember_user_info", "content": "错误:必须提供目标用户ID" } - + if not info_value: return { "type": "error", "id": "remember_user_info", "content": "错误:必须提供要记录的信息内容" } - + # 验证 info_type valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"] if info_type not in valid_types: info_type = "other" - + # 更新数据库 await self._add_key_fact(target_user_id, info_type, info_value) - + # 生成友好的类型名称 type_names = { "birthday": "生日", @@ -90,16 +90,16 @@ class UserFactTool(BaseTool): "other": "其他信息" } type_name = type_names.get(info_type, "信息") - + result_text = f"已记住 {target_user_name} 的{type_name}:{info_value}" logger.info(f"记录用户关键信息: {target_user_id}, {info_type}={info_value}") - + return { "type": "user_fact_recorded", "id": target_user_id, "content": result_text } - + except Exception as e: logger.error(f"记录用户关键信息失败: {e}") return { @@ -110,7 +110,7 @@ class UserFactTool(BaseTool): async def _add_key_fact(self, user_id: str, info_type: str, info_value: str): """添加或更新关键信息 - + Args: user_id: 用户ID info_type: 信息类型 @@ -118,22 +118,22 @@ class UserFactTool(BaseTool): """ try: current_time = time.time() - + async with get_db_session() as session: stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 解析现有的 key_facts try: facts = orjson.loads(existing.key_facts) if existing.key_facts else [] except Exception: facts = [] - + if not isinstance(facts, list): facts = [] - + # 查找是否已有相同类型的信息 found = False for i, fact in enumerate(facts): @@ -142,11 +142,11 @@ class UserFactTool(BaseTool): facts[i] = {"type": info_type, "value": info_value} found = True break - + if not found: # 添加新记录 facts.append({"type": info_type, "value": info_value}) - + # 更新数据库 existing.key_facts = orjson.dumps(facts).decode("utf-8") existing.last_updated = current_time @@ -161,10 +161,10 @@ class UserFactTool(BaseTool): last_updated=current_time ) session.add(new_profile) - + await session.commit() logger.info(f"关键信息已保存: {user_id}, {info_type}={info_value}") - + except Exception as e: logger.error(f"保存关键信息失败: {e}") raise diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py index dfa7d8f96..5247e9974 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py @@ -95,7 +95,7 @@ class UserProfileTool(BaseTool): dict: 执行结果 """ import asyncio - + try: # 提取参数 target_user_id = function_args.get("target_user_id") @@ -180,7 +180,7 @@ class UserProfileTool(BaseTool): operation=alias_operation, new_value=alias_value, ) - + # 🎯 处理偏好操作 final_preferences = self._process_list_operation( existing_value=existing_profile.get("preference_keywords", ""), @@ -194,7 +194,7 @@ class UserProfileTool(BaseTool): # 🎯 核心:使用relationship_tracker模型生成印象并决定好感度变化 final_impression = existing_profile.get("relationship_text", "") affection_change = 0.0 # 好感度变化量 - + if impression_hint or chat_history_text: impression_result = await self._generate_impression_with_affection( target_user_name=target_user_name, @@ -225,27 +225,27 @@ class UserProfileTool(BaseTool): except Exception as e: logger.error(f"[后台] 用户画像更新失败: {e}") - + def _process_list_operation(self, existing_value: str, operation: str, new_value: str) -> str: """处理列表类型的操作(别名、偏好等) - + Args: existing_value: 现有值(用、分隔) operation: 操作类型 add/remove/replace new_value: 新值(用、分隔) - + Returns: str: 处理后的值 """ if not new_value: return existing_value - + # 解析现有值和新值 existing_set = set(filter(None, [x.strip() for x in (existing_value or "").split("、")])) new_set = set(filter(None, [x.strip() for x in new_value.split("、")])) - + operation = (operation or "add").lower().strip() - + if operation == "replace": # 全部替换 result_set = new_set @@ -258,25 +258,25 @@ class UserProfileTool(BaseTool): # 新增(合并) result_set = existing_set | new_set logger.debug(f"别名/偏好新增: {new_set} 到 {existing_set}") - + return "、".join(sorted(result_set)) async def _add_key_fact(self, user_id: str, info_type: str, info_value: str): """添加或更新关键信息(生日、职业等) - + Args: user_id: 用户ID info_type: 信息类型(birthday/job/location/dream/family/pet/other) info_value: 信息内容 """ import orjson - + try: # 验证 info_type valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"] if info_type not in valid_types: info_type = "other" - + # 🎯 信息质量判断:过滤掉模糊的描述性内容 low_quality_patterns = [ "的生日", "的工作", "的位置", "的梦想", "的家人", "的宠物", @@ -284,34 +284,34 @@ class UserProfileTool(BaseTool): "affectionate", "friendly", "的信息", "某个", "一个" ] info_value_lower = info_value.lower().strip() - + # 如果值太短或包含低质量模式,跳过 if len(info_value_lower) < 2: logger.warning(f"关键信息值太短,跳过: {info_value}") return - + for pattern in low_quality_patterns: if pattern in info_value_lower: logger.warning(f"关键信息质量不佳,跳过: {info_type}={info_value}(包含'{pattern}')") return - + current_time = time.time() - + async with get_db_session() as session: stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() - + if existing: # 解析现有的 key_facts try: facts = orjson.loads(existing.key_facts) if existing.key_facts else [] except Exception: facts = [] - + if not isinstance(facts, list): facts = [] - + # 查找是否已有相同类型的信息 found = False for i, fact in enumerate(facts): @@ -324,11 +324,11 @@ class UserProfileTool(BaseTool): facts[i] = {"type": info_type, "value": info_value} found = True break - + if not found: # 添加新记录 facts.append({"type": info_type, "value": info_value}) - + # 更新数据库 existing.key_facts = orjson.dumps(facts).decode("utf-8") existing.last_updated = current_time @@ -343,9 +343,9 @@ class UserProfileTool(BaseTool): last_updated=current_time ) session.add(new_profile) - + await session.commit() - + # 清除缓存,确保下次查询获取最新数据 try: from src.common.database.optimization.cache_manager import get_cache @@ -355,20 +355,20 @@ class UserProfileTool(BaseTool): logger.debug(f"已清除用户关系缓存: {user_id}") except Exception as cache_err: logger.warning(f"清除缓存失败(不影响数据保存): {cache_err}") - + logger.info(f"关键信息已保存: {user_id}, {info_type}={info_value}") - + except Exception as e: logger.error(f"保存关键信息失败: {e}") # 不抛出异常,因为这是后台任务 async def _get_recent_chat_history(self, target_user_id: str, max_messages: int = 50) -> str: """获取最近的聊天记录 - + Args: target_user_id: 目标用户ID max_messages: 最大消息数量 - + Returns: str: 格式化的聊天记录文本 """ @@ -377,24 +377,24 @@ class UserProfileTool(BaseTool): if not self.chat_stream: logger.warning("chat_stream 未初始化,无法获取聊天记录") return "" - + context = getattr(self.chat_stream, "context", None) if not context: logger.warning("chat_stream.context 不存在,无法获取聊天记录") return "" - + # 获取最近的消息 - 使用正确的方法名 get_messages messages = context.get_messages(limit=max_messages, include_unread=True) if not messages: return "" - + # 将 DatabaseMessages 对象转换为字典列表 messages_dict = [] for msg in messages: try: - if hasattr(msg, 'to_dict'): + if hasattr(msg, "to_dict"): messages_dict.append(msg.to_dict()) - elif hasattr(msg, '__dict__'): + elif hasattr(msg, "__dict__"): # 手动构建字典 msg_dict = { "time": getattr(msg, "time", 0), @@ -418,10 +418,10 @@ class UserProfileTool(BaseTool): except Exception as e: logger.warning(f"转换消息失败: {e}") continue - + if not messages_dict: return "" - + # 构建可读的消息文本 readable_messages = await build_readable_messages( messages=messages_dict, @@ -429,9 +429,9 @@ class UserProfileTool(BaseTool): timestamp_mode="normal_no_YMD", truncate=True ) - + return readable_messages or "" - + except Exception as e: logger.error(f"获取聊天记录失败: {e}") return "" @@ -446,7 +446,7 @@ class UserProfileTool(BaseTool): current_score: float, ) -> dict[str, Any]: """使用relationship_tracker模型生成印象并决定好感度变化 - + Args: target_user_name: 目标用户的名字 impression_hint: 工具调用模型传入的简要观察 @@ -454,25 +454,26 @@ class UserProfileTool(BaseTool): preference_keywords: 用户的兴趣偏好 chat_history: 最近的聊天记录 current_score: 当前好感度分数 - + Returns: dict: {"impression": str, "affection_change": float} """ try: import orjson from json_repair import repair_json + from src.llm_models.utils_model import LLMRequest - + # 获取人设信息(添加空值保护) bot_name = global_config.bot.nickname if global_config and global_config.bot else "Bot" personality_core = global_config.personality.personality_core if global_config and global_config.personality else "" personality_side = global_config.personality.personality_side if global_config and global_config.personality else "" reply_style = global_config.personality.reply_style if global_config and global_config.personality else "" - + # 构建提示词 # 根据是否有旧印象决定任务类型 is_first_impression = not existing_impression or len(existing_impression) < 20 - + prompt = f"""你是{bot_name},现在要记录你对"{target_user_name}"的印象。 ## 你的核心人格 @@ -592,27 +593,27 @@ class UserProfileTool(BaseTool): # 使用relationship_tracker模型(添加空值保护) if not model_config or not model_config.model_task_config: raise ValueError("model_config 未初始化") - + llm = LLMRequest( model_set=model_config.model_task_config.relationship_tracker, request_type="user_profile.impression_and_affection" ) - + response, _ = await llm.generate_response_async( prompt=prompt, temperature=0.7, max_tokens=600, ) - + # 解析响应 response = response.strip() try: result = orjson.loads(repair_json(response)) impression = result.get("impression", "") affection_change = float(result.get("affection_change", 0)) - change_reason = result.get("change_reason", "") + result.get("change_reason", "") detected_gender = result.get("gender", "unknown") - + # 🎯 根据当前好感度阶段限制变化范围 if current_score < 0.3: # 陌生→初识:±0.05 @@ -629,21 +630,21 @@ class UserProfileTool(BaseTool): else: # 好友→挚友:±0.01 max_change = 0.01 - + affection_change = max(-max_change, min(max_change, affection_change)) - + # 如果印象为空或太短,回退到hint if not impression or len(impression) < 10: - logger.warning(f"印象生成结果过短,使用原始hint") + logger.warning("印象生成结果过短,使用原始hint") impression = impression_hint or existing_impression - + logger.debug(f"印象更新: 用户性别判断={detected_gender}, 好感度变化={affection_change:+.3f}") - + return { "impression": impression, "affection_change": affection_change } - + except Exception as parse_error: logger.warning(f"解析JSON失败: {parse_error},尝试提取文本") # 如果JSON解析失败,尝试直接使用响应作为印象 @@ -651,7 +652,7 @@ class UserProfileTool(BaseTool): "impression": response if len(response) > 10 else (impression_hint or existing_impression), "affection_change": 0.0 } - + except Exception as e: logger.error(f"生成印象和好感度失败: {e}") # 失败时回退 @@ -740,16 +741,16 @@ class UserProfileTool(BaseTool): if existing: # 别名和偏好已经在_background_update中处理好了,直接赋值 existing.user_aliases = profile.get("user_aliases", "") or existing.user_aliases - + # 同时更新新旧两个印象字段,保持兼容 impression = profile.get("relationship_text", "") if impression: # 只有有新印象才更新 existing.relationship_text = impression existing.impression_text = impression - + # 偏好关键词已经在_background_update中处理好了,直接赋值 existing.preference_keywords = profile.get("preference_keywords", "") or existing.preference_keywords - + existing.relationship_score = score existing.relationship_stage = stage existing.last_impression_update = current_time @@ -776,7 +777,7 @@ class UserProfileTool(BaseTool): session.add(new_profile) await session.commit() - + # 清除缓存,确保下次查询获取最新数据 try: from src.common.database.optimization.cache_manager import get_cache @@ -786,7 +787,7 @@ class UserProfileTool(BaseTool): logger.debug(f"已清除用户关系缓存: {user_id}") except Exception as cache_err: logger.warning(f"清除缓存失败(不影响数据保存): {cache_err}") - + logger.info(f"用户画像已更新到数据库: {user_id}, 阶段: {stage}") except Exception as e: @@ -795,10 +796,10 @@ class UserProfileTool(BaseTool): def _calculate_relationship_stage(self, score: float) -> str: """根据好感度分数计算关系阶段 - + Args: score: 好感度分数(0-1) - + Returns: str: 关系阶段 """ diff --git a/src/plugins/built_in/anti_injection_plugin/prompts.py b/src/plugins/built_in/anti_injection_plugin/prompts.py index 7f31fe828..ad48d4cb9 100644 --- a/src/plugins/built_in/anti_injection_plugin/prompts.py +++ b/src/plugins/built_in/anti_injection_plugin/prompts.py @@ -64,7 +64,7 @@ class AntiInjectionPrompt(BasePrompt): return "" # 获取安全管理器 - security_manager = get_security_manager() + get_security_manager() # 检查当前消息的风险级别 current_message = self.params.current_user_message diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 56b78c126..6ff8a2836 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -1,5 +1,4 @@ import random -import re from typing import ClassVar from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis @@ -11,7 +10,7 @@ from src.common.logger import get_logger from src.config.config import global_config # 导入新插件系统 -from src.plugin_system import ActionActivationType, BaseAction, ChatMode +from src.plugin_system import BaseAction, ChatMode # 导入API模块 - 标准Python包方式 from src.plugin_system.apis import llm_api, message_api diff --git a/src/plugins/built_in/kokoro_flow_chatter/__init__.py b/src/plugins/built_in/kokoro_flow_chatter/__init__.py index c5654e84e..6ae232d8a 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/__init__.py +++ b/src/plugins/built_in/kokoro_flow_chatter/__init__.py @@ -15,25 +15,9 @@ Kokoro Flow Chatter (KFC) - 私聊特化的心流聊天器 5. 大模板 + 小模板:线性叙事风格的提示词架构 """ -from .models import ( - EventType, - SessionStatus, - MentalLogEntry, - WaitingConfig, - ActionModel, - LLMResponse, -) -from .session import KokoroSession, SessionManager, get_session_manager +from src.plugin_system.base.plugin_metadata import PluginMetadata + from .chatter import KokoroFlowChatter -from .planner import generate_plan -from .replyer import generate_reply_text -from .unified import generate_unified_response -from .proactive_thinker import ( - ProactiveThinker, - get_proactive_thinker, - start_proactive_thinker, - stop_proactive_thinker, -) from .config import ( KFCMode, KokoroFlowChatterConfig, @@ -41,8 +25,25 @@ from .config import ( load_config, reload_config, ) +from .models import ( + ActionModel, + EventType, + LLMResponse, + MentalLogEntry, + SessionStatus, + WaitingConfig, +) +from .planner import generate_plan from .plugin import KokoroFlowChatterPlugin -from src.plugin_system.base.plugin_metadata import PluginMetadata +from .proactive_thinker import ( + ProactiveThinker, + get_proactive_thinker, + start_proactive_thinker, + stop_proactive_thinker, +) +from .replyer import generate_reply_text +from .session import KokoroSession, SessionManager, get_session_manager +from .unified import generate_unified_response __plugin_meta__ = PluginMetadata( name="Kokoro Flow Chatter", @@ -56,34 +57,34 @@ __plugin_meta__ = PluginMetadata( ) __all__ = [ + "ActionModel", # Models "EventType", - "SessionStatus", - "MentalLogEntry", - "WaitingConfig", - "ActionModel", - "LLMResponse", - # Session - "KokoroSession", - "SessionManager", - "get_session_manager", + # Config + "KFCMode", # Core Components "KokoroFlowChatter", + "KokoroFlowChatterConfig", + # Plugin + "KokoroFlowChatterPlugin", + # Session + "KokoroSession", + "LLMResponse", + "MentalLogEntry", + # Proactive Thinker + "ProactiveThinker", + "SessionManager", + "SessionStatus", + "WaitingConfig", + "__plugin_meta__", "generate_plan", "generate_reply_text", "generate_unified_response", - # Proactive Thinker - "ProactiveThinker", - "get_proactive_thinker", - "start_proactive_thinker", - "stop_proactive_thinker", - # Config - "KFCMode", - "KokoroFlowChatterConfig", "get_config", + "get_proactive_thinker", + "get_session_manager", "load_config", "reload_config", - # Plugin - "KokoroFlowChatterPlugin", - "__plugin_meta__", + "start_proactive_thinker", + "stop_proactive_thinker", ] diff --git a/src/plugins/built_in/kokoro_flow_chatter/actions/reply.py b/src/plugins/built_in/kokoro_flow_chatter/actions/reply.py index fc6edb1b8..aec5ab547 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/actions/reply.py +++ b/src/plugins/built_in/kokoro_flow_chatter/actions/reply.py @@ -53,7 +53,7 @@ class KFCReplyAction(BaseAction): activation_type = ActionActivationType.ALWAYS mode_enable = ChatMode.ALL parallel_action = False - + # Chatter 限制:仅允许 KokoroFlowChatter 使用 chatter_allow: ClassVar[list[str]] = ["KokoroFlowChatter"] @@ -77,7 +77,7 @@ class KFCReplyAction(BaseAction): try: # 1. 检查是否有预生成的内容 content = self.action_data.get("content", "") - + if not content: # 2. 需要生成回复,获取必要信息 user_id = self.action_data.get("user_id") @@ -85,17 +85,17 @@ class KFCReplyAction(BaseAction): thought = self.action_data.get("thought", "") situation_type = self.action_data.get("situation_type", "new_message") extra_context = self.action_data.get("extra_context") - + if not user_id: logger.warning(f"{self.log_prefix} 缺少 user_id,无法生成回复") return False, "" - + # 3. 获取 Session session = await self._get_session(user_id) if not session: logger.warning(f"{self.log_prefix} 无法获取 Session: {user_id}") return False, "" - + # 4. 调用 Replyer 生成回复 success, content = await self._generate_reply( session=session, @@ -104,35 +104,35 @@ class KFCReplyAction(BaseAction): situation_type=situation_type, extra_context=extra_context, ) - + if not success or not content: logger.warning(f"{self.log_prefix} 回复生成失败") return False, "" - + # 5. 回复后处理(系统格式词过滤 + 分段处理) enable_splitter = self.action_data.get("enable_splitter", True) enable_chinese_typo = self.action_data.get("enable_chinese_typo", True) - + processed_segments = self._post_process_reply( content=content, enable_splitter=enable_splitter, enable_chinese_typo=enable_chinese_typo, ) - + if not processed_segments: logger.warning(f"{self.log_prefix} 回复后处理后内容为空") return False, "" - + # 6. 分段发送回复 should_quote = self.action_data.get("should_quote_reply", False) reply_text = await self._send_segments( segments=processed_segments, should_quote=should_quote, ) - + logger.info(f"{self.log_prefix} KFC reply 动作执行成功: {reply_text[:50]}...") return True, reply_text - + except asyncio.CancelledError: logger.debug(f"{self.log_prefix} 回复任务被取消") return False, "" @@ -141,7 +141,7 @@ class KFCReplyAction(BaseAction): import traceback traceback.print_exc() return False, "" - + def _post_process_reply( self, content: str, @@ -150,53 +150,53 @@ class KFCReplyAction(BaseAction): ) -> list[str]: """ 回复后处理 - + 包括: 1. 系统格式词过滤(移除 [回复...]、[表情包:...]、@<...> 等) 2. 分段处理(根据标点分句、智能合并) 3. 错字生成(拟人化) - + Args: content: 原始回复内容 enable_splitter: 是否启用分段 enable_chinese_typo: 是否启用错字生成 - + Returns: 处理后的文本段落列表 """ try: from src.chat.utils.utils import filter_system_format_content, process_llm_response - + # 1. 过滤系统格式词 filtered_content = filter_system_format_content(content) - + if not filtered_content or not filtered_content.strip(): logger.warning(f"{self.log_prefix} 过滤系统格式词后内容为空") return [] - + # 2. 分段处理 + 错字生成 processed_segments = process_llm_response( filtered_content, enable_splitter=enable_splitter, enable_chinese_typo=enable_chinese_typo, ) - + # 过滤空段落 processed_segments = [seg for seg in processed_segments if seg and seg.strip()] - + logger.debug( f"{self.log_prefix} 回复后处理完成: " f"原始长度={len(content)}, 过滤后长度={len(filtered_content)}, " f"分段数={len(processed_segments)}" ) - + return processed_segments - + except Exception as e: logger.error(f"{self.log_prefix} 回复后处理失败: {e}") # 失败时返回原始内容 return [content] if content else [] - + async def _send_segments( self, segments: list[str], @@ -204,28 +204,28 @@ class KFCReplyAction(BaseAction): ) -> str: """ 分段发送回复 - + Args: segments: 要发送的文本段落列表 should_quote: 是否引用原消息(仅第一条消息引用) - + Returns: 完整的回复文本(所有段落拼接) """ reply_text = "" first_sent = False - + # 获取分段发送的间隔时间 typing_delay = 0.5 - if global_config and hasattr(global_config, 'response_splitter'): + if global_config and hasattr(global_config, "response_splitter"): typing_delay = getattr(global_config.response_splitter, "typing_delay", 0.5) - + for segment in segments: if not segment or not segment.strip(): continue - + reply_text += segment - + # 发送消息 if not first_sent: # 第一条消息:可能需要引用 @@ -241,7 +241,7 @@ class KFCReplyAction(BaseAction): # 后续消息:模拟打字延迟 if typing_delay > 0: await asyncio.sleep(typing_delay) - + await send_api.text_to_stream( text=segment, stream_id=self.chat_stream.stream_id, @@ -249,32 +249,32 @@ class KFCReplyAction(BaseAction): set_reply=False, typing=True, ) - + return reply_text - + async def _get_session(self, user_id: str) -> Optional["KokoroSession"]: """获取用户 Session""" try: from ..session import get_session_manager - + session_manager = get_session_manager() return await session_manager.get_session(user_id, self.chat_stream.stream_id) except Exception as e: logger.error(f"{self.log_prefix} 获取 Session 失败: {e}") return None - + async def _generate_reply( self, session: "KokoroSession", user_name: str, thought: str, situation_type: str, - extra_context: Optional[dict] = None, + extra_context: dict | None = None, ) -> tuple[bool, str]: """调用 Replyer 生成回复""" try: from ..replyer import generate_reply_text - + return await generate_reply_text( session=session, user_name=user_name, diff --git a/src/plugins/built_in/kokoro_flow_chatter/chatter.py b/src/plugins/built_in/kokoro_flow_chatter/chatter.py index 3b50fe6f0..b2f8dd866 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/chatter.py +++ b/src/plugins/built_in/kokoro_flow_chatter/chatter.py @@ -58,7 +58,7 @@ class KokoroFlowChatter(BaseChatter): chatter_name: str = "KokoroFlowChatter" chatter_description: str = "心流聊天器 - 私聊特化的深度情感交互处理器" chat_types: ClassVar[list[ChatType]] = [ChatType.PRIVATE] - + def __init__( self, stream_id: str, @@ -66,33 +66,33 @@ class KokoroFlowChatter(BaseChatter): plugin_config: dict | None = None, ): super().__init__(stream_id, action_manager, plugin_config) - + # 核心组件 self.session_manager = get_session_manager() - + # 加载配置 self._config = get_config() self._mode = self._config.mode - + # 并发控制 self._lock = asyncio.Lock() self._processing = False - + # 统计 self._stats: dict[str, Any] = { "messages_processed": 0, "successful_responses": 0, "failed_responses": 0, } - + # 输出初始化信息 mode_str = "统一模式" if self._mode == KFCMode.UNIFIED else "分离模式" logger.info(f"初始化完成 (模式: {mode_str}): stream_id={stream_id}") - + async def execute(self, context: StreamContext) -> dict: """ 执行聊天处理 - + 流程: 1. 获取 Session 2. 获取未读消息 @@ -105,60 +105,60 @@ class KokoroFlowChatter(BaseChatter): """ async with self._lock: self._processing = True - + try: # 1. 获取未读消息 unread_messages = context.get_unread_messages() if not unread_messages: return self._build_result(success=True, message="no_unread_messages") - + # 2. 取最后一条消息作为主消息 target_message = unread_messages[-1] user_info = target_message.user_info - + if not user_info: return self._build_result(success=False, message="no_user_info") - + user_id = str(user_info.user_id) user_name = user_info.user_nickname or user_id - + # 3. 获取或创建 Session session = await self.session_manager.get_session(user_id, self.stream_id) - + # 3.5 **立即**更新活动时间,阻止 ProactiveThinker 并发处理 session.last_activity_at = time.time() - + # 4. 确定 situation_type(根据之前的等待状态) situation_type = self._determine_situation_type(session) - + # 5. **立即**结束等待状态,防止 ProactiveThinker 并发处理 if session.status == SessionStatus.WAITING: session.end_waiting() await self.session_manager.save_session(user_id) - + # 6. 记录用户消息到 mental_log for msg in unread_messages: msg_content = msg.processed_plain_text or msg.display_message or "" msg_user_name = msg.user_info.user_nickname if msg.user_info else user_name msg_user_id = str(msg.user_info.user_id) if msg.user_info else user_id - + session.add_user_message( content=msg_content, user_name=msg_user_name, user_id=msg_user_id, timestamp=msg.time, ) - + # 7. 加载可用动作(通过 ActionModifier 过滤) from src.chat.planner_actions.action_modifier import ActionModifier - + action_modifier = ActionModifier(self.action_manager, self.stream_id) await action_modifier.modify_actions(chatter_name="KokoroFlowChatter") available_actions = self.action_manager.get_using_actions() - + # 8. 获取聊天流 chat_stream = await self._get_chat_stream() - + # 9. 根据模式调用对应的生成器 if self._mode == KFCMode.UNIFIED: plan_response = await self._execute_unified_mode( @@ -177,7 +177,7 @@ class KokoroFlowChatter(BaseChatter): chat_stream=chat_stream, available_actions=available_actions, ) - + # 10. 执行动作 raw_wait = plan_response.max_wait_seconds adjusted_wait = apply_wait_duration_rules( @@ -205,10 +205,10 @@ class KokoroFlowChatter(BaseChatter): exec_results = [] has_reply = False - + for action in plan_response.actions: action_data = action.params.copy() - + result = await self.action_manager.execute_action( action_name=action.type, chat_id=self.stream_id, @@ -221,7 +221,7 @@ class KokoroFlowChatter(BaseChatter): exec_results.append(result) if result.get("success") and action.type in ("kfc_reply", "respond"): has_reply = True - + # 11. 记录 Bot 规划到 mental_log session.add_bot_planning( thought=plan_response.thought, @@ -229,7 +229,7 @@ class KokoroFlowChatter(BaseChatter): expected_reaction=plan_response.expected_reaction, max_wait_seconds=plan_response.max_wait_seconds, ) - + # 12. 更新 Session 状态 if plan_response.max_wait_seconds > 0: session.start_waiting( @@ -238,19 +238,19 @@ class KokoroFlowChatter(BaseChatter): ) else: session.end_waiting() - + # 13. 标记消息为已读 for msg in unread_messages: context.mark_message_as_read(str(msg.message_id)) - + # 14. 保存 Session await self.session_manager.save_session(user_id) - + # 15. 更新统计 self._stats["messages_processed"] += len(unread_messages) if has_reply: self._stats["successful_responses"] += 1 - + # 输出完成信息 mode_str = "unified" if self._mode == KFCMode.UNIFIED else "split" logger.info( @@ -259,7 +259,7 @@ class KokoroFlowChatter(BaseChatter): f"actions={[a.type for a in plan_response.actions]}, " f"wait={plan_response.max_wait_seconds}s" ) - + return self._build_result( success=True, message="processed", @@ -268,17 +268,17 @@ class KokoroFlowChatter(BaseChatter): situation_type=situation_type, mode=mode_str, ) - + except Exception as e: self._stats["failed_responses"] += 1 logger.error(f"[KFC] 处理失败: {e}") import traceback traceback.print_exc() return self._build_result(success=False, message=str(e), error=True) - + finally: self._processing = False - + async def _execute_unified_mode( self, session, @@ -289,12 +289,12 @@ class KokoroFlowChatter(BaseChatter): ): """ 统一模式:单次 LLM 调用完成思考 + 回复生成 - + LLM 输出的 JSON 中 kfc_reply 动作已包含 content 字段, 无需再调用 Replyer 生成回复。 """ from .unified import generate_unified_response - + plan_response = await generate_unified_response( session=session, user_name=user_name, @@ -302,10 +302,10 @@ class KokoroFlowChatter(BaseChatter): chat_stream=chat_stream, available_actions=available_actions, ) - + # 统一模式下 content 已经在 actions 中,无需注入 return plan_response - + async def _execute_split_mode( self, session, @@ -317,12 +317,12 @@ class KokoroFlowChatter(BaseChatter): ): """ 分离模式:Planner + Replyer 两次 LLM 调用 - + 1. Planner 生成行动计划(JSON,kfc_reply 不含 content) 2. 为 kfc_reply 动作注入上下文,由 Action.execute() 调用 Replyer 生成回复 """ from .planner import generate_plan - + plan_response = await generate_plan( session=session, user_name=user_name, @@ -330,7 +330,7 @@ class KokoroFlowChatter(BaseChatter): chat_stream=chat_stream, available_actions=available_actions, ) - + # 为 kfc_reply 动作注入回复生成所需的上下文 for action in plan_response.actions: if action.type == "kfc_reply": @@ -338,13 +338,13 @@ class KokoroFlowChatter(BaseChatter): action.params["user_name"] = user_name action.params["thought"] = plan_response.thought action.params["situation_type"] = situation_type - + return plan_response - + def _determine_situation_type(self, session) -> str: """ 确定当前情况类型 - + 根据 Session 之前的状态决定提示词的 situation_type """ if session.status == SessionStatus.WAITING: @@ -352,7 +352,7 @@ class KokoroFlowChatter(BaseChatter): # 如果 max_wait_seconds <= 0,说明不是有效的等待状态,视为新消息 if session.waiting_config.max_wait_seconds <= 0: return "new_message" - + if session.waiting_config.is_timeout(): # 超时了才收到回复 return "reply_late" @@ -362,19 +362,19 @@ class KokoroFlowChatter(BaseChatter): else: # 之前是 IDLE return "new_message" - + async def _get_chat_stream(self): """获取聊天流对象""" try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() if chat_manager: return await chat_manager.get_stream(self.stream_id) except Exception as e: logger.warning(f"[KFC] 获取 chat_stream 失败: {e}") return None - + def _build_result( self, success: bool, @@ -392,18 +392,18 @@ class KokoroFlowChatter(BaseChatter): } result.update(kwargs) return result - + def get_stats(self) -> dict[str, Any]: """获取统计信息""" stats = self._stats.copy() stats["mode"] = self._mode.value return stats - + @property def is_processing(self) -> bool: """是否正在处理""" return self._processing - + @property def mode(self) -> KFCMode: """当前工作模式""" diff --git a/src/plugins/built_in/kokoro_flow_chatter/config.py b/src/plugins/built_in/kokoro_flow_chatter/config.py index f44bb51c8..ee01abf77 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/config.py +++ b/src/plugins/built_in/kokoro_flow_chatter/config.py @@ -10,18 +10,17 @@ Kokoro Flow Chatter - 配置 from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional class KFCMode(str, Enum): """KFC 工作模式""" - + # 统一模式:单次 LLM 调用,生成思考 + 回复(类似旧版架构) UNIFIED = "unified" - + # 分离模式:Planner 生成规划,Replyer 生成回复(推荐) SPLIT = "split" - + @classmethod def from_str(cls, value: str) -> "KFCMode": """从字符串创建模式""" @@ -38,13 +37,13 @@ class KFCMode(str, Enum): @dataclass class WaitingDefaults: """等待配置默认值""" - + # 默认最大等待时间(秒) default_max_wait_seconds: int = 300 - + # 最小等待时间 min_wait_seconds: int = 30 - + # 最大等待时间 max_wait_seconds: int = 1800 @@ -58,25 +57,25 @@ class WaitingDefaults: @dataclass class ProactiveConfig: """主动思考配置""" - + # 是否启用主动思考 enabled: bool = True - + # 沉默阈值(秒),超过此时间考虑主动发起 silence_threshold_seconds: int = 7200 - + # 两次主动发起最小间隔(秒) min_interval_between_proactive: int = 1800 - + # 勿扰时段开始(HH:MM 格式) quiet_hours_start: str = "23:00" - + # 勿扰时段结束 quiet_hours_end: str = "07:00" - + # 主动发起概率(0.0 ~ 1.0) trigger_probability: float = 0.3 - + # 关系门槛:最低好感度,达到此值才会主动关心 min_affinity_for_proactive: float = 0.3 @@ -84,16 +83,16 @@ class ProactiveConfig: @dataclass class PromptConfig: """提示词配置""" - + # 活动记录保留条数 max_activity_entries: int = 30 - + # 每条记录最大字符数 max_entry_length: int = 500 - + # 是否包含人物关系信息 include_relation: bool = True - + # 是否包含记忆信息 include_memory: bool = True @@ -101,30 +100,30 @@ class PromptConfig: @dataclass class SessionConfig: """会话配置""" - + # Session 持久化目录(相对于 data/) session_dir: str = "kokoro_flow_chatter/sessions" - + # Session 自动过期时间(秒),超过此时间未活动自动清理 session_expire_seconds: int = 86400 * 7 # 7 天 - + # 活动记录保留上限 max_mental_log_entries: int = 100 -@dataclass +@dataclass class LLMConfig: """LLM 配置""" - + # 模型名称(空则使用默认) model_name: str = "" - + # Temperature temperature: float = 0.8 - + # 最大 Token max_tokens: int = 1024 - + # 请求超时(秒) timeout: float = 60.0 @@ -132,39 +131,39 @@ class LLMConfig: @dataclass class KokoroFlowChatterConfig: """Kokoro Flow Chatter 总配置""" - + # 是否启用 enabled: bool = True - + # 工作模式:unified(统一模式)或 split(分离模式) # - unified: 单次 LLM 调用完成思考和回复生成(类似旧版架构,更简洁) # - split: Planner + Replyer 两次 LLM 调用(更精细的控制,推荐) mode: KFCMode = KFCMode.UNIFIED - + # 启用的消息源类型(空列表表示全部) - enabled_stream_types: List[str] = field(default_factory=lambda: ["private"]) - + enabled_stream_types: list[str] = field(default_factory=lambda: ["private"]) + # 等待配置 waiting: WaitingDefaults = field(default_factory=WaitingDefaults) - + # 主动思考配置 proactive: ProactiveConfig = field(default_factory=ProactiveConfig) - + # 提示词配置 prompt: PromptConfig = field(default_factory=PromptConfig) - + # 会话配置 session: SessionConfig = field(default_factory=SessionConfig) - + # LLM 配置 llm: LLMConfig = field(default_factory=LLMConfig) - + # 调试模式 debug: bool = False # 全局配置单例 -_config: Optional[KokoroFlowChatterConfig] = None +_config: KokoroFlowChatterConfig | None = None def get_config() -> KokoroFlowChatterConfig: @@ -180,88 +179,88 @@ def load_config() -> KokoroFlowChatterConfig: from src.config.config import global_config config = KokoroFlowChatterConfig() - + # 尝试从全局配置读取 if not global_config: return config - + try: - if hasattr(global_config, 'kokoro_flow_chatter'): - kfc_cfg = getattr(global_config, 'kokoro_flow_chatter') - + if hasattr(global_config, "kokoro_flow_chatter"): + kfc_cfg = getattr(global_config, "kokoro_flow_chatter") + # 基础配置 - 支持 enabled 和 enable 两种写法 - if hasattr(kfc_cfg, 'enable'): + if hasattr(kfc_cfg, "enable"): config.enabled = kfc_cfg.enable - if hasattr(kfc_cfg, 'enabled_stream_types'): + if hasattr(kfc_cfg, "enabled_stream_types"): config.enabled_stream_types = list(kfc_cfg.enabled_stream_types) - if hasattr(kfc_cfg, 'debug'): + if hasattr(kfc_cfg, "debug"): config.debug = kfc_cfg.debug - + # 工作模式配置 - if hasattr(kfc_cfg, 'mode'): + if hasattr(kfc_cfg, "mode"): config.mode = KFCMode.from_str(str(kfc_cfg.mode)) - + # 等待配置 - if hasattr(kfc_cfg, 'waiting'): + if hasattr(kfc_cfg, "waiting"): wait_cfg = kfc_cfg.waiting config.waiting = WaitingDefaults( - default_max_wait_seconds=getattr(wait_cfg, 'default_max_wait_seconds', 300), - min_wait_seconds=getattr(wait_cfg, 'min_wait_seconds', 30), - max_wait_seconds=getattr(wait_cfg, 'max_wait_seconds', 1800), - wait_duration_multiplier=getattr(wait_cfg, 'wait_duration_multiplier', 1.0), - max_consecutive_timeouts=getattr(wait_cfg, 'max_consecutive_timeouts', 3), + default_max_wait_seconds=getattr(wait_cfg, "default_max_wait_seconds", 300), + min_wait_seconds=getattr(wait_cfg, "min_wait_seconds", 30), + max_wait_seconds=getattr(wait_cfg, "max_wait_seconds", 1800), + wait_duration_multiplier=getattr(wait_cfg, "wait_duration_multiplier", 1.0), + max_consecutive_timeouts=getattr(wait_cfg, "max_consecutive_timeouts", 3), ) - + # 主动思考配置 - 支持 proactive 和 proactive_thinking 两种写法 pro_cfg = None - if hasattr(kfc_cfg, 'proactive_thinking'): + if hasattr(kfc_cfg, "proactive_thinking"): pro_cfg = kfc_cfg.proactive_thinking - + if pro_cfg: config.proactive = ProactiveConfig( - enabled=getattr(pro_cfg, 'enabled', True), - silence_threshold_seconds=getattr(pro_cfg, 'silence_threshold_seconds', 7200), - min_interval_between_proactive=getattr(pro_cfg, 'min_interval_between_proactive', 1800), - quiet_hours_start=getattr(pro_cfg, 'quiet_hours_start', "23:00"), - quiet_hours_end=getattr(pro_cfg, 'quiet_hours_end', "07:00"), - trigger_probability=getattr(pro_cfg, 'trigger_probability', 0.3), - min_affinity_for_proactive=getattr(pro_cfg, 'min_affinity_for_proactive', 0.3), + enabled=getattr(pro_cfg, "enabled", True), + silence_threshold_seconds=getattr(pro_cfg, "silence_threshold_seconds", 7200), + min_interval_between_proactive=getattr(pro_cfg, "min_interval_between_proactive", 1800), + quiet_hours_start=getattr(pro_cfg, "quiet_hours_start", "23:00"), + quiet_hours_end=getattr(pro_cfg, "quiet_hours_end", "07:00"), + trigger_probability=getattr(pro_cfg, "trigger_probability", 0.3), + min_affinity_for_proactive=getattr(pro_cfg, "min_affinity_for_proactive", 0.3), ) - + # 提示词配置 - if hasattr(kfc_cfg, 'prompt'): + if hasattr(kfc_cfg, "prompt"): pmt_cfg = kfc_cfg.prompt config.prompt = PromptConfig( - max_activity_entries=getattr(pmt_cfg, 'max_activity_entries', 30), - max_entry_length=getattr(pmt_cfg, 'max_entry_length', 500), - include_relation=getattr(pmt_cfg, 'include_relation', True), - include_memory=getattr(pmt_cfg, 'include_memory', True), + max_activity_entries=getattr(pmt_cfg, "max_activity_entries", 30), + max_entry_length=getattr(pmt_cfg, "max_entry_length", 500), + include_relation=getattr(pmt_cfg, "include_relation", True), + include_memory=getattr(pmt_cfg, "include_memory", True), ) - + # 会话配置 - if hasattr(kfc_cfg, 'session'): + if hasattr(kfc_cfg, "session"): sess_cfg = kfc_cfg.session config.session = SessionConfig( - session_dir=getattr(sess_cfg, 'session_dir', "kokoro_flow_chatter/sessions"), - session_expire_seconds=getattr(sess_cfg, 'session_expire_seconds', 86400 * 7), - max_mental_log_entries=getattr(sess_cfg, 'max_mental_log_entries', 100), + session_dir=getattr(sess_cfg, "session_dir", "kokoro_flow_chatter/sessions"), + session_expire_seconds=getattr(sess_cfg, "session_expire_seconds", 86400 * 7), + max_mental_log_entries=getattr(sess_cfg, "max_mental_log_entries", 100), ) - + # LLM 配置 - if hasattr(kfc_cfg, 'llm'): + if hasattr(kfc_cfg, "llm"): llm_cfg = kfc_cfg.llm config.llm = LLMConfig( - model_name=getattr(llm_cfg, 'model_name', ""), - temperature=getattr(llm_cfg, 'temperature', 0.8), - max_tokens=getattr(llm_cfg, 'max_tokens', 1024), - timeout=getattr(llm_cfg, 'timeout', 60.0), + model_name=getattr(llm_cfg, "model_name", ""), + temperature=getattr(llm_cfg, "temperature", 0.8), + max_tokens=getattr(llm_cfg, "max_tokens", 1024), + timeout=getattr(llm_cfg, "timeout", 60.0), ) - + except Exception as e: from src.common.logger import get_logger logger = get_logger("kfc_config") logger.warning(f"加载 KFC 配置失败,使用默认值: {e}") - + return config @@ -282,7 +281,7 @@ def apply_wait_duration_rules(raw_wait_seconds: int, consecutive_timeouts: int = if multiplier == 0: return 0 - adjusted = int(round(raw_wait_seconds * multiplier)) + adjusted = round(raw_wait_seconds * multiplier) min_wait = max(0, waiting_cfg.min_wait_seconds) max_wait = max(waiting_cfg.max_wait_seconds, 0) diff --git a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py index b57f06f42..605d52cae 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py @@ -14,11 +14,11 @@ Kokoro Flow Chatter 上下文构建器 import asyncio import time from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from src.common.logger import get_logger from src.config.config import global_config -from src.person_info.person_info import get_person_info_manager, PersonInfoManager +from src.person_info.person_info import get_person_info_manager if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream @@ -36,41 +36,41 @@ def _get_config(): class KFCContextBuilder: """ KFC V2 上下文构建器 - + 为提示词提供完整的情境感知数据。 """ - + def __init__(self, chat_stream: "ChatStream"): self.chat_stream = chat_stream self.chat_id = chat_stream.stream_id self.platform = chat_stream.platform self.is_group_chat = bool(chat_stream.group_info) - + async def build_all_context( self, sender_name: str, target_message: str, context: Optional["StreamContext"] = None, - user_id: Optional[str] = None, + user_id: str | None = None, enable_tool: bool = True, ) -> dict[str, str]: """ 并行构建所有上下文模块 - + Args: sender_name: 发送者名称 target_message: 目标消息内容 context: 聊天流上下文(可选) user_id: 用户ID(可选,用于精确查找关系信息) enable_tool: 是否启用工具调用 - + Returns: dict: 包含所有上下文块的字典 """ logger.debug(f"[KFC上下文] 开始构建上下文: sender={sender_name}, target={target_message[:50] if target_message else '(空)'}...") - + chat_history = await self._get_chat_history_text(context) - + tasks = { "relation_info": self._build_relation_info(sender_name, target_message, user_id), "memory_block": self._build_memory_block(chat_history, target_message, context), @@ -79,10 +79,10 @@ class KFCContextBuilder: "schedule": self._build_schedule_block(), "time": self._build_time_block(), } - + results = {} timing_logs = [] - + # 任务名称中英文映射 task_name_mapping = { "relation_info": "感受关系", @@ -92,13 +92,13 @@ class KFCContextBuilder: "schedule": "日程", "time": "时间", } - + try: task_results = await asyncio.gather( *[self._wrap_task_with_timing(name, coro) for name, coro in tasks.items()], return_exceptions=True ) - + for result in task_results: if isinstance(result, tuple) and len(result) == 3: name, value, duration = result @@ -111,13 +111,13 @@ class KFCContextBuilder: logger.warning(f"上下文构建任务异常: {result}") except Exception as e: logger.error(f"并行构建上下文失败: {e}") - + # 输出耗时日志 if timing_logs: logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}") - + return results - + async def _wrap_task_with_timing(self, name: str, coro) -> tuple[str, str, float]: """包装任务以返回名称、结果和耗时""" start_time = time.time() @@ -129,7 +129,7 @@ class KFCContextBuilder: duration = time.time() - start_time logger.error(f"构建 {name} 失败: {e}") return (name, "", duration) - + async def _get_chat_history_text( self, context: Optional["StreamContext"] = None, @@ -138,16 +138,16 @@ class KFCContextBuilder: """获取聊天历史文本""" if context is None: return "" - + try: from src.chat.utils.chat_message_builder import build_readable_messages - + messages = context.get_messages(limit=limit, include_unread=True) if not messages: return "" - + msg_dicts = [msg.flatten() for msg in messages] - + return await build_readable_messages( msg_dicts, replace_bot_name=True, @@ -157,54 +157,54 @@ class KFCContextBuilder: except Exception as e: logger.error(f"获取聊天历史失败: {e}") return "" - - async def _build_relation_info(self, sender_name: str, target_message: str, user_id: Optional[str] = None) -> str: + + async def _build_relation_info(self, sender_name: str, target_message: str, user_id: str | None = None) -> str: """构建关系信息块""" config = _get_config() - + if sender_name == f"{config.bot.nickname}(你)": return "你将要回复的是你自己发送的消息。" - + person_info_manager = get_person_info_manager() - + # 优先使用 user_id + platform 获取 person_id person_id = None if user_id and self.platform: person_id = person_info_manager.get_person_id(self.platform, user_id) logger.debug(f"通过 platform={self.platform}, user_id={user_id} 获取 person_id={person_id}") - + # 如果没有找到,尝试通过 person_name 查找 if not person_id: person_id = await person_info_manager.get_person_id_by_person_name(sender_name) - + if not person_id: logger.debug(f"未找到用户 {sender_name} 的 person_id") return f"你与{sender_name}还没有建立深厚的关系,这是早期的互动阶段。" - + try: from src.person_info.relationship_fetcher import relationship_fetcher_manager - + relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_id) - + user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5) stream_impression = await relationship_fetcher.build_chat_stream_impression(self.chat_id) - + parts = [] if user_relation_info: parts.append(f"### 你与 {sender_name} 的关系\n{user_relation_info}") if stream_impression: scene_type = "这个群" if self.is_group_chat else "你们的私聊" parts.append(f"### 你对{scene_type}的印象\n{stream_impression}") - + if parts: return "\n\n".join(parts) else: return f"你与{sender_name}还没有建立深厚的关系,这是早期的互动阶段。" - + except Exception as e: logger.error(f"获取关系信息失败: {e}") return f"你与{sender_name}是普通朋友关系。" - + async def _build_memory_block( self, chat_history: str, @@ -213,44 +213,44 @@ class KFCContextBuilder: ) -> str: """构建记忆块(使用三层记忆系统)""" config = _get_config() - + if not (config.memory and config.memory.enable): logger.debug("[KFC记忆] 记忆系统未启用") return "" - + try: from src.memory_graph.manager_singleton import get_unified_memory_manager from src.memory_graph.utils.three_tier_formatter import memory_formatter - + unified_manager = get_unified_memory_manager() if not unified_manager: logger.warning("[KFC记忆] 管理器未初始化,跳过记忆检索") return "" - + # 构建查询文本(使用最近多条消息的组合块) query_text = self._build_memory_query_text(target_message, context) logger.debug(f"[KFC记忆] 开始检索,查询文本: {query_text[:100]}...") - + search_result = await unified_manager.search_memories( query_text=query_text, use_judge=True, recent_chat_history=chat_history, ) - + if not search_result: logger.debug("[KFC记忆] 未找到相关记忆") return "" - + perceptual_blocks = search_result.get("perceptual_blocks", []) short_term_memories = search_result.get("short_term_memories", []) long_term_memories = search_result.get("long_term_memories", []) - + formatted_memories = await memory_formatter.format_all_tiers( perceptual_blocks=perceptual_blocks, short_term_memories=short_term_memories, long_term_memories=long_term_memories ) - + total_count = len(perceptual_blocks) + len(short_term_memories) + len(long_term_memories) if total_count > 0 and formatted_memories.strip(): logger.info( @@ -258,16 +258,16 @@ class KFCContextBuilder: f"(感知:{len(perceptual_blocks)}, 短期:{len(short_term_memories)}, 长期:{len(long_term_memories)})" ) return f"### 🧠 相关记忆\n\n{formatted_memories}" - + logger.debug("[KFC记忆] 记忆为空") return "" - + except Exception as e: logger.error(f"[KFC记忆] 检索失败: {e}") import traceback traceback.print_exc() return "" - + def _build_memory_query_text( self, fallback_text: str, @@ -276,23 +276,23 @@ class KFCContextBuilder: ) -> str: """ 将最近若干条消息拼接为一个查询块,用于生成语义向量。 - + Args: fallback_text: 如果无法拼接消息块时使用的后备文本 context: 聊天流上下文 block_size: 组合的消息数量 - + Returns: str: 用于检索的查询文本 """ if not context: return fallback_text - + try: messages = context.get_messages(limit=block_size, include_unread=True) if not messages: return fallback_text - + lines = [] for msg in messages: sender = "" @@ -303,11 +303,11 @@ class KFCContextBuilder: lines.append(f"{sender}: {content}") elif content: lines.append(content) - + return "\n".join(lines) if lines else fallback_text except Exception: return fallback_text - + async def _build_tool_info( self, chat_history: str, @@ -316,30 +316,30 @@ class KFCContextBuilder: enable_tool: bool = True, ) -> str: """构建工具信息块 - + Args: chat_history: 聊天历史记录 sender_name: 发送者名称 target_message: 目标消息内容 enable_tool: 是否启用工具调用 - + Returns: str: 工具信息字符串 """ if not enable_tool: return "" - + try: from src.plugin_system.core.tool_use import ToolExecutor - + tool_executor = ToolExecutor(chat_id=self.chat_id) - + info_parts = [] - + # ========== 1. 主动召回联网搜索缓存 ========== try: from src.common.cache_manager import tool_cache - + # 使用聊天历史作为语义查询 query_text = chat_history if chat_history else target_message recalled_caches = await tool_cache.recall_relevant_cache( @@ -348,7 +348,7 @@ class KFCContextBuilder: top_k=2, similarity_threshold=0.65, # 相似度阈值 ) - + if recalled_caches: recall_parts = ["### 🔍 相关的历史搜索结果"] for item in recalled_caches: @@ -360,19 +360,19 @@ class KFCContextBuilder: if len(content) > 500: content = content[:500] + "..." recall_parts.append(f"**搜索「{original_query}」** (相关度:{similarity:.0%})\n{content}") - + info_parts.append("\n\n".join(recall_parts)) logger.info(f"[缓存召回] 召回了 {len(recalled_caches)} 条相关搜索缓存") except Exception as e: logger.debug(f"[缓存召回] 召回失败(非关键): {e}") - + # ========== 2. 获取工具调用历史 ========== tool_history_str = tool_executor.history_manager.format_for_prompt( max_records=3, include_results=True ) if tool_history_str: info_parts.append(tool_history_str) - + # ========== 3. 执行工具调用 ========== tool_results, _, _ = await tool_executor.execute_from_chat_message( sender=sender_name, @@ -380,7 +380,7 @@ class KFCContextBuilder: chat_history=chat_history, return_details=False, ) - + # 显示当前工具调用的结果(简要信息) if tool_results: current_results_parts = ["### 🔧 刚获取的工具信息"] @@ -389,35 +389,35 @@ class KFCContextBuilder: content = tool_result.get("content", "") # 不进行截断,让工具自己处理结果长度 current_results_parts.append(f"- **{tool_name}**: {content}") - + info_parts.append("\n".join(current_results_parts)) logger.info(f"[工具调用] 获取到 {len(tool_results)} 个工具结果") - + # 如果没有任何信息,返回空字符串 if not info_parts: logger.debug("[工具调用] 未获取到任何工具结果或历史记录") return "" - + return "\n\n".join(info_parts) - + except Exception as e: logger.error(f"[工具调用] 工具信息获取失败: {e}") return "" - + async def _build_expression_habits(self, chat_history: str, target_message: str) -> str: """构建表达习惯块""" config = _get_config() - + use_expression, _, _ = config.expression.get_expression_config_for_chat(self.chat_id) if not use_expression: return "" - + try: from src.chat.express.expression_selector import expression_selector - + style_habits = [] grammar_habits = [] - + selected_expressions = await expression_selector.select_suitable_expressions( chat_id=self.chat_id, chat_history=chat_history, @@ -425,7 +425,7 @@ class KFCContextBuilder: max_num=8, min_num=2 ) - + if selected_expressions: for expr in selected_expressions: if isinstance(expr, dict) and "situation" in expr and "style" in expr: @@ -435,40 +435,40 @@ class KFCContextBuilder: grammar_habits.append(habit_str) else: style_habits.append(habit_str) - + parts = [] if style_habits: parts.append("**语言风格习惯**:\n" + "\n".join(f"- {h}" for h in style_habits)) if grammar_habits: parts.append("**句法习惯**:\n" + "\n".join(f"- {h}" for h in grammar_habits)) - + if parts: return "### 💬 你的表达习惯\n\n" + "\n\n".join(parts) - + return "" - + except Exception as e: logger.error(f"构建表达习惯失败: {e}") return "" - + async def _build_schedule_block(self) -> str: """构建日程信息块""" config = _get_config() - + if not config.planning_system.schedule_enable: return "" - + try: from src.schedule.schedule_manager import schedule_manager - + activity_info = schedule_manager.get_current_activity() if not activity_info: return "" - + activity = activity_info.get("activity") time_range = activity_info.get("time_range") now = datetime.now() - + if time_range: try: start_str, end_str = time_range.split("-") @@ -478,15 +478,15 @@ class KFCContextBuilder: end_time = datetime.strptime(end_str.strip(), "%H:%M").replace( year=now.year, month=now.month, day=now.day ) - + if end_time < start_time: end_time += timedelta(days=1) if now < start_time: now += timedelta(days=1) - + duration_minutes = (now - start_time).total_seconds() / 60 remaining_minutes = (end_time - now).total_seconds() / 60 - + return ( f"你当前正在「{activity}」," f"从{start_time.strftime('%H:%M')}开始,预计{end_time.strftime('%H:%M')}结束," @@ -494,13 +494,13 @@ class KFCContextBuilder: ) except (ValueError, AttributeError): pass - + return f"你当前正在「{activity}」" - + except Exception as e: logger.error(f"构建日程块失败: {e}") return "" - + async def _build_time_block(self) -> str: """构建时间信息块""" now = datetime.now() @@ -514,7 +514,7 @@ async def build_kfc_context( sender_name: str, target_message: str, context: Optional["StreamContext"] = None, - user_id: Optional[str] = None, + user_id: str | None = None, ) -> dict[str, str]: """ 便捷函数:构建KFC所需的所有上下文 diff --git a/src/plugins/built_in/kokoro_flow_chatter/models.py b/src/plugins/built_in/kokoro_flow_chatter/models.py index 774e4d34c..4721bcc7b 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/models.py +++ b/src/plugins/built_in/kokoro_flow_chatter/models.py @@ -10,35 +10,35 @@ Kokoro Flow Chatter - 数据模型 - LLMResponse: LLM 响应结构 """ +import time from dataclasses import dataclass, field from enum import Enum from typing import Any -import time class EventType(Enum): """ 活动流事件类型 - + 用于标记 mental_log 中不同类型的事件, 每种类型对应一个提示词小模板 """ # 用户相关 USER_MESSAGE = "user_message" # 用户发送消息 - + # Bot 行动相关 BOT_PLANNING = "bot_planning" # Bot 规划(thought + actions) - + # 等待相关 WAITING_START = "waiting_start" # 开始等待 WAITING_UPDATE = "waiting_update" # 等待期间心理变化 REPLY_RECEIVED_IN_TIME = "reply_in_time" # 在预期内收到回复 REPLY_RECEIVED_LATE = "reply_late" # 超出预期收到回复 WAIT_TIMEOUT = "wait_timeout" # 等待超时 - + # 主动思考相关 PROACTIVE_TRIGGER = "proactive_trigger" # 主动思考触发(长期沉默) - + def __str__(self) -> str: return self.value @@ -46,14 +46,14 @@ class EventType(Enum): class SessionStatus(Enum): """ 会话状态 - + 极简设计,只有两种稳定状态: - IDLE: 空闲,没有期待回复 - WAITING: 等待对方回复中 """ IDLE = "idle" WAITING = "waiting" - + def __str__(self) -> str: return self.value @@ -62,7 +62,7 @@ class SessionStatus(Enum): class WaitingConfig: """ 等待配置 - + 当 Bot 发送消息后设置的等待参数 """ expected_reaction: str = "" # 期望对方如何回应 @@ -71,33 +71,33 @@ class WaitingConfig: last_thinking_at: float = 0.0 # 上次连续思考的时间戳 thinking_count: int = 0 # 连续思考次数(心理活动) followup_count: int = 0 # 追问次数(真正发送消息的次数) - + def is_active(self) -> bool: """是否正在等待""" return self.max_wait_seconds > 0 and self.started_at > 0 - + def get_elapsed_seconds(self) -> float: """获取已等待时间(秒)""" if not self.is_active(): return 0.0 return time.time() - self.started_at - + def get_elapsed_minutes(self) -> float: """获取已等待时间(分钟)""" return self.get_elapsed_seconds() / 60 - + def is_timeout(self) -> bool: """是否已超时""" if not self.is_active(): return False return self.get_elapsed_seconds() >= self.max_wait_seconds - + def get_progress(self) -> float: """获取等待进度 (0.0 - 1.0)""" if not self.is_active() or self.max_wait_seconds <= 0: return 0.0 return min(self.get_elapsed_seconds() / self.max_wait_seconds, 1.0) - + def to_dict(self) -> dict[str, Any]: return { "expected_reaction": self.expected_reaction, @@ -107,7 +107,7 @@ class WaitingConfig: "thinking_count": self.thinking_count, "followup_count": self.followup_count, } - + @classmethod def from_dict(cls, data: dict[str, Any]) -> "WaitingConfig": return cls( @@ -118,7 +118,7 @@ class WaitingConfig: thinking_count=data.get("thinking_count", 0), followup_count=data.get("followup_count", 0), ) - + def reset(self) -> None: """重置等待配置""" self.expected_reaction = "" @@ -133,34 +133,34 @@ class WaitingConfig: class MentalLogEntry: """ 心理活动日志条目 - + 记录活动流中的每一个事件节点, 用于构建线性叙事风格的提示词 """ event_type: EventType timestamp: float - + # 通用字段 content: str = "" # 事件内容(消息文本、动作描述等) - + # 用户消息相关 user_name: str = "" # 发送者名称 user_id: str = "" # 发送者 ID - + # Bot 规划相关 thought: str = "" # 内心想法 actions: list[dict] = field(default_factory=list) # 执行的动作列表 expected_reaction: str = "" # 期望的回应 max_wait_seconds: int = 0 # 设定的等待时间 - + # 等待相关 elapsed_seconds: float = 0.0 # 已等待时间 waiting_thought: str = "" # 等待期间的想法 mood: str = "" # 当前心情 - + # 元数据 metadata: dict[str, Any] = field(default_factory=dict) - + def to_dict(self) -> dict[str, Any]: return { "event_type": str(self.event_type), @@ -177,7 +177,7 @@ class MentalLogEntry: "mood": self.mood, "metadata": self.metadata, } - + @classmethod def from_dict(cls, data: dict[str, Any]) -> "MentalLogEntry": event_type_str = data.get("event_type", "user_message") @@ -185,7 +185,7 @@ class MentalLogEntry: event_type = EventType(event_type_str) except ValueError: event_type = EventType.USER_MESSAGE - + return cls( event_type=event_type, timestamp=data.get("timestamp", time.time()), @@ -201,7 +201,7 @@ class MentalLogEntry: mood=data.get("mood", ""), metadata=data.get("metadata", {}), ) - + def get_time_str(self, format: str = "%H:%M") -> str: """获取格式化的时间字符串""" return time.strftime(format, time.localtime(self.timestamp)) @@ -211,27 +211,27 @@ class MentalLogEntry: class ActionModel: """ 动作模型 - + 表示 LLM 决策的单个动作 """ type: str # 动作类型 params: dict[str, Any] = field(default_factory=dict) # 动作参数 reason: str = "" # 选择该动作的理由 - + def to_dict(self) -> dict[str, Any]: result = {"type": self.type} if self.reason: result["reason"] = self.reason result.update(self.params) return result - + @classmethod def from_dict(cls, data: dict[str, Any]) -> "ActionModel": action_type = data.get("type", "do_nothing") reason = data.get("reason", "") params = {k: v for k, v in data.items() if k not in ("type", "reason")} return cls(type=action_type, params=params, reason=reason) - + def get_description(self) -> str: """获取动作的文字描述""" if self.type == "kfc_reply": @@ -252,17 +252,17 @@ class ActionModel: class LLMResponse: """ LLM 响应结构 - + 定义 LLM 输出的 JSON 格式 """ thought: str # 内心想法 actions: list[ActionModel] # 动作列表 expected_reaction: str = "" # 期望对方的回应 max_wait_seconds: int = 0 # 最长等待时间(0 = 不等待) - + # 可选字段 mood: str = "" # 当前心情 - + def to_dict(self) -> dict[str, Any]: return { "thought": self.thought, @@ -271,16 +271,16 @@ class LLMResponse: "max_wait_seconds": self.max_wait_seconds, "mood": self.mood, } - + @classmethod def from_dict(cls, data: dict[str, Any]) -> "LLMResponse": actions_data = data.get("actions", []) actions = [ActionModel.from_dict(a) for a in actions_data] if actions_data else [] - + # 如果没有动作,添加默认的 do_nothing if not actions: actions = [ActionModel(type="do_nothing")] - + # 处理 max_wait_seconds,确保在合理范围内 max_wait = data.get("max_wait_seconds", 0) try: @@ -288,7 +288,7 @@ class LLMResponse: max_wait = max(0, min(max_wait, 1800)) # 0-30分钟 except (ValueError, TypeError): max_wait = 0 - + return cls( thought=data.get("thought", ""), actions=actions, @@ -296,7 +296,7 @@ class LLMResponse: max_wait_seconds=max_wait, mood=data.get("mood", ""), ) - + @classmethod def create_error_response(cls, error_message: str) -> "LLMResponse": """创建错误响应""" @@ -306,18 +306,18 @@ class LLMResponse: expected_reaction="", max_wait_seconds=0, ) - + def has_reply(self) -> bool: """是否包含回复动作""" return any(a.type in ("kfc_reply", "respond") for a in self.actions) - + def get_reply_content(self) -> str: """获取回复内容""" for action in self.actions: if action.type in ("kfc_reply", "respond"): return action.params.get("content", "") return "" - + def get_actions_description(self) -> str: """获取所有动作的文字描述""" descriptions = [a.get_description() for a in self.actions] diff --git a/src/plugins/built_in/kokoro_flow_chatter/planner.py b/src/plugins/built_in/kokoro_flow_chatter/planner.py index 1fa7bfcfc..a5956f324 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/planner.py +++ b/src/plugins/built_in/kokoro_flow_chatter/planner.py @@ -28,12 +28,12 @@ async def generate_plan( user_name: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - available_actions: Optional[dict] = None, - extra_context: Optional[dict] = None, + available_actions: dict | None = None, + extra_context: dict | None = None, ) -> LLMResponse: """ 生成行动计划 - + Args: session: 会话对象 user_name: 用户名称 @@ -41,7 +41,7 @@ async def generate_plan( chat_stream: 聊天流对象 available_actions: 可用动作字典 extra_context: 额外上下文 - + Returns: LLMResponse 对象,包含计划信息 """ @@ -56,34 +56,34 @@ async def generate_plan( available_actions=available_actions, extra_context=extra_context, ) - + from src.config.config import global_config if global_config and global_config.debug.show_prompt: logger.info(f"[KFC Planner] 生成的规划提示词:\n{prompt}") - + # 2. 获取 planner 模型配置并调用 LLM models = llm_api.get_available_models() planner_config = models.get("planner") - + if not planner_config: logger.error("[KFC Planner] 未找到 planner 模型配置") return LLMResponse.create_error_response("未找到 planner 模型配置") - - success, raw_response, reasoning, model_name = await llm_api.generate_with_model( + + success, raw_response, _reasoning, model_name = await llm_api.generate_with_model( prompt=prompt, model_config=planner_config, request_type="kokoro_flow_chatter.plan", ) - + if not success: logger.error(f"[KFC Planner] LLM 调用失败: {raw_response}") return LLMResponse.create_error_response(raw_response) - + logger.debug(f"[KFC Planner] LLM 响应 (model={model_name}):\n{raw_response}") - + # 3. 解析响应 return _parse_response(raw_response) - + except Exception as e: logger.error(f"[KFC Planner] 生成失败: {e}") import traceback @@ -94,20 +94,20 @@ async def generate_plan( def _parse_response(raw_response: str) -> LLMResponse: """解析 LLM 响应""" data = extract_and_parse_json(raw_response, strict=False) - + if not data or not isinstance(data, dict): logger.warning(f"[KFC Planner] 无法解析 JSON: {raw_response[:200]}...") return LLMResponse.create_error_response("无法解析响应格式") - + response = LLMResponse.from_dict(data) - + if response.thought: # 使用 logger 输出美化日志(颜色通过 logger 系统配置) logger.info(f"💭 {response.thought}") - + actions_str = ", ".join(a.type for a in response.actions) logger.debug(f"actions={actions_str}") else: logger.warning("响应缺少 thought") - + return response diff --git a/src/plugins/built_in/kokoro_flow_chatter/plugin.py b/src/plugins/built_in/kokoro_flow_chatter/plugin.py index 8e5079215..be5a3af77 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/plugin.py +++ b/src/plugins/built_in/kokoro_flow_chatter/plugin.py @@ -7,9 +7,8 @@ Kokoro Flow Chatter - 插件注册 from typing import Any, ClassVar from src.common.logger import get_logger -from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.component_types import ChatterInfo from src.plugin_system import register_plugin +from src.plugin_system.base.base_plugin import BasePlugin from .chatter import KokoroFlowChatter from .config import get_config @@ -35,20 +34,20 @@ class KokoroFlowChatterPlugin(BasePlugin): dependencies: ClassVar[list[str]] = [] python_dependencies: ClassVar[list[str]] = [] config_file_name: str = "config.toml" - + # 状态 _is_started: bool = False - + async def on_plugin_loaded(self): """插件加载时""" config = get_config() - + if not config.enabled: logger.info("[KFC] 插件已禁用") return logger.info("[KFC] 插件已加载") - + # 启动主动思考器 if config.proactive.enabled: try: @@ -57,7 +56,7 @@ class KokoroFlowChatterPlugin(BasePlugin): self._is_started = True except Exception as e: logger.error(f"[KFC] 启动主动思考器失败: {e}") - + async def on_plugin_unloaded(self): """插件卸载时""" try: @@ -66,16 +65,16 @@ class KokoroFlowChatterPlugin(BasePlugin): self._is_started = False except Exception as e: logger.warning(f"[KFC] 停止主动思考器失败: {e}") - + def get_plugin_components(self): """返回组件列表""" config = get_config() - + if not config.enabled: return [] - + components = [] - + try: # 注册 Chatter components.append(( @@ -97,9 +96,9 @@ class KokoroFlowChatterPlugin(BasePlugin): logger.debug("[KFC] 成功加载 KFCReplyAction 组件") except Exception as e: logger.error(f"[KFC] 加载 Reply 动作失败: {e}") - + return components - + def get_plugin_info(self) -> dict[str, Any]: """获取插件信息""" return { diff --git a/src/plugins/built_in/kokoro_flow_chatter/proactive_thinker.py b/src/plugins/built_in/kokoro_flow_chatter/proactive_thinker.py index e80ee1863..45af2665c 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/proactive_thinker.py +++ b/src/plugins/built_in/kokoro_flow_chatter/proactive_thinker.py @@ -17,7 +17,7 @@ import asyncio import random import time from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional +from typing import TYPE_CHECKING from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.logger import get_logger @@ -29,7 +29,7 @@ from .models import EventType, SessionStatus from .session import KokoroSession, get_session_manager if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream + pass logger = get_logger("kfc_proactive_thinker") @@ -37,40 +37,40 @@ logger = get_logger("kfc_proactive_thinker") class ProactiveThinker: """ 主动思考器 - + 独立于 Chatter,负责处理: 1. 等待期间的连续思考 2. 等待超时 3. 长期沉默后主动发起 - + 核心逻辑: - 定期检查所有 WAITING 状态的 Session - 触发连续思考或超时决策 - 定期检查长期沉默的 Session,考虑主动发起 - + 支持两种工作模式(与 Chatter 保持一致): - unified: 单次 LLM 调用 - split: Planner + Replyer 两次调用 """ - + # 连续思考触发点(等待进度百分比) THINKING_TRIGGERS = [0.3, 0.6, 0.85] - + # 任务名称 TASK_WAITING_CHECK = "kfc_waiting_check" TASK_PROACTIVE_CHECK = "kfc_proactive_check" - + def __init__(self): self.session_manager = get_session_manager() - + # 配置 self._load_config() - + # 调度任务 ID - self._waiting_schedule_id: Optional[str] = None - self._proactive_schedule_id: Optional[str] = None + self._waiting_schedule_id: str | None = None + self._proactive_schedule_id: str | None = None self._running = False - + # 统计 self._stats = { "waiting_checks": 0, @@ -78,21 +78,21 @@ class ProactiveThinker: "timeout_decisions": 0, "proactive_triggered": 0, } - + def _load_config(self) -> None: """加载配置 - 使用统一的配置系统""" config = get_config() proactive_cfg = config.proactive self._waiting_cfg = config.waiting - + # 工作模式 self._mode = config.mode - + # 等待检查间隔(秒) self.waiting_check_interval = 15.0 # 主动思考检查间隔(秒) self.proactive_check_interval = 300.0 - + # 从配置读取主动思考相关设置 self.proactive_enabled = proactive_cfg.enabled self.silence_threshold = proactive_cfg.silence_threshold_seconds @@ -101,15 +101,15 @@ class ProactiveThinker: self.quiet_hours_end = proactive_cfg.quiet_hours_end self.trigger_probability = proactive_cfg.trigger_probability self.min_affinity_for_proactive = proactive_cfg.min_affinity_for_proactive - + async def start(self) -> None: """启动主动思考器""" if self._running: logger.info("已在运行中") return - + self._running = True - + # 注册等待检查任务(始终启用,用于处理等待中的 Session) self._waiting_schedule_id = await unified_scheduler.create_schedule( callback=self._check_waiting_sessions, @@ -120,7 +120,7 @@ class ProactiveThinker: force_overwrite=True, timeout=60.0, ) - + # 注册主动思考检查任务(仅在启用时注册) if self.proactive_enabled: self._proactive_schedule_id = await unified_scheduler.create_schedule( @@ -135,49 +135,49 @@ class ProactiveThinker: logger.info("[ProactiveThinker] 已启动(主动思考已启用)") else: logger.info("[ProactiveThinker] 已启动(主动思考已禁用)") - + async def stop(self) -> None: """停止主动思考器""" if not self._running: return - + self._running = False - + if self._waiting_schedule_id: await unified_scheduler.remove_schedule(self._waiting_schedule_id) if self._proactive_schedule_id: await unified_scheduler.remove_schedule(self._proactive_schedule_id) - + logger.info("[ProactiveThinker] 已停止") - + # ======================== # 等待检查 # ======================== - + async def _check_waiting_sessions(self) -> None: """检查所有等待中的 Session""" self._stats["waiting_checks"] += 1 - + sessions = await self.session_manager.get_waiting_sessions() if not sessions: return - + # 并行处理 tasks = [ asyncio.create_task(self._process_waiting_session(s)) for s in sessions ] await asyncio.gather(*tasks, return_exceptions=True) - + async def _process_waiting_session(self, session: KokoroSession) -> None: """处理单个等待中的 Session""" try: if session.status != SessionStatus.WAITING: return - + if not session.waiting_config.is_active(): return - + # 防止与 Chatter 并发处理:如果 Session 刚刚被更新(5秒内),跳过 # 这样可以避免 Chatter 正在处理时,ProactiveThinker 也开始处理 time_since_last_activity = time.time() - session.last_activity_at @@ -187,36 +187,36 @@ class ProactiveThinker: f"({time_since_last_activity:.1f}s ago),跳过处理" ) return - + # 检查是否超时 if session.waiting_config.is_timeout(): await self._handle_timeout(session) return - + # 检查是否需要触发连续思考 progress = session.waiting_config.get_progress() if self._should_trigger_thinking(session, progress): await self._handle_continuous_thinking(session, progress) - + except Exception as e: logger.error(f"[ProactiveThinker] 处理等待 Session 失败 {session.user_id}: {e}") - + def _should_trigger_thinking(self, session: KokoroSession, progress: float) -> bool: """判断是否应触发连续思考""" # 计算应该触发的次数 expected_count = sum(1 for t in self.THINKING_TRIGGERS if progress >= t) - + if session.waiting_config.thinking_count >= expected_count: return False - + # 确保两次思考之间有间隔 if session.waiting_config.last_thinking_at > 0: elapsed = time.time() - session.waiting_config.last_thinking_at if elapsed < 30: # 至少 30 秒间隔 return False - + return True - + async def _handle_continuous_thinking( self, session: KokoroSession, @@ -224,31 +224,31 @@ class ProactiveThinker: ) -> None: """处理连续思考""" self._stats["continuous_thinking_triggered"] += 1 - + # 获取用户名 user_name = await self._get_user_name(session.user_id, session.stream_id) - + # 调用 LLM 生成等待中的想法 thought = await self._generate_waiting_thought(session, user_name, progress) - + # 记录到 mental_log session.add_waiting_update( waiting_thought=thought, mood="", # 心情已融入 thought 中 ) - + # 更新思考计数 session.waiting_config.thinking_count += 1 session.waiting_config.last_thinking_at = time.time() - + # 保存 await self.session_manager.save_session(session.user_id) - + logger.debug( f"[ProactiveThinker] 连续思考: user={session.user_id}, " f"progress={progress:.1%}, thought={thought[:30]}..." ) - + async def _generate_waiting_thought( self, session: KokoroSession, @@ -259,19 +259,19 @@ class ProactiveThinker: try: from src.chat.utils.prompt import global_prompt_manager from src.plugin_system.apis import llm_api - + from .prompt.builder import get_prompt_builder from .prompt.prompts import PROMPT_NAMES - + # 使用 PromptBuilder 构建人设块 prompt_builder = get_prompt_builder() persona_block = prompt_builder._build_persona_block() - + # 获取关系信息 relation_block = f"你与 {user_name} 还不太熟悉。" try: from src.person_info.relationship_manager import relationship_manager - + person_info_manager = await self._get_person_info_manager() if person_info_manager: platform = global_config.bot.platform if global_config else "qq" @@ -281,7 +281,7 @@ class ProactiveThinker: relation_block = f"你与 {user_name} 的亲密度是 {relationship.intimacy}。{relationship.description or ''}" except Exception as e: logger.debug(f"获取关系信息失败: {e}") - + # 获取上次发送的消息 last_bot_message = "(未知)" for entry in reversed(session.mental_log): @@ -294,12 +294,12 @@ class ProactiveThinker: break if last_bot_message != "(未知)": break - + # 构建提示词 elapsed_minutes = session.waiting_config.get_elapsed_minutes() max_wait_minutes = session.waiting_config.max_wait_seconds / 60 expected_reaction = session.waiting_config.expected_reaction or "对方能回复点什么" - + prompt = await global_prompt_manager.format_prompt( PROMPT_NAMES["waiting_thought"], persona_block=persona_block, @@ -311,32 +311,32 @@ class ProactiveThinker: max_wait_minutes=max_wait_minutes, progress_percent=int(progress * 100), ) - + # 调用情绪模型 models = llm_api.get_available_models() emotion_config = models.get("emotion") or models.get("replyer") - + if not emotion_config: logger.warning("[ProactiveThinker] 未找到 emotion/replyer 模型配置,使用默认想法") return self._get_fallback_thought(elapsed_minutes, progress) - + success, raw_response, _, model_name = await llm_api.generate_with_model( prompt=prompt, model_config=emotion_config, request_type="kokoro_flow_chatter.waiting_thought", ) - + if not success or not raw_response: logger.warning(f"[ProactiveThinker] LLM 调用失败: {raw_response}") return self._get_fallback_thought(elapsed_minutes, progress) - + # 使用统一的文本清理函数 from .replyer import _clean_reply_text thought = _clean_reply_text(raw_response) - + logger.debug(f"[ProactiveThinker] LLM 生成等待想法 (model={model_name}): {thought[:50]}...") return thought - + except Exception as e: logger.error(f"[ProactiveThinker] 生成等待想法失败: {e}") import traceback @@ -345,7 +345,7 @@ class ProactiveThinker: session.waiting_config.get_elapsed_minutes(), progress ) - + def _get_fallback_thought(self, elapsed_minutes: float, progress: float) -> str: """获取备用的等待想法(当 LLM 调用失败时使用)""" if progress < 0.4: @@ -367,7 +367,7 @@ class ProactiveThinker: "快到时间了,对方还是没回", ] return random.choice(thoughts) - + async def _get_person_info_manager(self): """获取 person_info_manager""" try: @@ -375,16 +375,16 @@ class ProactiveThinker: return get_person_info_manager() except Exception: return None - + async def _handle_timeout(self, session: KokoroSession) -> None: """处理等待超时 - 支持双模式""" self._stats["timeout_decisions"] += 1 - + # 再次检查 Session 状态,防止在等待过程中被 Chatter 处理 if session.status != SessionStatus.WAITING: logger.debug(f"[ProactiveThinker] Session {session.user_id} 已不在等待状态,跳过超时处理") return - + # 再次检查最近活动时间 time_since_last_activity = time.time() - session.last_activity_at if time_since_last_activity < 5: @@ -392,36 +392,36 @@ class ProactiveThinker: f"[ProactiveThinker] Session {session.user_id} 刚有活动,跳过超时处理" ) return - + # 增加连续超时计数 session.consecutive_timeout_count += 1 - + logger.info( f"[ProactiveThinker] 等待超时: user={session.user_id}, " f"consecutive_timeout={session.consecutive_timeout_count}" ) - + try: # 获取用户名 user_name = await self._get_user_name(session.user_id, session.stream_id) - + # 获取聊天流 chat_stream = await self._get_chat_stream(session.stream_id) - + # 加载动作 action_manager = ChatterActionManager() await action_manager.load_actions(session.stream_id) - + # 通过 ActionModifier 过滤动作 from src.chat.planner_actions.action_modifier import ActionModifier action_modifier = ActionModifier(action_manager, session.stream_id) await action_modifier.modify_actions(chatter_name="KokoroFlowChatter") - + # 计算用户最后回复距今的时间 time_since_user_reply = None if session.last_user_message_at: time_since_user_reply = time.time() - session.last_user_message_at - + # 构建超时上下文信息 extra_context = { "consecutive_timeout_count": session.consecutive_timeout_count, @@ -429,7 +429,7 @@ class ProactiveThinker: "time_since_user_reply": time_since_user_reply, "time_since_user_reply_str": self._format_duration(time_since_user_reply) if time_since_user_reply else "未知", } - + # 根据模式选择生成方式 if self._mode == KFCMode.UNIFIED: # 统一模式:单次 LLM 调用 @@ -452,7 +452,7 @@ class ProactiveThinker: available_actions=action_manager.get_using_actions(), extra_context=extra_context, ) - + # 分离模式下需要注入上下文信息 for action in plan_response.actions: if action.type == "kfc_reply": @@ -485,14 +485,14 @@ class ProactiveThinker: adjusted_wait, ) plan_response.max_wait_seconds = adjusted_wait - + # ★ 在执行动作前最后一次检查状态,防止与 Chatter 并发 if session.status != SessionStatus.WAITING: logger.info( f"[ProactiveThinker] Session {session.user_id} 已被 Chatter 处理,取消执行动作" ) return - + # 执行动作(回复生成在 Action.execute() 中完成) for action in plan_response.actions: await action_manager.execute_action( @@ -504,7 +504,7 @@ class ProactiveThinker: thinking_id=None, log_prefix="[KFC ProactiveThinker]", ) - + # 🎯 只有真正发送了消息才增加追问计数(do_nothing 不算追问) has_reply_action = any( a.type in ("kfc_reply", "respond", "poke_user", "send_emoji") @@ -513,7 +513,7 @@ class ProactiveThinker: if has_reply_action: session.waiting_config.followup_count += 1 logger.debug(f"[ProactiveThinker] 超时追问计数+1: user={session.user_id}, followup_count={session.waiting_config.followup_count}") - + # 记录到 mental_log session.add_bot_planning( thought=plan_response.thought, @@ -521,7 +521,7 @@ class ProactiveThinker: expected_reaction=plan_response.expected_reaction, max_wait_seconds=plan_response.max_wait_seconds, ) - + # 更新状态 if plan_response.max_wait_seconds > 0: # 继续等待 @@ -532,36 +532,36 @@ class ProactiveThinker: else: # 不再等待 session.end_waiting() - + # 保存 await self.session_manager.save_session(session.user_id) - + logger.info( f"[ProactiveThinker] 超时决策完成: user={session.user_id}, " f"actions={[a.type for a in plan_response.actions]}, " f"continue_wait={plan_response.max_wait_seconds > 0}, " f"consecutive_timeout={session.consecutive_timeout_count}" ) - + except Exception as e: logger.error(f"[ProactiveThinker] 处理超时失败: {e}") # 出错时结束等待 session.end_waiting() await self.session_manager.save_session(session.user_id) - + # ======================== # 主动思考(长期沉默) # ======================== - + async def _check_proactive_sessions(self) -> None: """检查是否有需要主动发起对话的 Session""" # 检查是否在勿扰时段 if self._is_quiet_hours(): return - + sessions = await self.session_manager.get_all_sessions() current_time = time.time() - + for session in sessions: try: trigger_reason = self._should_trigger_proactive(session, current_time) @@ -569,54 +569,54 @@ class ProactiveThinker: await self._handle_proactive(session, trigger_reason) except Exception as e: logger.error(f"[ProactiveThinker] 检查主动思考失败 {session.user_id}: {e}") - + def _is_quiet_hours(self) -> bool: """检查是否在勿扰时段""" try: now = datetime.now() current_minutes = now.hour * 60 + now.minute - + start_parts = self.quiet_hours_start.split(":") start_minutes = int(start_parts[0]) * 60 + int(start_parts[1]) - + end_parts = self.quiet_hours_end.split(":") end_minutes = int(end_parts[0]) * 60 + int(end_parts[1]) - + if start_minutes <= end_minutes: return start_minutes <= current_minutes < end_minutes else: return current_minutes >= start_minutes or current_minutes < end_minutes except: return False - + def _should_trigger_proactive( self, session: KokoroSession, current_time: float, - ) -> Optional[str]: + ) -> str | None: """判断是否应触发主动思考""" # 只检查 IDLE 状态的 Session if session.status != SessionStatus.IDLE: return None - + # 检查沉默时长 silence_duration = current_time - session.last_activity_at if silence_duration < self.silence_threshold: return None - + # 检查距离上次主动思考的间隔 if session.last_proactive_at: time_since_last = current_time - session.last_proactive_at if time_since_last < self.min_proactive_interval: return None - + # 概率触发(避免每次检查都触发) if random.random() > self.trigger_probability: return None - + silence_hours = silence_duration / 3600 return f"沉默了 {silence_hours:.1f} 小时" - + async def _handle_proactive( self, session: KokoroSession, @@ -624,7 +624,7 @@ class ProactiveThinker: ) -> None: """处理主动思考 - 支持双模式""" self._stats["proactive_triggered"] += 1 - + # 再次检查最近活动时间,防止与 Chatter 并发 time_since_last_activity = time.time() - session.last_activity_at if time_since_last_activity < 5: @@ -632,37 +632,37 @@ class ProactiveThinker: f"[ProactiveThinker] Session {session.user_id} 刚有活动,跳过主动思考" ) return - + logger.info(f"主动思考触发: user={session.user_id}, reason={trigger_reason}") - + try: # 获取用户名 user_name = await self._get_user_name(session.user_id, session.stream_id) - + # 获取聊天流 chat_stream = await self._get_chat_stream(session.stream_id) - + # 加载动作 action_manager = ChatterActionManager() await action_manager.load_actions(session.stream_id) - + # 通过 ActionModifier 过滤动作 from src.chat.planner_actions.action_modifier import ActionModifier action_modifier = ActionModifier(action_manager, session.stream_id) await action_modifier.modify_actions(chatter_name="KokoroFlowChatter") - + # 计算沉默时长 silence_seconds = time.time() - session.last_activity_at if silence_seconds < 3600: silence_duration = f"{silence_seconds / 60:.0f} 分钟" else: silence_duration = f"{silence_seconds / 3600:.1f} 小时" - + extra_context = { "trigger_reason": trigger_reason, "silence_duration": silence_duration, } - + # 根据模式选择生成方式 if self._mode == KFCMode.UNIFIED: # 统一模式:单次 LLM 调用 @@ -686,19 +686,19 @@ class ProactiveThinker: available_actions=action_manager.get_using_actions(), extra_context=extra_context, ) - + # 检查是否决定不打扰 is_do_nothing = ( len(plan_response.actions) == 0 or (len(plan_response.actions) == 1 and plan_response.actions[0].type == "do_nothing") ) - + if is_do_nothing: logger.info(f"决定不打扰: user={session.user_id}") session.last_proactive_at = time.time() await self.session_manager.save_session(session.user_id) return - + # 分离模式下需要注入上下文信息 if self._mode == KFCMode.SPLIT: for action in plan_response.actions: @@ -732,7 +732,7 @@ class ProactiveThinker: adjusted_wait, ) plan_response.max_wait_seconds = adjusted_wait - + # 执行动作(回复生成在 Action.execute() 中完成) for action in plan_response.actions: await action_manager.execute_action( @@ -744,7 +744,7 @@ class ProactiveThinker: thinking_id=None, log_prefix="[KFC ProactiveThinker]", ) - + # 记录到 mental_log session.add_bot_planning( thought=plan_response.thought, @@ -752,7 +752,7 @@ class ProactiveThinker: expected_reaction=plan_response.expected_reaction, max_wait_seconds=plan_response.max_wait_seconds, ) - + # 更新状态 session.last_proactive_at = time.time() if plan_response.max_wait_seconds > 0: @@ -760,54 +760,54 @@ class ProactiveThinker: expected_reaction=plan_response.expected_reaction, max_wait_seconds=plan_response.max_wait_seconds, ) - + # 保存 await self.session_manager.save_session(session.user_id) - + logger.info( f"[ProactiveThinker] 主动发起完成: user={session.user_id}, " f"actions={[a.type for a in plan_response.actions]}" ) - + except Exception as e: logger.error(f"[ProactiveThinker] 主动思考失败: {e}") - + async def _get_chat_stream(self, stream_id: str): """获取聊天流""" try: from src.chat.message_receive.chat_stream import get_chat_manager - + chat_manager = get_chat_manager() if chat_manager: return await chat_manager.get_stream(stream_id) except Exception as e: logger.warning(f"[ProactiveThinker] 获取 chat_stream 失败: {e}") return None - + async def _get_user_name(self, user_id: str, stream_id: str) -> str: """获取用户名称(优先从 person_info 获取)""" try: from src.person_info.person_info import get_person_info_manager - + person_info_manager = get_person_info_manager() platform = global_config.bot.platform if global_config else "qq" - + person_id = person_info_manager.get_person_id(platform, user_id) person_name = await person_info_manager.get_value(person_id, "person_name") - + if person_name: return person_name except Exception as e: logger.debug(f"[ProactiveThinker] 获取用户名失败: {e}") - + # 回退到 user_id return user_id - + def _format_duration(self, seconds: float | None) -> str: """格式化时间间隔为人类可读的字符串""" if seconds is None or seconds < 0: return "未知" - + if seconds < 60: return f"{int(seconds)} 秒" elif seconds < 3600: @@ -819,7 +819,7 @@ class ProactiveThinker: else: days = seconds / 86400 return f"{days:.1f} 天" - + def get_stats(self) -> dict: """获取统计信息""" return { @@ -829,7 +829,7 @@ class ProactiveThinker: # 全局单例 -_proactive_thinker: Optional[ProactiveThinker] = None +_proactive_thinker: ProactiveThinker | None = None def get_proactive_thinker() -> ProactiveThinker: diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt/__init__.py b/src/plugins/built_in/kokoro_flow_chatter/prompt/__init__.py index 501e3b92f..dc63564e6 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt/__init__.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt/__init__.py @@ -10,7 +10,7 @@ from .builder import PromptBuilder, get_prompt_builder from .prompts import PROMPT_NAMES __all__ = [ + "PROMPT_NAMES", "PromptBuilder", "get_prompt_builder", - "PROMPT_NAMES", ] diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py index f2e6b65c5..cce36f1ec 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py @@ -12,7 +12,7 @@ from src.chat.utils.prompt import global_prompt_manager from src.common.logger import get_logger from src.config.config import global_config -from ..models import EventType, MentalLogEntry, SessionStatus +from ..models import EventType, MentalLogEntry from ..session import KokoroSession # 导入模板注册(确保模板被注册到 global_prompt_manager) @@ -28,28 +28,28 @@ logger = get_logger("kfc_prompt_builder") class PromptBuilder: """ 提示词构建器 - + 使用统一的 Prompt 管理系统构建提示词: 1. 构建活动流(从 mental_log 生成线性叙事) 2. 构建当前情况描述 3. 使用 global_prompt_manager 格式化最终提示词 """ - + def __init__(self): self._context_builder = None - + async def build_planner_prompt( self, session: KokoroSession, user_name: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - available_actions: Optional[dict] = None, - extra_context: Optional[dict] = None, + available_actions: dict | None = None, + extra_context: dict | None = None, ) -> str: """ 构建规划器提示词(用于生成行动计划) - + Args: session: 会话对象 user_name: 用户名称 @@ -57,45 +57,45 @@ class PromptBuilder: chat_stream: 聊天流对象 available_actions: 可用动作字典 extra_context: 额外上下文(如 trigger_reason) - + Returns: 完整的规划器提示词 """ extra_context = extra_context or {} - + # 获取 user_id(从 session 中) user_id = session.user_id if session else None - + # 1. 构建人设块 persona_block = self._build_persona_block() - + # 1.5. 构建安全互动准则块 safety_guidelines_block = self._build_safety_guidelines_block() - + # 2. 使用 context_builder 获取关系、记忆、工具、表达习惯等 context_data = await self._build_context_data(user_name, chat_stream, user_id) relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。") memory_block = context_data.get("memory_block", "") tool_info = context_data.get("tool_info", "") expression_habits = self._build_combined_expression_block(context_data.get("expression_habits", "")) - + # 3. 构建活动流 activity_stream = await self._build_activity_stream(session, user_name) - + # 4. 构建当前情况 current_situation = await self._build_current_situation( session, user_name, situation_type, extra_context ) - + # 5. 构建聊天历史总览 chat_history_block = await self._build_chat_history_block(chat_stream) - + # 6. 构建可用动作 actions_block = self._build_actions_block(available_actions) - + # 7. 获取规划器输出格式 output_format = await self._get_planner_output_format() - + # 8. 使用统一的 prompt 管理系统格式化 prompt = await global_prompt_manager.format_prompt( PROMPT_NAMES["main"], @@ -112,9 +112,9 @@ class PromptBuilder: available_actions=actions_block, output_format=output_format, ) - + return prompt - + async def build_replyer_prompt( self, session: KokoroSession, @@ -122,11 +122,11 @@ class PromptBuilder: thought: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - extra_context: Optional[dict] = None, + extra_context: dict | None = None, ) -> str: """ 构建回复器提示词(用于生成自然的回复文本) - + Args: session: 会话对象 user_name: 用户名称 @@ -134,44 +134,44 @@ class PromptBuilder: situation_type: 情况类型 chat_stream: 聊天流对象 extra_context: 额外上下文 - + Returns: 完整的回复器提示词 """ extra_context = extra_context or {} - + # 获取 user_id user_id = session.user_id if session else None - + # 1. 构建人设块 persona_block = self._build_persona_block() - + # 1.5. 构建安全互动准则块 safety_guidelines_block = self._build_safety_guidelines_block() - + # 2. 使用 context_builder 获取关系、记忆、表达习惯等 context_data = await self._build_context_data(user_name, chat_stream, user_id) relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。") memory_block = context_data.get("memory_block", "") tool_info = context_data.get("tool_info", "") expression_habits = self._build_combined_expression_block(context_data.get("expression_habits", "")) - + # 3. 构建活动流 activity_stream = await self._build_activity_stream(session, user_name) - + # 4. 构建当前情况(回复器专用,简化版,不包含决策语言) current_situation = await self._build_replyer_situation( session, user_name, situation_type, extra_context ) - + # 5. 构建聊天历史总览 chat_history_block = await self._build_chat_history_block(chat_stream) - + # 6. 构建回复情景上下文 reply_context = await self._build_reply_context( session, user_name, situation_type, extra_context ) - + # 7. 使用回复器专用模板 prompt = await global_prompt_manager.format_prompt( PROMPT_NAMES["replyer"], @@ -188,60 +188,60 @@ class PromptBuilder: thought=thought, reply_context=reply_context, ) - + return prompt - + def _build_persona_block(self) -> str: """构建人设块""" if global_config is None: return "你是一个温暖、真诚的人。" - + personality = global_config.personality parts = [] - + if personality.personality_core: parts.append(personality.personality_core) - + if personality.personality_side: parts.append(personality.personality_side) - + if personality.identity: parts.append(personality.identity) - + return "\n\n".join(parts) if parts else "你是一个温暖、真诚的人。" - + def _build_safety_guidelines_block(self) -> str: """ 构建安全互动准则块 - + 从配置中读取 safety_guidelines,构建成提示词格式 """ if global_config is None: return "" - + safety_guidelines = global_config.personality.safety_guidelines if not safety_guidelines: return "" - + guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines)) return f"""在任何情况下,你都必须遵守以下由你的设定者为你定义的原则: {guidelines_text} 如果遇到违反上述原则的请求,请在保持你核心人设的同时,以合适的方式进行回应。""" - + def _build_combined_expression_block(self, learned_habits: str) -> str: """ 构建合并后的表达习惯块 - + 合并: - 说话风格(来自人设配置 personality.reply_style) - 表达习惯(来自学习系统) """ parts = [] - + # 1. 添加说话风格(来自配置) if global_config and global_config.personality.reply_style: parts.append(f"**说话风格**:\n你必须参考你的说话风格:\n{global_config.personality.reply_style}") - + # 2. 添加学习到的表达习惯 if learned_habits and learned_habits.strip(): # 如果 learned_habits 已经有标题,直接追加;否则添加标题 @@ -252,23 +252,23 @@ class PromptBuilder: parts.append("\n".join(content_lines).strip()) else: parts.append(learned_habits) - + if parts: return "\n\n".join(parts) - + return "" - + async def _build_context_data( self, user_name: str, chat_stream: Optional["ChatStream"], - user_id: Optional[str] = None, - session: Optional[KokoroSession] = None, + user_id: str | None = None, + session: KokoroSession | None = None, situation_type: str = "new_message", ) -> dict[str, str]: """ 使用 KFCContextBuilder 构建完整的上下文数据 - + 包括:关系信息、记忆、表达习惯等 """ if not chat_stream: @@ -278,15 +278,15 @@ class PromptBuilder: "tool_info": "", "expression_habits": "", } - + try: # 延迟导入上下文构建器 if self._context_builder is None: from ..context_builder import KFCContextBuilder self._context_builder = KFCContextBuilder - + builder = self._context_builder(chat_stream) - + # 获取用于记忆检索的查询文本 target_message = await self._get_memory_search_query( chat_stream=chat_stream, @@ -294,16 +294,16 @@ class PromptBuilder: situation_type=situation_type, user_name=user_name, ) - + context_data = await builder.build_all_context( sender_name=user_name, target_message=target_message, context=chat_stream.context, user_id=user_id, ) - + return context_data - + except Exception as e: logger.warning(f"构建上下文数据失败: {e}") return { @@ -312,34 +312,34 @@ class PromptBuilder: "tool_info": "", "expression_habits": "", } - + async def _get_memory_search_query( self, chat_stream: Optional["ChatStream"], - session: Optional[KokoroSession], + session: KokoroSession | None, situation_type: str, user_name: str, ) -> str: """ 根据场景类型获取合适的记忆搜索查询文本 - + 策略: 1. 优先使用未读消息(new_message/reply_in_time/reply_late) 2. 如果没有未读消息(timeout/proactive),使用最近的历史消息 3. 如果历史消息也为空,从 session 的 mental_log 中提取 4. 最后回退到用户名作为查询 - + Args: chat_stream: 聊天流对象 session: KokoroSession 会话对象 situation_type: 情况类型 user_name: 用户名称 - + Returns: 用于记忆搜索的查询文本 """ target_message = "" - + # 策略1: 优先从未读消息获取(适用于 new_message/reply_in_time/reply_late) if chat_stream and chat_stream.context: unread = chat_stream.context.get_unread_messages() @@ -348,7 +348,7 @@ class PromptBuilder: if target_message: logger.debug(f"[记忆搜索] 使用未读消息作为查询: {target_message[:50]}...") return target_message - + # 策略2: 从最近的历史消息获取(适用于 timeout/proactive) if chat_stream and chat_stream.context: history_messages = chat_stream.context.history_messages @@ -361,17 +361,17 @@ class PromptBuilder: recent_texts.append(content) if len(recent_texts) >= 3: break - + if recent_texts: target_message = " ".join(reversed(recent_texts)) logger.debug(f"[记忆搜索] 使用历史消息作为查询 (situation={situation_type}): {target_message[:80]}...") return target_message - + # 策略3: 从 session 的 mental_log 中提取(超时/主动思考场景的最后手段) if session and situation_type in ("timeout", "proactive"): entries = session.get_recent_entries(limit=10) recent_texts = [] - + for entry in reversed(entries): # 从用户消息中提取 if entry.event_type == EventType.USER_MESSAGE and entry.content: @@ -379,15 +379,15 @@ class PromptBuilder: # 从 bot 的预期反应中提取(可能包含相关话题) elif entry.event_type == EventType.BOT_PLANNING and entry.expected_reaction: recent_texts.append(entry.expected_reaction) - + if len(recent_texts) >= 3: break - + if recent_texts: target_message = " ".join(reversed(recent_texts)) logger.debug(f"[记忆搜索] 使用 mental_log 作为查询 (situation={situation_type}): {target_message[:80]}...") return target_message - + # 策略4: 最后回退 - 使用用户名 + 场景描述 if situation_type == "timeout": target_message = f"与 {user_name} 的对话 等待回复" @@ -395,54 +395,56 @@ class PromptBuilder: target_message = f"与 {user_name} 的对话 主动发起聊天" else: target_message = f"与 {user_name} 的对话" - + logger.debug(f"[记忆搜索] 使用回退查询 (situation={situation_type}): {target_message}") return target_message - - def _get_latest_user_message(self, session: Optional[KokoroSession]) -> str: + + def _get_latest_user_message(self, session: KokoroSession | None) -> str: """ 获取最新的用户消息内容 - + Args: session: KokoroSession 会话对象 - + Returns: 最新用户消息的内容,如果没有则返回提示文本 """ if not session: return "(未知消息)" - + # 从 mental_log 中获取最新的用户消息 entries = session.get_recent_entries(limit=10) for entry in reversed(entries): if entry.event_type == EventType.USER_MESSAGE and entry.content: return entry.content - + return "(消息内容不可用)" - + async def _build_chat_history_block( self, chat_stream: Optional["ChatStream"], ) -> str: """ 构建聊天历史总览块 - + 从 chat_stream 获取历史消息,格式化为可读的聊天记录 类似于 AFC 的已读历史板块 """ if not chat_stream: return "(暂无聊天记录)" - + try: - from src.chat.utils.chat_message_builder import build_readable_messages_with_id - from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + from src.chat.utils.chat_message_builder import ( + build_readable_messages_with_id, + get_raw_msg_before_timestamp_with_chat, + ) from src.common.data_models.database_data_model import DatabaseMessages - + stream_context = chat_stream.context - + # 获取已读消息 history_messages = stream_context.history_messages if stream_context else [] - + if not history_messages: # 如果内存中没有历史消息,从数据库加载 fallback_messages_dicts = await get_raw_msg_before_timestamp_with_chat( @@ -453,16 +455,16 @@ class PromptBuilder: history_messages = [ DatabaseMessages.from_dict(msg_dict) for msg_dict in fallback_messages_dicts ] - + if not history_messages: return "(暂无聊天记录)" - + # 过滤非文本消息(如戳一戳、禁言等系统通知) text_messages = self._filter_text_messages(history_messages) - + if not text_messages: return "(暂无聊天记录)" - + # 构建可读消息 chat_content, _ = await build_readable_messages_with_id( messages=[msg.flatten() for msg in text_messages[-30:]], # 最多30条 @@ -470,22 +472,22 @@ class PromptBuilder: truncate=False, show_actions=False, ) - + return chat_content if chat_content else "(暂无聊天记录)" - + except Exception as e: logger.warning(f"构建聊天历史块失败: {e}") return "(获取聊天记录失败)" - + def _filter_text_messages(self, messages: list) -> list: """ 过滤非文本消息 - + 移除系统通知消息(如戳一戳、禁言等),只保留正常的文本聊天消息 - + Args: messages: 消息列表(DatabaseMessages 对象) - + Returns: 过滤后的消息列表 """ @@ -494,16 +496,16 @@ class PromptBuilder: # 跳过系统通知消息(戳一戳、禁言等) if getattr(msg, "is_notify", False): continue - + # 跳过没有实际文本内容的消息 content = getattr(msg, "processed_plain_text", "") or getattr(msg, "display_message", "") if not content or not content.strip(): continue - + filtered.append(msg) - + return filtered - + async def _build_activity_stream( self, session: KokoroSession, @@ -511,26 +513,26 @@ class PromptBuilder: ) -> str: """ 构建活动流 - + 将 mental_log 中的事件按时间顺序转换为线性叙事 使用统一的 prompt 模板 """ entries = session.get_recent_entries(limit=30) if not entries: return "" - + parts = [] - + for entry in entries: part = await self._format_entry(entry, user_name) if part: parts.append(part) - + return "\n\n".join(parts) - + async def _format_entry(self, entry: MentalLogEntry, user_name: str) -> str: """格式化单个活动日志条目""" - + if entry.event_type == EventType.USER_MESSAGE: # 用户消息 result = await global_prompt_manager.format_prompt( @@ -539,7 +541,7 @@ class PromptBuilder: user_name=entry.user_name or user_name, content=entry.content, ) - + # 如果有回复状态元数据,添加说明 reply_status = entry.metadata.get("reply_status") if reply_status == "in_time": @@ -558,13 +560,13 @@ class PromptBuilder: elapsed_minutes=elapsed, max_wait_minutes=max_wait, ) - + return result - + elif entry.event_type == EventType.BOT_PLANNING: # Bot 规划 actions_desc = self._format_actions(entry.actions) - + if entry.max_wait_seconds > 0: return await global_prompt_manager.format_prompt( PROMPT_NAMES["entry_bot_planning"], @@ -579,7 +581,7 @@ class PromptBuilder: thought=entry.thought or "(没有特别的想法)", actions_description=actions_desc, ) - + elif entry.event_type == EventType.WAITING_UPDATE: # 等待中心理变化 return await global_prompt_manager.format_prompt( @@ -587,7 +589,7 @@ class PromptBuilder: elapsed_minutes=entry.elapsed_seconds / 60, waiting_thought=entry.waiting_thought or "还在等...", ) - + elif entry.event_type == EventType.PROACTIVE_TRIGGER: # 主动思考触发 silence = entry.metadata.get("silence_duration", "一段时间") @@ -595,18 +597,18 @@ class PromptBuilder: PROMPT_NAMES["entry_proactive_trigger"], silence_duration=silence, ) - + return "" - + def _format_actions(self, actions: list[dict]) -> str: """格式化动作列表为可读描述""" if not actions: return "(无动作)" - + descriptions = [] for action in actions: action_type = action.get("type", "unknown") - + if action_type == "kfc_reply": content = action.get("content", "") if len(content) > 50: @@ -621,9 +623,9 @@ class PromptBuilder: descriptions.append(f"发送表情:{emoji}") else: descriptions.append(f"执行动作:{action_type}") - + return "、".join(descriptions) - + async def _build_current_situation( self, session: KokoroSession, @@ -633,13 +635,13 @@ class PromptBuilder: ) -> str: """构建当前情况描述""" current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") - + # 如果之前没有设置等待时间(max_wait_seconds == 0),视为 new_message if situation_type in ("reply_in_time", "reply_late"): max_wait = session.waiting_config.max_wait_seconds if max_wait <= 0: situation_type = "new_message" - + if situation_type == "new_message": # 获取最新消息内容 latest_message = self._get_latest_user_message(session) @@ -649,7 +651,7 @@ class PromptBuilder: user_name=user_name, latest_message=latest_message, ) - + elif situation_type == "reply_in_time": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -662,7 +664,7 @@ class PromptBuilder: max_wait_minutes=max_wait / 60, latest_message=latest_message, ) - + elif situation_type == "reply_late": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -675,29 +677,29 @@ class PromptBuilder: max_wait_minutes=max_wait / 60, latest_message=latest_message, ) - + elif situation_type == "timeout": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds expected = session.waiting_config.expected_reaction - + # 构建连续超时上下文 timeout_context_parts = [] - + # 添加真正追问次数警告(只有真正发了消息才算追问) followup_count = extra_context.get("followup_count", 0) if followup_count > 0: timeout_context_parts.append(f"⚠️ 你已经连续追问了 {followup_count} 次,对方仍未回复。再追问可能会显得太急躁,请三思。") - + # 添加距离用户上次回复的时间 time_since_user_reply_str = extra_context.get("time_since_user_reply_str") if time_since_user_reply_str: timeout_context_parts.append(f"距离 {user_name} 上一次回复你已经过去了 {time_since_user_reply_str}。") - + timeout_context = "\n".join(timeout_context_parts) if timeout_context: timeout_context = "\n" + timeout_context + "\n" - + return await global_prompt_manager.format_prompt( PROMPT_NAMES["situation_timeout"], current_time=current_time, @@ -707,7 +709,7 @@ class PromptBuilder: expected_reaction=expected or "对方能回复点什么", timeout_context=timeout_context, ) - + elif situation_type == "proactive": silence = extra_context.get("silence_duration", "一段时间") trigger_reason = extra_context.get("trigger_reason", "") @@ -718,18 +720,18 @@ class PromptBuilder: silence_duration=silence, trigger_reason=trigger_reason, ) - + # 默认使用 new_message return await global_prompt_manager.format_prompt( PROMPT_NAMES["situation_new_message"], current_time=current_time, user_name=user_name, ) - - def _build_actions_block(self, available_actions: Optional[dict]) -> str: + + def _build_actions_block(self, available_actions: dict | None) -> str: """ 构建可用动作块 - + 参考 AFC planner 的格式,为每个动作展示: - 动作名和描述 - 使用场景 @@ -737,45 +739,45 @@ class PromptBuilder: """ if not available_actions: return self._get_default_actions_block() - + action_blocks = [] for action_name, action_info in available_actions.items(): block = self._format_single_action(action_name, action_info) if block: action_blocks.append(block) - + return "\n".join(action_blocks) if action_blocks else self._get_default_actions_block() - + def _format_single_action(self, action_name: str, action_info) -> str: """ 格式化单个动作为详细说明块 - + Args: action_name: 动作名称 action_info: ActionInfo 对象 - + Returns: 格式化后的动作说明 """ # 获取动作描述 description = getattr(action_info, "description", "") or f"执行 {action_name}" - + # 获取使用场景 action_require = getattr(action_info, "action_require", []) or [] require_text = "\n".join(f" - {req}" for req in action_require) if action_require else " - 根据情况使用" - + # 获取参数定义 action_parameters = getattr(action_info, "action_parameters", {}) or {} - + # 构建 action_data JSON 示例 if action_parameters: param_lines = [] for param_name, param_desc in action_parameters.items(): param_lines.append(f' "{param_name}": "<{param_desc}>"') - action_data_json = "{\n" + ",\n".join(param_lines) + "\n }" + "{\n" + ",\n".join(param_lines) + "\n }" else: - action_data_json = "{}" - + pass + # 构建完整的动作块 return f"""### {action_name} **描述**: {description} @@ -787,22 +789,22 @@ class PromptBuilder: ```json {{ "type": "{action_name}", - {f'"content": "<你要说的内容>"' if action_name == "kfc_reply" else self._build_params_example(action_parameters)} + {'"content": "<你要说的内容>"' if action_name == "kfc_reply" else self._build_params_example(action_parameters)} }} ``` """ - + def _build_params_example(self, action_parameters: dict) -> str: """构建参数示例字符串""" if not action_parameters: return '"_comment": "此动作无需额外参数"' - + parts = [] for param_name, param_desc in action_parameters.items(): parts.append(f'"{param_name}": "<{param_desc}>"') - + return ",\n ".join(parts) - + def _get_default_actions_block(self) -> str: """获取默认的动作列表""" return """### kfc_reply @@ -832,7 +834,7 @@ class PromptBuilder: "type": "do_nothing" } ```""" - + async def _get_output_format(self) -> str: """获取输出格式模板""" try: @@ -849,7 +851,7 @@ class PromptBuilder: "expected_reaction": "期待的反应", "max_wait_seconds": 300 }""" - + async def _get_planner_output_format(self) -> str: """获取规划器输出格式模板""" try: @@ -868,7 +870,7 @@ class PromptBuilder: } 注意:kfc_reply 动作不需要填写 content 字段,回复内容会单独生成。""" - + async def _build_replyer_situation( self, session: KokoroSession, @@ -878,16 +880,16 @@ class PromptBuilder: ) -> str: """ 构建回复器专用的当前情况描述 - + 与 Planner 的 _build_current_situation 不同,这里不包含决策性语言, 只描述当前的情景背景 """ from datetime import datetime current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") - + if situation_type == "new_message": return f"现在是 {current_time}。{user_name} 刚给你发了消息。" - + elif situation_type == "reply_in_time": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -897,7 +899,7 @@ class PromptBuilder: f"等了大约 {elapsed / 60:.1f} 分钟(你原本打算最多等 {max_wait / 60:.1f} 分钟)。" f"现在 {user_name} 回复了!" ) - + elif situation_type == "reply_late": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -907,7 +909,7 @@ class PromptBuilder: f"你原本打算最多等 {max_wait / 60:.1f} 分钟,但实际等了 {elapsed / 60:.1f} 分钟才收到回复。" f"虽然有点迟,但 {user_name} 终于回复了。" ) - + elif situation_type == "timeout": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -917,7 +919,7 @@ class PromptBuilder: f"你原本打算最多等 {max_wait / 60:.1f} 分钟,现在已经等了 {elapsed / 60:.1f} 分钟了,对方还是没回。" f"你决定主动说点什么。" ) - + elif situation_type == "proactive": silence = extra_context.get("silence_duration", "一段时间") return ( @@ -925,10 +927,10 @@ class PromptBuilder: f"你和 {user_name} 已经有一段时间没聊天了(沉默了 {silence})。" f"你决定主动找 {user_name} 聊点什么。" ) - + # 默认 return f"现在是 {current_time}。" - + async def _build_reply_context( self, session: KokoroSession, @@ -938,7 +940,7 @@ class PromptBuilder: ) -> str: """ 构建回复情景上下文 - + 根据 situation_type 构建不同的情景描述,帮助回复器理解当前要回复的情境 """ # 获取最后一条用户消息 @@ -948,14 +950,14 @@ class PromptBuilder: if entry.event_type == EventType.USER_MESSAGE: target_message = entry.content or "" break - + if situation_type == "new_message": return await global_prompt_manager.format_prompt( PROMPT_NAMES["replyer_context_normal"], user_name=user_name, target_message=target_message or "(无消息内容)", ) - + elif situation_type == "reply_in_time": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -966,7 +968,7 @@ class PromptBuilder: elapsed_minutes=elapsed / 60, max_wait_minutes=max_wait / 60, ) - + elif situation_type == "reply_late": elapsed = session.waiting_config.get_elapsed_seconds() max_wait = session.waiting_config.max_wait_seconds @@ -977,7 +979,7 @@ class PromptBuilder: elapsed_minutes=elapsed / 60, max_wait_minutes=max_wait / 60, ) - + elif situation_type == "proactive": silence = extra_context.get("silence_duration", "一段时间") trigger_reason = extra_context.get("trigger_reason", "") @@ -987,30 +989,30 @@ class PromptBuilder: silence_duration=silence, trigger_reason=trigger_reason, ) - + # 默认使用普通情景 return await global_prompt_manager.format_prompt( PROMPT_NAMES["replyer_context_normal"], user_name=user_name, target_message=target_message or "(无消息内容)", ) - + async def build_unified_prompt( self, session: KokoroSession, user_name: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - available_actions: Optional[dict] = None, - extra_context: Optional[dict] = None, + available_actions: dict | None = None, + extra_context: dict | None = None, ) -> str: """ 构建统一模式提示词(单次 LLM 调用完成思考 + 回复生成) - + 与 planner_prompt 的区别: - 使用完整的输出格式(要求填写 content 字段) - 不使用分离的 replyer 提示词 - + Args: session: 会话对象 user_name: 用户名称 @@ -1018,45 +1020,45 @@ class PromptBuilder: chat_stream: 聊天流对象 available_actions: 可用动作字典 extra_context: 额外上下文 - + Returns: 完整的统一模式提示词 """ extra_context = extra_context or {} - + # 获取 user_id user_id = session.user_id if session else None - + # 1. 构建人设块 persona_block = self._build_persona_block() - + # 1.5. 构建安全互动准则块 safety_guidelines_block = self._build_safety_guidelines_block() - + # 2. 使用 context_builder 获取关系、记忆、表达习惯等 context_data = await self._build_context_data(user_name, chat_stream, user_id) relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。") memory_block = context_data.get("memory_block", "") tool_info = context_data.get("tool_info", "") expression_habits = self._build_combined_expression_block(context_data.get("expression_habits", "")) - + # 3. 构建活动流 activity_stream = await self._build_activity_stream(session, user_name) - + # 4. 构建当前情况 current_situation = await self._build_current_situation( session, user_name, situation_type, extra_context ) - + # 5. 构建聊天历史总览 chat_history_block = await self._build_chat_history_block(chat_stream) - + # 6. 构建可用动作(统一模式强调需要填写 content) actions_block = self._build_unified_actions_block(available_actions) - + # 7. 获取统一模式输出格式(要求填写 content) output_format = await self._get_unified_output_format() - + # 8. 使用统一的 prompt 管理系统格式化 prompt = await global_prompt_manager.format_prompt( PROMPT_NAMES["main"], @@ -1073,33 +1075,33 @@ class PromptBuilder: available_actions=actions_block, output_format=output_format, ) - + return prompt - - def _build_unified_actions_block(self, available_actions: Optional[dict]) -> str: + + def _build_unified_actions_block(self, available_actions: dict | None) -> str: """ 构建统一模式的可用动作块 - + 与 _build_actions_block 的区别: - 强调 kfc_reply 需要填写 content 字段 """ if not available_actions: return self._get_unified_default_actions_block() - + action_blocks = [] for action_name, action_info in available_actions.items(): block = self._format_unified_action(action_name, action_info) if block: action_blocks.append(block) - + return "\n".join(action_blocks) if action_blocks else self._get_unified_default_actions_block() - + def _format_unified_action(self, action_name: str, action_info) -> str: """格式化统一模式的单个动作""" description = getattr(action_info, "description", "") or f"执行 {action_name}" action_require = getattr(action_info, "action_require", []) or [] require_text = "\n".join(f" - {req}" for req in action_require) if action_require else " - 根据情况使用" - + # 统一模式要求 kfc_reply 必须填写 content if action_name == "kfc_reply": return f"""### {action_name} @@ -1119,7 +1121,7 @@ class PromptBuilder: else: action_parameters = getattr(action_info, "action_parameters", {}) or {} params_example = self._build_params_example(action_parameters) - + return f"""### {action_name} **描述**: {description} @@ -1134,7 +1136,7 @@ class PromptBuilder: }} ``` """ - + def _get_unified_default_actions_block(self) -> str: """获取统一模式的默认动作列表""" return """### kfc_reply @@ -1164,7 +1166,7 @@ class PromptBuilder: "type": "do_nothing" } ```""" - + async def _get_unified_output_format(self) -> str: """获取统一模式的输出格式模板""" try: @@ -1202,7 +1204,7 @@ class PromptBuilder: # 全局单例 -_prompt_builder: Optional[PromptBuilder] = None +_prompt_builder: PromptBuilder | None = None def get_prompt_builder() -> PromptBuilder: diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py b/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py index d3ecdf509..1cf506097 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream from src.common.data_models.message_manager_data_model import StreamContext -from .models import MentalLogEntry, EventType +from .models import EventType, MentalLogEntry from .session import KokoroSession logger = get_logger("kfc_prompt_unified") @@ -41,31 +41,31 @@ logger = get_logger("kfc_prompt_unified") def build_identity_module() -> str: """ 构建核心身份模块 - + 包含:人格核心、人格细节、外貌身份、世界观背景 这些是回复内容的基础,决定了"你是谁" """ if global_config is None: raise RuntimeError("global_config 未初始化") - + personality_cfg = global_config.personality - + personality_core = personality_cfg.personality_core or "你是一个温暖、真诚的对话伙伴。" personality_side = personality_cfg.personality_side or "" identity = personality_cfg.identity or "" background_story = personality_cfg.background_story or "" - + parts = [personality_core] - + if personality_side: parts.append(f"\n{personality_side}") - + if identity: parts.append(f"\n{identity}") - + if background_story: parts.append(f"\n### 背景故事\n{background_story}") - + return "\n".join(parts) @@ -76,25 +76,25 @@ def build_identity_module() -> str: def build_rules_module() -> str: """ 构建行为准则模块 - + 包含:安全规则 + KFC 通用准则 这些是行为的边界,决定了"什么能做什么不能做" """ if global_config is None: raise RuntimeError("global_config 未初始化") - + safety_guidelines = global_config.personality.safety_guidelines or [] if isinstance(safety_guidelines, list): safety_guidelines_str = "\n".join(f"- {rule}" for rule in safety_guidelines) else: safety_guidelines_str = str(safety_guidelines) - + # KFC 通用准则 kfc_rules = """- 说话要自然,像平时发微信一样 - 有自己的边界感,不用无条件迎合 - 称呼对方用自然的昵称,别念复杂的用户名 - 不要模仿聊天记录里的系统格式(比如"[表情包:xxx]"这种是系统标记,要发送表情包请使用emoji动作)""" - + return f"""{safety_guidelines_str} {kfc_rules}""" @@ -107,66 +107,66 @@ def build_rules_module() -> str: def build_context_module( session: KokoroSession, chat_stream: Optional["ChatStream"] = None, - context_data: Optional[dict[str, str]] = None, + context_data: dict[str, str] | None = None, ) -> str: """ 构建情境上下文模块 - + 包含:当前时间、聊天场景、内在状态、关系信息、记忆 这些是回复的上下文,决定了"当前在什么情况下" - + Args: session: 当前会话 chat_stream: 聊天流(判断群聊/私聊) context_data: S4U 上下文数据 """ context_data = context_data or {} - + # 时间和场景 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M:%S") is_group_chat = bool(chat_stream and chat_stream.group_info) chat_scene = "你在群里聊天" if is_group_chat else "你在和对方私聊" - + # 日程(如果有) schedule_block = context_data.get("schedule", "") - + # 内在状态(从 context_data 获取,如果没有使用默认值) mood = context_data.get("mood", "平静") - + # 关系信息 relation_info = context_data.get("relation_info", "") - + # 记忆 memory_block = context_data.get("memory_block", "") - + # 工具调用结果 tool_info = context_data.get("tool_info", "") - + parts = [] - + # 时间和场景 parts.append(f"**时间**: {current_time}") parts.append(f"**场景**: {chat_scene}") - + # 日程块 if schedule_block: parts.append(f"{schedule_block}") - + # 内在状态 parts.append(f"\n你现在的心情:{mood}") - + # 关系信息 if relation_info: parts.append(f"\n## 4. 你和对方的关系\n{relation_info}") - + # 记忆 if memory_block: parts.append(f"\n{memory_block}") - + # 工具调用结果 if tool_info: parts.append(f"\n{tool_info}") - + return "\n".join(parts) @@ -174,19 +174,19 @@ def build_context_module( # 模块4: 动作能力 - 可用动作的描述 # ============================================================ -def build_actions_module(available_actions: Optional[dict[str, ActionInfo]] = None) -> str: +def build_actions_module(available_actions: dict[str, ActionInfo] | None = None) -> str: """ 构建动作能力模块 - + 包含:所有可用动作的描述、参数、示例 这部分与回复内容分离,只描述"能做什么" - + Args: available_actions: 可用动作字典 """ if not available_actions: return _get_default_actions_block() - + # 核心限制说明(放在最前面) action_blocks = [ """⚠️ **输出限制(必须遵守)**: @@ -195,37 +195,37 @@ def build_actions_module(available_actions: Optional[dict[str, ActionInfo]] = No 3. 系统会自动把你的回复拆分成多条消息发送,你不需要自己分段 """ ] - + for action_name, action_info in available_actions.items(): description = action_info.description or f"执行 {action_name}" - + # 构建动作块 action_block = f"### `{action_name}` - {description}" - + # 对 kfc_reply 特殊处理,再次强调限制 if action_name == "kfc_reply": action_block += "\n(只能有一个,内容写完整)" - + # 参数说明(如果有) if action_info.action_parameters: params_lines = [f" - `{name}`: {desc}" for name, desc in action_info.action_parameters.items()] action_block += f"\n参数:\n{chr(10).join(params_lines)}" - + # 使用场景(如果有) if action_info.action_require: require_lines = [f" - {req}" for req in action_info.action_require] action_block += f"\n使用场景:\n{chr(10).join(require_lines)}" - + # 示例 example_params = "" if action_info.action_parameters: param_examples = [f'"{name}": "..."' for name in action_info.action_parameters.keys()] example_params = ", " + ", ".join(param_examples) - + action_block += f'\n```json\n{{"type": "{action_name}"{example_params}}}\n```' - + action_blocks.append(action_block) - + return "\n\n".join(action_blocks) @@ -256,24 +256,24 @@ def _get_default_actions_block() -> str: # 模块5: 表达与输出格式 - 回复风格 + JSON格式 # ============================================================ -def build_output_module(context_data: Optional[dict[str, str]] = None) -> str: +def build_output_module(context_data: dict[str, str] | None = None) -> str: """ 构建输出格式模块 - + 包含:表达风格、表达习惯、JSON 输出格式要求 这部分定义了"怎么说"和"输出什么格式" - + Args: context_data: S4U 上下文数据(包含 expression_habits) """ if global_config is None: raise RuntimeError("global_config 未初始化") - + context_data = context_data or {} - + reply_style = global_config.personality.reply_style or "" expression_habits = context_data.get("expression_habits", "") - + # JSON 输出格式说明(更自然的思考引导) json_format = """### 输出格式(JSON) @@ -307,17 +307,17 @@ def build_output_module(context_data: Optional[dict[str, str]] = None) -> str: - 不用分析得很正式 不想做任何事就用 `{"type": "do_nothing"}`""" - + parts = [] - + if reply_style: parts.append(reply_style) - + if expression_habits: parts.append(expression_habits) - + parts.append(json_format) - + return "\n\n".join(parts) @@ -327,13 +327,13 @@ def build_output_module(context_data: Optional[dict[str, str]] = None) -> str: def build_system_prompt( session: KokoroSession, - available_actions: Optional[dict[str, ActionInfo]] = None, - context_data: Optional[dict[str, str]] = None, + available_actions: dict[str, ActionInfo] | None = None, + context_data: dict[str, str] | None = None, chat_stream: Optional["ChatStream"] = None, ) -> str: """ 组装完整的系统提示词 - + 按模块顺序组装: 1. 角色框架 2. 核心身份(人设) @@ -341,19 +341,19 @@ def build_system_prompt( 4. 情境上下文(时间/状态/关系) 5. 动作能力(可用动作) 6. 输出格式(表达风格 + JSON) - + Args: session: 当前会话 available_actions: 可用动作字典 context_data: S4U 上下文数据 chat_stream: 聊天流 - + Returns: 完整的系统提示词 """ # 角色框架(不提AI,直接进入角色) role_frame = "你是一个真实的人,有自己的生活、想法和情绪。像平时和朋友聊天一样自然地交流。" - + # 组装各模块 modules = [ role_frame, @@ -373,7 +373,7 @@ def build_system_prompt( "## 5. 怎么回复", build_output_module(context_data), ] - + return "\n".join(modules) @@ -452,27 +452,27 @@ def format_narrative_history( ) -> str: """ 将心理活动日志格式化为叙事历史 - + Args: mental_log: 心理活动日志列表 max_entries: 最大条目数 - + Returns: str: 格式化的叙事历史文本 """ if not mental_log: return "(这是对话的开始,还没有历史记录)" - + # 获取最近的日志条目 recent_entries = mental_log[-max_entries:] - + narrative_parts = [] for entry in recent_entries: timestamp_str = time.strftime( - "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M:%S", time.localtime(entry.timestamp) ) - + if entry.event_type == EventType.USER_MESSAGE: user_name = entry.user_name or "用户" narrative_parts.append( @@ -497,63 +497,63 @@ def format_narrative_history( narrative_parts.append( f"[{timestamp_str}] (等待中的想法:{entry.waiting_thought})" ) - + return "\n".join(narrative_parts) if narrative_parts else "(这是对话的开始,还没有历史记录)" def format_history_from_context( context: "StreamContext", - mental_log: Optional[list[MentalLogEntry]] = None, + mental_log: list[MentalLogEntry] | None = None, ) -> str: """ 从 StreamContext 的历史消息构建叙事历史 - + 这是实现"无缝融入"的关键: - 从同一个数据库读取历史消息(与AFC共享) - 遵循全局配置 [chat].max_context_size - 将消息串渲染成KFC的叙事体格式 - + Args: context: 聊天流上下文,包含共享的历史消息 mental_log: 可选的心理活动日志,用于补充内心独白 - + Returns: str: 格式化的叙事历史文本 """ # 从 StreamContext 获取历史消息,遵循全局上下文长度配置 max_context = 25 # 默认值 - if global_config and hasattr(global_config, 'chat') and global_config.chat: + if global_config and hasattr(global_config, "chat") and global_config.chat: max_context = getattr(global_config.chat, "max_context_size", 25) history_messages = context.get_messages(limit=max_context, include_unread=False) - + if not history_messages and not mental_log: return "(这是对话的开始,还没有历史记录)" - + # 获取Bot的用户ID用于判断消息来源 bot_user_id = None - if global_config and hasattr(global_config, 'bot') and global_config.bot: - bot_user_id = str(getattr(global_config.bot, 'qq_account', '')) - + if global_config and hasattr(global_config, "bot") and global_config.bot: + bot_user_id = str(getattr(global_config.bot, "qq_account", "")) + narrative_parts = [] - + # 首先,将数据库历史消息转换为叙事格式 for msg in history_messages: timestamp_str = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(msg.time or time.time()) ) - + # 判断是用户消息还是Bot消息 msg_user_id = str(msg.user_info.user_id) if msg.user_info else "" is_bot_message = bot_user_id and msg_user_id == bot_user_id content = msg.processed_plain_text or msg.display_message or "" - + if is_bot_message: narrative_parts.append(f"[{timestamp_str}] 你回复:{content}") else: sender_name = msg.user_info.user_nickname if msg.user_info else "用户" narrative_parts.append(f"[{timestamp_str}] {sender_name}说:{content}") - + # 然后,补充 mental_log 中的内心独白(如果有) if mental_log: for entry in mental_log[-5:]: # 只取最近5条心理活动 @@ -561,10 +561,10 @@ def format_history_from_context( "%Y-%m-%d %H:%M:%S", time.localtime(entry.timestamp) ) - + if entry.event_type == EventType.BOT_PLANNING and entry.thought: narrative_parts.append(f"[{timestamp_str}] (你的内心:{entry.thought})") - + return "\n".join(narrative_parts) if narrative_parts else "(这是对话的开始,还没有历史记录)" @@ -572,45 +572,45 @@ def format_incoming_messages( message_content: str, sender_name: str, sender_id: str, - message_time: Optional[float] = None, - all_unread_messages: Optional[list] = None, + message_time: float | None = None, + all_unread_messages: list | None = None, ) -> str: """ 格式化收到的消息 - + 支持单条消息(兼容旧调用)和多条消息(打断合并场景) - + Args: message_content: 主消息内容 sender_name: 发送者名称 sender_id: 发送者ID message_time: 消息时间戳 all_unread_messages: 所有未读消息列表 - + Returns: str: 格式化的消息文本 """ if message_time is None: message_time = time.time() - + # 如果有多条消息,格式化为消息组 if all_unread_messages and len(all_unread_messages) > 1: lines = [f"**用户连续发送了 {len(all_unread_messages)} 条消息:**\n"] - + for i, msg in enumerate(all_unread_messages, 1): msg_time = msg.time or time.time() msg_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg_time)) msg_sender = msg.user_info.user_nickname if msg.user_info else sender_name msg_content = msg.processed_plain_text or msg.display_message or "" - + lines.append(f"[{i}] 来自:{msg_sender}") lines.append(f" 时间:{msg_time_str}") lines.append(f" 内容:{msg_content}") lines.append("") - + lines.append("**提示**:请综合理解这些消息的整体意图,不需要逐条回复。") return "\n".join(lines) - + # 单条消息(兼容旧格式) message_time_str = time.strftime( "%Y-%m-%d %H:%M:%S", diff --git a/src/plugins/built_in/kokoro_flow_chatter/replyer.py b/src/plugins/built_in/kokoro_flow_chatter/replyer.py index 06a63d80a..7d84eec0b 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/replyer.py +++ b/src/plugins/built_in/kokoro_flow_chatter/replyer.py @@ -27,11 +27,11 @@ async def generate_reply_text( thought: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - extra_context: Optional[dict] = None, + extra_context: dict | None = None, ) -> tuple[bool, str]: """ 生成回复文本 - + Args: session: 会话对象 user_name: 用户名称 @@ -39,7 +39,7 @@ async def generate_reply_text( situation_type: 情况类型 chat_stream: 聊天流对象 extra_context: 额外上下文 - + Returns: (success, reply_text) 元组 - success: 是否成功生成 @@ -56,37 +56,37 @@ async def generate_reply_text( chat_stream=chat_stream, extra_context=extra_context, ) - + from src.config.config import global_config if global_config and global_config.debug.show_prompt: logger.info(f"[KFC Replyer] 生成的回复提示词:\n{prompt}") - + # 2. 获取 replyer 模型配置并调用 LLM models = llm_api.get_available_models() replyer_config = models.get("replyer") - + if not replyer_config: logger.error("[KFC Replyer] 未找到 replyer 模型配置") return False, "(回复生成失败:未找到模型配置)" - - success, raw_response, reasoning, model_name = await llm_api.generate_with_model( + + success, raw_response, _reasoning, _model_name = await llm_api.generate_with_model( prompt=prompt, model_config=replyer_config, request_type="kokoro_flow_chatter.reply", ) - + if not success: logger.error(f"[KFC Replyer] LLM 调用失败: {raw_response}") return False, "(回复生成失败)" - + # 3. 清理并返回回复文本 reply_text = _clean_reply_text(raw_response) - + # 使用 logger 输出美化日志(颜色通过 logger 系统配置) logger.info(f"💬 {reply_text}") - + return True, reply_text - + except Exception as e: logger.error(f"[KFC Replyer] 生成失败: {e}") import traceback @@ -97,28 +97,28 @@ async def generate_reply_text( def _clean_reply_text(raw_text: str) -> str: """ 清理回复文本 - + 移除可能的前后缀、引号、markdown 标记等 """ text = raw_text.strip() - + # 移除可能的 markdown 代码块标记 if text.startswith("```") and text.endswith("```"): lines = text.split("\n") if len(lines) >= 3: # 移除首尾的 ``` 行 text = "\n".join(lines[1:-1]).strip() - + # 移除首尾的引号(如果整个文本被引号包裹) if (text.startswith('"') and text.endswith('"')) or \ (text.startswith("'") and text.endswith("'")): text = text[1:-1].strip() - + # 移除可能的"你说:"、"回复:"等前缀 prefixes_to_remove = ["你说:", "你说:", "回复:", "回复:", "我说:", "我说:"] for prefix in prefixes_to_remove: if text.startswith(prefix): text = text[len(prefix):].strip() break - + return text diff --git a/src/plugins/built_in/kokoro_flow_chatter/session.py b/src/plugins/built_in/kokoro_flow_chatter/session.py index b69adf274..d6df79af6 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/session.py +++ b/src/plugins/built_in/kokoro_flow_chatter/session.py @@ -29,17 +29,17 @@ logger = get_logger("kfc_session") class KokoroSession: """ Kokoro Flow Chatter 会话 - + 为每个私聊用户维护一个独立的会话,包含: - 基本信息(user_id, stream_id) - 状态(只有 IDLE 和 WAITING) - 心理活动历史(mental_log) - 等待配置(waiting_config) """ - + # 心理活动日志最大保留条数 MAX_MENTAL_LOG_SIZE = 50 - + def __init__( self, user_id: str, @@ -47,62 +47,62 @@ class KokoroSession: ): self.user_id = user_id self.stream_id = stream_id - + # 状态(只有 IDLE 和 WAITING) self._status: SessionStatus = SessionStatus.IDLE - + # 心理活动历史 self.mental_log: list[MentalLogEntry] = [] - + # 等待配置 self.waiting_config: WaitingConfig = WaitingConfig() - + # 时间戳 self.created_at: float = time.time() self.last_activity_at: float = time.time() - + # 统计 self.total_interactions: int = 0 - + # 上次主动思考时间 - self.last_proactive_at: Optional[float] = None - + self.last_proactive_at: float | None = None + # 连续超时计数(用于避免过度打扰用户) self.consecutive_timeout_count: int = 0 - + # 用户最后发消息的时间(用于计算距离用户上次回复的时间) - self.last_user_message_at: Optional[float] = None - + self.last_user_message_at: float | None = None + @property def status(self) -> SessionStatus: return self._status - + @status.setter def status(self, value: SessionStatus) -> None: old_status = self._status self._status = value if old_status != value: logger.debug(f"Session {self.user_id} 状态变更: {old_status} → {value}") - + def add_entry(self, entry: MentalLogEntry) -> None: """添加心理活动日志条目""" self.mental_log.append(entry) self.last_activity_at = time.time() - + # 保持日志在合理大小 if len(self.mental_log) > self.MAX_MENTAL_LOG_SIZE: self.mental_log = self.mental_log[-self.MAX_MENTAL_LOG_SIZE:] - + def add_user_message( self, content: str, user_name: str, user_id: str, - timestamp: Optional[float] = None, + timestamp: float | None = None, ) -> MentalLogEntry: """添加用户消息事件""" msg_time = timestamp or time.time() - + entry = MentalLogEntry( event_type=EventType.USER_MESSAGE, timestamp=msg_time, @@ -110,16 +110,16 @@ class KokoroSession: user_name=user_name, user_id=user_id, ) - + # 收到用户消息,重置连续超时计数 self.consecutive_timeout_count = 0 self.last_user_message_at = msg_time - + # 如果之前在等待,记录收到回复的情况 if self.status == SessionStatus.WAITING and self.waiting_config.is_active(): elapsed = self.waiting_config.get_elapsed_seconds() max_wait = self.waiting_config.max_wait_seconds - + if elapsed <= max_wait: entry.metadata["reply_status"] = "in_time" entry.metadata["elapsed_seconds"] = elapsed @@ -128,17 +128,17 @@ class KokoroSession: entry.metadata["reply_status"] = "late" entry.metadata["elapsed_seconds"] = elapsed entry.metadata["max_wait_seconds"] = max_wait - + self.add_entry(entry) return entry - + def add_bot_planning( self, thought: str, actions: list[dict], expected_reaction: str = "", max_wait_seconds: int = 0, - timestamp: Optional[float] = None, + timestamp: float | None = None, ) -> MentalLogEntry: """添加 Bot 规划事件""" entry = MentalLogEntry( @@ -152,12 +152,12 @@ class KokoroSession: self.add_entry(entry) self.total_interactions += 1 return entry - + def add_waiting_update( self, waiting_thought: str, mood: str = "", - timestamp: Optional[float] = None, + timestamp: float | None = None, ) -> MentalLogEntry: """添加等待期间的心理变化""" entry = MentalLogEntry( @@ -169,7 +169,7 @@ class KokoroSession: ) self.add_entry(entry) return entry - + def start_waiting( self, expected_reaction: str, @@ -181,7 +181,7 @@ class KokoroSession: self.status = SessionStatus.IDLE self.waiting_config.reset() return - + self.status = SessionStatus.WAITING self.waiting_config = WaitingConfig( expected_reaction=expected_reaction, @@ -194,19 +194,19 @@ class KokoroSession: f"Session {self.user_id} 开始等待: " f"max_wait={max_wait_seconds}s, expected={expected_reaction[:30]}..." ) - + def end_waiting(self) -> None: """结束等待""" self.status = SessionStatus.IDLE self.waiting_config.reset() # 更新活动时间,防止 ProactiveThinker 并发处理 self.last_activity_at = time.time() - + def get_recent_entries(self, limit: int = 20) -> list[MentalLogEntry]: """获取最近的心理活动日志""" return self.mental_log[-limit:] if self.mental_log else [] - - def get_last_bot_message(self) -> Optional[str]: + + def get_last_bot_message(self) -> str | None: """获取最后一条 Bot 发送的消息""" for entry in reversed(self.mental_log): if entry.event_type == EventType.BOT_PLANNING: @@ -214,7 +214,7 @@ class KokoroSession: if action.get("type") in ("kfc_reply", "respond"): return action.get("content", "") return None - + def to_dict(self) -> dict: """转换为字典(用于持久化)""" return { @@ -230,7 +230,7 @@ class KokoroSession: "consecutive_timeout_count": self.consecutive_timeout_count, "last_user_message_at": self.last_user_message_at, } - + @classmethod def from_dict(cls, data: dict) -> "KokoroSession": """从字典创建会话""" @@ -238,49 +238,49 @@ class KokoroSession: user_id=data.get("user_id", ""), stream_id=data.get("stream_id", ""), ) - + # 状态 status_str = data.get("status", "idle") try: session._status = SessionStatus(status_str) except ValueError: session._status = SessionStatus.IDLE - + # 心理活动历史 mental_log_data = data.get("mental_log", []) session.mental_log = [MentalLogEntry.from_dict(e) for e in mental_log_data] - + # 等待配置 waiting_data = data.get("waiting_config", {}) session.waiting_config = WaitingConfig.from_dict(waiting_data) - + # 时间戳 session.created_at = data.get("created_at", time.time()) session.last_activity_at = data.get("last_activity_at", time.time()) session.total_interactions = data.get("total_interactions", 0) session.last_proactive_at = data.get("last_proactive_at") - + # 连续超时相关 session.consecutive_timeout_count = data.get("consecutive_timeout_count", 0) session.last_user_message_at = data.get("last_user_message_at") - + return session class SessionManager: """ 会话管理器 - + 负责会话的创建、获取、保存和清理 """ - + _instance: Optional["SessionManager"] = None - + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance - + def __init__( self, data_dir: str = "data/kokoro_flow_chatter/sessions", @@ -288,31 +288,31 @@ class SessionManager: ): if hasattr(self, "_initialized") and self._initialized: return - + self._initialized = True self.data_dir = Path(data_dir) self.max_session_age_days = max_session_age_days - + # 内存缓存 self._sessions: dict[str, KokoroSession] = {} self._locks: dict[str, asyncio.Lock] = {} - + # 确保数据目录存在 self.data_dir.mkdir(parents=True, exist_ok=True) - + logger.info(f"SessionManager 初始化完成: {self.data_dir}") - + def _get_lock(self, user_id: str) -> asyncio.Lock: """获取用户级别的锁""" if user_id not in self._locks: self._locks[user_id] = asyncio.Lock() return self._locks[user_id] - + def _get_file_path(self, user_id: str) -> Path: """获取会话文件路径""" safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in user_id) return self.data_dir / f"{safe_id}.json" - + async def get_session(self, user_id: str, stream_id: str) -> KokoroSession: """获取或创建会话""" async with self._get_lock(user_id): @@ -321,28 +321,28 @@ class SessionManager: session = self._sessions[user_id] session.stream_id = stream_id # 更新 stream_id return session - + # 尝试从文件加载 session = await self._load_from_file(user_id) if session: session.stream_id = stream_id self._sessions[user_id] = session return session - + # 创建新会话 session = KokoroSession(user_id=user_id, stream_id=stream_id) self._sessions[user_id] = session logger.info(f"创建新会话: {user_id}") return session - - async def _load_from_file(self, user_id: str) -> Optional[KokoroSession]: + + async def _load_from_file(self, user_id: str) -> KokoroSession | None: """从文件加载会话""" file_path = self._get_file_path(user_id) if not file_path.exists(): return None - + try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: data = json.load(f) session = KokoroSession.from_dict(data) logger.debug(f"从文件加载会话: {user_id}") @@ -350,29 +350,29 @@ class SessionManager: except Exception as e: logger.error(f"加载会话失败 {user_id}: {e}") return None - + async def save_session(self, user_id: str) -> bool: """保存会话到文件""" async with self._get_lock(user_id): if user_id not in self._sessions: return False - + session = self._sessions[user_id] file_path = self._get_file_path(user_id) - + try: data = session.to_dict() temp_path = file_path.with_suffix(".json.tmp") - + with open(temp_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - + os.replace(temp_path, file_path) return True except Exception as e: logger.error(f"保存会话失败 {user_id}: {e}") return False - + async def save_all(self) -> int: """保存所有会话""" count = 0 @@ -380,22 +380,22 @@ class SessionManager: if await self.save_session(user_id): count += 1 return count - + async def get_waiting_sessions(self) -> list[KokoroSession]: """获取所有处于等待状态的会话""" return [s for s in self._sessions.values() if s.status == SessionStatus.WAITING] - + async def get_all_sessions(self) -> list[KokoroSession]: """获取所有会话""" return list(self._sessions.values()) - - def get_session_sync(self, user_id: str) -> Optional[KokoroSession]: + + def get_session_sync(self, user_id: str) -> KokoroSession | None: """同步获取会话(仅从内存)""" return self._sessions.get(user_id) # 全局单例 -_session_manager: Optional[SessionManager] = None +_session_manager: SessionManager | None = None def get_session_manager() -> SessionManager: diff --git a/src/plugins/built_in/kokoro_flow_chatter/unified.py b/src/plugins/built_in/kokoro_flow_chatter/unified.py index 05e9ed2e3..5b34167df 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/unified.py +++ b/src/plugins/built_in/kokoro_flow_chatter/unified.py @@ -21,16 +21,14 @@ from src.config.config import global_config from src.plugin_system.apis import llm_api from src.utils.json_parser import extract_and_parse_json -from .models import LLMResponse, EventType -from .session import KokoroSession - # 统一模式专用的提示词模块 from . import prompt_modules_unified as prompt_modules +from .models import EventType, LLMResponse +from .session import KokoroSession if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ActionInfo logger = get_logger("kfc_unified") @@ -38,27 +36,27 @@ logger = get_logger("kfc_unified") class UnifiedPromptGenerator: """ 统一模式提示词生成器 - + 为统一模式构建提示词: - generate_system_prompt: 构建系统提示词 - generate_responding_prompt: 回应消息场景 - generate_timeout_prompt: 超时决策场景 - generate_proactive_prompt: 主动思考场景 """ - + def __init__(self): pass - + async def generate_system_prompt( self, session: KokoroSession, - available_actions: Optional[dict] = None, - context_data: Optional[dict[str, str]] = None, + available_actions: dict | None = None, + context_data: dict[str, str] | None = None, chat_stream: Optional["ChatStream"] = None, ) -> str: """ 生成系统提示词 - + 使用 prompt_modules.build_system_prompt() 构建模块化的提示词 """ return prompt_modules.build_system_prompt( @@ -67,23 +65,23 @@ class UnifiedPromptGenerator: context_data=context_data, chat_stream=chat_stream, ) - + async def generate_responding_prompt( self, session: KokoroSession, message_content: str, sender_name: str, sender_id: str, - message_time: Optional[float] = None, - available_actions: Optional[dict] = None, + message_time: float | None = None, + available_actions: dict | None = None, context: Optional["StreamContext"] = None, - context_data: Optional[dict[str, str]] = None, + context_data: dict[str, str] | None = None, chat_stream: Optional["ChatStream"] = None, - all_unread_messages: Optional[list] = None, + all_unread_messages: list | None = None, ) -> tuple[str, str]: """ 生成回应消息场景的提示词 - + Returns: tuple[str, str]: (系统提示词, 用户提示词) """ @@ -94,7 +92,7 @@ class UnifiedPromptGenerator: context_data=context_data, chat_stream=chat_stream, ) - + # 构建叙事历史 if context: narrative_history = prompt_modules.format_history_from_context( @@ -102,7 +100,7 @@ class UnifiedPromptGenerator: ) else: narrative_history = prompt_modules.format_narrative_history(session.mental_log) - + # 格式化收到的消息 incoming_messages = prompt_modules.format_incoming_messages( message_content=message_content, @@ -111,25 +109,25 @@ class UnifiedPromptGenerator: message_time=message_time, all_unread_messages=all_unread_messages, ) - + # 使用用户提示词模板 user_prompt = prompt_modules.RESPONDING_USER_PROMPT_TEMPLATE.format( narrative_history=narrative_history, incoming_messages=incoming_messages, ) - + return system_prompt, user_prompt - + async def generate_timeout_prompt( self, session: KokoroSession, - available_actions: Optional[dict] = None, - context_data: Optional[dict[str, str]] = None, + available_actions: dict | None = None, + context_data: dict[str, str] | None = None, chat_stream: Optional["ChatStream"] = None, ) -> tuple[str, str]: """ 生成超时决策场景的提示词 - + Returns: tuple[str, str]: (系统提示词, 用户提示词) """ @@ -140,17 +138,17 @@ class UnifiedPromptGenerator: context_data=context_data, chat_stream=chat_stream, ) - + # 构建叙事历史 narrative_history = prompt_modules.format_narrative_history(session.mental_log) - + # 计算等待时间 wait_duration = session.waiting_config.get_elapsed_seconds() - + # 生成连续追问警告(使用 followup_count 作为追问计数,只有真正发消息才算) followup_count = session.waiting_config.followup_count max_followups = 3 # 最多追问3次 - + if followup_count >= max_followups: followup_warning = f"""⚠️ **重要提醒**: 你已经连续追问了 {followup_count} 次,对方都没有回复。 @@ -162,7 +160,7 @@ class UnifiedPromptGenerator: 如果对方持续没有回应,可能真的在忙或不方便,不需要急着追问。""" else: followup_warning = "" - + # 获取最后一条 Bot 消息 last_bot_message = "(没有记录)" for entry in reversed(session.mental_log): @@ -175,7 +173,7 @@ class UnifiedPromptGenerator: break if last_bot_message != "(没有记录)": break - + # 使用用户提示词模板 user_prompt = prompt_modules.TIMEOUT_DECISION_USER_PROMPT_TEMPLATE.format( narrative_history=narrative_history, @@ -185,20 +183,20 @@ class UnifiedPromptGenerator: followup_warning=followup_warning, last_bot_message=last_bot_message, ) - + return system_prompt, user_prompt - + async def generate_proactive_prompt( self, session: KokoroSession, trigger_context: str, - available_actions: Optional[dict] = None, - context_data: Optional[dict[str, str]] = None, + available_actions: dict | None = None, + context_data: dict[str, str] | None = None, chat_stream: Optional["ChatStream"] = None, ) -> tuple[str, str]: """ 生成主动思考场景的提示词 - + Returns: tuple[str, str]: (系统提示词, 用户提示词) """ @@ -209,35 +207,35 @@ class UnifiedPromptGenerator: context_data=context_data, chat_stream=chat_stream, ) - + # 构建叙事历史 narrative_history = prompt_modules.format_narrative_history( session.mental_log, max_entries=10 ) - + # 计算沉默时长 silence_seconds = time.time() - session.last_activity_at if silence_seconds < 3600: silence_duration = f"{silence_seconds / 60:.0f}分钟" else: silence_duration = f"{silence_seconds / 3600:.1f}小时" - + # 当前时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") - + # 从 context_data 获取关系信息 relation_block = "" if context_data: relation_info = context_data.get("relation_info", "") if relation_info: relation_block = f"### 你与对方的关系\n{relation_info}" - + if not relation_block: # 回退:使用默认关系描述 relation_block = """### 你与对方的关系 - 你们还不太熟悉 - 正在慢慢了解中""" - + # 使用用户提示词模板 user_prompt = prompt_modules.PROACTIVE_THINKING_USER_PROMPT_TEMPLATE.format( narrative_history=narrative_history, @@ -246,9 +244,9 @@ class UnifiedPromptGenerator: relation_block=relation_block, trigger_context=trigger_context, ) - + return system_prompt, user_prompt - + def build_messages_for_llm( self, system_prompt: str, @@ -257,12 +255,12 @@ class UnifiedPromptGenerator: ) -> str: """ 构建 LLM 请求的完整提示词 - + 将 system + user 合并为单个提示词字符串 """ # 合并提示词 full_prompt = f"{system_prompt}\n\n---\n\n{user_prompt}" - + # DEBUG日志:打印完整的KFC提示词(只在 DEBUG 级别输出) logger.debug( f"Final KFC prompt constructed for stream {stream_id}:\n" @@ -270,12 +268,12 @@ class UnifiedPromptGenerator: f"{full_prompt}\n" f"--- PROMPT END ---" ) - + return full_prompt # 全局提示词生成器实例 -_prompt_generator: Optional[UnifiedPromptGenerator] = None +_prompt_generator: UnifiedPromptGenerator | None = None def get_unified_prompt_generator() -> UnifiedPromptGenerator: @@ -291,17 +289,17 @@ async def generate_unified_response( user_name: str, situation_type: str = "new_message", chat_stream: Optional["ChatStream"] = None, - available_actions: Optional[dict] = None, - extra_context: Optional[dict] = None, + available_actions: dict | None = None, + extra_context: dict | None = None, ) -> LLMResponse: """ 统一模式:单次 LLM 调用生成完整响应 - + 调用方式: - 使用 UnifiedPromptGenerator 生成 System + User 提示词 - 使用 replyer 模型调用 LLM - 解析 JSON 响应(thought + actions + max_wait_seconds) - + Args: session: 会话对象 user_name: 用户名称 @@ -309,17 +307,17 @@ async def generate_unified_response( chat_stream: 聊天流对象 available_actions: 可用动作字典 extra_context: 额外上下文 - + Returns: LLMResponse 对象,包含完整的思考和动作 """ try: prompt_generator = get_unified_prompt_generator() extra_context = extra_context or {} - + # 获取上下文数据(关系、记忆等) context_data = await _build_context_data(user_name, chat_stream, session.user_id) - + # 根据情况类型选择提示词生成方法 if situation_type == "timeout": system_prompt, user_prompt = await prompt_generator.generate_timeout_prompt( @@ -343,7 +341,7 @@ async def generate_unified_response( message_content, sender_name, sender_id, message_time, all_unread = _get_last_user_message( session, user_name, chat_stream ) - + system_prompt, user_prompt = await prompt_generator.generate_responding_prompt( session=session, message_content=message_content, @@ -356,14 +354,14 @@ async def generate_unified_response( chat_stream=chat_stream, all_unread_messages=all_unread, ) - + # 构建完整提示词 prompt = prompt_generator.build_messages_for_llm( system_prompt, user_prompt, stream_id=chat_stream.stream_id if chat_stream else "", ) - + # 显示提示词(调试模式 - 只有在配置中开启时才输出) if global_config and global_config.debug.show_prompt: logger.info( @@ -372,26 +370,26 @@ async def generate_unified_response( f"{prompt}\n" f"--- PROMPT END ---" ) - + # 获取 replyer 模型配置并调用 LLM models = llm_api.get_available_models() replyer_config = models.get("replyer") - + if not replyer_config: logger.error("[KFC Unified] 未找到 replyer 模型配置") return LLMResponse.create_error_response("未找到 replyer 模型配置") - + # 调用 LLM(使用合并后的提示词) - success, raw_response, reasoning, model_name = await llm_api.generate_with_model( + success, raw_response, _reasoning, _model_name = await llm_api.generate_with_model( prompt=prompt, model_config=replyer_config, request_type="kokoro_flow_chatter.unified", ) - + if not success: logger.error(f"[KFC Unified] LLM 调用失败: {raw_response}") return LLMResponse.create_error_response(raw_response) - + # 输出原始 JSON 响应(DEBUG 级别,用于调试) logger.debug( f"Raw JSON response from LLM for stream {chat_stream.stream_id if chat_stream else 'unknown'}:\n" @@ -399,10 +397,10 @@ async def generate_unified_response( f"{raw_response}\n" f"--- JSON END ---" ) - + # 解析响应 return _parse_unified_response(raw_response, chat_stream.stream_id if chat_stream else None) - + except Exception as e: logger.error(f"[KFC Unified] 生成失败: {e}") import traceback @@ -413,13 +411,13 @@ async def generate_unified_response( async def _build_context_data( user_name: str, chat_stream: Optional["ChatStream"], - user_id: Optional[str] = None, + user_id: str | None = None, ) -> dict[str, str]: """ 构建上下文数据(关系、记忆、工具、表达习惯等) """ logger.debug(f"[KFC Unified] 开始构建上下文数据: user={user_name}") - + if not chat_stream: logger.warning("[KFC Unified] 无 chat_stream,返回默认上下文") return { @@ -429,33 +427,33 @@ async def _build_context_data( "expression_habits": "", "schedule": "", } - + try: from .context_builder import KFCContextBuilder - + builder = KFCContextBuilder(chat_stream) - + # 获取最近的消息作为 target_message(用于记忆检索) target_message = "" if chat_stream.context: unread = chat_stream.context.get_unread_messages() if unread: target_message = unread[-1].processed_plain_text or unread[-1].display_message or "" - + context_data = await builder.build_all_context( sender_name=user_name, target_message=target_message, context=chat_stream.context, user_id=user_id, ) - + # 打印关键信息 memory_len = len(context_data.get("memory_block", "")) tool_len = len(context_data.get("tool_info", "")) logger.debug(f"[KFC Unified] 上下文构建完成: memory_block={memory_len}字符, tool_info={tool_len}字符") - + return context_data - + except Exception as e: logger.error(f"[KFC Unified] 构建上下文数据失败: {e}") import traceback @@ -473,10 +471,10 @@ def _get_last_user_message( session: KokoroSession, user_name: str, chat_stream: Optional["ChatStream"], -) -> tuple[str, str, str, float, Optional[list]]: +) -> tuple[str, str, str, float, list | None]: """ 获取最后一条用户消息 - + Returns: tuple: (消息内容, 发送者名称, 发送者ID, 消息时间, 所有未读消息列表) """ @@ -485,7 +483,7 @@ def _get_last_user_message( sender_id = session.user_id or "" message_time = time.time() all_unread = None - + # 从 chat_stream 获取未读消息 if chat_stream and chat_stream.context: unread = chat_stream.context.get_unread_messages() @@ -497,7 +495,7 @@ def _get_last_user_message( sender_name = last_msg.user_info.user_nickname or user_name sender_id = str(last_msg.user_info.user_id) message_time = last_msg.time or time.time() - + # 如果没有从 chat_stream 获取到,从 mental_log 获取 if not message_content: for entry in reversed(session.mental_log): @@ -506,14 +504,14 @@ def _get_last_user_message( sender_name = entry.user_name or user_name message_time = entry.timestamp break - + return message_content, sender_name, sender_id, message_time, all_unread def _parse_unified_response(raw_response: str, stream_id: str | None = None) -> LLMResponse: """ 解析统一模式的 LLM 响应 - + 响应格式: { "thought": "...", @@ -523,28 +521,28 @@ def _parse_unified_response(raw_response: str, stream_id: str | None = None) -> } """ data = extract_and_parse_json(raw_response, strict=False) - + if not data or not isinstance(data, dict): logger.warning(f"[KFC Unified] 无法解析 JSON: {raw_response[:200]}...") return LLMResponse.create_error_response("无法解析响应格式") - + # 兼容旧版的字段名 # expected_user_reaction -> expected_reaction if "expected_user_reaction" in data and "expected_reaction" not in data: data["expected_reaction"] = data["expected_user_reaction"] - + # 兼容旧版的 reply -> kfc_reply actions = data.get("actions", []) for action in actions: if isinstance(action, dict): if action.get("type") == "reply": action["type"] = "kfc_reply" - + response = LLMResponse.from_dict(data) - + # 美化日志输出:内心思考 + 回复内容 _log_pretty_response(response, stream_id) - + return response @@ -553,9 +551,9 @@ def _log_pretty_response(response: LLMResponse, stream_id: str | None = None) -> if not response.thought and not response.actions: logger.warning("[KFC] 响应为空") return - + stream_tag = f"({stream_id[:8]}) " if stream_id else "" - + # 收集回复内容和其他动作 replies = [] actions = [] @@ -566,22 +564,22 @@ def _log_pretty_response(response: LLMResponse, stream_id: str | None = None) -> replies.append(content) elif action.type not in ("do_nothing", "no_action"): actions.append(action.type) - + # 逐行输出,简洁明了 if response.thought: logger.info(f"[KFC] {stream_tag}💭 {response.thought}") - + for i, reply in enumerate(replies): if len(replies) > 1: logger.info(f"[KFC] 💬 [{i+1}] {reply}") else: logger.info(f"[KFC] 💬 {reply}") - + if actions: logger.info(f"[KFC] 🎯 {', '.join(actions)}") - + if response.max_wait_seconds > 0 or response.expected_reaction: meta = f"⏱ {response.max_wait_seconds}s" if response.max_wait_seconds > 0 else "" if response.expected_reaction: meta += f" 预期: {response.expected_reaction}" - logger.info(f"[KFC] {meta.strip()}") \ No newline at end of file + logger.info(f"[KFC] {meta.strip()}") diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 7d4768c07..1c1c63e3e 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -12,10 +12,10 @@ import aiohttp import filetype from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.data_models.database_data_model import DatabaseUserInfo from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis import config_api, generator_api, llm_api -from src.common.data_models.database_data_model import DatabaseUserInfo # 导入旧的工具函数,我们稍后会考虑是否也需要重构它 from ..utils.history_utils import get_send_history diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index fa6dc634a..d83b246c6 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -6,15 +6,14 @@ import base64 import random from collections.abc import Callable -from pathlib import Path from io import BytesIO +from pathlib import Path + +import aiohttp from PIL import Image -import aiofiles -import aiohttp - from src.common.logger import get_logger -from src.plugin_system.apis import llm_api, config_api +from src.plugin_system.apis import config_api, llm_api logger = get_logger("MaiZone.ImageService") @@ -44,7 +43,7 @@ class ImageService: api_key = str(self.get_config("models.siliconflow_apikey", "")) image_dir = str(self.get_config("send.image_directory", "./data/plugins/maizone_refactored/images")) image_num_raw = self.get_config("send.ai_image_number", 1) - + # 安全地处理图片数量配置,并限制在API允许的范围内 try: image_num = int(image_num_raw) if image_num_raw not in [None, ""] else 1 @@ -79,7 +78,7 @@ class ImageService: async def _generate_image_prompt(self, story_content: str) -> str: """ 使用LLM生成图片提示词,基于说说内容。 - + :param story_content: 说说内容 :return: 生成的图片提示词,失败时返回空字符串 """ @@ -87,7 +86,7 @@ class ImageService: # 获取配置 identity = config_api.get_global_config("personality.identity", "年龄为19岁,是女孩子,身高为160cm,黑色短发") enable_ref = bool(self.get_config("models.image_ref", True)) - + # 构建提示词 prompt = f""" 请根据以下QQ空间说说内容配图,并构建生成配图的风格和prompt。 @@ -102,14 +101,14 @@ class ImageService: models = llm_api.get_available_models() prompt_model = self.get_config("models.text_model", "replyer") model_config = models.get(prompt_model) - + if not model_config: logger.error(f"找不到模型配置: {prompt_model}") return "" # 调用LLM生成提示词 logger.info("正在生成图片提示词...") - success, image_prompt, reasoning, model_name = await llm_api.generate_with_model( + success, image_prompt, _reasoning, _model_name = await llm_api.generate_with_model( prompt=prompt, model_config=model_config, request_type="story.generate", @@ -118,10 +117,10 @@ class ImageService: ) if success: - logger.info(f'成功生成图片提示词: {image_prompt}') + logger.info(f"成功生成图片提示词: {image_prompt}") return image_prompt else: - logger.error('生成图片提示词失败') + logger.error("生成图片提示词失败") return "" except Exception as e: @@ -174,10 +173,10 @@ class ImageService: async with session.post(url, json=data, headers=headers) as response: if response.status != 200: error_text = await response.text() - logger.error(f'生成图片出错,错误码[{response.status}]') - logger.error(f'错误响应: {error_text}') + logger.error(f"生成图片出错,错误码[{response.status}]") + logger.error(f"错误响应: {error_text}") return False - + json_data = await response.json() image_urls = [img["url"] for img in json_data["images"]] @@ -193,27 +192,27 @@ class ImageService: # 处理图片 try: image = Image.open(BytesIO(img_data)) - + # 保存图片为PNG格式(确保兼容性) filename = f"image_{i}.png" save_path = Path(image_dir) / filename - + # 转换为RGB模式如果必要(避免RGBA等模式的问题) - if image.mode in ('RGBA', 'LA', 'P'): - background = Image.new('RGB', image.size, (255, 255, 255)) - background.paste(image, mask=image.split()[-1] if image.mode == 'RGBA' else None) + if image.mode in ("RGBA", "LA", "P"): + background = Image.new("RGB", image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[-1] if image.mode == "RGBA" else None) image = background - - image.save(save_path, format='PNG') + + image.save(save_path, format="PNG") logger.info(f"图片已保存至: {save_path}") success_count += 1 - + except Exception as e: - logger.error(f"处理图片失败: {str(e)}") + logger.error(f"处理图片失败: {e!s}") continue except Exception as e: - logger.error(f"下载第{i+1}张图片失败: {str(e)}") + logger.error(f"下载第{i+1}张图片失败: {e!s}") continue # 只要至少有一张图片成功就返回True @@ -226,28 +225,28 @@ class ImageService: def _encode_image_to_base64(self, img: Image.Image) -> str: """ 将PIL.Image对象编码为base64 data URL - + :param img: PIL图片对象 :return: base64 data URL字符串,失败时返回空字符串 """ try: # 强制转换为PNG格式,因为SiliconFlow API要求data:image/png buffer = BytesIO() - + # 转换为RGB模式如果必要 - if img.mode in ('RGBA', 'LA', 'P'): - background = Image.new('RGB', img.size, (255, 255, 255)) - background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) + if img.mode in ("RGBA", "LA", "P"): + background = Image.new("RGB", img.size, (255, 255, 255)) + background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None) img = background - + # 保存为PNG格式 img.save(buffer, format="PNG") byte_data = buffer.getvalue() - + # Base64编码,使用固定的data:image/png encoded_string = base64.b64encode(byte_data).decode("utf-8") return f"data:image/png;base64,{encoded_string}" - + except Exception as e: logger.error(f"编码图片为base64失败: {e}") - return "" \ No newline at end of file + return "" diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 4b72ceb49..bec0eb921 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -19,8 +19,7 @@ import json5 import orjson from src.common.logger import get_logger -from src.plugin_system.apis import config_api, person_api -from src.plugin_system.apis import cross_context_api +from src.plugin_system.apis import config_api, cross_context_api, person_api from .content_service import ContentService from .cookie_service import CookieService diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index ac994c224..b8d9029b6 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -14,8 +14,6 @@ from sqlalchemy import select from src.common.database.compatibility import get_db_session from src.common.database.core.models import MaiZoneScheduleStatus from src.common.logger import get_logger -from src.config.config import model_config as global_model_config -from src.plugin_system.apis import llm_api from src.schedule.schedule_manager import schedule_manager from .qzone_service import QZoneService diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index fb1a7ff35..4fcc20ec8 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -12,21 +12,20 @@ from __future__ import annotations import asyncio import uuid -from typing import Any, ClassVar, Dict, List, Optional +from typing import Any, ClassVar import orjson -import websockets - from mofox_wire import CoreSink, MessageEnvelope, WebSocketAdapterOptions + from src.common.logger import get_logger from src.plugin_system import ConfigField, register_plugin -from src.plugin_system.base import BaseAdapter, BasePlugin from src.plugin_system.apis import config_api +from src.plugin_system.base import BaseAdapter, BasePlugin from .src.handlers import utils as handler_utils from .src.handlers.to_core.message_handler import MessageHandler -from .src.handlers.to_core.notice_handler import NoticeHandler from .src.handlers.to_core.meta_event_handler import MetaEventHandler +from .src.handlers.to_core.notice_handler import NoticeHandler from .src.handlers.to_napcat.send_handler import SendHandler logger = get_logger("napcat_adapter") @@ -43,7 +42,7 @@ class NapcatAdapter(BaseAdapter): run_in_subprocess = False - def __init__(self, core_sink: CoreSink, plugin: Optional[BasePlugin] = None, **kwargs): + def __init__(self, core_sink: CoreSink, plugin: BasePlugin | None = None, **kwargs): """初始化 Napcat 适配器""" # 从插件配置读取 WebSocket URL if plugin: @@ -78,7 +77,7 @@ class NapcatAdapter(BaseAdapter): self.send_handler = SendHandler(self) # 响应池:用于存储等待的 API 响应 - self._response_pool: Dict[str, asyncio.Future] = {} + self._response_pool: dict[str, asyncio.Future] = {} self._response_timeout = 30.0 # WebSocket 连接(用于发送 API 请求) @@ -88,10 +87,10 @@ class NapcatAdapter(BaseAdapter): # 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS handler_utils.register_adapter(self) - def _should_process_event(self, raw: Dict[str, Any]) -> bool: + def _should_process_event(self, raw: dict[str, Any]) -> bool: """ 检查事件是否应该被处理(黑白名单过滤) - + 此方法在 from_platform_message 顶层调用,对所有类型的事件(消息、通知、元事件)进行过滤。 Args: @@ -102,7 +101,7 @@ class NapcatAdapter(BaseAdapter): """ if not self.plugin: return True - + plugin_config = self.plugin.config if not plugin_config: return True # 如果没有配置,默认处理所有事件 @@ -138,7 +137,7 @@ class NapcatAdapter(BaseAdapter): # 获取消息类型(消息事件使用 message_type,通知事件根据 group_id 判断) message_type = raw.get("message_type") group_id = raw.get("group_id") - + # 如果是通知事件,根据是否有 group_id 判断是群通知还是私聊通知 if post_type == "notice": message_type = "group" if group_id else "private" @@ -178,7 +177,7 @@ class NapcatAdapter(BaseAdapter): async def on_adapter_loaded(self) -> None: """适配器加载时的初始化""" logger.info("Napcat 适配器正在启动...") - + # 设置处理器配置 if self.plugin: self.message_handler.set_plugin_config(self.plugin.config) @@ -194,6 +193,7 @@ class NapcatAdapter(BaseAdapter): async def _register_notice_events(self) -> None: """注册 notice 相关事件到 event manager""" from src.plugin_system.core.event_manager import event_manager + from .src.event_types import NapcatEvent # 定义所有 notice 事件类型 @@ -230,7 +230,7 @@ class NapcatAdapter(BaseAdapter): async def on_adapter_unloaded(self) -> None: """适配器卸载时的清理""" logger.info("Napcat 适配器正在关闭...") - + # 清理响应池 for future in self._response_pool.values(): if not future.done(): @@ -239,16 +239,16 @@ class NapcatAdapter(BaseAdapter): logger.info("Napcat 适配器已关闭") - async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope | None: # type: ignore[override] + async def from_platform_message(self, raw: dict[str, Any]) -> MessageEnvelope | None: # type: ignore[override] """ 将 Napcat/OneBot 原始消息转换为 MessageEnvelope - + 这是核心转换方法,处理: - message 事件 → 消息 - notice 事件 → 通知(戳一戳、表情回复等) - meta_event 事件 → 元事件(心跳、生命周期) - API 响应 → 存入响应池 - + 注意:黑白名单等过滤机制在此方法最开始执行,确保所有类型的事件都能被过滤。 """ post_type = raw.get("post_type") @@ -270,7 +270,7 @@ class NapcatAdapter(BaseAdapter): # 消息事件 if post_type == "message": return await self.message_handler.handle_raw_message(raw) # type: ignore[return-value] - + # 通知事件 elif post_type == "notice": return await self.notice_handler.handle_notice(raw) # type: ignore[return-value] @@ -288,11 +288,11 @@ class NapcatAdapter(BaseAdapter): except Exception as e: logger.error(f"处理 Napcat 事件失败: {e}, 原始数据: {raw}") return None - + async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override] """ 将 MessageEnvelope 转换并发送到 Napcat - + 这里不直接通过 WebSocket 发送 envelope, 而是调用 Napcat API(send_group_msg, send_private_msg 等) """ @@ -301,15 +301,15 @@ class NapcatAdapter(BaseAdapter): except Exception as e: logger.error(f"发送 Napcat 消息失败: {e}") - async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]: + async def send_napcat_api(self, action: str, params: dict[str, Any], timeout: float = 30.0) -> dict[str, Any]: """ 发送 Napcat API 请求并等待响应 - + Args: action: API 动作名称(如 send_group_msg) params: API 参数 timeout: 超时时间(秒) - + Returns: API 响应数据 """ diff --git a/src/plugins/built_in/napcat_adapter/src/event_models.py b/src/plugins/built_in/napcat_adapter/src/event_models.py index 6917e9716..4a2994f1b 100644 --- a/src/plugins/built_in/napcat_adapter/src/event_models.py +++ b/src/plugins/built_in/napcat_adapter/src/event_models.py @@ -298,13 +298,13 @@ QQ_FACE = { __all__ = [ - "MetaEventType", - "MessageType", - "NoticeType", - "RealMessageType", - "MessageSentType", - "CommandType", "ACCEPT_FORMAT", "PLUGIN_NAME", "QQ_FACE", + "CommandType", + "MessageSentType", + "MessageType", + "MetaEventType", + "NoticeType", + "RealMessageType", ] diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py index 6cef2fe40..0c2fc807c 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py @@ -5,19 +5,15 @@ from __future__ import annotations import base64 import time from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -import uuid +from typing import TYPE_CHECKING, Any + +from mofox_wire import ( + MessageBuilder, + SegPayload, +) -from mofox_wire import MessageBuilder from src.common.logger import get_logger from src.plugin_system.apis import config_api -from mofox_wire import ( - MessageEnvelope, - SegPayload, - MessageInfoPayload, - UserInfoPayload, - GroupInfoPayload, -) from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType from ..utils import * @@ -33,13 +29,13 @@ class MessageHandler: def __init__(self, adapter: "NapcatAdapter"): self.adapter = adapter - self.plugin_config: Optional[Dict[str, Any]] = None + self.plugin_config: dict[str, Any] | None = None - def set_plugin_config(self, config: Dict[str, Any]) -> None: + def set_plugin_config(self, config: dict[str, Any]) -> None: """设置插件配置""" self.plugin_config = config - async def handle_raw_message(self, raw: Dict[str, Any]): + async def handle_raw_message(self, raw: dict[str, Any]): """ 处理原始消息并转换为 MessageEnvelope @@ -48,7 +44,7 @@ class MessageHandler: Returns: MessageEnvelope (dict) or None - + Note: 黑白名单过滤已移动到 NapcatAdapter.from_platform_message 顶层执行, 确保所有类型的事件(消息、通知等)都能被统一过滤。 @@ -95,7 +91,7 @@ class MessageHandler: # 解析消息段 message_segments = raw.get("message", []) - seg_list: List[SegPayload] = [] + seg_list: list[SegPayload] = [] for segment in message_segments: seg_message = await self.handle_single_segment(segment, raw) @@ -158,7 +154,7 @@ class MessageHandler: return await self._handle_json_message(segment) case RealMessageType.file: return await self._handle_file_message(segment) - + case _: logger.warning(f"Unsupported segment type: {seg_type}") return None @@ -189,7 +185,7 @@ class MessageHandler: try: image_base64 = await get_image_base64(message_data.get("url", "")) except Exception as e: - logger.error(f"图片消息处理失败: {str(e)}") + logger.error(f"图片消息处理失败: {e!s}") return None if image_sub_type == 0: return {"type": "image", "data": image_base64} @@ -241,7 +237,7 @@ class MessageHandler: return {"type": "text", "data": "[无法获取被引用的消息]"} # 递归处理被引用的消息 - reply_segments: List[SegPayload] = [] + reply_segments: list[SegPayload] = [] for reply_seg in message_detail.get("message", []): if isinstance(reply_seg, dict): reply_result = await self.handle_single_segment(reply_seg, raw_message, in_reply=True) @@ -280,7 +276,7 @@ class MessageHandler: return None audio_base64 = record_detail.get("base64", "") except Exception as e: - logger.error(f"语音消息处理失败: {str(e)}") + logger.error(f"语音消息处理失败: {e!s}") return None if not audio_base64: @@ -344,7 +340,7 @@ class MessageHandler: return None except Exception as e: - logger.error(f"视频消息处理失败: {str(e)}") + logger.error(f"视频消息处理失败: {e!s}") return None async def _handle_rps_message(self, segment: dict) -> SegPayload: @@ -400,7 +396,7 @@ class MessageHandler: try: encoded_image = await get_image_base64(image_url) except Exception as e: - logger.error(f"图片处理失败: {str(e)}") + logger.error(f"图片处理失败: {e!s}") return {"type": "text", "data": "[图片]"} return {"type": "image", "data": encoded_image} if seg_data.get("type") == "emoji": @@ -408,7 +404,7 @@ class MessageHandler: try: encoded_image = await get_image_base64(image_url) except Exception as e: - logger.error(f"图片处理失败: {str(e)}") + logger.error(f"图片处理失败: {e!s}") return {"type": "text", "data": "[表情包]"} return {"type": "emoji", "data": encoded_image} logger.debug(f"不处理类型: {seg_data.get('type')}") @@ -421,7 +417,7 @@ class MessageHandler: logger.debug(f"不处理类型: {seg_data.get('type')}") return seg_data - async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[SegPayload | None, int]: + async def _handle_forward_message(self, message_list: list, layer: int) -> tuple[SegPayload | None, int]: # sourcery skip: low-code-quality """ 递归处理实际转发消息 @@ -432,7 +428,7 @@ class MessageHandler: seg_data: Seg: 处理后的消息段 image_count: int: 图片数量 """ - seg_list: List[SegPayload] = [] + seg_list: list[SegPayload] = [] image_count = 0 if message_list is None: return None, 0 @@ -441,7 +437,7 @@ class MessageHandler: user_nickname: str = sender_info.get("nickname", "QQ用户") user_nickname_str = f"【{user_nickname}】:" break_seg: SegPayload = {"type": "text", "data": "\n"} - message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message") + message_of_sub_message_list: list[dict[str, Any]] = sub_message.get("message") if not message_of_sub_message_list: logger.warning("转发消息内容为空") continue @@ -475,7 +471,7 @@ class MessageHandler: text_message = sub_message_data.get("text") seg_data: SegPayload = {"type": "text", "data": text_message} nickname_prefix = ("--" * layer) + user_nickname_str if layer > 0 else user_nickname_str - data_list: List[SegPayload] = [ + data_list: list[SegPayload] = [ {"type": "text", "data": nickname_prefix}, seg_data, break_seg, @@ -607,7 +603,7 @@ class MessageHandler: "data": f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}", } - + # 检查是否是音乐分享 (QQ音乐类型) if nested_data.get("view") == "music" and "com.tencent.music" in str(nested_data.get("app", "")): @@ -677,7 +673,7 @@ class MessageHandler: except Exception as e: logger.error(f"处理JSON消息时发生未知错误: {e}") return None - + def _is_file_upload_echo(self, nested_data: Any) -> bool: """检查一个JSON对象是否是机器人自己上传文件的回声消息""" if not isinstance(nested_data, dict): @@ -699,26 +695,26 @@ class MessageHandler: return False - def _extract_file_info_from_echo(self, nested_data: dict) -> Optional[dict]: + def _extract_file_info_from_echo(self, nested_data: dict) -> dict | None: """从文件上传的回声消息中提取文件信息""" try: meta = nested_data.get("meta", {}) detail_1 = meta.get("detail_1", {}) - + # 文件名在 'desc' 字段 file_name = detail_1.get("desc") - + # 文件大小在 'summary' 字段,格式为 "大小:1.7MB" summary = detail_1.get("summary", "") file_size_str = summary.replace("大小:", "").strip() # 移除前缀和空格 - + # QQ API有时返回的大小不标准,这里我们只提取它给的字符串 # 实际大小已经由Napcat在发送时记录,这里主要是为了保持格式一致 - + if file_name and file_size_str: return {"file": file_name, "file_size": file_size_str, "file_id": None} # file_id在回声中不可用 except Exception as e: logger.error(f"从文件回声中提取信息失败: {e}") - + return None - + diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py index 49976a164..6be6eb0ad 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py @@ -1,9 +1,9 @@ """元事件处理器""" from __future__ import annotations -import time import asyncio -from typing import TYPE_CHECKING, Any, Dict, Optional +import time +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger @@ -20,14 +20,14 @@ class MetaEventHandler: def __init__(self, adapter: "NapcatAdapter"): self.adapter = adapter - self.plugin_config: Optional[Dict[str, Any]] = None + self.plugin_config: dict[str, Any] | None = None self._interval_checking = False - def set_plugin_config(self, config: Dict[str, Any]) -> None: + def set_plugin_config(self, config: dict[str, Any]) -> None: """设置插件配置""" self.plugin_config = config - async def handle_meta_event(self, raw: Dict[str, Any]): + async def handle_meta_event(self, raw: dict[str, Any]): event_type = raw.get("meta_event_type") if event_type == MetaEventType.lifecycle: sub_type = raw.get("sub_type") diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/notice_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/notice_handler.py index 31e8937fd..1b235a735 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/notice_handler.py @@ -3,14 +3,15 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any from mofox_wire import MessageBuilder, SegPayload, UserInfoPayload + from src.common.logger import get_logger from src.plugin_system.apis import config_api -from ...event_models import ACCEPT_FORMAT, NoticeType, QQ_FACE, PLUGIN_NAME, RealMessageType -from ..utils import get_group_info, get_member_info, get_self_info, get_stranger_info, get_message_detail +from ...event_models import ACCEPT_FORMAT, PLUGIN_NAME, QQ_FACE, NoticeType, RealMessageType +from ..utils import get_group_info, get_member_info, get_message_detail, get_self_info, get_stranger_info if TYPE_CHECKING: from ....plugin import NapcatAdapter @@ -23,11 +24,11 @@ class NoticeHandler: def __init__(self, adapter: "NapcatAdapter"): self.adapter = adapter - self.plugin_config: Optional[Dict[str, Any]] = None + self.plugin_config: dict[str, Any] | None = None # 戳一戳防抖时间戳 self.last_poke_time: float = 0.0 - def set_plugin_config(self, config: Dict[str, Any]) -> None: + def set_plugin_config(self, config: dict[str, Any]) -> None: """设置插件配置""" self.plugin_config = config @@ -37,7 +38,7 @@ class NoticeHandler: return default return config_api.get_plugin_config(self.plugin_config, key, default) - async def handle_notice(self, raw: Dict[str, Any]): + async def handle_notice(self, raw: dict[str, Any]): """ 处理通知事件 @@ -57,8 +58,7 @@ class NoticeHandler: handled_segment: SegPayload | None = None user_info: UserInfoPayload | None = None - system_notice: bool = False - notice_config: Dict[str, Any] = { + notice_config: dict[str, Any] = { "is_notice": False, "is_public_notice": False, "target_id": target_id, @@ -93,6 +93,7 @@ class NoticeHandler: case NoticeType.Notify.input_status: from src.plugin_system.core.event_manager import event_manager + from ...event_types import NapcatEvent await event_manager.trigger_event( NapcatEvent.ON_RECEIVED.FRIEND_INPUT, @@ -128,7 +129,6 @@ class NoticeHandler: logger.info("处理群禁言") handled_segment, user_info = await self._handle_ban_notify(raw, group_id) if handled_segment and user_info: - system_notice = True user_id_in_ban = raw.get("user_id") if user_id_in_ban == 0: notice_config["notice_type"] = "group_whole_ban" @@ -140,7 +140,6 @@ class NoticeHandler: logger.info("处理解除群禁言") handled_segment, user_info = await self._handle_lift_ban_notify(raw, group_id) if handled_segment and user_info: - system_notice = True user_id_in_ban = raw.get("user_id") if user_id_in_ban == 0: notice_config["notice_type"] = "group_whole_lift_ban" @@ -217,10 +216,10 @@ class NoticeHandler: envelope = msg_builder.build() envelope["message_info"]["additional_config"] = notice_config return envelope - + async def _handle_poke_notify( - self, raw: Dict[str, Any], group_id: Any, user_id: Any - ) -> Tuple[SegPayload | None, UserInfoPayload | None]: + self, raw: dict[str, Any], group_id: Any, user_id: Any + ) -> tuple[SegPayload | None, UserInfoPayload | None]: """处理戳一戳通知""" self_info: dict | None = await get_self_info() @@ -295,7 +294,7 @@ class NoticeHandler: if len(raw_info) > 4: second_txt = raw_info[4].get("txt", "") except Exception as e: - logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") + logger.warning(f"解析戳一戳消息失败: {e!s},将使用默认文本") user_info: UserInfoPayload = { "platform": "qq", @@ -311,8 +310,8 @@ class NoticeHandler: return seg_data, user_info async def _handle_group_emoji_like_notify( - self, raw: Dict[str, Any], group_id: Any, user_id: Any - ) -> Tuple[SegPayload | None, UserInfoPayload | None]: + self, raw: dict[str, Any], group_id: Any, user_id: Any + ) -> tuple[SegPayload | None, UserInfoPayload | None]: """处理群聊表情回复通知""" if not group_id: logger.error("群ID不能为空,无法处理群聊表情回复通知") @@ -329,6 +328,7 @@ class NoticeHandler: # 触发事件 from src.plugin_system.core.event_manager import event_manager + from ...event_types import NapcatEvent target_message = await get_message_detail(raw.get("message_id", "")) @@ -367,12 +367,12 @@ class NoticeHandler: } return seg_data, user_info - async def _extract_message_preview(self, message_detail: Dict[str, Any], depth: int = 0) -> str: + async def _extract_message_preview(self, message_detail: dict[str, Any], depth: int = 0) -> str: """提取被表情回应消息的可读摘要,支持多层嵌套""" if depth > 3: return "..." - preview_parts: List[str] = [] + preview_parts: list[str] = [] for seg in message_detail.get("message", []): seg_type = seg.get("type") seg_data = seg.get("data", {}) @@ -410,8 +410,8 @@ class NoticeHandler: return preview async def _handle_group_upload_notify( - self, raw: Dict[str, Any], group_id: Any, user_id: Any, self_id: Any - ) -> Tuple[SegPayload | None, UserInfoPayload | None]: + self, raw: dict[str, Any], group_id: Any, user_id: Any, self_id: Any + ) -> tuple[SegPayload | None, UserInfoPayload | None]: """处理群文件上传通知""" if not group_id: logger.error("群ID不能为空,无法处理群文件上传通知") @@ -448,8 +448,8 @@ class NoticeHandler: return seg_data, user_info async def _handle_ban_notify( - self, raw: Dict[str, Any], group_id: Any - ) -> Tuple[SegPayload | None, UserInfoPayload | None]: + self, raw: dict[str, Any], group_id: Any + ) -> tuple[SegPayload | None, UserInfoPayload | None]: """处理群禁言通知""" if not group_id: logger.error("群ID不能为空,无法处理禁言通知") @@ -476,7 +476,7 @@ class NoticeHandler: # 获取被禁言者信息 user_id = raw.get("user_id") - banned_user_info: Dict[str, Any] | None = None + banned_user_info: dict[str, Any] | None = None user_nickname: str = "QQ用户" user_cardname: str = "" sub_type: str = "" @@ -513,8 +513,8 @@ class NoticeHandler: return seg_data, operator_info async def _handle_lift_ban_notify( - self, raw: Dict[str, Any], group_id: Any - ) -> Tuple[SegPayload | None, UserInfoPayload | None]: + self, raw: dict[str, Any], group_id: Any + ) -> tuple[SegPayload | None, UserInfoPayload | None]: """处理解除群禁言通知""" if not group_id: logger.error("群ID不能为空,无法处理解除禁言通知") @@ -543,7 +543,7 @@ class NoticeHandler: sub_type: str = "" user_nickname: str = "QQ用户" user_cardname: str = "" - lifted_user_info: Dict[str, Any] | None = None + lifted_user_info: dict[str, Any] | None = None user_id = raw.get("user_id") if user_id == 0: # 全体禁言解除 diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_napcat/send_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_napcat/send_handler.py index e2ffc068d..99b95d716 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_napcat/send_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_napcat/send_handler.py @@ -3,13 +3,13 @@ from __future__ import annotations import random -import time -import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any + +from mofox_wire import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload -from mofox_wire import MessageEnvelope, SegPayload, GroupInfoPayload, UserInfoPayload, MessageInfoPayload from src.common.logger import get_logger from src.plugin_system.apis import config_api + from ...event_models import CommandType from ..utils import convert_image_to_gif, get_image_format @@ -24,9 +24,9 @@ class SendHandler: def __init__(self, adapter: "NapcatAdapter"): self.adapter = adapter - self.plugin_config: Optional[Dict[str, Any]] = None + self.plugin_config: dict[str, Any] | None = None - def set_plugin_config(self, config: Dict[str, Any]) -> None: + def set_plugin_config(self, config: dict[str, Any]) -> None: """设置插件配置""" self.plugin_config = config @@ -74,11 +74,11 @@ class SendHandler: else: seg_data = message_segment - group_info: Optional[GroupInfoPayload] = message_info.get("group_info") - user_info: Optional[UserInfoPayload] = message_info.get("user_info") - target_id: Optional[int] = None - action: Optional[str] = None - id_name: Optional[str] = None + group_info: GroupInfoPayload | None = message_info.get("group_info") + user_info: UserInfoPayload | None = message_info.get("user_info") + target_id: int | None = None + action: str | None = None + id_name: str | None = None processed_message: list = [] try: processed_message = await self.handle_seg_recursive(seg_data, user_info or {}) @@ -121,18 +121,18 @@ class SendHandler: if response.get("status") == "ok": logger.info("消息发送成功") else: - logger.warning(f"消息发送失败,napcat返回:{str(response)}") + logger.warning(f"消息发送失败,napcat返回:{response!s}") async def send_command(self, envelope: MessageEnvelope) -> None: """ 处理命令类 """ logger.debug("处理命令中") - message_info: Dict[str, Any] = envelope.get("message_info", {}) - group_info: Optional[Dict[str, Any]] = message_info.get("group_info") + message_info: dict[str, Any] = envelope.get("message_info", {}) + group_info: dict[str, Any] | None = message_info.get("group_info") segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment] - seg_data: Dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {} - command_name: Optional[str] = seg_data.get("name") + seg_data: dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {} + command_name: str | None = seg_data.get("name") try: args = seg_data.get("args", {}) if not isinstance(args, dict): @@ -174,7 +174,7 @@ class SendHandler: if response.get("status") == "ok": logger.info(f"命令 {command_name} 执行成功") else: - logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") + logger.warning(f"命令 {command_name} 执行失败,napcat返回:{response!s}") async def handle_adapter_command(self, envelope: MessageEnvelope) -> None: """ @@ -182,7 +182,7 @@ class SendHandler: """ logger.info("处理适配器命令中") segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment] - seg_data: Dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {} + seg_data: dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {} try: action = seg_data.get("action") @@ -212,7 +212,7 @@ class SendHandler: if response.get("status") == "ok": logger.info(f"适配器命令 {action} 执行成功") else: - logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}") + logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{response!s}") logger.debug(f"适配器命令 {action} 的完整响应: {response}") except Exception as e: @@ -411,11 +411,11 @@ class SendHandler: "data": {"file": f"file://{file_path}"}, } - def delete_msg_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + def delete_msg_command(self, args: dict[str, Any]) -> tuple[str, dict[str, Any]]: """处理删除消息命令""" return "delete_msg", {"message_id": args["message_id"]} - def handle_ban_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_ban_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """处理封禁命令""" duration: int = int(args["duration"]) user_id: int = int(args["qq_id"]) @@ -435,7 +435,7 @@ class SendHandler: }, ) - def handle_whole_ban_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_whole_ban_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """处理全体禁言命令""" enable = args["enable"] assert isinstance(enable, bool), "enable参数必须是布尔值" @@ -450,7 +450,7 @@ class SendHandler: }, ) - def handle_kick_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_kick_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """处理群成员踢出命令""" user_id: int = int(args["qq_id"]) group_id: int = int(group_info["group_id"]) if group_info and group_info.get("group_id") else 0 @@ -467,10 +467,10 @@ class SendHandler: }, ) - def handle_poke_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_poke_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """处理戳一戳命令""" user_id: int = int(args["qq_id"]) - group_id: Optional[int] = None + group_id: int | None = None if group_info and group_info.get("group_id"): group_id = int(group_info["group_id"]) if group_id <= 0: @@ -485,7 +485,7 @@ class SendHandler: }, ) - def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + def handle_set_emoji_like_command(self, args: dict[str, Any]) -> tuple[str, dict[str, Any]]: """处理设置表情回应命令""" logger.info(f"开始处理表情回应命令, 接收到参数: {args}") try: @@ -501,7 +501,7 @@ class SendHandler: {"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, ) - def handle_send_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + def handle_send_like_command(self, args: dict[str, Any]) -> tuple[str, dict[str, Any]]: """处理发送点赞命令的逻辑。""" try: user_id: int = int(args["qq_id"]) @@ -514,7 +514,7 @@ class SendHandler: {"user_id": user_id, "times": times}, ) - def handle_at_message_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_at_message_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """处理艾特并发送消息命令""" at_user_id = args.get("qq_id") text = args.get("text") @@ -538,7 +538,7 @@ class SendHandler: }, ) - def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]: + def handle_ai_voice_send_command(self, args: dict[str, Any], group_info: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """ 处理AI语音发送命令的逻辑。 并返回 NapCat 兼容的 (action, params) 元组。 diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/utils.py b/src/plugins/built_in/napcat_adapter/src/handlers/utils.py index 6891f8414..a0354ac0a 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/utils.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/utils.py @@ -3,10 +3,9 @@ import base64 import io import ssl import time -import uuid import weakref from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any import orjson import urllib3 @@ -22,7 +21,7 @@ logger = get_logger("napcat_adapter") # 简单的缓存实现,通过 JSON 文件实现磁盘一价存储 _CACHE_FILE = Path(__file__).resolve().parent / "napcat_cache.json" _CACHE_LOCK = asyncio.Lock() -_CACHE: Dict[str, Dict[str, Dict[str, Any]]] = { +_CACHE: dict[str, dict[str, dict[str, Any]]] = { "group_info": {}, "group_detail_info": {}, "member_info": {}, @@ -106,10 +105,10 @@ def _get_adapter(adapter: "NapcatAdapter | None" = None) -> "NapcatAdapter": async def _call_adapter_api( action: str, - params: Dict[str, Any], + params: dict[str, Any], adapter: "NapcatAdapter | None" = None, timeout: float = 30.0, -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """统一通过 adapter 发送和接收 API 调用""" try: target = _get_adapter(adapter) @@ -141,7 +140,7 @@ class SSLAdapter(urllib3.PoolManager): async def get_respose( action: str, - params: Dict[str, Any], + params: dict[str, Any], adapter: "NapcatAdapter | None" = None, timeout: float = 30.0, ): @@ -250,7 +249,7 @@ async def get_image_base64(url: str) -> str: image_bytes = response.data return base64.b64encode(image_bytes).decode("utf-8") except Exception as e: - logger.error(f"图片下载失败: {str(e)}") + logger.error(f"图片下载失败: {e!s}") raise @@ -272,7 +271,7 @@ def convert_image_to_gif(image_base64: str) -> str: output_buffer.seek(0) return base64.b64encode(output_buffer.read()).decode("utf-8") except Exception as e: - logger.error(f"图片转换为GIF失败: {str(e)}") + logger.error(f"图片转换为GIF失败: {e!s}") return image_base64 @@ -338,7 +337,7 @@ async def get_stranger_info( async def get_message_detail( - message_id: Union[str, int], + message_id: str | int, *, adapter: "NapcatAdapter | None" = None, ) -> dict | None: @@ -357,7 +356,7 @@ async def get_message_detail( async def get_record_detail( file: str, - file_id: Optional[str] = None, + file_id: str | None = None, *, adapter: "NapcatAdapter | None" = None, ) -> dict | None: @@ -397,14 +396,14 @@ async def get_forward_message( logger.error("获取转发消息超时") return None except Exception as e: - logger.error(f"获取转发消息失败: {str(e)}") + logger.error(f"获取转发消息失败: {e!s}") return None logger.debug( f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..." if len(orjson.dumps(response).decode("utf-8")) > 80 else orjson.dumps(response).decode("utf-8") ) - response_data: Dict = response.get("data") + response_data: dict = response.get("data") if not response_data: logger.warning("转发消息内容为空或获取失败") return None diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/video_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/video_handler.py index aa64d2571..2bef75b91 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/video_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/video_handler.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ 视频下载和处理模块 用于从QQ消息中下载视频并转发给Bot进行分析 @@ -7,7 +6,7 @@ import asyncio from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import aiohttp @@ -52,7 +51,7 @@ class VideoDownloader: # 如果解析失败,默认允许尝试下载(稍后验证) return True - def check_file_size(self, content_length: Optional[str]) -> bool: + def check_file_size(self, content_length: str | None) -> bool: """检查文件大小是否在允许范围内""" if content_length is None: return True # 无法获取大小时允许下载 @@ -64,7 +63,7 @@ class VideoDownloader: except Exception: return True - async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]: + async def download_video(self, url: str, filename: str | None = None) -> dict[str, Any]: """ 下载视频文件 diff --git a/src/plugins/built_in/siliconflow_api_index_tts/plugin.py b/src/plugins/built_in/siliconflow_api_index_tts/plugin.py index 44104f94b..124e73221 100644 --- a/src/plugins/built_in/siliconflow_api_index_tts/plugin.py +++ b/src/plugins/built_in/siliconflow_api_index_tts/plugin.py @@ -3,24 +3,22 @@ SiliconFlow IndexTTS 语音合成插件 基于SiliconFlow API的IndexTTS语音合成插件,支持高质量的零样本语音克隆和情感控制 """ -import os +import asyncio import base64 import hashlib -import asyncio -import aiohttp -import json -import toml -from typing import Tuple, Optional, Dict, Any, List, Type from pathlib import Path -from src.plugin_system import BasePlugin, BaseAction, BaseCommand, register_plugin, ConfigField -from src.plugin_system.base.base_action import ActionActivationType, ChatMode +import aiohttp +import toml + from src.common.logger import get_logger +from src.plugin_system import BaseAction, BaseCommand, BasePlugin, ConfigField, register_plugin +from src.plugin_system.base.base_action import ActionActivationType, ChatMode logger = get_logger("SiliconFlow-TTS") -def get_global_siliconflow_api_key() -> Optional[str]: +def get_global_siliconflow_api_key() -> str | None: """从全局配置文件中获取SiliconFlow API密钥""" try: # 读取全局model_config.toml配置文件 @@ -28,10 +26,10 @@ def get_global_siliconflow_api_key() -> Optional[str]: if not config_path.exists(): logger.error("全局配置文件 config/model_config.toml 不存在") return None - - with open(config_path, "r", encoding="utf-8") as f: + + with open(config_path, encoding="utf-8") as f: model_config = toml.load(f) - + # 查找SiliconFlow API提供商配置 api_providers = model_config.get("api_providers", []) for provider in api_providers: @@ -40,10 +38,10 @@ def get_global_siliconflow_api_key() -> Optional[str]: if api_key: logger.info("成功从全局配置读取SiliconFlow API密钥") return api_key - + logger.warning("在全局配置中未找到SiliconFlow API提供商或API密钥为空") return None - + except Exception as e: logger.error(f"读取全局配置失败: {e}") return None @@ -51,22 +49,22 @@ def get_global_siliconflow_api_key() -> Optional[str]: class SiliconFlowTTSClient: """SiliconFlow TTS API客户端""" - - def __init__(self, api_key: str, base_url: str = "https://api.siliconflow.cn/v1/audio/speech", + + def __init__(self, api_key: str, base_url: str = "https://api.siliconflow.cn/v1/audio/speech", timeout: int = 60, max_retries: int = 3): self.api_key = api_key self.base_url = base_url self.timeout = timeout self.max_retries = max_retries - + async def synthesize_speech(self, text: str, voice_id: str, - model: str = "IndexTeam/IndexTTS-2", + model: str = "IndexTeam/IndexTTS-2", speed: float = 1.0, volume: float = 1.0, emotion_strength: float = 1.0, output_format: str = "wav") -> bytes: """ 调用SiliconFlow API进行语音合成 - + Args: text: 要合成的文本 voice_id: 预配置的语音ID @@ -75,7 +73,7 @@ class SiliconFlowTTSClient: volume: 音量 emotion_strength: 情感强度 output_format: 输出格式 - + Returns: 合成的音频数据 """ @@ -83,7 +81,7 @@ class SiliconFlowTTSClient: "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } - + # 构建请求数据 data = { "model": model, @@ -92,9 +90,9 @@ class SiliconFlowTTSClient: "format": output_format, "speed": speed } - + logger.info(f"使用配置的Voice ID: {voice_id}") - + # 发送请求 for attempt in range(self.max_retries): try: @@ -123,7 +121,7 @@ class SiliconFlowTTSClient: if attempt == self.max_retries - 1: raise e await asyncio.sleep(2 ** attempt) # 指数退避 - + raise Exception("所有重试都失败了") @@ -161,7 +159,7 @@ class SiliconFlowIndexTTSAction(BaseAction): # 关联类型 - 支持语音消息 associated_types = ["voice"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行SiliconFlow IndexTTS语音合成""" logger.info(f"{self.log_prefix} 执行SiliconFlow IndexTTS动作: {self.reasoning}") @@ -176,13 +174,13 @@ class SiliconFlowIndexTTSAction(BaseAction): # 获取文本内容 - 多种来源尝试 text = "" - + # 1. 尝试从action_data获取text参数 text = self.action_data.get("text", "") if not text: # 2. 尝试从action_data获取tts_text参数(兼容其他TTS插件) text = self.action_data.get("tts_text", "") - + if not text: # 3. 如果没有提供具体文本,则生成一个基于reasoning的语音回复 if self.reasoning: @@ -201,7 +199,7 @@ class SiliconFlowIndexTTSAction(BaseAction): # 如果完全没有内容,使用默认回复 text = "喵~使用SiliconFlow IndexTTS测试语音合成功能~" logger.info(f"{self.log_prefix} 使用默认语音内容") - + # 获取其他参数 speed = float(self.action_data.get("speed", self.get_config("synthesis.speed", 1.0))) @@ -232,18 +230,18 @@ class SiliconFlowIndexTTSAction(BaseAction): ) # 转换为base64编码(语音消息需要base64格式) - audio_base64 = base64.b64encode(audio_data).decode('utf-8') + audio_base64 = base64.b64encode(audio_data).decode("utf-8") # 发送语音消息(使用voice类型,支持WAV格式的base64) await self.send_custom( - message_type="voice", + message_type="voice", content=audio_base64 ) # 记录动作信息 await self.store_action_info( - action_build_into_prompt=True, - action_prompt_display=f"已使用SiliconFlow IndexTTS生成语音: {text[:20]}...", + action_build_into_prompt=True, + action_prompt_display=f"已使用SiliconFlow IndexTTS生成语音: {text[:20]}...", action_done=True ) @@ -252,7 +250,7 @@ class SiliconFlowIndexTTSAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 语音合成失败: {e}") - return False, f"语音合成失败: {str(e)}" + return False, f"语音合成失败: {e!s}" class SiliconFlowTTSCommand(BaseCommand): @@ -267,7 +265,7 @@ class SiliconFlowTTSCommand(BaseCommand): "speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"} } - async def execute(self, text: str, speed: float = 1.0) -> Tuple[bool, str]: + async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]: """执行TTS命令""" logger.info(f"{self.log_prefix} 执行SiliconFlow TTS命令") @@ -289,7 +287,7 @@ class SiliconFlowTTSCommand(BaseCommand): plugin_dir = Path(__file__).parent audio_dir = plugin_dir / "audio_reference" reference_audio_path = audio_dir / "refer.mp3" - + if not reference_audio_path.exists(): logger.warning(f"参考音频文件不存在: {reference_audio_path}") reference_audio_path = None @@ -317,7 +315,7 @@ class SiliconFlowTTSCommand(BaseCommand): # 发送音频 await self.send_custom( - message_type="audio_file", + message_type="audio_file", content=audio_data, filename=filename ) @@ -326,7 +324,7 @@ class SiliconFlowTTSCommand(BaseCommand): return True, "命令执行成功" except Exception as e: - error_msg = f"❌ 语音合成失败: {str(e)}" + error_msg = f"❌ 语音合成失败: {e!s}" await self.send_reply(error_msg) logger.error(f"{self.log_prefix} 命令执行失败: {e}") return False, str(e) @@ -352,7 +350,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): # 配置描述 config_section_descriptions = { "plugin": "插件基本配置", - "components": "组件启用配置", + "components": "组件启用配置", "api": "SiliconFlow API配置", "synthesis": "语音合成配置" } @@ -368,19 +366,19 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): "enable_command": ConfigField(type=bool, default=True, description="是否启用Command组件"), }, "api": { - "api_key": ConfigField(type=str, default="", + "api_key": ConfigField(type=str, default="", description="SiliconFlow API密钥(可选,优先使用全局配置)"), - "base_url": ConfigField(type=str, default="https://api.siliconflow.cn/v1/audio/speech", + "base_url": ConfigField(type=str, default="https://api.siliconflow.cn/v1/audio/speech", description="SiliconFlow TTS API地址"), "timeout": ConfigField(type=int, default=60, description="API请求超时时间(秒)"), "max_retries": ConfigField(type=int, default=3, description="API请求最大重试次数"), }, "synthesis": { - "model": ConfigField(type=str, default="IndexTeam/IndexTTS-2", + "model": ConfigField(type=str, default="IndexTeam/IndexTTS-2", description="TTS模型名称"), - "speed": ConfigField(type=float, default=1.0, + "speed": ConfigField(type=float, default=1.0, description="默认语速 (0.1-3.0)"), - "output_format": ConfigField(type=str, default="wav", + "output_format": ConfigField(type=str, default="wav", description="输出音频格式"), } } @@ -388,9 +386,9 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): def get_plugin_components(self): """获取插件组件""" from src.plugin_system.base.component_types import ActionInfo, CommandInfo, ComponentType - + components = [] - + # 检查配置是否启用组件 if self.get_config("components.enable_action", True): action_info = ActionInfo( @@ -416,7 +414,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): async def on_plugin_load(self): """插件加载时的回调""" logger.info("SiliconFlow IndexTTS插件已加载") - + # 检查audio_reference目录 audio_dir = Path(self.plugin_path) / "audio_reference" if not audio_dir.exists(): @@ -446,4 +444,4 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): async def on_plugin_unload(self): """插件卸载时的回调""" - logger.info("SiliconFlow IndexTTS插件已卸载") \ No newline at end of file + logger.info("SiliconFlow IndexTTS插件已卸载") diff --git a/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py b/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py index f18987d87..4249a6eaf 100644 --- a/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py +++ b/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py @@ -13,56 +13,55 @@ from pathlib import Path import aiohttp import toml - # 设置日志 logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class VoiceUploader: """语音上传器""" - + def __init__(self, api_key: str): self.api_key = api_key self.upload_url = "https://api.siliconflow.cn/v1/uploads/audio/voice" - + async def upload_audio(self, audio_path: str) -> str: """ 上传音频文件并获取voice_id - + Args: audio_path: 音频文件路径 - + Returns: voice_id: 返回的语音ID """ audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"音频文件不存在: {audio_path}") - + # 读取音频文件并转换为base64 with open(audio_path, "rb") as f: audio_data = f.read() - - audio_base64 = base64.b64encode(audio_data).decode('utf-8') - + + audio_base64 = base64.b64encode(audio_data).decode("utf-8") + # 准备请求数据 request_data = { "file": audio_base64, "filename": audio_path } - + headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } - + logger.info(f"正在上传音频文件: {audio_path}") logger.info(f"文件大小: {len(audio_data)} bytes") - + async with aiohttp.ClientSession() as session: async with session.post( self.upload_url, @@ -88,7 +87,7 @@ class VoiceUploader: def load_config(config_path: Path) -> dict: """加载配置文件""" if config_path.exists(): - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, encoding="utf-8") as f: return toml.load(f) return {} @@ -96,7 +95,7 @@ def load_config(config_path: Path) -> dict: def save_config(config_path: Path, config: dict): """保存配置文件""" config_path.parent.mkdir(parents=True, exist_ok=True) - with open(config_path, 'w', encoding='utf-8') as f: + with open(config_path, "w", encoding="utf-8") as f: toml.dump(config, f) @@ -106,23 +105,23 @@ async def main(): print("用法: python upload_voice.py <音频文件路径>") print("示例: python upload_voice.py refer.mp3") sys.exit(1) - + audio_file = sys.argv[1] - + # 获取插件目录 plugin_dir = Path(__file__).parent - + # 加载全局配置获取API key bot_dir = plugin_dir.parents[2] # 回到Bot目录 global_config_path = bot_dir / "config" / "model_config.toml" - + if not global_config_path.exists(): logger.error(f"全局配置文件不存在: {global_config_path}") logger.error("请确保Bot/config/model_config.toml文件存在并配置了SiliconFlow API密钥") sys.exit(1) - + global_config = load_config(global_config_path) - + # 从api_providers中查找SiliconFlow的API密钥 api_key = None api_providers = global_config.get("api_providers", []) @@ -130,40 +129,40 @@ async def main(): if provider.get("name") == "SiliconFlow": api_key = provider.get("api_key") break - + if not api_key: logger.error("在全局配置中未找到SiliconFlow API密钥") logger.error("请在Bot/config/model_config.toml中添加SiliconFlow的api_providers配置:") logger.error("[[api_providers]]") - logger.error("name = \"SiliconFlow\"") - logger.error("base_url = \"https://api.siliconflow.cn/v1\"") - logger.error("api_key = \"your_api_key_here\"") - logger.error("client_type = \"openai\"") + logger.error('name = "SiliconFlow"') + logger.error('base_url = "https://api.siliconflow.cn/v1"') + logger.error('api_key = "your_api_key_here"') + logger.error('client_type = "openai"') sys.exit(1) - + try: # 创建上传器并上传音频 uploader = VoiceUploader(api_key) voice_id = await uploader.upload_audio(audio_file) - + # 更新插件配置 plugin_config_path = plugin_dir / "config.toml" plugin_config = load_config(plugin_config_path) - + if "synthesis" not in plugin_config: plugin_config["synthesis"] = {} - + plugin_config["synthesis"]["voice_id"] = voice_id - + save_config(plugin_config_path, plugin_config) - + logger.info(f"配置已更新!voice_id已保存到: {plugin_config_path}") logger.info("现在可以使用SiliconFlow IndexTTS插件了!") - + except Exception as e: logger.error(f"上传失败: {e}") sys.exit(1) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/plugins/built_in/system_management/plugin.py b/src/plugins/built_in/system_management/plugin.py index 3f14075d2..76706938e 100644 --- a/src/plugins/built_in/system_management/plugin.py +++ b/src/plugins/built_in/system_management/plugin.py @@ -31,7 +31,6 @@ from src.plugin_system.base.component_types import ( from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission -from src.plugin_system.apis.permission_api import permission_api logger = get_logger("SystemManagement") @@ -481,19 +480,19 @@ class SystemCommand(PlusCommand): f" • 加载状态: {details['status']}", ] - if details.get('description'): + if details.get("description"): response_parts.append(f" • 描述: {details['description']}") - if details.get('license'): + if details.get("license"): response_parts.append(f" • 许可证: {details['license']}") # 组件信息 - if details['components']: + if details["components"]: response_parts.append(f"\n🧩 **组件列表** (共 {len(details['components'])} 个):") - for comp in details['components']: - status = "✅" if comp['enabled'] else "❌" + for comp in details["components"]: + status = "✅" if comp["enabled"] else "❌" response_parts.append(f" {status} `{comp['name']}` ({comp['component_type']})") - if comp.get('description'): + if comp.get("description"): response_parts.append(f" {comp['description'][:50]}...") await self._send_long_message("\n".join(response_parts)) @@ -526,11 +525,11 @@ class SystemCommand(PlusCommand): response_parts = ["🧩 **组件类型概览**", ""] for t in ComponentType: comps = plugin_info_api.list_components(t, enabled_only=False) - enabled = sum(1 for c in comps if c['enabled']) + enabled = sum(1 for c in comps if c["enabled"]) if comps: response_parts.append(f"• **{t.value}**: {enabled}/{len(comps)} 启用") - response_parts.append(f"\n💡 使用 `/system plugin list <类型>` 查看详情") + response_parts.append("\n💡 使用 `/system plugin list <类型>` 查看详情") response_parts.append(f"可用类型: {', '.join([f'`{t}`' for t in available_types])}") await self.send_text("\n".join(response_parts)) return @@ -541,7 +540,7 @@ class SystemCommand(PlusCommand): response_parts = [title, ""] for comp in components: - status = "✅" if comp['enabled'] else "❌" + status = "✅" if comp["enabled"] else "❌" response_parts.append(f"{status} `{comp['name']}` (来自: `{comp['plugin_name']}`)") await self._send_long_message("\n".join(response_parts)) @@ -558,7 +557,7 @@ class SystemCommand(PlusCommand): response_parts = [f"🔍 **搜索结果** (关键词: `{keyword}`, 共 {len(results)} 个)", ""] for comp in results: - status = "✅" if comp['enabled'] else "❌" + status = "✅" if comp["enabled"] else "❌" response_parts.append( f"{status} `{comp['name']}` ({comp['component_type']})\n" f" 来自: `{comp['plugin_name']}`" @@ -644,7 +643,7 @@ class SystemCommand(PlusCommand): async def _show_system_report(self): """显示系统插件报告""" report = plugin_info_api.get_system_report() - + response_parts = [ "📊 **系统插件报告**", f" - 已加载插件: {report['system_info']['loaded_plugins_count']}", @@ -655,12 +654,12 @@ class SystemCommand(PlusCommand): response_parts.append("\n✅ **已加载插件:**") for name, info in report["plugins"].items(): response_parts.append(f" • **{info['display_name']} (`{name}`)** v{info['version']} by {info['author']}") - + if report["failed_plugins"]: response_parts.append("\n❌ **加载失败的插件:**") for name, error in report["failed_plugins"].items(): response_parts.append(f" • **`{name}`**: {error}") - + await self._send_long_message("\n".join(response_parts)) @@ -723,7 +722,7 @@ class SystemCommand(PlusCommand): if not found_components: await self.send_text(f"❌ 未找到名为 '{comp_name}' 的组件。") return - + if len(found_components) > 1: suggestions = "\n".join([f"- `{c['name']}` (类型: {c['component_type']})" for c in found_components]) await self.send_text(f"❌ 发现多个名为 '{comp_name}' 的组件,操作已取消。\n找到的组件:\n{suggestions}") @@ -749,7 +748,7 @@ class SystemCommand(PlusCommand): if len(context_args) >= 2: context_type = context_args[0].lower() context_id = context_args[1] - + target_stream = None if context_type == "group": target_stream = chat_api.get_stream_by_group_id( @@ -768,7 +767,7 @@ class SystemCommand(PlusCommand): if not target_stream: await self.send_text(f"❌ 在当前平台找不到指定的 {context_type}: `{context_id}`。") return - + stream_id = target_stream.stream_id # 4. 执行操作 @@ -783,7 +782,7 @@ class SystemCommand(PlusCommand): if success: await self.send_text(f"✅ 在会话 `{stream_id}` 中,已成功将组件 `{comp_name}` ({comp_type_str}) 设置为 {action_text} 状态。") else: - await self.send_text(f"❌ 操作失败。可能无法禁用最后一个启用的 Chatter,或组件不存在。请检查日志。") + await self.send_text("❌ 操作失败。可能无法禁用最后一个启用的 Chatter,或组件不存在。请检查日志。") # ================================================================= diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index a056964c9..ed9bd5df1 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -1,13 +1,13 @@ from typing import ClassVar from src.common.logger import get_logger +from src.config.config import global_config +from src.plugin_system.apis.generator_api import generate_reply from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo from src.plugin_system.base.config_types import ConfigField -from src.plugin_system.apis.generator_api import generate_reply -from src.config.config import global_config logger = get_logger("tts") @@ -52,7 +52,7 @@ class TTSAction(BaseAction): logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}") - success, response_set, _ = await generate_reply( + _success, response_set, _ = await generate_reply( chat_stream=self.chat_stream, reply_message=self.chat_stream.context.get_last_message(), enable_tool=global_config.tool.enable_tool, @@ -78,7 +78,7 @@ class TTSAction(BaseAction): # 处理文本以优化TTS效果 processed_text = self._process_text_for_tts(reply_text) - + try: # 发送TTS消息 await self.send_custom(message_type="tts_text", content=processed_text) diff --git a/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py b/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py index 4d5b09910..8e10418a8 100644 --- a/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py +++ b/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py @@ -75,7 +75,7 @@ class TTSVoiceAction(BaseAction): super().__init__(*args, **kwargs) # 关键配置项现在由 TTSService 管理 self.tts_service = get_service("tts") - + # 动态更新 voice_style 参数描述(包含可用风格) self._update_voice_style_parameter() @@ -101,40 +101,40 @@ class TTSVoiceAction(BaseAction): """安全地获取可用语音风格列表""" try: # 首先尝试从TTS服务获取 - if hasattr(self.tts_service, 'get_available_styles'): + if hasattr(self.tts_service, "get_available_styles"): styles = self.tts_service.get_available_styles() if styles: return styles - + # 回退到直接读取配置文件 plugin_file = Path(__file__).resolve() bot_root = plugin_file.parent.parent.parent.parent.parent.parent config_file = bot_root / "config" / "plugins" / "tts_voice_plugin" / "config.toml" - + if config_file.exists(): - with open(config_file, 'r', encoding='utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = toml.load(f) - styles_config = config.get('tts_styles', []) - + styles_config = config.get("tts_styles", []) + if isinstance(styles_config, list): style_names = [] for style in styles_config: if isinstance(style, dict): - name = style.get('style_name') + name = style.get("style_name") if isinstance(name, str) and name: style_names.append(name) - return style_names if style_names else ['default'] + return style_names if style_names else ["default"] except Exception as e: logger.debug(f"{self.log_prefix} 获取可用语音风格时出错: {e}") - - return ['default'] # 安全回退 + + return ["default"] # 安全回退 @classmethod def get_action_info(cls) -> "ActionInfo": """重写获取Action信息的方法,动态更新参数描述""" # 先调用父类方法获取基础信息 info = super().get_action_info() - + # 尝试动态更新 voice_style 参数描述 try: # 尝试获取可用风格(不创建完整实例) @@ -149,10 +149,10 @@ class TTSVoiceAction(BaseAction): info.action_parameters["voice_style"]["description"] = updated_description except Exception as e: logger.debug(f"[TTSVoiceAction] 在获取Action信息时更新参数描述失败: {e}") - + return info - @classmethod + @classmethod def _get_available_styles_for_info(cls) -> list[str]: """为 get_action_info 方法获取可用风格(类方法版本)""" try: @@ -160,24 +160,24 @@ class TTSVoiceAction(BaseAction): plugin_file = Path(__file__).resolve() bot_root = plugin_file.parent.parent.parent.parent.parent.parent config_file = bot_root / "config" / "plugins" / "tts_voice_plugin" / "config.toml" - + if config_file.exists(): - with open(config_file, 'r', encoding='utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = toml.load(f) - styles_config = config.get('tts_styles', []) - + styles_config = config.get("tts_styles", []) + if isinstance(styles_config, list): style_names = [] for style in styles_config: if isinstance(style, dict): - name = style.get('style_name') + name = style.get("style_name") if isinstance(name, str) and name: style_names.append(name) - return style_names if style_names else ['default'] + return style_names if style_names else ["default"] except Exception: pass - - return ['default'] # 安全回退 + + return ["default"] # 安全回退 async def go_activate(self, llm_judge_model=None) -> bool: """ diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 572b54fe5..c71970a7d 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -106,7 +106,7 @@ class URLParserTool(BaseTool): logger.error("未配置LLM模型") return {"error": "未配置LLM模型"} - success, summary, reasoning, model_name = await llm_api.generate_with_model( + success, summary, _reasoning, _model_name = await llm_api.generate_with_model( prompt=summary_prompt, model_config=model_config, request_type="story.generate", diff --git a/ui_log_adapter.py b/ui_log_adapter.py index bc52d4826..7a9b35d17 100644 --- a/ui_log_adapter.py +++ b/ui_log_adapter.py @@ -99,7 +99,6 @@ class UILogHandler(logging.Handler): if record.levelname == "DEBUG": return - emoji_map = {"info": "", "warning": "", "error": "", "debug": ""} formatted_msg = msg self._send_log_with_retry(formatted_msg, ui_level)