fix: 修复代码质量问题 - 更正异常处理和导入语句

Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-11-07 04:39:35 +00:00
parent 3bdcfa3dd4
commit 5caf630623
20 changed files with 893 additions and 910 deletions

View File

@@ -6,10 +6,12 @@ from typing import ClassVar
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, register_plugin from src.plugin_system import BasePlugin, register_plugin
from src.plugin_system.base.component_types import ComponentInfo, ToolInfo
logger = get_logger("memory_graph_plugin") logger = get_logger("memory_graph_plugin")
# 用于存储后台任务引用
_background_tasks = set()
@register_plugin @register_plugin
class MemoryGraphPlugin(BasePlugin): class MemoryGraphPlugin(BasePlugin):
@@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin):
"""插件卸载时的回调""" """插件卸载时的回调"""
try: try:
import asyncio import asyncio
from src.memory_graph.manager_singleton import shutdown_memory_manager from src.memory_graph.manager_singleton import shutdown_memory_manager
logger.info(f"{self.log_prefix} 正在关闭记忆系统...") logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
@@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
# 如果循环正在运行,创建任务 # 如果循环正在运行,创建任务
asyncio.create_task(shutdown_memory_manager()) task = asyncio.create_task(shutdown_memory_manager())
# 存储引用以防止任务被垃圾回收
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else: else:
# 如果循环未运行,直接运行 # 如果循环未运行,直接运行
loop.run_until_complete(shutdown_memory_manager()) loop.run_until_complete(shutdown_memory_manager())

View File

@@ -25,14 +25,13 @@ import asyncio
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -65,7 +64,7 @@ class MemoryDeduplicator:
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories()) self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆") logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
async def find_similar_pairs(self) -> List[Tuple[str, str, float]]: async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
""" """
查找所有相似的记忆对(通过向量相似度计算) 查找所有相似的记忆对(通过向量相似度计算)
@@ -144,7 +143,7 @@ class MemoryDeduplicator:
logger.error(f"计算余弦相似度失败: {e}") logger.error(f"计算余弦相似度失败: {e}")
return 0.0 return 0.0
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]: def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
""" """
决定保留哪个记忆,删除哪个 决定保留哪个记忆,删除哪个
@@ -197,7 +196,7 @@ class MemoryDeduplicator:
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id) keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id) remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
logger.info(f"") logger.info("")
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):") logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
logger.info(f" 保留: {keep_id}") logger.info(f" 保留: {keep_id}")
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}") logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
@@ -221,14 +220,14 @@ class MemoryDeduplicator:
keep_mem.activation = min(1.0, keep_mem.activation + 0.05) keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
# 累加访问次数 # 累加访问次数
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'): if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count keep_mem.access_count += remove_mem.access_count
# 删除相似记忆 # 删除相似记忆
await self.manager.delete_memory(remove_id) await self.manager.delete_memory(remove_id)
self.stats["duplicates_removed"] += 1 self.stats["duplicates_removed"] += 1
logger.info(f" ✅ 删除成功") logger.info(" ✅ 删除成功")
# 让出控制权 # 让出控制权
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@@ -6,24 +6,24 @@
from src.memory_graph.manager import MemoryManager from src.memory_graph.manager import MemoryManager
from src.memory_graph.models import ( from src.memory_graph.models import (
EdgeType,
Memory, Memory,
MemoryEdge, MemoryEdge,
MemoryNode, MemoryNode,
MemoryStatus, MemoryStatus,
MemoryType, MemoryType,
NodeType, NodeType,
EdgeType,
) )
__all__ = [ __all__ = [
"MemoryManager", "EdgeType",
"Memory", "Memory",
"MemoryNode",
"MemoryEdge", "MemoryEdge",
"MemoryManager",
"MemoryNode",
"MemoryStatus",
"MemoryType", "MemoryType",
"NodeType", "NodeType",
"EdgeType",
"MemoryStatus",
] ]
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.core.node_merger import NodeMerger from src.memory_graph.core.node_merger import NodeMerger
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"] __all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any
import numpy as np import numpy as np
@@ -16,7 +16,6 @@ from src.memory_graph.models import (
MemoryEdge, MemoryEdge,
MemoryNode, MemoryNode,
MemoryStatus, MemoryStatus,
MemoryType,
NodeType, NodeType,
) )
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
@@ -41,7 +40,7 @@ class MemoryBuilder:
self, self,
vector_store: VectorStore, vector_store: VectorStore,
graph_store: GraphStore, graph_store: GraphStore,
embedding_generator: Optional[Any] = None, embedding_generator: Any | None = None,
): ):
""" """
初始化记忆构建器 初始化记忆构建器
@@ -55,7 +54,7 @@ class MemoryBuilder:
self.graph_store = graph_store self.graph_store = graph_store
self.embedding_generator = embedding_generator self.embedding_generator = embedding_generator
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory: async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
""" """
构建完整的记忆对象 构建完整的记忆对象
@@ -97,7 +96,7 @@ class MemoryBuilder:
edges.append(memory_type_edge) edges.append(memory_type_edge)
# 4. 如果有客体,创建客体节点并连接 # 4. 如果有客体,创建客体节点并连接
if "object" in extracted_params and extracted_params["object"]: if extracted_params.get("object"):
object_node = await self._create_object_node( object_node = await self._create_object_node(
content=extracted_params["object"], memory_id=memory_id content=extracted_params["object"], memory_id=memory_id
) )
@@ -258,11 +257,11 @@ class MemoryBuilder:
async def _process_attributes( async def _process_attributes(
self, self,
attributes: Dict[str, Any], attributes: dict[str, Any],
parent_id: str, parent_id: str,
memory_id: str, memory_id: str,
importance: float, importance: float,
) -> tuple[List[MemoryNode], List[MemoryEdge]]: ) -> tuple[list[MemoryNode], list[MemoryEdge]]:
""" """
处理属性,构建属性子图 处理属性,构建属性子图
@@ -341,7 +340,7 @@ class MemoryBuilder:
async def _find_existing_node( async def _find_existing_node(
self, content: str, node_type: NodeType self, content: str, node_type: NodeType
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找已存在的完全匹配节点(用于主体和属性) 查找已存在的完全匹配节点(用于主体和属性)
@@ -369,7 +368,7 @@ class MemoryBuilder:
async def _find_similar_topic( async def _find_similar_topic(
self, content: str, embedding: np.ndarray self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找相似的主题节点(基于语义相似度) 查找相似的主题节点(基于语义相似度)
@@ -414,7 +413,7 @@ class MemoryBuilder:
async def _find_similar_object( async def _find_similar_object(
self, content: str, embedding: np.ndarray self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找相似的客体节点(基于语义相似度) 查找相似的客体节点(基于语义相似度)
@@ -525,7 +524,7 @@ class MemoryBuilder:
logger.error(f"记忆关联失败: {e}", exc_info=True) logger.error(f"记忆关联失败: {e}", exc_info=True)
raise RuntimeError(f"记忆关联失败: {e}") raise RuntimeError(f"记忆关联失败: {e}")
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]: def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
"""查找记忆中的主题节点""" """查找记忆中的主题节点"""
for node in memory.nodes: for node in memory.nodes:
if node.node_type == NodeType.TOPIC: if node.node_type == NodeType.TOPIC:

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import MemoryType from src.memory_graph.models import MemoryType
@@ -25,7 +25,7 @@ class MemoryExtractor:
4. 清洗和格式化数据 4. 清洗和格式化数据
""" """
def __init__(self, time_parser: Optional[TimeParser] = None): def __init__(self, time_parser: TimeParser | None = None):
""" """
初始化记忆提取器 初始化记忆提取器
@@ -34,7 +34,7 @@ class MemoryExtractor:
""" """
self.time_parser = time_parser or TimeParser() self.time_parser = time_parser or TimeParser()
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]: def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
""" """
从工具参数中提取记忆元素 从工具参数中提取记忆元素
@@ -64,11 +64,11 @@ class MemoryExtractor:
} }
# 3. 提取可选的客体 # 3. 提取可选的客体
if "object" in params and params["object"]: if params.get("object"):
extracted["object"] = self._clean_text(params["object"]) extracted["object"] = self._clean_text(params["object"])
# 4. 提取和标准化属性 # 4. 提取和标准化属性
if "attributes" in params and params["attributes"]: if params.get("attributes"):
extracted["attributes"] = self._process_attributes(params["attributes"]) extracted["attributes"] = self._process_attributes(params["attributes"])
else: else:
extracted["attributes"] = {} extracted["attributes"] = {}
@@ -86,7 +86,7 @@ class MemoryExtractor:
logger.error(f"记忆提取失败: {e}", exc_info=True) logger.error(f"记忆提取失败: {e}", exc_info=True)
raise ValueError(f"记忆提取失败: {e}") raise ValueError(f"记忆提取失败: {e}")
def _validate_required_params(self, params: Dict[str, Any]) -> None: def _validate_required_params(self, params: dict[str, Any]) -> None:
""" """
验证必需参数 验证必需参数
@@ -181,7 +181,7 @@ class MemoryExtractor:
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5") logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
return 0.5 return 0.5
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
""" """
处理属性字典 处理属性字典
@@ -222,7 +222,7 @@ class MemoryExtractor:
return processed return processed
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]: def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
""" """
提取记忆关联参数(用于 link_memories 工具) 提取记忆关联参数(用于 link_memories 工具)

View File

@@ -4,11 +4,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.official_configs import MemoryConfig from src.config.official_configs import MemoryConfig
from src.memory_graph.models import MemoryNode, NodeType from src.memory_graph.models import MemoryNode, NodeType
@@ -54,9 +49,9 @@ class NodeMerger:
async def find_similar_nodes( async def find_similar_nodes(
self, self,
node: MemoryNode, node: MemoryNode,
threshold: Optional[float] = None, threshold: float | None = None,
limit: int = 5, limit: int = 5,
) -> List[Tuple[MemoryNode, float]]: ) -> list[tuple[MemoryNode, float]]:
""" """
查找与指定节点相似的节点 查找与指定节点相似的节点
@@ -207,7 +202,7 @@ class NodeMerger:
# 如果有 30% 以上的邻居重叠,认为上下文匹配 # 如果有 30% 以上的邻居重叠,认为上下文匹配
return overlap_ratio > 0.3 return overlap_ratio > 0.3
def _get_node_content(self, node_id: str) -> Optional[str]: def _get_node_content(self, node_id: str) -> str | None:
"""获取节点的内容""" """获取节点的内容"""
memories = self.graph_store.get_memories_by_node(node_id) memories = self.graph_store.get_memories_by_node(node_id)
if memories: if memories:
@@ -280,8 +275,8 @@ class NodeMerger:
async def batch_merge_similar_nodes( async def batch_merge_similar_nodes(
self, self,
nodes: List[MemoryNode], nodes: list[MemoryNode],
progress_callback: Optional[callable] = None, progress_callback: callable | None = None,
) -> dict: ) -> dict:
""" """
批量处理节点合并 批量处理节点合并
@@ -344,7 +339,7 @@ class NodeMerger:
self, self,
min_similarity: float = 0.85, min_similarity: float = 0.85,
limit: int = 100, limit: int = 100,
) -> List[Tuple[str, str, float]]: ) -> list[tuple[str, str, float]]:
""" """
获取待合并的候选节点对 获取待合并的候选节点对

View File

@@ -10,22 +10,21 @@
import asyncio import asyncio
import logging import logging
import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any
from src.config.config import global_config from src.config.config import global_config
from src.config.official_configs import MemoryConfig from src.config.official_configs import MemoryConfig
from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode, MemoryType, NodeType, EdgeType from src.memory_graph.models import EdgeType, Memory, MemoryEdge, NodeType
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.storage.vector_store import VectorStore
from src.memory_graph.tools.memory_tools import MemoryTools from src.memory_graph.tools.memory_tools import MemoryTools
from src.memory_graph.utils.embeddings import EmbeddingGenerator from src.memory_graph.utils.embeddings import EmbeddingGenerator
import uuid
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
@@ -46,7 +45,7 @@ class MemoryManager:
def __init__( def __init__(
self, self,
data_dir: Optional[Path] = None, data_dir: Path | None = None,
): ):
""" """
初始化记忆管理器 初始化记忆管理器
@@ -55,29 +54,29 @@ class MemoryManager:
data_dir: 数据目录可选默认从global_config读取 data_dir: 数据目录可选默认从global_config读取
""" """
# 直接使用 global_config.memory # 直接使用 global_config.memory
if not global_config.memory or not getattr(global_config.memory, 'enable', False): if not global_config.memory or not getattr(global_config.memory, "enable", False):
raise ValueError("记忆系统未启用,请在配置文件中启用 [memory] enable = true") raise ValueError("记忆系统未启用,请在配置文件中启用 [memory] enable = true")
self.config: MemoryConfig = global_config.memory self.config: MemoryConfig = global_config.memory
self.data_dir = data_dir or Path(getattr(self.config, 'data_dir', 'data/memory_graph')) self.data_dir = data_dir or Path(getattr(self.config, "data_dir", "data/memory_graph"))
# 存储组件 # 存储组件
self.vector_store: Optional[VectorStore] = None self.vector_store: VectorStore | None = None
self.graph_store: Optional[GraphStore] = None self.graph_store: GraphStore | None = None
self.persistence: Optional[PersistenceManager] = None self.persistence: PersistenceManager | None = None
# 核心组件 # 核心组件
self.embedding_generator: Optional[EmbeddingGenerator] = None self.embedding_generator: EmbeddingGenerator | None = None
self.extractor: Optional[MemoryExtractor] = None self.extractor: MemoryExtractor | None = None
self.builder: Optional[MemoryBuilder] = None self.builder: MemoryBuilder | None = None
self.tools: Optional[MemoryTools] = None self.tools: MemoryTools | None = None
# 状态 # 状态
self._initialized = False self._initialized = False
self._last_maintenance = datetime.now() self._last_maintenance = datetime.now()
self._maintenance_task: Optional[asyncio.Task] = None self._maintenance_task: asyncio.Task | None = None
self._maintenance_interval_hours = getattr(self.config, 'consolidation_interval_hours', 1.0) self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
self._maintenance_schedule_id: Optional[str] = None # 调度任务ID self._maintenance_schedule_id: str | None = None # 调度任务ID
logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})") logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
@@ -101,8 +100,8 @@ class MemoryManager:
self.data_dir.mkdir(parents=True, exist_ok=True) self.data_dir.mkdir(parents=True, exist_ok=True)
# 获取存储配置 # 获取存储配置
storage_config = getattr(self.config, 'storage', None) storage_config = getattr(self.config, "storage", None)
vector_collection_name = getattr(storage_config, 'vector_collection_name', 'memory_graph') if storage_config else 'memory_graph' vector_collection_name = getattr(storage_config, "vector_collection_name", "memory_graph") if storage_config else "memory_graph"
self.vector_store = VectorStore( self.vector_store = VectorStore(
collection_name=vector_collection_name, collection_name=vector_collection_name,
@@ -203,11 +202,11 @@ class MemoryManager:
subject: str, subject: str,
memory_type: str, memory_type: str,
topic: str, topic: str,
object: Optional[str] = None, object: str | None = None,
attributes: Optional[Dict[str, str]] = None, attributes: dict[str, str] | None = None,
importance: float = 0.5, importance: float = 0.5,
**kwargs, **kwargs,
) -> Optional[Memory]: ) -> Memory | None:
""" """
创建新记忆 创建新记忆
@@ -250,7 +249,7 @@ class MemoryManager:
logger.error(f"创建记忆时发生异常: {e}", exc_info=True) logger.error(f"创建记忆时发生异常: {e}", exc_info=True)
return None return None
async def get_memory(self, memory_id: str) -> Optional[Memory]: async def get_memory(self, memory_id: str) -> Memory | None:
""" """
根据 ID 获取记忆 根据 ID 获取记忆
@@ -348,8 +347,8 @@ class MemoryManager:
async def generate_multi_queries( async def generate_multi_queries(
self, self,
query: str, query: str,
context: Optional[Dict[str, Any]] = None, context: dict[str, Any] | None = None,
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
使用小模型生成多个查询语句(用于多路召回) 使用小模型生成多个查询语句(用于多路召回)
@@ -364,8 +363,8 @@ class MemoryManager:
List of (query_string, weight) - 查询语句和权重 List of (query_string, weight) - 查询语句和权重
""" """
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.utils_small, model_set=model_config.model_task_config.utils_small,
@@ -407,9 +406,10 @@ class MemoryManager:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
# 解析JSON # 解析JSON
import json, re import json
response = re.sub(r'```json\s*', '', response) import re
response = re.sub(r'```\s*$', '', response).strip() response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
try: try:
data = json.loads(response) data = json.loads(response)
@@ -439,14 +439,14 @@ class MemoryManager:
self, self,
query: str, query: str,
top_k: int = 10, top_k: int = 10,
memory_types: Optional[List[str]] = None, memory_types: list[str] | None = None,
time_range: Optional[Tuple[datetime, datetime]] = None, time_range: tuple[datetime, datetime] | None = None,
min_importance: float = 0.0, min_importance: float = 0.0,
include_forgotten: bool = False, include_forgotten: bool = False,
use_multi_query: bool = True, use_multi_query: bool = True,
expand_depth: int | None = None, expand_depth: int | None = None,
context: Optional[Dict[str, Any]] = None, context: dict[str, Any] | None = None,
) -> List[Memory]: ) -> list[Memory]:
""" """
搜索记忆 搜索记忆
@@ -611,7 +611,7 @@ class MemoryManager:
# 计算时间衰减 # 计算时间衰减
last_access_dt = datetime.fromisoformat(last_access) last_access_dt = datetime.fromisoformat(last_access)
hours_passed = (now - last_access_dt).total_seconds() / 3600 hours_passed = (now - last_access_dt).total_seconds() / 3600
decay_rate = getattr(self.config, 'activation_decay_rate', 0.95) decay_rate = getattr(self.config, "activation_decay_rate", 0.95)
decay_factor = decay_rate ** (hours_passed / 24) decay_factor = decay_rate ** (hours_passed / 24)
current_activation = activation_info.get("level", 0.0) * decay_factor current_activation = activation_info.get("level", 0.0) * decay_factor
else: else:
@@ -631,15 +631,15 @@ class MemoryManager:
# 激活传播:激活相关记忆 # 激活传播:激活相关记忆
if strength > 0.1: # 只有足够强的激活才传播 if strength > 0.1: # 只有足够强的激活才传播
propagation_depth = getattr(self.config, 'activation_propagation_depth', 2) propagation_depth = getattr(self.config, "activation_propagation_depth", 2)
related_memories = self._get_related_memories( related_memories = self._get_related_memories(
memory_id, memory_id,
max_depth=propagation_depth max_depth=propagation_depth
) )
propagation_strength_factor = getattr(self.config, 'activation_propagation_strength', 0.5) propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5)
propagation_strength = strength * propagation_strength_factor propagation_strength = strength * propagation_strength_factor
max_related = getattr(self.config, 'max_related_memories', 5) max_related = getattr(self.config, "max_related_memories", 5)
for related_id in related_memories[:max_related]: for related_id in related_memories[:max_related]:
await self.activate_memory(related_id, propagation_strength) await self.activate_memory(related_id, propagation_strength)
@@ -652,7 +652,7 @@ class MemoryManager:
logger.error(f"激活记忆失败: {e}", exc_info=True) logger.error(f"激活记忆失败: {e}", exc_info=True)
return False return False
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> List[str]: def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]:
""" """
获取相关记忆 ID 列表(旧版本,保留用于激活传播) 获取相关记忆 ID 列表(旧版本,保留用于激活传播)
@@ -687,12 +687,12 @@ class MemoryManager:
async def expand_memories_with_semantic_filter( async def expand_memories_with_semantic_filter(
self, self,
initial_memory_ids: List[str], initial_memory_ids: list[str],
query_embedding: "np.ndarray", query_embedding: "np.ndarray",
max_depth: int = 2, max_depth: int = 2,
semantic_threshold: float = 0.5, semantic_threshold: float = 0.5,
max_expanded: int = 20 max_expanded: int = 20
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤 从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
@@ -712,12 +712,11 @@ class MemoryManager:
return [] return []
try: try:
import numpy as np
# 记录已访问的记忆,避免重复 # 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids) visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数 # 记录扩展的记忆及其分数
expanded_memories: Dict[str, float] = {} expanded_memories: dict[str, float] = {}
# BFS扩展 # BFS扩展
current_level = initial_memory_ids current_level = initial_memory_ids
@@ -738,7 +737,7 @@ class MemoryManager:
# 获取邻居节点 # 获取邻居节点
try: try:
neighbors = list(self.graph_store.graph.neighbors(node.id)) neighbors = list(self.graph_store.graph.neighbors(node.id))
except: except Exception:
continue continue
for neighbor_id in neighbors: for neighbor_id in neighbors:
@@ -764,7 +763,7 @@ class MemoryManager:
try: try:
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id) edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5 edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except: except Exception:
edge_importance = 0.5 edge_importance = 0.5
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%) # 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
@@ -785,7 +784,7 @@ class MemoryManager:
import json import json
try: try:
neighbor_memory_ids = json.loads(neighbor_memory_ids) neighbor_memory_ids = json.loads(neighbor_memory_ids)
except: except Exception:
neighbor_memory_ids = [neighbor_memory_ids] neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids: for neighbor_mem_id in neighbor_memory_ids:
@@ -909,7 +908,7 @@ class MemoryManager:
continue continue
# 跳过高重要性记忆 # 跳过高重要性记忆
min_importance = getattr(self.config, 'forgetting_min_importance', 7.0) min_importance = getattr(self.config, "forgetting_min_importance", 7.0)
if memory.importance >= min_importance: if memory.importance >= min_importance:
continue continue
@@ -939,7 +938,7 @@ class MemoryManager:
# ==================== 统计与维护 ==================== # ==================== 统计与维护 ====================
def get_statistics(self) -> Dict[str, Any]: def get_statistics(self) -> dict[str, Any]:
""" """
获取记忆系统统计信息 获取记忆系统统计信息
@@ -980,7 +979,7 @@ class MemoryManager:
similarity_threshold: float = 0.85, similarity_threshold: float = 0.85,
time_window_hours: float = 24.0, time_window_hours: float = 24.0,
max_batch_size: int = 50, max_batch_size: int = 50,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
整理记忆:直接合并去重相似记忆(不创建新边) 整理记忆:直接合并去重相似记忆(不创建新边)
@@ -1065,7 +1064,7 @@ class MemoryManager:
result["checked_count"] = len(recent_memories) result["checked_count"] = len(recent_memories)
# 按记忆类型分组,减少跨类型比较 # 按记忆类型分组,减少跨类型比较
memories_by_type: Dict[str, List[Memory]] = {} memories_by_type: dict[str, list[Memory]] = {}
for mem in recent_memories: for mem in recent_memories:
mem_type = mem.metadata.get("memory_type", "") mem_type = mem.metadata.get("memory_type", "")
if mem_type not in memories_by_type: if mem_type not in memories_by_type:
@@ -1073,7 +1072,7 @@ class MemoryManager:
memories_by_type[mem_type].append(mem) memories_by_type[mem_type].append(mem)
# 记录需要删除的记忆,延迟批量删除 # 记录需要删除的记忆,延迟批量删除
to_delete: List[Tuple[Memory, str]] = [] # (memory, reason) to_delete: list[tuple[Memory, str]] = [] # (memory, reason)
deleted_ids = set() deleted_ids = set()
# 对每个类型的记忆进行相似度检测 # 对每个类型的记忆进行相似度检测
@@ -1084,7 +1083,7 @@ class MemoryManager:
logger.debug(f"🔍 检查类型 '{mem_type}'{len(memories)} 条记忆") logger.debug(f"🔍 检查类型 '{mem_type}'{len(memories)} 条记忆")
# 预提取所有主题节点的嵌入向量 # 预提取所有主题节点的嵌入向量
embeddings_map: Dict[str, "np.ndarray"] = {} embeddings_map: dict[str, "np.ndarray"] = {}
valid_memories = [] valid_memories = []
for mem in memories: for mem in memories:
@@ -1094,7 +1093,6 @@ class MemoryManager:
valid_memories.append(mem) valid_memories.append(mem)
# 批量计算相似度矩阵(比逐个计算更高效) # 批量计算相似度矩阵(比逐个计算更高效)
import numpy as np
for i in range(len(valid_memories)): for i in range(len(valid_memories)):
# 更频繁的协作式多任务让出 # 更频繁的协作式多任务让出
@@ -1134,7 +1132,7 @@ class MemoryManager:
keep_mem.importance = min(1.0, keep_mem.importance + 0.05) keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
# 累加访问次数 # 累加访问次数
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'): if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count keep_mem.access_count += remove_mem.access_count
# 标记为待删除(不立即删除) # 标记为待删除(不立即删除)
@@ -1164,7 +1162,7 @@ class MemoryManager:
# 批量保存一次性写入减少I/O # 批量保存一次性写入减少I/O
await self.persistence.save_graph_store(self.graph_store) await self.persistence.save_graph_store(self.graph_store)
logger.info(f"💾 批量保存完成") logger.info("💾 批量保存完成")
logger.info(f"✅ 记忆整理完成: {result}") logger.info(f"✅ 记忆整理完成: {result}")
@@ -1207,10 +1205,10 @@ class MemoryManager:
async def auto_link_memories( async def auto_link_memories(
self, self,
time_window_hours: float = None, time_window_hours: float | None = None,
max_candidates: int = None, max_candidates: int | None = None,
min_confidence: float = None, min_confidence: float | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
自动关联记忆 自动关联记忆
@@ -1229,8 +1227,8 @@ class MemoryManager:
# 使用配置值或参数覆盖 # 使用配置值或参数覆盖
time_window_hours = time_window_hours if time_window_hours is not None else 24 time_window_hours = time_window_hours if time_window_hours is not None else 24
max_candidates = max_candidates if max_candidates is not None else getattr(self.config, 'auto_link_max_candidates', 10) max_candidates = max_candidates if max_candidates is not None else getattr(self.config, "auto_link_max_candidates", 10)
min_confidence = min_confidence if min_confidence is not None else getattr(self.config, 'auto_link_min_confidence', 0.7) min_confidence = min_confidence if min_confidence is not None else getattr(self.config, "auto_link_min_confidence", 0.7)
try: try:
logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...") logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...")
@@ -1361,9 +1359,9 @@ class MemoryManager:
async def _find_link_candidates( async def _find_link_candidates(
self, self,
memory: Memory, memory: Memory,
exclude_ids: Set[str], exclude_ids: set[str],
max_results: int = 5, max_results: int = 5,
) -> List[Memory]: ) -> list[Memory]:
""" """
为记忆寻找关联候选 为记忆寻找关联候选
@@ -1407,9 +1405,9 @@ class MemoryManager:
async def _analyze_memory_relations( async def _analyze_memory_relations(
self, self,
source_memory: Memory, source_memory: Memory,
candidate_memories: List[Memory], candidate_memories: list[Memory],
min_confidence: float = 0.7, min_confidence: float = 0.7,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
使用LLM分析记忆之间的关系 使用LLM分析记忆之间的关系
@@ -1426,8 +1424,8 @@ class MemoryManager:
- reasoning: 推理过程 - reasoning: 推理过程
""" """
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
# 构建LLM请求 # 构建LLM请求
llm = LLMRequest( llm = LLMRequest(
@@ -1496,7 +1494,7 @@ class MemoryManager:
import re import re
# 提取JSON # 提取JSON
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
if json_match: if json_match:
json_str = json_match.group(1) json_str = json_match.group(1)
else: else:
@@ -1507,7 +1505,7 @@ class MemoryManager:
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}") logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}")
# 尝试简单修复 # 尝试简单修复
json_str = re.sub(r'[\r\n\t]', '', json_str) json_str = re.sub(r"[\r\n\t]", "", json_str)
analysis_results = json.loads(json_str) analysis_results = json.loads(json_str)
# 转换为结果格式 # 转换为结果格式
@@ -1574,7 +1572,7 @@ class MemoryManager:
logger.warning(f"格式化记忆失败: {e}") logger.warning(f"格式化记忆失败: {e}")
return f"记忆ID: {memory.id}" return f"记忆ID: {memory.id}"
async def maintenance(self) -> Dict[str, Any]: async def maintenance(self) -> dict[str, Any]:
""" """
执行维护任务(优化版本) 执行维护任务(优化版本)
@@ -1604,12 +1602,12 @@ class MemoryManager:
start_time = datetime.now() start_time = datetime.now()
# 1. 记忆整理(异步后台执行,不阻塞主流程) # 1. 记忆整理(异步后台执行,不阻塞主流程)
if getattr(self.config, 'consolidation_enabled', False): if getattr(self.config, "consolidation_enabled", False):
logger.info("🚀 启动异步记忆整理任务...") logger.info("🚀 启动异步记忆整理任务...")
consolidate_result = await self.consolidate_memories( consolidate_result = await self.consolidate_memories(
similarity_threshold=getattr(self.config, 'consolidation_deduplication_threshold', 0.93), similarity_threshold=getattr(self.config, "consolidation_deduplication_threshold", 0.93),
time_window_hours=getattr(self.config, 'consolidation_time_window_hours', 2.0), # 统一时间窗口 time_window_hours=getattr(self.config, "consolidation_time_window_hours", 2.0), # 统一时间窗口
max_batch_size=getattr(self.config, 'consolidation_max_batch_size', 30) max_batch_size=getattr(self.config, "consolidation_max_batch_size", 30)
) )
if consolidate_result.get("task_started"): if consolidate_result.get("task_started"):
@@ -1620,16 +1618,16 @@ class MemoryManager:
logger.warning("❌ 记忆整理任务启动失败") logger.warning("❌ 记忆整理任务启动失败")
# 2. 自动关联记忆(使用统一的时间窗口) # 2. 自动关联记忆(使用统一的时间窗口)
if getattr(self.config, 'consolidation_linking_enabled', True): if getattr(self.config, "consolidation_linking_enabled", True):
logger.info("🔗 执行轻量级自动关联...") logger.info("🔗 执行轻量级自动关联...")
link_result = await self._lightweight_auto_link_memories() link_result = await self._lightweight_auto_link_memories()
result["linked"] = link_result.get("linked_count", 0) result["linked"] = link_result.get("linked_count", 0)
# 3. 自动遗忘(快速执行) # 3. 自动遗忘(快速执行)
if getattr(self.config, 'forgetting_enabled', True): if getattr(self.config, "forgetting_enabled", True):
logger.info("🗑️ 执行自动遗忘...") logger.info("🗑️ 执行自动遗忘...")
forgotten_count = await self.auto_forget_memories( forgotten_count = await self.auto_forget_memories(
threshold=getattr(self.config, 'forgetting_activation_threshold', 0.1) threshold=getattr(self.config, "forgetting_activation_threshold", 0.1)
) )
result["forgotten"] = forgotten_count result["forgotten"] = forgotten_count
@@ -1654,10 +1652,10 @@ class MemoryManager:
async def _lightweight_auto_link_memories( async def _lightweight_auto_link_memories(
self, self,
time_window_hours: float = None, # 从配置读取 time_window_hours: float | None = None, # 从配置读取
max_candidates: int = None, # 从配置读取 max_candidates: int | None = None, # 从配置读取
max_memories: int = None, # 从配置读取 max_memories: int | None = None, # 从配置读取
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
智能轻量级自动关联记忆保留LLM判断优化性能 智能轻量级自动关联记忆保留LLM判断优化性能
@@ -1676,11 +1674,11 @@ class MemoryManager:
# 从配置读取参数,使用统一的时间窗口 # 从配置读取参数,使用统一的时间窗口
if time_window_hours is None: if time_window_hours is None:
time_window_hours = getattr(self.config, 'consolidation_time_window_hours', 2.0) time_window_hours = getattr(self.config, "consolidation_time_window_hours", 2.0)
if max_candidates is None: if max_candidates is None:
max_candidates = getattr(self.config, 'consolidation_linking_max_candidates', 10) max_candidates = getattr(self.config, "consolidation_linking_max_candidates", 10)
if max_memories is None: if max_memories is None:
max_memories = getattr(self.config, 'consolidation_linking_max_memories', 20) max_memories = getattr(self.config, "consolidation_linking_max_memories", 20)
# 获取用户配置时间窗口内的记忆 # 获取用户配置时间窗口内的记忆
time_threshold = datetime.now() - timedelta(hours=time_window_hours) time_threshold = datetime.now() - timedelta(hours=time_window_hours)
@@ -1690,7 +1688,7 @@ class MemoryManager:
mem for mem in all_memories mem for mem in all_memories
if mem.created_at >= time_threshold if mem.created_at >= time_threshold
and not mem.metadata.get("forgotten", False) and not mem.metadata.get("forgotten", False)
and mem.importance >= getattr(self.config, 'consolidation_linking_min_importance', 0.5) # 从配置读取重要性阈值 and mem.importance >= getattr(self.config, "consolidation_linking_min_importance", 0.5) # 从配置读取重要性阈值
] ]
if len(recent_memories) > max_memories: if len(recent_memories) > max_memories:
@@ -1704,7 +1702,6 @@ class MemoryManager:
# 第一步:向量相似度预筛选,找到潜在关联对 # 第一步:向量相似度预筛选,找到潜在关联对
candidate_pairs = [] candidate_pairs = []
import numpy as np
for i, memory in enumerate(recent_memories): for i, memory in enumerate(recent_memories):
# 获取主题节点 # 获取主题节点
@@ -1733,7 +1730,7 @@ class MemoryManager:
) )
# 使用配置的预筛选阈值 # 使用配置的预筛选阈值
pre_filter_threshold = getattr(self.config, 'consolidation_linking_pre_filter_threshold', 0.7) pre_filter_threshold = getattr(self.config, "consolidation_linking_pre_filter_threshold", 0.7)
if similarity >= pre_filter_threshold: if similarity >= pre_filter_threshold:
candidate_pairs.append((memory, other_memory, similarity)) candidate_pairs.append((memory, other_memory, similarity))
@@ -1747,7 +1744,7 @@ class MemoryManager:
return result return result
# 第二步批量LLM分析使用配置的最大候选对数 # 第二步批量LLM分析使用配置的最大候选对数
max_pairs_for_llm = getattr(self.config, 'consolidation_linking_max_pairs_for_llm', 5) max_pairs_for_llm = getattr(self.config, "consolidation_linking_max_pairs_for_llm", 5)
if len(candidate_pairs) <= max_pairs_for_llm: if len(candidate_pairs) <= max_pairs_for_llm:
link_relations = await self._batch_analyze_memory_relations(candidate_pairs) link_relations = await self._batch_analyze_memory_relations(candidate_pairs)
result["llm_calls"] = 1 result["llm_calls"] = 1
@@ -1810,8 +1807,8 @@ class MemoryManager:
async def _batch_analyze_memory_relations( async def _batch_analyze_memory_relations(
self, self,
candidate_pairs: List[Tuple[Memory, Memory, float]] candidate_pairs: list[tuple[Memory, Memory, float]]
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
批量分析记忆关系优化LLM调用 批量分析记忆关系优化LLM调用
@@ -1822,8 +1819,8 @@ class MemoryManager:
关系分析结果列表 关系分析结果列表
""" """
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.utils_small, model_set=model_config.model_task_config.utils_small,
@@ -1843,7 +1840,7 @@ class MemoryManager:
""" """
# 构建批量分析提示词(使用配置的置信度阈值) # 构建批量分析提示词(使用配置的置信度阈值)
min_confidence = getattr(self.config, 'consolidation_linking_min_confidence', 0.7) min_confidence = getattr(self.config, "consolidation_linking_min_confidence", 0.7)
prompt = f"""你是记忆关系分析专家。请批量分析以下候选记忆对之间的关系。 prompt = f"""你是记忆关系分析专家。请批量分析以下候选记忆对之间的关系。
@@ -1885,8 +1882,8 @@ class MemoryManager:
请分析并输出JSON结果""" 请分析并输出JSON结果"""
# 调用LLM使用配置的参数 # 调用LLM使用配置的参数
llm_temperature = getattr(self.config, 'consolidation_linking_llm_temperature', 0.2) llm_temperature = getattr(self.config, "consolidation_linking_llm_temperature", 0.2)
llm_max_tokens = getattr(self.config, 'consolidation_linking_llm_max_tokens', 1500) llm_max_tokens = getattr(self.config, "consolidation_linking_llm_max_tokens", 1500)
response, _ = await llm.generate_response_async( response, _ = await llm.generate_response_async(
prompt, prompt,
@@ -1899,7 +1896,7 @@ class MemoryManager:
import re import re
# 提取JSON # 提取JSON
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
if json_match: if json_match:
json_str = json_match.group(1) json_str = json_match.group(1)
else: else:
@@ -1910,7 +1907,7 @@ class MemoryManager:
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}") logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}")
# 尝试简单修复 # 尝试简单修复
json_str = re.sub(r'[\r\n\t]', '', json_str) json_str = re.sub(r"[\r\n\t]", "", json_str)
analysis_results = json.loads(json_str) analysis_results = json.loads(json_str)
# 转换为结果格式 # 转换为结果格式

View File

@@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.manager import MemoryManager from src.memory_graph.manager import MemoryManager
@@ -15,13 +14,13 @@ from src.memory_graph.manager import MemoryManager
logger = get_logger(__name__) logger = get_logger(__name__)
# 全局 MemoryManager 实例 # 全局 MemoryManager 实例
_memory_manager: Optional[MemoryManager] = None _memory_manager: MemoryManager | None = None
_initialized: bool = False _initialized: bool = False
async def initialize_memory_manager( async def initialize_memory_manager(
data_dir: Optional[Path | str] = None, data_dir: Path | str | None = None,
) -> Optional[MemoryManager]: ) -> MemoryManager | None:
""" """
初始化全局 MemoryManager 初始化全局 MemoryManager
@@ -43,7 +42,7 @@ async def initialize_memory_manager(
from src.config.config import global_config from src.config.config import global_config
# 检查是否启用 # 检查是否启用
if not global_config.memory or not getattr(global_config.memory, 'enable', False): if not global_config.memory or not getattr(global_config.memory, "enable", False):
logger.info("记忆图系统已在配置中禁用") logger.info("记忆图系统已在配置中禁用")
_initialized = False _initialized = False
_memory_manager = None _memory_manager = None
@@ -51,7 +50,7 @@ async def initialize_memory_manager(
# 处理数据目录 # 处理数据目录
if data_dir is None: if data_dir is None:
data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph') data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
if isinstance(data_dir, str): if isinstance(data_dir, str):
data_dir = Path(data_dir) data_dir = Path(data_dir)
@@ -72,7 +71,7 @@ async def initialize_memory_manager(
raise raise
def get_memory_manager() -> Optional[MemoryManager]: def get_memory_manager() -> MemoryManager | None:
""" """
获取全局 MemoryManager 实例 获取全局 MemoryManager 实例

View File

@@ -10,7 +10,7 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any
import numpy as np import numpy as np
@@ -60,8 +60,8 @@ class MemoryNode:
id: str # 节点唯一ID id: str # 节点唯一ID
content: str # 节点内容(如:"我"、"吃饭"、"白米饭" content: str # 节点内容(如:"我"、"吃饭"、"白米饭"
node_type: NodeType # 节点类型 node_type: NodeType # 节点类型
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要) embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self): def __post_init__(self):
@@ -69,7 +69,7 @@ class MemoryNode:
if not self.id: if not self.id:
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -81,7 +81,7 @@ class MemoryNode:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode: def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
"""从字典创建节点""" """从字典创建节点"""
embedding = None embedding = None
if data.get("embedding") is not None: if data.get("embedding") is not None:
@@ -114,7 +114,7 @@ class MemoryEdge:
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为" relation: str # 关系名称(如:"是"、"做"、"时间"、"因为"
edge_type: EdgeType # 边类型 edge_type: EdgeType # 边类型
importance: float = 0.5 # 重要性 [0-1] importance: float = 0.5 # 重要性 [0-1]
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self): def __post_init__(self):
@@ -124,7 +124,7 @@ class MemoryEdge:
# 确保重要性在有效范围内 # 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance)) self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -138,7 +138,7 @@ class MemoryEdge:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge: def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
"""从字典创建边""" """从字典创建边"""
return cls( return cls(
id=data["id"], id=data["id"],
@@ -162,8 +162,8 @@ class Memory:
id: str # 记忆唯一ID id: str # 记忆唯一ID
subject_id: str # 主体节点ID subject_id: str # 主体节点ID
memory_type: MemoryType # 记忆类型 memory_type: MemoryType # 记忆类型
nodes: List[MemoryNode] # 该记忆包含的所有节点 nodes: list[MemoryNode] # 该记忆包含的所有节点
edges: List[MemoryEdge] # 该记忆包含的所有边 edges: list[MemoryEdge] # 该记忆包含的所有边
importance: float = 0.5 # 整体重要性 [0-1] importance: float = 0.5 # 整体重要性 [0-1]
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘 activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态 status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
@@ -171,7 +171,7 @@ class Memory:
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间 last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
access_count: int = 0 # 访问次数 access_count: int = 0 # 访问次数
decay_factor: float = 1.0 # 衰减因子(随时间变化) decay_factor: float = 1.0 # 衰减因子(随时间变化)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
def __post_init__(self): def __post_init__(self):
"""后初始化处理""" """后初始化处理"""
@@ -181,7 +181,7 @@ class Memory:
self.importance = max(0.0, min(1.0, self.importance)) self.importance = max(0.0, min(1.0, self.importance))
self.activation = max(0.0, min(1.0, self.activation)) self.activation = max(0.0, min(1.0, self.activation))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -200,7 +200,7 @@ class Memory:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> Memory: def from_dict(cls, data: dict[str, Any]) -> Memory:
"""从字典创建记忆""" """从字典创建记忆"""
return cls( return cls(
id=data["id"], id=data["id"],
@@ -223,14 +223,14 @@ class Memory:
self.last_accessed = datetime.now() self.last_accessed = datetime.now()
self.access_count += 1 self.access_count += 1
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]: def get_node_by_id(self, node_id: str) -> MemoryNode | None:
"""根据ID获取节点""" """根据ID获取节点"""
for node in self.nodes: for node in self.nodes:
if node.id == node_id: if node.id == node_id:
return node return node
return None return None
def get_subject_node(self) -> Optional[MemoryNode]: def get_subject_node(self) -> MemoryNode | None:
"""获取主体节点""" """获取主体节点"""
return self.get_node_by_id(self.subject_id) return self.get_node_by_id(self.subject_id)
@@ -274,10 +274,10 @@ class StagedMemory:
memory: Memory # 原始记忆对象 memory: Memory # 原始记忆对象
status: MemoryStatus = MemoryStatus.STAGED # 状态 status: MemoryStatus = MemoryStatus.STAGED # 状态
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
consolidated_at: Optional[datetime] = None # 整理时间 consolidated_at: datetime | None = None # 整理时间
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表 merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典""" """转换为字典"""
return { return {
"memory": self.memory.to_dict(), "memory": self.memory.to_dict(),
@@ -288,7 +288,7 @@ class StagedMemory:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory: def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
"""从字典创建临时记忆""" """从字典创建临时记忆"""
return cls( return cls(
memory=Memory.from_dict(data["memory"]), memory=Memory.from_dict(data["memory"]),

View File

@@ -58,7 +58,7 @@ class CreateMemoryTool(BaseTool):
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]), ("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None), ("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None), ("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None),
("attributes", ToolParamType.STRING, "详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None), ("attributes", ToolParamType.STRING, '详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None), ("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None),
] ]
@@ -124,7 +124,7 @@ class CreateMemoryTool(BaseTool):
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True) logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"创建记忆时出错: {str(e)}" "content": f"创建记忆时出错: {e!s}"
} }
@@ -184,7 +184,7 @@ class LinkMemoriesTool(BaseTool):
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True) logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"关联记忆时出错: {str(e)}" "content": f"关联记忆时出错: {e!s}"
} }
@@ -254,5 +254,5 @@ class SearchMemoriesTool(BaseTool):
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True) logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"搜索记忆时出错: {str(e)}" "content": f"搜索记忆时出错: {e!s}"
} }

View File

@@ -5,4 +5,4 @@
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.storage.vector_store import VectorStore
__all__ = ["VectorStore", "GraphStore"] __all__ = ["GraphStore", "VectorStore"]

View File

@@ -4,12 +4,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Set, Tuple
import networkx as nx import networkx as nx
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode from src.memory_graph.models import Memory, MemoryEdge
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -31,10 +29,10 @@ class GraphStore:
self.graph = nx.DiGraph() self.graph = nx.DiGraph()
# 索引记忆ID -> 记忆对象 # 索引记忆ID -> 记忆对象
self.memory_index: Dict[str, Memory] = {} self.memory_index: dict[str, Memory] = {}
# 索引节点ID -> 所属记忆ID集合 # 索引节点ID -> 所属记忆ID集合
self.node_to_memories: Dict[str, Set[str]] = {} self.node_to_memories: dict[str, set[str]] = {}
logger.info("初始化图存储") logger.info("初始化图存储")
@@ -84,7 +82,7 @@ class GraphStore:
logger.error(f"添加记忆失败: {e}", exc_info=True) logger.error(f"添加记忆失败: {e}", exc_info=True)
raise raise
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]: def get_memory_by_id(self, memory_id: str) -> Memory | None:
""" """
根据ID获取记忆 根据ID获取记忆
@@ -96,7 +94,7 @@ class GraphStore:
""" """
return self.memory_index.get(memory_id) return self.memory_index.get(memory_id)
def get_all_memories(self) -> List[Memory]: def get_all_memories(self) -> list[Memory]:
""" """
获取所有记忆 获取所有记忆
@@ -105,7 +103,7 @@ class GraphStore:
""" """
return list(self.memory_index.values()) return list(self.memory_index.values())
def get_memories_by_node(self, node_id: str) -> List[Memory]: def get_memories_by_node(self, node_id: str) -> list[Memory]:
""" """
获取包含指定节点的所有记忆 获取包含指定节点的所有记忆
@@ -121,7 +119,7 @@ class GraphStore:
memory_ids = self.node_to_memories[node_id] memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index] return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]: def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
""" """
获取从指定节点出发的所有边 获取从指定节点出发的所有边
@@ -155,8 +153,8 @@ class GraphStore:
return edges return edges
def get_neighbors( def get_neighbors(
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
) -> List[Tuple[str, Dict]]: ) -> list[tuple[str, dict]]:
""" """
获取节点的邻居节点 获取节点的邻居节点
@@ -187,7 +185,7 @@ class GraphStore:
return neighbors return neighbors
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]: def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
""" """
查找两个节点之间的最短路径 查找两个节点之间的最短路径
@@ -220,10 +218,10 @@ class GraphStore:
def bfs_expand( def bfs_expand(
self, self,
start_nodes: List[str], start_nodes: list[str],
depth: int = 1, depth: int = 1,
relation_types: Optional[List[str]] = None, relation_types: list[str] | None = None,
) -> Set[str]: ) -> set[str]:
""" """
从起始节点进行广度优先搜索扩展 从起始节点进行广度优先搜索扩展
@@ -256,7 +254,7 @@ class GraphStore:
return visited return visited
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph: def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
""" """
获取包含指定节点的子图 获取包含指定节点的子图
@@ -308,7 +306,7 @@ class GraphStore:
logger.error(f"合并节点失败: {e}", exc_info=True) logger.error(f"合并节点失败: {e}", exc_info=True)
raise raise
def get_node_degree(self, node_id: str) -> Tuple[int, int]: def get_node_degree(self, node_id: str) -> tuple[int, int]:
""" """
获取节点的度数 获取节点的度数
@@ -323,7 +321,7 @@ class GraphStore:
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id)) return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
def get_statistics(self) -> Dict[str, int]: def get_statistics(self) -> dict[str, int]:
"""获取图的统计信息""" """获取图的统计信息"""
return { return {
"total_nodes": self.graph.number_of_nodes(), "total_nodes": self.graph.number_of_nodes(),
@@ -332,7 +330,7 @@ class GraphStore:
"connected_components": nx.number_weakly_connected_components(self.graph), "connected_components": nx.number_weakly_connected_components(self.graph),
} }
def to_dict(self) -> Dict: def to_dict(self) -> dict:
""" """
将图转换为字典(用于持久化) 将图转换为字典(用于持久化)
@@ -356,7 +354,7 @@ class GraphStore:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> GraphStore: def from_dict(cls, data: dict) -> GraphStore:
""" """
从字典加载图 从字典加载图
@@ -406,7 +404,6 @@ class GraphStore:
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。 规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
已存在的边(通过 edge.id 检查)将不会重复添加。 已存在的边(通过 edge.id 检查)将不会重复添加。
""" """
from src.memory_graph.models import MemoryEdge
# 构建快速查重索引memory_id -> set(edge_id) # 构建快速查重索引memory_id -> set(edge_id)
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()} existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}

View File

@@ -8,14 +8,12 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
import orjson import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import Memory, StagedMemory from src.memory_graph.models import StagedMemory
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -55,7 +53,7 @@ class PersistenceManager:
self.backup_dir.mkdir(parents=True, exist_ok=True) self.backup_dir.mkdir(parents=True, exist_ok=True)
self.auto_save_interval = auto_save_interval self.auto_save_interval = auto_save_interval
self._auto_save_task: Optional[asyncio.Task] = None self._auto_save_task: asyncio.Task | None = None
self._running = False self._running = False
logger.info(f"初始化持久化管理器: data_dir={data_dir}") logger.info(f"初始化持久化管理器: data_dir={data_dir}")
@@ -95,7 +93,7 @@ class PersistenceManager:
logger.error(f"保存图数据失败: {e}", exc_info=True) logger.error(f"保存图数据失败: {e}", exc_info=True)
raise raise
async def load_graph_store(self) -> Optional[GraphStore]: async def load_graph_store(self) -> GraphStore | None:
""" """
从文件加载图存储 从文件加载图存储
@@ -179,7 +177,7 @@ class PersistenceManager:
logger.error(f"加载临时记忆失败: {e}", exc_info=True) logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return [] return []
async def create_backup(self) -> Optional[Path]: async def create_backup(self) -> Path | None:
""" """
创建当前数据的备份 创建当前数据的备份
@@ -208,7 +206,7 @@ class PersistenceManager:
logger.error(f"创建备份失败: {e}", exc_info=True) logger.error(f"创建备份失败: {e}", exc_info=True)
return None return None
async def _load_from_backup(self) -> Optional[GraphStore]: async def _load_from_backup(self) -> GraphStore | None:
"""从最新的备份加载数据""" """从最新的备份加载数据"""
try: try:
# 查找最新的备份文件 # 查找最新的备份文件
@@ -254,7 +252,7 @@ class PersistenceManager:
async def start_auto_save( async def start_auto_save(
self, self,
graph_store: GraphStore, graph_store: GraphStore,
staged_memories_getter: callable = None, staged_memories_getter: callable | None = None,
) -> None: ) -> None:
""" """
启动自动保存任务 启动自动保存任务
@@ -334,7 +332,7 @@ class PersistenceManager:
logger.error(f"导出图数据失败: {e}", exc_info=True) logger.error(f"导出图数据失败: {e}", exc_info=True)
raise raise
async def import_from_json(self, input_file: Path) -> Optional[GraphStore]: async def import_from_json(self, input_file: Path) -> GraphStore | None:
""" """
从 JSON 文件导入图数据 从 JSON 文件导入图数据

View File

@@ -4,9 +4,8 @@
from __future__ import annotations from __future__ import annotations
import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any
import numpy as np import numpy as np
@@ -29,8 +28,8 @@ class VectorStore:
def __init__( def __init__(
self, self,
collection_name: str = "memory_nodes", collection_name: str = "memory_nodes",
data_dir: Optional[Path] = None, data_dir: Path | None = None,
embedding_function: Optional[Any] = None, embedding_function: Any | None = None,
): ):
""" """
初始化向量存储 初始化向量存储
@@ -103,7 +102,7 @@ class VectorStore:
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value metadata[key] = value
else: else:
@@ -122,7 +121,7 @@ class VectorStore:
logger.error(f"添加节点失败: {e}", exc_info=True) logger.error(f"添加节点失败: {e}", exc_info=True)
raise raise
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None: async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
""" """
批量添加节点 批量添加节点
@@ -151,7 +150,7 @@ class VectorStore:
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
@@ -175,9 +174,9 @@ class VectorStore:
self, self,
query_embedding: np.ndarray, query_embedding: np.ndarray,
limit: int = 10, limit: int = 10,
node_types: Optional[List[NodeType]] = None, node_types: list[NodeType] | None = None,
min_similarity: float = 0.0, min_similarity: float = 0.0,
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
搜索相似节点 搜索相似节点
@@ -226,10 +225,10 @@ class VectorStore:
# 解析 JSON 字符串回列表/字典 # 解析 JSON 字符串回列表/字典
for key, value in list(metadata.items()): for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
try: try:
metadata[key] = orjson.loads(value) metadata[key] = orjson.loads(value)
except: except Exception:
pass # 保持原值 pass # 保持原值
similar_nodes.append((node_id, similarity, metadata)) similar_nodes.append((node_id, similarity, metadata))
@@ -243,13 +242,13 @@ class VectorStore:
async def search_with_multiple_queries( async def search_with_multiple_queries(
self, self,
query_embeddings: List[np.ndarray], query_embeddings: list[np.ndarray],
query_weights: Optional[List[float]] = None, query_weights: list[float] | None = None,
limit: int = 10, limit: int = 10,
node_types: Optional[List[NodeType]] = None, node_types: list[NodeType] | None = None,
min_similarity: float = 0.0, min_similarity: float = 0.0,
fusion_strategy: str = "weighted_max", fusion_strategy: str = "weighted_max",
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
多查询融合搜索 多查询融合搜索
@@ -287,7 +286,7 @@ class VectorStore:
try: try:
# 1. 对每个查询执行搜索 # 1. 对每个查询执行搜索
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata} all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)): for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
# 搜索更多结果以提高融合质量 # 搜索更多结果以提高融合质量
@@ -356,7 +355,7 @@ class VectorStore:
logger.error(f"多查询融合搜索失败: {e}", exc_info=True) logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
raise raise
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]: async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
""" """
根据ID获取节点元数据 根据ID获取节点元数据

View File

@@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory, MemoryStatus from src.memory_graph.models import Memory
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.storage.vector_store import VectorStore
@@ -33,7 +33,7 @@ class MemoryTools:
vector_store: VectorStore, vector_store: VectorStore,
graph_store: GraphStore, graph_store: GraphStore,
persistence_manager: PersistenceManager, persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None, embedding_generator: EmbeddingGenerator | None = None,
max_expand_depth: int = 1, max_expand_depth: int = 1,
expand_semantic_threshold: float = 0.3, expand_semantic_threshold: float = 0.3,
): ):
@@ -72,7 +72,7 @@ class MemoryTools:
self._initialized = True self._initialized = True
@staticmethod @staticmethod
def get_create_memory_schema() -> Dict[str, Any]: def get_create_memory_schema() -> dict[str, Any]:
""" """
获取 create_memory 工具的 JSON schema 获取 create_memory 工具的 JSON schema
@@ -183,7 +183,7 @@ class MemoryTools:
} }
@staticmethod @staticmethod
def get_link_memories_schema() -> Dict[str, Any]: def get_link_memories_schema() -> dict[str, Any]:
""" """
获取 link_memories 工具的 JSON schema 获取 link_memories 工具的 JSON schema
@@ -239,7 +239,7 @@ class MemoryTools:
} }
@staticmethod @staticmethod
def get_search_memories_schema() -> Dict[str, Any]: def get_search_memories_schema() -> dict[str, Any]:
""" """
获取 search_memories 工具的 JSON schema 获取 search_memories 工具的 JSON schema
@@ -307,7 +307,7 @@ class MemoryTools:
}, },
} }
async def create_memory(self, **params) -> Dict[str, Any]: async def create_memory(self, **params) -> dict[str, Any]:
""" """
执行 create_memory 工具 执行 create_memory 工具
@@ -353,7 +353,7 @@ class MemoryTools:
"message": "记忆创建失败", "message": "记忆创建失败",
} }
async def link_memories(self, **params) -> Dict[str, Any]: async def link_memories(self, **params) -> dict[str, Any]:
""" """
执行 link_memories 工具 执行 link_memories 工具
@@ -433,7 +433,7 @@ class MemoryTools:
"message": "记忆关联失败", "message": "记忆关联失败",
} }
async def search_memories(self, **params) -> Dict[str, Any]: async def search_memories(self, **params) -> dict[str, Any]:
""" """
执行 search_memories 工具 执行 search_memories 工具
@@ -486,7 +486,7 @@ class MemoryTools:
import orjson import orjson
try: try:
ids = orjson.loads(ids) ids = orjson.loads(ids)
except: except Exception:
ids = [ids] ids = [ids]
if isinstance(ids, list): if isinstance(ids, list):
for mem_id in ids: for mem_id in ids:
@@ -526,8 +526,7 @@ class MemoryTools:
# ) # )
# 合并扩展结果 # 合并扩展结果
for mem_id, score in expanded_results: expanded_memory_scores.update(dict(expanded_results))
expanded_memory_scores[mem_id] = score
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆") logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
@@ -624,16 +623,16 @@ class MemoryTools:
} }
async def _generate_multi_queries_simple( async def _generate_multi_queries_simple(
self, query: str, context: Optional[Dict[str, Any]] = None self, query: str, context: dict[str, Any] | None = None
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
简化版多查询生成(直接在 Tools 层实现,避免循环依赖) 简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
让小模型直接生成3-5个不同角度的查询语句。 让小模型直接生成3-5个不同角度的查询语句。
""" """
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.utils_small, model_set=model_config.model_task_config.utils_small,
@@ -648,10 +647,10 @@ class MemoryTools:
# 处理聊天历史提取最近5条左右的对话 # 处理聊天历史提取最近5条左右的对话
recent_chat = "" recent_chat = ""
if chat_history: if chat_history:
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
# 取最近5条消息 # 取最近5条消息
recent_lines = lines[-5:] if len(lines) > 5 else lines recent_lines = lines[-5:] if len(lines) > 5 else lines
recent_chat = '\n'.join(recent_lines) recent_chat = "\n".join(recent_lines)
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式 prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式
@@ -686,9 +685,11 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import orjson, re import re
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip() import orjson
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
data = orjson.loads(response) data = orjson.loads(response)
queries = data.get("queries", []) queries = data.get("queries", [])
@@ -707,7 +708,7 @@ class MemoryTools:
async def _single_query_search( async def _single_query_search(
self, query: str, top_k: int self, query: str, top_k: int
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
传统的单查询搜索 传统的单查询搜索
@@ -735,8 +736,8 @@ class MemoryTools:
return similar_nodes return similar_nodes
async def _multi_query_search( async def _multi_query_search(
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None self, query: str, top_k: int, context: dict[str, Any] | None = None
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
多查询策略搜索(简化版) 多查询策略搜索(简化版)
@@ -800,7 +801,7 @@ class MemoryTools:
if node.embedding is not None: if node.embedding is not None:
await self.vector_store.add_node(node) await self.vector_store.add_node(node)
async def _find_memory_by_description(self, description: str) -> Optional[Memory]: async def _find_memory_by_description(self, description: str) -> Memory | None:
""" """
通过描述查找记忆 通过描述查找记忆
@@ -827,7 +828,7 @@ class MemoryTools:
return None return None
# 获取最相似节点关联的记忆 # 获取最相似节点关联的记忆
node_id, similarity, metadata = similar_nodes[0] _node_id, _similarity, metadata = similar_nodes[0]
if "memory_ids" not in metadata or not metadata["memory_ids"]: if "memory_ids" not in metadata or not metadata["memory_ids"]:
return None return None
@@ -862,12 +863,12 @@ class MemoryTools:
async def _expand_with_semantic_filter( async def _expand_with_semantic_filter(
self, self,
initial_memory_ids: List[str], initial_memory_ids: list[str],
query_embedding, query_embedding,
max_depth: int = 2, max_depth: int = 2,
semantic_threshold: float = 0.5, semantic_threshold: float = 0.5,
max_expanded: int = 20 max_expanded: int = 20
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤 从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
@@ -885,10 +886,9 @@ class MemoryTools:
return [] return []
try: try:
import numpy as np
visited_memories = set(initial_memory_ids) visited_memories = set(initial_memory_ids)
expanded_memories: Dict[str, float] = {} expanded_memories: dict[str, float] = {}
current_level = initial_memory_ids current_level = initial_memory_ids
@@ -906,7 +906,7 @@ class MemoryTools:
try: try:
neighbors = list(self.graph_store.graph.neighbors(node.id)) neighbors = list(self.graph_store.graph.neighbors(node.id))
except: except Exception:
continue continue
for neighbor_id in neighbors: for neighbor_id in neighbors:
@@ -932,7 +932,7 @@ class MemoryTools:
try: try:
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id) edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5 edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except: except Exception:
edge_importance = 0.5 edge_importance = 0.5
# 综合评分 # 综合评分
@@ -952,7 +952,7 @@ class MemoryTools:
import orjson import orjson
try: try:
neighbor_memory_ids = orjson.loads(neighbor_memory_ids) neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except: except Exception:
neighbor_memory_ids = [neighbor_memory_ids] neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids: for neighbor_mem_id in neighbor_memory_ids:
@@ -1010,7 +1010,7 @@ class MemoryTools:
return 0.0 return 0.0
@staticmethod @staticmethod
def get_all_tool_schemas() -> List[Dict[str, Any]]: def get_all_tool_schemas() -> list[dict[str, Any]]:
""" """
获取所有工具的 schema 获取所有工具的 schema

View File

@@ -5,4 +5,4 @@
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
from src.memory_graph.utils.time_parser import TimeParser from src.memory_graph.utils.time_parser import TimeParser
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"] __all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]

View File

@@ -5,8 +5,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from functools import lru_cache
from typing import List, Optional
import numpy as np import numpy as np
@@ -142,7 +140,7 @@ class EmbeddingGenerator:
dim = self._get_dimension() dim = self._get_dimension()
return np.random.rand(dim).astype(np.float32) return np.random.rand(dim).astype(np.float32)
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]: async def _generate_with_api(self, text: str) -> np.ndarray | None:
"""使用 API 生成嵌入""" """使用 API 生成嵌入"""
try: try:
# 初始化 API # 初始化 API
@@ -166,7 +164,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}") logger.debug(f"API 嵌入生成失败: {e}")
return None return None
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]: async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
"""使用本地模型生成嵌入""" """使用本地模型生成嵌入"""
try: try:
# 加载本地模型 # 加载本地模型
@@ -204,13 +202,13 @@ class EmbeddingGenerator:
if self._local_model_loaded and self._local_model: if self._local_model_loaded and self._local_model:
try: try:
return self._local_model.get_sentence_embedding_dimension() return self._local_model.get_sentence_embedding_dimension()
except: except Exception:
pass pass
# 默认 384sentence-transformers 常用维度) # 默认 384sentence-transformers 常用维度)
return 384 return 384
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]: async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
""" """
批量生成嵌入向量 批量生成嵌入向量
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
dim = self._get_dimension() dim = self._get_dimension()
return [np.random.rand(dim).astype(np.float32) for _ in texts] return [np.random.rand(dim).astype(np.float32) for _ in texts]
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]: async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
"""使用 API 批量生成""" """使用 API 批量生成"""
try: try:
# 对于大多数 API批量调用就是多次单独调用 # 对于大多数 API批量调用就是多次单独调用
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
# 全局单例 # 全局单例
_global_generator: Optional[EmbeddingGenerator] = None _global_generator: EmbeddingGenerator | None = None
def get_embedding_generator( def get_embedding_generator(

View File

@@ -5,10 +5,9 @@
""" """
import logging import logging
from typing import Optional, List, Dict, Any
from datetime import datetime from datetime import datetime
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,11 +41,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
# 2. 查找主题节点(谓语/动作) # 2. 查找主题节点(谓语/动作)
topic_node = None topic_node = None
memory_type_relation = None
for edge in memory.edges: for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id: if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id) topic_node = memory.get_node_by_id(edge.target_id)
memory_type_relation = edge.relation
break break
if not topic_node: if not topic_node:
@@ -65,7 +62,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
break break
# 4. 收集属性节点 # 4. 收集属性节点
attributes: Dict[str, str] = {} attributes: dict[str, str] = {}
for edge in memory.edges: for edge in memory.edges:
if edge.edge_type == EdgeType.ATTRIBUTE: if edge.edge_type == EdgeType.ATTRIBUTE:
# 查找属性节点和值节点 # 查找属性节点和值节点
@@ -143,8 +140,8 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
def format_memories_for_prompt( def format_memories_for_prompt(
memories: List[Memory], memories: list[Memory],
max_count: Optional[int] = None, max_count: int | None = None,
include_metadata: bool = False, include_metadata: bool = False,
group_by_type: bool = False group_by_type: bool = False
) -> str: ) -> str:
@@ -169,7 +166,7 @@ def format_memories_for_prompt(
# 按类型分组 # 按类型分组
if group_by_type: if group_by_type:
type_groups: Dict[MemoryType, List[Memory]] = {} type_groups: dict[MemoryType, list[Memory]] = {}
for memory in memories: for memory in memories:
if memory.memory_type not in type_groups: if memory.memory_type not in type_groups:
type_groups[memory.memory_type] = [] type_groups[memory.memory_type] = []
@@ -250,7 +247,7 @@ def get_memory_type_label(memory_type: str) -> str:
return type_mapping.get(memory_type_lower, "未知") return type_mapping.get(memory_type_lower, "未知")
def _format_relative_time(timestamp: datetime) -> Optional[str]: def _format_relative_time(timestamp: datetime) -> str | None:
""" """
格式化相对时间(如"2天前""刚才" 格式化相对时间(如"2天前""刚才"
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
# 导出主要函数 # 导出主要函数
__all__ = [ __all__ = [
'format_memory_for_prompt', "format_memories_for_prompt",
'format_memories_for_prompt', "format_memory_for_prompt",
'get_memory_type_label', "format_memory_summary",
'format_memory_summary', "get_memory_type_label",
] ]

View File

@@ -14,7 +14,6 @@ from __future__ import annotations
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -28,7 +27,7 @@ class TimeParser:
负责将自然语言时间表达转换为标准化的绝对时间 负责将自然语言时间表达转换为标准化的绝对时间
""" """
def __init__(self, reference_time: Optional[datetime] = None): def __init__(self, reference_time: datetime | None = None):
""" """
初始化时间解析器 初始化时间解析器
@@ -37,7 +36,7 @@ class TimeParser:
""" """
self.reference_time = reference_time or datetime.now() self.reference_time = reference_time or datetime.now()
def parse(self, time_str: str) -> Optional[datetime]: def parse(self, time_str: str) -> datetime | None:
""" """
解析时间字符串 解析时间字符串
@@ -81,7 +80,7 @@ class TimeParser:
logger.warning(f"无法解析时间: '{time_str}',使用当前时间") logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
return self.reference_time return self.reference_time
def _parse_relative_day(self, time_str: str) -> Optional[datetime]: def _parse_relative_day(self, time_str: str) -> datetime | None:
""" """
解析相对日期:今天、明天、昨天、前天、后天 解析相对日期:今天、明天、昨天、前天、后天
""" """
@@ -108,7 +107,7 @@ class TimeParser:
return None return None
def _parse_days_ago(self, time_str: str) -> Optional[datetime]: def _parse_days_ago(self, time_str: str) -> datetime | None:
""" """
解析 X天前/X天后、X周前/X周后、X个月前/X个月后 解析 X天前/X天后、X周前/X周后、X个月前/X个月后
""" """
@@ -172,7 +171,7 @@ class TimeParser:
return None return None
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]: def _parse_hours_ago(self, time_str: str) -> datetime | None:
""" """
解析 X小时前/X小时后、X分钟前/X分钟后 解析 X小时前/X小时后、X分钟前/X分钟后
""" """
@@ -204,7 +203,7 @@ class TimeParser:
return None return None
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]: def _parse_week_month_year(self, time_str: str) -> datetime | None:
""" """
解析:上周、上个月、去年、本周、本月、今年 解析:上周、上个月、去年、本周、本月、今年
""" """
@@ -232,7 +231,7 @@ class TimeParser:
return None return None
def _parse_specific_date(self, time_str: str) -> Optional[datetime]: def _parse_specific_date(self, time_str: str) -> datetime | None:
""" """
解析具体日期: 解析具体日期:
- 2025-11-05 - 2025-11-05
@@ -266,7 +265,7 @@ class TimeParser:
return None return None
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]: def _parse_time_of_day(self, time_str: str) -> datetime | None:
""" """
解析一天中的时间: 解析一天中的时间:
- 早上、上午、中午、下午、晚上、深夜 - 早上、上午、中午、下午、晚上、深夜
@@ -290,7 +289,7 @@ class TimeParser:
} }
# 先检查是否有具体时间点早上8点、下午3点 # 先检查是否有具体时间点早上8点、下午3点
for period, default_hour in time_periods.items(): for period in time_periods.keys():
pattern = rf"{period}(\d{{1,2}})点?" pattern = rf"{period}(\d{{1,2}})点?"
match = re.search(pattern, time_str) match = re.search(pattern, time_str)
if match: if match:
@@ -314,7 +313,7 @@ class TimeParser:
return None return None
def _parse_combined_time(self, time_str: str) -> Optional[datetime]: def _parse_combined_time(self, time_str: str) -> datetime | None:
""" """
解析组合时间表达:今天下午、昨天晚上、明天早上 解析组合时间表达:今天下午、昨天晚上、明天早上
""" """
@@ -461,7 +460,7 @@ class TimeParser:
return str(dt) return str(dt)
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]: def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
""" """
解析时间范围最近一周、最近3天 解析时间范围最近一周、最近3天