fix: 修复代码质量问题 - 更正异常处理和导入语句

Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-11-07 04:39:35 +00:00
parent 3bdcfa3dd4
commit 5caf630623
20 changed files with 893 additions and 910 deletions

View File

@@ -6,10 +6,12 @@ from typing import ClassVar
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, register_plugin
from src.plugin_system.base.component_types import ComponentInfo, ToolInfo
logger = get_logger("memory_graph_plugin")
# 用于存储后台任务引用
_background_tasks = set()
@register_plugin
class MemoryGraphPlugin(BasePlugin):
@@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin):
"""插件卸载时的回调"""
try:
import asyncio
from src.memory_graph.manager_singleton import shutdown_memory_manager
logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
@@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin):
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果循环正在运行,创建任务
asyncio.create_task(shutdown_memory_manager())
task = asyncio.create_task(shutdown_memory_manager())
# 存储引用以防止任务被垃圾回收
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else:
# 如果循环未运行,直接运行
loop.run_until_complete(shutdown_memory_manager())

View File

@@ -25,14 +25,13 @@ import asyncio
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
logger = get_logger(__name__)
@@ -65,7 +64,7 @@ class MemoryDeduplicator:
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
async def find_similar_pairs(self) -> List[Tuple[str, str, float]]:
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
"""
查找所有相似的记忆对(通过向量相似度计算)
@@ -144,7 +143,7 @@ class MemoryDeduplicator:
logger.error(f"计算余弦相似度失败: {e}")
return 0.0
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]:
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
"""
决定保留哪个记忆,删除哪个
@@ -197,7 +196,7 @@ class MemoryDeduplicator:
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
logger.info(f"")
logger.info("")
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
logger.info(f" 保留: {keep_id}")
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
@@ -221,14 +220,14 @@ class MemoryDeduplicator:
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
# 累加访问次数
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'):
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count
# 删除相似记忆
await self.manager.delete_memory(remove_id)
self.stats["duplicates_removed"] += 1
logger.info(f" ✅ 删除成功")
logger.info(" ✅ 删除成功")
# 让出控制权
await asyncio.sleep(0)

View File

@@ -6,24 +6,24 @@
from src.memory_graph.manager import MemoryManager
from src.memory_graph.models import (
EdgeType,
Memory,
MemoryEdge,
MemoryNode,
MemoryStatus,
MemoryType,
NodeType,
EdgeType,
)
__all__ = [
"MemoryManager",
"EdgeType",
"Memory",
"MemoryNode",
"MemoryEdge",
"MemoryManager",
"MemoryNode",
"MemoryStatus",
"MemoryType",
"NodeType",
"EdgeType",
"MemoryStatus",
]
__version__ = "0.1.0"

View File

@@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.core.node_merger import NodeMerger
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"]
__all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]

View File

@@ -5,7 +5,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
@@ -16,7 +16,6 @@ from src.memory_graph.models import (
MemoryEdge,
MemoryNode,
MemoryStatus,
MemoryType,
NodeType,
)
from src.memory_graph.storage.graph_store import GraphStore
@@ -41,7 +40,7 @@ class MemoryBuilder:
self,
vector_store: VectorStore,
graph_store: GraphStore,
embedding_generator: Optional[Any] = None,
embedding_generator: Any | None = None,
):
"""
初始化记忆构建器
@@ -55,7 +54,7 @@ class MemoryBuilder:
self.graph_store = graph_store
self.embedding_generator = embedding_generator
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory:
async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
"""
构建完整的记忆对象
@@ -97,7 +96,7 @@ class MemoryBuilder:
edges.append(memory_type_edge)
# 4. 如果有客体,创建客体节点并连接
if "object" in extracted_params and extracted_params["object"]:
if extracted_params.get("object"):
object_node = await self._create_object_node(
content=extracted_params["object"], memory_id=memory_id
)
@@ -258,11 +257,11 @@ class MemoryBuilder:
async def _process_attributes(
self,
attributes: Dict[str, Any],
attributes: dict[str, Any],
parent_id: str,
memory_id: str,
importance: float,
) -> tuple[List[MemoryNode], List[MemoryEdge]]:
) -> tuple[list[MemoryNode], list[MemoryEdge]]:
"""
处理属性,构建属性子图
@@ -341,7 +340,7 @@ class MemoryBuilder:
async def _find_existing_node(
self, content: str, node_type: NodeType
) -> Optional[MemoryNode]:
) -> MemoryNode | None:
"""
查找已存在的完全匹配节点(用于主体和属性)
@@ -369,7 +368,7 @@ class MemoryBuilder:
async def _find_similar_topic(
self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]:
) -> MemoryNode | None:
"""
查找相似的主题节点(基于语义相似度)
@@ -414,7 +413,7 @@ class MemoryBuilder:
async def _find_similar_object(
self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]:
) -> MemoryNode | None:
"""
查找相似的客体节点(基于语义相似度)
@@ -525,7 +524,7 @@ class MemoryBuilder:
logger.error(f"记忆关联失败: {e}", exc_info=True)
raise RuntimeError(f"记忆关联失败: {e}")
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]:
def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
"""查找记忆中的主题节点"""
for node in memory.nodes:
if node.node_type == NodeType.TOPIC:

View File

@@ -5,7 +5,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any
from src.common.logger import get_logger
from src.memory_graph.models import MemoryType
@@ -25,7 +25,7 @@ class MemoryExtractor:
4. 清洗和格式化数据
"""
def __init__(self, time_parser: Optional[TimeParser] = None):
def __init__(self, time_parser: TimeParser | None = None):
"""
初始化记忆提取器
@@ -34,7 +34,7 @@ class MemoryExtractor:
"""
self.time_parser = time_parser or TimeParser()
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
从工具参数中提取记忆元素
@@ -64,11 +64,11 @@ class MemoryExtractor:
}
# 3. 提取可选的客体
if "object" in params and params["object"]:
if params.get("object"):
extracted["object"] = self._clean_text(params["object"])
# 4. 提取和标准化属性
if "attributes" in params and params["attributes"]:
if params.get("attributes"):
extracted["attributes"] = self._process_attributes(params["attributes"])
else:
extracted["attributes"] = {}
@@ -86,7 +86,7 @@ class MemoryExtractor:
logger.error(f"记忆提取失败: {e}", exc_info=True)
raise ValueError(f"记忆提取失败: {e}")
def _validate_required_params(self, params: Dict[str, Any]) -> None:
def _validate_required_params(self, params: dict[str, Any]) -> None:
"""
验证必需参数
@@ -181,7 +181,7 @@ class MemoryExtractor:
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
return 0.5
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]:
def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
"""
处理属性字典
@@ -222,7 +222,7 @@ class MemoryExtractor:
return processed
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
提取记忆关联参数(用于 link_memories 工具)

View File

@@ -4,11 +4,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
from src.common.logger import get_logger
from src.config.official_configs import MemoryConfig
from src.memory_graph.models import MemoryNode, NodeType
@@ -54,9 +49,9 @@ class NodeMerger:
async def find_similar_nodes(
self,
node: MemoryNode,
threshold: Optional[float] = None,
threshold: float | None = None,
limit: int = 5,
) -> List[Tuple[MemoryNode, float]]:
) -> list[tuple[MemoryNode, float]]:
"""
查找与指定节点相似的节点
@@ -207,7 +202,7 @@ class NodeMerger:
# 如果有 30% 以上的邻居重叠,认为上下文匹配
return overlap_ratio > 0.3
def _get_node_content(self, node_id: str) -> Optional[str]:
def _get_node_content(self, node_id: str) -> str | None:
"""获取节点的内容"""
memories = self.graph_store.get_memories_by_node(node_id)
if memories:
@@ -280,8 +275,8 @@ class NodeMerger:
async def batch_merge_similar_nodes(
self,
nodes: List[MemoryNode],
progress_callback: Optional[callable] = None,
nodes: list[MemoryNode],
progress_callback: callable | None = None,
) -> dict:
"""
批量处理节点合并
@@ -344,7 +339,7 @@ class NodeMerger:
self,
min_similarity: float = 0.85,
limit: int = 100,
) -> List[Tuple[str, str, float]]:
) -> list[tuple[str, str, float]]:
"""
获取待合并的候选节点对

View File

@@ -10,22 +10,21 @@
import asyncio
import logging
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any
from src.config.config import global_config
from src.config.official_configs import MemoryConfig
from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode, MemoryType, NodeType, EdgeType
from src.memory_graph.models import EdgeType, Memory, MemoryEdge, NodeType
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore
from src.memory_graph.tools.memory_tools import MemoryTools
from src.memory_graph.utils.embeddings import EmbeddingGenerator
import uuid
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
@@ -46,7 +45,7 @@ class MemoryManager:
def __init__(
self,
data_dir: Optional[Path] = None,
data_dir: Path | None = None,
):
"""
初始化记忆管理器
@@ -55,29 +54,29 @@ class MemoryManager:
data_dir: 数据目录可选默认从global_config读取
"""
# 直接使用 global_config.memory
if not global_config.memory or not getattr(global_config.memory, 'enable', False):
if not global_config.memory or not getattr(global_config.memory, "enable", False):
raise ValueError("记忆系统未启用,请在配置文件中启用 [memory] enable = true")
self.config: MemoryConfig = global_config.memory
self.data_dir = data_dir or Path(getattr(self.config, 'data_dir', 'data/memory_graph'))
self.data_dir = data_dir or Path(getattr(self.config, "data_dir", "data/memory_graph"))
# 存储组件
self.vector_store: Optional[VectorStore] = None
self.graph_store: Optional[GraphStore] = None
self.persistence: Optional[PersistenceManager] = None
self.vector_store: VectorStore | None = None
self.graph_store: GraphStore | None = None
self.persistence: PersistenceManager | None = None
# 核心组件
self.embedding_generator: Optional[EmbeddingGenerator] = None
self.extractor: Optional[MemoryExtractor] = None
self.builder: Optional[MemoryBuilder] = None
self.tools: Optional[MemoryTools] = None
self.embedding_generator: EmbeddingGenerator | None = None
self.extractor: MemoryExtractor | None = None
self.builder: MemoryBuilder | None = None
self.tools: MemoryTools | None = None
# 状态
self._initialized = False
self._last_maintenance = datetime.now()
self._maintenance_task: Optional[asyncio.Task] = None
self._maintenance_interval_hours = getattr(self.config, 'consolidation_interval_hours', 1.0)
self._maintenance_schedule_id: Optional[str] = None # 调度任务ID
self._maintenance_task: asyncio.Task | None = None
self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
self._maintenance_schedule_id: str | None = None # 调度任务ID
logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
@@ -101,8 +100,8 @@ class MemoryManager:
self.data_dir.mkdir(parents=True, exist_ok=True)
# 获取存储配置
storage_config = getattr(self.config, 'storage', None)
vector_collection_name = getattr(storage_config, 'vector_collection_name', 'memory_graph') if storage_config else 'memory_graph'
storage_config = getattr(self.config, "storage", None)
vector_collection_name = getattr(storage_config, "vector_collection_name", "memory_graph") if storage_config else "memory_graph"
self.vector_store = VectorStore(
collection_name=vector_collection_name,
@@ -203,11 +202,11 @@ class MemoryManager:
subject: str,
memory_type: str,
topic: str,
object: Optional[str] = None,
attributes: Optional[Dict[str, str]] = None,
object: str | None = None,
attributes: dict[str, str] | None = None,
importance: float = 0.5,
**kwargs,
) -> Optional[Memory]:
) -> Memory | None:
"""
创建新记忆
@@ -250,7 +249,7 @@ class MemoryManager:
logger.error(f"创建记忆时发生异常: {e}", exc_info=True)
return None
async def get_memory(self, memory_id: str) -> Optional[Memory]:
async def get_memory(self, memory_id: str) -> Memory | None:
"""
根据 ID 获取记忆
@@ -348,8 +347,8 @@ class MemoryManager:
async def generate_multi_queries(
self,
query: str,
context: Optional[Dict[str, Any]] = None,
) -> List[Tuple[str, float]]:
context: dict[str, Any] | None = None,
) -> list[tuple[str, float]]:
"""
使用小模型生成多个查询语句(用于多路召回)
@@ -364,8 +363,8 @@ class MemoryManager:
List of (query_string, weight) - 查询语句和权重
"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(
model_set=model_config.model_task_config.utils_small,
@@ -407,9 +406,10 @@ class MemoryManager:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
# 解析JSON
import json, re
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip()
import json
import re
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
try:
data = json.loads(response)
@@ -439,14 +439,14 @@ class MemoryManager:
self,
query: str,
top_k: int = 10,
memory_types: Optional[List[str]] = None,
time_range: Optional[Tuple[datetime, datetime]] = None,
memory_types: list[str] | None = None,
time_range: tuple[datetime, datetime] | None = None,
min_importance: float = 0.0,
include_forgotten: bool = False,
use_multi_query: bool = True,
expand_depth: int | None = None,
context: Optional[Dict[str, Any]] = None,
) -> List[Memory]:
context: dict[str, Any] | None = None,
) -> list[Memory]:
"""
搜索记忆
@@ -611,7 +611,7 @@ class MemoryManager:
# 计算时间衰减
last_access_dt = datetime.fromisoformat(last_access)
hours_passed = (now - last_access_dt).total_seconds() / 3600
decay_rate = getattr(self.config, 'activation_decay_rate', 0.95)
decay_rate = getattr(self.config, "activation_decay_rate", 0.95)
decay_factor = decay_rate ** (hours_passed / 24)
current_activation = activation_info.get("level", 0.0) * decay_factor
else:
@@ -631,15 +631,15 @@ class MemoryManager:
# 激活传播:激活相关记忆
if strength > 0.1: # 只有足够强的激活才传播
propagation_depth = getattr(self.config, 'activation_propagation_depth', 2)
propagation_depth = getattr(self.config, "activation_propagation_depth", 2)
related_memories = self._get_related_memories(
memory_id,
max_depth=propagation_depth
)
propagation_strength_factor = getattr(self.config, 'activation_propagation_strength', 0.5)
propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5)
propagation_strength = strength * propagation_strength_factor
max_related = getattr(self.config, 'max_related_memories', 5)
max_related = getattr(self.config, "max_related_memories", 5)
for related_id in related_memories[:max_related]:
await self.activate_memory(related_id, propagation_strength)
@@ -652,7 +652,7 @@ class MemoryManager:
logger.error(f"激活记忆失败: {e}", exc_info=True)
return False
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> List[str]:
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]:
"""
获取相关记忆 ID 列表(旧版本,保留用于激活传播)
@@ -687,12 +687,12 @@ class MemoryManager:
async def expand_memories_with_semantic_filter(
self,
initial_memory_ids: List[str],
initial_memory_ids: list[str],
query_embedding: "np.ndarray",
max_depth: int = 2,
semantic_threshold: float = 0.5,
max_expanded: int = 20
) -> List[Tuple[str, float]]:
) -> list[tuple[str, float]]:
"""
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
@@ -712,12 +712,11 @@ class MemoryManager:
return []
try:
import numpy as np
# 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数
expanded_memories: Dict[str, float] = {}
expanded_memories: dict[str, float] = {}
# BFS扩展
current_level = initial_memory_ids
@@ -738,7 +737,7 @@ class MemoryManager:
# 获取邻居节点
try:
neighbors = list(self.graph_store.graph.neighbors(node.id))
except:
except Exception:
continue
for neighbor_id in neighbors:
@@ -764,7 +763,7 @@ class MemoryManager:
try:
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except:
except Exception:
edge_importance = 0.5
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
@@ -785,7 +784,7 @@ class MemoryManager:
import json
try:
neighbor_memory_ids = json.loads(neighbor_memory_ids)
except:
except Exception:
neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids:
@@ -909,7 +908,7 @@ class MemoryManager:
continue
# 跳过高重要性记忆
min_importance = getattr(self.config, 'forgetting_min_importance', 7.0)
min_importance = getattr(self.config, "forgetting_min_importance", 7.0)
if memory.importance >= min_importance:
continue
@@ -939,7 +938,7 @@ class MemoryManager:
# ==================== 统计与维护 ====================
def get_statistics(self) -> Dict[str, Any]:
def get_statistics(self) -> dict[str, Any]:
"""
获取记忆系统统计信息
@@ -980,7 +979,7 @@ class MemoryManager:
similarity_threshold: float = 0.85,
time_window_hours: float = 24.0,
max_batch_size: int = 50,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
整理记忆:直接合并去重相似记忆(不创建新边)
@@ -1065,7 +1064,7 @@ class MemoryManager:
result["checked_count"] = len(recent_memories)
# 按记忆类型分组,减少跨类型比较
memories_by_type: Dict[str, List[Memory]] = {}
memories_by_type: dict[str, list[Memory]] = {}
for mem in recent_memories:
mem_type = mem.metadata.get("memory_type", "")
if mem_type not in memories_by_type:
@@ -1073,7 +1072,7 @@ class MemoryManager:
memories_by_type[mem_type].append(mem)
# 记录需要删除的记忆,延迟批量删除
to_delete: List[Tuple[Memory, str]] = [] # (memory, reason)
to_delete: list[tuple[Memory, str]] = [] # (memory, reason)
deleted_ids = set()
# 对每个类型的记忆进行相似度检测
@@ -1084,7 +1083,7 @@ class MemoryManager:
logger.debug(f"🔍 检查类型 '{mem_type}'{len(memories)} 条记忆")
# 预提取所有主题节点的嵌入向量
embeddings_map: Dict[str, "np.ndarray"] = {}
embeddings_map: dict[str, "np.ndarray"] = {}
valid_memories = []
for mem in memories:
@@ -1094,7 +1093,6 @@ class MemoryManager:
valid_memories.append(mem)
# 批量计算相似度矩阵(比逐个计算更高效)
import numpy as np
for i in range(len(valid_memories)):
# 更频繁的协作式多任务让出
@@ -1134,7 +1132,7 @@ class MemoryManager:
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
# 累加访问次数
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'):
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count
# 标记为待删除(不立即删除)
@@ -1164,7 +1162,7 @@ class MemoryManager:
# 批量保存一次性写入减少I/O
await self.persistence.save_graph_store(self.graph_store)
logger.info(f"💾 批量保存完成")
logger.info("💾 批量保存完成")
logger.info(f"✅ 记忆整理完成: {result}")
@@ -1207,10 +1205,10 @@ class MemoryManager:
async def auto_link_memories(
self,
time_window_hours: float = None,
max_candidates: int = None,
min_confidence: float = None,
) -> Dict[str, Any]:
time_window_hours: float | None = None,
max_candidates: int | None = None,
min_confidence: float | None = None,
) -> dict[str, Any]:
"""
自动关联记忆
@@ -1229,8 +1227,8 @@ class MemoryManager:
# 使用配置值或参数覆盖
time_window_hours = time_window_hours if time_window_hours is not None else 24
max_candidates = max_candidates if max_candidates is not None else getattr(self.config, 'auto_link_max_candidates', 10)
min_confidence = min_confidence if min_confidence is not None else getattr(self.config, 'auto_link_min_confidence', 0.7)
max_candidates = max_candidates if max_candidates is not None else getattr(self.config, "auto_link_max_candidates", 10)
min_confidence = min_confidence if min_confidence is not None else getattr(self.config, "auto_link_min_confidence", 0.7)
try:
logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...")
@@ -1361,9 +1359,9 @@ class MemoryManager:
async def _find_link_candidates(
self,
memory: Memory,
exclude_ids: Set[str],
exclude_ids: set[str],
max_results: int = 5,
) -> List[Memory]:
) -> list[Memory]:
"""
为记忆寻找关联候选
@@ -1407,9 +1405,9 @@ class MemoryManager:
async def _analyze_memory_relations(
self,
source_memory: Memory,
candidate_memories: List[Memory],
candidate_memories: list[Memory],
min_confidence: float = 0.7,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
使用LLM分析记忆之间的关系
@@ -1426,8 +1424,8 @@ class MemoryManager:
- reasoning: 推理过程
"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
# 构建LLM请求
llm = LLMRequest(
@@ -1496,7 +1494,7 @@ class MemoryManager:
import re
# 提取JSON
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
@@ -1507,7 +1505,7 @@ class MemoryManager:
except json.JSONDecodeError:
logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}")
# 尝试简单修复
json_str = re.sub(r'[\r\n\t]', '', json_str)
json_str = re.sub(r"[\r\n\t]", "", json_str)
analysis_results = json.loads(json_str)
# 转换为结果格式
@@ -1574,7 +1572,7 @@ class MemoryManager:
logger.warning(f"格式化记忆失败: {e}")
return f"记忆ID: {memory.id}"
async def maintenance(self) -> Dict[str, Any]:
async def maintenance(self) -> dict[str, Any]:
"""
执行维护任务(优化版本)
@@ -1604,12 +1602,12 @@ class MemoryManager:
start_time = datetime.now()
# 1. 记忆整理(异步后台执行,不阻塞主流程)
if getattr(self.config, 'consolidation_enabled', False):
if getattr(self.config, "consolidation_enabled", False):
logger.info("🚀 启动异步记忆整理任务...")
consolidate_result = await self.consolidate_memories(
similarity_threshold=getattr(self.config, 'consolidation_deduplication_threshold', 0.93),
time_window_hours=getattr(self.config, 'consolidation_time_window_hours', 2.0), # 统一时间窗口
max_batch_size=getattr(self.config, 'consolidation_max_batch_size', 30)
similarity_threshold=getattr(self.config, "consolidation_deduplication_threshold", 0.93),
time_window_hours=getattr(self.config, "consolidation_time_window_hours", 2.0), # 统一时间窗口
max_batch_size=getattr(self.config, "consolidation_max_batch_size", 30)
)
if consolidate_result.get("task_started"):
@@ -1620,16 +1618,16 @@ class MemoryManager:
logger.warning("❌ 记忆整理任务启动失败")
# 2. 自动关联记忆(使用统一的时间窗口)
if getattr(self.config, 'consolidation_linking_enabled', True):
if getattr(self.config, "consolidation_linking_enabled", True):
logger.info("🔗 执行轻量级自动关联...")
link_result = await self._lightweight_auto_link_memories()
result["linked"] = link_result.get("linked_count", 0)
# 3. 自动遗忘(快速执行)
if getattr(self.config, 'forgetting_enabled', True):
if getattr(self.config, "forgetting_enabled", True):
logger.info("🗑️ 执行自动遗忘...")
forgotten_count = await self.auto_forget_memories(
threshold=getattr(self.config, 'forgetting_activation_threshold', 0.1)
threshold=getattr(self.config, "forgetting_activation_threshold", 0.1)
)
result["forgotten"] = forgotten_count
@@ -1654,10 +1652,10 @@ class MemoryManager:
async def _lightweight_auto_link_memories(
self,
time_window_hours: float = None, # 从配置读取
max_candidates: int = None, # 从配置读取
max_memories: int = None, # 从配置读取
) -> Dict[str, Any]:
time_window_hours: float | None = None, # 从配置读取
max_candidates: int | None = None, # 从配置读取
max_memories: int | None = None, # 从配置读取
) -> dict[str, Any]:
"""
智能轻量级自动关联记忆保留LLM判断优化性能
@@ -1676,11 +1674,11 @@ class MemoryManager:
# 从配置读取参数,使用统一的时间窗口
if time_window_hours is None:
time_window_hours = getattr(self.config, 'consolidation_time_window_hours', 2.0)
time_window_hours = getattr(self.config, "consolidation_time_window_hours", 2.0)
if max_candidates is None:
max_candidates = getattr(self.config, 'consolidation_linking_max_candidates', 10)
max_candidates = getattr(self.config, "consolidation_linking_max_candidates", 10)
if max_memories is None:
max_memories = getattr(self.config, 'consolidation_linking_max_memories', 20)
max_memories = getattr(self.config, "consolidation_linking_max_memories", 20)
# 获取用户配置时间窗口内的记忆
time_threshold = datetime.now() - timedelta(hours=time_window_hours)
@@ -1690,7 +1688,7 @@ class MemoryManager:
mem for mem in all_memories
if mem.created_at >= time_threshold
and not mem.metadata.get("forgotten", False)
and mem.importance >= getattr(self.config, 'consolidation_linking_min_importance', 0.5) # 从配置读取重要性阈值
and mem.importance >= getattr(self.config, "consolidation_linking_min_importance", 0.5) # 从配置读取重要性阈值
]
if len(recent_memories) > max_memories:
@@ -1704,7 +1702,6 @@ class MemoryManager:
# 第一步:向量相似度预筛选,找到潜在关联对
candidate_pairs = []
import numpy as np
for i, memory in enumerate(recent_memories):
# 获取主题节点
@@ -1733,7 +1730,7 @@ class MemoryManager:
)
# 使用配置的预筛选阈值
pre_filter_threshold = getattr(self.config, 'consolidation_linking_pre_filter_threshold', 0.7)
pre_filter_threshold = getattr(self.config, "consolidation_linking_pre_filter_threshold", 0.7)
if similarity >= pre_filter_threshold:
candidate_pairs.append((memory, other_memory, similarity))
@@ -1747,7 +1744,7 @@ class MemoryManager:
return result
# 第二步批量LLM分析使用配置的最大候选对数
max_pairs_for_llm = getattr(self.config, 'consolidation_linking_max_pairs_for_llm', 5)
max_pairs_for_llm = getattr(self.config, "consolidation_linking_max_pairs_for_llm", 5)
if len(candidate_pairs) <= max_pairs_for_llm:
link_relations = await self._batch_analyze_memory_relations(candidate_pairs)
result["llm_calls"] = 1
@@ -1810,8 +1807,8 @@ class MemoryManager:
async def _batch_analyze_memory_relations(
self,
candidate_pairs: List[Tuple[Memory, Memory, float]]
) -> List[Dict[str, Any]]:
candidate_pairs: list[tuple[Memory, Memory, float]]
) -> list[dict[str, Any]]:
"""
批量分析记忆关系优化LLM调用
@@ -1822,8 +1819,8 @@ class MemoryManager:
关系分析结果列表
"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(
model_set=model_config.model_task_config.utils_small,
@@ -1843,7 +1840,7 @@ class MemoryManager:
"""
# 构建批量分析提示词(使用配置的置信度阈值)
min_confidence = getattr(self.config, 'consolidation_linking_min_confidence', 0.7)
min_confidence = getattr(self.config, "consolidation_linking_min_confidence", 0.7)
prompt = f"""你是记忆关系分析专家。请批量分析以下候选记忆对之间的关系。
@@ -1885,8 +1882,8 @@ class MemoryManager:
请分析并输出JSON结果"""
# 调用LLM使用配置的参数
llm_temperature = getattr(self.config, 'consolidation_linking_llm_temperature', 0.2)
llm_max_tokens = getattr(self.config, 'consolidation_linking_llm_max_tokens', 1500)
llm_temperature = getattr(self.config, "consolidation_linking_llm_temperature", 0.2)
llm_max_tokens = getattr(self.config, "consolidation_linking_llm_max_tokens", 1500)
response, _ = await llm.generate_response_async(
prompt,
@@ -1899,7 +1896,7 @@ class MemoryManager:
import re
# 提取JSON
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
@@ -1910,7 +1907,7 @@ class MemoryManager:
except json.JSONDecodeError:
logger.warning(f"LLM返回格式错误尝试修复: {response[:200]}")
# 尝试简单修复
json_str = re.sub(r'[\r\n\t]', '', json_str)
json_str = re.sub(r"[\r\n\t]", "", json_str)
analysis_results = json.loads(json_str)
# 转换为结果格式

View File

@@ -7,7 +7,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from src.common.logger import get_logger
from src.memory_graph.manager import MemoryManager
@@ -15,13 +14,13 @@ from src.memory_graph.manager import MemoryManager
logger = get_logger(__name__)
# 全局 MemoryManager 实例
_memory_manager: Optional[MemoryManager] = None
_memory_manager: MemoryManager | None = None
_initialized: bool = False
async def initialize_memory_manager(
data_dir: Optional[Path | str] = None,
) -> Optional[MemoryManager]:
data_dir: Path | str | None = None,
) -> MemoryManager | None:
"""
初始化全局 MemoryManager
@@ -43,7 +42,7 @@ async def initialize_memory_manager(
from src.config.config import global_config
# 检查是否启用
if not global_config.memory or not getattr(global_config.memory, 'enable', False):
if not global_config.memory or not getattr(global_config.memory, "enable", False):
logger.info("记忆图系统已在配置中禁用")
_initialized = False
_memory_manager = None
@@ -51,7 +50,7 @@ async def initialize_memory_manager(
# 处理数据目录
if data_dir is None:
data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph')
data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
if isinstance(data_dir, str):
data_dir = Path(data_dir)
@@ -72,7 +71,7 @@ async def initialize_memory_manager(
raise
def get_memory_manager() -> Optional[MemoryManager]:
def get_memory_manager() -> MemoryManager | None:
"""
获取全局 MemoryManager 实例

View File

@@ -10,7 +10,7 @@ import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
@@ -60,8 +60,8 @@ class MemoryNode:
id: str # 节点唯一ID
content: str # 节点内容(如:"我"、"吃饭"、"白米饭"
node_type: NodeType # 节点类型
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
@@ -69,7 +69,7 @@ class MemoryNode:
if not self.id:
self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
@@ -81,7 +81,7 @@ class MemoryNode:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode:
def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
"""从字典创建节点"""
embedding = None
if data.get("embedding") is not None:
@@ -114,7 +114,7 @@ class MemoryEdge:
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为"
edge_type: EdgeType # 边类型
importance: float = 0.5 # 重要性 [0-1]
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
@@ -124,7 +124,7 @@ class MemoryEdge:
# 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
@@ -138,7 +138,7 @@ class MemoryEdge:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge:
def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
"""从字典创建边"""
return cls(
id=data["id"],
@@ -162,8 +162,8 @@ class Memory:
id: str # 记忆唯一ID
subject_id: str # 主体节点ID
memory_type: MemoryType # 记忆类型
nodes: List[MemoryNode] # 该记忆包含的所有节点
edges: List[MemoryEdge] # 该记忆包含的所有边
nodes: list[MemoryNode] # 该记忆包含的所有节点
edges: list[MemoryEdge] # 该记忆包含的所有边
importance: float = 0.5 # 整体重要性 [0-1]
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
@@ -171,7 +171,7 @@ class Memory:
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
access_count: int = 0 # 访问次数
decay_factor: float = 1.0 # 衰减因子(随时间变化)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
def __post_init__(self):
"""后初始化处理"""
@@ -181,7 +181,7 @@ class Memory:
self.importance = max(0.0, min(1.0, self.importance))
self.activation = max(0.0, min(1.0, self.activation))
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
@@ -200,7 +200,7 @@ class Memory:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Memory:
def from_dict(cls, data: dict[str, Any]) -> Memory:
"""从字典创建记忆"""
return cls(
id=data["id"],
@@ -223,14 +223,14 @@ class Memory:
self.last_accessed = datetime.now()
self.access_count += 1
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]:
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
"""根据ID获取节点"""
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_subject_node(self) -> Optional[MemoryNode]:
def get_subject_node(self) -> MemoryNode | None:
"""获取主体节点"""
return self.get_node_by_id(self.subject_id)
@@ -274,10 +274,10 @@ class StagedMemory:
memory: Memory # 原始记忆对象
status: MemoryStatus = MemoryStatus.STAGED # 状态
created_at: datetime = field(default_factory=datetime.now)
consolidated_at: Optional[datetime] = None # 整理时间
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表
consolidated_at: datetime | None = None # 整理时间
merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
return {
"memory": self.memory.to_dict(),
@@ -288,7 +288,7 @@ class StagedMemory:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory:
def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
"""从字典创建临时记忆"""
return cls(
memory=Memory.from_dict(data["memory"]),

View File

@@ -58,7 +58,7 @@ class CreateMemoryTool(BaseTool):
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None),
("attributes", ToolParamType.STRING, "详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None),
("attributes", ToolParamType.STRING, '详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None),
]
@@ -124,7 +124,7 @@ class CreateMemoryTool(BaseTool):
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"创建记忆时出错: {str(e)}"
"content": f"创建记忆时出错: {e!s}"
}
@@ -184,7 +184,7 @@ class LinkMemoriesTool(BaseTool):
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"关联记忆时出错: {str(e)}"
"content": f"关联记忆时出错: {e!s}"
}
@@ -254,5 +254,5 @@ class SearchMemoriesTool(BaseTool):
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"搜索记忆时出错: {str(e)}"
"content": f"搜索记忆时出错: {e!s}"
}

View File

@@ -5,4 +5,4 @@
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
__all__ = ["VectorStore", "GraphStore"]
__all__ = ["GraphStore", "VectorStore"]

View File

@@ -4,12 +4,10 @@
from __future__ import annotations
from typing import Dict, List, Optional, Set, Tuple
import networkx as nx
from src.common.logger import get_logger
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
from src.memory_graph.models import Memory, MemoryEdge
logger = get_logger(__name__)
@@ -31,10 +29,10 @@ class GraphStore:
self.graph = nx.DiGraph()
# 索引记忆ID -> 记忆对象
self.memory_index: Dict[str, Memory] = {}
self.memory_index: dict[str, Memory] = {}
# 索引节点ID -> 所属记忆ID集合
self.node_to_memories: Dict[str, Set[str]] = {}
self.node_to_memories: dict[str, set[str]] = {}
logger.info("初始化图存储")
@@ -84,7 +82,7 @@ class GraphStore:
logger.error(f"添加记忆失败: {e}", exc_info=True)
raise
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
def get_memory_by_id(self, memory_id: str) -> Memory | None:
"""
根据ID获取记忆
@@ -96,7 +94,7 @@ class GraphStore:
"""
return self.memory_index.get(memory_id)
def get_all_memories(self) -> List[Memory]:
def get_all_memories(self) -> list[Memory]:
"""
获取所有记忆
@@ -105,7 +103,7 @@ class GraphStore:
"""
return list(self.memory_index.values())
def get_memories_by_node(self, node_id: str) -> List[Memory]:
def get_memories_by_node(self, node_id: str) -> list[Memory]:
"""
获取包含指定节点的所有记忆
@@ -121,7 +119,7 @@ class GraphStore:
memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
"""
获取从指定节点出发的所有边
@@ -155,8 +153,8 @@ class GraphStore:
return edges
def get_neighbors(
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
) -> List[Tuple[str, Dict]]:
self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
) -> list[tuple[str, dict]]:
"""
获取节点的邻居节点
@@ -187,7 +185,7 @@ class GraphStore:
return neighbors
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
"""
查找两个节点之间的最短路径
@@ -220,10 +218,10 @@ class GraphStore:
def bfs_expand(
self,
start_nodes: List[str],
start_nodes: list[str],
depth: int = 1,
relation_types: Optional[List[str]] = None,
) -> Set[str]:
relation_types: list[str] | None = None,
) -> set[str]:
"""
从起始节点进行广度优先搜索扩展
@@ -256,7 +254,7 @@ class GraphStore:
return visited
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
"""
获取包含指定节点的子图
@@ -308,7 +306,7 @@ class GraphStore:
logger.error(f"合并节点失败: {e}", exc_info=True)
raise
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
def get_node_degree(self, node_id: str) -> tuple[int, int]:
"""
获取节点的度数
@@ -323,7 +321,7 @@ class GraphStore:
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
def get_statistics(self) -> Dict[str, int]:
def get_statistics(self) -> dict[str, int]:
"""获取图的统计信息"""
return {
"total_nodes": self.graph.number_of_nodes(),
@@ -332,7 +330,7 @@ class GraphStore:
"connected_components": nx.number_weakly_connected_components(self.graph),
}
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
"""
将图转换为字典(用于持久化)
@@ -356,7 +354,7 @@ class GraphStore:
}
@classmethod
def from_dict(cls, data: Dict) -> GraphStore:
def from_dict(cls, data: dict) -> GraphStore:
"""
从字典加载图
@@ -406,7 +404,6 @@ class GraphStore:
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
已存在的边(通过 edge.id 检查)将不会重复添加。
"""
from src.memory_graph.models import MemoryEdge
# 构建快速查重索引memory_id -> set(edge_id)
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}

View File

@@ -8,14 +8,12 @@ import asyncio
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
import orjson
from src.common.logger import get_logger
from src.memory_graph.models import Memory, StagedMemory
from src.memory_graph.models import StagedMemory
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__)
@@ -55,7 +53,7 @@ class PersistenceManager:
self.backup_dir.mkdir(parents=True, exist_ok=True)
self.auto_save_interval = auto_save_interval
self._auto_save_task: Optional[asyncio.Task] = None
self._auto_save_task: asyncio.Task | None = None
self._running = False
logger.info(f"初始化持久化管理器: data_dir={data_dir}")
@@ -95,7 +93,7 @@ class PersistenceManager:
logger.error(f"保存图数据失败: {e}", exc_info=True)
raise
async def load_graph_store(self) -> Optional[GraphStore]:
async def load_graph_store(self) -> GraphStore | None:
"""
从文件加载图存储
@@ -179,7 +177,7 @@ class PersistenceManager:
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return []
async def create_backup(self) -> Optional[Path]:
async def create_backup(self) -> Path | None:
"""
创建当前数据的备份
@@ -208,7 +206,7 @@ class PersistenceManager:
logger.error(f"创建备份失败: {e}", exc_info=True)
return None
async def _load_from_backup(self) -> Optional[GraphStore]:
async def _load_from_backup(self) -> GraphStore | None:
"""从最新的备份加载数据"""
try:
# 查找最新的备份文件
@@ -254,7 +252,7 @@ class PersistenceManager:
async def start_auto_save(
self,
graph_store: GraphStore,
staged_memories_getter: callable = None,
staged_memories_getter: callable | None = None,
) -> None:
"""
启动自动保存任务
@@ -334,7 +332,7 @@ class PersistenceManager:
logger.error(f"导出图数据失败: {e}", exc_info=True)
raise
async def import_from_json(self, input_file: Path) -> Optional[GraphStore]:
async def import_from_json(self, input_file: Path) -> GraphStore | None:
"""
从 JSON 文件导入图数据

View File

@@ -4,9 +4,8 @@
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import numpy as np
@@ -29,8 +28,8 @@ class VectorStore:
def __init__(
self,
collection_name: str = "memory_nodes",
data_dir: Optional[Path] = None,
embedding_function: Optional[Any] = None,
data_dir: Path | None = None,
embedding_function: Any | None = None,
):
"""
初始化向量存储
@@ -103,7 +102,7 @@ class VectorStore:
for key, value in node.metadata.items():
if isinstance(value, (list, dict)):
import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value
else:
@@ -122,7 +121,7 @@ class VectorStore:
logger.error(f"添加节点失败: {e}", exc_info=True)
raise
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
"""
批量添加节点
@@ -151,7 +150,7 @@ class VectorStore:
}
for key, value in n.metadata.items():
if isinstance(value, (list, dict)):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore
else:
@@ -175,9 +174,9 @@ class VectorStore:
self,
query_embedding: np.ndarray,
limit: int = 10,
node_types: Optional[List[NodeType]] = None,
node_types: list[NodeType] | None = None,
min_similarity: float = 0.0,
) -> List[Tuple[str, float, Dict[str, Any]]]:
) -> list[tuple[str, float, dict[str, Any]]]:
"""
搜索相似节点
@@ -226,10 +225,10 @@ class VectorStore:
# 解析 JSON 字符串回列表/字典
for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
try:
metadata[key] = orjson.loads(value)
except:
except Exception:
pass # 保持原值
similar_nodes.append((node_id, similarity, metadata))
@@ -243,13 +242,13 @@ class VectorStore:
async def search_with_multiple_queries(
self,
query_embeddings: List[np.ndarray],
query_weights: Optional[List[float]] = None,
query_embeddings: list[np.ndarray],
query_weights: list[float] | None = None,
limit: int = 10,
node_types: Optional[List[NodeType]] = None,
node_types: list[NodeType] | None = None,
min_similarity: float = 0.0,
fusion_strategy: str = "weighted_max",
) -> List[Tuple[str, float, Dict[str, Any]]]:
) -> list[tuple[str, float, dict[str, Any]]]:
"""
多查询融合搜索
@@ -287,7 +286,7 @@ class VectorStore:
try:
# 1. 对每个查询执行搜索
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata}
all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
# 搜索更多结果以提高融合质量
@@ -356,7 +355,7 @@ class VectorStore:
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
raise
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
"""
根据ID获取节点元数据

View File

@@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
from src.common.logger import get_logger
from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory, MemoryStatus
from src.memory_graph.models import Memory
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore
@@ -33,7 +33,7 @@ class MemoryTools:
vector_store: VectorStore,
graph_store: GraphStore,
persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None,
embedding_generator: EmbeddingGenerator | None = None,
max_expand_depth: int = 1,
expand_semantic_threshold: float = 0.3,
):
@@ -72,7 +72,7 @@ class MemoryTools:
self._initialized = True
@staticmethod
def get_create_memory_schema() -> Dict[str, Any]:
def get_create_memory_schema() -> dict[str, Any]:
"""
获取 create_memory 工具的 JSON schema
@@ -183,7 +183,7 @@ class MemoryTools:
}
@staticmethod
def get_link_memories_schema() -> Dict[str, Any]:
def get_link_memories_schema() -> dict[str, Any]:
"""
获取 link_memories 工具的 JSON schema
@@ -239,7 +239,7 @@ class MemoryTools:
}
@staticmethod
def get_search_memories_schema() -> Dict[str, Any]:
def get_search_memories_schema() -> dict[str, Any]:
"""
获取 search_memories 工具的 JSON schema
@@ -307,7 +307,7 @@ class MemoryTools:
},
}
async def create_memory(self, **params) -> Dict[str, Any]:
async def create_memory(self, **params) -> dict[str, Any]:
"""
执行 create_memory 工具
@@ -353,7 +353,7 @@ class MemoryTools:
"message": "记忆创建失败",
}
async def link_memories(self, **params) -> Dict[str, Any]:
async def link_memories(self, **params) -> dict[str, Any]:
"""
执行 link_memories 工具
@@ -433,7 +433,7 @@ class MemoryTools:
"message": "记忆关联失败",
}
async def search_memories(self, **params) -> Dict[str, Any]:
async def search_memories(self, **params) -> dict[str, Any]:
"""
执行 search_memories 工具
@@ -486,7 +486,7 @@ class MemoryTools:
import orjson
try:
ids = orjson.loads(ids)
except:
except Exception:
ids = [ids]
if isinstance(ids, list):
for mem_id in ids:
@@ -526,8 +526,7 @@ class MemoryTools:
# )
# 合并扩展结果
for mem_id, score in expanded_results:
expanded_memory_scores[mem_id] = score
expanded_memory_scores.update(dict(expanded_results))
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
@@ -624,16 +623,16 @@ class MemoryTools:
}
async def _generate_multi_queries_simple(
self, query: str, context: Optional[Dict[str, Any]] = None
) -> List[Tuple[str, float]]:
self, query: str, context: dict[str, Any] | None = None
) -> list[tuple[str, float]]:
"""
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
让小模型直接生成3-5个不同角度的查询语句。
"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(
model_set=model_config.model_task_config.utils_small,
@@ -648,10 +647,10 @@ class MemoryTools:
# 处理聊天历史提取最近5条左右的对话
recent_chat = ""
if chat_history:
lines = chat_history.strip().split('\n')
lines = chat_history.strip().split("\n")
# 取最近5条消息
recent_lines = lines[-5:] if len(lines) > 5 else lines
recent_chat = '\n'.join(recent_lines)
recent_chat = "\n".join(recent_lines)
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式
@@ -686,9 +685,11 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import orjson, re
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip()
import re
import orjson
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
data = orjson.loads(response)
queries = data.get("queries", [])
@@ -707,7 +708,7 @@ class MemoryTools:
async def _single_query_search(
self, query: str, top_k: int
) -> List[Tuple[str, float, Dict[str, Any]]]:
) -> list[tuple[str, float, dict[str, Any]]]:
"""
传统的单查询搜索
@@ -735,8 +736,8 @@ class MemoryTools:
return similar_nodes
async def _multi_query_search(
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None
) -> List[Tuple[str, float, Dict[str, Any]]]:
self, query: str, top_k: int, context: dict[str, Any] | None = None
) -> list[tuple[str, float, dict[str, Any]]]:
"""
多查询策略搜索(简化版)
@@ -800,7 +801,7 @@ class MemoryTools:
if node.embedding is not None:
await self.vector_store.add_node(node)
async def _find_memory_by_description(self, description: str) -> Optional[Memory]:
async def _find_memory_by_description(self, description: str) -> Memory | None:
"""
通过描述查找记忆
@@ -827,7 +828,7 @@ class MemoryTools:
return None
# 获取最相似节点关联的记忆
node_id, similarity, metadata = similar_nodes[0]
_node_id, _similarity, metadata = similar_nodes[0]
if "memory_ids" not in metadata or not metadata["memory_ids"]:
return None
@@ -862,12 +863,12 @@ class MemoryTools:
async def _expand_with_semantic_filter(
self,
initial_memory_ids: List[str],
initial_memory_ids: list[str],
query_embedding,
max_depth: int = 2,
semantic_threshold: float = 0.5,
max_expanded: int = 20
) -> List[Tuple[str, float]]:
) -> list[tuple[str, float]]:
"""
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
@@ -885,10 +886,9 @@ class MemoryTools:
return []
try:
import numpy as np
visited_memories = set(initial_memory_ids)
expanded_memories: Dict[str, float] = {}
expanded_memories: dict[str, float] = {}
current_level = initial_memory_ids
@@ -906,7 +906,7 @@ class MemoryTools:
try:
neighbors = list(self.graph_store.graph.neighbors(node.id))
except:
except Exception:
continue
for neighbor_id in neighbors:
@@ -932,7 +932,7 @@ class MemoryTools:
try:
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except:
except Exception:
edge_importance = 0.5
# 综合评分
@@ -952,7 +952,7 @@ class MemoryTools:
import orjson
try:
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except:
except Exception:
neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids:
@@ -1010,7 +1010,7 @@ class MemoryTools:
return 0.0
@staticmethod
def get_all_tool_schemas() -> List[Dict[str, Any]]:
def get_all_tool_schemas() -> list[dict[str, Any]]:
"""
获取所有工具的 schema

View File

@@ -5,4 +5,4 @@
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
from src.memory_graph.utils.time_parser import TimeParser
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"]
__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]

View File

@@ -5,8 +5,6 @@
from __future__ import annotations
import asyncio
from functools import lru_cache
from typing import List, Optional
import numpy as np
@@ -142,7 +140,7 @@ class EmbeddingGenerator:
dim = self._get_dimension()
return np.random.rand(dim).astype(np.float32)
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
async def _generate_with_api(self, text: str) -> np.ndarray | None:
"""使用 API 生成嵌入"""
try:
# 初始化 API
@@ -166,7 +164,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}")
return None
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
"""使用本地模型生成嵌入"""
try:
# 加载本地模型
@@ -204,13 +202,13 @@ class EmbeddingGenerator:
if self._local_model_loaded and self._local_model:
try:
return self._local_model.get_sentence_embedding_dimension()
except:
except Exception:
pass
# 默认 384sentence-transformers 常用维度)
return 384
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
"""
批量生成嵌入向量
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
dim = self._get_dimension()
return [np.random.rand(dim).astype(np.float32) for _ in texts]
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
"""使用 API 批量生成"""
try:
# 对于大多数 API批量调用就是多次单独调用
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
# 全局单例
_global_generator: Optional[EmbeddingGenerator] = None
_global_generator: EmbeddingGenerator | None = None
def get_embedding_generator(

View File

@@ -5,10 +5,9 @@
"""
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType
from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
logger = logging.getLogger(__name__)
@@ -42,11 +41,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
# 2. 查找主题节点(谓语/动作)
topic_node = None
memory_type_relation = None
for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id)
memory_type_relation = edge.relation
break
if not topic_node:
@@ -65,7 +62,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
break
# 4. 收集属性节点
attributes: Dict[str, str] = {}
attributes: dict[str, str] = {}
for edge in memory.edges:
if edge.edge_type == EdgeType.ATTRIBUTE:
# 查找属性节点和值节点
@@ -143,8 +140,8 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
def format_memories_for_prompt(
memories: List[Memory],
max_count: Optional[int] = None,
memories: list[Memory],
max_count: int | None = None,
include_metadata: bool = False,
group_by_type: bool = False
) -> str:
@@ -169,7 +166,7 @@ def format_memories_for_prompt(
# 按类型分组
if group_by_type:
type_groups: Dict[MemoryType, List[Memory]] = {}
type_groups: dict[MemoryType, list[Memory]] = {}
for memory in memories:
if memory.memory_type not in type_groups:
type_groups[memory.memory_type] = []
@@ -250,7 +247,7 @@ def get_memory_type_label(memory_type: str) -> str:
return type_mapping.get(memory_type_lower, "未知")
def _format_relative_time(timestamp: datetime) -> Optional[str]:
def _format_relative_time(timestamp: datetime) -> str | None:
"""
格式化相对时间(如"2天前""刚才"
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
# 导出主要函数
__all__ = [
'format_memory_for_prompt',
'format_memories_for_prompt',
'get_memory_type_label',
'format_memory_summary',
"format_memories_for_prompt",
"format_memory_for_prompt",
"format_memory_summary",
"get_memory_type_label",
]

View File

@@ -14,7 +14,6 @@ from __future__ import annotations
import re
from datetime import datetime, timedelta
from typing import Optional, Tuple
from src.common.logger import get_logger
@@ -28,7 +27,7 @@ class TimeParser:
负责将自然语言时间表达转换为标准化的绝对时间
"""
def __init__(self, reference_time: Optional[datetime] = None):
def __init__(self, reference_time: datetime | None = None):
"""
初始化时间解析器
@@ -37,7 +36,7 @@ class TimeParser:
"""
self.reference_time = reference_time or datetime.now()
def parse(self, time_str: str) -> Optional[datetime]:
def parse(self, time_str: str) -> datetime | None:
"""
解析时间字符串
@@ -81,7 +80,7 @@ class TimeParser:
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
return self.reference_time
def _parse_relative_day(self, time_str: str) -> Optional[datetime]:
def _parse_relative_day(self, time_str: str) -> datetime | None:
"""
解析相对日期:今天、明天、昨天、前天、后天
"""
@@ -108,7 +107,7 @@ class TimeParser:
return None
def _parse_days_ago(self, time_str: str) -> Optional[datetime]:
def _parse_days_ago(self, time_str: str) -> datetime | None:
"""
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
"""
@@ -172,7 +171,7 @@ class TimeParser:
return None
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]:
def _parse_hours_ago(self, time_str: str) -> datetime | None:
"""
解析 X小时前/X小时后、X分钟前/X分钟后
"""
@@ -204,7 +203,7 @@ class TimeParser:
return None
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]:
def _parse_week_month_year(self, time_str: str) -> datetime | None:
"""
解析:上周、上个月、去年、本周、本月、今年
"""
@@ -232,7 +231,7 @@ class TimeParser:
return None
def _parse_specific_date(self, time_str: str) -> Optional[datetime]:
def _parse_specific_date(self, time_str: str) -> datetime | None:
"""
解析具体日期:
- 2025-11-05
@@ -266,7 +265,7 @@ class TimeParser:
return None
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]:
def _parse_time_of_day(self, time_str: str) -> datetime | None:
"""
解析一天中的时间:
- 早上、上午、中午、下午、晚上、深夜
@@ -290,7 +289,7 @@ class TimeParser:
}
# 先检查是否有具体时间点早上8点、下午3点
for period, default_hour in time_periods.items():
for period in time_periods.keys():
pattern = rf"{period}(\d{{1,2}})点?"
match = re.search(pattern, time_str)
if match:
@@ -314,7 +313,7 @@ class TimeParser:
return None
def _parse_combined_time(self, time_str: str) -> Optional[datetime]:
def _parse_combined_time(self, time_str: str) -> datetime | None:
"""
解析组合时间表达:今天下午、昨天晚上、明天早上
"""
@@ -461,7 +460,7 @@ class TimeParser:
return str(dt)
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]:
def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
"""
解析时间范围最近一周、最近3天