diff --git a/bot.py b/bot.py index 80e6394d1..f58697bec 100644 --- a/bot.py +++ b/bot.py @@ -588,7 +588,7 @@ class MaiBotMain: async def run_async_init(self, main_system): """执行异步初始化步骤""" - + # 初始化数据库表结构 await self.initialize_database_async() diff --git a/scripts/generate_missing_embeddings.py b/scripts/generate_missing_embeddings.py index a8957e50b..afa3f59a6 100644 --- a/scripts/generate_missing_embeddings.py +++ b/scripts/generate_missing_embeddings.py @@ -19,14 +19,13 @@ import asyncio import sys from pathlib import Path -from typing import List # 添加项目根目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) async def generate_missing_embeddings( - target_node_types: List[str] = None, + target_node_types: list[str] = None, batch_size: int = 50, ): """ @@ -46,13 +45,13 @@ async def generate_missing_embeddings( target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] print(f"\n{'='*80}") - print(f"🔧 为节点生成嵌入向量") + print("🔧 为节点生成嵌入向量") print(f"{'='*80}\n") print(f"目标节点类型: {', '.join(target_node_types)}") print(f"批处理大小: {batch_size}\n") # 1. 初始化记忆管理器 - print(f"🔧 正在初始化记忆管理器...") + print("🔧 正在初始化记忆管理器...") await initialize_memory_manager() manager = get_memory_manager() @@ -60,10 +59,10 @@ async def generate_missing_embeddings( print("❌ 记忆管理器初始化失败") return - print(f"✅ 记忆管理器已初始化\n") + print("✅ 记忆管理器已初始化\n") # 2. 获取已索引的节点ID - print(f"🔍 检查现有向量索引...") + print("🔍 检查现有向量索引...") existing_node_ids = set() try: vector_count = manager.vector_store.collection.count() @@ -78,14 +77,14 @@ async def generate_missing_embeddings( ) if result and "ids" in result: existing_node_ids.update(result["ids"]) - + print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") except Exception as e: logger.warning(f"获取已索引节点ID失败: {e}") - print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n") + print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n") # 3. 收集需要生成嵌入的节点 - print(f"🔍 扫描需要生成嵌入的节点...") + print("🔍 扫描需要生成嵌入的节点...") all_memories = manager.graph_store.get_all_memories() nodes_to_process = [] @@ -110,7 +109,7 @@ async def generate_missing_embeddings( }) type_stats[node.node_type.value]["need_emb"] += 1 - print(f"\n📊 扫描结果:") + print("\n📊 扫描结果:") for node_type in target_node_types: stats = type_stats[node_type] already_ok = stats["already_indexed"] @@ -121,11 +120,11 @@ async def generate_missing_embeddings( print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") if len(nodes_to_process) == 0: - print(f"✅ 所有节点已有嵌入向量,无需生成") + print("✅ 所有节点已有嵌入向量,无需生成") return # 3. 批量生成嵌入 - print(f"🚀 开始生成嵌入向量...\n") + print("🚀 开始生成嵌入向量...\n") total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size success_count = 0 @@ -193,22 +192,22 @@ async def generate_missing_embeddings( print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") # 4. 保存图数据(更新节点的 embedding 字段) - print(f"💾 保存图数据...") + print("💾 保存图数据...") try: await manager.persistence.save_graph_store(manager.graph_store) - print(f"✅ 图数据已保存\n") + print("✅ 图数据已保存\n") except Exception as e: - logger.error(f"保存图数据失败", exc_info=True) + logger.error("保存图数据失败", exc_info=True) print(f"❌ 保存失败: {e}\n") # 5. 验证结果 - print(f"🔍 验证向量索引...") + print("🔍 验证向量索引...") final_vector_count = manager.vector_store.collection.count() stats = manager.graph_store.get_statistics() total_nodes = stats["total_nodes"] print(f"\n{'='*80}") - print(f"📊 生成完成") + print("📊 生成完成") print(f"{'='*80}") print(f"处理节点数: {len(nodes_to_process)}") print(f"成功生成: {success_count}") @@ -219,7 +218,7 @@ async def generate_missing_embeddings( print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") # 6. 测试搜索 - print(f"🧪 测试搜索功能...") + print("🧪 测试搜索功能...") test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] for query in test_queries: diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index db0fdbd73..8b5d6f3df 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") # ========== 性能配置参数 ========== -# +# # 知识提取(步骤2:txt转json)并发控制 # - 控制同时进行的LLM提取请求数量 # - 推荐值: 3-10,取决于API速率限制 @@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api): tuple: (doc_item或None, failed_hash或None) """ temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") - + # 🔧 优化:使用异步文件检查,避免阻塞 if os.path.exists(temp_file_path): try: @@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api): "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } - + # 保存到缓存(异步写入) async with aiofiles.open(temp_file_path, "wb") as f: await f.write(orjson.dumps(doc_item)) - + return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") @@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set): os.makedirs(TEMP_DIR, exist_ok=True) failed_hashes, open_ie_docs = [], [] - + # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") - + # 🔧 并发控制:限制最大并发数,防止速率限制 semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) - + async def extract_with_semaphore(pg_hash, paragraph): """带信号量控制的提取函数""" async with semaphore: @@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set): extract_with_semaphore(p_hash, paragraph) for p_hash, paragraph in paragraphs_dict.items() ] - + total = len(tasks) completed = 0 @@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set): TimeRemainingColumn(), ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=total) - + # 🔧 优化:使用 asyncio.gather 并发执行所有任务 # return_exceptions=True 确保单个失败不影响其他任务 for coro in asyncio.as_completed(tasks): @@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set): failed_hashes.append(failed_hash) elif doc_item: open_ie_docs.append(doc_item) - + completed += 1 progress.update(task, advance=1) @@ -415,7 +415,7 @@ def rebuild_faiss_only(): logger.info("--- 重建 FAISS 索引 ---") # 重建索引不需要并发参数(不涉及 embedding 生成) embed_manager = EmbeddingManager() - + logger.info("正在加载现有的 Embedding 库...") try: embed_manager.load_from_file() diff --git a/src/api/memory_visualizer_router.py b/src/api/memory_visualizer_router.py index 84971f78a..e1137d684 100644 --- a/src/api/memory_visualizer_router.py +++ b/src/api/memory_visualizer_router.py @@ -4,13 +4,13 @@ 提供 Web API 用于可视化记忆图数据 """ +from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional -from collections import defaultdict +from typing import Any import orjson -from fastapi import APIRouter, HTTPException, Request, Query +from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates @@ -29,7 +29,7 @@ router = APIRouter() templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) -def find_available_data_files() -> List[Path]: +def find_available_data_files() -> list[Path]: """查找所有可用的记忆图数据文件""" files = [] if not data_dir.exists(): @@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]: return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) -def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]: +def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]: """从磁盘加载图数据""" global graph_data_cache, current_data_file @@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any if not graph_file.exists(): return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}} - with open(graph_file, "r", encoding="utf-8") as f: + with open(graph_file, encoding="utf-8") as f: data = orjson.loads(f.read()) nodes = data.get("nodes", []) @@ -150,7 +150,7 @@ async def index(request: Request): return templates.TemplateResponse("visualizer.html", {"request": request}) -def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: +def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]: """从 MemoryManager 提取并格式化图数据""" if not memory_manager.graph_store: return {"nodes": [], "edges": [], "memories": [], "stats": {}} @@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: "arrows": "to", "memory_id": memory.id, } - + edges_list = list(edges_dict.values()) stats = memory_manager.get_statistics() @@ -261,7 +261,7 @@ async def get_paginated_graph( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(500, ge=100, le=2000, description="每页节点数"), min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"), - node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"), + node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"), ): """分页获取图数据,支持重要性过滤""" try: @@ -301,13 +301,13 @@ async def get_paginated_graph( total_pages = (total_nodes + page_size - 1) // page_size start_idx = (page - 1) * page_size end_idx = min(start_idx + page_size, total_nodes) - + paginated_nodes = nodes_with_importance[start_idx:end_idx] node_ids = set(n["id"] for n in paginated_nodes) # 只保留连接分页节点的边 paginated_edges = [ - e for e in edges + e for e in edges if e.get("from") in node_ids and e.get("to") in node_ids ] @@ -383,7 +383,7 @@ async def get_clustered_graph( return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) -def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict: +def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict: """简单的图聚类算法:按类型和连接度聚类""" # 构建邻接表 adjacency = defaultdict(set) @@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl for node in type_nodes: importance = len(adjacency[node["id"]]) node_importance.append((node, importance)) - + node_importance.sort(key=lambda x: x[1], reverse=True) - + # 保留前N个重要节点 keep_count = min(len(type_nodes), max_nodes // len(type_groups)) for node, importance in node_importance[:keep_count]: clustered_nodes.append(node) node_mapping[node["id"]] = node["id"] - + # 其余节点聚合为一个超级节点 if len(node_importance) > keep_count: clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]] cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}" cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)" - + clustered_nodes.append({ "id": cluster_id, "label": cluster_label, @@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl "cluster_size": len(clustered_node_ids), "clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示 }) - + for node_id in clustered_node_ids: node_mapping[node_id] = cluster_id @@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl for edge in edges: from_id = node_mapping.get(edge["from"]) to_id = node_mapping.get(edge["to"]) - + if from_id and to_id and from_id != to_id: edge_key = tuple(sorted([from_id, to_id])) if edge_key not in edge_set: diff --git a/src/api/statistic_router.py b/src/api/statistic_router.py index 97c80afa1..54f6836bf 100644 --- a/src/api/statistic_router.py +++ b/src/api/statistic_router.py @@ -1,6 +1,5 @@ -from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Literal +from typing import Literal from fastapi import APIRouter, HTTPException, Query diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 72be0f0f4..b286fa968 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -161,16 +161,16 @@ class EmbeddingStore: # 限制 chunk_size 和 max_workers 在合理范围内 chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) - + semaphore = asyncio.Semaphore(max_workers) llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") results = {} - + # 将字符串列表分成多个 chunk chunks = [] for i in range(0, len(strs), chunk_size): chunks.append(strs[i : i + chunk_size]) - + async def _process_chunk(chunk: list[str]): """处理一个 chunk 的字符串(批量获取 embedding)""" async with semaphore: @@ -180,12 +180,12 @@ class EmbeddingStore: embedding = await EmbeddingStore._get_embedding_async(llm, s) embeddings.append(embedding) results[s] = embedding - + if progress_callback: progress_callback(len(chunk)) - + return embeddings - + # 并发处理所有 chunks tasks = [_process_chunk(chunk) for chunk in chunks] await asyncio.gather(*tasks) @@ -418,22 +418,22 @@ class EmbeddingStore: # 🔧 修复:检查所有 embedding 的维度是否一致 dimensions = [len(emb) for emb in array] unique_dims = set(dimensions) - + if len(unique_dims) > 1: logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") - + # 获取期望的维度(使用最常见的维度) from collections import Counter dim_counter = Counter(dimensions) expected_dim = dim_counter.most_common(1)[0][0] logger.warning(f"将使用最常见的维度: {expected_dim}") - + # 过滤掉维度不匹配的 embedding filtered_array = [] filtered_idx2hash = {} skipped_count = 0 - + for i, emb in enumerate(array): if len(emb) == expected_dim: filtered_array.append(emb) @@ -442,11 +442,11 @@ class EmbeddingStore: skipped_count += 1 hash_key = self.idx2hash[str(i)] logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") - + logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") array = filtered_array self.idx2hash = filtered_idx2hash - + if not array: logger.error("过滤后没有可用的 embedding,无法构建索引") embedding_dim = expected_dim diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 175320774..c8bd18a08 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -13,4 +13,4 @@ __all__ = [ "StreamLoopManager", "message_manager", "stream_loop_manager", -] \ No newline at end of file +] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index f3dad2fb9..34731952d 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -82,7 +82,7 @@ class SingleStreamContextManager: self.total_messages += 1 self.last_access_time = time.time() - + # 如果使用了缓存系统,输出调试信息 if cache_enabled and self.context.is_cache_enabled: if self.context.is_chatter_processing: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index c88808d1d..6ae951663 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -111,9 +111,9 @@ class StreamLoopManager: # 获取或创建该流的启动锁 if stream_id not in self._stream_start_locks: self._stream_start_locks[stream_id] = asyncio.Lock() - + lock = self._stream_start_locks[stream_id] - + # 使用锁防止并发启动同一个流的多个循环任务 async with lock: # 获取流上下文 @@ -148,7 +148,7 @@ class StreamLoopManager: # 紧急取消 context.stream_loop_task.cancel() await asyncio.sleep(0.1) - + loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") # 将任务记录到 StreamContext 中 @@ -252,7 +252,7 @@ class StreamLoopManager: self.stats["total_process_cycles"] += 1 if success: logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功") - + # 🔒 处理成功后,等待一小段时间确保清理操作完成 # 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环 await asyncio.sleep(0.1) @@ -382,7 +382,7 @@ class StreamLoopManager: self.chatter_manager.process_stream_context(stream_id, context), name=f"chatter_process_{stream_id}" ) - + # 等待 chatter 任务完成 results = await chatter_task success = results.get("success", False) @@ -398,8 +398,8 @@ class StreamLoopManager: else: logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") - return success - except asyncio.CancelledError: + return success + except asyncio.CancelledError: if chatter_task and not chatter_task.done(): chatter_task.cancel() raise @@ -709,4 +709,4 @@ class StreamLoopManager: # 全局流循环管理器实例 -stream_loop_manager = StreamLoopManager() \ No newline at end of file +stream_loop_manager = StreamLoopManager() diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 9952a88d4..29da4f068 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -417,7 +417,7 @@ class MessageManager: return # 记录详细信息 - msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" + msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" for msg in unread_messages[:3]] # 只显示前3条 logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}") @@ -446,15 +446,15 @@ class MessageManager: context = chat_stream.context_manager.context if hasattr(context, "unread_messages") and context.unread_messages: unread_count = len(context.unread_messages) - + # 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们 - if unread_count > 0: + if unread_count > 0: # 获取所有未读消息的 ID message_ids = [msg.message_id for msg in context.unread_messages] - + # 标记为已读(会移到历史消息) success = chat_stream.context_manager.mark_messages_as_read(message_ids) - + if success: logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读") else: @@ -481,7 +481,7 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): + if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): chat_stream.context_manager.context.is_chatter_processing = is_processing logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}") except Exception as e: @@ -517,7 +517,7 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): + if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): return chat_stream.context_manager.context.is_chatter_processing except Exception: pass @@ -677,4 +677,4 @@ class MessageManager: # 创建全局消息管理器实例 -message_manager = MessageManager() \ No newline at end of file +message_manager = MessageManager() diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index f979c1bea..0857089b4 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -248,16 +248,16 @@ class ChatterActionManager: try: # 根据动作类型确定提示词模式 prompt_mode = "s4u" if action_name == "reply" else "normal" - + # 将prompt_mode传递给generate_reply action_data_with_mode = (action_data or {}).copy() action_data_with_mode["prompt_mode"] = prompt_mode - + # 只传递当前正在执行的动作,而不是所有可用动作 # 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用" current_action_info = self._using_actions.get(action_name) current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {} - + # 附加目标消息信息(如果存在) if target_message: # 提取目标消息的关键信息 @@ -268,7 +268,7 @@ class ChatterActionManager: "time": getattr(target_message, "time", 0), } current_actions["_target_message"] = target_msg_info - + success, response_set, _ = await generator_api.generate_reply( chat_stream=chat_stream, reply_message=target_message, @@ -295,12 +295,12 @@ class ChatterActionManager: should_quote_reply = None if action_data and isinstance(action_data, dict): should_quote_reply = action_data.get("should_quote_reply", None) - + # respond动作默认不引用回复,保持对话流畅 if action_name == "respond" and should_quote_reply is None: should_quote_reply = False - async def _after_reply(): + async def _after_reply(): # 发送并存储回复 loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( chat_stream, diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 1139496f2..0082268ae 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -372,7 +372,7 @@ class DefaultReplyer: # 确保类型安全 if isinstance(mode, str): prompt_mode_value = mode - + # 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_reply_context( @@ -1166,16 +1166,16 @@ class DefaultReplyer: from src.plugin_system.apis.chat_api import get_chat_manager chat_manager = get_chat_manager() chat_stream_obj = await chat_manager.get_stream(chat_id) - + if chat_stream_obj: unread_messages = chat_stream_obj.context_manager.get_unread_messages() if unread_messages: # 使用最后一条未读消息作为参考 last_msg = unread_messages[-1] - platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform - user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else "" - user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else "" - user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else "" + platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform + user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else "" + user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else "" + user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else "" processed_plain_text = last_msg.processed_plain_text or "" else: # 没有未读消息,使用默认值 @@ -1258,19 +1258,19 @@ class DefaultReplyer: if available_actions: # 过滤掉特殊键(以_开头) action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")} - + # 提取目标消息信息(如果存在) target_msg_info = available_actions.get("_target_message") # type: ignore - + if action_items: if len(action_items) == 1: # 单个动作 action_name, action_info = list(action_items.items())[0] action_desc = action_info.description - + # 构建基础决策信息 action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n" - + # 只有需要目标消息的动作才显示目标消息详情 # respond 动作是统一回应所有未读消息,不应该显示特定目标消息 if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict): @@ -1279,7 +1279,7 @@ class DefaultReplyer: content = target_msg_info.get("content", "") msg_time = target_msg_info.get("time", 0) time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间" - + action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n" else: # 多个动作 @@ -2166,7 +2166,7 @@ class DefaultReplyer: except Exception as e: logger.error(f"存储聊天记忆失败: {e}") - + def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/security/__init__.py b/src/chat/security/__init__.py index 328211db1..3700bb82a 100644 --- a/src/chat/security/__init__.py +++ b/src/chat/security/__init__.py @@ -5,12 +5,12 @@ 插件可以通过实现这些接口来扩展安全功能。 """ -from .interfaces import SecurityCheckResult, SecurityChecker +from .interfaces import SecurityChecker, SecurityCheckResult from .manager import SecurityManager, get_security_manager __all__ = [ - "SecurityChecker", "SecurityCheckResult", + "SecurityChecker", "SecurityManager", "get_security_manager", ] diff --git a/src/chat/security/manager.py b/src/chat/security/manager.py index 1ddc3055a..e160e8a74 100644 --- a/src/chat/security/manager.py +++ b/src/chat/security/manager.py @@ -10,7 +10,7 @@ from typing import Any from src.common.logger import get_logger -from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel +from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel logger = get_logger("security.manager") diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index 108138c5b..8afbbde57 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -98,7 +98,7 @@ class StreamContext(BaseDataModel): break def mark_message_as_read(self, message_id: str): - """标记消息为已读""" + """标记消息为已读""" # 先找到要标记的消息(处理 int/str 类型不匹配问题) message_to_mark = None for msg in self.unread_messages: @@ -106,7 +106,7 @@ class StreamContext(BaseDataModel): if str(msg.message_id) == str(message_id): message_to_mark = msg break - + # 然后移动到历史消息 if message_to_mark: message_to_mark.is_read = True diff --git a/src/common/database/optimization/cache_manager.py b/src/common/database/optimization/cache_manager.py index 60b0b0170..27b7b33a2 100644 --- a/src/common/database/optimization/cache_manager.py +++ b/src/common/database/optimization/cache_manager.py @@ -9,11 +9,12 @@ """ import asyncio +import builtins import time from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union +from typing import Any, Generic, TypeVar from src.common.logger import get_logger from src.common.memory_utils import estimate_size_smart @@ -96,7 +97,7 @@ class LRUCache(Generic[T]): self._lock = asyncio.Lock() self._stats = CacheStats() - async def get(self, key: str) -> Optional[T]: + async def get(self, key: str) -> T | None: """获取缓存值 Args: @@ -137,8 +138,8 @@ class LRUCache(Generic[T]): self, key: str, value: T, - size: Optional[int] = None, - ttl: Optional[float] = None, + size: int | None = None, + ttl: float | None = None, ) -> None: """设置缓存值 @@ -287,8 +288,8 @@ class MultiLevelCache: async def get( self, key: str, - loader: Optional[Callable[[], Any]] = None, - ) -> Optional[Any]: + loader: Callable[[], Any] | None = None, + ) -> Any | None: """从缓存获取数据 查询顺序:L1 -> L2 -> loader @@ -329,8 +330,8 @@ class MultiLevelCache: self, key: str, value: Any, - size: Optional[int] = None, - ttl: Optional[float] = None, + size: int | None = None, + ttl: float | None = None, ) -> None: """设置缓存值 @@ -390,7 +391,7 @@ class MultiLevelCache: await self.l2_cache.clear() logger.info("所有缓存已清空") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)""" # 🔧 修复:并行获取统计信息,避免锁嵌套 l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1")) @@ -492,7 +493,7 @@ class MultiLevelCache: logger.error(f"{cache_name}统计获取异常: {e}") return CacheStats() - async def _get_cache_keys_safe(self, cache) -> Set[str]: + async def _get_cache_keys_safe(self, cache) -> builtins.set[str]: """安全获取缓存键集合(带超时)""" try: # 快速获取键集合,使用超时避免死锁 @@ -507,12 +508,12 @@ class MultiLevelCache: logger.error(f"缓存键获取异常: {e}") return set() - async def _extract_keys_with_lock(self, cache) -> Set[str]: + async def _extract_keys_with_lock(self, cache) -> builtins.set[str]: """在锁保护下提取键集合""" async with cache._lock: return set(cache._cache.keys()) - async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int: + async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int: """安全计算内存使用(带超时)""" if not keys: return 0 @@ -529,7 +530,7 @@ class MultiLevelCache: logger.error(f"内存计算异常: {e}") return 0 - async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int: + async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int: """在锁保护下计算内存使用""" total_size = 0 async with cache._lock: @@ -749,7 +750,7 @@ class MultiLevelCache: # 全局缓存实例 -_global_cache: Optional[MultiLevelCache] = None +_global_cache: MultiLevelCache | None = None _cache_lock = asyncio.Lock() diff --git a/src/common/server.py b/src/common/server.py index a6ed588d8..f4553f537 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -3,7 +3,6 @@ import socket from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles from rich.traceback import install from uvicorn import Config from uvicorn import Server as UvicornServer diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 94dc125b4..89e186b3d 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -444,7 +444,7 @@ class OpenaiClient(BaseClient): # 🔧 优化:增加连接池限制,支持高并发embedding请求 # 默认httpx限制为100,对于高频embedding场景不够用 import httpx - + limits = httpx.Limits( max_keepalive_connections=200, # 保持活跃连接数(原100) max_connections=300, # 最大总连接数(原100) diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py index 4b0d66218..00f55c0fa 100644 --- a/src/memory_graph/core/builder.py +++ b/src/memory_graph/core/builder.py @@ -128,7 +128,7 @@ class MemoryBuilder: # 6. 构建 Memory 对象 # 新记忆应该有较高的初始激活度 initial_activation = 0.75 # 新记忆初始激活度为 0.75 - + memory = Memory( id=memory_id, subject_id=subject_node.id, diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 2c234ae86..32fe76f20 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -149,7 +149,7 @@ class MemoryManager: # 读取阈值过滤配置 search_min_importance = self.config.search_min_importance search_similarity_threshold = self.config.search_similarity_threshold - + logger.info( f"📊 配置检查: search_max_expand_depth={expand_depth}, " f"search_expand_semantic_threshold={expand_semantic_threshold}, " @@ -417,7 +417,7 @@ class MemoryManager: # 使用配置的默认值 if top_k is None: top_k = getattr(self.config, "search_top_k", 10) - + # 准备搜索参数 params = { "query": query, @@ -951,7 +951,7 @@ class MemoryManager: ) else: logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)") - + # 4. 保存更新 await self.persistence.save_graph_store(self.graph_store) return True @@ -984,7 +984,7 @@ class MemoryManager: try: forgotten_count = 0 all_memories = self.graph_store.get_all_memories() - + # 获取配置参数 min_importance = getattr(self.config, "forgetting_min_importance", 0.8) decay_rate = getattr(self.config, "activation_decay_rate", 0.9) @@ -1010,10 +1010,10 @@ class MemoryManager: try: last_access_dt = datetime.fromisoformat(last_access) days_passed = (datetime.now() - last_access_dt).days - + # 应用指数衰减:activation = base * (decay_rate ^ days) current_activation = base_activation * (decay_rate ** days_passed) - + logger.debug( f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, " f"经过{days_passed}天衰减后={current_activation:.3f}" @@ -1035,20 +1035,20 @@ class MemoryManager: # 批量遗忘记忆(不立即清理孤立节点) if memories_to_forget: logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...") - + for memory_id, activation in memories_to_forget: # cleanup_orphans=False:暂不清理孤立节点 success = await self.forget_memory(memory_id, cleanup_orphans=False) if success: forgotten_count += 1 - + # 统一清理孤立节点和边 logger.info("批量遗忘完成,开始统一清理孤立节点和边...") orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges() - + # 保存最终更新 await self.persistence.save_graph_store(self.graph_store) - + logger.info( f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, " f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边" @@ -1079,31 +1079,31 @@ class MemoryManager: # 1. 清理孤立节点 # graph_store.node_to_memories 记录了每个节点属于哪些记忆 nodes_to_remove = [] - + for node_id, memory_ids in list(self.graph_store.node_to_memories.items()): # 如果节点不再属于任何记忆,标记为删除 if not memory_ids: nodes_to_remove.append(node_id) - + # 从图中删除孤立节点 for node_id in nodes_to_remove: if self.graph_store.graph.has_node(node_id): self.graph_store.graph.remove_node(node_id) orphan_nodes_count += 1 - + # 从映射中删除 if node_id in self.graph_store.node_to_memories: del self.graph_store.node_to_memories[node_id] - + # 2. 清理孤立边(指向已删除节点的边) edges_to_remove = [] - - for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'): + + for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"): # 检查边的源节点和目标节点是否还存在于node_to_memories中 if source not in self.graph_store.node_to_memories or \ target not in self.graph_store.node_to_memories: edges_to_remove.append((source, target)) - + # 删除孤立边 for source, target in edges_to_remove: try: @@ -1111,12 +1111,12 @@ class MemoryManager: orphan_edges_count += 1 except Exception as e: logger.debug(f"删除边失败 {source} -> {target}: {e}") - + if orphan_nodes_count > 0 or orphan_edges_count > 0: logger.info( f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边" ) - + return orphan_nodes_count, orphan_edges_count except Exception as e: @@ -1258,7 +1258,7 @@ class MemoryManager: mem for mem in recent_memories if mem.importance >= min_importance_for_consolidation ] - + result["importance_filtered"] = len(recent_memories) - len(important_memories) logger.info( f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): " @@ -1382,26 +1382,26 @@ class MemoryManager: # ===== 步骤4: 向量检索关联记忆 + LLM分析关系 ===== # 过滤掉已删除的记忆 remaining_memories = [m for m in important_memories if m.id not in deleted_ids] - + if not remaining_memories: logger.info("✅ 记忆整理完成: 去重后无剩余记忆") return logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...") - + # 分批处理记忆关联 llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10) max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5) min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6) - + all_new_edges = [] # 收集所有新建的边 - + for batch_start in range(0, len(remaining_memories), llm_batch_size): batch_end = min(batch_start + llm_batch_size, len(remaining_memories)) batch = remaining_memories[batch_start:batch_end] - + logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}") - + for memory in batch: # 跳过已经有很多连接的记忆 existing_edges = len([ @@ -1454,14 +1454,14 @@ class MemoryManager: except Exception as e: logger.warning(f"创建关联边失败: {e}") continue - + # 每个批次后让出控制权 await asyncio.sleep(0.01) # ===== 步骤5: 统一更新记忆数据 ===== if all_new_edges: logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...") - + for memory, edge, relation in all_new_edges: try: # 添加到图 @@ -2301,7 +2301,7 @@ class MemoryManager: # 使用 asyncio.wait_for 来支持取消 await asyncio.wait_for( asyncio.sleep(initial_delay), - timeout=float('inf') # 允许随时取消 + timeout=float("inf") # 允许随时取消 ) # 检查是否仍然需要运行 diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index 93e30cb83..ed7c16a2c 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -482,7 +482,7 @@ class GraphStore: for node in memory.nodes: if node.id in self.node_to_memories: self.node_to_memories[node.id].discard(memory_id) - + # 可选:立即清理孤立节点 if cleanup_orphans: # 如果该节点不再属于任何记忆,从图中移除节点 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index e87557cb8..33618dd8c 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -72,12 +72,12 @@ class MemoryTools: self.max_expand_depth = max_expand_depth self.expand_semantic_threshold = expand_semantic_threshold self.search_top_k = search_top_k - + # 保存权重配置 self.base_vector_weight = search_vector_weight self.base_importance_weight = search_importance_weight self.base_recency_weight = search_recency_weight - + # 保存阈值过滤配置 self.search_min_importance = search_min_importance self.search_similarity_threshold = search_similarity_threshold @@ -516,14 +516,14 @@ class MemoryTools: # 1. 根据策略选择检索方式 llm_prefer_types = [] # LLM识别的偏好节点类型 - + if use_multi_query: # 多查询策略(返回节点列表 + 偏好类型) similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context) else: # 传统单查询策略 similar_nodes = await self._single_query_search(query, top_k) - + # 合并用户指定的偏好类型和LLM识别的偏好类型 all_prefer_types = list(set(prefer_node_types + llm_prefer_types)) if all_prefer_types: @@ -551,7 +551,7 @@ class MemoryTools: # 记录最高分数 if mem_id not in memory_scores or similarity > memory_scores[mem_id]: memory_scores[mem_id] = similarity - + # 🔥 详细日志:检查初始召回情况 logger.info( f"初始向量搜索: 返回{len(similar_nodes)}个节点 → " @@ -559,8 +559,8 @@ class MemoryTools: ) if len(initial_memory_ids) == 0: logger.warning( - f"⚠️ 向量搜索未找到任何记忆!" - f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" + "⚠️ 向量搜索未找到任何记忆!" + "可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" ) # 输出相似节点的详细信息用于调试 if similar_nodes: @@ -692,7 +692,7 @@ class MemoryTools: key=lambda x: final_scores[x], reverse=True ) # 🔥 不再提前截断,让所有候选参与详细评分 - + # 🔍 统计初始记忆的相似度分布(用于诊断) if memory_scores: similarities = list(memory_scores.values()) @@ -707,7 +707,7 @@ class MemoryTools: # 5. 获取完整记忆并进行最终排序(优化后的动态权重系统) memories_with_scores = [] filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计 - + for memory_id in sorted_memory_ids: # 遍历所有候选 memory = self.graph_store.get_memory_by_id(memory_id) if memory: @@ -715,7 +715,7 @@ class MemoryTools: # 基础分数 similarity_score = final_scores[memory_id] importance_score = memory.importance - + # 🆕 区分记忆来源(用于过滤) is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索 true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None @@ -738,16 +738,16 @@ class MemoryTools: activation_score = memory.activation # 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调 - memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type) - + memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type) + # 检测记忆的主要节点类型 node_types_count = {} for node in memory.nodes: - nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type) + nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type) node_types_count[nt] = node_types_count.get(nt, 0) + 1 - + dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown" - + # 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调) if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT": # 事实性记忆:提升相似度权重,降低时效性权重 @@ -777,41 +777,41 @@ class MemoryTools: "importance": 1.0, "recency": 1.0, } - + # 应用调整后的权重(基于配置的基础权重) weights = { "similarity": self.base_vector_weight * type_adjustments["similarity"], "importance": self.base_importance_weight * type_adjustments["importance"], "recency": self.base_recency_weight * type_adjustments["recency"], } - + # 归一化权重(确保总和为1.0) total_weight = sum(weights.values()) if total_weight > 0: weights = {k: v / total_weight for k, v in weights.items()} - + # 综合分数计算(🔥 移除激活度影响) final_score = ( similarity_score * weights["similarity"] + importance_score * weights["importance"] + recency_score * weights["recency"] ) - + # 🆕 阈值过滤策略: # 1. 重要性过滤:应用于所有记忆(过滤极低质量) if memory.importance < self.search_min_importance: filter_stats["importance"] += 1 logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}") continue - + # 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序) # 理由:向量搜索已经按相似度排序,返回的都是最相关结果 # 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾 - # + # # 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑 # if not is_initial_memory and some_score < threshold: # continue - + # 记录通过过滤的记忆(用于调试) if is_initial_memory: logger.debug( @@ -823,11 +823,11 @@ class MemoryTools: f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, " f"综合分数={final_score:.4f}" ) - + # 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回) if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count: final_score *= 1.1 # 10% 加成 - + # 🆕 用户指定的优先节点类型额外加权 if prefer_node_types: for prefer_type in prefer_node_types: @@ -835,7 +835,7 @@ class MemoryTools: final_score *= 1.15 # 15% 额外加成 logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}") break - + memories_with_scores.append((memory, final_score, dominant_node_type)) # 按综合分数排序 @@ -845,7 +845,7 @@ class MemoryTools: # 统计过滤情况 total_candidates = len(all_memory_ids) filtered_count = total_candidates - len(memories_with_scores) - + # 6. 格式化结果(包含调试信息) results = [] for memory, score, node_type in memories_with_scores[:top_k]: @@ -866,7 +866,7 @@ class MemoryTools: f"过滤{filtered_count}个 (重要性过滤) → " f"最终返回{len(results)}条记忆" ) - + # 如果过滤率过高,发出警告 if total_candidates > 0: filter_rate = filtered_count / total_candidates @@ -1092,20 +1092,21 @@ class MemoryTools: response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) import re + import orjson - + # 清理Markdown代码块 response = re.sub(r"```json\s*", "", response) response = re.sub(r"```\s*$", "", response).strip() # 解析JSON data = orjson.loads(response) - + # 提取查询列表 queries = data.get("queries", []) result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) for item in queries if item.get("text", "").strip()] - + # 提取偏好节点类型 prefer_node_types = data.get("prefer_node_types", []) # 确保类型正确且有效 @@ -1154,7 +1155,7 @@ class MemoryTools: limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率 min_similarity=0.0, # 不在这里过滤,交给后续评分 ) - + logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}") if similar_nodes: logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}") diff --git a/src/memory_graph/utils/graph_expansion.py b/src/memory_graph/utils/graph_expansion.py index f7c850f74..babfba788 100644 --- a/src/memory_graph/utils/graph_expansion.py +++ b/src/memory_graph/utils/graph_expansion.py @@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter( try: import time start_time = time.time() - + # 记录已访问的记忆,避免重复 visited_memories = set(initial_memory_ids) # 记录扩展的记忆及其分数 @@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter( # 获取该记忆的邻居记忆(通过边关系) neighbor_memory_ids = set() - + # 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重) edge_weights = {} # 记录通过不同边类型到达的记忆的权重 - + for edge in memory.edges: # 获取边的目标节点 target_node_id = edge.target_id source_node_id = edge.source_id - + # 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边) - edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) + edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type) if edge_type_str == "REFERENCE": edge_weight = 1.3 # REFERENCE边权重最高(引用关系) elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]: @@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter( edge_weight = 0.9 # 一般关系适中降权 else: edge_weight = 1.0 # 默认权重 - + # 通过节点找到其他记忆 for node_id in [target_node_id, source_node_id]: if node_id in graph_store.node_to_memories: for neighbor_id in graph_store.node_to_memories[node_id]: if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight: edge_weights[neighbor_id] = edge_weight - + # 将权重高的邻居记忆加入候选 for neighbor_id, edge_weight in edge_weights.items(): neighbor_memory_ids.add((neighbor_id, edge_weight)) - + # 过滤掉已访问的和自己 filtered_neighbors = [] for neighbor_id, edge_weight in neighbor_memory_ids: @@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter( # 批量评估邻居记忆 for neighbor_mem_id, edge_weight in filtered_neighbors: candidates_checked += 1 - + neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id) if not neighbor_memory: continue @@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter( (n for n in neighbor_memory.nodes if n.has_embedding()), None ) - + if not topic_node or topic_node.embedding is None: continue @@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter( if len(expanded_memories) >= max_expanded: logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}") break - + # 早停检查 if len(expanded_memories) >= max_expanded: break - + # 记录本层统计 depth_stats.append({ "depth": depth + 1, @@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter( # 限制下一层的记忆数量,避免爆炸性增长 current_level_memories = next_level_memories[:max_expanded] - + # 每层让出控制权 await asyncio.sleep(0.001) # 排序并返回 sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded] - + elapsed = time.time() - start_time logger.info( f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → " f"扩展{len(sorted_results)}个新记忆 " f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)" ) - + # 输出每层统计 for stat in depth_stats: logger.debug( diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 2025a4c31..40de6885a 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -137,7 +137,7 @@ async def generate_reply( prompt_mode = "s4u" # 默认使用s4u模式 if action_data and "prompt_mode" in action_data: prompt_mode = action_data.get("prompt_mode", "s4u") - + # 将prompt_mode添加到available_actions中(作为特殊键) # 注意:这里我们需要暂时使用类型忽略,因为available_actions的类型定义不支持非ActionInfo值 if available_actions is None: diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index 3837c29d7..bf60af9ab 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -362,7 +362,7 @@ class ChatterPlanFilter: return "最近没有聊天内容。", "没有未读消息。", [] stream_context = chat_stream.context_manager - + # 获取真正的已读和未读消息 read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 if not read_messages: @@ -660,30 +660,30 @@ class ChatterPlanFilter: if not action_info: logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数") return action_data - + # 获取该动作定义的合法参数 defined_params = set(action_info.action_parameters.keys()) - + # 合法参数集合 valid_params = defined_params - + # 过滤参数 filtered_data = {} removed_params = [] - + for key, value in action_data.items(): if key in valid_params: filtered_data[key] = value else: removed_params.append(key) - + # 记录被移除的参数 if removed_params: logger.info( f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. " f"合法参数: {sorted(valid_params)}" ) - + return filtered_data def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py index c8a0c1e60..908bb487c 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py @@ -615,14 +615,14 @@ async def execute_proactive_thinking(stream_id: str): # 获取或创建该聊天流的执行锁 if stream_id not in _execution_locks: _execution_locks[stream_id] = asyncio.Lock() - + lock = _execution_locks[stream_id] - + # 尝试获取锁,如果已被占用则跳过本次执行(防止重复) if lock.locked(): logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务") return - + async with lock: logger.debug(f"🤔 开始主动思考 {stream_id}") @@ -633,13 +633,13 @@ async def execute_proactive_thinking(stream_id: str): from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream and chat_stream.context_manager.context.is_chatter_processing: logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息") return except Exception as e: logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行") - + # 0.1 检查白名单/黑名单 # 从 stream_id 获取 stream_config 字符串进行验证 try: diff --git a/src/plugins/built_in/anti_injection_plugin/__init__.py b/src/plugins/built_in/anti_injection_plugin/__init__.py index 808164495..139e2e64d 100644 --- a/src/plugins/built_in/anti_injection_plugin/__init__.py +++ b/src/plugins/built_in/anti_injection_plugin/__init__.py @@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata( # 导入插件主类 from .plugin import AntiInjectionPlugin -__all__ = ["__plugin_meta__", "AntiInjectionPlugin"] +__all__ = ["AntiInjectionPlugin", "__plugin_meta__"] diff --git a/src/plugins/built_in/anti_injection_plugin/checker.py b/src/plugins/built_in/anti_injection_plugin/checker.py index 136e4aae4..68c1284e6 100644 --- a/src/plugins/built_in/anti_injection_plugin/checker.py +++ b/src/plugins/built_in/anti_injection_plugin/checker.py @@ -8,8 +8,8 @@ import time from src.chat.security.interfaces import ( SecurityAction, - SecurityCheckResult, SecurityChecker, + SecurityCheckResult, SecurityLevel, ) from src.common.logger import get_logger diff --git a/src/plugins/built_in/anti_injection_plugin/processor.py b/src/plugins/built_in/anti_injection_plugin/processor.py index 9960f1521..fa0380569 100644 --- a/src/plugins/built_in/anti_injection_plugin/processor.py +++ b/src/plugins/built_in/anti_injection_plugin/processor.py @@ -4,7 +4,7 @@ 处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。 """ -from src.chat.security.interfaces import SecurityAction, SecurityCheckResult +from src.chat.security.interfaces import SecurityCheckResult from src.common.logger import get_logger from .counter_attack import CounterAttackGenerator diff --git a/src/plugins/built_in/core_actions/reply.py b/src/plugins/built_in/core_actions/reply.py index 993ddb85a..9a90f7e33 100644 --- a/src/plugins/built_in/core_actions/reply.py +++ b/src/plugins/built_in/core_actions/reply.py @@ -22,23 +22,23 @@ class ReplyAction(BaseAction): - 专注于理解和回应单条消息的具体内容 - 适合 Focus 模式下的精准回复 """ - + # 动作基本信息 action_name = "reply" action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。" - + # 激活设置 activation_type = ActionActivationType.ALWAYS # 回复动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回复动作不能与其他动作并行 - + # 动作参数定义 action_parameters: ClassVar = { "target_message_id": "要回复的目标消息ID(必需,来自未读消息的 标签)", "content": "回复的具体内容(可选,由LLM生成)", "should_quote_reply": "是否引用原消息(可选,true/false,默认false。群聊中回复较早消息或需要明确指向时使用true)", } - + # 动作使用场景 action_require: ClassVar = [ "需要针对特定消息进行精准回复时使用", @@ -48,10 +48,10 @@ class ReplyAction(BaseAction): "群聊中需要明确回应某个特定用户或问题时使用", "关注单条消息的具体内容和上下文细节", ] - + # 关联类型 associated_types: ClassVar[list[str]] = ["text"] - + async def execute(self) -> tuple[bool, str]: """执行reply动作 @@ -70,21 +70,21 @@ class RespondAction(BaseAction): - 适合对于群聊消息下的宏观回应 - 避免与单一用户深度对话而忽略其他用户的消息 """ - + # 动作基本信息 action_name = "respond" action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。" - + # 激活设置 activation_type = ActionActivationType.ALWAYS # 回应动作总是可用 mode_enable = ChatMode.ALL # 在所有模式下都可用 parallel_action = False # 回应动作不能与其他动作并行 - + # 动作参数定义 action_parameters: ClassVar = { "content": "回复的具体内容(可选,由LLM生成)", } - + # 动作使用场景 action_require: ClassVar = [ "需要统一回应多条未读消息时使用(Normal 模式专用)", @@ -94,10 +94,10 @@ class RespondAction(BaseAction): "适合群聊中的自然对话流,无需精确指向特定消息", "可以同时回应多个话题或参与者", ] - + # 关联类型 associated_types: ClassVar[list[str]] = ["text"] - + async def execute(self) -> tuple[bool, str]: """执行respond动作 diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py index 3b1c7a014..5baaa3a8e 100644 --- a/src/plugins/built_in/emoji_plugin/plugin.py +++ b/src/plugins/built_in/emoji_plugin/plugin.py @@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin): # --- 根据配置注册组件 --- components: ClassVar = [] - + # 注册 reply 动作 if self.get_config("components.enable_reply", True): components.append((ReplyAction.get_action_info(), ReplyAction)) - + # 注册 respond 动作 if self.get_config("components.enable_respond", True): components.append((RespondAction.get_action_info(), RespondAction)) - + # 注册 emoji 动作 if self.get_config("components.enable_emoji", True): components.append((EmojiAction.get_action_info(), EmojiAction)) diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index b446d1dbb..2dc95d949 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -6,10 +6,10 @@ import asyncio import base64 import datetime -import filetype from collections.abc import Callable import aiohttp +import filetype from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/schedule/unified_scheduler.py b/src/schedule/unified_scheduler.py index deca5ac91..c0af4db16 100644 --- a/src/schedule/unified_scheduler.py +++ b/src/schedule/unified_scheduler.py @@ -17,7 +17,6 @@ import uuid import weakref from collections import defaultdict from collections.abc import Awaitable, Callable -from contextlib import suppress from dataclasses import dataclass, field from datetime import datetime from enum import Enum