feat(memory-graph): Phase 1 基础架构实现
- 定义核心数据模型 (MemoryNode, MemoryEdge, Memory) - 实现配置管理系统 (MemoryGraphConfig) - 实现向量存储层 (VectorStore with ChromaDB) - 实现图存储层 (GraphStore with NetworkX) - 创建设计文档大纲 - 添加基础测试并验证通过 待完成: - 持久化管理 - 节点去重逻辑 - 记忆构建器 - 记忆检索器
This commit is contained in:
1534
docs/memory_graph/design_outline.md
Normal file
1534
docs/memory_graph/design_outline.md
Normal file
File diff suppressed because it is too large
Load Diff
27
src/memory_graph/__init__.py
Normal file
27
src/memory_graph/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
记忆图系统 (Memory Graph System)
|
||||
|
||||
基于知识图谱 + 语义向量的混合记忆架构
|
||||
"""
|
||||
|
||||
from src.memory_graph.models import (
|
||||
Memory,
|
||||
MemoryEdge,
|
||||
MemoryNode,
|
||||
MemoryStatus,
|
||||
MemoryType,
|
||||
NodeType,
|
||||
EdgeType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Memory",
|
||||
"MemoryNode",
|
||||
"MemoryEdge",
|
||||
"MemoryType",
|
||||
"NodeType",
|
||||
"EdgeType",
|
||||
"MemoryStatus",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
145
src/memory_graph/config.py
Normal file
145
src/memory_graph/config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
记忆图系统配置管理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""记忆整理配置"""
|
||||
|
||||
interval_hours: int = 6 # 整理间隔(小时)
|
||||
batch_size: int = 100 # 每次处理记忆数量
|
||||
enable_auto_discovery: bool = True # 是否启用自动关联发现
|
||||
enable_conflict_detection: bool = True # 是否启用冲突检测
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""记忆检索配置"""
|
||||
|
||||
default_mode: str = "auto" # auto/fast/deep
|
||||
max_expand_depth: int = 2 # 最大图扩展深度
|
||||
vector_weight: float = 0.4 # 向量相似度权重
|
||||
graph_distance_weight: float = 0.2 # 图距离权重
|
||||
importance_weight: float = 0.2 # 重要性权重
|
||||
recency_weight: float = 0.2 # 时效性权重
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证权重总和"""
|
||||
total = self.vector_weight + self.graph_distance_weight + self.importance_weight + self.recency_weight
|
||||
if abs(total - 1.0) > 0.01:
|
||||
raise ValueError(f"权重总和必须为1.0,当前为 {total}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeMergerConfig:
|
||||
"""节点去重配置"""
|
||||
|
||||
similarity_threshold: float = 0.85 # 相似度阈值
|
||||
context_match_required: bool = True # 是否要求上下文匹配
|
||||
merge_batch_size: int = 50 # 批量处理大小
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证阈值范围"""
|
||||
if not 0.0 <= self.similarity_threshold <= 1.0:
|
||||
raise ValueError(f"相似度阈值必须在 [0, 1] 范围内,当前为 {self.similarity_threshold}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageConfig:
|
||||
"""存储配置"""
|
||||
|
||||
data_dir: Path = field(default_factory=lambda: Path("data/memory_graph"))
|
||||
vector_collection_name: str = "memory_nodes"
|
||||
graph_file_name: str = "memory_graph.json"
|
||||
enable_persistence: bool = True # 是否启用持久化
|
||||
auto_save_interval: int = 300 # 自动保存间隔(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryGraphConfig:
|
||||
"""记忆图系统总配置"""
|
||||
|
||||
consolidation: ConsolidationConfig = field(default_factory=ConsolidationConfig)
|
||||
retrieval: RetrievalConfig = field(default_factory=RetrievalConfig)
|
||||
node_merger: NodeMergerConfig = field(default_factory=NodeMergerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
|
||||
# 时间衰减配置
|
||||
decay_rates: Dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"EVENT": 0.05, # 事件衰减较快
|
||||
"FACT": 0.01, # 事实衰减慢
|
||||
"RELATION": 0.005, # 关系衰减很慢
|
||||
"OPINION": 0.03, # 观点中等衰减
|
||||
}
|
||||
)
|
||||
|
||||
# 嵌入模型配置
|
||||
embedding_model: Optional[str] = None # 如果为None,则使用系统默认
|
||||
embedding_dimension: int = 384 # 默认使用 sentence-transformers 的维度
|
||||
|
||||
# 调试和日志
|
||||
enable_debug_logging: bool = False
|
||||
enable_visualization: bool = False # 是否启用记忆可视化
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict) -> MemoryGraphConfig:
|
||||
"""从字典创建配置"""
|
||||
return cls(
|
||||
consolidation=ConsolidationConfig(**config_dict.get("consolidation", {})),
|
||||
retrieval=RetrievalConfig(**config_dict.get("retrieval", {})),
|
||||
node_merger=NodeMergerConfig(**config_dict.get("node_merger", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
decay_rates=config_dict.get("decay_rates", cls().decay_rates),
|
||||
embedding_model=config_dict.get("embedding_model"),
|
||||
embedding_dimension=config_dict.get("embedding_dimension", 384),
|
||||
enable_debug_logging=config_dict.get("enable_debug_logging", False),
|
||||
enable_visualization=config_dict.get("enable_visualization", False),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"consolidation": {
|
||||
"interval_hours": self.consolidation.interval_hours,
|
||||
"batch_size": self.consolidation.batch_size,
|
||||
"enable_auto_discovery": self.consolidation.enable_auto_discovery,
|
||||
"enable_conflict_detection": self.consolidation.enable_conflict_detection,
|
||||
},
|
||||
"retrieval": {
|
||||
"default_mode": self.retrieval.default_mode,
|
||||
"max_expand_depth": self.retrieval.max_expand_depth,
|
||||
"vector_weight": self.retrieval.vector_weight,
|
||||
"graph_distance_weight": self.retrieval.graph_distance_weight,
|
||||
"importance_weight": self.retrieval.importance_weight,
|
||||
"recency_weight": self.retrieval.recency_weight,
|
||||
},
|
||||
"node_merger": {
|
||||
"similarity_threshold": self.node_merger.similarity_threshold,
|
||||
"context_match_required": self.node_merger.context_match_required,
|
||||
"merge_batch_size": self.node_merger.merge_batch_size,
|
||||
},
|
||||
"storage": {
|
||||
"data_dir": str(self.storage.data_dir),
|
||||
"vector_collection_name": self.storage.vector_collection_name,
|
||||
"graph_file_name": self.storage.graph_file_name,
|
||||
"enable_persistence": self.storage.enable_persistence,
|
||||
"auto_save_interval": self.storage.auto_save_interval,
|
||||
},
|
||||
"decay_rates": self.decay_rates,
|
||||
"embedding_model": self.embedding_model,
|
||||
"embedding_dimension": self.embedding_dimension,
|
||||
"enable_debug_logging": self.enable_debug_logging,
|
||||
"enable_visualization": self.enable_visualization,
|
||||
}
|
||||
|
||||
|
||||
# 默认配置实例
|
||||
DEFAULT_CONFIG = MemoryGraphConfig()
|
||||
294
src/memory_graph/models.py
Normal file
294
src/memory_graph/models.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
记忆图系统核心数据模型
|
||||
|
||||
定义节点、边、记忆等核心数据结构
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""节点类型枚举"""
|
||||
|
||||
SUBJECT = "主体" # 记忆的主语(我、小明、老师)
|
||||
TOPIC = "主题" # 动作或状态(吃饭、情绪、学习)
|
||||
OBJECT = "客体" # 宾语(白米饭、学校、书)
|
||||
ATTRIBUTE = "属性" # 延伸属性(时间、地点、原因)
|
||||
VALUE = "值" # 属性的具体值(2025-11-05、不开心)
|
||||
|
||||
|
||||
class MemoryType(Enum):
|
||||
"""记忆类型枚举"""
|
||||
|
||||
EVENT = "事件" # 有时间点的动作
|
||||
FACT = "事实" # 相对稳定的状态
|
||||
RELATION = "关系" # 人际关系
|
||||
OPINION = "观点" # 主观评价
|
||||
|
||||
|
||||
class EdgeType(Enum):
|
||||
"""边类型枚举"""
|
||||
|
||||
MEMORY_TYPE = "记忆类型" # 主体 → 主题
|
||||
CORE_RELATION = "核心关系" # 主题 → 客体(是/做/有)
|
||||
ATTRIBUTE = "属性关系" # 任意节点 → 属性
|
||||
CAUSALITY = "因果关系" # 记忆 → 记忆
|
||||
REFERENCE = "引用关系" # 记忆 → 记忆(转述)
|
||||
|
||||
|
||||
class MemoryStatus(Enum):
|
||||
"""记忆状态枚举"""
|
||||
|
||||
STAGED = "staged" # 临时状态,未整理
|
||||
CONSOLIDATED = "consolidated" # 已整理
|
||||
ARCHIVED = "archived" # 已归档(低价值,很少访问)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryNode:
|
||||
"""记忆节点"""
|
||||
|
||||
id: str # 节点唯一ID
|
||||
content: str # 节点内容(如:"我"、"吃饭"、"白米饭")
|
||||
node_type: NodeType # 节点类型
|
||||
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"content": self.content,
|
||||
"node_type": self.node_type.value,
|
||||
"embedding": self.embedding.tolist() if self.embedding is not None else None,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode:
|
||||
"""从字典创建节点"""
|
||||
embedding = None
|
||||
if data.get("embedding") is not None:
|
||||
embedding = np.array(data["embedding"])
|
||||
|
||||
return cls(
|
||||
id=data["id"],
|
||||
content=data["content"],
|
||||
node_type=NodeType(data["node_type"]),
|
||||
embedding=embedding,
|
||||
metadata=data.get("metadata", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
)
|
||||
|
||||
def has_embedding(self) -> bool:
|
||||
"""是否有语义向量"""
|
||||
return self.embedding is not None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Node({self.node_type.value}: {self.content})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryEdge:
|
||||
"""记忆边(节点之间的关系)"""
|
||||
|
||||
id: str # 边唯一ID
|
||||
source_id: str # 源节点ID
|
||||
target_id: str # 目标节点ID(或目标记忆ID)
|
||||
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为")
|
||||
edge_type: EdgeType # 边类型
|
||||
importance: float = 0.5 # 重要性 [0-1]
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())
|
||||
# 确保重要性在有效范围内
|
||||
self.importance = max(0.0, min(1.0, self.importance))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"relation": self.relation,
|
||||
"edge_type": self.edge_type.value,
|
||||
"importance": self.importance,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge:
|
||||
"""从字典创建边"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
source_id=data["source_id"],
|
||||
target_id=data["target_id"],
|
||||
relation=data["relation"],
|
||||
edge_type=EdgeType(data["edge_type"]),
|
||||
importance=data.get("importance", 0.5),
|
||||
metadata=data.get("metadata", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Edge({self.source_id} --{self.relation}--> {self.target_id})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""完整记忆(由节点和边组成的子图)"""
|
||||
|
||||
id: str # 记忆唯一ID
|
||||
subject_id: str # 主体节点ID
|
||||
memory_type: MemoryType # 记忆类型
|
||||
nodes: List[MemoryNode] # 该记忆包含的所有节点
|
||||
edges: List[MemoryEdge] # 该记忆包含的所有边
|
||||
importance: float = 0.5 # 整体重要性 [0-1]
|
||||
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
||||
access_count: int = 0 # 访问次数
|
||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())
|
||||
# 确保重要性在有效范围内
|
||||
self.importance = max(0.0, min(1.0, self.importance))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"subject_id": self.subject_id,
|
||||
"memory_type": self.memory_type.value,
|
||||
"nodes": [node.to_dict() for node in self.nodes],
|
||||
"edges": [edge.to_dict() for edge in self.edges],
|
||||
"importance": self.importance,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_accessed": self.last_accessed.isoformat(),
|
||||
"access_count": self.access_count,
|
||||
"decay_factor": self.decay_factor,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Memory:
|
||||
"""从字典创建记忆"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
subject_id=data["subject_id"],
|
||||
memory_type=MemoryType(data["memory_type"]),
|
||||
nodes=[MemoryNode.from_dict(n) for n in data["nodes"]],
|
||||
edges=[MemoryEdge.from_dict(e) for e in data["edges"]],
|
||||
importance=data.get("importance", 0.5),
|
||||
status=MemoryStatus(data.get("status", "staged")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
|
||||
access_count=data.get("access_count", 0),
|
||||
decay_factor=data.get("decay_factor", 1.0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def update_access(self) -> None:
|
||||
"""更新访问记录"""
|
||||
self.last_accessed = datetime.now()
|
||||
self.access_count += 1
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]:
|
||||
"""根据ID获取节点"""
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_subject_node(self) -> Optional[MemoryNode]:
|
||||
"""获取主体节点"""
|
||||
return self.get_node_by_id(self.subject_id)
|
||||
|
||||
def to_text(self) -> str:
|
||||
"""转换为文本描述(用于显示和LLM处理)"""
|
||||
subject_node = self.get_subject_node()
|
||||
if not subject_node:
|
||||
return f"[记忆 {self.id[:8]}]"
|
||||
|
||||
# 简单的文本生成逻辑
|
||||
parts = [f"{subject_node.content}"]
|
||||
|
||||
# 查找主题节点(通过记忆类型边连接)
|
||||
topic_node = None
|
||||
for edge in self.edges:
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id:
|
||||
topic_node = self.get_node_by_id(edge.target_id)
|
||||
break
|
||||
|
||||
if topic_node:
|
||||
parts.append(topic_node.content)
|
||||
|
||||
# 查找客体节点(通过核心关系边连接)
|
||||
for edge in self.edges:
|
||||
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
|
||||
obj_node = self.get_node_by_id(edge.target_id)
|
||||
if obj_node:
|
||||
parts.append(f"{edge.relation} {obj_node.content}")
|
||||
break
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Memory({self.memory_type.value}: {self.to_text()})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagedMemory:
|
||||
"""临时记忆(未整理状态)"""
|
||||
|
||||
memory: Memory # 原始记忆对象
|
||||
status: MemoryStatus = MemoryStatus.STAGED # 状态
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
consolidated_at: Optional[datetime] = None # 整理时间
|
||||
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"memory": self.memory.to_dict(),
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"consolidated_at": self.consolidated_at.isoformat() if self.consolidated_at else None,
|
||||
"merge_history": self.merge_history,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory:
|
||||
"""从字典创建临时记忆"""
|
||||
return cls(
|
||||
memory=Memory.from_dict(data["memory"]),
|
||||
status=MemoryStatus(data.get("status", "staged")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
consolidated_at=datetime.fromisoformat(data["consolidated_at"]) if data.get("consolidated_at") else None,
|
||||
merge_history=data.get("merge_history", []),
|
||||
)
|
||||
8
src/memory_graph/storage/__init__.py
Normal file
8
src/memory_graph/storage/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
存储层模块
|
||||
"""
|
||||
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
|
||||
__all__ = ["VectorStore", "GraphStore"]
|
||||
389
src/memory_graph/storage/graph_store.py
Normal file
389
src/memory_graph/storage/graph_store.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
图存储层:基于 NetworkX 的图结构管理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GraphStore:
|
||||
"""
|
||||
图存储封装类
|
||||
|
||||
负责:
|
||||
1. 记忆图的构建和维护
|
||||
2. 节点和边的快速查询
|
||||
3. 图遍历算法(BFS/DFS)
|
||||
4. 邻接关系查询
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化图存储"""
|
||||
# 使用有向图(记忆关系通常是有向的)
|
||||
self.graph = nx.DiGraph()
|
||||
|
||||
# 索引:记忆ID -> 记忆对象
|
||||
self.memory_index: Dict[str, Memory] = {}
|
||||
|
||||
# 索引:节点ID -> 所属记忆ID集合
|
||||
self.node_to_memories: Dict[str, Set[str]] = {}
|
||||
|
||||
logger.info("初始化图存储")
|
||||
|
||||
def add_memory(self, memory: Memory) -> None:
|
||||
"""
|
||||
添加记忆到图
|
||||
|
||||
Args:
|
||||
memory: 要添加的记忆
|
||||
"""
|
||||
try:
|
||||
# 1. 添加所有节点到图
|
||||
for node in memory.nodes:
|
||||
if not self.graph.has_node(node.id):
|
||||
self.graph.add_node(
|
||||
node.id,
|
||||
content=node.content,
|
||||
node_type=node.node_type.value,
|
||||
created_at=node.created_at.isoformat(),
|
||||
metadata=node.metadata,
|
||||
)
|
||||
|
||||
# 更新节点到记忆的映射
|
||||
if node.id not in self.node_to_memories:
|
||||
self.node_to_memories[node.id] = set()
|
||||
self.node_to_memories[node.id].add(memory.id)
|
||||
|
||||
# 2. 添加所有边到图
|
||||
for edge in memory.edges:
|
||||
self.graph.add_edge(
|
||||
edge.source_id,
|
||||
edge.target_id,
|
||||
edge_id=edge.id,
|
||||
relation=edge.relation,
|
||||
edge_type=edge.edge_type.value,
|
||||
importance=edge.importance,
|
||||
metadata=edge.metadata,
|
||||
created_at=edge.created_at.isoformat(),
|
||||
)
|
||||
|
||||
# 3. 保存记忆对象
|
||||
self.memory_index[memory.id] = memory
|
||||
|
||||
logger.debug(f"添加记忆到图: {memory}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
|
||||
"""
|
||||
根据ID获取记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
记忆对象或 None
|
||||
"""
|
||||
return self.memory_index.get(memory_id)
|
||||
|
||||
def get_memories_by_node(self, node_id: str) -> List[Memory]:
|
||||
"""
|
||||
获取包含指定节点的所有记忆
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
if node_id not in self.node_to_memories:
|
||||
return []
|
||||
|
||||
memory_ids = self.node_to_memories[node_id]
|
||||
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
||||
|
||||
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
|
||||
"""
|
||||
获取从指定节点出发的所有边
|
||||
|
||||
Args:
|
||||
node_id: 源节点ID
|
||||
relation_types: 关系类型过滤(可选)
|
||||
|
||||
Returns:
|
||||
边信息列表
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
edges = []
|
||||
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
|
||||
# 过滤关系类型
|
||||
if relation_types and edge_data.get("relation") not in relation_types:
|
||||
continue
|
||||
|
||||
edges.append(
|
||||
{
|
||||
"source_id": node_id,
|
||||
"target_id": target_id,
|
||||
"relation": edge_data.get("relation"),
|
||||
"edge_type": edge_data.get("edge_type"),
|
||||
"importance": edge_data.get("importance", 0.5),
|
||||
**edge_data,
|
||||
}
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
def get_neighbors(
|
||||
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
|
||||
) -> List[Tuple[str, Dict]]:
|
||||
"""
|
||||
获取节点的邻居节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
|
||||
relation_types: 关系类型过滤
|
||||
|
||||
Returns:
|
||||
List of (neighbor_id, edge_data)
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
neighbors = []
|
||||
|
||||
# 处理出边
|
||||
if direction in ["out", "both"]:
|
||||
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
|
||||
if not relation_types or edge_data.get("relation") in relation_types:
|
||||
neighbors.append((target_id, edge_data))
|
||||
|
||||
# 处理入边
|
||||
if direction in ["in", "both"]:
|
||||
for source_id, _, edge_data in self.graph.in_edges(node_id, data=True):
|
||||
if not relation_types or edge_data.get("relation") in relation_types:
|
||||
neighbors.append((source_id, edge_data))
|
||||
|
||||
return neighbors
|
||||
|
||||
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
|
||||
"""
|
||||
查找两个节点之间的最短路径
|
||||
|
||||
Args:
|
||||
source_id: 源节点ID
|
||||
target_id: 目标节点ID
|
||||
max_length: 最大路径长度(可选)
|
||||
|
||||
Returns:
|
||||
路径节点ID列表,或 None(如果不存在路径)
|
||||
"""
|
||||
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
|
||||
return None
|
||||
|
||||
try:
|
||||
if max_length:
|
||||
# 使用 cutoff 限制路径长度
|
||||
path = nx.shortest_path(self.graph, source_id, target_id, weight=None)
|
||||
if len(path) - 1 <= max_length: # 边数 = 节点数 - 1
|
||||
return path
|
||||
return None
|
||||
else:
|
||||
return nx.shortest_path(self.graph, source_id, target_id, weight=None)
|
||||
|
||||
except nx.NetworkXNoPath:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"查找路径失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def bfs_expand(
|
||||
self,
|
||||
start_nodes: List[str],
|
||||
depth: int = 1,
|
||||
relation_types: Optional[List[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
从起始节点进行广度优先搜索扩展
|
||||
|
||||
Args:
|
||||
start_nodes: 起始节点ID列表
|
||||
depth: 扩展深度
|
||||
relation_types: 关系类型过滤
|
||||
|
||||
Returns:
|
||||
扩展到的所有节点ID集合
|
||||
"""
|
||||
visited = set()
|
||||
queue = [(node_id, 0) for node_id in start_nodes if self.graph.has_node(node_id)]
|
||||
|
||||
while queue:
|
||||
current_node, current_depth = queue.pop(0)
|
||||
|
||||
if current_node in visited:
|
||||
continue
|
||||
visited.add(current_node)
|
||||
|
||||
if current_depth >= depth:
|
||||
continue
|
||||
|
||||
# 获取邻居并加入队列
|
||||
neighbors = self.get_neighbors(current_node, direction="out", relation_types=relation_types)
|
||||
for neighbor_id, _ in neighbors:
|
||||
if neighbor_id not in visited:
|
||||
queue.append((neighbor_id, current_depth + 1))
|
||||
|
||||
return visited
|
||||
|
||||
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
|
||||
"""
|
||||
获取包含指定节点的子图
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
|
||||
Returns:
|
||||
NetworkX 子图
|
||||
"""
|
||||
return self.graph.subgraph(node_ids).copy()
|
||||
|
||||
def merge_nodes(self, source_id: str, target_id: str) -> None:
|
||||
"""
|
||||
合并两个节点(将source的所有边转移到target,然后删除source)
|
||||
|
||||
Args:
|
||||
source_id: 源节点ID(将被删除)
|
||||
target_id: 目标节点ID(保留)
|
||||
"""
|
||||
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
|
||||
logger.warning(f"合并节点失败: 节点不存在 ({source_id}, {target_id})")
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. 转移入边
|
||||
for pred, _, edge_data in self.graph.in_edges(source_id, data=True):
|
||||
if pred != target_id: # 避免自环
|
||||
self.graph.add_edge(pred, target_id, **edge_data)
|
||||
|
||||
# 2. 转移出边
|
||||
for _, succ, edge_data in self.graph.out_edges(source_id, data=True):
|
||||
if succ != target_id: # 避免自环
|
||||
self.graph.add_edge(target_id, succ, **edge_data)
|
||||
|
||||
# 3. 更新节点到记忆的映射
|
||||
if source_id in self.node_to_memories:
|
||||
memory_ids = self.node_to_memories[source_id]
|
||||
if target_id not in self.node_to_memories:
|
||||
self.node_to_memories[target_id] = set()
|
||||
self.node_to_memories[target_id].update(memory_ids)
|
||||
del self.node_to_memories[source_id]
|
||||
|
||||
# 4. 删除源节点
|
||||
self.graph.remove_node(source_id)
|
||||
|
||||
logger.info(f"节点合并: {source_id} → {target_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
|
||||
"""
|
||||
获取节点的度数
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
(in_degree, out_degree)
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return (0, 0)
|
||||
|
||||
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""获取图的统计信息"""
|
||||
return {
|
||||
"total_nodes": self.graph.number_of_nodes(),
|
||||
"total_edges": self.graph.number_of_edges(),
|
||||
"total_memories": len(self.memory_index),
|
||||
"connected_components": nx.number_weakly_connected_components(self.graph),
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
将图转换为字典(用于持久化)
|
||||
|
||||
Returns:
|
||||
图的字典表示
|
||||
"""
|
||||
return {
|
||||
"nodes": [
|
||||
{"id": node_id, **self.graph.nodes[node_id]} for node_id in self.graph.nodes()
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": u,
|
||||
"target": v,
|
||||
**data,
|
||||
}
|
||||
for u, v, data in self.graph.edges(data=True)
|
||||
],
|
||||
"memories": {memory_id: memory.to_dict() for memory_id, memory in self.memory_index.items()},
|
||||
"node_to_memories": {node_id: list(mem_ids) for node_id, mem_ids in self.node_to_memories.items()},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> GraphStore:
|
||||
"""
|
||||
从字典加载图
|
||||
|
||||
Args:
|
||||
data: 图的字典表示
|
||||
|
||||
Returns:
|
||||
GraphStore 实例
|
||||
"""
|
||||
store = cls()
|
||||
|
||||
# 1. 加载节点
|
||||
for node_data in data.get("nodes", []):
|
||||
node_id = node_data.pop("id")
|
||||
store.graph.add_node(node_id, **node_data)
|
||||
|
||||
# 2. 加载边
|
||||
for edge_data in data.get("edges", []):
|
||||
source = edge_data.pop("source")
|
||||
target = edge_data.pop("target")
|
||||
store.graph.add_edge(source, target, **edge_data)
|
||||
|
||||
# 3. 加载记忆
|
||||
for memory_id, memory_dict in data.get("memories", {}).items():
|
||||
store.memory_index[memory_id] = Memory.from_dict(memory_dict)
|
||||
|
||||
# 4. 加载节点到记忆的映射
|
||||
for node_id, mem_ids in data.get("node_to_memories", {}).items():
|
||||
store.node_to_memories[node_id] = set(mem_ids)
|
||||
|
||||
logger.info(f"从字典加载图: {store.get_statistics()}")
|
||||
return store
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空图(危险操作,仅用于测试)"""
|
||||
self.graph.clear()
|
||||
self.memory_index.clear()
|
||||
self.node_to_memories.clear()
|
||||
logger.warning("图存储已清空")
|
||||
297
src/memory_graph/storage/vector_store.py
Normal file
297
src/memory_graph/storage/vector_store.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
向量存储层:基于 ChromaDB 的语义向量存储
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import MemoryNode, NodeType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
向量存储封装类
|
||||
|
||||
负责:
|
||||
1. 节点的语义向量存储和检索
|
||||
2. 基于相似度的向量搜索
|
||||
3. 节点去重时的相似节点查找
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "memory_nodes",
|
||||
data_dir: Optional[Path] = None,
|
||||
embedding_function: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化向量存储
|
||||
|
||||
Args:
|
||||
collection_name: ChromaDB 集合名称
|
||||
data_dir: 数据存储目录
|
||||
embedding_function: 嵌入函数(如果为None则使用默认)
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.data_dir = data_dir or Path("data/memory_graph")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = None
|
||||
self.collection = None
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
logger.info(f"初始化向量存储: collection={collection_name}, dir={self.data_dir}")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""异步初始化 ChromaDB"""
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
# 创建持久化客户端
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=str(self.data_dir / "chroma"),
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 获取或创建集合
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"ChromaDB 初始化完成,集合包含 {self.collection.count()} 个节点")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 ChromaDB 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_node(self, node: MemoryNode) -> None:
|
||||
"""
|
||||
添加节点到向量存储
|
||||
|
||||
Args:
|
||||
node: 要添加的节点
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
if not node.has_embedding():
|
||||
logger.warning(f"节点 {node.id} 没有 embedding,跳过添加")
|
||||
return
|
||||
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[node.id],
|
||||
embeddings=[node.embedding.tolist()],
|
||||
metadatas=[
|
||||
{
|
||||
"content": node.content,
|
||||
"node_type": node.node_type.value,
|
||||
"created_at": node.created_at.isoformat(),
|
||||
**node.metadata,
|
||||
}
|
||||
],
|
||||
documents=[node.content], # 文本内容用于检索
|
||||
)
|
||||
|
||||
logger.debug(f"添加节点到向量存储: {node}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
|
||||
"""
|
||||
批量添加节点
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
# 过滤出有 embedding 的节点
|
||||
valid_nodes = [n for n in nodes if n.has_embedding()]
|
||||
|
||||
if not valid_nodes:
|
||||
logger.warning("批量添加:没有有效的节点(缺少 embedding)")
|
||||
return
|
||||
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[n.id for n in valid_nodes],
|
||||
embeddings=[n.embedding.tolist() for n in valid_nodes],
|
||||
metadatas=[
|
||||
{
|
||||
"content": n.content,
|
||||
"node_type": n.node_type.value,
|
||||
"created_at": n.created_at.isoformat(),
|
||||
**n.metadata,
|
||||
}
|
||||
for n in valid_nodes
|
||||
],
|
||||
documents=[n.content for n in valid_nodes],
|
||||
)
|
||||
|
||||
logger.info(f"批量添加 {len(valid_nodes)} 个节点到向量存储")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量添加节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def search_similar_nodes(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
limit: int = 10,
|
||||
node_types: Optional[List[NodeType]] = None,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
"""
|
||||
搜索相似节点
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
limit: 返回结果数量
|
||||
node_types: 限制节点类型(可选)
|
||||
min_similarity: 最小相似度阈值
|
||||
|
||||
Returns:
|
||||
List of (node_id, similarity, metadata)
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
# 构建 where 条件
|
||||
where_filter = None
|
||||
if node_types:
|
||||
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
|
||||
|
||||
# 执行查询
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=limit,
|
||||
where=where_filter,
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
similar_nodes = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, node_id in enumerate(results["ids"][0]):
|
||||
# ChromaDB 返回的是距离,需要转换为相似度
|
||||
# 余弦距离: distance = 1 - similarity
|
||||
distance = results["distances"][0][i]
|
||||
similarity = 1.0 - distance
|
||||
|
||||
if similarity >= min_similarity:
|
||||
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
similar_nodes.append((node_id, similarity, metadata))
|
||||
|
||||
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
||||
return similar_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据ID获取节点元数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
节点元数据或 None
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
result = self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
|
||||
|
||||
if result["ids"]:
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"metadata": result["metadatas"][0] if result["metadatas"] else {},
|
||||
"embedding": np.array(result["embeddings"][0]) if result["embeddings"] else None,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取节点失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
删除节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
self.collection.delete(ids=[node_id])
|
||||
logger.debug(f"删除节点: {node_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
|
||||
"""
|
||||
更新节点的 embedding
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
embedding: 新的向量
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
|
||||
logger.debug(f"更新节点 embedding: {node_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新节点 embedding 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""获取向量存储中的节点总数"""
|
||||
if not self.collection:
|
||||
return 0
|
||||
return self.collection.count()
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""清空向量存储(危险操作,仅用于测试)"""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
try:
|
||||
# 删除并重新创建集合
|
||||
self.client.delete_collection(self.collection_name)
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
logger.warning(f"向量存储已清空: {self.collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空向量存储失败: {e}", exc_info=True)
|
||||
raise
|
||||
Reference in New Issue
Block a user