fix: 修复代码质量问题 - 更正异常处理和导入语句
Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
@@ -6,10 +6,12 @@ from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, register_plugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ToolInfo
|
||||
|
||||
logger = get_logger("memory_graph_plugin")
|
||||
|
||||
# 用于存储后台任务引用
|
||||
_background_tasks = set()
|
||||
|
||||
|
||||
@register_plugin
|
||||
class MemoryGraphPlugin(BasePlugin):
|
||||
@@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin):
|
||||
"""插件卸载时的回调"""
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from src.memory_graph.manager_singleton import shutdown_memory_manager
|
||||
|
||||
logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
|
||||
@@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin):
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果循环正在运行,创建任务
|
||||
asyncio.create_task(shutdown_memory_manager())
|
||||
task = asyncio.create_task(shutdown_memory_manager())
|
||||
# 存储引用以防止任务被垃圾回收
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
else:
|
||||
# 如果循环未运行,直接运行
|
||||
loop.run_until_complete(shutdown_memory_manager())
|
||||
|
||||
@@ -25,14 +25,13 @@ import asyncio
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager
|
||||
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -65,7 +64,7 @@ class MemoryDeduplicator:
|
||||
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
|
||||
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
|
||||
|
||||
async def find_similar_pairs(self) -> List[Tuple[str, str, float]]:
|
||||
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
|
||||
"""
|
||||
查找所有相似的记忆对(通过向量相似度计算)
|
||||
|
||||
@@ -144,7 +143,7 @@ class MemoryDeduplicator:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
决定保留哪个记忆,删除哪个
|
||||
|
||||
@@ -197,7 +196,7 @@ class MemoryDeduplicator:
|
||||
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
|
||||
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
|
||||
|
||||
logger.info(f"")
|
||||
logger.info("")
|
||||
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
|
||||
logger.info(f" 保留: {keep_id}")
|
||||
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
|
||||
@@ -221,14 +220,14 @@ class MemoryDeduplicator:
|
||||
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
|
||||
|
||||
# 累加访问次数
|
||||
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'):
|
||||
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
|
||||
keep_mem.access_count += remove_mem.access_count
|
||||
|
||||
# 删除相似记忆
|
||||
await self.manager.delete_memory(remove_id)
|
||||
|
||||
self.stats["duplicates_removed"] += 1
|
||||
logger.info(f" ✅ 删除成功")
|
||||
logger.info(" ✅ 删除成功")
|
||||
|
||||
# 让出控制权
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@@ -6,24 +6,24 @@
|
||||
|
||||
from src.memory_graph.manager import MemoryManager
|
||||
from src.memory_graph.models import (
|
||||
EdgeType,
|
||||
Memory,
|
||||
MemoryEdge,
|
||||
MemoryNode,
|
||||
MemoryStatus,
|
||||
MemoryType,
|
||||
NodeType,
|
||||
EdgeType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MemoryManager",
|
||||
"EdgeType",
|
||||
"Memory",
|
||||
"MemoryNode",
|
||||
"MemoryEdge",
|
||||
"MemoryManager",
|
||||
"MemoryNode",
|
||||
"MemoryStatus",
|
||||
"MemoryType",
|
||||
"NodeType",
|
||||
"EdgeType",
|
||||
"MemoryStatus",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
@@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder
|
||||
from src.memory_graph.core.extractor import MemoryExtractor
|
||||
from src.memory_graph.core.node_merger import NodeMerger
|
||||
|
||||
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"]
|
||||
__all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -16,7 +16,6 @@ from src.memory_graph.models import (
|
||||
MemoryEdge,
|
||||
MemoryNode,
|
||||
MemoryStatus,
|
||||
MemoryType,
|
||||
NodeType,
|
||||
)
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
@@ -41,7 +40,7 @@ class MemoryBuilder:
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
graph_store: GraphStore,
|
||||
embedding_generator: Optional[Any] = None,
|
||||
embedding_generator: Any | None = None,
|
||||
):
|
||||
"""
|
||||
初始化记忆构建器
|
||||
@@ -55,7 +54,7 @@ class MemoryBuilder:
|
||||
self.graph_store = graph_store
|
||||
self.embedding_generator = embedding_generator
|
||||
|
||||
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory:
|
||||
async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
|
||||
"""
|
||||
构建完整的记忆对象
|
||||
|
||||
@@ -97,7 +96,7 @@ class MemoryBuilder:
|
||||
edges.append(memory_type_edge)
|
||||
|
||||
# 4. 如果有客体,创建客体节点并连接
|
||||
if "object" in extracted_params and extracted_params["object"]:
|
||||
if extracted_params.get("object"):
|
||||
object_node = await self._create_object_node(
|
||||
content=extracted_params["object"], memory_id=memory_id
|
||||
)
|
||||
@@ -258,11 +257,11 @@ class MemoryBuilder:
|
||||
|
||||
async def _process_attributes(
|
||||
self,
|
||||
attributes: Dict[str, Any],
|
||||
attributes: dict[str, Any],
|
||||
parent_id: str,
|
||||
memory_id: str,
|
||||
importance: float,
|
||||
) -> tuple[List[MemoryNode], List[MemoryEdge]]:
|
||||
) -> tuple[list[MemoryNode], list[MemoryEdge]]:
|
||||
"""
|
||||
处理属性,构建属性子图
|
||||
|
||||
@@ -341,7 +340,7 @@ class MemoryBuilder:
|
||||
|
||||
async def _find_existing_node(
|
||||
self, content: str, node_type: NodeType
|
||||
) -> Optional[MemoryNode]:
|
||||
) -> MemoryNode | None:
|
||||
"""
|
||||
查找已存在的完全匹配节点(用于主体和属性)
|
||||
|
||||
@@ -369,7 +368,7 @@ class MemoryBuilder:
|
||||
|
||||
async def _find_similar_topic(
|
||||
self, content: str, embedding: np.ndarray
|
||||
) -> Optional[MemoryNode]:
|
||||
) -> MemoryNode | None:
|
||||
"""
|
||||
查找相似的主题节点(基于语义相似度)
|
||||
|
||||
@@ -414,7 +413,7 @@ class MemoryBuilder:
|
||||
|
||||
async def _find_similar_object(
|
||||
self, content: str, embedding: np.ndarray
|
||||
) -> Optional[MemoryNode]:
|
||||
) -> MemoryNode | None:
|
||||
"""
|
||||
查找相似的客体节点(基于语义相似度)
|
||||
|
||||
@@ -525,7 +524,7 @@ class MemoryBuilder:
|
||||
logger.error(f"记忆关联失败: {e}", exc_info=True)
|
||||
raise RuntimeError(f"记忆关联失败: {e}")
|
||||
|
||||
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]:
|
||||
def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
|
||||
"""查找记忆中的主题节点"""
|
||||
for node in memory.nodes:
|
||||
if node.node_type == NodeType.TOPIC:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import MemoryType
|
||||
@@ -25,7 +25,7 @@ class MemoryExtractor:
|
||||
4. 清洗和格式化数据
|
||||
"""
|
||||
|
||||
def __init__(self, time_parser: Optional[TimeParser] = None):
|
||||
def __init__(self, time_parser: TimeParser | None = None):
|
||||
"""
|
||||
初始化记忆提取器
|
||||
|
||||
@@ -34,7 +34,7 @@ class MemoryExtractor:
|
||||
"""
|
||||
self.time_parser = time_parser or TimeParser()
|
||||
|
||||
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
从工具参数中提取记忆元素
|
||||
|
||||
@@ -64,11 +64,11 @@ class MemoryExtractor:
|
||||
}
|
||||
|
||||
# 3. 提取可选的客体
|
||||
if "object" in params and params["object"]:
|
||||
if params.get("object"):
|
||||
extracted["object"] = self._clean_text(params["object"])
|
||||
|
||||
# 4. 提取和标准化属性
|
||||
if "attributes" in params and params["attributes"]:
|
||||
if params.get("attributes"):
|
||||
extracted["attributes"] = self._process_attributes(params["attributes"])
|
||||
else:
|
||||
extracted["attributes"] = {}
|
||||
@@ -86,7 +86,7 @@ class MemoryExtractor:
|
||||
logger.error(f"记忆提取失败: {e}", exc_info=True)
|
||||
raise ValueError(f"记忆提取失败: {e}")
|
||||
|
||||
def _validate_required_params(self, params: Dict[str, Any]) -> None:
|
||||
def _validate_required_params(self, params: dict[str, Any]) -> None:
|
||||
"""
|
||||
验证必需参数
|
||||
|
||||
@@ -181,7 +181,7 @@ class MemoryExtractor:
|
||||
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
|
||||
return 0.5
|
||||
|
||||
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
处理属性字典
|
||||
|
||||
@@ -222,7 +222,7 @@ class MemoryExtractor:
|
||||
|
||||
return processed
|
||||
|
||||
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
提取记忆关联参数(用于 link_memories 工具)
|
||||
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.official_configs import MemoryConfig
|
||||
from src.memory_graph.models import MemoryNode, NodeType
|
||||
@@ -54,9 +49,9 @@ class NodeMerger:
|
||||
async def find_similar_nodes(
|
||||
self,
|
||||
node: MemoryNode,
|
||||
threshold: Optional[float] = None,
|
||||
threshold: float | None = None,
|
||||
limit: int = 5,
|
||||
) -> List[Tuple[MemoryNode, float]]:
|
||||
) -> list[tuple[MemoryNode, float]]:
|
||||
"""
|
||||
查找与指定节点相似的节点
|
||||
|
||||
@@ -207,7 +202,7 @@ class NodeMerger:
|
||||
# 如果有 30% 以上的邻居重叠,认为上下文匹配
|
||||
return overlap_ratio > 0.3
|
||||
|
||||
def _get_node_content(self, node_id: str) -> Optional[str]:
|
||||
def _get_node_content(self, node_id: str) -> str | None:
|
||||
"""获取节点的内容"""
|
||||
memories = self.graph_store.get_memories_by_node(node_id)
|
||||
if memories:
|
||||
@@ -280,8 +275,8 @@ class NodeMerger:
|
||||
|
||||
async def batch_merge_similar_nodes(
|
||||
self,
|
||||
nodes: List[MemoryNode],
|
||||
progress_callback: Optional[callable] = None,
|
||||
nodes: list[MemoryNode],
|
||||
progress_callback: callable | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
批量处理节点合并
|
||||
@@ -344,7 +339,7 @@ class NodeMerger:
|
||||
self,
|
||||
min_similarity: float = 0.85,
|
||||
limit: int = 100,
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
) -> list[tuple[str, str, float]]:
|
||||
"""
|
||||
获取待合并的候选节点对
|
||||
|
||||
|
||||
@@ -10,22 +10,21 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.config.official_configs import MemoryConfig
|
||||
from src.memory_graph.core.builder import MemoryBuilder
|
||||
from src.memory_graph.core.extractor import MemoryExtractor
|
||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode, MemoryType, NodeType, EdgeType
|
||||
from src.memory_graph.models import EdgeType, Memory, MemoryEdge, NodeType
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.persistence import PersistenceManager
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
from src.memory_graph.tools.memory_tools import MemoryTools
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
@@ -46,7 +45,7 @@ class MemoryManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Optional[Path] = None,
|
||||
data_dir: Path | None = None,
|
||||
):
|
||||
"""
|
||||
初始化记忆管理器
|
||||
@@ -55,29 +54,29 @@ class MemoryManager:
|
||||
data_dir: 数据目录(可选,默认从global_config读取)
|
||||
"""
|
||||
# 直接使用 global_config.memory
|
||||
if not global_config.memory or not getattr(global_config.memory, 'enable', False):
|
||||
if not global_config.memory or not getattr(global_config.memory, "enable", False):
|
||||
raise ValueError("记忆系统未启用,请在配置文件中启用 [memory] enable = true")
|
||||
|
||||
self.config: MemoryConfig = global_config.memory
|
||||
self.data_dir = data_dir or Path(getattr(self.config, 'data_dir', 'data/memory_graph'))
|
||||
self.data_dir = data_dir or Path(getattr(self.config, "data_dir", "data/memory_graph"))
|
||||
|
||||
# 存储组件
|
||||
self.vector_store: Optional[VectorStore] = None
|
||||
self.graph_store: Optional[GraphStore] = None
|
||||
self.persistence: Optional[PersistenceManager] = None
|
||||
self.vector_store: VectorStore | None = None
|
||||
self.graph_store: GraphStore | None = None
|
||||
self.persistence: PersistenceManager | None = None
|
||||
|
||||
# 核心组件
|
||||
self.embedding_generator: Optional[EmbeddingGenerator] = None
|
||||
self.extractor: Optional[MemoryExtractor] = None
|
||||
self.builder: Optional[MemoryBuilder] = None
|
||||
self.tools: Optional[MemoryTools] = None
|
||||
self.embedding_generator: EmbeddingGenerator | None = None
|
||||
self.extractor: MemoryExtractor | None = None
|
||||
self.builder: MemoryBuilder | None = None
|
||||
self.tools: MemoryTools | None = None
|
||||
|
||||
# 状态
|
||||
self._initialized = False
|
||||
self._last_maintenance = datetime.now()
|
||||
self._maintenance_task: Optional[asyncio.Task] = None
|
||||
self._maintenance_interval_hours = getattr(self.config, 'consolidation_interval_hours', 1.0)
|
||||
self._maintenance_schedule_id: Optional[str] = None # 调度任务ID
|
||||
self._maintenance_task: asyncio.Task | None = None
|
||||
self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
|
||||
self._maintenance_schedule_id: str | None = None # 调度任务ID
|
||||
|
||||
logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
|
||||
|
||||
@@ -101,8 +100,8 @@ class MemoryManager:
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取存储配置
|
||||
storage_config = getattr(self.config, 'storage', None)
|
||||
vector_collection_name = getattr(storage_config, 'vector_collection_name', 'memory_graph') if storage_config else 'memory_graph'
|
||||
storage_config = getattr(self.config, "storage", None)
|
||||
vector_collection_name = getattr(storage_config, "vector_collection_name", "memory_graph") if storage_config else "memory_graph"
|
||||
|
||||
self.vector_store = VectorStore(
|
||||
collection_name=vector_collection_name,
|
||||
@@ -203,11 +202,11 @@ class MemoryManager:
|
||||
subject: str,
|
||||
memory_type: str,
|
||||
topic: str,
|
||||
object: Optional[str] = None,
|
||||
attributes: Optional[Dict[str, str]] = None,
|
||||
object: str | None = None,
|
||||
attributes: dict[str, str] | None = None,
|
||||
importance: float = 0.5,
|
||||
**kwargs,
|
||||
) -> Optional[Memory]:
|
||||
) -> Memory | None:
|
||||
"""
|
||||
创建新记忆
|
||||
|
||||
@@ -250,7 +249,7 @@ class MemoryManager:
|
||||
logger.error(f"创建记忆时发生异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def get_memory(self, memory_id: str) -> Optional[Memory]:
|
||||
async def get_memory(self, memory_id: str) -> Memory | None:
|
||||
"""
|
||||
根据 ID 获取记忆
|
||||
|
||||
@@ -348,8 +347,8 @@ class MemoryManager:
|
||||
async def generate_multi_queries(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[str, float]]:
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
使用小模型生成多个查询语句(用于多路召回)
|
||||
|
||||
@@ -364,8 +363,8 @@ class MemoryManager:
|
||||
List of (query_string, weight) - 查询语句和权重
|
||||
"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
@@ -407,9 +406,10 @@ class MemoryManager:
|
||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||
|
||||
# 解析JSON
|
||||
import json, re
|
||||
response = re.sub(r'```json\s*', '', response)
|
||||
response = re.sub(r'```\s*$', '', response).strip()
|
||||
import json
|
||||
import re
|
||||
response = re.sub(r"```json\s*", "", response)
|
||||
response = re.sub(r"```\s*$", "", response).strip()
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
@@ -439,14 +439,14 @@ class MemoryManager:
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
time_range: Optional[Tuple[datetime, datetime]] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
time_range: tuple[datetime, datetime] | None = None,
|
||||
min_importance: float = 0.0,
|
||||
include_forgotten: bool = False,
|
||||
use_multi_query: bool = True,
|
||||
expand_depth: int | None = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Memory]:
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> list[Memory]:
|
||||
"""
|
||||
搜索记忆
|
||||
|
||||
@@ -611,7 +611,7 @@ class MemoryManager:
|
||||
# 计算时间衰减
|
||||
last_access_dt = datetime.fromisoformat(last_access)
|
||||
hours_passed = (now - last_access_dt).total_seconds() / 3600
|
||||
decay_rate = getattr(self.config, 'activation_decay_rate', 0.95)
|
||||
decay_rate = getattr(self.config, "activation_decay_rate", 0.95)
|
||||
decay_factor = decay_rate ** (hours_passed / 24)
|
||||
current_activation = activation_info.get("level", 0.0) * decay_factor
|
||||
else:
|
||||
@@ -631,15 +631,15 @@ class MemoryManager:
|
||||
|
||||
# 激活传播:激活相关记忆
|
||||
if strength > 0.1: # 只有足够强的激活才传播
|
||||
propagation_depth = getattr(self.config, 'activation_propagation_depth', 2)
|
||||
propagation_depth = getattr(self.config, "activation_propagation_depth", 2)
|
||||
related_memories = self._get_related_memories(
|
||||
memory_id,
|
||||
max_depth=propagation_depth
|
||||
)
|
||||
propagation_strength_factor = getattr(self.config, 'activation_propagation_strength', 0.5)
|
||||
propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5)
|
||||
propagation_strength = strength * propagation_strength_factor
|
||||
|
||||
max_related = getattr(self.config, 'max_related_memories', 5)
|
||||
max_related = getattr(self.config, "max_related_memories", 5)
|
||||
for related_id in related_memories[:max_related]:
|
||||
await self.activate_memory(related_id, propagation_strength)
|
||||
|
||||
@@ -652,7 +652,7 @@ class MemoryManager:
|
||||
logger.error(f"激活记忆失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> List[str]:
|
||||
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]:
|
||||
"""
|
||||
获取相关记忆 ID 列表(旧版本,保留用于激活传播)
|
||||
|
||||
@@ -687,12 +687,12 @@ class MemoryManager:
|
||||
|
||||
async def expand_memories_with_semantic_filter(
|
||||
self,
|
||||
initial_memory_ids: List[str],
|
||||
initial_memory_ids: list[str],
|
||||
query_embedding: "np.ndarray",
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20
|
||||
) -> List[Tuple[str, float]]:
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
@@ -712,12 +712,11 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
expanded_memories: Dict[str, float] = {}
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
# BFS扩展
|
||||
current_level = initial_memory_ids
|
||||
@@ -738,7 +737,7 @@ class MemoryManager:
|
||||
# 获取邻居节点
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
@@ -764,7 +763,7 @@ class MemoryManager:
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except:
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
|
||||
@@ -785,7 +784,7 @@ class MemoryManager:
|
||||
import json
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except:
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
@@ -909,7 +908,7 @@ class MemoryManager:
|
||||
continue
|
||||
|
||||
# 跳过高重要性记忆
|
||||
min_importance = getattr(self.config, 'forgetting_min_importance', 7.0)
|
||||
min_importance = getattr(self.config, "forgetting_min_importance", 7.0)
|
||||
if memory.importance >= min_importance:
|
||||
continue
|
||||
|
||||
@@ -939,7 +938,7 @@ class MemoryManager:
|
||||
|
||||
# ==================== 统计与维护 ====================
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统统计信息
|
||||
|
||||
@@ -980,7 +979,7 @@ class MemoryManager:
|
||||
similarity_threshold: float = 0.85,
|
||||
time_window_hours: float = 24.0,
|
||||
max_batch_size: int = 50,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
整理记忆:直接合并去重相似记忆(不创建新边)
|
||||
|
||||
@@ -1065,7 +1064,7 @@ class MemoryManager:
|
||||
result["checked_count"] = len(recent_memories)
|
||||
|
||||
# 按记忆类型分组,减少跨类型比较
|
||||
memories_by_type: Dict[str, List[Memory]] = {}
|
||||
memories_by_type: dict[str, list[Memory]] = {}
|
||||
for mem in recent_memories:
|
||||
mem_type = mem.metadata.get("memory_type", "")
|
||||
if mem_type not in memories_by_type:
|
||||
@@ -1073,7 +1072,7 @@ class MemoryManager:
|
||||
memories_by_type[mem_type].append(mem)
|
||||
|
||||
# 记录需要删除的记忆,延迟批量删除
|
||||
to_delete: List[Tuple[Memory, str]] = [] # (memory, reason)
|
||||
to_delete: list[tuple[Memory, str]] = [] # (memory, reason)
|
||||
deleted_ids = set()
|
||||
|
||||
# 对每个类型的记忆进行相似度检测
|
||||
@@ -1084,7 +1083,7 @@ class MemoryManager:
|
||||
logger.debug(f"🔍 检查类型 '{mem_type}' 的 {len(memories)} 条记忆")
|
||||
|
||||
# 预提取所有主题节点的嵌入向量
|
||||
embeddings_map: Dict[str, "np.ndarray"] = {}
|
||||
embeddings_map: dict[str, "np.ndarray"] = {}
|
||||
valid_memories = []
|
||||
|
||||
for mem in memories:
|
||||
@@ -1094,7 +1093,6 @@ class MemoryManager:
|
||||
valid_memories.append(mem)
|
||||
|
||||
# 批量计算相似度矩阵(比逐个计算更高效)
|
||||
import numpy as np
|
||||
|
||||
for i in range(len(valid_memories)):
|
||||
# 更频繁的协作式多任务让出
|
||||
@@ -1134,7 +1132,7 @@ class MemoryManager:
|
||||
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
|
||||
|
||||
# 累加访问次数
|
||||
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'):
|
||||
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
|
||||
keep_mem.access_count += remove_mem.access_count
|
||||
|
||||
# 标记为待删除(不立即删除)
|
||||
@@ -1164,7 +1162,7 @@ class MemoryManager:
|
||||
|
||||
# 批量保存(一次性写入,减少I/O)
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
logger.info(f"💾 批量保存完成")
|
||||
logger.info("💾 批量保存完成")
|
||||
|
||||
logger.info(f"✅ 记忆整理完成: {result}")
|
||||
|
||||
@@ -1207,10 +1205,10 @@ class MemoryManager:
|
||||
|
||||
async def auto_link_memories(
|
||||
self,
|
||||
time_window_hours: float = None,
|
||||
max_candidates: int = None,
|
||||
min_confidence: float = None,
|
||||
) -> Dict[str, Any]:
|
||||
time_window_hours: float | None = None,
|
||||
max_candidates: int | None = None,
|
||||
min_confidence: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
自动关联记忆
|
||||
|
||||
@@ -1229,8 +1227,8 @@ class MemoryManager:
|
||||
|
||||
# 使用配置值或参数覆盖
|
||||
time_window_hours = time_window_hours if time_window_hours is not None else 24
|
||||
max_candidates = max_candidates if max_candidates is not None else getattr(self.config, 'auto_link_max_candidates', 10)
|
||||
min_confidence = min_confidence if min_confidence is not None else getattr(self.config, 'auto_link_min_confidence', 0.7)
|
||||
max_candidates = max_candidates if max_candidates is not None else getattr(self.config, "auto_link_max_candidates", 10)
|
||||
min_confidence = min_confidence if min_confidence is not None else getattr(self.config, "auto_link_min_confidence", 0.7)
|
||||
|
||||
try:
|
||||
logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...")
|
||||
@@ -1361,9 +1359,9 @@ class MemoryManager:
|
||||
async def _find_link_candidates(
|
||||
self,
|
||||
memory: Memory,
|
||||
exclude_ids: Set[str],
|
||||
exclude_ids: set[str],
|
||||
max_results: int = 5,
|
||||
) -> List[Memory]:
|
||||
) -> list[Memory]:
|
||||
"""
|
||||
为记忆寻找关联候选
|
||||
|
||||
@@ -1407,9 +1405,9 @@ class MemoryManager:
|
||||
async def _analyze_memory_relations(
|
||||
self,
|
||||
source_memory: Memory,
|
||||
candidate_memories: List[Memory],
|
||||
candidate_memories: list[Memory],
|
||||
min_confidence: float = 0.7,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
使用LLM分析记忆之间的关系
|
||||
|
||||
@@ -1426,8 +1424,8 @@ class MemoryManager:
|
||||
- reasoning: 推理过程
|
||||
"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 构建LLM请求
|
||||
llm = LLMRequest(
|
||||
@@ -1496,7 +1494,7 @@ class MemoryManager:
|
||||
import re
|
||||
|
||||
# 提取JSON
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
@@ -1507,7 +1505,7 @@ class MemoryManager:
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"LLM返回格式错误,尝试修复: {response[:200]}")
|
||||
# 尝试简单修复
|
||||
json_str = re.sub(r'[\r\n\t]', '', json_str)
|
||||
json_str = re.sub(r"[\r\n\t]", "", json_str)
|
||||
analysis_results = json.loads(json_str)
|
||||
|
||||
# 转换为结果格式
|
||||
@@ -1574,7 +1572,7 @@ class MemoryManager:
|
||||
logger.warning(f"格式化记忆失败: {e}")
|
||||
return f"记忆ID: {memory.id}"
|
||||
|
||||
async def maintenance(self) -> Dict[str, Any]:
|
||||
async def maintenance(self) -> dict[str, Any]:
|
||||
"""
|
||||
执行维护任务(优化版本)
|
||||
|
||||
@@ -1604,12 +1602,12 @@ class MemoryManager:
|
||||
start_time = datetime.now()
|
||||
|
||||
# 1. 记忆整理(异步后台执行,不阻塞主流程)
|
||||
if getattr(self.config, 'consolidation_enabled', False):
|
||||
if getattr(self.config, "consolidation_enabled", False):
|
||||
logger.info("🚀 启动异步记忆整理任务...")
|
||||
consolidate_result = await self.consolidate_memories(
|
||||
similarity_threshold=getattr(self.config, 'consolidation_deduplication_threshold', 0.93),
|
||||
time_window_hours=getattr(self.config, 'consolidation_time_window_hours', 2.0), # 统一时间窗口
|
||||
max_batch_size=getattr(self.config, 'consolidation_max_batch_size', 30)
|
||||
similarity_threshold=getattr(self.config, "consolidation_deduplication_threshold", 0.93),
|
||||
time_window_hours=getattr(self.config, "consolidation_time_window_hours", 2.0), # 统一时间窗口
|
||||
max_batch_size=getattr(self.config, "consolidation_max_batch_size", 30)
|
||||
)
|
||||
|
||||
if consolidate_result.get("task_started"):
|
||||
@@ -1620,16 +1618,16 @@ class MemoryManager:
|
||||
logger.warning("❌ 记忆整理任务启动失败")
|
||||
|
||||
# 2. 自动关联记忆(使用统一的时间窗口)
|
||||
if getattr(self.config, 'consolidation_linking_enabled', True):
|
||||
if getattr(self.config, "consolidation_linking_enabled", True):
|
||||
logger.info("🔗 执行轻量级自动关联...")
|
||||
link_result = await self._lightweight_auto_link_memories()
|
||||
result["linked"] = link_result.get("linked_count", 0)
|
||||
|
||||
# 3. 自动遗忘(快速执行)
|
||||
if getattr(self.config, 'forgetting_enabled', True):
|
||||
if getattr(self.config, "forgetting_enabled", True):
|
||||
logger.info("🗑️ 执行自动遗忘...")
|
||||
forgotten_count = await self.auto_forget_memories(
|
||||
threshold=getattr(self.config, 'forgetting_activation_threshold', 0.1)
|
||||
threshold=getattr(self.config, "forgetting_activation_threshold", 0.1)
|
||||
)
|
||||
result["forgotten"] = forgotten_count
|
||||
|
||||
@@ -1654,10 +1652,10 @@ class MemoryManager:
|
||||
|
||||
async def _lightweight_auto_link_memories(
|
||||
self,
|
||||
time_window_hours: float = None, # 从配置读取
|
||||
max_candidates: int = None, # 从配置读取
|
||||
max_memories: int = None, # 从配置读取
|
||||
) -> Dict[str, Any]:
|
||||
time_window_hours: float | None = None, # 从配置读取
|
||||
max_candidates: int | None = None, # 从配置读取
|
||||
max_memories: int | None = None, # 从配置读取
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
智能轻量级自动关联记忆(保留LLM判断,优化性能)
|
||||
|
||||
@@ -1676,11 +1674,11 @@ class MemoryManager:
|
||||
|
||||
# 从配置读取参数,使用统一的时间窗口
|
||||
if time_window_hours is None:
|
||||
time_window_hours = getattr(self.config, 'consolidation_time_window_hours', 2.0)
|
||||
time_window_hours = getattr(self.config, "consolidation_time_window_hours", 2.0)
|
||||
if max_candidates is None:
|
||||
max_candidates = getattr(self.config, 'consolidation_linking_max_candidates', 10)
|
||||
max_candidates = getattr(self.config, "consolidation_linking_max_candidates", 10)
|
||||
if max_memories is None:
|
||||
max_memories = getattr(self.config, 'consolidation_linking_max_memories', 20)
|
||||
max_memories = getattr(self.config, "consolidation_linking_max_memories", 20)
|
||||
|
||||
# 获取用户配置时间窗口内的记忆
|
||||
time_threshold = datetime.now() - timedelta(hours=time_window_hours)
|
||||
@@ -1690,7 +1688,7 @@ class MemoryManager:
|
||||
mem for mem in all_memories
|
||||
if mem.created_at >= time_threshold
|
||||
and not mem.metadata.get("forgotten", False)
|
||||
and mem.importance >= getattr(self.config, 'consolidation_linking_min_importance', 0.5) # 从配置读取重要性阈值
|
||||
and mem.importance >= getattr(self.config, "consolidation_linking_min_importance", 0.5) # 从配置读取重要性阈值
|
||||
]
|
||||
|
||||
if len(recent_memories) > max_memories:
|
||||
@@ -1704,7 +1702,6 @@ class MemoryManager:
|
||||
|
||||
# 第一步:向量相似度预筛选,找到潜在关联对
|
||||
candidate_pairs = []
|
||||
import numpy as np
|
||||
|
||||
for i, memory in enumerate(recent_memories):
|
||||
# 获取主题节点
|
||||
@@ -1733,7 +1730,7 @@ class MemoryManager:
|
||||
)
|
||||
|
||||
# 使用配置的预筛选阈值
|
||||
pre_filter_threshold = getattr(self.config, 'consolidation_linking_pre_filter_threshold', 0.7)
|
||||
pre_filter_threshold = getattr(self.config, "consolidation_linking_pre_filter_threshold", 0.7)
|
||||
if similarity >= pre_filter_threshold:
|
||||
candidate_pairs.append((memory, other_memory, similarity))
|
||||
|
||||
@@ -1747,7 +1744,7 @@ class MemoryManager:
|
||||
return result
|
||||
|
||||
# 第二步:批量LLM分析(使用配置的最大候选对数)
|
||||
max_pairs_for_llm = getattr(self.config, 'consolidation_linking_max_pairs_for_llm', 5)
|
||||
max_pairs_for_llm = getattr(self.config, "consolidation_linking_max_pairs_for_llm", 5)
|
||||
if len(candidate_pairs) <= max_pairs_for_llm:
|
||||
link_relations = await self._batch_analyze_memory_relations(candidate_pairs)
|
||||
result["llm_calls"] = 1
|
||||
@@ -1810,8 +1807,8 @@ class MemoryManager:
|
||||
|
||||
async def _batch_analyze_memory_relations(
|
||||
self,
|
||||
candidate_pairs: List[Tuple[Memory, Memory, float]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate_pairs: list[tuple[Memory, Memory, float]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
批量分析记忆关系(优化LLM调用)
|
||||
|
||||
@@ -1822,8 +1819,8 @@ class MemoryManager:
|
||||
关系分析结果列表
|
||||
"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
@@ -1843,7 +1840,7 @@ class MemoryManager:
|
||||
"""
|
||||
|
||||
# 构建批量分析提示词(使用配置的置信度阈值)
|
||||
min_confidence = getattr(self.config, 'consolidation_linking_min_confidence', 0.7)
|
||||
min_confidence = getattr(self.config, "consolidation_linking_min_confidence", 0.7)
|
||||
|
||||
prompt = f"""你是记忆关系分析专家。请批量分析以下候选记忆对之间的关系。
|
||||
|
||||
@@ -1885,8 +1882,8 @@ class MemoryManager:
|
||||
请分析并输出JSON结果:"""
|
||||
|
||||
# 调用LLM(使用配置的参数)
|
||||
llm_temperature = getattr(self.config, 'consolidation_linking_llm_temperature', 0.2)
|
||||
llm_max_tokens = getattr(self.config, 'consolidation_linking_llm_max_tokens', 1500)
|
||||
llm_temperature = getattr(self.config, "consolidation_linking_llm_temperature", 0.2)
|
||||
llm_max_tokens = getattr(self.config, "consolidation_linking_llm_max_tokens", 1500)
|
||||
|
||||
response, _ = await llm.generate_response_async(
|
||||
prompt,
|
||||
@@ -1899,7 +1896,7 @@ class MemoryManager:
|
||||
import re
|
||||
|
||||
# 提取JSON
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
@@ -1910,7 +1907,7 @@ class MemoryManager:
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"LLM返回格式错误,尝试修复: {response[:200]}")
|
||||
# 尝试简单修复
|
||||
json_str = re.sub(r'[\r\n\t]', '', json_str)
|
||||
json_str = re.sub(r"[\r\n\t]", "", json_str)
|
||||
analysis_results = json.loads(json_str)
|
||||
|
||||
# 转换为结果格式
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager import MemoryManager
|
||||
@@ -15,13 +14,13 @@ from src.memory_graph.manager import MemoryManager
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 MemoryManager 实例
|
||||
_memory_manager: Optional[MemoryManager] = None
|
||||
_memory_manager: MemoryManager | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
async def initialize_memory_manager(
|
||||
data_dir: Optional[Path | str] = None,
|
||||
) -> Optional[MemoryManager]:
|
||||
data_dir: Path | str | None = None,
|
||||
) -> MemoryManager | None:
|
||||
"""
|
||||
初始化全局 MemoryManager
|
||||
|
||||
@@ -43,7 +42,7 @@ async def initialize_memory_manager(
|
||||
from src.config.config import global_config
|
||||
|
||||
# 检查是否启用
|
||||
if not global_config.memory or not getattr(global_config.memory, 'enable', False):
|
||||
if not global_config.memory or not getattr(global_config.memory, "enable", False):
|
||||
logger.info("记忆图系统已在配置中禁用")
|
||||
_initialized = False
|
||||
_memory_manager = None
|
||||
@@ -51,7 +50,7 @@ async def initialize_memory_manager(
|
||||
|
||||
# 处理数据目录
|
||||
if data_dir is None:
|
||||
data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph')
|
||||
data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
|
||||
if isinstance(data_dir, str):
|
||||
data_dir = Path(data_dir)
|
||||
|
||||
@@ -72,7 +71,7 @@ async def initialize_memory_manager(
|
||||
raise
|
||||
|
||||
|
||||
def get_memory_manager() -> Optional[MemoryManager]:
|
||||
def get_memory_manager() -> MemoryManager | None:
|
||||
"""
|
||||
获取全局 MemoryManager 实例
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -60,8 +60,8 @@ class MemoryNode:
|
||||
id: str # 节点唯一ID
|
||||
content: str # 节点内容(如:"我"、"吃饭"、"白米饭")
|
||||
node_type: NodeType # 节点类型
|
||||
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -69,7 +69,7 @@ class MemoryNode:
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -81,7 +81,7 @@ class MemoryNode:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode:
|
||||
def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
|
||||
"""从字典创建节点"""
|
||||
embedding = None
|
||||
if data.get("embedding") is not None:
|
||||
@@ -114,7 +114,7 @@ class MemoryEdge:
|
||||
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为")
|
||||
edge_type: EdgeType # 边类型
|
||||
importance: float = 0.5 # 重要性 [0-1]
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -124,7 +124,7 @@ class MemoryEdge:
|
||||
# 确保重要性在有效范围内
|
||||
self.importance = max(0.0, min(1.0, self.importance))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -138,7 +138,7 @@ class MemoryEdge:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge:
|
||||
def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
|
||||
"""从字典创建边"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
@@ -162,8 +162,8 @@ class Memory:
|
||||
id: str # 记忆唯一ID
|
||||
subject_id: str # 主体节点ID
|
||||
memory_type: MemoryType # 记忆类型
|
||||
nodes: List[MemoryNode] # 该记忆包含的所有节点
|
||||
edges: List[MemoryEdge] # 该记忆包含的所有边
|
||||
nodes: list[MemoryNode] # 该记忆包含的所有节点
|
||||
edges: list[MemoryEdge] # 该记忆包含的所有边
|
||||
importance: float = 0.5 # 整体重要性 [0-1]
|
||||
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
||||
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
||||
@@ -171,7 +171,7 @@ class Memory:
|
||||
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
||||
access_count: int = 0 # 访问次数
|
||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -181,7 +181,7 @@ class Memory:
|
||||
self.importance = max(0.0, min(1.0, self.importance))
|
||||
self.activation = max(0.0, min(1.0, self.activation))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -200,7 +200,7 @@ class Memory:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Memory:
|
||||
def from_dict(cls, data: dict[str, Any]) -> Memory:
|
||||
"""从字典创建记忆"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
@@ -223,14 +223,14 @@ class Memory:
|
||||
self.last_accessed = datetime.now()
|
||||
self.access_count += 1
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]:
|
||||
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
|
||||
"""根据ID获取节点"""
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_subject_node(self) -> Optional[MemoryNode]:
|
||||
def get_subject_node(self) -> MemoryNode | None:
|
||||
"""获取主体节点"""
|
||||
return self.get_node_by_id(self.subject_id)
|
||||
|
||||
@@ -274,10 +274,10 @@ class StagedMemory:
|
||||
memory: Memory # 原始记忆对象
|
||||
status: MemoryStatus = MemoryStatus.STAGED # 状态
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
consolidated_at: Optional[datetime] = None # 整理时间
|
||||
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表
|
||||
consolidated_at: datetime | None = None # 整理时间
|
||||
merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"memory": self.memory.to_dict(),
|
||||
@@ -288,7 +288,7 @@ class StagedMemory:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory:
|
||||
def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
|
||||
"""从字典创建临时记忆"""
|
||||
return cls(
|
||||
memory=Memory.from_dict(data["memory"]),
|
||||
|
||||
@@ -58,7 +58,7 @@ class CreateMemoryTool(BaseTool):
|
||||
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
|
||||
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
|
||||
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填,如果有明确对象建议填写", False, None),
|
||||
("attributes", ToolParamType.STRING, "详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None),
|
||||
("attributes", ToolParamType.STRING, '详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
|
||||
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考:日常琐事0.3-0.4,一般对话0.5-0.6,重要信息0.7-0.8,核心记忆0.9-1.0。不确定时用0.5", False, None),
|
||||
]
|
||||
|
||||
@@ -124,7 +124,7 @@ class CreateMemoryTool(BaseTool):
|
||||
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"创建记忆时出错: {str(e)}"
|
||||
"content": f"创建记忆时出错: {e!s}"
|
||||
}
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ class LinkMemoriesTool(BaseTool):
|
||||
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"关联记忆时出错: {str(e)}"
|
||||
"content": f"关联记忆时出错: {e!s}"
|
||||
}
|
||||
|
||||
|
||||
@@ -254,5 +254,5 @@ class SearchMemoriesTool(BaseTool):
|
||||
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"搜索记忆时出错: {str(e)}"
|
||||
"content": f"搜索记忆时出错: {e!s}"
|
||||
}
|
||||
|
||||
@@ -5,4 +5,4 @@
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
|
||||
__all__ = ["VectorStore", "GraphStore"]
|
||||
__all__ = ["GraphStore", "VectorStore"]
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
|
||||
from src.memory_graph.models import Memory, MemoryEdge
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -31,10 +29,10 @@ class GraphStore:
|
||||
self.graph = nx.DiGraph()
|
||||
|
||||
# 索引:记忆ID -> 记忆对象
|
||||
self.memory_index: Dict[str, Memory] = {}
|
||||
self.memory_index: dict[str, Memory] = {}
|
||||
|
||||
# 索引:节点ID -> 所属记忆ID集合
|
||||
self.node_to_memories: Dict[str, Set[str]] = {}
|
||||
self.node_to_memories: dict[str, set[str]] = {}
|
||||
|
||||
logger.info("初始化图存储")
|
||||
|
||||
@@ -84,7 +82,7 @@ class GraphStore:
|
||||
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
|
||||
def get_memory_by_id(self, memory_id: str) -> Memory | None:
|
||||
"""
|
||||
根据ID获取记忆
|
||||
|
||||
@@ -96,7 +94,7 @@ class GraphStore:
|
||||
"""
|
||||
return self.memory_index.get(memory_id)
|
||||
|
||||
def get_all_memories(self) -> List[Memory]:
|
||||
def get_all_memories(self) -> list[Memory]:
|
||||
"""
|
||||
获取所有记忆
|
||||
|
||||
@@ -105,7 +103,7 @@ class GraphStore:
|
||||
"""
|
||||
return list(self.memory_index.values())
|
||||
|
||||
def get_memories_by_node(self, node_id: str) -> List[Memory]:
|
||||
def get_memories_by_node(self, node_id: str) -> list[Memory]:
|
||||
"""
|
||||
获取包含指定节点的所有记忆
|
||||
|
||||
@@ -121,7 +119,7 @@ class GraphStore:
|
||||
memory_ids = self.node_to_memories[node_id]
|
||||
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
||||
|
||||
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
|
||||
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
|
||||
"""
|
||||
获取从指定节点出发的所有边
|
||||
|
||||
@@ -155,8 +153,8 @@ class GraphStore:
|
||||
return edges
|
||||
|
||||
def get_neighbors(
|
||||
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
|
||||
) -> List[Tuple[str, Dict]]:
|
||||
self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
|
||||
) -> list[tuple[str, dict]]:
|
||||
"""
|
||||
获取节点的邻居节点
|
||||
|
||||
@@ -187,7 +185,7 @@ class GraphStore:
|
||||
|
||||
return neighbors
|
||||
|
||||
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
|
||||
def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
|
||||
"""
|
||||
查找两个节点之间的最短路径
|
||||
|
||||
@@ -220,10 +218,10 @@ class GraphStore:
|
||||
|
||||
def bfs_expand(
|
||||
self,
|
||||
start_nodes: List[str],
|
||||
start_nodes: list[str],
|
||||
depth: int = 1,
|
||||
relation_types: Optional[List[str]] = None,
|
||||
) -> Set[str]:
|
||||
relation_types: list[str] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
从起始节点进行广度优先搜索扩展
|
||||
|
||||
@@ -256,7 +254,7 @@ class GraphStore:
|
||||
|
||||
return visited
|
||||
|
||||
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
|
||||
def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
|
||||
"""
|
||||
获取包含指定节点的子图
|
||||
|
||||
@@ -308,7 +306,7 @@ class GraphStore:
|
||||
logger.error(f"合并节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
|
||||
def get_node_degree(self, node_id: str) -> tuple[int, int]:
|
||||
"""
|
||||
获取节点的度数
|
||||
|
||||
@@ -323,7 +321,7 @@ class GraphStore:
|
||||
|
||||
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
def get_statistics(self) -> dict[str, int]:
|
||||
"""获取图的统计信息"""
|
||||
return {
|
||||
"total_nodes": self.graph.number_of_nodes(),
|
||||
@@ -332,7 +330,7 @@ class GraphStore:
|
||||
"connected_components": nx.number_weakly_connected_components(self.graph),
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
将图转换为字典(用于持久化)
|
||||
|
||||
@@ -356,7 +354,7 @@ class GraphStore:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> GraphStore:
|
||||
def from_dict(cls, data: dict) -> GraphStore:
|
||||
"""
|
||||
从字典加载图
|
||||
|
||||
@@ -406,7 +404,6 @@ class GraphStore:
|
||||
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
|
||||
已存在的边(通过 edge.id 检查)将不会重复添加。
|
||||
"""
|
||||
from src.memory_graph.models import MemoryEdge
|
||||
|
||||
# 构建快速查重索引:memory_id -> set(edge_id)
|
||||
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}
|
||||
|
||||
@@ -8,14 +8,12 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import Memory, StagedMemory
|
||||
from src.memory_graph.models import StagedMemory
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -55,7 +53,7 @@ class PersistenceManager:
|
||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.auto_save_interval = auto_save_interval
|
||||
self._auto_save_task: Optional[asyncio.Task] = None
|
||||
self._auto_save_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
logger.info(f"初始化持久化管理器: data_dir={data_dir}")
|
||||
@@ -95,7 +93,7 @@ class PersistenceManager:
|
||||
logger.error(f"保存图数据失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def load_graph_store(self) -> Optional[GraphStore]:
|
||||
async def load_graph_store(self) -> GraphStore | None:
|
||||
"""
|
||||
从文件加载图存储
|
||||
|
||||
@@ -179,7 +177,7 @@ class PersistenceManager:
|
||||
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def create_backup(self) -> Optional[Path]:
|
||||
async def create_backup(self) -> Path | None:
|
||||
"""
|
||||
创建当前数据的备份
|
||||
|
||||
@@ -208,7 +206,7 @@ class PersistenceManager:
|
||||
logger.error(f"创建备份失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _load_from_backup(self) -> Optional[GraphStore]:
|
||||
async def _load_from_backup(self) -> GraphStore | None:
|
||||
"""从最新的备份加载数据"""
|
||||
try:
|
||||
# 查找最新的备份文件
|
||||
@@ -254,7 +252,7 @@ class PersistenceManager:
|
||||
async def start_auto_save(
|
||||
self,
|
||||
graph_store: GraphStore,
|
||||
staged_memories_getter: callable = None,
|
||||
staged_memories_getter: callable | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
启动自动保存任务
|
||||
@@ -334,7 +332,7 @@ class PersistenceManager:
|
||||
logger.error(f"导出图数据失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def import_from_json(self, input_file: Path) -> Optional[GraphStore]:
|
||||
async def import_from_json(self, input_file: Path) -> GraphStore | None:
|
||||
"""
|
||||
从 JSON 文件导入图数据
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -29,8 +28,8 @@ class VectorStore:
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "memory_nodes",
|
||||
data_dir: Optional[Path] = None,
|
||||
embedding_function: Optional[Any] = None,
|
||||
data_dir: Path | None = None,
|
||||
embedding_function: Any | None = None,
|
||||
):
|
||||
"""
|
||||
初始化向量存储
|
||||
@@ -103,7 +102,7 @@ class VectorStore:
|
||||
for key, value in node.metadata.items():
|
||||
if isinstance(value, (list, dict)):
|
||||
import orjson
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||
metadata[key] = value
|
||||
else:
|
||||
@@ -122,7 +121,7 @@ class VectorStore:
|
||||
logger.error(f"添加节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
|
||||
async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
|
||||
"""
|
||||
批量添加节点
|
||||
|
||||
@@ -151,7 +150,7 @@ class VectorStore:
|
||||
}
|
||||
for key, value in n.metadata.items():
|
||||
if isinstance(value, (list, dict)):
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||
metadata[key] = value # type: ignore
|
||||
else:
|
||||
@@ -175,9 +174,9 @@ class VectorStore:
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
limit: int = 10,
|
||||
node_types: Optional[List[NodeType]] = None,
|
||||
node_types: list[NodeType] | None = None,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||
"""
|
||||
搜索相似节点
|
||||
|
||||
@@ -226,10 +225,10 @@ class VectorStore:
|
||||
|
||||
# 解析 JSON 字符串回列表/字典
|
||||
for key, value in list(metadata.items()):
|
||||
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
|
||||
if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
|
||||
try:
|
||||
metadata[key] = orjson.loads(value)
|
||||
except:
|
||||
except Exception:
|
||||
pass # 保持原值
|
||||
|
||||
similar_nodes.append((node_id, similarity, metadata))
|
||||
@@ -243,13 +242,13 @@ class VectorStore:
|
||||
|
||||
async def search_with_multiple_queries(
|
||||
self,
|
||||
query_embeddings: List[np.ndarray],
|
||||
query_weights: Optional[List[float]] = None,
|
||||
query_embeddings: list[np.ndarray],
|
||||
query_weights: list[float] | None = None,
|
||||
limit: int = 10,
|
||||
node_types: Optional[List[NodeType]] = None,
|
||||
node_types: list[NodeType] | None = None,
|
||||
min_similarity: float = 0.0,
|
||||
fusion_strategy: str = "weighted_max",
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||
"""
|
||||
多查询融合搜索
|
||||
|
||||
@@ -287,7 +286,7 @@ class VectorStore:
|
||||
|
||||
try:
|
||||
# 1. 对每个查询执行搜索
|
||||
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata}
|
||||
all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
|
||||
|
||||
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
|
||||
# 搜索更多结果以提高融合质量
|
||||
@@ -356,7 +355,7 @@ class VectorStore:
|
||||
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
根据ID获取节点元数据
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.core.builder import MemoryBuilder
|
||||
from src.memory_graph.core.extractor import MemoryExtractor
|
||||
from src.memory_graph.models import Memory, MemoryStatus
|
||||
from src.memory_graph.models import Memory
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.persistence import PersistenceManager
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
@@ -33,7 +33,7 @@ class MemoryTools:
|
||||
vector_store: VectorStore,
|
||||
graph_store: GraphStore,
|
||||
persistence_manager: PersistenceManager,
|
||||
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||
embedding_generator: EmbeddingGenerator | None = None,
|
||||
max_expand_depth: int = 1,
|
||||
expand_semantic_threshold: float = 0.3,
|
||||
):
|
||||
@@ -72,7 +72,7 @@ class MemoryTools:
|
||||
self._initialized = True
|
||||
|
||||
@staticmethod
|
||||
def get_create_memory_schema() -> Dict[str, Any]:
|
||||
def get_create_memory_schema() -> dict[str, Any]:
|
||||
"""
|
||||
获取 create_memory 工具的 JSON schema
|
||||
|
||||
@@ -183,7 +183,7 @@ class MemoryTools:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_link_memories_schema() -> Dict[str, Any]:
|
||||
def get_link_memories_schema() -> dict[str, Any]:
|
||||
"""
|
||||
获取 link_memories 工具的 JSON schema
|
||||
|
||||
@@ -239,7 +239,7 @@ class MemoryTools:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_search_memories_schema() -> Dict[str, Any]:
|
||||
def get_search_memories_schema() -> dict[str, Any]:
|
||||
"""
|
||||
获取 search_memories 工具的 JSON schema
|
||||
|
||||
@@ -307,7 +307,7 @@ class MemoryTools:
|
||||
},
|
||||
}
|
||||
|
||||
async def create_memory(self, **params) -> Dict[str, Any]:
|
||||
async def create_memory(self, **params) -> dict[str, Any]:
|
||||
"""
|
||||
执行 create_memory 工具
|
||||
|
||||
@@ -353,7 +353,7 @@ class MemoryTools:
|
||||
"message": "记忆创建失败",
|
||||
}
|
||||
|
||||
async def link_memories(self, **params) -> Dict[str, Any]:
|
||||
async def link_memories(self, **params) -> dict[str, Any]:
|
||||
"""
|
||||
执行 link_memories 工具
|
||||
|
||||
@@ -433,7 +433,7 @@ class MemoryTools:
|
||||
"message": "记忆关联失败",
|
||||
}
|
||||
|
||||
async def search_memories(self, **params) -> Dict[str, Any]:
|
||||
async def search_memories(self, **params) -> dict[str, Any]:
|
||||
"""
|
||||
执行 search_memories 工具
|
||||
|
||||
@@ -486,7 +486,7 @@ class MemoryTools:
|
||||
import orjson
|
||||
try:
|
||||
ids = orjson.loads(ids)
|
||||
except:
|
||||
except Exception:
|
||||
ids = [ids]
|
||||
if isinstance(ids, list):
|
||||
for mem_id in ids:
|
||||
@@ -526,8 +526,7 @@ class MemoryTools:
|
||||
# )
|
||||
|
||||
# 合并扩展结果
|
||||
for mem_id, score in expanded_results:
|
||||
expanded_memory_scores[mem_id] = score
|
||||
expanded_memory_scores.update(dict(expanded_results))
|
||||
|
||||
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
||||
|
||||
@@ -624,16 +623,16 @@ class MemoryTools:
|
||||
}
|
||||
|
||||
async def _generate_multi_queries_simple(
|
||||
self, query: str, context: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
self, query: str, context: dict[str, Any] | None = None
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
||||
|
||||
让小模型直接生成3-5个不同角度的查询语句。
|
||||
"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
@@ -648,10 +647,10 @@ class MemoryTools:
|
||||
# 处理聊天历史,提取最近5条左右的对话
|
||||
recent_chat = ""
|
||||
if chat_history:
|
||||
lines = chat_history.strip().split('\n')
|
||||
lines = chat_history.strip().split("\n")
|
||||
# 取最近5条消息
|
||||
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
||||
recent_chat = '\n'.join(recent_lines)
|
||||
recent_chat = "\n".join(recent_lines)
|
||||
|
||||
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
||||
|
||||
@@ -686,9 +685,11 @@ class MemoryTools:
|
||||
|
||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
||||
|
||||
import orjson, re
|
||||
response = re.sub(r'```json\s*', '', response)
|
||||
response = re.sub(r'```\s*$', '', response).strip()
|
||||
import re
|
||||
|
||||
import orjson
|
||||
response = re.sub(r"```json\s*", "", response)
|
||||
response = re.sub(r"```\s*$", "", response).strip()
|
||||
|
||||
data = orjson.loads(response)
|
||||
queries = data.get("queries", [])
|
||||
@@ -707,7 +708,7 @@ class MemoryTools:
|
||||
|
||||
async def _single_query_search(
|
||||
self, query: str, top_k: int
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||
"""
|
||||
传统的单查询搜索
|
||||
|
||||
@@ -735,8 +736,8 @@ class MemoryTools:
|
||||
return similar_nodes
|
||||
|
||||
async def _multi_query_search(
|
||||
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
self, query: str, top_k: int, context: dict[str, Any] | None = None
|
||||
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||
"""
|
||||
多查询策略搜索(简化版)
|
||||
|
||||
@@ -800,7 +801,7 @@ class MemoryTools:
|
||||
if node.embedding is not None:
|
||||
await self.vector_store.add_node(node)
|
||||
|
||||
async def _find_memory_by_description(self, description: str) -> Optional[Memory]:
|
||||
async def _find_memory_by_description(self, description: str) -> Memory | None:
|
||||
"""
|
||||
通过描述查找记忆
|
||||
|
||||
@@ -827,7 +828,7 @@ class MemoryTools:
|
||||
return None
|
||||
|
||||
# 获取最相似节点关联的记忆
|
||||
node_id, similarity, metadata = similar_nodes[0]
|
||||
_node_id, _similarity, metadata = similar_nodes[0]
|
||||
|
||||
if "memory_ids" not in metadata or not metadata["memory_ids"]:
|
||||
return None
|
||||
@@ -862,12 +863,12 @@ class MemoryTools:
|
||||
|
||||
async def _expand_with_semantic_filter(
|
||||
self,
|
||||
initial_memory_ids: List[str],
|
||||
initial_memory_ids: list[str],
|
||||
query_embedding,
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20
|
||||
) -> List[Tuple[str, float]]:
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
@@ -885,10 +886,9 @@ class MemoryTools:
|
||||
return []
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
visited_memories = set(initial_memory_ids)
|
||||
expanded_memories: Dict[str, float] = {}
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
current_level = initial_memory_ids
|
||||
|
||||
@@ -906,7 +906,7 @@ class MemoryTools:
|
||||
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
@@ -932,7 +932,7 @@ class MemoryTools:
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except:
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分
|
||||
@@ -952,7 +952,7 @@ class MemoryTools:
|
||||
import orjson
|
||||
try:
|
||||
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
|
||||
except:
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
@@ -1010,7 +1010,7 @@ class MemoryTools:
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_all_tool_schemas() -> List[Dict[str, Any]]:
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取所有工具的 schema
|
||||
|
||||
|
||||
@@ -5,4 +5,4 @@
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from src.memory_graph.utils.time_parser import TimeParser
|
||||
|
||||
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"]
|
||||
__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -142,7 +140,7 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return np.random.rand(dim).astype(np.float32)
|
||||
|
||||
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||||
"""使用 API 生成嵌入"""
|
||||
try:
|
||||
# 初始化 API
|
||||
@@ -166,7 +164,7 @@ class EmbeddingGenerator:
|
||||
logger.debug(f"API 嵌入生成失败: {e}")
|
||||
return None
|
||||
|
||||
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
|
||||
"""使用本地模型生成嵌入"""
|
||||
try:
|
||||
# 加载本地模型
|
||||
@@ -204,13 +202,13 @@ class EmbeddingGenerator:
|
||||
if self._local_model_loaded and self._local_model:
|
||||
try:
|
||||
return self._local_model.get_sentence_embedding_dimension()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 默认 384(sentence-transformers 常用维度)
|
||||
return 384
|
||||
|
||||
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
|
||||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""
|
||||
批量生成嵌入向量
|
||||
|
||||
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
||||
|
||||
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
|
||||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
|
||||
"""使用 API 批量生成"""
|
||||
try:
|
||||
# 对于大多数 API,批量调用就是多次单独调用
|
||||
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_generator: Optional[EmbeddingGenerator] = None
|
||||
_global_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
def get_embedding_generator(
|
||||
|
||||
@@ -5,10 +5,9 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType
|
||||
from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,11 +41,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
|
||||
# 2. 查找主题节点(谓语/动作)
|
||||
topic_node = None
|
||||
memory_type_relation = None
|
||||
for edge in memory.edges:
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
||||
topic_node = memory.get_node_by_id(edge.target_id)
|
||||
memory_type_relation = edge.relation
|
||||
break
|
||||
|
||||
if not topic_node:
|
||||
@@ -65,7 +62,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
break
|
||||
|
||||
# 4. 收集属性节点
|
||||
attributes: Dict[str, str] = {}
|
||||
attributes: dict[str, str] = {}
|
||||
for edge in memory.edges:
|
||||
if edge.edge_type == EdgeType.ATTRIBUTE:
|
||||
# 查找属性节点和值节点
|
||||
@@ -143,8 +140,8 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
|
||||
|
||||
def format_memories_for_prompt(
|
||||
memories: List[Memory],
|
||||
max_count: Optional[int] = None,
|
||||
memories: list[Memory],
|
||||
max_count: int | None = None,
|
||||
include_metadata: bool = False,
|
||||
group_by_type: bool = False
|
||||
) -> str:
|
||||
@@ -169,7 +166,7 @@ def format_memories_for_prompt(
|
||||
|
||||
# 按类型分组
|
||||
if group_by_type:
|
||||
type_groups: Dict[MemoryType, List[Memory]] = {}
|
||||
type_groups: dict[MemoryType, list[Memory]] = {}
|
||||
for memory in memories:
|
||||
if memory.memory_type not in type_groups:
|
||||
type_groups[memory.memory_type] = []
|
||||
@@ -250,7 +247,7 @@ def get_memory_type_label(memory_type: str) -> str:
|
||||
return type_mapping.get(memory_type_lower, "未知")
|
||||
|
||||
|
||||
def _format_relative_time(timestamp: datetime) -> Optional[str]:
|
||||
def _format_relative_time(timestamp: datetime) -> str | None:
|
||||
"""
|
||||
格式化相对时间(如"2天前"、"刚才")
|
||||
|
||||
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
|
||||
|
||||
# 导出主要函数
|
||||
__all__ = [
|
||||
'format_memory_for_prompt',
|
||||
'format_memories_for_prompt',
|
||||
'get_memory_type_label',
|
||||
'format_memory_summary',
|
||||
"format_memories_for_prompt",
|
||||
"format_memory_for_prompt",
|
||||
"format_memory_summary",
|
||||
"get_memory_type_label",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -28,7 +27,7 @@ class TimeParser:
|
||||
负责将自然语言时间表达转换为标准化的绝对时间
|
||||
"""
|
||||
|
||||
def __init__(self, reference_time: Optional[datetime] = None):
|
||||
def __init__(self, reference_time: datetime | None = None):
|
||||
"""
|
||||
初始化时间解析器
|
||||
|
||||
@@ -37,7 +36,7 @@ class TimeParser:
|
||||
"""
|
||||
self.reference_time = reference_time or datetime.now()
|
||||
|
||||
def parse(self, time_str: str) -> Optional[datetime]:
|
||||
def parse(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析时间字符串
|
||||
|
||||
@@ -81,7 +80,7 @@ class TimeParser:
|
||||
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
|
||||
return self.reference_time
|
||||
|
||||
def _parse_relative_day(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_relative_day(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析相对日期:今天、明天、昨天、前天、后天
|
||||
"""
|
||||
@@ -108,7 +107,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_days_ago(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_days_ago(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
|
||||
"""
|
||||
@@ -172,7 +171,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_hours_ago(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析 X小时前/X小时后、X分钟前/X分钟后
|
||||
"""
|
||||
@@ -204,7 +203,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_week_month_year(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析:上周、上个月、去年、本周、本月、今年
|
||||
"""
|
||||
@@ -232,7 +231,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_specific_date(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_specific_date(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析具体日期:
|
||||
- 2025-11-05
|
||||
@@ -266,7 +265,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_time_of_day(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析一天中的时间:
|
||||
- 早上、上午、中午、下午、晚上、深夜
|
||||
@@ -290,7 +289,7 @@ class TimeParser:
|
||||
}
|
||||
|
||||
# 先检查是否有具体时间点:早上8点、下午3点
|
||||
for period, default_hour in time_periods.items():
|
||||
for period in time_periods.keys():
|
||||
pattern = rf"{period}(\d{{1,2}})点?"
|
||||
match = re.search(pattern, time_str)
|
||||
if match:
|
||||
@@ -314,7 +313,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_combined_time(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_combined_time(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析组合时间表达:今天下午、昨天晚上、明天早上
|
||||
"""
|
||||
@@ -461,7 +460,7 @@ class TimeParser:
|
||||
|
||||
return str(dt)
|
||||
|
||||
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
解析时间范围:最近一周、最近3天
|
||||
|
||||
|
||||
Reference in New Issue
Block a user