diff --git a/bot.py b/bot.py index 80e6394d1..f58697bec 100644 --- a/bot.py +++ b/bot.py @@ -588,7 +588,7 @@ class MaiBotMain: async def run_async_init(self, main_system): """执行异步初始化步骤""" - + # 初始化数据库表结构 await self.initialize_database_async() diff --git a/scripts/generate_missing_embeddings.py b/scripts/generate_missing_embeddings.py index a8957e50b..afa3f59a6 100644 --- a/scripts/generate_missing_embeddings.py +++ b/scripts/generate_missing_embeddings.py @@ -19,14 +19,13 @@ import asyncio import sys from pathlib import Path -from typing import List # 添加项目根目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) async def generate_missing_embeddings( - target_node_types: List[str] = None, + target_node_types: list[str] = None, batch_size: int = 50, ): """ @@ -46,13 +45,13 @@ async def generate_missing_embeddings( target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] print(f"\n{'='*80}") - print(f"🔧 为节点生成嵌入向量") + print("🔧 为节点生成嵌入向量") print(f"{'='*80}\n") print(f"目标节点类型: {', '.join(target_node_types)}") print(f"批处理大小: {batch_size}\n") # 1. 初始化记忆管理器 - print(f"🔧 正在初始化记忆管理器...") + print("🔧 正在初始化记忆管理器...") await initialize_memory_manager() manager = get_memory_manager() @@ -60,10 +59,10 @@ async def generate_missing_embeddings( print("❌ 记忆管理器初始化失败") return - print(f"✅ 记忆管理器已初始化\n") + print("✅ 记忆管理器已初始化\n") # 2. 获取已索引的节点ID - print(f"🔍 检查现有向量索引...") + print("🔍 检查现有向量索引...") existing_node_ids = set() try: vector_count = manager.vector_store.collection.count() @@ -78,14 +77,14 @@ async def generate_missing_embeddings( ) if result and "ids" in result: existing_node_ids.update(result["ids"]) - + print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") except Exception as e: logger.warning(f"获取已索引节点ID失败: {e}") - print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n") + print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n") # 3. 收集需要生成嵌入的节点 - print(f"🔍 扫描需要生成嵌入的节点...") + print("🔍 扫描需要生成嵌入的节点...") all_memories = manager.graph_store.get_all_memories() nodes_to_process = [] @@ -110,7 +109,7 @@ async def generate_missing_embeddings( }) type_stats[node.node_type.value]["need_emb"] += 1 - print(f"\n📊 扫描结果:") + print("\n📊 扫描结果:") for node_type in target_node_types: stats = type_stats[node_type] already_ok = stats["already_indexed"] @@ -121,11 +120,11 @@ async def generate_missing_embeddings( print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") if len(nodes_to_process) == 0: - print(f"✅ 所有节点已有嵌入向量,无需生成") + print("✅ 所有节点已有嵌入向量,无需生成") return # 3. 批量生成嵌入 - print(f"🚀 开始生成嵌入向量...\n") + print("🚀 开始生成嵌入向量...\n") total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size success_count = 0 @@ -193,22 +192,22 @@ async def generate_missing_embeddings( print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") # 4. 保存图数据(更新节点的 embedding 字段) - print(f"💾 保存图数据...") + print("💾 保存图数据...") try: await manager.persistence.save_graph_store(manager.graph_store) - print(f"✅ 图数据已保存\n") + print("✅ 图数据已保存\n") except Exception as e: - logger.error(f"保存图数据失败", exc_info=True) + logger.error("保存图数据失败", exc_info=True) print(f"❌ 保存失败: {e}\n") # 5. 验证结果 - print(f"🔍 验证向量索引...") + print("🔍 验证向量索引...") final_vector_count = manager.vector_store.collection.count() stats = manager.graph_store.get_statistics() total_nodes = stats["total_nodes"] print(f"\n{'='*80}") - print(f"📊 生成完成") + print("📊 生成完成") print(f"{'='*80}") print(f"处理节点数: {len(nodes_to_process)}") print(f"成功生成: {success_count}") @@ -219,7 +218,7 @@ async def generate_missing_embeddings( print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") # 6. 测试搜索 - print(f"🧪 测试搜索功能...") + print("🧪 测试搜索功能...") test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] for query in test_queries: diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index db0fdbd73..8b5d6f3df 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") # ========== 性能配置参数 ========== -# +# # 知识提取(步骤2:txt转json)并发控制 # - 控制同时进行的LLM提取请求数量 # - 推荐值: 3-10,取决于API速率限制 @@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api): tuple: (doc_item或None, failed_hash或None) """ temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") - + # 🔧 优化:使用异步文件检查,避免阻塞 if os.path.exists(temp_file_path): try: @@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api): "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } - + # 保存到缓存(异步写入) async with aiofiles.open(temp_file_path, "wb") as f: await f.write(orjson.dumps(doc_item)) - + return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") @@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set): os.makedirs(TEMP_DIR, exist_ok=True) failed_hashes, open_ie_docs = [], [] - + # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") - + # 🔧 并发控制:限制最大并发数,防止速率限制 semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) - + async def extract_with_semaphore(pg_hash, paragraph): """带信号量控制的提取函数""" async with semaphore: @@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set): extract_with_semaphore(p_hash, paragraph) for p_hash, paragraph in paragraphs_dict.items() ] - + total = len(tasks) completed = 0 @@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set): TimeRemainingColumn(), ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=total) - + # 🔧 优化:使用 asyncio.gather 并发执行所有任务 # return_exceptions=True 确保单个失败不影响其他任务 for coro in asyncio.as_completed(tasks): @@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set): failed_hashes.append(failed_hash) elif doc_item: open_ie_docs.append(doc_item) - + completed += 1 progress.update(task, advance=1) @@ -415,7 +415,7 @@ def rebuild_faiss_only(): logger.info("--- 重建 FAISS 索引 ---") # 重建索引不需要并发参数(不涉及 embedding 生成) embed_manager = EmbeddingManager() - + logger.info("正在加载现有的 Embedding 库...") try: embed_manager.load_from_file() diff --git a/src/api/memory_visualizer_router.py b/src/api/memory_visualizer_router.py index 84971f78a..e1137d684 100644 --- a/src/api/memory_visualizer_router.py +++ b/src/api/memory_visualizer_router.py @@ -4,13 +4,13 @@ 提供 Web API 用于可视化记忆图数据 """ +from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional -from collections import defaultdict +from typing import Any import orjson -from fastapi import APIRouter, HTTPException, Request, Query +from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates @@ -29,7 +29,7 @@ router = APIRouter() templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) -def find_available_data_files() -> List[Path]: +def find_available_data_files() -> list[Path]: """查找所有可用的记忆图数据文件""" files = [] if not data_dir.exists(): @@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]: return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) -def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]: +def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]: """从磁盘加载图数据""" global graph_data_cache, current_data_file @@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any if not graph_file.exists(): return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}} - with open(graph_file, "r", encoding="utf-8") as f: + with open(graph_file, encoding="utf-8") as f: data = orjson.loads(f.read()) nodes = data.get("nodes", []) @@ -150,7 +150,7 @@ async def index(request: Request): return templates.TemplateResponse("visualizer.html", {"request": request}) -def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: +def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]: """从 MemoryManager 提取并格式化图数据""" if not memory_manager.graph_store: return {"nodes": [], "edges": [], "memories": [], "stats": {}} @@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: "arrows": "to", "memory_id": memory.id, } - + edges_list = list(edges_dict.values()) stats = memory_manager.get_statistics() @@ -261,7 +261,7 @@ async def get_paginated_graph( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(500, ge=100, le=2000, description="每页节点数"), min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"), - node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"), + node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"), ): """分页获取图数据,支持重要性过滤""" try: @@ -301,13 +301,13 @@ async def get_paginated_graph( total_pages = (total_nodes + page_size - 1) // page_size start_idx = (page - 1) * page_size end_idx = min(start_idx + page_size, total_nodes) - + paginated_nodes = nodes_with_importance[start_idx:end_idx] node_ids = set(n["id"] for n in paginated_nodes) # 只保留连接分页节点的边 paginated_edges = [ - e for e in edges + e for e in edges if e.get("from") in node_ids and e.get("to") in node_ids ] @@ -383,7 +383,7 @@ async def get_clustered_graph( return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) -def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict: +def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict: """简单的图聚类算法:按类型和连接度聚类""" # 构建邻接表 adjacency = defaultdict(set) @@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl for node in type_nodes: importance = len(adjacency[node["id"]]) node_importance.append((node, importance)) - + node_importance.sort(key=lambda x: x[1], reverse=True) - + # 保留前N个重要节点 keep_count = min(len(type_nodes), max_nodes // len(type_groups)) for node, importance in node_importance[:keep_count]: clustered_nodes.append(node) node_mapping[node["id"]] = node["id"] - + # 其余节点聚合为一个超级节点 if len(node_importance) > keep_count: clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]] cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}" cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)" - + clustered_nodes.append({ "id": cluster_id, "label": cluster_label, @@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl "cluster_size": len(clustered_node_ids), "clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示 }) - + for node_id in clustered_node_ids: node_mapping[node_id] = cluster_id @@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl for edge in edges: from_id = node_mapping.get(edge["from"]) to_id = node_mapping.get(edge["to"]) - + if from_id and to_id and from_id != to_id: edge_key = tuple(sorted([from_id, to_id])) if edge_key not in edge_set: diff --git a/src/api/statistic_router.py b/src/api/statistic_router.py index 97c80afa1..54f6836bf 100644 --- a/src/api/statistic_router.py +++ b/src/api/statistic_router.py @@ -1,6 +1,5 @@ -from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Literal +from typing import Literal from fastapi import APIRouter, HTTPException, Query diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 72be0f0f4..b286fa968 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -161,16 +161,16 @@ class EmbeddingStore: # 限制 chunk_size 和 max_workers 在合理范围内 chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) - + semaphore = asyncio.Semaphore(max_workers) llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") results = {} - + # 将字符串列表分成多个 chunk chunks = [] for i in range(0, len(strs), chunk_size): chunks.append(strs[i : i + chunk_size]) - + async def _process_chunk(chunk: list[str]): """处理一个 chunk 的字符串(批量获取 embedding)""" async with semaphore: @@ -180,12 +180,12 @@ class EmbeddingStore: embedding = await EmbeddingStore._get_embedding_async(llm, s) embeddings.append(embedding) results[s] = embedding - + if progress_callback: progress_callback(len(chunk)) - + return embeddings - + # 并发处理所有 chunks tasks = [_process_chunk(chunk) for chunk in chunks] await asyncio.gather(*tasks) @@ -418,22 +418,22 @@ class EmbeddingStore: # 🔧 修复:检查所有 embedding 的维度是否一致 dimensions = [len(emb) for emb in array] unique_dims = set(dimensions) - + if len(unique_dims) > 1: logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") - + # 获取期望的维度(使用最常见的维度) from collections import Counter dim_counter = Counter(dimensions) expected_dim = dim_counter.most_common(1)[0][0] logger.warning(f"将使用最常见的维度: {expected_dim}") - + # 过滤掉维度不匹配的 embedding filtered_array = [] filtered_idx2hash = {} skipped_count = 0 - + for i, emb in enumerate(array): if len(emb) == expected_dim: filtered_array.append(emb) @@ -442,11 +442,11 @@ class EmbeddingStore: skipped_count += 1 hash_key = self.idx2hash[str(i)] logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") - + logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") array = filtered_array self.idx2hash = filtered_idx2hash - + if not array: logger.error("过滤后没有可用的 embedding,无法构建索引") embedding_dim = expected_dim diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 175320774..c8bd18a08 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -13,4 +13,4 @@ __all__ = [ "StreamLoopManager", "message_manager", "stream_loop_manager", -] \ No newline at end of file +] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index f3dad2fb9..34731952d 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -82,7 +82,7 @@ class SingleStreamContextManager: self.total_messages += 1 self.last_access_time = time.time() - + # 如果使用了缓存系统,输出调试信息 if cache_enabled and self.context.is_cache_enabled: if self.context.is_chatter_processing: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index c88808d1d..6ae951663 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -111,9 +111,9 @@ class StreamLoopManager: # 获取或创建该流的启动锁 if stream_id not in self._stream_start_locks: self._stream_start_locks[stream_id] = asyncio.Lock() - + lock = self._stream_start_locks[stream_id] - + # 使用锁防止并发启动同一个流的多个循环任务 async with lock: # 获取流上下文 @@ -148,7 +148,7 @@ class StreamLoopManager: # 紧急取消 context.stream_loop_task.cancel() await asyncio.sleep(0.1) - + loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") # 将任务记录到 StreamContext 中 @@ -252,7 +252,7 @@ class StreamLoopManager: self.stats["total_process_cycles"] += 1 if success: logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功") - + # 🔒 处理成功后,等待一小段时间确保清理操作完成 # 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环 await asyncio.sleep(0.1) @@ -382,7 +382,7 @@ class StreamLoopManager: self.chatter_manager.process_stream_context(stream_id, context), name=f"chatter_process_{stream_id}" ) - + # 等待 chatter 任务完成 results = await chatter_task success = results.get("success", False) @@ -398,8 +398,8 @@ class StreamLoopManager: else: logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") - return success - except asyncio.CancelledError: + return success + except asyncio.CancelledError: if chatter_task and not chatter_task.done(): chatter_task.cancel() raise @@ -709,4 +709,4 @@ class StreamLoopManager: # 全局流循环管理器实例 -stream_loop_manager = StreamLoopManager() \ No newline at end of file +stream_loop_manager = StreamLoopManager() diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 9952a88d4..29da4f068 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -417,7 +417,7 @@ class MessageManager: return # 记录详细信息 - msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" + msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" for msg in unread_messages[:3]] # 只显示前3条 logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}") @@ -446,15 +446,15 @@ class MessageManager: context = chat_stream.context_manager.context if hasattr(context, "unread_messages") and context.unread_messages: unread_count = len(context.unread_messages) - + # 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们 - if unread_count > 0: + if unread_count > 0: # 获取所有未读消息的 ID message_ids = [msg.message_id for msg in context.unread_messages] - + # 标记为已读(会移到历史消息) success = chat_stream.context_manager.mark_messages_as_read(message_ids) - + if success: logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读") else: @@ -481,7 +481,7 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): + if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): chat_stream.context_manager.context.is_chatter_processing = is_processing logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}") except Exception as e: @@ -517,7 +517,7 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): + if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): return chat_stream.context_manager.context.is_chatter_processing except Exception: pass @@ -677,4 +677,4 @@ class MessageManager: # 创建全局消息管理器实例 -message_manager = MessageManager() \ No newline at end of file +message_manager = MessageManager() diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index f979c1bea..0857089b4 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -248,16 +248,16 @@ class ChatterActionManager: try: # 根据动作类型确定提示词模式 prompt_mode = "s4u" if action_name == "reply" else "normal" - + # 将prompt_mode传递给generate_reply action_data_with_mode = (action_data or {}).copy() action_data_with_mode["prompt_mode"] = prompt_mode - + # 只传递当前正在执行的动作,而不是所有可用动作 # 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用" current_action_info = self._using_actions.get(action_name) current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {} - + # 附加目标消息信息(如果存在) if target_message: # 提取目标消息的关键信息 @@ -268,7 +268,7 @@ class ChatterActionManager: "time": getattr(target_message, "time", 0), } current_actions["_target_message"] = target_msg_info - + success, response_set, _ = await generator_api.generate_reply( chat_stream=chat_stream, reply_message=target_message, @@ -295,12 +295,12 @@ class ChatterActionManager: should_quote_reply = None if action_data and isinstance(action_data, dict): should_quote_reply = action_data.get("should_quote_reply", None) - + # respond动作默认不引用回复,保持对话流畅 if action_name == "respond" and should_quote_reply is None: should_quote_reply = False - async def _after_reply(): + async def _after_reply(): # 发送并存储回复 loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( chat_stream, diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 2e370679e..c518d6d1d 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -365,7 +365,7 @@ class DefaultReplyer: # 确保类型安全 if isinstance(mode, str): prompt_mode_value = mode - + # 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_reply_context( @@ -1171,16 +1171,16 @@ class DefaultReplyer: from src.plugin_system.apis.chat_api import get_chat_manager chat_manager = get_chat_manager() chat_stream_obj = await chat_manager.get_stream(chat_id) - + if chat_stream_obj: unread_messages = chat_stream_obj.context_manager.get_unread_messages() if unread_messages: # 使用最后一条未读消息作为参考 last_msg = unread_messages[-1] - platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform - user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else "" - user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else "" - user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else "" + platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform + user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else "" + user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else "" + user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else "" processed_plain_text = last_msg.processed_plain_text or "" else: # 没有未读消息,使用默认值 @@ -1263,19 +1263,19 @@ class DefaultReplyer: if available_actions: # 过滤掉特殊键(以_开头) action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")} - + # 提取目标消息信息(如果存在) target_msg_info = available_actions.get("_target_message") # type: ignore - + if action_items: if len(action_items) == 1: # 单个动作 action_name, action_info = list(action_items.items())[0] action_desc = action_info.description - + # 构建基础决策信息 action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n" - + # 只有需要目标消息的动作才显示目标消息详情 # respond 动作是统一回应所有未读消息,不应该显示特定目标消息 if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict): @@ -1284,7 +1284,7 @@ class DefaultReplyer: content = target_msg_info.get("content", "") msg_time = target_msg_info.get("time", 0) time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间" - + action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n" else: # 多个动作 @@ -2137,7 +2137,7 @@ class DefaultReplyer: except Exception as e: logger.error(f"存储聊天记忆失败: {e}") - + def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/security/__init__.py b/src/chat/security/__init__.py index 328211db1..3700bb82a 100644 --- a/src/chat/security/__init__.py +++ b/src/chat/security/__init__.py @@ -5,12 +5,12 @@ 插件可以通过实现这些接口来扩展安全功能。 """ -from .interfaces import SecurityCheckResult, SecurityChecker +from .interfaces import SecurityChecker, SecurityCheckResult from .manager import SecurityManager, get_security_manager __all__ = [ - "SecurityChecker", "SecurityCheckResult", + "SecurityChecker", "SecurityManager", "get_security_manager", ] diff --git a/src/chat/security/manager.py b/src/chat/security/manager.py index 1ddc3055a..e160e8a74 100644 --- a/src/chat/security/manager.py +++ b/src/chat/security/manager.py @@ -10,7 +10,7 @@ from typing import Any from src.common.logger import get_logger -from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel +from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel logger = get_logger("security.manager") diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py index 3c68630d1..610bbe0c2 100644 --- a/src/chat/utils/prompt_component_manager.py +++ b/src/chat/utils/prompt_component_manager.py @@ -1,5 +1,7 @@ import asyncio +import copy import re +from collections.abc import Awaitable, Callable from src.chat.utils.prompt_params import PromptParameters from src.common.logger import get_logger @@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager") class PromptComponentManager: """ - 管理所有 `BasePrompt` 组件的单例类。 + 一个统一的、动态的、可观测的提示词组件管理中心。 - 该管理器负责: - 1. 从 `component_registry` 中查询 `BasePrompt` 子类。 - 2. 根据注入点(目标Prompt名称)对它们进行筛选。 - 3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。 + 该管理器是整个提示词动态注入系统的核心,它负责: + 1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件, + 并将其静态定义的 `injection_rules` 加载为默认的动态规则。 + 2. **动态管理**: 提供线程安全的 API,允许在运行时动态地添加、更新或移除注入规则, + 使得提示词的结构可以被实时调整。 + 3. **状态观测**: 提供丰富的查询 API,用于观测系统当前完整的注入状态, + 例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。 + 4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集, + 动态地修改和装配提示词模板,实现灵活的提示词组合。 """ - def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]: - """ - 获取指定目标Prompt的所有注入规则及其关联的组件类。 + def __init__(self): + """初始化管理器实例。""" + # _dynamic_rules 是管理器的核心状态,存储所有注入规则。 + # 结构: { + # "target_prompt_name": { + # "prompt_component_name": (InjectionRule, content_provider, source) + # } + # } + # content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。 + # source 记录了规则的来源(例如 "static_default" 或 "runtime")。 + self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {} + self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。 + self._initialized = False # 标记静态规则是否已加载,防止重复加载。 - Args: - target_prompt_name (str): 目标 Prompt 的名称。 + # --- 核心生命周期与初始化 --- - Returns: - list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表, - 每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。 + def load_static_rules(self): """ - # 从注册表中获取所有已启用的 PROMPT 类型的组件 + 在系统启动时加载所有静态注入规则。 + + 该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件, + 将其类变量 `injection_rules` 转换为管理器的动态规则。 + 这确保了所有插件定义的默认注入行为在系统启动时就能生效。 + 此操作是幂等的,一旦初始化完成就不会重复执行。 + """ + if self._initialized: + return + logger.info("正在加载静态 Prompt 注入规则...") + + # 从组件注册表中获取所有已启用的 Prompt 组件 enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT) - matching_rules = [] - # 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则 for prompt_name, prompt_info in enabled_prompts.items(): if not isinstance(prompt_info, PromptInfo): continue - # prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表 - for rule in prompt_info.injection_rules: - # 如果规则的目标是当前指定的 Prompt - if rule.target_prompt == target_prompt_name: - # 获取该规则对应的组件类 - component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT) - # 确保获取到的确实是一个 BasePrompt 的子类 - if component_class and issubclass(component_class, BasePrompt): - matching_rules.append((rule, component_class)) + component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT) + if not (component_class and issubclass(component_class, BasePrompt)): + logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。") + continue - # 根据规则的优先级进行排序,数字越小,优先级越高,越先应用 - matching_rules.sort(key=lambda x: x[0].priority) - return matching_rules + def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]: + """ + 为静态组件创建一个内容提供者闭包 (Content Provider Closure)。 + + 这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。 + 当 `apply_injections` 需要内容时,它会调用这个函数。 + 函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。 + + Args: + cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。 + + Returns: + Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。 + """ + + async def content_provider(params: PromptParameters) -> str: + """实际执行内容生成的异步函数。""" + try: + # 从注册表获取最新的组件信息,包括插件配置 + p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT) + plugin_config = {} + if isinstance(p_info, PromptInfo): + plugin_config = component_registry.get_plugin_config(p_info.plugin_name) + + # 实例化组件并执行 + instance = cls(params=params, plugin_config=plugin_config) + result = await instance.execute() + return str(result) if result is not None else "" + except Exception as e: + logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True) + return "" # 出错时返回空字符串,避免影响主流程 + + return content_provider + + # 为该组件的每条静态注入规则创建并注册一个动态规则 + for rule in prompt_info.injection_rules: + provider = create_provider(component_class) + target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {}) + target_rules[prompt_name] = (rule, provider, "static_default") + + self._initialized = True + logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。") + + # --- 运行时规则管理 API --- + + async def add_injection_rule( + self, + prompt_name: str, + rule: InjectionRule, + content_provider: Callable[..., Awaitable[str]], + source: str = "runtime", + ) -> bool: + """ + 动态添加或更新一条注入规则。 + + 此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。 + 如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。 + + Args: + prompt_name (str): 动态注入组件的唯一名称。 + rule (InjectionRule): 描述注入行为的规则对象。 + content_provider (Callable[..., Awaitable[str]]): + 一个异步函数,用于在应用注入时动态生成内容。 + 函数签名应为: `async def provider(params: "PromptParameters") -> str` + source (str, optional): 规则的来源标识,默认为 "runtime"。 + + Returns: + bool: 如果成功添加或更新,则返回 True。 + """ + async with self._lock: + target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {}) + target_rules[prompt_name] = (rule, content_provider, source) + logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})") + return True + + async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool: + """ + 移除一条动态注入规则。 + + Args: + prompt_name (str): 要移除的注入组件的名称。 + target_prompt (str): 该组件注入的目标核心提示词名称。 + + Returns: + bool: 如果成功移除,则返回 True;如果规则不存在,则返回 False。 + """ + async with self._lock: + if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]: + del self._dynamic_rules[target_prompt][prompt_name] + # 如果目标下已无任何规则,则清理掉这个键 + if not self._dynamic_rules[target_prompt]: + del self._dynamic_rules[target_prompt] + logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'") + return True + logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'") + return False + + # --- 核心注入逻辑 --- async def apply_injections( self, target_prompt_name: str, original_template: str, params: PromptParameters ) -> str: """ - 获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。 + 【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。 - 这是一个三步走的过程: - 1. 实例化所有需要执行的组件。 - 2. 并行执行它们的 `execute` 方法以获取注入内容。 - 3. 按照优先级顺序,将内容注入到原始模板中。 + 这是提示词构建流程中的关键步骤。它会执行以下操作: + 1. 检查并确保静态规则已加载。 + 2. 获取所有注入到 `target_prompt_name` 的规则。 + 3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用。 + 4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。 + 5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。 Args: - target_prompt_name (str): 目标 Prompt 的名称。 - original_template (str): 原始的、未经修改的 Prompt 模板字符串。 - params (PromptParameters): 传递给 Prompt 组件实例的参数。 + target_prompt_name (str): 目标核心提示词的名称。 + original_template (str): 未经修改的原始提示词模板。 + params (PromptParameters): 当前请求的参数,会传递给 `content_provider`。 Returns: - str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。 + str: 应用了所有注入规则后,最终生成的提示词模板字符串。 """ - rules_with_classes = self._get_rules_for(target_prompt_name) - # 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干 - if not rules_with_classes: + if not self._initialized: + self.load_static_rules() + + # 步骤 1: 获取所有指向当前目标的规则 + # 使用 .values() 获取 (rule, provider, source) 元组列表 + rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values()) + if not rules_for_target: return original_template - # --- 第一步: 实例化所有需要执行的组件 --- - instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错 - tasks = [] # 存放所有需要并行执行的 execute 异步任务 - components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射 + # 步骤 2: 按优先级排序,数字越小越优先 + rules_for_target.sort(key=lambda x: x[0].priority) - for rule, component_class in rules_with_classes: - # 如果注入类型是 REMOVE,那就不需要执行组件了,因为它不产生内容 + # 步骤 3: 依次执行内容提供者并根据注入类型修改模板 + modified_template = original_template + for rule, provider, source in rules_for_target: + content = "" + # 对于非 REMOVE 类型的注入,需要先获取内容 if rule.injection_type != InjectionType.REMOVE: try: - # 获取组件的元信息,主要是为了拿到插件名称来读取插件配置 - prompt_info = component_registry.get_component_info( - component_class.prompt_name, ComponentType.PROMPT - ) - if not isinstance(prompt_info, PromptInfo): - plugin_config = {} - else: - # 从注册表获取该组件所属插件的配置 - plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name) - - # 实例化组件,并传入参数和插件配置 - instance = component_class(params=params, plugin_config=plugin_config) - instance_map[component_class.prompt_name] = instance - # 将组件的 execute 方法作为一个任务添加到列表中 - tasks.append(instance.execute()) - components_to_execute.append(component_class) + content = await provider(params) except Exception as e: - logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}") - # 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步 - tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore - - # --- 第二步: 并行执行所有组件的 execute 方法 --- - # 使用 asyncio.gather 来同时运行所有任务,提高效率 - results = await asyncio.gather(*tasks, return_exceptions=True) - # 创建一个从组件名到执行结果的映射,方便后续查找 - result_map = { - components_to_execute[i].prompt_name: res - for i, res in enumerate(results) - if not isinstance(res, Exception) # 只包含成功的结果 - } - # 单独处理并记录执行失败的组件 - for i, res in enumerate(results): - if isinstance(res, Exception): - logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}") - - # --- 第三步: 按优先级顺序应用注入规则 --- - modified_template = original_template - for rule, component_class in rules_with_classes: - # 从结果映射中获取该组件生成的内容 - content = result_map.get(component_class.prompt_name) + logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True) + continue # 跳过失败的 provider,不中断整个流程 + # 应用注入逻辑 try: if rule.injection_type == InjectionType.PREPEND: if content: @@ -136,28 +221,178 @@ class PromptComponentManager: if content: modified_template = f"{modified_template}\n{content}" elif rule.injection_type == InjectionType.REPLACE: - # 使用正则表达式替换目标内容 - if content and rule.target_content: + # 只有在 content 不为 None 且 target_content 有效时才执行替换 + if content is not None and rule.target_content: modified_template = re.sub(rule.target_content, str(content), modified_template) elif rule.injection_type == InjectionType.INSERT_AFTER: - # 在匹配到的内容后面插入 if content and rule.target_content: - # re.sub a little trick: \g<0> represents the entire matched string + # 使用 `\g<0>` 在正则匹配的整个内容后添加新内容 replacement = f"\\g<0>\n{content}" modified_template = re.sub(rule.target_content, replacement, modified_template) elif rule.injection_type == InjectionType.REMOVE: - # 使用正则表达式移除目标内容 if rule.target_content: modified_template = re.sub(rule.target_content, "", modified_template) except re.error as e: - logger.error( - f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')" - ) + logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')") except Exception as e: - logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}") + logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True) return modified_template + async def preview_prompt_injections( + self, target_prompt_name: str, params: PromptParameters + ) -> str: + """ + 【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。 -# 创建全局单例 + 这个方法对于调试和测试非常有用,可以查看在特定参数下, + 一个核心提示词经过所有注入规则处理后会变成什么样子。 + + Args: + target_prompt_name (str): 希望预览的目标核心提示词名称。 + params (PromptParameters): 模拟的请求参数。 + + Returns: + str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。 + """ + try: + # 从全局提示词管理器获取最原始的模板内容 + from src.chat.utils.prompt import global_prompt_manager + original_prompt = global_prompt_manager._prompts.get(target_prompt_name) + if not original_prompt: + logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。") + return f"Error: Prompt '{target_prompt_name}' not found." + original_template = original_prompt.template + except KeyError: + logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。") + return f"Error: Prompt '{target_prompt_name}' not found." + + # 直接调用核心注入逻辑来模拟结果 + return await self.apply_injections(target_prompt_name, original_template, params) + + # --- 状态观测与查询 API --- + + def get_core_prompts(self) -> list[str]: + """获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。""" + from src.chat.utils.prompt import global_prompt_manager + return list(global_prompt_manager._prompts.keys()) + + def get_core_prompt_contents(self) -> dict[str, str]: + """获取所有核心提示词模板的原始内容。""" + from src.chat.utils.prompt import global_prompt_manager + return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()} + + def get_registered_prompt_component_info(self) -> list[PromptInfo]: + """获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。""" + components = component_registry.get_components_by_type(ComponentType.PROMPT).values() + return [info for info in components if isinstance(info, PromptInfo)] + + async def get_full_injection_map(self) -> dict[str, list[dict]]: + """ + 获取当前完整的注入映射图。 + + 此方法提供了一个系统全局的注入视图,展示了每个核心提示词(target) + 被哪些注入组件(source)以何种优先级注入。 + + Returns: + dict[str, list[dict]]: 一个字典,键是目标提示词名称, + 值是按优先级排序的注入信息列表。 + `[{"name": str, "priority": int, "source": str}]` + """ + injection_map = {} + async with self._lock: + # 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含 + all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts()) + for target in sorted(all_targets): + rules = self._dynamic_rules.get(target, {}) + if not rules: + injection_map[target] = [] + continue + + info_list = [] + for prompt_name, (rule, _, source) in rules.items(): + info_list.append({"name": prompt_name, "priority": rule.priority, "source": source}) + + # 按优先级排序后存入 map + info_list.sort(key=lambda x: x["priority"]) + injection_map[target] = info_list + return injection_map + + async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]: + """ + 获取指定核心提示词模板的所有注入信息(包含详细规则)。 + + Args: + target_prompt_name (str): 目标核心提示词的名称。 + + Returns: + list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。 + """ + rules_for_target = self._dynamic_rules.get(target_prompt_name, {}) + if not rules_for_target: + return [] + + info_list = [] + for prompt_name, (rule, _, source) in rules_for_target.items(): + info_list.append( + { + "name": prompt_name, + "priority": rule.priority, + "source": source, + "injection_type": rule.injection_type.value, + "target_content": rule.target_content, + } + ) + info_list.sort(key=lambda x: x["priority"]) + return info_list + + def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]: + """ + 获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。 + + 此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。 + 适合用于展示或序列化当前的规则配置。 + """ + rules_copy = {} + for target, rules in self._dynamic_rules.items(): + target_copy = {name: rule for name, (rule, _, _) in rules.items()} + rules_copy[target] = target_copy + return copy.deepcopy(rules_copy) + + def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]: + """ + 获取所有注入到指定核心提示词的动态规则。 + + Args: + target_prompt (str): 目标核心提示词的名称。 + + Returns: + dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。 + 如果找不到任何注入到该目标的规则,则返回一个空字典。 + """ + target_rules = self._dynamic_rules.get(target_prompt, {}) + return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()} + + def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]: + """ + 获取由指定的单个注入组件定义的所有动态规则。 + + Args: + component_name (str): 注入组件的名称。 + + Returns: + dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。 + 如果该组件没有定义任何注入规则,则返回一个空字典。 + """ + found_rules = {} + for target, rules in self._dynamic_rules.items(): + if component_name in rules: + rule_info = rules[component_name] + found_rules[target] = copy.deepcopy(rule_info[0]) + return found_rules + + +# 创建全局单例 (Singleton) +# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例, +# 以确保所有部分都共享和操作同一份动态规则集。 prompt_component_manager = PromptComponentManager() diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index 108138c5b..8afbbde57 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -98,7 +98,7 @@ class StreamContext(BaseDataModel): break def mark_message_as_read(self, message_id: str): - """标记消息为已读""" + """标记消息为已读""" # 先找到要标记的消息(处理 int/str 类型不匹配问题) message_to_mark = None for msg in self.unread_messages: @@ -106,7 +106,7 @@ class StreamContext(BaseDataModel): if str(msg.message_id) == str(message_id): message_to_mark = msg break - + # 然后移动到历史消息 if message_to_mark: message_to_mark.is_read = True diff --git a/src/common/database/optimization/cache_manager.py b/src/common/database/optimization/cache_manager.py index 60b0b0170..27b7b33a2 100644 --- a/src/common/database/optimization/cache_manager.py +++ b/src/common/database/optimization/cache_manager.py @@ -9,11 +9,12 @@ """ import asyncio +import builtins import time from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union +from typing import Any, Generic, TypeVar from src.common.logger import get_logger from src.common.memory_utils import estimate_size_smart @@ -96,7 +97,7 @@ class LRUCache(Generic[T]): self._lock = asyncio.Lock() self._stats = CacheStats() - async def get(self, key: str) -> Optional[T]: + async def get(self, key: str) -> T | None: """获取缓存值 Args: @@ -137,8 +138,8 @@ class LRUCache(Generic[T]): self, key: str, value: T, - size: Optional[int] = None, - ttl: Optional[float] = None, + size: int | None = None, + ttl: float | None = None, ) -> None: """设置缓存值 @@ -287,8 +288,8 @@ class MultiLevelCache: async def get( self, key: str, - loader: Optional[Callable[[], Any]] = None, - ) -> Optional[Any]: + loader: Callable[[], Any] | None = None, + ) -> Any | None: """从缓存获取数据 查询顺序:L1 -> L2 -> loader @@ -329,8 +330,8 @@ class MultiLevelCache: self, key: str, value: Any, - size: Optional[int] = None, - ttl: Optional[float] = None, + size: int | None = None, + ttl: float | None = None, ) -> None: """设置缓存值 @@ -390,7 +391,7 @@ class MultiLevelCache: await self.l2_cache.clear() logger.info("所有缓存已清空") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)""" # 🔧 修复:并行获取统计信息,避免锁嵌套 l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1")) @@ -492,7 +493,7 @@ class MultiLevelCache: logger.error(f"{cache_name}统计获取异常: {e}") return CacheStats() - async def _get_cache_keys_safe(self, cache) -> Set[str]: + async def _get_cache_keys_safe(self, cache) -> builtins.set[str]: """安全获取缓存键集合(带超时)""" try: # 快速获取键集合,使用超时避免死锁 @@ -507,12 +508,12 @@ class MultiLevelCache: logger.error(f"缓存键获取异常: {e}") return set() - async def _extract_keys_with_lock(self, cache) -> Set[str]: + async def _extract_keys_with_lock(self, cache) -> builtins.set[str]: """在锁保护下提取键集合""" async with cache._lock: return set(cache._cache.keys()) - async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int: + async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int: """安全计算内存使用(带超时)""" if not keys: return 0 @@ -529,7 +530,7 @@ class MultiLevelCache: logger.error(f"内存计算异常: {e}") return 0 - async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int: + async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int: """在锁保护下计算内存使用""" total_size = 0 async with cache._lock: @@ -749,7 +750,7 @@ class MultiLevelCache: # 全局缓存实例 -_global_cache: Optional[MultiLevelCache] = None +_global_cache: MultiLevelCache | None = None _cache_lock = asyncio.Lock() diff --git a/src/common/server.py b/src/common/server.py index a6ed588d8..f4553f537 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -3,7 +3,6 @@ import socket from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles from rich.traceback import install from uvicorn import Config from uvicorn import Server as UvicornServer diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 2bf07c734..7245a79db 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -436,7 +436,7 @@ class OpenaiClient(BaseClient): # 🔧 优化:增加连接池限制,支持高并发embedding请求 # 默认httpx限制为100,对于高频embedding场景不够用 import httpx - + limits = httpx.Limits( max_keepalive_connections=200, # 保持活跃连接数(原100) max_connections=300, # 最大总连接数(原100) diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py index 4b0d66218..00f55c0fa 100644 --- a/src/memory_graph/core/builder.py +++ b/src/memory_graph/core/builder.py @@ -128,7 +128,7 @@ class MemoryBuilder: # 6. 构建 Memory 对象 # 新记忆应该有较高的初始激活度 initial_activation = 0.75 # 新记忆初始激活度为 0.75 - + memory = Memory( id=memory_id, subject_id=subject_node.id, diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 2c234ae86..32fe76f20 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -149,7 +149,7 @@ class MemoryManager: # 读取阈值过滤配置 search_min_importance = self.config.search_min_importance search_similarity_threshold = self.config.search_similarity_threshold - + logger.info( f"📊 配置检查: search_max_expand_depth={expand_depth}, " f"search_expand_semantic_threshold={expand_semantic_threshold}, " @@ -417,7 +417,7 @@ class MemoryManager: # 使用配置的默认值 if top_k is None: top_k = getattr(self.config, "search_top_k", 10) - + # 准备搜索参数 params = { "query": query, @@ -951,7 +951,7 @@ class MemoryManager: ) else: logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)") - + # 4. 保存更新 await self.persistence.save_graph_store(self.graph_store) return True @@ -984,7 +984,7 @@ class MemoryManager: try: forgotten_count = 0 all_memories = self.graph_store.get_all_memories() - + # 获取配置参数 min_importance = getattr(self.config, "forgetting_min_importance", 0.8) decay_rate = getattr(self.config, "activation_decay_rate", 0.9) @@ -1010,10 +1010,10 @@ class MemoryManager: try: last_access_dt = datetime.fromisoformat(last_access) days_passed = (datetime.now() - last_access_dt).days - + # 应用指数衰减:activation = base * (decay_rate ^ days) current_activation = base_activation * (decay_rate ** days_passed) - + logger.debug( f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, " f"经过{days_passed}天衰减后={current_activation:.3f}" @@ -1035,20 +1035,20 @@ class MemoryManager: # 批量遗忘记忆(不立即清理孤立节点) if memories_to_forget: logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...") - + for memory_id, activation in memories_to_forget: # cleanup_orphans=False:暂不清理孤立节点 success = await self.forget_memory(memory_id, cleanup_orphans=False) if success: forgotten_count += 1 - + # 统一清理孤立节点和边 logger.info("批量遗忘完成,开始统一清理孤立节点和边...") orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges() - + # 保存最终更新 await self.persistence.save_graph_store(self.graph_store) - + logger.info( f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, " f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边" @@ -1079,31 +1079,31 @@ class MemoryManager: # 1. 清理孤立节点 # graph_store.node_to_memories 记录了每个节点属于哪些记忆 nodes_to_remove = [] - + for node_id, memory_ids in list(self.graph_store.node_to_memories.items()): # 如果节点不再属于任何记忆,标记为删除 if not memory_ids: nodes_to_remove.append(node_id) - + # 从图中删除孤立节点 for node_id in nodes_to_remove: if self.graph_store.graph.has_node(node_id): self.graph_store.graph.remove_node(node_id) orphan_nodes_count += 1 - + # 从映射中删除 if node_id in self.graph_store.node_to_memories: del self.graph_store.node_to_memories[node_id] - + # 2. 清理孤立边(指向已删除节点的边) edges_to_remove = [] - - for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'): + + for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"): # 检查边的源节点和目标节点是否还存在于node_to_memories中 if source not in self.graph_store.node_to_memories or \ target not in self.graph_store.node_to_memories: edges_to_remove.append((source, target)) - + # 删除孤立边 for source, target in edges_to_remove: try: @@ -1111,12 +1111,12 @@ class MemoryManager: orphan_edges_count += 1 except Exception as e: logger.debug(f"删除边失败 {source} -> {target}: {e}") - + if orphan_nodes_count > 0 or orphan_edges_count > 0: logger.info( f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边" ) - + return orphan_nodes_count, orphan_edges_count except Exception as e: @@ -1258,7 +1258,7 @@ class MemoryManager: mem for mem in recent_memories if mem.importance >= min_importance_for_consolidation ] - + result["importance_filtered"] = len(recent_memories) - len(important_memories) logger.info( f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): " @@ -1382,26 +1382,26 @@ class MemoryManager: # ===== 步骤4: 向量检索关联记忆 + LLM分析关系 ===== # 过滤掉已删除的记忆 remaining_memories = [m for m in important_memories if m.id not in deleted_ids] - + if not remaining_memories: logger.info("✅ 记忆整理完成: 去重后无剩余记忆") return logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...") - + # 分批处理记忆关联 llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10) max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5) min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6) - + all_new_edges = [] # 收集所有新建的边 - + for batch_start in range(0, len(remaining_memories), llm_batch_size): batch_end = min(batch_start + llm_batch_size, len(remaining_memories)) batch = remaining_memories[batch_start:batch_end] - + logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}") - + for memory in batch: # 跳过已经有很多连接的记忆 existing_edges = len([ @@ -1454,14 +1454,14 @@ class MemoryManager: except Exception as e: logger.warning(f"创建关联边失败: {e}") continue - + # 每个批次后让出控制权 await asyncio.sleep(0.01) # ===== 步骤5: 统一更新记忆数据 ===== if all_new_edges: logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...") - + for memory, edge, relation in all_new_edges: try: # 添加到图 @@ -2301,7 +2301,7 @@ class MemoryManager: # 使用 asyncio.wait_for 来支持取消 await asyncio.wait_for( asyncio.sleep(initial_delay), - timeout=float('inf') # 允许随时取消 + timeout=float("inf") # 允许随时取消 ) # 检查是否仍然需要运行 diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index 93e30cb83..ed7c16a2c 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -482,7 +482,7 @@ class GraphStore: for node in memory.nodes: if node.id in self.node_to_memories: self.node_to_memories[node.id].discard(memory_id) - + # 可选:立即清理孤立节点 if cleanup_orphans: # 如果该节点不再属于任何记忆,从图中移除节点 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index e87557cb8..33618dd8c 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -72,12 +72,12 @@ class MemoryTools: self.max_expand_depth = max_expand_depth self.expand_semantic_threshold = expand_semantic_threshold self.search_top_k = search_top_k - + # 保存权重配置 self.base_vector_weight = search_vector_weight self.base_importance_weight = search_importance_weight self.base_recency_weight = search_recency_weight - + # 保存阈值过滤配置 self.search_min_importance = search_min_importance self.search_similarity_threshold = search_similarity_threshold @@ -516,14 +516,14 @@ class MemoryTools: # 1. 根据策略选择检索方式 llm_prefer_types = [] # LLM识别的偏好节点类型 - + if use_multi_query: # 多查询策略(返回节点列表 + 偏好类型) similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context) else: # 传统单查询策略 similar_nodes = await self._single_query_search(query, top_k) - + # 合并用户指定的偏好类型和LLM识别的偏好类型 all_prefer_types = list(set(prefer_node_types + llm_prefer_types)) if all_prefer_types: @@ -551,7 +551,7 @@ class MemoryTools: # 记录最高分数 if mem_id not in memory_scores or similarity > memory_scores[mem_id]: memory_scores[mem_id] = similarity - + # 🔥 详细日志:检查初始召回情况 logger.info( f"初始向量搜索: 返回{len(similar_nodes)}个节点 → " @@ -559,8 +559,8 @@ class MemoryTools: ) if len(initial_memory_ids) == 0: logger.warning( - f"⚠️ 向量搜索未找到任何记忆!" - f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" + "⚠️ 向量搜索未找到任何记忆!" + "可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" ) # 输出相似节点的详细信息用于调试 if similar_nodes: @@ -692,7 +692,7 @@ class MemoryTools: key=lambda x: final_scores[x], reverse=True ) # 🔥 不再提前截断,让所有候选参与详细评分 - + # 🔍 统计初始记忆的相似度分布(用于诊断) if memory_scores: similarities = list(memory_scores.values()) @@ -707,7 +707,7 @@ class MemoryTools: # 5. 获取完整记忆并进行最终排序(优化后的动态权重系统) memories_with_scores = [] filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计 - + for memory_id in sorted_memory_ids: # 遍历所有候选 memory = self.graph_store.get_memory_by_id(memory_id) if memory: @@ -715,7 +715,7 @@ class MemoryTools: # 基础分数 similarity_score = final_scores[memory_id] importance_score = memory.importance - + # 🆕 区分记忆来源(用于过滤) is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索 true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None @@ -738,16 +738,16 @@ class MemoryTools: activation_score = memory.activation # 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调 - memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type) - + memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type) + # 检测记忆的主要节点类型 node_types_count = {} for node in memory.nodes: - nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type) + nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type) node_types_count[nt] = node_types_count.get(nt, 0) + 1 - + dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown" - + # 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调) if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT": # 事实性记忆:提升相似度权重,降低时效性权重 @@ -777,41 +777,41 @@ class MemoryTools: "importance": 1.0, "recency": 1.0, } - + # 应用调整后的权重(基于配置的基础权重) weights = { "similarity": self.base_vector_weight * type_adjustments["similarity"], "importance": self.base_importance_weight * type_adjustments["importance"], "recency": self.base_recency_weight * type_adjustments["recency"], } - + # 归一化权重(确保总和为1.0) total_weight = sum(weights.values()) if total_weight > 0: weights = {k: v / total_weight for k, v in weights.items()} - + # 综合分数计算(🔥 移除激活度影响) final_score = ( similarity_score * weights["similarity"] + importance_score * weights["importance"] + recency_score * weights["recency"] ) - + # 🆕 阈值过滤策略: # 1. 重要性过滤:应用于所有记忆(过滤极低质量) if memory.importance < self.search_min_importance: filter_stats["importance"] += 1 logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}") continue - + # 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序) # 理由:向量搜索已经按相似度排序,返回的都是最相关结果 # 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾 - # + # # 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑 # if not is_initial_memory and some_score < threshold: # continue - + # 记录通过过滤的记忆(用于调试) if is_initial_memory: logger.debug( @@ -823,11 +823,11 @@ class MemoryTools: f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, " f"综合分数={final_score:.4f}" ) - + # 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回) if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count: final_score *= 1.1 # 10% 加成 - + # 🆕 用户指定的优先节点类型额外加权 if prefer_node_types: for prefer_type in prefer_node_types: @@ -835,7 +835,7 @@ class MemoryTools: final_score *= 1.15 # 15% 额外加成 logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}") break - + memories_with_scores.append((memory, final_score, dominant_node_type)) # 按综合分数排序 @@ -845,7 +845,7 @@ class MemoryTools: # 统计过滤情况 total_candidates = len(all_memory_ids) filtered_count = total_candidates - len(memories_with_scores) - + # 6. 格式化结果(包含调试信息) results = [] for memory, score, node_type in memories_with_scores[:top_k]: @@ -866,7 +866,7 @@ class MemoryTools: f"过滤{filtered_count}个 (重要性过滤) → " f"最终返回{len(results)}条记忆" ) - + # 如果过滤率过高,发出警告 if total_candidates > 0: filter_rate = filtered_count / total_candidates @@ -1092,20 +1092,21 @@ class MemoryTools: response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) import re + import orjson - + # 清理Markdown代码块 response = re.sub(r"```json\s*", "", response) response = re.sub(r"```\s*$", "", response).strip() # 解析JSON data = orjson.loads(response) - + # 提取查询列表 queries = data.get("queries", []) result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) for item in queries if item.get("text", "").strip()] - + # 提取偏好节点类型 prefer_node_types = data.get("prefer_node_types", []) # 确保类型正确且有效 @@ -1154,7 +1155,7 @@ class MemoryTools: limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率 min_similarity=0.0, # 不在这里过滤,交给后续评分 ) - + logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}") if similar_nodes: logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}") diff --git a/src/memory_graph/utils/graph_expansion.py b/src/memory_graph/utils/graph_expansion.py index f7c850f74..babfba788 100644 --- a/src/memory_graph/utils/graph_expansion.py +++ b/src/memory_graph/utils/graph_expansion.py @@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter( try: import time start_time = time.time() - + # 记录已访问的记忆,避免重复 visited_memories = set(initial_memory_ids) # 记录扩展的记忆及其分数 @@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter( # 获取该记忆的邻居记忆(通过边关系) neighbor_memory_ids = set() - + # 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重) edge_weights = {} # 记录通过不同边类型到达的记忆的权重 - + for edge in memory.edges: # 获取边的目标节点 target_node_id = edge.target_id source_node_id = edge.source_id - + # 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边) - edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) + edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type) if edge_type_str == "REFERENCE": edge_weight = 1.3 # REFERENCE边权重最高(引用关系) elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]: @@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter( edge_weight = 0.9 # 一般关系适中降权 else: edge_weight = 1.0 # 默认权重 - + # 通过节点找到其他记忆 for node_id in [target_node_id, source_node_id]: if node_id in graph_store.node_to_memories: for neighbor_id in graph_store.node_to_memories[node_id]: if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight: edge_weights[neighbor_id] = edge_weight - + # 将权重高的邻居记忆加入候选 for neighbor_id, edge_weight in edge_weights.items(): neighbor_memory_ids.add((neighbor_id, edge_weight)) - + # 过滤掉已访问的和自己 filtered_neighbors = [] for neighbor_id, edge_weight in neighbor_memory_ids: @@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter( # 批量评估邻居记忆 for neighbor_mem_id, edge_weight in filtered_neighbors: candidates_checked += 1 - + neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id) if not neighbor_memory: continue @@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter( (n for n in neighbor_memory.nodes if n.has_embedding()), None ) - + if not topic_node or topic_node.embedding is None: continue @@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter( if len(expanded_memories) >= max_expanded: logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}") break - + # 早停检查 if len(expanded_memories) >= max_expanded: break - + # 记录本层统计 depth_stats.append({ "depth": depth + 1, @@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter( # 限制下一层的记忆数量,避免爆炸性增长 current_level_memories = next_level_memories[:max_expanded] - + # 每层让出控制权 await asyncio.sleep(0.001) # 排序并返回 sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded] - + elapsed = time.time() - start_time logger.info( f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → " f"扩展{len(sorted_results)}个新记忆 " f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)" ) - + # 输出每层统计 for stat in depth_stats: logger.debug( diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 4ff982d06..7c4d6c9ed 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -78,11 +78,9 @@ __all__ = [ # 消息 "MaiMessages", # 工具函数 - "ManifestValidator", "PluginInfo", # 增强命令系统 "PlusCommand", - "PlusCommandAdapter", "PythonDependency", "ToolInfo", "ToolParamType", diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 616db0b88..86c8f3210 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -132,7 +132,7 @@ async def generate_reply( prompt_mode = "s4u" # 默认使用s4u模式 if action_data and "prompt_mode" in action_data: prompt_mode = action_data.get("prompt_mode", "s4u") - + # 将prompt_mode添加到available_actions中(作为特殊键) # 注意:这里我们需要暂时使用类型忽略,因为available_actions的类型定义不支持非ActionInfo值 if available_actions is None: diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index a180b727b..efabbf4a5 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -362,7 +362,7 @@ class ChatterPlanFilter: return "最近没有聊天内容。", "没有未读消息。", [] stream_context = chat_stream.context_manager - + # 获取真正的已读和未读消息 read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 if not read_messages: @@ -652,30 +652,30 @@ class ChatterPlanFilter: if not action_info: logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数") return action_data - + # 获取该动作定义的合法参数 defined_params = set(action_info.action_parameters.keys()) - + # 合法参数集合 valid_params = defined_params - + # 过滤参数 filtered_data = {} removed_params = [] - + for key, value in action_data.items(): if key in valid_params: filtered_data[key] = value else: removed_params.append(key) - + # 记录被移除的参数 if removed_params: logger.info( f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. " f"合法参数: {sorted(valid_params)}" ) - + return filtered_data def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py index 4c182a70c..23d19cc23 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py @@ -545,14 +545,14 @@ async def execute_proactive_thinking(stream_id: str): # 获取或创建该聊天流的执行锁 if stream_id not in _execution_locks: _execution_locks[stream_id] = asyncio.Lock() - + lock = _execution_locks[stream_id] - + # 尝试获取锁,如果已被占用则跳过本次执行(防止重复) if lock.locked(): logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务") return - + async with lock: logger.debug(f"🤔 开始主动思考 {stream_id}") @@ -563,13 +563,13 @@ async def execute_proactive_thinking(stream_id: str): from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream and chat_stream.context_manager.context.is_chatter_processing: logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息") return except Exception as e: logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行") - + # 0.1 检查白名单/黑名单 # 从 stream_id 获取 stream_config 字符串进行验证 try: diff --git a/src/plugins/built_in/anti_injection_plugin/__init__.py b/src/plugins/built_in/anti_injection_plugin/__init__.py index 808164495..139e2e64d 100644 --- a/src/plugins/built_in/anti_injection_plugin/__init__.py +++ b/src/plugins/built_in/anti_injection_plugin/__init__.py @@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata( # 导入插件主类 from .plugin import AntiInjectionPlugin -__all__ = ["__plugin_meta__", "AntiInjectionPlugin"] +__all__ = ["AntiInjectionPlugin", "__plugin_meta__"] diff --git a/src/plugins/built_in/anti_injection_plugin/checker.py b/src/plugins/built_in/anti_injection_plugin/checker.py index 136e4aae4..68c1284e6 100644 --- a/src/plugins/built_in/anti_injection_plugin/checker.py +++ b/src/plugins/built_in/anti_injection_plugin/checker.py @@ -8,8 +8,8 @@ import time from src.chat.security.interfaces import ( SecurityAction, - SecurityCheckResult, SecurityChecker, + SecurityCheckResult, SecurityLevel, ) from src.common.logger import get_logger diff --git a/src/plugins/built_in/anti_injection_plugin/processor.py b/src/plugins/built_in/anti_injection_plugin/processor.py index 9960f1521..fa0380569 100644 --- a/src/plugins/built_in/anti_injection_plugin/processor.py +++ b/src/plugins/built_in/anti_injection_plugin/processor.py @@ -4,7 +4,7 @@ 处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。 """ -from src.chat.security.interfaces import SecurityAction, SecurityCheckResult +from src.chat.security.interfaces import SecurityCheckResult from src.common.logger import get_logger from .counter_attack import CounterAttackGenerator diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 3b1c7a014..5baaa3a8e 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin): # --- 根据配置注册组件 --- components: ClassVar = [] - + # 注册 reply 动作 if self.get_config("components.enable_reply", True): components.append((ReplyAction.get_action_info(), ReplyAction)) - + # 注册 respond 动作 if self.get_config("components.enable_respond", True): components.append((RespondAction.get_action_info(), RespondAction)) - + # 注册 emoji 动作 if self.get_config("components.enable_emoji", True): components.append((EmojiAction.get_action_info(), EmojiAction)) diff --git a/src/plugins/built_in/core_actions/reply.py b/src/plugins/built_in/core_actions/reply.py index 993ddb85a..9a90f7e33 100644 --- a/src/plugins/built_in/core_actions/reply.py +++ b/src/plugins/built_in/core_actions/reply.py @@ -22,23 +22,23 @@ class ReplyAction(BaseAction): - 专注于理解和回应单条消息的具体内容 - 适合 Focus 模式下的精准回复 """ - + # 动作基本信息 action_name = "reply" action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。" - + # 激活设置 activation_type = ActionActivationType.ALWAYS # 回复动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回复动作不能与其他动作并行 - + # 动作参数定义 action_parameters: ClassVar = { "target_message_id": "要回复的目标消息ID(必需,来自未读消息的 标签)", "content": "回复的具体内容(可选,由LLM生成)", "should_quote_reply": "是否引用原消息(可选,true/false,默认false。群聊中回复较早消息或需要明确指向时使用true)", } - + # 动作使用场景 action_require: ClassVar = [ "需要针对特定消息进行精准回复时使用", @@ -48,10 +48,10 @@ class ReplyAction(BaseAction): "群聊中需要明确回应某个特定用户或问题时使用", "关注单条消息的具体内容和上下文细节", ] - + # 关联类型 associated_types: ClassVar[list[str]] = ["text"] - + async def execute(self) -> tuple[bool, str]: """执行reply动作 @@ -70,21 +70,21 @@ class RespondAction(BaseAction): - 适合对于群聊消息下的宏观回应 - 避免与单一用户深度对话而忽略其他用户的消息 """ - + # 动作基本信息 action_name = "respond" action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。" - + # 激活设置 activation_type = ActionActivationType.ALWAYS # 回应动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回应动作不能与其他动作并行 - + # 动作参数定义 action_parameters: ClassVar = { "content": "回复的具体内容(可选,由LLM生成)", } - + # 动作使用场景 action_require: ClassVar = [ "需要统一回应多条未读消息时使用(Normal 模式专用)", @@ -94,10 +94,10 @@ class RespondAction(BaseAction): "适合群聊中的自然对话流,无需精确指向特定消息", "可以同时回应多个话题或参与者", ] - + # 关联类型 associated_types: ClassVar[list[str]] = ["text"] - + async def execute(self) -> tuple[bool, str]: """执行respond动作 diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index b446d1dbb..2dc95d949 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -6,10 +6,10 @@ import asyncio import base64 import datetime -import filetype from collections.abc import Callable import aiohttp +import filetype from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/plugins/built_in/system_management/plugin.py b/src/plugins/built_in/system_management/plugin.py index a1eb007ea..77638b4f4 100644 --- a/src/plugins/built_in/system_management/plugin.py +++ b/src/plugins/built_in/system_management/plugin.py @@ -6,7 +6,7 @@ import re from typing import ClassVar - +from src.chat.utils.prompt_component_manager import prompt_component_manager from src.plugin_system.apis import ( plugin_manage_api, ) @@ -74,6 +74,7 @@ class SystemCommand(PlusCommand): • `/system permission` - 权限管理 • `/system plugin` - 插件管理 • `/system schedule` - 定时任务管理 +• `/system prompt` - 提示词注入管理 """ elif target == "schedule": help_text = """📅 定时任务管理帮助 @@ -113,8 +114,17 @@ class SystemCommand(PlusCommand): • /system permission nodes [插件名] - 查看权限节点 • /system permission allnodes - 查看所有权限节点详情 """ - await self.send_text(help_text) + elif target == "prompt": + help_text = """📝 提示词注入管理帮助 +🔎 查询命令 (需要 `system.prompt.view` 权限): +• `/system prompt help` - 显示此帮助 +• `/system prompt map` - 查看全局注入关系图 +• `/system prompt targets` - 列出所有可被注入的核心提示词 +• `/system prompt components` - 列出所有已注册的提示词组件 +• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情 +""" + await self.send_text(help_text) # ================================================================= # Plugin Management Section @@ -231,6 +241,101 @@ class SystemCommand(PlusCommand): else: await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`") + # ================================================================= + # Prompt Management Section + # ================================================================= + async def _handle_prompt_commands(self, args: list[str]): + """处理提示词管理相关命令""" + if not args or args[0].lower() in ["help", "帮助"]: + await self._show_help("prompt") + return + + action = args[0].lower() + remaining_args = args[1:] + + if action in ["map", "关系图"]: + await self._show_injection_map() + elif action in ["targets", "目标"]: + await self._list_core_prompts() + elif action in ["components", "组件"]: + await self._list_prompt_components() + elif action in ["info", "详情"] and remaining_args: + await self._get_prompt_injection_info(remaining_args[0]) + else: + await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助") + + @require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限") + async def _show_injection_map(self): + """显示全局注入关系图""" + injection_map = await prompt_component_manager.get_full_injection_map() + if not injection_map: + await self.send_text("📊 当前没有任何提示词注入关系") + return + + response_parts = ["📊 全局提示词注入关系图:\n"] + for target, injections in injection_map.items(): + if injections: + response_parts.append(f"🎯 **{target}** (注入源):") + for inj in injections: + source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else '' + response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}") + else: + response_parts.append(f"🎯 **{target}** (无注入)") + + await self._send_long_message("\n".join(response_parts)) + + @require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限") + async def _list_core_prompts(self): + """列出所有可注入的核心提示词""" + targets = prompt_component_manager.get_core_prompts() + if not targets: + await self.send_text("🎯 当前没有可注入的核心提示词") + return + + response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets]) + await self.send_text(response) + + @require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限") + async def _list_prompt_components(self): + """列出所有已注册的提示词组件""" + components = prompt_component_manager.get_registered_prompt_component_info() + if not components: + await self.send_text("🧩 当前没有已注册的提示词组件") + return + + response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"] + for comp in components: + response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)") + + await self._send_long_message("\n".join(response_parts)) + + + @require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限") + async def _get_prompt_injection_info(self, target_name: str): + """获取特定核心提示词的注入详情""" + injections = await prompt_component_manager.get_injections_for_prompt(target_name) + + core_prompts = prompt_component_manager.get_core_prompts() + if target_name not in core_prompts: + await self.send_text(f"❌ 找不到核心提示词: `{target_name}`") + return + + if not injections: + await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。") + return + + response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"] + for inj in injections: + response_parts.append( + f" • **`{inj['name']}`** (优先级: {inj['priority']})" + ) + response_parts.append(f" - 来源: `{inj['source']}`") + response_parts.append(f" - 类型: `{inj['injection_type']}`") + if inj.get('target_content'): + response_parts.append(f" - 操作目标: `{inj['target_content']}`") + + await self.send_text("\n".join(response_parts)) + # ================================================================= # Permission Management Section # ================================================================= diff --git a/src/schedule/unified_scheduler.py b/src/schedule/unified_scheduler.py index 30eaf61b1..c0af4db16 100644 --- a/src/schedule/unified_scheduler.py +++ b/src/schedule/unified_scheduler.py @@ -17,7 +17,6 @@ import uuid import weakref from collections import defaultdict from collections.abc import Awaitable, Callable -from contextlib import suppress from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -31,29 +30,30 @@ logger = get_logger("unified_scheduler") # ==================== 配置和常量 ==================== + @dataclass class SchedulerConfig: """调度器配置""" - + # 检查间隔 check_interval: float = 1.0 # 主循环检查间隔(秒) deadlock_check_interval: float = 30.0 # 死锁检查间隔(秒) - + # 超时配置 task_default_timeout: float = 300.0 # 默认任务超时(5分钟) task_cancel_timeout: float = 10.0 # 任务取消超时(10秒) shutdown_timeout: float = 30.0 # 关闭超时(30秒) deadlock_threshold: float = 600.0 # 死锁检测阈值(10分钟,超过此时间视为死锁) - + # 并发控制 max_concurrent_tasks: int = 100 # 最大并发任务数 enable_task_semaphore: bool = True # 是否启用任务信号量 - + # 重试配置 enable_retry: bool = True # 是否启用失败重试 max_retries: int = 3 # 最大重试次数 retry_delay: float = 5.0 # 重试延迟(秒) - + # 资源管理 cleanup_interval: float = 60.0 # 清理已完成任务的间隔(秒) keep_completed_tasks: int = 100 # 保留的已完成任务数(用于统计) @@ -61,8 +61,10 @@ class SchedulerConfig: # ==================== 枚举类型 ==================== + class TriggerType(Enum): """触发类型枚举""" + TIME = "time" # 时间触发 EVENT = "event" # 事件触发(通过 event_manager) CUSTOM = "custom" # 自定义条件触发 @@ -70,6 +72,7 @@ class TriggerType(Enum): class TaskStatus(Enum): """任务状态枚举""" + PENDING = "pending" # 等待触发 RUNNING = "running" # 正在执行 COMPLETED = "completed" # 已完成 @@ -81,9 +84,11 @@ class TaskStatus(Enum): # ==================== 任务模型 ==================== + @dataclass class TaskExecution: """任务执行记录""" + execution_id: str started_at: datetime ended_at: datetime | None = None @@ -91,21 +96,21 @@ class TaskExecution: error: Exception | None = None result: Any = None duration: float = 0.0 - + def complete(self, result: Any = None) -> None: """标记执行完成""" self.ended_at = datetime.now() self.status = TaskStatus.COMPLETED self.result = result self.duration = (self.ended_at - self.started_at).total_seconds() - + def fail(self, error: Exception) -> None: """标记执行失败""" self.ended_at = datetime.now() self.status = TaskStatus.FAILED self.error = error self.duration = (self.ended_at - self.started_at).total_seconds() - + def cancel(self) -> None: """标记执行取消""" self.ended_at = datetime.now() @@ -116,79 +121,76 @@ class TaskExecution: @dataclass class ScheduleTask: """调度任务模型(重构版)""" - + # 基本信息 schedule_id: str task_name: str callback: Callable[..., Awaitable[Any]] - + # 触发配置 trigger_type: TriggerType trigger_config: dict[str, Any] is_recurring: bool = False - + # 回调参数 callback_args: tuple = field(default_factory=tuple) callback_kwargs: dict = field(default_factory=dict) - + # 状态信息 status: TaskStatus = TaskStatus.PENDING created_at: datetime = field(default_factory=datetime.now) last_triggered_at: datetime | None = None next_trigger_at: datetime | None = None - + # 统计信息 trigger_count: int = 0 success_count: int = 0 failure_count: int = 0 total_execution_time: float = 0.0 - + # 执行记录(弱引用,避免内存泄漏) execution_history: list[TaskExecution] = field(default_factory=list) current_execution: TaskExecution | None = None - + # 重试配置 max_retries: int = 0 retry_count: int = 0 last_error: Exception | None = None - + # 超时配置 timeout: float | None = None - + # 运行时引用 _asyncio_task: asyncio.Task | None = field(default=None, init=False, repr=False) _weak_scheduler: Any = field(default=None, init=False, repr=False) - + def __repr__(self) -> str: return ( f"ScheduleTask(id={self.schedule_id[:8]}..., " f"name={self.task_name}, type={self.trigger_type.value}, " f"status={self.status.value}, recurring={self.is_recurring})" ) - + def is_active(self) -> bool: """任务是否活跃(可以被触发)""" return self.status in (TaskStatus.PENDING, TaskStatus.RUNNING) - + def can_trigger(self) -> bool: """任务是否可以被触发""" return self.status == TaskStatus.PENDING - + def start_execution(self) -> TaskExecution: """开始新的执行""" - execution = TaskExecution( - execution_id=str(uuid.uuid4()), - started_at=datetime.now() - ) + execution = TaskExecution(execution_id=str(uuid.uuid4()), started_at=datetime.now()) self.current_execution = execution self.status = TaskStatus.RUNNING return execution - + def finish_execution(self, success: bool, result: Any = None, error: Exception | None = None) -> None: """完成当前执行""" if not self.current_execution: return - + if success: self.current_execution.complete(result) self.success_count += 1 @@ -197,18 +199,18 @@ class ScheduleTask: self.current_execution.fail(error or Exception("Unknown error")) self.failure_count += 1 self.last_error = error - + self.total_execution_time += self.current_execution.duration - + # 保留最近10条执行记录 self.execution_history.append(self.current_execution) if len(self.execution_history) > 10: self.execution_history.pop(0) - + self.current_execution = None self.last_triggered_at = datetime.now() self.trigger_count += 1 - + # 更新状态 if self.is_recurring: self.status = TaskStatus.PENDING @@ -218,76 +220,77 @@ class ScheduleTask: # ==================== 死锁检测器(重构版)==================== + class DeadlockDetector: """死锁检测器(重构版) - + 功能增强: 1. 多级超时检测 2. 任务健康度评分 3. 自动恢复建议 """ - + def __init__(self, config: SchedulerConfig): self.config = config self._monitored_tasks: dict[str, tuple[float, str]] = {} # task_id -> (start_time, task_name) self._timeout_history: defaultdict[str, list[float]] = defaultdict(list) # task_id -> [timeout_times] - + def register_task(self, task_id: str, task_name: str) -> None: """注册任务开始监控""" self._monitored_tasks[task_id] = (time.time(), task_name) - + def unregister_task(self, task_id: str) -> None: """取消注册任务""" self._monitored_tasks.pop(task_id, None) - + def get_running_time(self, task_id: str) -> float: """获取任务运行时间""" if task_id not in self._monitored_tasks: return 0.0 start_time, _ = self._monitored_tasks[task_id] return time.time() - start_time - + def check_deadlocks(self) -> list[tuple[str, float, str]]: """检查死锁任务 - + Returns: List[Tuple[task_id, runtime, task_name]]: 疑似死锁的任务列表 """ current_time = time.time() deadlocked = [] - + for task_id, (start_time, task_name) in list(self._monitored_tasks.items()): runtime = current_time - start_time # 使用死锁阈值而不是默认超时 if runtime > self.config.deadlock_threshold: deadlocked.append((task_id, runtime, task_name)) - + return deadlocked - + def record_timeout(self, task_id: str) -> None: """记录超时事件""" self._timeout_history[task_id].append(time.time()) # 只保留最近10次记录 if len(self._timeout_history[task_id]) > 10: self._timeout_history[task_id].pop(0) - + def get_health_score(self, task_id: str) -> float: """计算任务健康度 (0.0-1.0) - + 基于超时频率计算,频繁超时的任务健康度低 """ if task_id not in self._timeout_history: return 1.0 - + timeouts = self._timeout_history[task_id] if not timeouts: return 1.0 - + # 最近10次执行中的超时次数 recent_count = len(timeouts) # 健康度 = 1 - (超时次数 / 10) return max(0.0, 1.0 - (recent_count / 10.0)) - + def clear(self) -> None: """清空所有监控数据""" self._monitored_tasks.clear() @@ -296,9 +299,10 @@ class DeadlockDetector: # ==================== 统一调度器(完全重构版)==================== + class UnifiedScheduler: """统一调度器(完全重构版) - + 核心改进: 1. 完全无锁设计 - 利用 asyncio 的单线程特性 2. 任务完全隔离 - 使用独立的 Task,互不阻塞 @@ -307,7 +311,7 @@ class UnifiedScheduler: 5. 资源自动清理 - 防止内存泄漏 6. 并发控制 - 可配置的并发限制 7. 健康监控 - 任务健康度评分和统计 - + 特点: - 每秒检查一次所有任务 - 自动执行到期任务 @@ -316,247 +320,233 @@ class UnifiedScheduler: - 与 event_manager 集成 - 内置死锁检测和恢复机制 """ - + def __init__(self, config: SchedulerConfig | None = None): self.config = config or SchedulerConfig() - + # 任务存储 self._tasks: dict[str, ScheduleTask] = {} self._tasks_by_name: dict[str, str] = {} # task_name -> schedule_id 快速查找 - + # 运行状态 self._running = False self._stopping = False - + # 后台任务 self._check_loop_task: asyncio.Task | None = None self._deadlock_check_task: asyncio.Task | None = None self._cleanup_task: asyncio.Task | None = None - + # 事件订阅追踪 self._event_subscriptions: dict[str | EventType, set[str]] = defaultdict(set) # event -> {task_ids} - + # 死锁检测器 self._deadlock_detector = DeadlockDetector(self.config) - + # 并发控制 self._task_semaphore: asyncio.Semaphore | None = None if self.config.enable_task_semaphore: self._task_semaphore = asyncio.Semaphore(self.config.max_concurrent_tasks) - + # 统计信息 self._total_executions = 0 self._total_failures = 0 self._total_timeouts = 0 self._start_time: datetime | None = None - + # 已完成任务缓存(用于统计) self._completed_tasks: list[ScheduleTask] = [] - + # ==================== 生命周期管理 ==================== - + async def start(self) -> None: """启动调度器""" if self._running: logger.warning("调度器已在运行中") return - + logger.info("正在启动统一调度器...") self._running = True self._stopping = False self._start_time = datetime.now() - + # 启动后台任务 - self._check_loop_task = asyncio.create_task( - self._check_loop(), - name="scheduler_check_loop" - ) - self._deadlock_check_task = asyncio.create_task( - self._deadlock_check_loop(), - name="scheduler_deadlock_check" - ) - self._cleanup_task = asyncio.create_task( - self._cleanup_loop(), - name="scheduler_cleanup" - ) - + self._check_loop_task = asyncio.create_task(self._check_loop(), name="scheduler_check_loop") + self._deadlock_check_task = asyncio.create_task(self._deadlock_check_loop(), name="scheduler_deadlock_check") + self._cleanup_task = asyncio.create_task(self._cleanup_loop(), name="scheduler_cleanup") + # 注册到 event_manager try: from src.plugin_system.core.event_manager import event_manager + event_manager.register_scheduler_callback(self._handle_event_trigger) logger.debug("调度器已注册到 event_manager") except ImportError: logger.warning("无法导入 event_manager,事件触发功能将不可用") - + logger.info("统一调度器已启动") - + async def stop(self) -> None: """停止调度器(优雅关闭)""" if not self._running: return - + logger.info("正在停止统一调度器...") self._stopping = True self._running = False - + # 取消后台任务 background_tasks = [ self._check_loop_task, self._deadlock_check_task, self._cleanup_task, ] - + for task in background_tasks: if task and not task.done(): task.cancel() - + # 等待后台任务完成 await asyncio.gather(*[t for t in background_tasks if t], return_exceptions=True) - + # 取消注册 event_manager try: from src.plugin_system.core.event_manager import event_manager + event_manager.unregister_scheduler_callback() logger.debug("调度器已从 event_manager 注销") except ImportError: pass - + # 取消所有正在执行的任务 await self._cancel_all_running_tasks() - + # 显示最终统计 stats = self.get_statistics() - logger.info(f"调度器最终统计: 总任务={stats['total_tasks']}, " - f"执行次数={stats['total_executions']}, " - f"失败={stats['total_failures']}") - + logger.info( + f"调度器最终统计: 总任务={stats['total_tasks']}, " + f"执行次数={stats['total_executions']}, " + f"失败={stats['total_failures']}" + ) + # 清理资源 self._tasks.clear() self._tasks_by_name.clear() self._event_subscriptions.clear() self._completed_tasks.clear() self._deadlock_detector.clear() - + logger.info("统一调度器已停止") - + async def _cancel_all_running_tasks(self) -> None: """取消所有正在运行的任务""" running_tasks = [ - task for task in self._tasks.values() - if task.status == TaskStatus.RUNNING and task._asyncio_task + task for task in self._tasks.values() if task.status == TaskStatus.RUNNING and task._asyncio_task ] - + if not running_tasks: return - + logger.info(f"正在取消 {len(running_tasks)} 个运行中的任务...") - + # 第一阶段:发送取消信号 for task in running_tasks: if task._asyncio_task and not task._asyncio_task.done(): task._asyncio_task.cancel() - + # 第二阶段:等待取消完成(带超时) cancel_tasks = [ - task._asyncio_task for task in running_tasks - if task._asyncio_task and not task._asyncio_task.done() + task._asyncio_task for task in running_tasks if task._asyncio_task and not task._asyncio_task.done() ] - + if cancel_tasks: try: await asyncio.wait_for( - asyncio.gather(*cancel_tasks, return_exceptions=True), - timeout=self.config.shutdown_timeout + asyncio.gather(*cancel_tasks, return_exceptions=True), timeout=self.config.shutdown_timeout ) logger.info("所有任务已成功取消") except asyncio.TimeoutError: logger.warning(f"部分任务取消超时({self.config.shutdown_timeout}秒),强制停止") - + # ==================== 后台循环 ==================== - + async def _check_loop(self) -> None: """主循环:定期检查和触发任务""" logger.debug("调度器主循环已启动") - + while self._running: try: await asyncio.sleep(self.config.check_interval) - + if not self._stopping: # 使用 create_task 避免阻塞循环 - asyncio.create_task( - self._check_and_trigger_tasks(), - name="check_trigger_tasks" - ) - + asyncio.create_task(self._check_and_trigger_tasks(), name="check_trigger_tasks") + except asyncio.CancelledError: logger.debug("调度器主循环被取消") break except Exception as e: logger.error(f"调度器主循环发生错误: {e}", exc_info=True) - + async def _deadlock_check_loop(self) -> None: """死锁检测循环""" logger.debug("死锁检测循环已启动") - + while self._running: try: await asyncio.sleep(self.config.deadlock_check_interval) - + if not self._stopping: # 使用 create_task 避免阻塞循环,并限制错误传播 - asyncio.create_task( - self._safe_check_and_handle_deadlocks(), - name="deadlock_check" - ) - + asyncio.create_task(self._safe_check_and_handle_deadlocks(), name="deadlock_check") + except asyncio.CancelledError: logger.debug("死锁检测循环被取消") break except Exception as e: logger.error(f"死锁检测循环发生错误: {e}", exc_info=True) # 继续运行,不因单次错误停止 - + async def _cleanup_loop(self) -> None: """清理循环:定期清理已完成的任务""" logger.debug("清理循环已启动") - + while self._running: try: await asyncio.sleep(self.config.cleanup_interval) - + if not self._stopping: await self._cleanup_completed_tasks() - + except asyncio.CancelledError: logger.debug("清理循环被取消") break except Exception as e: logger.error(f"清理循环发生错误: {e}", exc_info=True) - + # ==================== 任务触发逻辑 ==================== - + async def _check_and_trigger_tasks(self) -> None: """检查并触发到期任务(完全无锁设计)""" current_time = datetime.now() tasks_to_trigger: list[ScheduleTask] = [] - + # 第一阶段:收集需要触发的任务 for task in list(self._tasks.values()): if not task.can_trigger(): continue - + try: should_trigger = await self._should_trigger_task(task, current_time) if should_trigger: tasks_to_trigger.append(task) except Exception as e: logger.error(f"检查任务 {task.task_name} 触发条件时出错: {e}", exc_info=True) - + # 第二阶段:并发触发所有任务 if tasks_to_trigger: await self._trigger_tasks_concurrently(tasks_to_trigger) - + async def _should_trigger_task(self, task: ScheduleTask, current_time: datetime) -> bool: """判断任务是否应该触发""" if task.trigger_type == TriggerType.TIME: @@ -565,17 +555,17 @@ class UnifiedScheduler: return await self._check_custom_trigger(task) # EVENT 类型由 event_manager 触发 return False - + def _check_time_trigger(self, task: ScheduleTask, current_time: datetime) -> bool: """检查时间触发条件""" config = task.trigger_config - + # 检查 trigger_at if "trigger_at" in config: trigger_time = config["trigger_at"] if isinstance(trigger_time, str): trigger_time = datetime.fromisoformat(trigger_time) - + if task.is_recurring and "interval_seconds" in config: # 循环任务:检查是否达到间隔 if task.last_triggered_at is None: @@ -586,7 +576,7 @@ class UnifiedScheduler: else: # 一次性任务:检查是否到达触发时间 return current_time >= trigger_time - + # 检查 delay_seconds elif "delay_seconds" in config: if task.last_triggered_at is None: @@ -597,16 +587,16 @@ class UnifiedScheduler: # 后续触发:从上次触发时间算起 elapsed = (current_time - task.last_triggered_at).total_seconds() return elapsed >= config["delay_seconds"] - + return False - + async def _check_custom_trigger(self, task: ScheduleTask) -> bool: """检查自定义触发条件""" condition_func = task.trigger_config.get("condition_func") if not condition_func or not callable(condition_func): logger.warning(f"任务 {task.task_name} 的自定义条件函数无效") return False - + try: if asyncio.iscoroutinefunction(condition_func): result = await condition_func() @@ -616,48 +606,41 @@ class UnifiedScheduler: except Exception as e: logger.error(f"执行任务 {task.task_name} 的自定义条件函数时出错: {e}", exc_info=True) return False - + async def _trigger_tasks_concurrently(self, tasks: list[ScheduleTask]) -> None: """并发触发多个任务""" logger.debug(f"并发触发 {len(tasks)} 个任务") - + # 为每个任务创建独立的执行 Task execution_tasks = [] for task in tasks: - exec_task = asyncio.create_task( - self._execute_task(task), - name=f"exec_{task.task_name}" - ) + exec_task = asyncio.create_task(self._execute_task(task), name=f"exec_{task.task_name}") task._asyncio_task = exec_task execution_tasks.append(exec_task) - + # 等待所有任务完成(不阻塞主循环) # 使用 return_exceptions=True 确保单个任务失败不影响其他任务 await asyncio.gather(*execution_tasks, return_exceptions=True) - + async def _execute_task(self, task: ScheduleTask) -> None: """执行单个任务(完全隔离)""" execution = task.start_execution() self._deadlock_detector.register_task(task.schedule_id, task.task_name) - + try: # 使用信号量控制并发 async with self._acquire_semaphore(): # 应用超时保护 timeout = task.timeout or self.config.task_default_timeout - + try: - await asyncio.wait_for( - self._run_callback(task), - timeout=timeout - ) - + await asyncio.wait_for(self._run_callback(task), timeout=timeout) + # 执行成功 task.finish_execution(success=True) self._total_executions += 1 - logger.debug(f"任务 {task.task_name} 执行成功 " - f"(第{task.trigger_count}次)") - + logger.debug(f"任务 {task.task_name} 执行成功 (第{task.trigger_count}次)") + except asyncio.TimeoutError: # 任务超时 logger.warning(f"任务 {task.task_name} 执行超时 ({timeout}秒)") @@ -665,7 +648,7 @@ class UnifiedScheduler: task.finish_execution(success=False, error=TimeoutError(f"Task timeout after {timeout}s")) self._total_timeouts += 1 self._deadlock_detector.record_timeout(task.schedule_id) - + except asyncio.CancelledError: # 任务被取消 logger.debug(f"任务 {task.task_name} 被取消") @@ -673,30 +656,32 @@ class UnifiedScheduler: task.current_execution.cancel() task.status = TaskStatus.CANCELLED raise # 重新抛出,让上层处理 - + except Exception as e: # 任务执行失败 logger.error(f"任务 {task.task_name} 执行失败: {e}", exc_info=True) task.finish_execution(success=False, error=e) self._total_failures += 1 - + # 检查是否需要重试 if self.config.enable_retry and task.retry_count < task.max_retries: task.retry_count += 1 - logger.info(f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 " - f"({task.retry_count}/{task.max_retries})") + logger.info( + f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 " + f"({task.retry_count}/{task.max_retries})" + ) await asyncio.sleep(self.config.retry_delay) task.status = TaskStatus.PENDING # 重置为待触发状态 - + finally: # 清理 self._deadlock_detector.unregister_task(task.schedule_id) task._asyncio_task = None - + # 如果是一次性任务且成功完成,移动到已完成列表 if not task.is_recurring and task.status == TaskStatus.COMPLETED: await self._move_to_completed(task) - + async def _run_callback(self, task: ScheduleTask) -> Any: """运行任务回调函数""" try: @@ -706,14 +691,13 @@ class UnifiedScheduler: # 同步函数在线程池中运行,避免阻塞事件循环 loop = asyncio.get_running_loop() result = await loop.run_in_executor( - None, - lambda: task.callback(*task.callback_args, **task.callback_kwargs) + None, lambda: task.callback(*task.callback_args, **task.callback_kwargs) ) return result except Exception as e: logger.error(f"执行任务 {task.task_name} 的回调函数时出错: {e}", exc_info=True) raise - + def _acquire_semaphore(self): """获取信号量(如果启用)""" if self._task_semaphore: @@ -721,14 +705,15 @@ class UnifiedScheduler: else: # 返回一个空的上下文管理器 from contextlib import nullcontext + return nullcontext() - + async def _move_to_completed(self, task: ScheduleTask) -> None: """将任务移动到已完成列表""" if task.schedule_id in self._tasks: self._tasks.pop(task.schedule_id) self._tasks_by_name.pop(task.task_name, None) - + # 清理事件订阅 if task.trigger_type == TriggerType.EVENT: event_name = task.trigger_config.get("event_name") @@ -736,108 +721,101 @@ class UnifiedScheduler: self._event_subscriptions[event_name].discard(task.schedule_id) if not self._event_subscriptions[event_name]: del self._event_subscriptions[event_name] - + # 添加到已完成列表 self._completed_tasks.append(task) if len(self._completed_tasks) > self.config.keep_completed_tasks: self._completed_tasks.pop(0) - + logger.debug(f"一次性任务 {task.task_name} 已完成并移除") - + # ==================== 事件触发处理 ==================== - + async def _handle_event_trigger(self, event_name: str | EventType, event_params: dict[str, Any]) -> None: """处理来自 event_manager 的事件通知(无锁设计)""" task_ids = self._event_subscriptions.get(event_name, set()) if not task_ids: return - + # 收集需要触发的任务 tasks_to_trigger = [] for task_id in list(task_ids): # 使用 list() 避免迭代时修改 task = self._tasks.get(task_id) if task and task.can_trigger(): tasks_to_trigger.append(task) - + if not tasks_to_trigger: return - + logger.debug(f"事件 '{event_name}' 触发 {len(tasks_to_trigger)} 个任务") - + # 并发执行所有事件任务 execution_tasks = [] for task in tasks_to_trigger: # 将事件参数注入到回调 exec_task = asyncio.create_task( - self._execute_event_task(task, event_params), - name=f"event_exec_{task.task_name}" + self._execute_event_task(task, event_params), name=f"event_exec_{task.task_name}" ) task._asyncio_task = exec_task execution_tasks.append(exec_task) - + # 等待所有任务完成 await asyncio.gather(*execution_tasks, return_exceptions=True) - + async def _execute_event_task(self, task: ScheduleTask, event_params: dict[str, Any]) -> None: """执行事件触发的任务""" execution = task.start_execution() self._deadlock_detector.register_task(task.schedule_id, task.task_name) - + try: async with self._acquire_semaphore(): timeout = task.timeout or self.config.task_default_timeout - + try: # 合并事件参数和任务参数 merged_kwargs = {**task.callback_kwargs, **event_params} - + if asyncio.iscoroutinefunction(task.callback): - await asyncio.wait_for( - task.callback(*task.callback_args, **merged_kwargs), - timeout=timeout - ) + await asyncio.wait_for(task.callback(*task.callback_args, **merged_kwargs), timeout=timeout) else: loop = asyncio.get_running_loop() await asyncio.wait_for( - loop.run_in_executor( - None, - lambda: task.callback(*task.callback_args, **merged_kwargs) - ), - timeout=timeout + loop.run_in_executor(None, lambda: task.callback(*task.callback_args, **merged_kwargs)), + timeout=timeout, ) - + task.finish_execution(success=True) self._total_executions += 1 logger.debug(f"事件任务 {task.task_name} 执行成功") - + except asyncio.TimeoutError: logger.warning(f"事件任务 {task.task_name} 执行超时") task.status = TaskStatus.TIMEOUT task.finish_execution(success=False, error=TimeoutError()) self._total_timeouts += 1 self._deadlock_detector.record_timeout(task.schedule_id) - + except asyncio.CancelledError: logger.debug(f"事件任务 {task.task_name} 被取消") if task.current_execution: task.current_execution.cancel() task.status = TaskStatus.CANCELLED raise - + except Exception as e: logger.error(f"事件任务 {task.task_name} 执行失败: {e}", exc_info=True) task.finish_execution(success=False, error=e) self._total_failures += 1 - + finally: self._deadlock_detector.unregister_task(task.schedule_id) task._asyncio_task = None - + if not task.is_recurring and task.status == TaskStatus.COMPLETED: await self._move_to_completed(task) - + # ==================== 死锁检测和处理 ==================== - + async def _safe_check_and_handle_deadlocks(self) -> None: """安全地检查并处理死锁任务(带错误隔离)""" try: @@ -846,28 +824,25 @@ class UnifiedScheduler: logger.error("死锁检测发生递归错误,跳过本轮检测") except Exception as e: logger.error(f"死锁检测处理失败: {e}", exc_info=True) - + async def _check_and_handle_deadlocks(self) -> None: """检查并处理死锁任务""" deadlocked = self._deadlock_detector.check_deadlocks() - + if not deadlocked: return - + logger.warning(f"检测到 {len(deadlocked)} 个可能的死锁任务") - + for task_id, runtime, task_name in deadlocked: task = self._tasks.get(task_id) if not task: self._deadlock_detector.unregister_task(task_id) continue - + health = self._deadlock_detector.get_health_score(task_id) - logger.warning( - f"任务 {task_name} 疑似死锁: " - f"运行时间={runtime:.1f}秒, 健康度={health:.2f}" - ) - + logger.warning(f"任务 {task_name} 疑似死锁: 运行时间={runtime:.1f}秒, 健康度={health:.2f}") + # 尝试取消任务(每个取消操作独立处理错误) try: await self._cancel_task(task, reason="deadlock detected") @@ -877,73 +852,69 @@ class UnifiedScheduler: task._asyncio_task = None task.status = TaskStatus.CANCELLED self._deadlock_detector.unregister_task(task_id) - + async def _cancel_task(self, task: ScheduleTask, reason: str = "manual") -> bool: """取消正在运行的任务(多级超时机制)""" if not task._asyncio_task or task._asyncio_task.done(): return True - + logger.info(f"取消任务 {task.task_name} (原因: {reason})") - + # 第一阶段:发送取消信号 task._asyncio_task.cancel() - + # 第二阶段:渐进式等待(使用 asyncio.wait 避免递归) timeouts = [1.0, 3.0, 5.0, 10.0] for i, timeout in enumerate(timeouts): try: # 使用 asyncio.wait 代替 wait_for,避免重新抛出异常 - done, pending = await asyncio.wait( - {task._asyncio_task}, - timeout=timeout - ) - + done, pending = await asyncio.wait({task._asyncio_task}, timeout=timeout) + if done: # 任务已完成(可能是正常完成或被取消) - logger.debug(f"任务 {task.task_name} 在阶段 {i+1} 成功停止") + logger.debug(f"任务 {task.task_name} 在阶段 {i + 1} 成功停止") return True - + # 超时:继续下一阶段或放弃 if i < len(timeouts) - 1: - logger.warning(f"任务 {task.task_name} 取消阶段 {i+1} 超时,继续等待...") + logger.warning(f"任务 {task.task_name} 取消阶段 {i + 1} 超时,继续等待...") continue else: logger.error(f"任务 {task.task_name} 取消失败,强制清理") break - + except Exception as e: logger.error(f"取消任务 {task.task_name} 时发生异常: {e}", exc_info=True) return False - + # 第三阶段:强制清理 task._asyncio_task = None task.status = TaskStatus.CANCELLED self._deadlock_detector.unregister_task(task.schedule_id) return False - + # ==================== 资源清理 ==================== - + async def _cleanup_completed_tasks(self) -> None: """清理已完成的任务""" # 清理已完成的一次性任务 completed_tasks = [ - task for task in self._tasks.values() - if not task.is_recurring and task.status == TaskStatus.COMPLETED + task for task in self._tasks.values() if not task.is_recurring and task.status == TaskStatus.COMPLETED ] - + for task in completed_tasks: await self._move_to_completed(task) - + if completed_tasks: logger.debug(f"清理了 {len(completed_tasks)} 个已完成的任务") - + # 清理已完成的 asyncio Task for task in list(self._tasks.values()): if task._asyncio_task and task._asyncio_task.done(): task._asyncio_task = None - + # ==================== 任务管理 API ==================== - + async def create_schedule( self, callback: Callable[..., Awaitable[Any]], @@ -958,7 +929,7 @@ class UnifiedScheduler: max_retries: int = 0, ) -> str: """创建调度任务 - + Args: callback: 回调函数(必须是异步函数) trigger_type: 触发类型 @@ -970,27 +941,27 @@ class UnifiedScheduler: force_overwrite: 如果同名任务已存在,是否强制覆盖 timeout: 任务超时时间(秒),None表示使用默认值 max_retries: 最大重试次数 - + Returns: str: 创建的 schedule_id - + Raises: ValueError: 如果同名任务已存在且未启用强制覆盖 RuntimeError: 如果调度器未运行 """ if not self._running: raise RuntimeError("调度器未运行,请先调用 start()") - + # 生成任务ID和名称 schedule_id = str(uuid.uuid4()) if task_name is None: task_name = f"Task-{schedule_id[:8]}" - + # 检查同名任务 if task_name in self._tasks_by_name: existing_id = self._tasks_by_name[task_name] existing_task = self._tasks.get(existing_id) - + if existing_task and existing_task.is_active(): if force_overwrite: logger.info(f"检测到同名活跃任务 '{task_name}',启用强制覆盖,移除现有任务") @@ -1000,7 +971,7 @@ class UnifiedScheduler: f"任务名称 '{task_name}' 已存在活跃任务 (ID: {existing_id[:8]}...)。" f"如需覆盖,请设置 force_overwrite=True" ) - + # 创建任务 task = ScheduleTask( schedule_id=schedule_id, @@ -1014,14 +985,14 @@ class UnifiedScheduler: timeout=timeout, max_retries=max_retries, ) - + # 保存弱引用到调度器(避免循环引用) task._weak_scheduler = weakref.ref(self) - + # 注册任务 self._tasks[schedule_id] = task self._tasks_by_name[task_name] = schedule_id - + # 如果是事件触发,注册事件订阅 if trigger_type == TriggerType.EVENT: event_name = trigger_config.get("event_name") @@ -1029,18 +1000,18 @@ class UnifiedScheduler: raise ValueError("事件触发类型必须提供 event_name") self._event_subscriptions[event_name].add(schedule_id) logger.debug(f"任务 {task_name} 订阅事件: {event_name}") - + logger.debug(f"创建调度任务: {task_name} (ID: {schedule_id[:8]}...)") return schedule_id - + async def remove_schedule(self, schedule_id: str) -> bool: """移除调度任务 - + 如果任务正在执行,会安全地取消执行中的任务 - + Args: schedule_id: 任务ID - + Returns: bool: 是否成功移除 """ @@ -1048,15 +1019,15 @@ class UnifiedScheduler: if not task: logger.warning(f"尝试移除不存在的任务: {schedule_id[:8]}...") return False - + # 如果任务正在运行,先取消 if task.status == TaskStatus.RUNNING: await self._cancel_task(task, reason="removed") - + # 从字典中移除 self._tasks.pop(schedule_id, None) self._tasks_by_name.pop(task.task_name, None) - + # 清理事件订阅 if task.trigger_type == TriggerType.EVENT: event_name = task.trigger_config.get("event_name") @@ -1065,16 +1036,16 @@ class UnifiedScheduler: if not self._event_subscriptions[event_name]: del self._event_subscriptions[event_name] logger.debug(f"事件 '{event_name}' 已无订阅任务") - + logger.debug(f"移除调度任务: {task.task_name}") return True - + async def remove_schedule_by_name(self, task_name: str) -> bool: """根据任务名称移除调度任务 - + Args: task_name: 任务名称 - + Returns: bool: 是否成功移除 """ @@ -1083,24 +1054,24 @@ class UnifiedScheduler: return await self.remove_schedule(schedule_id) logger.warning(f"未找到名为 '{task_name}' 的任务") return False - + async def find_schedule_by_name(self, task_name: str) -> str | None: """根据任务名称查找 schedule_id - + Args: task_name: 任务名称 - + Returns: str | None: 找到的 schedule_id,如果不存在则返回 None """ return self._tasks_by_name.get(task_name) - + async def trigger_schedule(self, schedule_id: str) -> bool: """强制触发指定任务(立即执行) - + Args: schedule_id: 任务ID - + Returns: bool: 是否成功触发 """ @@ -1108,20 +1079,17 @@ class UnifiedScheduler: if not task: logger.warning(f"尝试触发不存在的任务: {schedule_id[:8]}...") return False - + if not task.can_trigger(): logger.warning(f"任务 {task.task_name} 当前状态 {task.status.value} 无法触发") return False - + logger.info(f"强制触发任务: {task.task_name}") - + # 创建执行任务 - exec_task = asyncio.create_task( - self._execute_task(task), - name=f"manual_trigger_{task.task_name}" - ) + exec_task = asyncio.create_task(self._execute_task(task), name=f"manual_trigger_{task.task_name}") task._asyncio_task = exec_task - + # 等待完成 try: await exec_task @@ -1129,13 +1097,13 @@ class UnifiedScheduler: except Exception as e: logger.error(f"强制触发任务 {task.task_name} 失败: {e}", exc_info=True) return False - + async def pause_schedule(self, schedule_id: str) -> bool: """暂停任务(不删除,但不会被触发) - + Args: schedule_id: 任务ID - + Returns: bool: 是否成功暂停 """ @@ -1143,21 +1111,21 @@ class UnifiedScheduler: if not task: logger.warning(f"尝试暂停不存在的任务: {schedule_id[:8]}...") return False - + if task.status == TaskStatus.RUNNING: logger.warning(f"任务 {task.task_name} 正在运行,无法暂停") return False - + task.status = TaskStatus.PAUSED logger.debug(f"暂停任务: {task.task_name}") return True - + async def resume_schedule(self, schedule_id: str) -> bool: """恢复暂停的任务 - + Args: schedule_id: 任务ID - + Returns: bool: 是否成功恢复 """ @@ -1165,36 +1133,36 @@ class UnifiedScheduler: if not task: logger.warning(f"尝试恢复不存在的任务: {schedule_id[:8]}...") return False - + if task.status != TaskStatus.PAUSED: logger.warning(f"任务 {task.task_name} 状态为 {task.status.value},无需恢复") return False - + task.status = TaskStatus.PENDING logger.debug(f"恢复任务: {task.task_name}") return True - + async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None: """获取任务详细信息 - + Args: schedule_id: 任务ID - + Returns: dict | None: 任务信息字典,如果不存在返回 None """ task = self._tasks.get(schedule_id) if not task: return None - + # 计算平均执行时间 avg_execution_time = 0.0 if task.success_count > 0: avg_execution_time = task.total_execution_time / task.success_count - + # 获取健康度 health = self._deadlock_detector.get_health_score(schedule_id) - + return { "schedule_id": task.schedule_id, "task_name": task.task_name, @@ -1217,18 +1185,18 @@ class UnifiedScheduler: "timeout": task.timeout, "last_error": str(task.last_error) if task.last_error else None, } - + async def list_tasks( self, trigger_type: TriggerType | None = None, status: TaskStatus | None = None, ) -> list[dict[str, Any]]: """列出所有任务或指定类型/状态的任务 - + Args: trigger_type: 触发类型过滤 status: 状态过滤 - + Returns: list: 任务信息列表 """ @@ -1239,16 +1207,16 @@ class UnifiedScheduler: continue if status is not None and task.status != status: continue - + task_info = await self.get_task_info(task.schedule_id) if task_info: tasks.append(task_info) - + return tasks - + def get_statistics(self) -> dict[str, Any]: """获取调度器统计信息 - + Returns: dict: 统计信息字典 """ @@ -1256,17 +1224,17 @@ class UnifiedScheduler: status_counts = defaultdict(int) for task in self._tasks.values(): status_counts[task.status.value] += 1 - + # 统计各类型的任务数 type_counts = defaultdict(int) for task in self._tasks.values(): type_counts[task.trigger_type.value] += 1 - + # 计算运行时长 uptime = 0.0 if self._start_time: uptime = (datetime.now() - self._start_time).total_seconds() - + # 获取正在运行的任务 running_tasks_info = [] for task in self._tasks.values(): @@ -1274,12 +1242,14 @@ class UnifiedScheduler: runtime = 0.0 if task.current_execution: runtime = (datetime.now() - task.current_execution.started_at).total_seconds() - running_tasks_info.append({ - "schedule_id": task.schedule_id[:8] + "...", - "task_name": task.task_name, - "runtime": runtime, - }) - + running_tasks_info.append( + { + "schedule_id": task.schedule_id[:8] + "...", + "task_name": task.task_name, + "runtime": runtime, + } + ) + return { "is_running": self._running, "uptime_seconds": uptime, @@ -1316,6 +1286,7 @@ class UnifiedScheduler: # 全局调度器实例 unified_scheduler = UnifiedScheduler() + async def initialize_scheduler(): """初始化调度器