From 5caf630623f0bac7b24897cb9a7f536104abfdfb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 7 Nov 2025 04:39:35 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E8=B4=A8=E9=87=8F=E9=97=AE=E9=A2=98=20-=20=E6=9B=B4=E6=AD=A3?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E5=A4=84=E7=90=86=E5=92=8C=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E8=AF=AD=E5=8F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com> --- plugins/memory_graph_plugin/plugin.py | 10 +- scripts/deduplicate_memories.py | 157 +++-- src/memory_graph/__init__.py | 10 +- src/memory_graph/core/__init__.py | 2 +- src/memory_graph/core/builder.py | 69 +- src/memory_graph/core/extractor.py | 54 +- src/memory_graph/core/node_merger.py | 51 +- src/memory_graph/manager.py | 611 +++++++++--------- src/memory_graph/manager_singleton.py | 43 +- src/memory_graph/models.py | 38 +- .../plugin_tools/memory_plugin_tools.py | 58 +- src/memory_graph/storage/__init__.py | 2 +- src/memory_graph/storage/graph_store.py | 99 ++- src/memory_graph/storage/persistence.py | 42 +- src/memory_graph/storage/vector_store.py | 79 ++- src/memory_graph/tools/memory_tools.py | 230 +++---- src/memory_graph/utils/__init__.py | 2 +- src/memory_graph/utils/embeddings.py | 78 ++- src/memory_graph/utils/memory_formatter.py | 113 ++-- src/memory_graph/utils/time_parser.py | 55 +- 20 files changed, 893 insertions(+), 910 deletions(-) diff --git a/plugins/memory_graph_plugin/plugin.py b/plugins/memory_graph_plugin/plugin.py index 3237d6363..eb20c1b78 100644 --- a/plugins/memory_graph_plugin/plugin.py +++ b/plugins/memory_graph_plugin/plugin.py @@ -6,10 +6,12 @@ from typing import ClassVar from src.common.logger import get_logger from src.plugin_system import BasePlugin, register_plugin -from src.plugin_system.base.component_types import ComponentInfo, ToolInfo logger = get_logger("memory_graph_plugin") +# 用于存储后台任务引用 +_background_tasks = set() + @register_plugin class MemoryGraphPlugin(BasePlugin): @@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin): """插件卸载时的回调""" try: import asyncio + from src.memory_graph.manager_singleton import shutdown_memory_manager logger.info(f"{self.log_prefix} 正在关闭记忆系统...") @@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin): loop = asyncio.get_event_loop() if loop.is_running(): # 如果循环正在运行,创建任务 - asyncio.create_task(shutdown_memory_manager()) + task = asyncio.create_task(shutdown_memory_manager()) + # 存储引用以防止任务被垃圾回收 + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) else: # 如果循环未运行,直接运行 loop.run_until_complete(shutdown_memory_manager()) diff --git a/scripts/deduplicate_memories.py b/scripts/deduplicate_memories.py index 936fd9014..5eb1f1aa1 100644 --- a/scripts/deduplicate_memories.py +++ b/scripts/deduplicate_memories.py @@ -10,13 +10,13 @@ 使用方法: # 预览模式(不实际删除) python scripts/deduplicate_memories.py --dry-run - + # 执行去重 python scripts/deduplicate_memories.py - + # 指定相似度阈值 python scripts/deduplicate_memories.py --threshold 0.9 - + # 指定数据目录 python scripts/deduplicate_memories.py --data-dir data/memory_graph """ @@ -25,27 +25,26 @@ import asyncio import sys from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Set, Tuple import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent)) from src.common.logger import get_logger -from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager +from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager logger = get_logger(__name__) class MemoryDeduplicator: """记忆去重器""" - + def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85): self.data_dir = data_dir self.dry_run = dry_run self.threshold = threshold self.manager = None - + # 统计信息 self.stats = { "total_memories": 0, @@ -54,34 +53,34 @@ class MemoryDeduplicator: "duplicates_removed": 0, "errors": 0, } - + async def initialize(self): """初始化记忆管理器""" logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...") self.manager = await initialize_memory_manager(data_dir=self.data_dir) if not self.manager: raise RuntimeError("记忆管理器初始化失败") - + self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories()) logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆") - - async def find_similar_pairs(self) -> List[Tuple[str, str, float]]: + + async def find_similar_pairs(self) -> list[tuple[str, str, float]]: """ 查找所有相似的记忆对(通过向量相似度计算) - + Returns: [(memory_id_1, memory_id_2, similarity), ...] """ logger.info("正在扫描相似记忆对...") similar_pairs = [] seen_pairs = set() # 避免重复 - + # 获取所有记忆 all_memories = self.manager.graph_store.get_all_memories() total_memories = len(all_memories) - + logger.info(f"开始计算 {total_memories} 条记忆的相似度...") - + # 两两比较记忆的相似度 for i, memory_i in enumerate(all_memories): # 每处理10条记忆让出控制权 @@ -89,115 +88,115 @@ class MemoryDeduplicator: await asyncio.sleep(0) if i > 0: logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)") - + # 获取记忆i的向量(从主题节点) vector_i = None for node in memory_i.nodes: if node.embedding is not None: vector_i = node.embedding break - + if vector_i is None: continue - + # 与后续记忆比较 for j in range(i + 1, total_memories): memory_j = all_memories[j] - + # 获取记忆j的向量 vector_j = None for node in memory_j.nodes: if node.embedding is not None: vector_j = node.embedding break - + if vector_j is None: continue - + # 计算余弦相似度 similarity = self._cosine_similarity(vector_i, vector_j) - + # 只保存满足阈值的相似对 if similarity >= self.threshold: pair_key = tuple(sorted([memory_i.id, memory_j.id])) if pair_key not in seen_pairs: seen_pairs.add(pair_key) similar_pairs.append((memory_i.id, memory_j.id, similarity)) - + self.stats["similar_pairs"] = len(similar_pairs) logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})") - + return similar_pairs - + def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """计算余弦相似度""" try: vec1_norm = np.linalg.norm(vec1) vec2_norm = np.linalg.norm(vec2) - + if vec1_norm == 0 or vec2_norm == 0: return 0.0 - + similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) return float(similarity) except Exception as e: logger.error(f"计算余弦相似度失败: {e}") return 0.0 - - def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]: + + def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]: """ 决定保留哪个记忆,删除哪个 - + 优先级: 1. 重要性更高的 2. 激活度更高的 3. 创建时间更早的 - + Returns: (keep_id, remove_id) """ mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1) mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2) - + if not mem1 or not mem2: logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}") return None, None - + # 比较重要性 if mem1.importance > mem2.importance: return mem_id_1, mem_id_2 elif mem1.importance < mem2.importance: return mem_id_2, mem_id_1 - + # 重要性相同,比较激活度 if mem1.activation > mem2.activation: return mem_id_1, mem_id_2 elif mem1.activation < mem2.activation: return mem_id_2, mem_id_1 - + # 激活度也相同,保留更早创建的 if mem1.created_at < mem2.created_at: return mem_id_1, mem_id_2 else: return mem_id_2, mem_id_1 - + async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool: """ 去重一对相似记忆 - + Returns: 是否成功去重 """ keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2) - + if not keep_id or not remove_id: self.stats["errors"] += 1 return False - + keep_mem = self.manager.graph_store.get_memory_by_id(keep_id) remove_mem = self.manager.graph_store.get_memory_by_id(remove_id) - - logger.info(f"") + + logger.info("") logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):") logger.info(f" 保留: {keep_id}") logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}") @@ -209,41 +208,41 @@ class MemoryDeduplicator: logger.info(f" - 重要性: {remove_mem.importance:.2f}") logger.info(f" - 激活度: {remove_mem.activation:.2f}") logger.info(f" - 创建时间: {remove_mem.created_at}") - + if self.dry_run: logger.info(" [预览模式] 不执行实际删除") self.stats["duplicates_found"] += 1 return True - + try: # 增强保留记忆的属性 keep_mem.importance = min(1.0, keep_mem.importance + 0.05) keep_mem.activation = min(1.0, keep_mem.activation + 0.05) - + # 累加访问次数 - if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'): + if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"): keep_mem.access_count += remove_mem.access_count - + # 删除相似记忆 await self.manager.delete_memory(remove_id) - + self.stats["duplicates_removed"] += 1 - logger.info(f" ✅ 删除成功") - + logger.info(" ✅ 删除成功") + # 让出控制权 await asyncio.sleep(0) - + return True - + except Exception as e: logger.error(f" ❌ 删除失败: {e}", exc_info=True) self.stats["errors"] += 1 return False - + async def run(self): """执行去重""" start_time = datetime.now() - + print("="*70) print("记忆去重工具") print("="*70) @@ -252,13 +251,13 @@ class MemoryDeduplicator: print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}") print("="*70) print() - + # 初始化 await self.initialize() - + # 查找相似对 similar_pairs = await self.find_similar_pairs() - + if not similar_pairs: logger.info("未找到需要去重的相似记忆对") print() @@ -266,19 +265,19 @@ class MemoryDeduplicator: print("未找到需要去重的记忆") print("="*70) return - + # 去重处理 logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...") print() - + processed_pairs = set() # 避免重复处理 - + for mem_id_1, mem_id_2, similarity in similar_pairs: # 检查是否已处理(可能一个记忆已被删除) pair_key = tuple(sorted([mem_id_1, mem_id_2])) if pair_key in processed_pairs: continue - + # 检查记忆是否仍存在 if not self.manager.graph_store.get_memory_by_id(mem_id_1): logger.debug(f"记忆 {mem_id_1} 已不存在,跳过") @@ -286,22 +285,22 @@ class MemoryDeduplicator: if not self.manager.graph_store.get_memory_by_id(mem_id_2): logger.debug(f"记忆 {mem_id_2} 已不存在,跳过") continue - + # 执行去重 success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity) - + if success: processed_pairs.add(pair_key) - + # 保存数据(如果不是干运行) if not self.dry_run: logger.info("正在保存数据...") await self.manager.persistence.save_graph_store(self.manager.graph_store) logger.info("✅ 数据已保存") - + # 统计报告 elapsed = (datetime.now() - start_time).total_seconds() - + print() print("="*70) print("去重报告") @@ -312,7 +311,7 @@ class MemoryDeduplicator: print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}") print(f"错误数: {self.stats['errors']}") print(f"耗时: {elapsed:.2f}秒") - + if self.dry_run: print() print("⚠️ 这是预览模式,未实际删除任何记忆") @@ -322,9 +321,9 @@ class MemoryDeduplicator: print("✅ 去重完成!") final_count = len(self.manager.graph_store.get_all_memories()) print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)") - + print("="*70) - + async def cleanup(self): """清理资源""" if self.manager: @@ -340,50 +339,50 @@ async def main(): 示例: # 预览模式(推荐先运行) python scripts/deduplicate_memories.py --dry-run - + # 执行去重 python scripts/deduplicate_memories.py - + # 指定相似度阈值(只处理相似度>=0.9的记忆对) python scripts/deduplicate_memories.py --threshold 0.9 - + # 指定数据目录 python scripts/deduplicate_memories.py --data-dir data/memory_graph - + # 组合使用 python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test """ ) - + parser.add_argument( "--dry-run", action="store_true", help="预览模式,不实际删除记忆(推荐先运行此模式)" ) - + parser.add_argument( "--threshold", type=float, default=0.85, help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)" ) - + parser.add_argument( "--data-dir", type=str, default="data/memory_graph", help="记忆数据目录(默认: data/memory_graph)" ) - + args = parser.parse_args() - + # 创建去重器 deduplicator = MemoryDeduplicator( data_dir=args.data_dir, dry_run=args.dry_run, threshold=args.threshold ) - + try: # 执行去重 await deduplicator.run() @@ -396,7 +395,7 @@ async def main(): finally: # 清理资源 await deduplicator.cleanup() - + return 0 diff --git a/src/memory_graph/__init__.py b/src/memory_graph/__init__.py index f2070de74..c692d6c07 100644 --- a/src/memory_graph/__init__.py +++ b/src/memory_graph/__init__.py @@ -6,24 +6,24 @@ from src.memory_graph.manager import MemoryManager from src.memory_graph.models import ( + EdgeType, Memory, MemoryEdge, MemoryNode, MemoryStatus, MemoryType, NodeType, - EdgeType, ) __all__ = [ - "MemoryManager", + "EdgeType", "Memory", - "MemoryNode", "MemoryEdge", + "MemoryManager", + "MemoryNode", + "MemoryStatus", "MemoryType", "NodeType", - "EdgeType", - "MemoryStatus", ] __version__ = "0.1.0" diff --git a/src/memory_graph/core/__init__.py b/src/memory_graph/core/__init__.py index c6dc426db..556407efd 100644 --- a/src/memory_graph/core/__init__.py +++ b/src/memory_graph/core/__init__.py @@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.node_merger import NodeMerger -__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"] +__all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"] diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py index df3494d2e..a5cad1f20 100644 --- a/src/memory_graph/core/builder.py +++ b/src/memory_graph/core/builder.py @@ -5,7 +5,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np @@ -16,7 +16,6 @@ from src.memory_graph.models import ( MemoryEdge, MemoryNode, MemoryStatus, - MemoryType, NodeType, ) from src.memory_graph.storage.graph_store import GraphStore @@ -28,7 +27,7 @@ logger = get_logger(__name__) class MemoryBuilder: """ 记忆构建器 - + 负责: 1. 根据提取的元素自动构造记忆子图 2. 创建节点和边的完整结构 @@ -41,11 +40,11 @@ class MemoryBuilder: self, vector_store: VectorStore, graph_store: GraphStore, - embedding_generator: Optional[Any] = None, + embedding_generator: Any | None = None, ): """ 初始化记忆构建器 - + Args: vector_store: 向量存储 graph_store: 图存储 @@ -55,13 +54,13 @@ class MemoryBuilder: self.graph_store = graph_store self.embedding_generator = embedding_generator - async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory: + async def build_memory(self, extracted_params: dict[str, Any]) -> Memory: """ 构建完整的记忆对象 - + Args: extracted_params: 提取器返回的标准化参数 - + Returns: Memory 对象(状态为 STAGED) """ @@ -97,7 +96,7 @@ class MemoryBuilder: edges.append(memory_type_edge) # 4. 如果有客体,创建客体节点并连接 - if "object" in extracted_params and extracted_params["object"]: + if extracted_params.get("object"): object_node = await self._create_object_node( content=extracted_params["object"], memory_id=memory_id ) @@ -158,14 +157,14 @@ class MemoryBuilder: ) -> MemoryNode: """ 创建新节点或复用已存在的相似节点 - + 对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点 - + Args: content: 节点内容 node_type: 节点类型 memory_id: 所属记忆ID - + Returns: MemoryNode 对象 """ @@ -190,11 +189,11 @@ class MemoryBuilder: async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode: """ 创建主题节点(需要生成嵌入向量) - + Args: content: 节点内容 memory_id: 所属记忆ID - + Returns: MemoryNode 对象 """ @@ -225,11 +224,11 @@ class MemoryBuilder: async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode: """ 创建客体节点(需要生成嵌入向量) - + Args: content: 节点内容 memory_id: 所属记忆ID - + Returns: MemoryNode 对象 """ @@ -258,22 +257,22 @@ class MemoryBuilder: async def _process_attributes( self, - attributes: Dict[str, Any], + attributes: dict[str, Any], parent_id: str, memory_id: str, importance: float, - ) -> tuple[List[MemoryNode], List[MemoryEdge]]: + ) -> tuple[list[MemoryNode], list[MemoryEdge]]: """ 处理属性,构建属性子图 - + 结构:TOPIC -> ATTRIBUTE -> VALUE - + Args: attributes: 属性字典 parent_id: 父节点ID(通常是TOPIC) memory_id: 所属记忆ID importance: 重要性 - + Returns: (属性节点列表, 属性边列表) """ @@ -322,10 +321,10 @@ class MemoryBuilder: async def _generate_embedding(self, text: str) -> np.ndarray: """ 生成文本的嵌入向量 - + Args: text: 文本内容 - + Returns: 嵌入向量 """ @@ -341,14 +340,14 @@ class MemoryBuilder: async def _find_existing_node( self, content: str, node_type: NodeType - ) -> Optional[MemoryNode]: + ) -> MemoryNode | None: """ 查找已存在的完全匹配节点(用于主体和属性) - + Args: content: 节点内容 node_type: 节点类型 - + Returns: 已存在的节点,如果没有则返回 None """ @@ -369,14 +368,14 @@ class MemoryBuilder: async def _find_similar_topic( self, content: str, embedding: np.ndarray - ) -> Optional[MemoryNode]: + ) -> MemoryNode | None: """ 查找相似的主题节点(基于语义相似度) - + Args: content: 内容 embedding: 嵌入向量 - + Returns: 相似节点,如果没有则返回 None """ @@ -414,14 +413,14 @@ class MemoryBuilder: async def _find_similar_object( self, content: str, embedding: np.ndarray - ) -> Optional[MemoryNode]: + ) -> MemoryNode | None: """ 查找相似的客体节点(基于语义相似度) - + Args: content: 内容 embedding: 嵌入向量 - + Returns: 相似节点,如果没有则返回 None """ @@ -480,13 +479,13 @@ class MemoryBuilder: ) -> MemoryEdge: """ 关联两个记忆(创建因果或引用边) - + Args: source_memory: 源记忆 target_memory: 目标记忆 relation_type: 关系类型(如 "导致", "引用") importance: 重要性 - + Returns: 创建的边 """ @@ -525,7 +524,7 @@ class MemoryBuilder: logger.error(f"记忆关联失败: {e}", exc_info=True) raise RuntimeError(f"记忆关联失败: {e}") - def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]: + def _find_topic_node(self, memory: Memory) -> MemoryNode | None: """查找记忆中的主题节点""" for node in memory.nodes: if node.node_type == NodeType.TOPIC: diff --git a/src/memory_graph/core/extractor.py b/src/memory_graph/core/extractor.py index afe2ad370..988feae92 100644 --- a/src/memory_graph/core/extractor.py +++ b/src/memory_graph/core/extractor.py @@ -5,7 +5,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any from src.common.logger import get_logger from src.memory_graph.models import MemoryType @@ -17,7 +17,7 @@ logger = get_logger(__name__) class MemoryExtractor: """ 记忆提取器 - + 负责: 1. 从工具调用参数中提取记忆元素 2. 验证参数完整性和有效性 @@ -25,19 +25,19 @@ class MemoryExtractor: 4. 清洗和格式化数据 """ - def __init__(self, time_parser: Optional[TimeParser] = None): + def __init__(self, time_parser: TimeParser | None = None): """ 初始化记忆提取器 - + Args: time_parser: 时间解析器(可选) """ self.time_parser = time_parser or TimeParser() - def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]: """ 从工具参数中提取记忆元素 - + Args: params: 工具调用参数,例如: { @@ -48,7 +48,7 @@ class MemoryExtractor: "attributes": {"时间": "今天", "地点": "家里"}, "importance": 0.3 } - + Returns: 提取和标准化后的参数字典 """ @@ -64,11 +64,11 @@ class MemoryExtractor: } # 3. 提取可选的客体 - if "object" in params and params["object"]: + if params.get("object"): extracted["object"] = self._clean_text(params["object"]) # 4. 提取和标准化属性 - if "attributes" in params and params["attributes"]: + if params.get("attributes"): extracted["attributes"] = self._process_attributes(params["attributes"]) else: extracted["attributes"] = {} @@ -86,13 +86,13 @@ class MemoryExtractor: logger.error(f"记忆提取失败: {e}", exc_info=True) raise ValueError(f"记忆提取失败: {e}") - def _validate_required_params(self, params: Dict[str, Any]) -> None: + def _validate_required_params(self, params: dict[str, Any]) -> None: """ 验证必需参数 - + Args: params: 参数字典 - + Raises: ValueError: 如果缺少必需参数 """ @@ -105,10 +105,10 @@ class MemoryExtractor: def _clean_text(self, text: Any) -> str: """ 清洗文本 - + Args: text: 输入文本 - + Returns: 清洗后的文本 """ @@ -128,13 +128,13 @@ class MemoryExtractor: def _parse_memory_type(self, type_str: str) -> MemoryType: """ 解析记忆类型 - + Args: type_str: 类型字符串 - + Returns: MemoryType 枚举 - + Raises: ValueError: 如果类型无效 """ @@ -166,10 +166,10 @@ class MemoryExtractor: def _parse_importance(self, importance: Any) -> float: """ 解析重要性值 - + Args: importance: 重要性值(可以是数字、字符串等) - + Returns: 0-1之间的浮点数 """ @@ -181,13 +181,13 @@ class MemoryExtractor: logger.warning(f"无效的重要性值: {importance},使用默认值 0.5") return 0.5 - def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: + def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]: """ 处理属性字典 - + Args: attributes: 原始属性字典 - + Returns: 处理后的属性字典 """ @@ -222,10 +222,10 @@ class MemoryExtractor: return processed - def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]: """ 提取记忆关联参数(用于 link_memories 工具) - + Args: params: 工具参数,例如: { @@ -234,7 +234,7 @@ class MemoryExtractor: "relation_type": "导致", "importance": 0.6 } - + Returns: 提取后的参数 """ @@ -266,10 +266,10 @@ class MemoryExtractor: def validate_relation_type(self, relation_type: str) -> str: """ 验证关系类型 - + Args: relation_type: 关系类型字符串 - + Returns: 标准化的关系类型 """ diff --git a/src/memory_graph/core/node_merger.py b/src/memory_graph/core/node_merger.py index e8b790f1e..74a2674ff 100644 --- a/src/memory_graph/core/node_merger.py +++ b/src/memory_graph/core/node_merger.py @@ -4,11 +4,6 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import List, Optional, Tuple - -import numpy as np - from src.common.logger import get_logger from src.config.official_configs import MemoryConfig from src.memory_graph.models import MemoryNode, NodeType @@ -21,7 +16,7 @@ logger = get_logger(__name__) class NodeMerger: """ 节点合并器 - + 负责: 1. 基于语义相似度查找重复节点 2. 验证上下文匹配 @@ -36,7 +31,7 @@ class NodeMerger: ): """ 初始化节点合并器 - + Args: vector_store: 向量存储 graph_store: 图存储 @@ -54,17 +49,17 @@ class NodeMerger: async def find_similar_nodes( self, node: MemoryNode, - threshold: Optional[float] = None, + threshold: float | None = None, limit: int = 5, - ) -> List[Tuple[MemoryNode, float]]: + ) -> list[tuple[MemoryNode, float]]: """ 查找与指定节点相似的节点 - + Args: node: 查询节点 threshold: 相似度阈值(可选,默认使用配置值) limit: 返回结果数量 - + Returns: List of (similar_node, similarity) """ @@ -112,12 +107,12 @@ class NodeMerger: ) -> bool: """ 判断两个节点是否应该合并 - + Args: source_node: 源节点 target_node: 目标节点 similarity: 语义相似度 - + Returns: 是否应该合并 """ @@ -157,16 +152,16 @@ class NodeMerger: ) -> bool: """ 检查两个节点的上下文是否匹配 - + 上下文匹配的标准: 1. 节点类型相同 2. 邻居节点有重叠 3. 邻居节点的内容相似 - + Args: source_node: 源节点 target_node: 目标节点 - + Returns: 是否匹配 """ @@ -207,7 +202,7 @@ class NodeMerger: # 如果有 30% 以上的邻居重叠,认为上下文匹配 return overlap_ratio > 0.3 - def _get_node_content(self, node_id: str) -> Optional[str]: + def _get_node_content(self, node_id: str) -> str | None: """获取节点的内容""" memories = self.graph_store.get_memories_by_node(node_id) if memories: @@ -223,13 +218,13 @@ class NodeMerger: ) -> bool: """ 合并两个节点 - + 将 source 节点的所有边转移到 target 节点,然后删除 source - + Args: source: 源节点(将被删除) target: 目标节点(保留) - + Returns: 是否成功 """ @@ -255,7 +250,7 @@ class NodeMerger: def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None: """ 更新记忆中的节点引用 - + Args: old_node_id: 旧节点ID new_node_id: 新节点ID @@ -280,16 +275,16 @@ class NodeMerger: async def batch_merge_similar_nodes( self, - nodes: List[MemoryNode], - progress_callback: Optional[callable] = None, + nodes: list[MemoryNode], + progress_callback: callable | None = None, ) -> dict: """ 批量处理节点合并 - + Args: nodes: 要处理的节点列表 progress_callback: 进度回调函数 - + Returns: 统计信息字典 """ @@ -344,14 +339,14 @@ class NodeMerger: self, min_similarity: float = 0.85, limit: int = 100, - ) -> List[Tuple[str, str, float]]: + ) -> list[tuple[str, str, float]]: """ 获取待合并的候选节点对 - + Args: min_similarity: 最小相似度 limit: 最大返回数量 - + Returns: List of (node_id_1, node_id_2, similarity) """ diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 142fe101a..e10894bf8 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -10,22 +10,21 @@ import asyncio import logging +import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any from src.config.config import global_config from src.config.official_configs import MemoryConfig from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.extractor import MemoryExtractor -from src.memory_graph.models import Memory, MemoryEdge, MemoryNode, MemoryType, NodeType, EdgeType +from src.memory_graph.models import EdgeType, Memory, MemoryEdge, NodeType from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.tools.memory_tools import MemoryTools from src.memory_graph.utils.embeddings import EmbeddingGenerator -import uuid -from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy as np @@ -36,7 +35,7 @@ logger = logging.getLogger(__name__) class MemoryManager: """ 记忆管理器 - + 核心管理类,提供记忆系统的统一接口: - 记忆 CRUD 操作 - 记忆生命周期管理 @@ -46,45 +45,45 @@ class MemoryManager: def __init__( self, - data_dir: Optional[Path] = None, + data_dir: Path | None = None, ): """ 初始化记忆管理器 - + Args: data_dir: 数据目录(可选,默认从global_config读取) """ # 直接使用 global_config.memory - if not global_config.memory or not getattr(global_config.memory, 'enable', False): + if not global_config.memory or not getattr(global_config.memory, "enable", False): raise ValueError("记忆系统未启用,请在配置文件中启用 [memory] enable = true") - + self.config: MemoryConfig = global_config.memory - self.data_dir = data_dir or Path(getattr(self.config, 'data_dir', 'data/memory_graph')) - + self.data_dir = data_dir or Path(getattr(self.config, "data_dir", "data/memory_graph")) + # 存储组件 - self.vector_store: Optional[VectorStore] = None - self.graph_store: Optional[GraphStore] = None - self.persistence: Optional[PersistenceManager] = None - + self.vector_store: VectorStore | None = None + self.graph_store: GraphStore | None = None + self.persistence: PersistenceManager | None = None + # 核心组件 - self.embedding_generator: Optional[EmbeddingGenerator] = None - self.extractor: Optional[MemoryExtractor] = None - self.builder: Optional[MemoryBuilder] = None - self.tools: Optional[MemoryTools] = None - + self.embedding_generator: EmbeddingGenerator | None = None + self.extractor: MemoryExtractor | None = None + self.builder: MemoryBuilder | None = None + self.tools: MemoryTools | None = None + # 状态 self._initialized = False self._last_maintenance = datetime.now() - self._maintenance_task: Optional[asyncio.Task] = None - self._maintenance_interval_hours = getattr(self.config, 'consolidation_interval_hours', 1.0) - self._maintenance_schedule_id: Optional[str] = None # 调度任务ID - + self._maintenance_task: asyncio.Task | None = None + self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0) + self._maintenance_schedule_id: str | None = None # 调度任务ID + logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})") async def initialize(self) -> None: """ 初始化所有组件 - + 按照依赖顺序初始化: 1. 存储层(向量存储、图存储、持久化) 2. 工具层(嵌入生成器、提取器) @@ -96,22 +95,22 @@ class MemoryManager: try: logger.info("开始初始化记忆管理器...") - + # 1. 初始化存储层 self.data_dir.mkdir(parents=True, exist_ok=True) - + # 获取存储配置 - storage_config = getattr(self.config, 'storage', None) - vector_collection_name = getattr(storage_config, 'vector_collection_name', 'memory_graph') if storage_config else 'memory_graph' - + storage_config = getattr(self.config, "storage", None) + vector_collection_name = getattr(storage_config, "vector_collection_name", "memory_graph") if storage_config else "memory_graph" + self.vector_store = VectorStore( collection_name=vector_collection_name, data_dir=self.data_dir, ) await self.vector_store.initialize() - + self.persistence = PersistenceManager(data_dir=self.data_dir) - + # 尝试加载现有图数据 self.graph_store = await self.persistence.load_graph_store() if not self.graph_store: @@ -123,20 +122,20 @@ class MemoryManager: f"加载图数据: {stats['total_memories']} 条记忆, " f"{stats['total_nodes']} 个节点, {stats['total_edges']} 条边" ) - + # 2. 初始化工具层 self.embedding_generator = EmbeddingGenerator() # EmbeddingGenerator 使用延迟初始化,在第一次调用时自动初始化 - + self.extractor = MemoryExtractor() - + # 3. 初始化管理层 self.builder = MemoryBuilder( vector_store=self.vector_store, graph_store=self.graph_store, embedding_generator=self.embedding_generator, ) - + # 检查配置值 expand_depth = self.config.search_max_expand_depth expand_semantic_threshold = self.config.search_expand_semantic_threshold @@ -150,13 +149,13 @@ class MemoryManager: max_expand_depth=expand_depth, # 从配置读取图扩展深度 expand_semantic_threshold=expand_semantic_threshold, # 从配置读取图扩展语义阈值 ) - + self._initialized = True logger.info("✅ 记忆管理器初始化完成") - + # 启动后台维护调度任务 await self.start_maintenance_scheduler() - + except Exception as e: logger.error(f"记忆管理器初始化失败: {e}", exc_info=True) raise @@ -164,7 +163,7 @@ class MemoryManager: async def shutdown(self) -> None: """ 关闭记忆管理器 - + 执行清理操作: - 停止维护调度任务 - 保存所有数据 @@ -176,23 +175,23 @@ class MemoryManager: try: logger.info("正在关闭记忆管理器...") - + # 1. 停止调度任务 await self.stop_maintenance_scheduler() - + # 2. 执行最后一次维护(保存数据) if self.graph_store and self.persistence: logger.info("执行最终数据保存...") await self.persistence.save_graph_store(self.graph_store) - + # 3. 关闭存储组件 if self.vector_store: # VectorStore 使用 chromadb,无需显式关闭 pass - + self._initialized = False logger.info("✅ 记忆管理器已关闭") - + except Exception as e: logger.error(f"关闭记忆管理器失败: {e}", exc_info=True) @@ -203,14 +202,14 @@ class MemoryManager: subject: str, memory_type: str, topic: str, - object: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, + object: str | None = None, + attributes: dict[str, str] | None = None, importance: float = 0.5, **kwargs, - ) -> Optional[Memory]: + ) -> Memory | None: """ 创建新记忆 - + Args: subject: 主体(谁) memory_type: 记忆类型(事件/观点/事实/关系) @@ -219,7 +218,7 @@ class MemoryManager: attributes: 属性字典(时间、地点、原因等) importance: 重要性 (0.0-1.0) **kwargs: 其他参数 - + Returns: 创建的记忆对象,失败返回 None """ @@ -236,7 +235,7 @@ class MemoryManager: importance=importance, **kwargs, ) - + if result["success"]: memory_id = result["memory_id"] memory = self.graph_store.get_memory_by_id(memory_id) @@ -245,18 +244,18 @@ class MemoryManager: else: logger.error(f"记忆创建失败: {result.get('error', 'Unknown error')}") return None - + except Exception as e: logger.error(f"创建记忆时发生异常: {e}", exc_info=True) return None - async def get_memory(self, memory_id: str) -> Optional[Memory]: + async def get_memory(self, memory_id: str) -> Memory | None: """ 根据 ID 获取记忆 - + Args: memory_id: 记忆 ID - + Returns: 记忆对象,不存在返回 None """ @@ -272,11 +271,11 @@ class MemoryManager: ) -> bool: """ 更新记忆 - + Args: memory_id: 记忆 ID **updates: 要更新的字段 - + Returns: 是否更新成功 """ @@ -288,21 +287,21 @@ class MemoryManager: if not memory: logger.warning(f"记忆不存在: {memory_id}") return False - + # 更新元数据 if "importance" in updates: memory.importance = updates["importance"] - + if "metadata" in updates: memory.metadata.update(updates["metadata"]) - + memory.updated_at = datetime.now() - + # 保存更新 await self.persistence.save_graph_store(self.graph_store) logger.info(f"记忆更新成功: {memory_id}") return True - + except Exception as e: logger.error(f"更新记忆失败: {e}", exc_info=True) return False @@ -310,10 +309,10 @@ class MemoryManager: async def delete_memory(self, memory_id: str) -> bool: """ 删除记忆 - + Args: memory_id: 记忆 ID - + Returns: 是否删除成功 """ @@ -325,20 +324,20 @@ class MemoryManager: if not memory: logger.warning(f"记忆不存在: {memory_id}") return False - + # 从向量存储删除节点 for node in memory.nodes: if node.embedding is not None: await self.vector_store.delete_node(node.id) - + # 从图存储删除记忆 self.graph_store.remove_memory(memory_id) - + # 保存更新 await self.persistence.save_graph_store(self.graph_store) logger.info(f"记忆删除成功: {memory_id}") return True - + except Exception as e: logger.error(f"删除记忆失败: {e}", exc_info=True) return False @@ -348,33 +347,33 @@ class MemoryManager: async def generate_multi_queries( self, query: str, - context: Optional[Dict[str, Any]] = None, - ) -> List[Tuple[str, float]]: + context: dict[str, Any] | None = None, + ) -> list[tuple[str, float]]: """ 使用小模型生成多个查询语句(用于多路召回) - + 简化版多查询策略:直接让小模型生成3-5个不同角度的查询, 避免复杂的查询分解和组合逻辑。 - + Args: query: 原始查询 context: 上下文信息(聊天历史、发言人、参与者等) - + Returns: List of (query_string, weight) - 查询语句和权重 """ try: - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + from src.llm_models.utils_model import LLMRequest + llm = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="memory.multi_query_generator" ) - + # 构建上下文信息 chat_history = context.get("chat_history", "") if context else "" - + prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。 **核心原则(重要!):** @@ -405,33 +404,34 @@ class MemoryManager: ```""" response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) - + # 解析JSON - import json, re - response = re.sub(r'```json\s*', '', response) - response = re.sub(r'```\s*$', '', response).strip() - + import json + import re + response = re.sub(r"```json\s*", "", response) + response = re.sub(r"```\s*$", "", response).strip() + try: data = json.loads(response) queries = data.get("queries", []) - + result = [] for item in queries: text = item.get("text", "").strip() weight = float(item.get("weight", 0.5)) if text: result.append((text, weight)) - + if result: logger.info(f"生成 {len(result)} 个查询: {[q for q, _ in result]}") return result - + except json.JSONDecodeError as e: logger.warning(f"解析失败: {e}, response={response[:100]}") - + except Exception as e: logger.warning(f"多查询生成失败: {e}") - + # 回退到原始查询 return [(query, 1.0)] @@ -439,23 +439,23 @@ class MemoryManager: self, query: str, top_k: int = 10, - memory_types: Optional[List[str]] = None, - time_range: Optional[Tuple[datetime, datetime]] = None, + memory_types: list[str] | None = None, + time_range: tuple[datetime, datetime] | None = None, min_importance: float = 0.0, include_forgotten: bool = False, use_multi_query: bool = True, expand_depth: int | None = None, - context: Optional[Dict[str, Any]] = None, - ) -> List[Memory]: + context: dict[str, Any] | None = None, + ) -> list[Memory]: """ 搜索记忆 - + 使用多策略检索优化,解决复杂查询问题。 例如:"杰瑞喵如何评价新的记忆系统" 会被分解为多个子查询, 确保同时匹配"杰瑞喵"和"新的记忆系统"两个关键概念。 - + 同时支持图扩展:从初始检索结果出发,沿图结构查找语义相关的邻居记忆。 - + Args: query: 搜索查询 top_k: 返回结果数 @@ -466,7 +466,7 @@ class MemoryManager: use_multi_query: 是否使用多查询策略(推荐,默认True) expand_depth: 图扩展深度(0=禁用, 1=推荐, 2-3=深度探索) context: 查询上下文(用于优化) - + Returns: 记忆列表 """ @@ -482,19 +482,19 @@ class MemoryManager: "expand_depth": expand_depth or global_config.memory.search_max_expand_depth, # 传递图扩展深度 "context": context, } - + if memory_types: params["memory_types"] = memory_types - + # 执行搜索 result = await self.tools.search_memories(**params) - + if not result["success"]: logger.error(f"搜索失败: {result.get('error', 'Unknown error')}") return [] - + memories = result.get("results", []) - + # 后处理过滤 filtered_memories = [] for mem_dict in memories: @@ -502,33 +502,33 @@ class MemoryManager: memory_id = mem_dict.get("memory_id", "") if not memory_id: continue - + memory = self.graph_store.get_memory_by_id(memory_id) if not memory: continue - + # 重要性过滤 if min_importance is not None and memory.importance < min_importance: continue - + # 遗忘状态过滤 if not include_forgotten and memory.metadata.get("forgotten", False): continue - + # 时间范围过滤 if time_range: mem_time = memory.created_at if not (time_range[0] <= mem_time <= time_range[1]): continue - + filtered_memories.append(memory) - + strategy = result.get("strategy", "unknown") logger.info( f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})" ) return filtered_memories[:top_k] - + except Exception as e: logger.error(f"搜索记忆失败: {e}", exc_info=True) return [] @@ -542,13 +542,13 @@ class MemoryManager: ) -> bool: """ 关联两条记忆 - + Args: source_description: 源记忆描述 target_description: 目标记忆描述 relation_type: 关系类型(导致/引用/相似/相反) importance: 关系重要性 - + Returns: 是否关联成功 """ @@ -562,7 +562,7 @@ class MemoryManager: relation_type=relation_type, importance=importance, ) - + if result["success"]: logger.info( f"记忆关联成功: {result['source_memory_id']} -> " @@ -572,7 +572,7 @@ class MemoryManager: else: logger.error(f"记忆关联失败: {result.get('error', 'Unknown error')}") return False - + except Exception as e: logger.error(f"关联记忆失败: {e}", exc_info=True) return False @@ -582,13 +582,13 @@ class MemoryManager: async def activate_memory(self, memory_id: str, strength: float = 1.0) -> bool: """ 激活记忆 - + 更新记忆的激活度,并传播到相关记忆 - + Args: memory_id: 记忆 ID strength: 激活强度 (0.0-1.0) - + Returns: 是否激活成功 """ @@ -600,80 +600,80 @@ class MemoryManager: if not memory: logger.warning(f"记忆不存在: {memory_id}") return False - + # 更新激活信息 now = datetime.now() activation_info = memory.metadata.get("activation", {}) - + # 更新激活度(考虑时间衰减) last_access = activation_info.get("last_access") if last_access: # 计算时间衰减 last_access_dt = datetime.fromisoformat(last_access) hours_passed = (now - last_access_dt).total_seconds() / 3600 - decay_rate = getattr(self.config, 'activation_decay_rate', 0.95) + decay_rate = getattr(self.config, "activation_decay_rate", 0.95) decay_factor = decay_rate ** (hours_passed / 24) current_activation = activation_info.get("level", 0.0) * decay_factor else: current_activation = 0.0 - + # 新的激活度 = 当前激活度 + 激活强度 new_activation = min(1.0, current_activation + strength) - + activation_info.update({ "level": new_activation, "last_access": now.isoformat(), "access_count": activation_info.get("access_count", 0) + 1, }) - + memory.metadata["activation"] = activation_info memory.last_accessed = now - + # 激活传播:激活相关记忆 if strength > 0.1: # 只有足够强的激活才传播 - propagation_depth = getattr(self.config, 'activation_propagation_depth', 2) + propagation_depth = getattr(self.config, "activation_propagation_depth", 2) related_memories = self._get_related_memories( memory_id, max_depth=propagation_depth ) - propagation_strength_factor = getattr(self.config, 'activation_propagation_strength', 0.5) + propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5) propagation_strength = strength * propagation_strength_factor - - max_related = getattr(self.config, 'max_related_memories', 5) + + max_related = getattr(self.config, "max_related_memories", 5) for related_id in related_memories[:max_related]: await self.activate_memory(related_id, propagation_strength) - + # 保存更新 await self.persistence.save_graph_store(self.graph_store) logger.debug(f"记忆已激活: {memory_id} (level={new_activation:.3f})") return True - + except Exception as e: logger.error(f"激活记忆失败: {e}", exc_info=True) return False - def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> List[str]: + def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]: """ 获取相关记忆 ID 列表(旧版本,保留用于激活传播) - + Args: memory_id: 记忆 ID max_depth: 最大遍历深度 - + Returns: 相关记忆 ID 列表 """ memory = self.graph_store.get_memory_by_id(memory_id) if not memory: return [] - + related_ids = set() - + # 遍历记忆的节点 for node in memory.nodes: # 获取节点的邻居 neighbors = list(self.graph_store.graph.neighbors(node.id)) - + for neighbor_id in neighbors: # 获取邻居节点所属的记忆 neighbor_node = self.graph_store.graph.nodes.get(neighbor_id) @@ -682,116 +682,115 @@ class MemoryManager: for mem_id in neighbor_memory_ids: if mem_id != memory_id: related_ids.add(mem_id) - + return list(related_ids) async def expand_memories_with_semantic_filter( self, - initial_memory_ids: List[str], + initial_memory_ids: list[str], query_embedding: "np.ndarray", max_depth: int = 2, semantic_threshold: float = 0.5, max_expanded: int = 20 - ) -> List[Tuple[str, float]]: + ) -> list[tuple[str, float]]: """ 从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤 - + 这个方法解决了纯向量搜索可能遗漏的"语义相关且图结构相关"的记忆。 - + Args: initial_memory_ids: 初始记忆ID集合(由向量搜索得到) query_embedding: 查询向量 max_depth: 最大扩展深度(1-3推荐) semantic_threshold: 语义相似度阈值(0.5推荐) max_expanded: 最多扩展多少个记忆 - + Returns: List[(memory_id, relevance_score)] 按相关度排序 """ if not initial_memory_ids or query_embedding is None: return [] - + try: - import numpy as np - + # 记录已访问的记忆,避免重复 visited_memories = set(initial_memory_ids) # 记录扩展的记忆及其分数 - expanded_memories: Dict[str, float] = {} - + expanded_memories: dict[str, float] = {} + # BFS扩展 current_level = initial_memory_ids - + for depth in range(max_depth): next_level = [] - + for memory_id in current_level: memory = self.graph_store.get_memory_by_id(memory_id) if not memory: continue - + # 遍历该记忆的所有节点 for node in memory.nodes: if not node.has_embedding(): continue - + # 获取邻居节点 try: neighbors = list(self.graph_store.graph.neighbors(node.id)) - except: + except Exception: continue - + for neighbor_id in neighbors: # 获取邻居节点信息 neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id) if not neighbor_node_data: continue - + # 获取邻居节点的向量(从向量存储) neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id) if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None: continue - + neighbor_embedding = neighbor_vector_data["embedding"] - + # 计算与查询的语义相似度 semantic_sim = self._cosine_similarity( query_embedding, neighbor_embedding ) - + # 获取边的权重 try: edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id) edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5 - except: + except Exception: edge_importance = 0.5 - + # 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%) depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低 relevance_score = ( - semantic_sim * 0.7 + - edge_importance * 0.2 + + semantic_sim * 0.7 + + edge_importance * 0.2 + depth_decay * 0.1 ) - + # 只保留超过阈值的节点 if relevance_score < semantic_threshold: continue - + # 提取邻居节点所属的记忆 neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) if isinstance(neighbor_memory_ids, str): import json try: neighbor_memory_ids = json.loads(neighbor_memory_ids) - except: + except Exception: neighbor_memory_ids = [neighbor_memory_ids] - + for neighbor_mem_id in neighbor_memory_ids: if neighbor_mem_id in visited_memories: continue - + # 记录这个扩展记忆 if neighbor_mem_id not in expanded_memories: expanded_memories[neighbor_mem_id] = relevance_score @@ -803,54 +802,54 @@ class MemoryManager: expanded_memories[neighbor_mem_id], relevance_score ) - + # 如果没有新节点或已达到数量限制,提前终止 if not next_level or len(expanded_memories) >= max_expanded: break - + current_level = next_level[:max_expanded] # 限制每层的扩展数量 - + # 排序并返回 sorted_results = sorted( expanded_memories.items(), key=lambda x: x[1], reverse=True )[:max_expanded] - + logger.info( f"图扩展完成: 初始{len(initial_memory_ids)}个 → " f"扩展{len(sorted_results)}个新记忆 " f"(深度={max_depth}, 阈值={semantic_threshold:.2f})" ) - + return sorted_results - + except Exception as e: logger.error(f"语义图扩展失败: {e}", exc_info=True) return [] - + def _cosine_similarity(self, vec1: "np.ndarray", vec2: "np.ndarray") -> float: """计算余弦相似度""" try: import numpy as np - + # 确保是numpy数组 if not isinstance(vec1, np.ndarray): vec1 = np.array(vec1) if not isinstance(vec2, np.ndarray): vec2 = np.array(vec2) - + # 归一化 vec1_norm = np.linalg.norm(vec1) vec2_norm = np.linalg.norm(vec2) - + if vec1_norm == 0 or vec2_norm == 0: return 0.0 - + # 余弦相似度 similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) return float(similarity) - + except Exception as e: logger.warning(f"计算余弦相似度失败: {e}") return 0.0 @@ -858,10 +857,10 @@ class MemoryManager: async def forget_memory(self, memory_id: str) -> bool: """ 遗忘记忆(标记为已遗忘,不删除) - + Args: memory_id: 记忆 ID - + Returns: 是否遗忘成功 """ @@ -873,15 +872,15 @@ class MemoryManager: if not memory: logger.warning(f"记忆不存在: {memory_id}") return False - + memory.metadata["forgotten"] = True memory.metadata["forgotten_at"] = datetime.now().isoformat() - + # 保存更新 await self.persistence.save_graph_store(self.graph_store) logger.info(f"记忆已遗忘: {memory_id}") return True - + except Exception as e: logger.error(f"遗忘记忆失败: {e}", exc_info=True) return False @@ -889,10 +888,10 @@ class MemoryManager: async def auto_forget_memories(self, threshold: float = 0.1) -> int: """ 自动遗忘低激活度的记忆 - + Args: threshold: 激活度阈值 - + Returns: 遗忘的记忆数量 """ @@ -902,47 +901,47 @@ class MemoryManager: try: forgotten_count = 0 all_memories = self.graph_store.get_all_memories() - + for memory in all_memories: # 跳过已遗忘的记忆 if memory.metadata.get("forgotten", False): continue - + # 跳过高重要性记忆 - min_importance = getattr(self.config, 'forgetting_min_importance', 7.0) + min_importance = getattr(self.config, "forgetting_min_importance", 7.0) if memory.importance >= min_importance: continue - + # 计算当前激活度 activation_info = memory.metadata.get("activation", {}) last_access = activation_info.get("last_access") - + if last_access: last_access_dt = datetime.fromisoformat(last_access) days_passed = (datetime.now() - last_access_dt).days - + # 长时间未访问的记忆,应用时间衰减 decay_factor = 0.9 ** days_passed current_activation = activation_info.get("level", 0.0) * decay_factor - + # 低于阈值则遗忘 if current_activation < threshold: await self.forget_memory(memory.id) forgotten_count += 1 - + logger.info(f"自动遗忘完成: 遗忘了 {forgotten_count} 条记忆") return forgotten_count - + except Exception as e: logger.error(f"自动遗忘失败: {e}", exc_info=True) return 0 # ==================== 统计与维护 ==================== - def get_statistics(self) -> Dict[str, Any]: + def get_statistics(self) -> dict[str, Any]: """ 获取记忆系统统计信息 - + Returns: 统计信息字典 """ @@ -950,29 +949,29 @@ class MemoryManager: return {} stats = self.graph_store.get_statistics() - + # 添加激活度统计 all_memories = self.graph_store.get_all_memories() activation_levels = [] forgotten_count = 0 - + for memory in all_memories: if memory.metadata.get("forgotten", False): forgotten_count += 1 else: activation_info = memory.metadata.get("activation", {}) activation_levels.append(activation_info.get("level", 0.0)) - + if activation_levels: stats["avg_activation"] = sum(activation_levels) / len(activation_levels) stats["max_activation"] = max(activation_levels) else: stats["avg_activation"] = 0.0 stats["max_activation"] = 0.0 - + stats["forgotten_memories"] = forgotten_count stats["active_memories"] = stats["total_memories"] - forgotten_count - + return stats async def consolidate_memories( @@ -980,7 +979,7 @@ class MemoryManager: similarity_threshold: float = 0.85, time_window_hours: float = 24.0, max_batch_size: int = 50, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ 整理记忆:直接合并去重相似记忆(不创建新边) @@ -1065,7 +1064,7 @@ class MemoryManager: result["checked_count"] = len(recent_memories) # 按记忆类型分组,减少跨类型比较 - memories_by_type: Dict[str, List[Memory]] = {} + memories_by_type: dict[str, list[Memory]] = {} for mem in recent_memories: mem_type = mem.metadata.get("memory_type", "") if mem_type not in memories_by_type: @@ -1073,7 +1072,7 @@ class MemoryManager: memories_by_type[mem_type].append(mem) # 记录需要删除的记忆,延迟批量删除 - to_delete: List[Tuple[Memory, str]] = [] # (memory, reason) + to_delete: list[tuple[Memory, str]] = [] # (memory, reason) deleted_ids = set() # 对每个类型的记忆进行相似度检测 @@ -1084,7 +1083,7 @@ class MemoryManager: logger.debug(f"🔍 检查类型 '{mem_type}' 的 {len(memories)} 条记忆") # 预提取所有主题节点的嵌入向量 - embeddings_map: Dict[str, "np.ndarray"] = {} + embeddings_map: dict[str, "np.ndarray"] = {} valid_memories = [] for mem in memories: @@ -1094,7 +1093,6 @@ class MemoryManager: valid_memories.append(mem) # 批量计算相似度矩阵(比逐个计算更高效) - import numpy as np for i in range(len(valid_memories)): # 更频繁的协作式多任务让出 @@ -1134,7 +1132,7 @@ class MemoryManager: keep_mem.importance = min(1.0, keep_mem.importance + 0.05) # 累加访问次数 - if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'): + if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"): keep_mem.access_count += remove_mem.access_count # 标记为待删除(不立即删除) @@ -1164,7 +1162,7 @@ class MemoryManager: # 批量保存(一次性写入,减少I/O) await self.persistence.save_graph_store(self.graph_store) - logger.info(f"💾 批量保存完成") + logger.info("💾 批量保存完成") logger.info(f"✅ 记忆整理完成: {result}") @@ -1207,20 +1205,20 @@ class MemoryManager: async def auto_link_memories( self, - time_window_hours: float = None, - max_candidates: int = None, - min_confidence: float = None, - ) -> Dict[str, Any]: + time_window_hours: float | None = None, + max_candidates: int | None = None, + min_confidence: float | None = None, + ) -> dict[str, Any]: """ 自动关联记忆 - + 使用LLM分析记忆之间的关系,自动建立关联边。 - + Args: time_window_hours: 分析时间窗口(小时) max_candidates: 每个记忆最多关联的候选数 min_confidence: 最低置信度阈值 - + Returns: 关联结果统计 """ @@ -1229,39 +1227,39 @@ class MemoryManager: # 使用配置值或参数覆盖 time_window_hours = time_window_hours if time_window_hours is not None else 24 - max_candidates = max_candidates if max_candidates is not None else getattr(self.config, 'auto_link_max_candidates', 10) - min_confidence = min_confidence if min_confidence is not None else getattr(self.config, 'auto_link_min_confidence', 0.7) + max_candidates = max_candidates if max_candidates is not None else getattr(self.config, "auto_link_max_candidates", 10) + min_confidence = min_confidence if min_confidence is not None else getattr(self.config, "auto_link_min_confidence", 0.7) try: logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...") - + result = { "checked_count": 0, "linked_count": 0, "relation_stats": {}, # 关系类型统计 {类型: 数量} "relations": {}, # 详细关系 {source_id: [关系列表]} } - + # 1. 获取时间窗口内的记忆 time_threshold = datetime.now() - timedelta(hours=time_window_hours) all_memories = self.graph_store.get_all_memories() - + recent_memories = [ mem for mem in all_memories if mem.created_at >= time_threshold and not mem.metadata.get("forgotten", False) ] - + if len(recent_memories) < 2: logger.info("记忆数量不足,跳过自动关联") return result - + logger.info(f"找到 {len(recent_memories)} 条待关联记忆") - + # 2. 为每个记忆寻找关联候选 for memory in recent_memories: result["checked_count"] += 1 - + # 跳过已经有很多连接的记忆 existing_edges = len([ e for e in memory.edges @@ -1269,24 +1267,24 @@ class MemoryManager: ]) if existing_edges >= 10: continue - + # 3. 使用向量搜索找候选记忆 candidates = await self._find_link_candidates( memory, exclude_ids={memory.id}, max_results=max_candidates ) - + if not candidates: continue - + # 4. 使用LLM分析关系 relations = await self._analyze_memory_relations( source_memory=memory, candidate_memories=candidates, min_confidence=min_confidence ) - + # 5. 建立关联 for relation in relations: try: @@ -1305,7 +1303,7 @@ class MemoryManager: "created_at": datetime.now().isoformat(), } ) - + # 添加到图 self.graph_store.graph.add_edge( edge.source_id, @@ -1316,16 +1314,16 @@ class MemoryManager: importance=edge.importance, metadata=edge.metadata, ) - + # 同时添加到记忆的边列表 memory.edges.append(edge) - + result["linked_count"] += 1 - + # 更新统计 result["relation_stats"][relation["relation_type"]] = \ result["relation_stats"].get(relation["relation_type"], 0) + 1 - + # 记录详细关系 if memory.id not in result["relations"]: result["relations"][memory.id] = [] @@ -1335,25 +1333,25 @@ class MemoryManager: "confidence": relation["confidence"], "reasoning": relation["reasoning"], }) - + logger.info( f"建立关联: {memory.id[:8]} --[{relation['relation_type']}]--> " f"{relation['target_memory'].id[:8]} " f"(置信度={relation['confidence']:.2f})" ) - + except Exception as e: logger.warning(f"建立关联失败: {e}") continue - + # 保存更新后的图数据 if result["linked_count"] > 0: await self.persistence.save_graph_store(self.graph_store) logger.info(f"已保存 {result['linked_count']} 条自动关联边") - + logger.info(f"自动关联完成: {result}") return result - + except Exception as e: logger.error(f"自动关联失败: {e}", exc_info=True) return {"error": str(e), "checked_count": 0, "linked_count": 0} @@ -1361,12 +1359,12 @@ class MemoryManager: async def _find_link_candidates( self, memory: Memory, - exclude_ids: Set[str], + exclude_ids: set[str], max_results: int = 5, - ) -> List[Memory]: + ) -> list[Memory]: """ 为记忆寻找关联候选 - + 使用向量相似度 + 时间接近度找到潜在相关记忆 """ try: @@ -1375,31 +1373,31 @@ class MemoryManager: (n for n in memory.nodes if n.node_type == NodeType.TOPIC), None ) - + if not topic_node or not topic_node.content: return [] - + # 使用主题内容搜索相似记忆 candidates = await self.search_memories( query=topic_node.content, top_k=max_results * 2, include_forgotten=False, ) - + # 过滤:排除自己和已关联的 existing_targets = { e.target_id for e in memory.edges if e.edge_type == EdgeType.RELATION } - + filtered = [ c for c in candidates if c.id not in exclude_ids and c.id not in existing_targets ] - + return filtered[:max_results] - + except Exception as e: logger.warning(f"查找候选失败: {e}") return [] @@ -1407,17 +1405,17 @@ class MemoryManager: async def _analyze_memory_relations( self, source_memory: Memory, - candidate_memories: List[Memory], + candidate_memories: list[Memory], min_confidence: float = 0.7, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ 使用LLM分析记忆之间的关系 - + Args: source_memory: 源记忆 candidate_memories: 候选记忆列表 min_confidence: 最低置信度 - + Returns: 关系列表,每项包含: - target_memory: 目标记忆 @@ -1426,22 +1424,22 @@ class MemoryManager: - reasoning: 推理过程 """ try: - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + from src.llm_models.utils_model import LLMRequest + # 构建LLM请求 llm = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="memory.relation_analysis" ) - + # 格式化记忆信息 source_desc = self._format_memory_for_llm(source_memory) candidates_desc = "\n\n".join([ f"记忆{i+1}:\n{self._format_memory_for_llm(mem)}" for i, mem in enumerate(candidate_memories) ]) - + # 构建提示词 prompt = f"""你是一个记忆关系分析专家。请分析源记忆与候选记忆之间是否存在有意义的关系。 @@ -1490,36 +1488,36 @@ class MemoryManager: temperature=0.3, max_tokens=1000, ) - + # 解析响应 import json import re - + # 提取JSON - json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) + json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL) if json_match: json_str = json_match.group(1) else: json_str = response.strip() - + try: analysis_results = json.loads(json_str) except json.JSONDecodeError: logger.warning(f"LLM返回格式错误,尝试修复: {response[:200]}") # 尝试简单修复 - json_str = re.sub(r'[\r\n\t]', '', json_str) + json_str = re.sub(r"[\r\n\t]", "", json_str) analysis_results = json.loads(json_str) - + # 转换为结果格式 relations = [] for result in analysis_results: if not result.get("has_relation", False): continue - + confidence = result.get("confidence", 0.0) if confidence < min_confidence: continue - + candidate_id = result.get("candidate_id", 0) - 1 if 0 <= candidate_id < len(candidate_memories): relations.append({ @@ -1528,10 +1526,10 @@ class MemoryManager: "confidence": confidence, "reasoning": result.get("reasoning", ""), }) - + logger.debug(f"LLM分析完成: 发现 {len(relations)} 个关系") return relations - + except Exception as e: logger.error(f"LLM关系分析失败: {e}", exc_info=True) return [] @@ -1552,29 +1550,29 @@ class MemoryManager: (n for n in memory.nodes if n.node_type == NodeType.OBJECT), None ) - + parts = [] parts.append(f"类型: {memory.memory_type.value}") - + if subject_node: parts.append(f"主体: {subject_node.content}") - + if topic_node: parts.append(f"主题: {topic_node.content}") - + if object_node: parts.append(f"对象: {object_node.content}") - + parts.append(f"重要性: {memory.importance:.2f}") parts.append(f"时间: {memory.created_at.strftime('%Y-%m-%d %H:%M')}") - + return " | ".join(parts) - + except Exception as e: logger.warning(f"格式化记忆失败: {e}") return f"记忆ID: {memory.id}" - async def maintenance(self) -> Dict[str, Any]: + async def maintenance(self) -> dict[str, Any]: """ 执行维护任务(优化版本) @@ -1604,12 +1602,12 @@ class MemoryManager: start_time = datetime.now() # 1. 记忆整理(异步后台执行,不阻塞主流程) - if getattr(self.config, 'consolidation_enabled', False): + if getattr(self.config, "consolidation_enabled", False): logger.info("🚀 启动异步记忆整理任务...") consolidate_result = await self.consolidate_memories( - similarity_threshold=getattr(self.config, 'consolidation_deduplication_threshold', 0.93), - time_window_hours=getattr(self.config, 'consolidation_time_window_hours', 2.0), # 统一时间窗口 - max_batch_size=getattr(self.config, 'consolidation_max_batch_size', 30) + similarity_threshold=getattr(self.config, "consolidation_deduplication_threshold", 0.93), + time_window_hours=getattr(self.config, "consolidation_time_window_hours", 2.0), # 统一时间窗口 + max_batch_size=getattr(self.config, "consolidation_max_batch_size", 30) ) if consolidate_result.get("task_started"): @@ -1620,16 +1618,16 @@ class MemoryManager: logger.warning("❌ 记忆整理任务启动失败") # 2. 自动关联记忆(使用统一的时间窗口) - if getattr(self.config, 'consolidation_linking_enabled', True): + if getattr(self.config, "consolidation_linking_enabled", True): logger.info("🔗 执行轻量级自动关联...") link_result = await self._lightweight_auto_link_memories() result["linked"] = link_result.get("linked_count", 0) # 3. 自动遗忘(快速执行) - if getattr(self.config, 'forgetting_enabled', True): + if getattr(self.config, "forgetting_enabled", True): logger.info("🗑️ 执行自动遗忘...") forgotten_count = await self.auto_forget_memories( - threshold=getattr(self.config, 'forgetting_activation_threshold', 0.1) + threshold=getattr(self.config, "forgetting_activation_threshold", 0.1) ) result["forgotten"] = forgotten_count @@ -1654,10 +1652,10 @@ class MemoryManager: async def _lightweight_auto_link_memories( self, - time_window_hours: float = None, # 从配置读取 - max_candidates: int = None, # 从配置读取 - max_memories: int = None, # 从配置读取 - ) -> Dict[str, Any]: + time_window_hours: float | None = None, # 从配置读取 + max_candidates: int | None = None, # 从配置读取 + max_memories: int | None = None, # 从配置读取 + ) -> dict[str, Any]: """ 智能轻量级自动关联记忆(保留LLM判断,优化性能) @@ -1676,11 +1674,11 @@ class MemoryManager: # 从配置读取参数,使用统一的时间窗口 if time_window_hours is None: - time_window_hours = getattr(self.config, 'consolidation_time_window_hours', 2.0) + time_window_hours = getattr(self.config, "consolidation_time_window_hours", 2.0) if max_candidates is None: - max_candidates = getattr(self.config, 'consolidation_linking_max_candidates', 10) + max_candidates = getattr(self.config, "consolidation_linking_max_candidates", 10) if max_memories is None: - max_memories = getattr(self.config, 'consolidation_linking_max_memories', 20) + max_memories = getattr(self.config, "consolidation_linking_max_memories", 20) # 获取用户配置时间窗口内的记忆 time_threshold = datetime.now() - timedelta(hours=time_window_hours) @@ -1690,7 +1688,7 @@ class MemoryManager: mem for mem in all_memories if mem.created_at >= time_threshold and not mem.metadata.get("forgotten", False) - and mem.importance >= getattr(self.config, 'consolidation_linking_min_importance', 0.5) # 从配置读取重要性阈值 + and mem.importance >= getattr(self.config, "consolidation_linking_min_importance", 0.5) # 从配置读取重要性阈值 ] if len(recent_memories) > max_memories: @@ -1704,7 +1702,6 @@ class MemoryManager: # 第一步:向量相似度预筛选,找到潜在关联对 candidate_pairs = [] - import numpy as np for i, memory in enumerate(recent_memories): # 获取主题节点 @@ -1733,7 +1730,7 @@ class MemoryManager: ) # 使用配置的预筛选阈值 - pre_filter_threshold = getattr(self.config, 'consolidation_linking_pre_filter_threshold', 0.7) + pre_filter_threshold = getattr(self.config, "consolidation_linking_pre_filter_threshold", 0.7) if similarity >= pre_filter_threshold: candidate_pairs.append((memory, other_memory, similarity)) @@ -1747,7 +1744,7 @@ class MemoryManager: return result # 第二步:批量LLM分析(使用配置的最大候选对数) - max_pairs_for_llm = getattr(self.config, 'consolidation_linking_max_pairs_for_llm', 5) + max_pairs_for_llm = getattr(self.config, "consolidation_linking_max_pairs_for_llm", 5) if len(candidate_pairs) <= max_pairs_for_llm: link_relations = await self._batch_analyze_memory_relations(candidate_pairs) result["llm_calls"] = 1 @@ -1810,8 +1807,8 @@ class MemoryManager: async def _batch_analyze_memory_relations( self, - candidate_pairs: List[Tuple[Memory, Memory, float]] - ) -> List[Dict[str, Any]]: + candidate_pairs: list[tuple[Memory, Memory, float]] + ) -> list[dict[str, Any]]: """ 批量分析记忆关系(优化LLM调用) @@ -1822,8 +1819,8 @@ class MemoryManager: 关系分析结果列表 """ try: - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm = LLMRequest( model_set=model_config.model_task_config.utils_small, @@ -1843,7 +1840,7 @@ class MemoryManager: """ # 构建批量分析提示词(使用配置的置信度阈值) - min_confidence = getattr(self.config, 'consolidation_linking_min_confidence', 0.7) + min_confidence = getattr(self.config, "consolidation_linking_min_confidence", 0.7) prompt = f"""你是记忆关系分析专家。请批量分析以下候选记忆对之间的关系。 @@ -1885,8 +1882,8 @@ class MemoryManager: 请分析并输出JSON结果:""" # 调用LLM(使用配置的参数) - llm_temperature = getattr(self.config, 'consolidation_linking_llm_temperature', 0.2) - llm_max_tokens = getattr(self.config, 'consolidation_linking_llm_max_tokens', 1500) + llm_temperature = getattr(self.config, "consolidation_linking_llm_temperature", 0.2) + llm_max_tokens = getattr(self.config, "consolidation_linking_llm_max_tokens", 1500) response, _ = await llm.generate_response_async( prompt, @@ -1899,7 +1896,7 @@ class MemoryManager: import re # 提取JSON - json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) + json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL) if json_match: json_str = json_match.group(1) else: @@ -1910,7 +1907,7 @@ class MemoryManager: except json.JSONDecodeError: logger.warning(f"LLM返回格式错误,尝试修复: {response[:200]}") # 尝试简单修复 - json_str = re.sub(r'[\r\n\t]', '', json_str) + json_str = re.sub(r"[\r\n\t]", "", json_str) analysis_results = json.loads(json_str) # 转换为结果格式 @@ -1944,25 +1941,25 @@ class MemoryManager: async def start_maintenance_scheduler(self) -> None: """ 启动记忆维护调度任务 - + 使用 unified_scheduler 定期执行维护任务: - 记忆整合(合并相似记忆) - 自动遗忘低激活度记忆 - 保存数据 - + 默认间隔:1小时 """ try: from src.schedule.unified_scheduler import TriggerType, unified_scheduler - + # 如果已有调度任务,先移除 if self._maintenance_schedule_id: await unified_scheduler.remove_schedule(self._maintenance_schedule_id) logger.info("移除旧的维护调度任务") - + # 创建新的调度任务 interval_seconds = self._maintenance_interval_hours * 3600 - + self._maintenance_schedule_id = await unified_scheduler.create_schedule( callback=self.maintenance, trigger_type=TriggerType.TIME, @@ -1973,13 +1970,13 @@ class MemoryManager: is_recurring=True, task_name="memory_maintenance", ) - + logger.info( f"✅ 记忆维护调度任务已启动 " f"(间隔={self._maintenance_interval_hours}小时, " f"schedule_id={self._maintenance_schedule_id[:8]}...)" ) - + except ImportError: logger.warning("无法导入 unified_scheduler,维护调度功能不可用") except Exception as e: @@ -1991,18 +1988,18 @@ class MemoryManager: """ if not self._maintenance_schedule_id: return - + try: from src.schedule.unified_scheduler import unified_scheduler - + success = await unified_scheduler.remove_schedule(self._maintenance_schedule_id) if success: logger.info(f"✅ 记忆维护调度任务已停止 (schedule_id={self._maintenance_schedule_id[:8]}...)") else: logger.warning(f"停止维护调度任务失败 (schedule_id={self._maintenance_schedule_id[:8]}...)") - + self._maintenance_schedule_id = None - + except ImportError: logger.warning("无法导入 unified_scheduler") except Exception as e: diff --git a/src/memory_graph/manager_singleton.py b/src/memory_graph/manager_singleton.py index b8422192e..dc735a06b 100644 --- a/src/memory_graph/manager_singleton.py +++ b/src/memory_graph/manager_singleton.py @@ -7,7 +7,6 @@ from __future__ import annotations from pathlib import Path -from typing import Optional from src.common.logger import get_logger from src.memory_graph.manager import MemoryManager @@ -15,56 +14,56 @@ from src.memory_graph.manager import MemoryManager logger = get_logger(__name__) # 全局 MemoryManager 实例 -_memory_manager: Optional[MemoryManager] = None +_memory_manager: MemoryManager | None = None _initialized: bool = False async def initialize_memory_manager( - data_dir: Optional[Path | str] = None, -) -> Optional[MemoryManager]: + data_dir: Path | str | None = None, +) -> MemoryManager | None: """ 初始化全局 MemoryManager - + 直接从 global_config.memory 读取配置 - + Args: data_dir: 数据目录(可选,默认从配置读取) - + Returns: MemoryManager 实例,如果禁用则返回 None """ global _memory_manager, _initialized - + if _initialized and _memory_manager: logger.info("MemoryManager 已经初始化,返回现有实例") return _memory_manager - + try: from src.config.config import global_config - + # 检查是否启用 - if not global_config.memory or not getattr(global_config.memory, 'enable', False): + if not global_config.memory or not getattr(global_config.memory, "enable", False): logger.info("记忆图系统已在配置中禁用") _initialized = False _memory_manager = None return None - + # 处理数据目录 if data_dir is None: - data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph') + data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph") if isinstance(data_dir, str): data_dir = Path(data_dir) - + logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...") - + _memory_manager = MemoryManager(data_dir=data_dir) await _memory_manager.initialize() - + _initialized = True logger.info("✅ 全局 MemoryManager 初始化成功") - + return _memory_manager - + except Exception as e: logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True) _initialized = False @@ -72,24 +71,24 @@ async def initialize_memory_manager( raise -def get_memory_manager() -> Optional[MemoryManager]: +def get_memory_manager() -> MemoryManager | None: """ 获取全局 MemoryManager 实例 - + Returns: MemoryManager 实例,如果未初始化则返回 None """ if not _initialized or _memory_manager is None: logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()") return None - + return _memory_manager async def shutdown_memory_manager(): """关闭全局 MemoryManager""" global _memory_manager, _initialized - + if _memory_manager: try: logger.info("正在关闭全局 MemoryManager...") diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index 497ab4892..c01716aff 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -10,7 +10,7 @@ import uuid from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np @@ -60,8 +60,8 @@ class MemoryNode: id: str # 节点唯一ID content: str # 节点内容(如:"我"、"吃饭"、"白米饭") node_type: NodeType # 节点类型 - embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要) - metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 + embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要) + metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据 created_at: datetime = field(default_factory=datetime.now) def __post_init__(self): @@ -69,7 +69,7 @@ class MemoryNode: if not self.id: self.id = str(uuid.uuid4()) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典(用于序列化)""" return { "id": self.id, @@ -81,7 +81,7 @@ class MemoryNode: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> MemoryNode: + def from_dict(cls, data: dict[str, Any]) -> MemoryNode: """从字典创建节点""" embedding = None if data.get("embedding") is not None: @@ -114,7 +114,7 @@ class MemoryEdge: relation: str # 关系名称(如:"是"、"做"、"时间"、"因为") edge_type: EdgeType # 边类型 importance: float = 0.5 # 重要性 [0-1] - metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据 created_at: datetime = field(default_factory=datetime.now) def __post_init__(self): @@ -124,7 +124,7 @@ class MemoryEdge: # 确保重要性在有效范围内 self.importance = max(0.0, min(1.0, self.importance)) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典(用于序列化)""" return { "id": self.id, @@ -138,7 +138,7 @@ class MemoryEdge: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge: + def from_dict(cls, data: dict[str, Any]) -> MemoryEdge: """从字典创建边""" return cls( id=data["id"], @@ -162,8 +162,8 @@ class Memory: id: str # 记忆唯一ID subject_id: str # 主体节点ID memory_type: MemoryType # 记忆类型 - nodes: List[MemoryNode] # 该记忆包含的所有节点 - edges: List[MemoryEdge] # 该记忆包含的所有边 + nodes: list[MemoryNode] # 该记忆包含的所有节点 + edges: list[MemoryEdge] # 该记忆包含的所有边 importance: float = 0.5 # 整体重要性 [0-1] activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘 status: MemoryStatus = MemoryStatus.STAGED # 记忆状态 @@ -171,7 +171,7 @@ class Memory: last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间 access_count: int = 0 # 访问次数 decay_factor: float = 1.0 # 衰减因子(随时间变化) - metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据 def __post_init__(self): """后初始化处理""" @@ -181,7 +181,7 @@ class Memory: self.importance = max(0.0, min(1.0, self.importance)) self.activation = max(0.0, min(1.0, self.activation)) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典(用于序列化)""" return { "id": self.id, @@ -200,7 +200,7 @@ class Memory: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Memory: + def from_dict(cls, data: dict[str, Any]) -> Memory: """从字典创建记忆""" return cls( id=data["id"], @@ -223,14 +223,14 @@ class Memory: self.last_accessed = datetime.now() self.access_count += 1 - def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]: + def get_node_by_id(self, node_id: str) -> MemoryNode | None: """根据ID获取节点""" for node in self.nodes: if node.id == node_id: return node return None - def get_subject_node(self) -> Optional[MemoryNode]: + def get_subject_node(self) -> MemoryNode | None: """获取主体节点""" return self.get_node_by_id(self.subject_id) @@ -274,10 +274,10 @@ class StagedMemory: memory: Memory # 原始记忆对象 status: MemoryStatus = MemoryStatus.STAGED # 状态 created_at: datetime = field(default_factory=datetime.now) - consolidated_at: Optional[datetime] = None # 整理时间 - merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表 + consolidated_at: datetime | None = None # 整理时间 + merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典""" return { "memory": self.memory.to_dict(), @@ -288,7 +288,7 @@ class StagedMemory: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> StagedMemory: + def from_dict(cls, data: dict[str, Any]) -> StagedMemory: """从字典创建临时记忆""" return cls( memory=Memory.from_dict(data["memory"]), diff --git a/src/memory_graph/plugin_tools/memory_plugin_tools.py b/src/memory_graph/plugin_tools/memory_plugin_tools.py index 9fab29055..f921d6851 100644 --- a/src/memory_graph/plugin_tools/memory_plugin_tools.py +++ b/src/memory_graph/plugin_tools/memory_plugin_tools.py @@ -52,16 +52,16 @@ class CreateMemoryTool(BaseTool): 示例:"我最近在学Python,想找数据分析的工作" → 调用1:{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}} → 调用2:{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}""" - + parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ ("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...',subject应填'Prou';如果看到'张三: 我在...',subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None), ("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]), ("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None), ("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填,如果有明确对象建议填写", False, None), - ("attributes", ToolParamType.STRING, "详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None), + ("attributes", ToolParamType.STRING, '详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None), ("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考:日常琐事0.3-0.4,一般对话0.5-0.6,重要信息0.7-0.8,核心记忆0.9-1.0。不确定时用0.5", False, None), ] - + available_for_llm = True async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: @@ -69,20 +69,20 @@ class CreateMemoryTool(BaseTool): try: # 获取全局 memory_manager from src.memory_graph.manager_singleton import get_memory_manager - + manager = get_memory_manager() if not manager: return { "name": self.name, "content": "记忆系统未初始化" } - + # 提取参数 subject = function_args.get("subject", "") memory_type = function_args.get("memory_type", "") topic = function_args.get("topic", "") obj = function_args.get("object") - + # 处理 attributes(可能是字符串或字典) attributes_raw = function_args.get("attributes", {}) if isinstance(attributes_raw, str): @@ -93,9 +93,9 @@ class CreateMemoryTool(BaseTool): attributes = {} else: attributes = attributes_raw - + importance = function_args.get("importance", 0.5) - + # 创建记忆 memory = await manager.create_memory( subject=subject, @@ -105,7 +105,7 @@ class CreateMemoryTool(BaseTool): attributes=attributes, importance=importance, ) - + if memory: logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}") return { @@ -119,12 +119,12 @@ class CreateMemoryTool(BaseTool): "content": "创建记忆失败", "memory_id": None, } - + except Exception as e: logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True) return { "name": self.name, - "content": f"创建记忆时出错: {str(e)}" + "content": f"创建记忆时出错: {e!s}" } @@ -133,33 +133,33 @@ class LinkMemoriesTool(BaseTool): name = "link_memories" description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。" - + parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ ("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None), ("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None), ("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]), ("strength", ToolParamType.FLOAT, "关系强度(0.0-1.0),默认0.7", False, None), ] - + available_for_llm = False # 暂不对 LLM 开放 async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行关联记忆""" try: from src.memory_graph.manager_singleton import get_memory_manager - + manager = get_memory_manager() if not manager: return { "name": self.name, "content": "记忆系统未初始化" } - + source_query = function_args.get("source_query", "") target_query = function_args.get("target_query", "") relation = function_args.get("relation", "引用") strength = function_args.get("strength", 0.7) - + # 关联记忆 success = await manager.link_memories( source_description=source_query, @@ -167,7 +167,7 @@ class LinkMemoriesTool(BaseTool): relation_type=relation, importance=strength, ) - + if success: logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}") return { @@ -179,12 +179,12 @@ class LinkMemoriesTool(BaseTool): "name": self.name, "content": "关联记忆失败,可能找不到匹配的记忆" } - + except Exception as e: logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True) return { "name": self.name, - "content": f"关联记忆时出错: {str(e)}" + "content": f"关联记忆时出错: {e!s}" } @@ -193,39 +193,39 @@ class SearchMemoriesTool(BaseTool): name = "search_memories" description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。" - + parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ ("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None), ("top_k", ToolParamType.INTEGER, "返回的记忆数量,默认5", False, None), ("min_importance", ToolParamType.FLOAT, "最低重要性阈值(0.0-1.0),只返回重要性不低于此值的记忆", False, None), ] - + available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行 async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行搜索记忆""" try: from src.memory_graph.manager_singleton import get_memory_manager - + manager = get_memory_manager() if not manager: return { "name": self.name, "content": "记忆系统未初始化" } - + query = function_args.get("query", "") top_k = function_args.get("top_k", 5) min_importance_raw = function_args.get("min_importance") min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0 - + # 搜索记忆 memories = await manager.search_memories( query=query, top_k=top_k, min_importance=min_importance, ) - + if memories: # 格式化结果 result_lines = [f"找到 {len(memories)} 条相关记忆:\n"] @@ -236,10 +236,10 @@ class SearchMemoriesTool(BaseTool): result_lines.append( f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})" ) - + result_text = "\n".join(result_lines) logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}") - + return { "name": self.name, "content": result_text @@ -249,10 +249,10 @@ class SearchMemoriesTool(BaseTool): "name": self.name, "content": f"未找到与 '{query}' 相关的记忆" } - + except Exception as e: logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True) return { "name": self.name, - "content": f"搜索记忆时出错: {str(e)}" + "content": f"搜索记忆时出错: {e!s}" } diff --git a/src/memory_graph/storage/__init__.py b/src/memory_graph/storage/__init__.py index a2be09edf..ee407917b 100644 --- a/src/memory_graph/storage/__init__.py +++ b/src/memory_graph/storage/__init__.py @@ -5,4 +5,4 @@ from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.vector_store import VectorStore -__all__ = ["VectorStore", "GraphStore"] +__all__ = ["GraphStore", "VectorStore"] diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index 69b49d089..1e1a9a91f 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -4,12 +4,10 @@ from __future__ import annotations -from typing import Dict, List, Optional, Set, Tuple - import networkx as nx from src.common.logger import get_logger -from src.memory_graph.models import Memory, MemoryEdge, MemoryNode +from src.memory_graph.models import Memory, MemoryEdge logger = get_logger(__name__) @@ -17,7 +15,7 @@ logger = get_logger(__name__) class GraphStore: """ 图存储封装类 - + 负责: 1. 记忆图的构建和维护 2. 节点和边的快速查询 @@ -31,17 +29,17 @@ class GraphStore: self.graph = nx.DiGraph() # 索引:记忆ID -> 记忆对象 - self.memory_index: Dict[str, Memory] = {} + self.memory_index: dict[str, Memory] = {} # 索引:节点ID -> 所属记忆ID集合 - self.node_to_memories: Dict[str, Set[str]] = {} + self.node_to_memories: dict[str, set[str]] = {} logger.info("初始化图存储") def add_memory(self, memory: Memory) -> None: """ 添加记忆到图 - + Args: memory: 要添加的记忆 """ @@ -84,34 +82,34 @@ class GraphStore: logger.error(f"添加记忆失败: {e}", exc_info=True) raise - def get_memory_by_id(self, memory_id: str) -> Optional[Memory]: + def get_memory_by_id(self, memory_id: str) -> Memory | None: """ 根据ID获取记忆 - + Args: memory_id: 记忆ID - + Returns: 记忆对象或 None """ return self.memory_index.get(memory_id) - def get_all_memories(self) -> List[Memory]: + def get_all_memories(self) -> list[Memory]: """ 获取所有记忆 - + Returns: 所有记忆的列表 """ return list(self.memory_index.values()) - def get_memories_by_node(self, node_id: str) -> List[Memory]: + def get_memories_by_node(self, node_id: str) -> list[Memory]: """ 获取包含指定节点的所有记忆 - + Args: node_id: 节点ID - + Returns: 记忆列表 """ @@ -121,14 +119,14 @@ class GraphStore: memory_ids = self.node_to_memories[node_id] return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index] - def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]: + def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]: """ 获取从指定节点出发的所有边 - + Args: node_id: 源节点ID relation_types: 关系类型过滤(可选) - + Returns: 边信息列表 """ @@ -155,16 +153,16 @@ class GraphStore: return edges def get_neighbors( - self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None - ) -> List[Tuple[str, Dict]]: + self, node_id: str, direction: str = "out", relation_types: list[str] | None = None + ) -> list[tuple[str, dict]]: """ 获取节点的邻居节点 - + Args: node_id: 节点ID direction: 方向 ("out"=出边, "in"=入边, "both"=双向) relation_types: 关系类型过滤 - + Returns: List of (neighbor_id, edge_data) """ @@ -187,15 +185,15 @@ class GraphStore: return neighbors - def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]: + def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None: """ 查找两个节点之间的最短路径 - + Args: source_id: 源节点ID target_id: 目标节点ID max_length: 最大路径长度(可选) - + Returns: 路径节点ID列表,或 None(如果不存在路径) """ @@ -220,18 +218,18 @@ class GraphStore: def bfs_expand( self, - start_nodes: List[str], + start_nodes: list[str], depth: int = 1, - relation_types: Optional[List[str]] = None, - ) -> Set[str]: + relation_types: list[str] | None = None, + ) -> set[str]: """ 从起始节点进行广度优先搜索扩展 - + Args: start_nodes: 起始节点ID列表 depth: 扩展深度 relation_types: 关系类型过滤 - + Returns: 扩展到的所有节点ID集合 """ @@ -256,13 +254,13 @@ class GraphStore: return visited - def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph: + def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph: """ 获取包含指定节点的子图 - + Args: node_ids: 节点ID列表 - + Returns: NetworkX 子图 """ @@ -271,7 +269,7 @@ class GraphStore: def merge_nodes(self, source_id: str, target_id: str) -> None: """ 合并两个节点(将source的所有边转移到target,然后删除source) - + Args: source_id: 源节点ID(将被删除) target_id: 目标节点ID(保留) @@ -308,13 +306,13 @@ class GraphStore: logger.error(f"合并节点失败: {e}", exc_info=True) raise - def get_node_degree(self, node_id: str) -> Tuple[int, int]: + def get_node_degree(self, node_id: str) -> tuple[int, int]: """ 获取节点的度数 - + Args: node_id: 节点ID - + Returns: (in_degree, out_degree) """ @@ -323,7 +321,7 @@ class GraphStore: return (self.graph.in_degree(node_id), self.graph.out_degree(node_id)) - def get_statistics(self) -> Dict[str, int]: + def get_statistics(self) -> dict[str, int]: """获取图的统计信息""" return { "total_nodes": self.graph.number_of_nodes(), @@ -332,10 +330,10 @@ class GraphStore: "connected_components": nx.number_weakly_connected_components(self.graph), } - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """ 将图转换为字典(用于持久化) - + Returns: 图的字典表示 """ @@ -356,13 +354,13 @@ class GraphStore: } @classmethod - def from_dict(cls, data: Dict) -> GraphStore: + def from_dict(cls, data: dict) -> GraphStore: """ 从字典加载图 - + Args: data: 图的字典表示 - + Returns: GraphStore 实例 """ @@ -406,7 +404,6 @@ class GraphStore: 规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。 已存在的边(通过 edge.id 检查)将不会重复添加。 """ - from src.memory_graph.models import MemoryEdge # 构建快速查重索引:memory_id -> set(edge_id) existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()} @@ -465,10 +462,10 @@ class GraphStore: def remove_memory(self, memory_id: str) -> bool: """ 从图中删除指定记忆 - + Args: memory_id: 要删除的记忆ID - + Returns: 是否删除成功 """ @@ -477,9 +474,9 @@ class GraphStore: if memory_id not in self.memory_index: logger.warning(f"记忆不存在,无法删除: {memory_id}") return False - + memory = self.memory_index[memory_id] - + # 2. 从节点映射中移除此记忆 for node in memory.nodes: if node.id in self.node_to_memories: @@ -489,13 +486,13 @@ class GraphStore: if self.graph.has_node(node.id): self.graph.remove_node(node.id) del self.node_to_memories[node.id] - + # 3. 从记忆索引中移除 del self.memory_index[memory_id] - + logger.info(f"成功删除记忆: {memory_id}") return True - + except Exception as e: logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True) return False diff --git a/src/memory_graph/storage/persistence.py b/src/memory_graph/storage/persistence.py index 3600ab5f6..bb6dc2946 100644 --- a/src/memory_graph/storage/persistence.py +++ b/src/memory_graph/storage/persistence.py @@ -8,14 +8,12 @@ import asyncio import json from datetime import datetime from pathlib import Path -from typing import Optional import orjson from src.common.logger import get_logger -from src.memory_graph.models import Memory, StagedMemory +from src.memory_graph.models import StagedMemory from src.memory_graph.storage.graph_store import GraphStore -from src.memory_graph.storage.vector_store import VectorStore logger = get_logger(__name__) @@ -23,7 +21,7 @@ logger = get_logger(__name__) class PersistenceManager: """ 持久化管理器 - + 负责: 1. 图数据的保存和加载 2. 定期自动保存 @@ -39,7 +37,7 @@ class PersistenceManager: ): """ 初始化持久化管理器 - + Args: data_dir: 数据存储目录 graph_file_name: 图数据文件名 @@ -55,7 +53,7 @@ class PersistenceManager: self.backup_dir.mkdir(parents=True, exist_ok=True) self.auto_save_interval = auto_save_interval - self._auto_save_task: Optional[asyncio.Task] = None + self._auto_save_task: asyncio.Task | None = None self._running = False logger.info(f"初始化持久化管理器: data_dir={data_dir}") @@ -63,7 +61,7 @@ class PersistenceManager: async def save_graph_store(self, graph_store: GraphStore) -> None: """ 保存图存储到文件 - + Args: graph_store: 图存储对象 """ @@ -95,10 +93,10 @@ class PersistenceManager: logger.error(f"保存图数据失败: {e}", exc_info=True) raise - async def load_graph_store(self) -> Optional[GraphStore]: + async def load_graph_store(self) -> GraphStore | None: """ 从文件加载图存储 - + Returns: GraphStore 对象,如果文件不存在则返回 None """ @@ -129,7 +127,7 @@ class PersistenceManager: async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None: """ 保存临时记忆列表 - + Args: staged_memories: 临时记忆列表 """ @@ -158,7 +156,7 @@ class PersistenceManager: async def load_staged_memories(self) -> list[StagedMemory]: """ 加载临时记忆列表 - + Returns: 临时记忆列表 """ @@ -179,10 +177,10 @@ class PersistenceManager: logger.error(f"加载临时记忆失败: {e}", exc_info=True) return [] - async def create_backup(self) -> Optional[Path]: + async def create_backup(self) -> Path | None: """ 创建当前数据的备份 - + Returns: 备份文件路径,如果失败则返回 None """ @@ -208,7 +206,7 @@ class PersistenceManager: logger.error(f"创建备份失败: {e}", exc_info=True) return None - async def _load_from_backup(self) -> Optional[GraphStore]: + async def _load_from_backup(self) -> GraphStore | None: """从最新的备份加载数据""" try: # 查找最新的备份文件 @@ -236,7 +234,7 @@ class PersistenceManager: async def _cleanup_old_backups(self, keep: int = 10) -> None: """ 清理旧备份,只保留最近的几个 - + Args: keep: 保留的备份数量 """ @@ -254,11 +252,11 @@ class PersistenceManager: async def start_auto_save( self, graph_store: GraphStore, - staged_memories_getter: callable = None, + staged_memories_getter: callable | None = None, ) -> None: """ 启动自动保存任务 - + Args: graph_store: 图存储对象 staged_memories_getter: 获取临时记忆的回调函数 @@ -310,7 +308,7 @@ class PersistenceManager: async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None: """ 导出图数据到指定的 JSON 文件(用于数据迁移或分析) - + Args: output_file: 输出文件路径 graph_store: 图存储对象 @@ -334,13 +332,13 @@ class PersistenceManager: logger.error(f"导出图数据失败: {e}", exc_info=True) raise - async def import_from_json(self, input_file: Path) -> Optional[GraphStore]: + async def import_from_json(self, input_file: Path) -> GraphStore | None: """ 从 JSON 文件导入图数据 - + Args: input_file: 输入文件路径 - + Returns: GraphStore 对象 """ @@ -360,7 +358,7 @@ class PersistenceManager: def get_data_size(self) -> dict[str, int]: """ 获取数据文件的大小信息 - + Returns: 文件大小字典(字节) """ diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index ff5d433d4..883e32f6b 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -4,9 +4,8 @@ from __future__ import annotations -import uuid from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import numpy as np @@ -19,7 +18,7 @@ logger = get_logger(__name__) class VectorStore: """ 向量存储封装类 - + 负责: 1. 节点的语义向量存储和检索 2. 基于相似度的向量搜索 @@ -29,12 +28,12 @@ class VectorStore: def __init__( self, collection_name: str = "memory_nodes", - data_dir: Optional[Path] = None, - embedding_function: Optional[Any] = None, + data_dir: Path | None = None, + embedding_function: Any | None = None, ): """ 初始化向量存储 - + Args: collection_name: ChromaDB 集合名称 data_dir: 数据存储目录 @@ -80,7 +79,7 @@ class VectorStore: async def add_node(self, node: MemoryNode) -> None: """ 添加节点到向量存储 - + Args: node: 要添加的节点 """ @@ -98,17 +97,17 @@ class VectorStore: "node_type": node.node_type.value, "created_at": node.created_at.isoformat(), } - + # 处理额外的元数据,将 list 转换为 JSON 字符串 for key, value in node.metadata.items(): if isinstance(value, (list, dict)): import orjson - metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') + metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") elif isinstance(value, (str, int, float, bool)) or value is None: metadata[key] = value else: metadata[key] = str(value) - + self.collection.add( ids=[node.id], embeddings=[node.embedding.tolist()], @@ -122,10 +121,10 @@ class VectorStore: logger.error(f"添加节点失败: {e}", exc_info=True) raise - async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None: + async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None: """ 批量添加节点 - + Args: nodes: 节点列表 """ @@ -151,13 +150,13 @@ class VectorStore: } for key, value in n.metadata.items(): if isinstance(value, (list, dict)): - metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') + metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") elif isinstance(value, (str, int, float, bool)) or value is None: metadata[key] = value # type: ignore else: metadata[key] = str(value) metadatas.append(metadata) - + self.collection.add( ids=[n.id for n in valid_nodes], embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore @@ -175,18 +174,18 @@ class VectorStore: self, query_embedding: np.ndarray, limit: int = 10, - node_types: Optional[List[NodeType]] = None, + node_types: list[NodeType] | None = None, min_similarity: float = 0.0, - ) -> List[Tuple[str, float, Dict[str, Any]]]: + ) -> list[tuple[str, float, dict[str, Any]]]: """ 搜索相似节点 - + Args: query_embedding: 查询向量 limit: 返回结果数量 node_types: 限制节点类型(可选) min_similarity: 最小相似度阈值 - + Returns: List of (node_id, similarity, metadata) """ @@ -214,7 +213,7 @@ class VectorStore: if ids is not None and len(ids) > 0 and len(ids[0]) > 0: distances = results.get("distances") metadatas = results.get("metadatas") - + for i, node_id in enumerate(ids[0]): # ChromaDB 返回的是距离,需要转换为相似度 # 余弦距离: distance = 1 - similarity @@ -223,15 +222,15 @@ class VectorStore: if similarity >= min_similarity: metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore - + # 解析 JSON 字符串回列表/字典 for key, value in list(metadata.items()): - if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): + if isinstance(value, str) and (value.startswith("[") or value.startswith("{")): try: metadata[key] = orjson.loads(value) - except: + except Exception: pass # 保持原值 - + similar_nodes.append((node_id, similarity, metadata)) logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果") @@ -243,19 +242,19 @@ class VectorStore: async def search_with_multiple_queries( self, - query_embeddings: List[np.ndarray], - query_weights: Optional[List[float]] = None, + query_embeddings: list[np.ndarray], + query_weights: list[float] | None = None, limit: int = 10, - node_types: Optional[List[NodeType]] = None, + node_types: list[NodeType] | None = None, min_similarity: float = 0.0, fusion_strategy: str = "weighted_max", - ) -> List[Tuple[str, float, Dict[str, Any]]]: + ) -> list[tuple[str, float, dict[str, Any]]]: """ 多查询融合搜索 - + 使用多个查询向量进行搜索,然后融合结果。 这能解决单一查询向量无法同时关注多个关键概念的问题。 - + Args: query_embeddings: 查询向量列表 query_weights: 每个查询的权重(可选,默认均等) @@ -266,7 +265,7 @@ class VectorStore: - "weighted_max": 加权最大值(推荐) - "weighted_sum": 加权求和 - "rrf": Reciprocal Rank Fusion - + Returns: 融合后的节点列表 [(node_id, fused_score, metadata), ...] """ @@ -279,7 +278,7 @@ class VectorStore: # 默认权重均等 if query_weights is None: query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings) - + # 归一化权重 total_weight = sum(query_weights) if total_weight > 0: @@ -287,7 +286,7 @@ class VectorStore: try: # 1. 对每个查询执行搜索 - all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata} + all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata} for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)): # 搜索更多结果以提高融合质量 @@ -307,13 +306,13 @@ class VectorStore: "ranks": [], "metadata": metadata, } - + all_results[node_id]["scores"].append((similarity, weight)) all_results[node_id]["ranks"].append((rank, weight)) # 2. 融合分数 fused_results = [] - + for node_id, data in all_results.items(): scores = data["scores"] ranks = data["ranks"] @@ -356,13 +355,13 @@ class VectorStore: logger.error(f"多查询融合搜索失败: {e}", exc_info=True) raise - async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]: + async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None: """ 根据ID获取节点元数据 - + Args: node_id: 节点ID - + Returns: 节点元数据或 None """ @@ -378,7 +377,7 @@ class VectorStore: if ids is not None and len(ids) > 0: metadatas = result.get("metadatas") embeddings = result.get("embeddings") - + return { "id": ids[0], "metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {}, @@ -394,7 +393,7 @@ class VectorStore: async def delete_node(self, node_id: str) -> None: """ 删除节点 - + Args: node_id: 节点ID """ @@ -412,7 +411,7 @@ class VectorStore: async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None: """ 更新节点的 embedding - + Args: node_id: 节点ID embedding: 新的向量 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 9f0e431c9..b05e89c5b 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑 from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from src.common.logger import get_logger from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.extractor import MemoryExtractor -from src.memory_graph.models import Memory, MemoryStatus +from src.memory_graph.models import Memory from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.vector_store import VectorStore @@ -21,7 +21,7 @@ logger = get_logger(__name__) class MemoryTools: """ 记忆系统工具集 - + 提供给 LLM 使用的工具接口: 1. create_memory: 创建新记忆 2. link_memories: 关联两个记忆 @@ -33,7 +33,7 @@ class MemoryTools: vector_store: VectorStore, graph_store: GraphStore, persistence_manager: PersistenceManager, - embedding_generator: Optional[EmbeddingGenerator] = None, + embedding_generator: EmbeddingGenerator | None = None, max_expand_depth: int = 1, expand_semantic_threshold: float = 0.3, ): @@ -72,10 +72,10 @@ class MemoryTools: self._initialized = True @staticmethod - def get_create_memory_schema() -> Dict[str, Any]: + def get_create_memory_schema() -> dict[str, Any]: """ 获取 create_memory 工具的 JSON schema - + Returns: 工具 schema 定义 """ @@ -145,15 +145,15 @@ class MemoryTools: "description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05'、'2025年11月'\n- 相对时间:'今天'、'昨天'、'上周'、'最近'、'3天前'\n- 时间段:'今天下午'、'上个月'、'这学期'", }, "地点": { - "type": "string", + "type": "string", "description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家'、'公司'、'学校'、'咖啡店'" }, "原因": { - "type": "string", + "type": "string", "description": "为什么这样做/这样想(如明确提到)" }, "方式": { - "type": "string", + "type": "string", "description": "怎么做的/通过什么方式(如明确提到)" }, "结果": { @@ -183,10 +183,10 @@ class MemoryTools: } @staticmethod - def get_link_memories_schema() -> Dict[str, Any]: + def get_link_memories_schema() -> dict[str, Any]: """ 获取 link_memories 工具的 JSON schema - + Returns: 工具 schema 定义 """ @@ -239,10 +239,10 @@ class MemoryTools: } @staticmethod - def get_search_memories_schema() -> Dict[str, Any]: + def get_search_memories_schema() -> dict[str, Any]: """ 获取 search_memories 工具的 JSON schema - + Returns: 工具 schema 定义 """ @@ -307,13 +307,13 @@ class MemoryTools: }, } - async def create_memory(self, **params) -> Dict[str, Any]: + async def create_memory(self, **params) -> dict[str, Any]: """ 执行 create_memory 工具 - + Args: **params: 工具参数 - + Returns: 执行结果 """ @@ -353,13 +353,13 @@ class MemoryTools: "message": "记忆创建失败", } - async def link_memories(self, **params) -> Dict[str, Any]: + async def link_memories(self, **params) -> dict[str, Any]: """ 执行 link_memories 工具 - + Args: **params: 工具参数 - + Returns: 执行结果 """ @@ -433,15 +433,15 @@ class MemoryTools: "message": "记忆关联失败", } - async def search_memories(self, **params) -> Dict[str, Any]: + async def search_memories(self, **params) -> dict[str, Any]: """ 执行 search_memories 工具 - + 使用多策略检索优化: 1. 查询分解(识别主要实体和概念) 2. 多查询并行检索 3. 结果融合和重排 - + Args: **params: 工具参数 - query: 查询字符串 @@ -449,7 +449,7 @@ class MemoryTools: - expand_depth: 扩展深度(暂未使用) - use_multi_query: 是否使用多查询策略(默认True) - context: 查询上下文(可选) - + Returns: 搜索结果 """ @@ -477,7 +477,7 @@ class MemoryTools: # 2. 提取初始记忆ID(来自向量搜索) initial_memory_ids = set() memory_scores = {} # 记录每个记忆的初始分数 - + for node_id, similarity, metadata in similar_nodes: if "memory_ids" in metadata: ids = metadata["memory_ids"] @@ -486,7 +486,7 @@ class MemoryTools: import orjson try: ids = orjson.loads(ids) - except: + except Exception: ids = [ids] if isinstance(ids, list): for mem_id in ids: @@ -499,12 +499,12 @@ class MemoryTools: expanded_memory_scores = {} if expand_depth > 0 and initial_memory_ids: logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}") - + # 获取查询的embedding用于语义过滤 if self.builder.embedding_generator: try: query_embedding = await self.builder.embedding_generator.generate(query) - + # 直接使用图扩展逻辑(避免循环依赖) expanded_results = await self._expand_with_semantic_filter( initial_memory_ids=list(initial_memory_ids), @@ -513,7 +513,7 @@ class MemoryTools: semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值 max_expanded=top_k * 2 ) - + # 旧代码(如果需要使用Manager): # from src.memory_graph.manager import MemoryManager # manager = MemoryManager.get_instance() @@ -524,19 +524,18 @@ class MemoryTools: # semantic_threshold=0.5, # max_expanded=top_k * 2 # ) - + # 合并扩展结果 - for mem_id, score in expanded_results: - expanded_memory_scores[mem_id] = score - + expanded_memory_scores.update(dict(expanded_results)) + logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆") - + except Exception as e: logger.warning(f"图扩展失败: {e}") # 4. 合并初始记忆和扩展记忆 all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys()) - + # 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数 final_scores = {} for mem_id in all_memory_ids: @@ -546,7 +545,7 @@ class MemoryTools: elif mem_id in expanded_memory_scores: # 扩展记忆:使用图扩展分数(稍微降权) final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8 - + # 按分数排序 sorted_memory_ids = sorted( final_scores.keys(), @@ -562,7 +561,7 @@ class MemoryTools: # 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%) similarity_score = final_scores[memory_id] importance_score = memory.importance - + # 计算时效性分数(最近的记忆得分更高) from datetime import datetime, timezone now = datetime.now(timezone.utc) @@ -573,16 +572,16 @@ class MemoryTools: memory_time = memory.created_at age_days = (now - memory_time).total_seconds() / 86400 recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期 - + # 综合分数 final_score = ( similarity_score * 0.6 + importance_score * 0.3 + recency_score * 0.1 ) - + memories_with_scores.append((memory, final_score)) - + # 按综合分数排序 memories_with_scores.sort(key=lambda x: x[1], reverse=True) memories = [mem for mem, _ in memories_with_scores[:top_k]] @@ -624,16 +623,16 @@ class MemoryTools: } async def _generate_multi_queries_simple( - self, query: str, context: Optional[Dict[str, Any]] = None - ) -> List[Tuple[str, float]]: + self, query: str, context: dict[str, Any] | None = None + ) -> list[tuple[str, float]]: """ 简化版多查询生成(直接在 Tools 层实现,避免循环依赖) - + 让小模型直接生成3-5个不同角度的查询语句。 """ try: - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm = LLMRequest( model_set=model_config.model_task_config.utils_small, @@ -648,10 +647,10 @@ class MemoryTools: # 处理聊天历史,提取最近5条左右的对话 recent_chat = "" if chat_history: - lines = chat_history.strip().split('\n') + lines = chat_history.strip().split("\n") # 取最近5条消息 recent_lines = lines[-5:] if len(lines) > 5 else lines - recent_chat = '\n'.join(recent_lines) + recent_chat = "\n".join(recent_lines) prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。 @@ -685,36 +684,38 @@ class MemoryTools: """ response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) - - import orjson, re - response = re.sub(r'```json\s*', '', response) - response = re.sub(r'```\s*$', '', response).strip() - + + import re + + import orjson + response = re.sub(r"```json\s*", "", response) + response = re.sub(r"```\s*$", "", response).strip() + data = orjson.loads(response) queries = data.get("queries", []) - - result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) + + result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) for item in queries if item.get("text", "").strip()] - + if result: logger.info(f"生成查询: {[q for q, _ in result]}") return result - + except Exception as e: logger.warning(f"多查询生成失败: {e}") - + return [(query, 1.0)] async def _single_query_search( self, query: str, top_k: int - ) -> List[Tuple[str, float, Dict[str, Any]]]: + ) -> list[tuple[str, float, dict[str, Any]]]: """ 传统的单查询搜索 - + Args: query: 查询字符串 top_k: 返回结果数 - + Returns: 相似节点列表 [(node_id, similarity, metadata), ...] """ @@ -735,30 +736,30 @@ class MemoryTools: return similar_nodes async def _multi_query_search( - self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None - ) -> List[Tuple[str, float, Dict[str, Any]]]: + self, query: str, top_k: int, context: dict[str, Any] | None = None + ) -> list[tuple[str, float, dict[str, Any]]]: """ 多查询策略搜索(简化版) - + 直接使用小模型生成多个查询,无需复杂的分解和组合。 - + 步骤: 1. 让小模型生成3-5个不同角度的查询 2. 为每个查询生成嵌入 3. 并行搜索并融合结果 - + Args: query: 查询字符串 top_k: 返回结果数 context: 查询上下文 - + Returns: 融合后的相似节点列表 """ try: # 1. 使用小模型生成多个查询 multi_queries = await self._generate_multi_queries_simple(query, context) - + logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}") # 2. 生成所有查询的嵌入 @@ -800,13 +801,13 @@ class MemoryTools: if node.embedding is not None: await self.vector_store.add_node(node) - async def _find_memory_by_description(self, description: str) -> Optional[Memory]: + async def _find_memory_by_description(self, description: str) -> Memory | None: """ 通过描述查找记忆 - + Args: description: 记忆描述 - + Returns: 找到的记忆,如果没有则返回 None """ @@ -827,13 +828,13 @@ class MemoryTools: return None # 获取最相似节点关联的记忆 - node_id, similarity, metadata = similar_nodes[0] - + _node_id, _similarity, metadata = similar_nodes[0] + if "memory_ids" not in metadata or not metadata["memory_ids"]: return None - + ids = metadata["memory_ids"] - + # 确保是列表 if isinstance(ids, str): import orjson @@ -842,11 +843,11 @@ class MemoryTools: except Exception as e: logger.warning(f"JSON 解析失败: {e}") ids = [ids] - + if isinstance(ids, list) and ids: memory_id = ids[0] return self.graph_store.get_memory_by_id(memory_id) - + return None def _summarize_memory(self, memory: Memory) -> str: @@ -862,103 +863,102 @@ class MemoryTools: async def _expand_with_semantic_filter( self, - initial_memory_ids: List[str], + initial_memory_ids: list[str], query_embedding, max_depth: int = 2, semantic_threshold: float = 0.5, max_expanded: int = 20 - ) -> List[Tuple[str, float]]: + ) -> list[tuple[str, float]]: """ 从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤 - + Args: initial_memory_ids: 初始记忆ID集合 query_embedding: 查询向量 max_depth: 最大扩展深度 semantic_threshold: 语义相似度阈值 max_expanded: 最多扩展多少个记忆 - + Returns: List[(memory_id, relevance_score)] """ if not initial_memory_ids or query_embedding is None: return [] - + try: - import numpy as np - + visited_memories = set(initial_memory_ids) - expanded_memories: Dict[str, float] = {} - + expanded_memories: dict[str, float] = {} + current_level = initial_memory_ids - + for depth in range(max_depth): next_level = [] - + for memory_id in current_level: memory = self.graph_store.get_memory_by_id(memory_id) if not memory: continue - + for node in memory.nodes: if not node.has_embedding(): continue - + try: neighbors = list(self.graph_store.graph.neighbors(node.id)) - except: + except Exception: continue - + for neighbor_id in neighbors: neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id) if not neighbor_node_data: continue - + neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id) if neighbor_vector_data is None: continue - + neighbor_embedding = neighbor_vector_data.get("embedding") if neighbor_embedding is None: continue - + # 计算语义相似度 semantic_sim = self._cosine_similarity( query_embedding, neighbor_embedding ) - + # 获取边权重 try: edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id) edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5 - except: + except Exception: edge_importance = 0.5 - + # 综合评分 depth_decay = 1.0 / (depth + 1) relevance_score = ( - semantic_sim * 0.7 + - edge_importance * 0.2 + + semantic_sim * 0.7 + + edge_importance * 0.2 + depth_decay * 0.1 ) - + if relevance_score < semantic_threshold: continue - + # 提取记忆ID neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) if isinstance(neighbor_memory_ids, str): import orjson try: neighbor_memory_ids = orjson.loads(neighbor_memory_ids) - except: + except Exception: neighbor_memory_ids = [neighbor_memory_ids] - + for neighbor_mem_id in neighbor_memory_ids: if neighbor_mem_id in visited_memories: continue - + if neighbor_mem_id not in expanded_memories: expanded_memories[neighbor_mem_id] = relevance_score visited_memories.add(neighbor_mem_id) @@ -968,52 +968,52 @@ class MemoryTools: expanded_memories[neighbor_mem_id], relevance_score ) - + if not next_level or len(expanded_memories) >= max_expanded: break - + current_level = next_level[:max_expanded] - + sorted_results = sorted( expanded_memories.items(), key=lambda x: x[1], reverse=True )[:max_expanded] - + return sorted_results - + except Exception as e: logger.error(f"图扩展失败: {e}", exc_info=True) return [] - + def _cosine_similarity(self, vec1, vec2) -> float: """计算余弦相似度""" try: import numpy as np - + if not isinstance(vec1, np.ndarray): vec1 = np.array(vec1) if not isinstance(vec2, np.ndarray): vec2 = np.array(vec2) - + vec1_norm = np.linalg.norm(vec1) vec2_norm = np.linalg.norm(vec2) - + if vec1_norm == 0 or vec2_norm == 0: return 0.0 - + similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) return float(similarity) - + except Exception as e: logger.warning(f"计算余弦相似度失败: {e}") return 0.0 @staticmethod - def get_all_tool_schemas() -> List[Dict[str, Any]]: + def get_all_tool_schemas() -> list[dict[str, Any]]: """ 获取所有工具的 schema - + Returns: 工具 schema 列表 """ diff --git a/src/memory_graph/utils/__init__.py b/src/memory_graph/utils/__init__.py index 0b23863d0..cfd4f4f98 100644 --- a/src/memory_graph/utils/__init__.py +++ b/src/memory_graph/utils/__init__.py @@ -5,4 +5,4 @@ from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.time_parser import TimeParser -__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"] +__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"] diff --git a/src/memory_graph/utils/embeddings.py b/src/memory_graph/utils/embeddings.py index 016b7bda3..ae80b5aa0 100644 --- a/src/memory_graph/utils/embeddings.py +++ b/src/memory_graph/utils/embeddings.py @@ -5,8 +5,6 @@ from __future__ import annotations import asyncio -from functools import lru_cache -from typing import List, Optional import numpy as np @@ -18,12 +16,12 @@ logger = get_logger(__name__) class EmbeddingGenerator: """ 嵌入向量生成器 - + 策略: 1. 优先使用配置的 embedding API(通过 LLMRequest) 2. 如果 API 不可用,回退到本地 sentence-transformers 3. 如果 sentence-transformers 未安装,使用随机向量(仅测试) - + 优点: - 降低本地运算负载 - 即使未安装 sentence-transformers 也可正常运行 @@ -37,19 +35,19 @@ class EmbeddingGenerator: ): """ 初始化嵌入生成器 - + Args: use_api: 是否优先使用 API(默认 True) fallback_model_name: 回退本地模型名称 """ self.use_api = use_api self.fallback_model_name = fallback_model_name - + # API 相关 self._llm_request = None self._api_available = False self._api_dimension = None - + # 本地模型相关 self._local_model = None self._local_model_loaded = False @@ -58,24 +56,24 @@ class EmbeddingGenerator: """初始化 embedding API""" if self._api_available: return - + try: from src.config.config import model_config from src.llm_models.utils_model import LLMRequest - + embedding_config = model_config.model_task_config.embedding self._llm_request = LLMRequest( model_set=embedding_config, request_type="memory_graph.embedding" ) - + # 获取嵌入维度 if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension: self._api_dimension = embedding_config.embedding_dimension - + self._api_available = True logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})") - + except Exception as e: logger.warning(f"⚠️ Embedding API 初始化失败: {e}") self._api_available = False @@ -103,15 +101,15 @@ class EmbeddingGenerator: async def generate(self, text: str) -> np.ndarray: """ 生成单个文本的嵌入向量 - + 策略: 1. 优先使用 API 2. API 失败则使用本地模型 3. 本地模型不可用则使用随机向量 - + Args: text: 输入文本 - + Returns: 嵌入向量 """ @@ -126,12 +124,12 @@ class EmbeddingGenerator: embedding = await self._generate_with_api(text) if embedding is not None: return embedding - + # 策略 2: 使用本地模型 embedding = await self._generate_with_local_model(text) if embedding is not None: return embedding - + # 策略 3: 随机向量(仅测试) logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...") dim = self._get_dimension() @@ -142,47 +140,47 @@ class EmbeddingGenerator: dim = self._get_dimension() return np.random.rand(dim).astype(np.float32) - async def _generate_with_api(self, text: str) -> Optional[np.ndarray]: + async def _generate_with_api(self, text: str) -> np.ndarray | None: """使用 API 生成嵌入""" try: # 初始化 API if not self._api_available: await self._initialize_api() - + if not self._api_available or not self._llm_request: return None - + # 调用 API embedding_list, model_name = await self._llm_request.get_embedding(text) - + if embedding_list and len(embedding_list) > 0: embedding = np.array(embedding_list, dtype=np.float32) logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})") return embedding - + return None - + except Exception as e: logger.debug(f"API 嵌入生成失败: {e}") return None - async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]: + async def _generate_with_local_model(self, text: str) -> np.ndarray | None: """使用本地模型生成嵌入""" try: # 加载本地模型 if not self._local_model_loaded: self._load_local_model() - + if not self._local_model_loaded or not self._local_model: return None - + # 在线程池中运行 loop = asyncio.get_event_loop() embedding = await loop.run_in_executor(None, self._encode_single_local, text) - + logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维") return embedding - + except Exception as e: logger.debug(f"本地模型嵌入生成失败: {e}") return None @@ -199,24 +197,24 @@ class EmbeddingGenerator: # 优先使用 API 维度 if self._api_dimension: return self._api_dimension - + # 其次使用本地模型维度 if self._local_model_loaded and self._local_model: try: return self._local_model.get_sentence_embedding_dimension() - except: + except Exception: pass - + # 默认 384(sentence-transformers 常用维度) return 384 - async def generate_batch(self, texts: List[str]) -> List[np.ndarray]: + async def generate_batch(self, texts: list[str]) -> list[np.ndarray]: """ 批量生成嵌入向量 - + Args: texts: 文本列表 - + Returns: 嵌入向量列表 """ @@ -236,13 +234,13 @@ class EmbeddingGenerator: results = await self._generate_batch_with_api(valid_texts) if results: return results - + # 回退到逐个生成 results = [] for text in valid_texts: embedding = await self.generate(text) results.append(embedding) - + logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本") return results @@ -251,7 +249,7 @@ class EmbeddingGenerator: dim = self._get_dimension() return [np.random.rand(dim).astype(np.float32) for _ in texts] - async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]: + async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None: """使用 API 批量生成""" try: # 对于大多数 API,批量调用就是多次单独调用 @@ -273,7 +271,7 @@ class EmbeddingGenerator: # 全局单例 -_global_generator: Optional[EmbeddingGenerator] = None +_global_generator: EmbeddingGenerator | None = None def get_embedding_generator( @@ -282,11 +280,11 @@ def get_embedding_generator( ) -> EmbeddingGenerator: """ 获取全局嵌入生成器单例 - + Args: use_api: 是否优先使用 API fallback_model_name: 回退本地模型名称 - + Returns: EmbeddingGenerator 实例 """ diff --git a/src/memory_graph/utils/memory_formatter.py b/src/memory_graph/utils/memory_formatter.py index b19266af5..fac7d45a2 100644 --- a/src/memory_graph/utils/memory_formatter.py +++ b/src/memory_graph/utils/memory_formatter.py @@ -5,10 +5,9 @@ """ import logging -from typing import Optional, List, Dict, Any from datetime import datetime -from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType +from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType logger = logging.getLogger(__name__) @@ -16,18 +15,18 @@ logger = logging.getLogger(__name__) def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str: """ 将记忆对象格式化为适合提示词的自然语言描述 - + 根据记忆的图结构,构建完整的主谓宾描述,包含: - 主语(subject node) - 谓语/动作(topic node) - 宾语/对象(object node,如果存在) - 属性信息(attributes,如时间、地点等) - 关系信息(记忆之间的关系) - + Args: memory: 记忆对象 include_metadata: 是否包含元数据(时间、重要性等) - + Returns: 格式化后的自然语言描述 """ @@ -37,24 +36,22 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> if not subject_node: logger.warning(f"记忆 {memory.id} 缺少主体节点") return "(记忆格式错误:缺少主体)" - + subject_text = subject_node.content - + # 2. 查找主题节点(谓语/动作) topic_node = None - memory_type_relation = None for edge in memory.edges: if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id: topic_node = memory.get_node_by_id(edge.target_id) - memory_type_relation = edge.relation break - + if not topic_node: logger.warning(f"记忆 {memory.id} 缺少主题节点") return f"{subject_text}(记忆格式错误:缺少主题)" - + topic_text = topic_node.content - + # 3. 查找客体节点(宾语)和核心关系 object_node = None core_relation = None @@ -63,9 +60,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> object_node = memory.get_node_by_id(edge.target_id) core_relation = edge.relation if edge.relation else "" break - + # 4. 收集属性节点 - attributes: Dict[str, str] = {} + attributes: dict[str, str] = {} for edge in memory.edges: if edge.edge_type == EdgeType.ATTRIBUTE: # 查找属性节点和值节点 @@ -73,16 +70,16 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> if attr_node and attr_node.node_type == NodeType.ATTRIBUTE: # 查找这个属性的值 for value_edge in memory.edges: - if (value_edge.edge_type == EdgeType.ATTRIBUTE + if (value_edge.edge_type == EdgeType.ATTRIBUTE and value_edge.source_id == attr_node.id): value_node = memory.get_node_by_id(value_edge.target_id) if value_node and value_node.node_type == NodeType.VALUE: attributes[attr_node.content] = value_node.content break - + # 5. 构建自然语言描述 parts = [] - + # 主谓宾结构 if object_node is not None: # 有完整的主谓宾 @@ -93,7 +90,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> else: # 只有主谓 parts.append(f"{subject_text}{topic_text}") - + # 添加属性信息 if attributes: attr_parts = [] @@ -106,78 +103,78 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> for key, value in attributes.items(): if key not in ["时间", "地点"]: attr_parts.append(f"{key}:{value}") - + if attr_parts: parts.append(f"({' '.join(attr_parts)})") - + description = "".join(parts) - + # 6. 添加元数据(可选) if include_metadata: metadata_parts = [] - + # 记忆类型 if memory.memory_type: metadata_parts.append(f"类型:{memory.memory_type.value}") - + # 重要性 if memory.importance >= 0.8: metadata_parts.append("重要") elif memory.importance >= 0.6: metadata_parts.append("一般") - + # 时间(如果没有在属性中) if "时间" not in attributes: time_str = _format_relative_time(memory.created_at) if time_str: metadata_parts.append(time_str) - + if metadata_parts: description += f" [{', '.join(metadata_parts)}]" - + return description - + except Exception as e: logger.error(f"格式化记忆失败: {e}", exc_info=True) return f"(记忆格式化错误: {str(e)[:50]})" def format_memories_for_prompt( - memories: List[Memory], - max_count: Optional[int] = None, + memories: list[Memory], + max_count: int | None = None, include_metadata: bool = False, group_by_type: bool = False ) -> str: """ 批量格式化多条记忆为提示词文本 - + Args: memories: 记忆列表 max_count: 最大记忆数量(可选) include_metadata: 是否包含元数据 group_by_type: 是否按类型分组 - + Returns: 格式化后的文本,包含标题和列表 """ if not memories: return "" - + # 限制数量 if max_count: memories = memories[:max_count] - + # 按类型分组 if group_by_type: - type_groups: Dict[MemoryType, List[Memory]] = {} + type_groups: dict[MemoryType, list[Memory]] = {} for memory in memories: if memory.memory_type not in type_groups: type_groups[memory.memory_type] = [] type_groups[memory.memory_type].append(memory) - + # 构建分组文本 parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] - + type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION] for mem_type in type_order: if mem_type in type_groups: @@ -186,33 +183,33 @@ def format_memories_for_prompt( desc = format_memory_for_prompt(memory, include_metadata) parts.append(f"- {desc}") parts.append("") - + return "\n".join(parts) - + else: # 不分组,直接列出 parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] - + for memory in memories: # 获取类型标签 type_label = memory.memory_type.value if memory.memory_type else "未知" - + # 格式化记忆内容 desc = format_memory_for_prompt(memory, include_metadata) - + # 添加类型标签 parts.append(f"- **[{type_label}]** {desc}") - + return "\n".join(parts) def get_memory_type_label(memory_type: str) -> str: """ 获取记忆类型的中文标签 - + Args: memory_type: 记忆类型(可能是英文或中文) - + Returns: 中文标签 """ @@ -243,27 +240,27 @@ def get_memory_type_label(memory_type: str) -> str: "经历": "经历", "情境": "情境", } - + # 转换为小写进行匹配 memory_type_lower = memory_type.lower() if memory_type else "" - + return type_mapping.get(memory_type_lower, "未知") -def _format_relative_time(timestamp: datetime) -> Optional[str]: +def _format_relative_time(timestamp: datetime) -> str | None: """ 格式化相对时间(如"2天前"、"刚才") - + Args: timestamp: 时间戳 - + Returns: 相对时间描述,如果太久远则返回None """ try: now = datetime.now() delta = now - timestamp - + if delta.total_seconds() < 60: return "刚才" elif delta.total_seconds() < 3600: @@ -290,17 +287,17 @@ def _format_relative_time(timestamp: datetime) -> Optional[str]: def format_memory_summary(memory: Memory) -> str: """ 生成记忆的简短摘要(用于日志和调试) - + Args: memory: 记忆对象 - + Returns: 简短摘要 """ try: subject_node = memory.get_subject_node() subject_text = subject_node.content if subject_node else "?" - + topic_text = "?" for edge in memory.edges: if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id: @@ -308,7 +305,7 @@ def format_memory_summary(memory: Memory) -> str: if topic_node: topic_text = topic_node.content break - + return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}" except Exception: return f"记忆 {memory.id[:8]}" @@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str: # 导出主要函数 __all__ = [ - 'format_memory_for_prompt', - 'format_memories_for_prompt', - 'get_memory_type_label', - 'format_memory_summary', + "format_memories_for_prompt", + "format_memory_for_prompt", + "format_memory_summary", + "get_memory_type_label", ] diff --git a/src/memory_graph/utils/time_parser.py b/src/memory_graph/utils/time_parser.py index bde10c133..99498c689 100644 --- a/src/memory_graph/utils/time_parser.py +++ b/src/memory_graph/utils/time_parser.py @@ -14,7 +14,6 @@ from __future__ import annotations import re from datetime import datetime, timedelta -from typing import Optional, Tuple from src.common.logger import get_logger @@ -24,26 +23,26 @@ logger = get_logger(__name__) class TimeParser: """ 时间解析器 - + 负责将自然语言时间表达转换为标准化的绝对时间 """ - def __init__(self, reference_time: Optional[datetime] = None): + def __init__(self, reference_time: datetime | None = None): """ 初始化时间解析器 - + Args: reference_time: 参考时间(通常是当前时间) """ self.reference_time = reference_time or datetime.now() - def parse(self, time_str: str) -> Optional[datetime]: + def parse(self, time_str: str) -> datetime | None: """ 解析时间字符串 - + Args: time_str: 时间字符串 - + Returns: 标准化的datetime对象,如果解析失败则返回None """ @@ -81,7 +80,7 @@ class TimeParser: logger.warning(f"无法解析时间: '{time_str}',使用当前时间") return self.reference_time - def _parse_relative_day(self, time_str: str) -> Optional[datetime]: + def _parse_relative_day(self, time_str: str) -> datetime | None: """ 解析相对日期:今天、明天、昨天、前天、后天 """ @@ -108,7 +107,7 @@ class TimeParser: return None - def _parse_days_ago(self, time_str: str) -> Optional[datetime]: + def _parse_days_ago(self, time_str: str) -> datetime | None: """ 解析 X天前/X天后、X周前/X周后、X个月前/X个月后 """ @@ -172,7 +171,7 @@ class TimeParser: return None - def _parse_hours_ago(self, time_str: str) -> Optional[datetime]: + def _parse_hours_ago(self, time_str: str) -> datetime | None: """ 解析 X小时前/X小时后、X分钟前/X分钟后 """ @@ -204,7 +203,7 @@ class TimeParser: return None - def _parse_week_month_year(self, time_str: str) -> Optional[datetime]: + def _parse_week_month_year(self, time_str: str) -> datetime | None: """ 解析:上周、上个月、去年、本周、本月、今年 """ @@ -232,7 +231,7 @@ class TimeParser: return None - def _parse_specific_date(self, time_str: str) -> Optional[datetime]: + def _parse_specific_date(self, time_str: str) -> datetime | None: """ 解析具体日期: - 2025-11-05 @@ -266,7 +265,7 @@ class TimeParser: return None - def _parse_time_of_day(self, time_str: str) -> Optional[datetime]: + def _parse_time_of_day(self, time_str: str) -> datetime | None: """ 解析一天中的时间: - 早上、上午、中午、下午、晚上、深夜 @@ -290,7 +289,7 @@ class TimeParser: } # 先检查是否有具体时间点:早上8点、下午3点 - for period, default_hour in time_periods.items(): + for period in time_periods.keys(): pattern = rf"{period}(\d{{1,2}})点?" match = re.search(pattern, time_str) if match: @@ -314,13 +313,13 @@ class TimeParser: return None - def _parse_combined_time(self, time_str: str) -> Optional[datetime]: + def _parse_combined_time(self, time_str: str) -> datetime | None: """ 解析组合时间表达:今天下午、昨天晚上、明天早上 """ # 先解析日期部分 date_result = None - + # 相对日期关键词 relative_days = { "今天": 0, "今日": 0, @@ -330,16 +329,16 @@ class TimeParser: "后天": 2, "后日": 2, "大前天": -3, "大后天": 3, } - + for keyword, days in relative_days.items(): if keyword in time_str: date_result = self.reference_time + timedelta(days=days) date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0) break - + if not date_result: return None - + # 再解析时间段部分 time_periods = { "早上": 8, "早晨": 8, @@ -351,7 +350,7 @@ class TimeParser: "深夜": 23, "凌晨": 2, } - + for period, hour in time_periods.items(): if period in time_str: # 检查是否有具体时间点 @@ -363,17 +362,17 @@ class TimeParser: if period in ["下午", "晚上"] and hour < 12: hour += 12 return date_result.replace(hour=hour) - + # 如果没有时间段,返回日期(默认0点) return date_result def _chinese_num_to_int(self, num_str: str) -> int: """ 将中文数字转换为阿拉伯数字 - + Args: num_str: 中文数字字符串(如:"一"、"十"、"3") - + Returns: 整数 """ @@ -418,11 +417,11 @@ class TimeParser: def format_time(self, dt: datetime, format_type: str = "iso") -> str: """ 格式化时间 - + Args: dt: datetime对象 format_type: 格式类型 ("iso", "cn", "relative") - + Returns: 格式化的时间字符串 """ @@ -461,13 +460,13 @@ class TimeParser: return str(dt) - def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]: + def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]: """ 解析时间范围:最近一周、最近3天 - + Args: time_str: 时间范围字符串 - + Returns: (start_time, end_time) """