feat(memory-graph): Phase 1 基础架构实现

- 定义核心数据模型 (MemoryNode, MemoryEdge, Memory)
- 实现配置管理系统 (MemoryGraphConfig)
- 实现向量存储层 (VectorStore with ChromaDB)
- 实现图存储层 (GraphStore with NetworkX)
- 创建设计文档大纲
- 添加基础测试并验证通过

待完成:
- 持久化管理
- 节点去重逻辑
- 记忆构建器
- 记忆检索器
This commit is contained in:
Windpicker-owo
2025-11-05 16:46:53 +08:00
parent 25c50f759f
commit 47af755805
7 changed files with 2694 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
"""
记忆图系统 (Memory Graph System)
基于知识图谱 + 语义向量的混合记忆架构
"""
from src.memory_graph.models import (
Memory,
MemoryEdge,
MemoryNode,
MemoryStatus,
MemoryType,
NodeType,
EdgeType,
)
__all__ = [
"Memory",
"MemoryNode",
"MemoryEdge",
"MemoryType",
"NodeType",
"EdgeType",
"MemoryStatus",
]
__version__ = "0.1.0"

145
src/memory_graph/config.py Normal file
View File

@@ -0,0 +1,145 @@
"""
记忆图系统配置管理
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Optional
@dataclass
class ConsolidationConfig:
"""记忆整理配置"""
interval_hours: int = 6 # 整理间隔(小时)
batch_size: int = 100 # 每次处理记忆数量
enable_auto_discovery: bool = True # 是否启用自动关联发现
enable_conflict_detection: bool = True # 是否启用冲突检测
@dataclass
class RetrievalConfig:
"""记忆检索配置"""
default_mode: str = "auto" # auto/fast/deep
max_expand_depth: int = 2 # 最大图扩展深度
vector_weight: float = 0.4 # 向量相似度权重
graph_distance_weight: float = 0.2 # 图距离权重
importance_weight: float = 0.2 # 重要性权重
recency_weight: float = 0.2 # 时效性权重
def __post_init__(self):
"""验证权重总和"""
total = self.vector_weight + self.graph_distance_weight + self.importance_weight + self.recency_weight
if abs(total - 1.0) > 0.01:
raise ValueError(f"权重总和必须为1.0,当前为 {total}")
@dataclass
class NodeMergerConfig:
"""节点去重配置"""
similarity_threshold: float = 0.85 # 相似度阈值
context_match_required: bool = True # 是否要求上下文匹配
merge_batch_size: int = 50 # 批量处理大小
def __post_init__(self):
"""验证阈值范围"""
if not 0.0 <= self.similarity_threshold <= 1.0:
raise ValueError(f"相似度阈值必须在 [0, 1] 范围内,当前为 {self.similarity_threshold}")
@dataclass
class StorageConfig:
"""存储配置"""
data_dir: Path = field(default_factory=lambda: Path("data/memory_graph"))
vector_collection_name: str = "memory_nodes"
graph_file_name: str = "memory_graph.json"
enable_persistence: bool = True # 是否启用持久化
auto_save_interval: int = 300 # 自动保存间隔(秒)
@dataclass
class MemoryGraphConfig:
"""记忆图系统总配置"""
consolidation: ConsolidationConfig = field(default_factory=ConsolidationConfig)
retrieval: RetrievalConfig = field(default_factory=RetrievalConfig)
node_merger: NodeMergerConfig = field(default_factory=NodeMergerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
# 时间衰减配置
decay_rates: Dict[str, float] = field(
default_factory=lambda: {
"EVENT": 0.05, # 事件衰减较快
"FACT": 0.01, # 事实衰减慢
"RELATION": 0.005, # 关系衰减很慢
"OPINION": 0.03, # 观点中等衰减
}
)
# 嵌入模型配置
embedding_model: Optional[str] = None # 如果为None则使用系统默认
embedding_dimension: int = 384 # 默认使用 sentence-transformers 的维度
# 调试和日志
enable_debug_logging: bool = False
enable_visualization: bool = False # 是否启用记忆可视化
@classmethod
def from_dict(cls, config_dict: Dict) -> MemoryGraphConfig:
"""从字典创建配置"""
return cls(
consolidation=ConsolidationConfig(**config_dict.get("consolidation", {})),
retrieval=RetrievalConfig(**config_dict.get("retrieval", {})),
node_merger=NodeMergerConfig(**config_dict.get("node_merger", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
decay_rates=config_dict.get("decay_rates", cls().decay_rates),
embedding_model=config_dict.get("embedding_model"),
embedding_dimension=config_dict.get("embedding_dimension", 384),
enable_debug_logging=config_dict.get("enable_debug_logging", False),
enable_visualization=config_dict.get("enable_visualization", False),
)
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"consolidation": {
"interval_hours": self.consolidation.interval_hours,
"batch_size": self.consolidation.batch_size,
"enable_auto_discovery": self.consolidation.enable_auto_discovery,
"enable_conflict_detection": self.consolidation.enable_conflict_detection,
},
"retrieval": {
"default_mode": self.retrieval.default_mode,
"max_expand_depth": self.retrieval.max_expand_depth,
"vector_weight": self.retrieval.vector_weight,
"graph_distance_weight": self.retrieval.graph_distance_weight,
"importance_weight": self.retrieval.importance_weight,
"recency_weight": self.retrieval.recency_weight,
},
"node_merger": {
"similarity_threshold": self.node_merger.similarity_threshold,
"context_match_required": self.node_merger.context_match_required,
"merge_batch_size": self.node_merger.merge_batch_size,
},
"storage": {
"data_dir": str(self.storage.data_dir),
"vector_collection_name": self.storage.vector_collection_name,
"graph_file_name": self.storage.graph_file_name,
"enable_persistence": self.storage.enable_persistence,
"auto_save_interval": self.storage.auto_save_interval,
},
"decay_rates": self.decay_rates,
"embedding_model": self.embedding_model,
"embedding_dimension": self.embedding_dimension,
"enable_debug_logging": self.enable_debug_logging,
"enable_visualization": self.enable_visualization,
}
# 默认配置实例
DEFAULT_CONFIG = MemoryGraphConfig()

294
src/memory_graph/models.py Normal file
View File

@@ -0,0 +1,294 @@
"""
记忆图系统核心数据模型
定义节点、边、记忆等核心数据结构
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
import numpy as np
class NodeType(Enum):
"""节点类型枚举"""
SUBJECT = "主体" # 记忆的主语(我、小明、老师)
TOPIC = "主题" # 动作或状态(吃饭、情绪、学习)
OBJECT = "客体" # 宾语(白米饭、学校、书)
ATTRIBUTE = "属性" # 延伸属性(时间、地点、原因)
VALUE = "" # 属性的具体值2025-11-05、不开心
class MemoryType(Enum):
"""记忆类型枚举"""
EVENT = "事件" # 有时间点的动作
FACT = "事实" # 相对稳定的状态
RELATION = "关系" # 人际关系
OPINION = "观点" # 主观评价
class EdgeType(Enum):
"""边类型枚举"""
MEMORY_TYPE = "记忆类型" # 主体 → 主题
CORE_RELATION = "核心关系" # 主题 → 客体(是/做/有)
ATTRIBUTE = "属性关系" # 任意节点 → 属性
CAUSALITY = "因果关系" # 记忆 → 记忆
REFERENCE = "引用关系" # 记忆 → 记忆(转述)
class MemoryStatus(Enum):
"""记忆状态枚举"""
STAGED = "staged" # 临时状态,未整理
CONSOLIDATED = "consolidated" # 已整理
ARCHIVED = "archived" # 已归档(低价值,很少访问)
@dataclass
class MemoryNode:
"""记忆节点"""
id: str # 节点唯一ID
content: str # 节点内容(如:"我"、"吃饭"、"白米饭"
node_type: NodeType # 节点类型
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"content": self.content,
"node_type": self.node_type.value,
"embedding": self.embedding.tolist() if self.embedding is not None else None,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode:
"""从字典创建节点"""
embedding = None
if data.get("embedding") is not None:
embedding = np.array(data["embedding"])
return cls(
id=data["id"],
content=data["content"],
node_type=NodeType(data["node_type"]),
embedding=embedding,
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]),
)
def has_embedding(self) -> bool:
"""是否有语义向量"""
return self.embedding is not None
def __str__(self) -> str:
return f"Node({self.node_type.value}: {self.content})"
@dataclass
class MemoryEdge:
"""记忆边(节点之间的关系)"""
id: str # 边唯一ID
source_id: str # 源节点ID
target_id: str # 目标节点ID或目标记忆ID
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为"
edge_type: EdgeType # 边类型
importance: float = 0.5 # 重要性 [0-1]
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
# 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"source_id": self.source_id,
"target_id": self.target_id,
"relation": self.relation,
"edge_type": self.edge_type.value,
"importance": self.importance,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge:
"""从字典创建边"""
return cls(
id=data["id"],
source_id=data["source_id"],
target_id=data["target_id"],
relation=data["relation"],
edge_type=EdgeType(data["edge_type"]),
importance=data.get("importance", 0.5),
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]),
)
def __str__(self) -> str:
return f"Edge({self.source_id} --{self.relation}--> {self.target_id})"
@dataclass
class Memory:
"""完整记忆(由节点和边组成的子图)"""
id: str # 记忆唯一ID
subject_id: str # 主体节点ID
memory_type: MemoryType # 记忆类型
nodes: List[MemoryNode] # 该记忆包含的所有节点
edges: List[MemoryEdge] # 该记忆包含的所有边
importance: float = 0.5 # 整体重要性 [0-1]
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
created_at: datetime = field(default_factory=datetime.now)
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
access_count: int = 0 # 访问次数
decay_factor: float = 1.0 # 衰减因子(随时间变化)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
# 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"subject_id": self.subject_id,
"memory_type": self.memory_type.value,
"nodes": [node.to_dict() for node in self.nodes],
"edges": [edge.to_dict() for edge in self.edges],
"importance": self.importance,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"last_accessed": self.last_accessed.isoformat(),
"access_count": self.access_count,
"decay_factor": self.decay_factor,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Memory:
"""从字典创建记忆"""
return cls(
id=data["id"],
subject_id=data["subject_id"],
memory_type=MemoryType(data["memory_type"]),
nodes=[MemoryNode.from_dict(n) for n in data["nodes"]],
edges=[MemoryEdge.from_dict(e) for e in data["edges"]],
importance=data.get("importance", 0.5),
status=MemoryStatus(data.get("status", "staged")),
created_at=datetime.fromisoformat(data["created_at"]),
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
access_count=data.get("access_count", 0),
decay_factor=data.get("decay_factor", 1.0),
metadata=data.get("metadata", {}),
)
def update_access(self) -> None:
"""更新访问记录"""
self.last_accessed = datetime.now()
self.access_count += 1
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]:
"""根据ID获取节点"""
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_subject_node(self) -> Optional[MemoryNode]:
"""获取主体节点"""
return self.get_node_by_id(self.subject_id)
def to_text(self) -> str:
"""转换为文本描述用于显示和LLM处理"""
subject_node = self.get_subject_node()
if not subject_node:
return f"[记忆 {self.id[:8]}]"
# 简单的文本生成逻辑
parts = [f"{subject_node.content}"]
# 查找主题节点(通过记忆类型边连接)
topic_node = None
for edge in self.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id:
topic_node = self.get_node_by_id(edge.target_id)
break
if topic_node:
parts.append(topic_node.content)
# 查找客体节点(通过核心关系边连接)
for edge in self.edges:
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
obj_node = self.get_node_by_id(edge.target_id)
if obj_node:
parts.append(f"{edge.relation} {obj_node.content}")
break
return " ".join(parts)
def __str__(self) -> str:
return f"Memory({self.memory_type.value}: {self.to_text()})"
@dataclass
class StagedMemory:
"""临时记忆(未整理状态)"""
memory: Memory # 原始记忆对象
status: MemoryStatus = MemoryStatus.STAGED # 状态
created_at: datetime = field(default_factory=datetime.now)
consolidated_at: Optional[datetime] = None # 整理时间
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"memory": self.memory.to_dict(),
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"consolidated_at": self.consolidated_at.isoformat() if self.consolidated_at else None,
"merge_history": self.merge_history,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory:
"""从字典创建临时记忆"""
return cls(
memory=Memory.from_dict(data["memory"]),
status=MemoryStatus(data.get("status", "staged")),
created_at=datetime.fromisoformat(data["created_at"]),
consolidated_at=datetime.fromisoformat(data["consolidated_at"]) if data.get("consolidated_at") else None,
merge_history=data.get("merge_history", []),
)

View File

@@ -0,0 +1,8 @@
"""
存储层模块
"""
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
__all__ = ["VectorStore", "GraphStore"]

View File

@@ -0,0 +1,389 @@
"""
图存储层:基于 NetworkX 的图结构管理
"""
from __future__ import annotations
from typing import Dict, List, Optional, Set, Tuple
import networkx as nx
from src.common.logger import get_logger
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
logger = get_logger(__name__)
class GraphStore:
"""
图存储封装类
负责:
1. 记忆图的构建和维护
2. 节点和边的快速查询
3. 图遍历算法BFS/DFS
4. 邻接关系查询
"""
def __init__(self):
"""初始化图存储"""
# 使用有向图(记忆关系通常是有向的)
self.graph = nx.DiGraph()
# 索引记忆ID -> 记忆对象
self.memory_index: Dict[str, Memory] = {}
# 索引节点ID -> 所属记忆ID集合
self.node_to_memories: Dict[str, Set[str]] = {}
logger.info("初始化图存储")
def add_memory(self, memory: Memory) -> None:
"""
添加记忆到图
Args:
memory: 要添加的记忆
"""
try:
# 1. 添加所有节点到图
for node in memory.nodes:
if not self.graph.has_node(node.id):
self.graph.add_node(
node.id,
content=node.content,
node_type=node.node_type.value,
created_at=node.created_at.isoformat(),
metadata=node.metadata,
)
# 更新节点到记忆的映射
if node.id not in self.node_to_memories:
self.node_to_memories[node.id] = set()
self.node_to_memories[node.id].add(memory.id)
# 2. 添加所有边到图
for edge in memory.edges:
self.graph.add_edge(
edge.source_id,
edge.target_id,
edge_id=edge.id,
relation=edge.relation,
edge_type=edge.edge_type.value,
importance=edge.importance,
metadata=edge.metadata,
created_at=edge.created_at.isoformat(),
)
# 3. 保存记忆对象
self.memory_index[memory.id] = memory
logger.debug(f"添加记忆到图: {memory}")
except Exception as e:
logger.error(f"添加记忆失败: {e}", exc_info=True)
raise
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
"""
根据ID获取记忆
Args:
memory_id: 记忆ID
Returns:
记忆对象或 None
"""
return self.memory_index.get(memory_id)
def get_memories_by_node(self, node_id: str) -> List[Memory]:
"""
获取包含指定节点的所有记忆
Args:
node_id: 节点ID
Returns:
记忆列表
"""
if node_id not in self.node_to_memories:
return []
memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
"""
获取从指定节点出发的所有边
Args:
node_id: 源节点ID
relation_types: 关系类型过滤(可选)
Returns:
边信息列表
"""
if not self.graph.has_node(node_id):
return []
edges = []
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
# 过滤关系类型
if relation_types and edge_data.get("relation") not in relation_types:
continue
edges.append(
{
"source_id": node_id,
"target_id": target_id,
"relation": edge_data.get("relation"),
"edge_type": edge_data.get("edge_type"),
"importance": edge_data.get("importance", 0.5),
**edge_data,
}
)
return edges
def get_neighbors(
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
) -> List[Tuple[str, Dict]]:
"""
获取节点的邻居节点
Args:
node_id: 节点ID
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
relation_types: 关系类型过滤
Returns:
List of (neighbor_id, edge_data)
"""
if not self.graph.has_node(node_id):
return []
neighbors = []
# 处理出边
if direction in ["out", "both"]:
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
if not relation_types or edge_data.get("relation") in relation_types:
neighbors.append((target_id, edge_data))
# 处理入边
if direction in ["in", "both"]:
for source_id, _, edge_data in self.graph.in_edges(node_id, data=True):
if not relation_types or edge_data.get("relation") in relation_types:
neighbors.append((source_id, edge_data))
return neighbors
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
"""
查找两个节点之间的最短路径
Args:
source_id: 源节点ID
target_id: 目标节点ID
max_length: 最大路径长度(可选)
Returns:
路径节点ID列表或 None如果不存在路径
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
return None
try:
if max_length:
# 使用 cutoff 限制路径长度
path = nx.shortest_path(self.graph, source_id, target_id, weight=None)
if len(path) - 1 <= max_length: # 边数 = 节点数 - 1
return path
return None
else:
return nx.shortest_path(self.graph, source_id, target_id, weight=None)
except nx.NetworkXNoPath:
return None
except Exception as e:
logger.error(f"查找路径失败: {e}", exc_info=True)
return None
def bfs_expand(
self,
start_nodes: List[str],
depth: int = 1,
relation_types: Optional[List[str]] = None,
) -> Set[str]:
"""
从起始节点进行广度优先搜索扩展
Args:
start_nodes: 起始节点ID列表
depth: 扩展深度
relation_types: 关系类型过滤
Returns:
扩展到的所有节点ID集合
"""
visited = set()
queue = [(node_id, 0) for node_id in start_nodes if self.graph.has_node(node_id)]
while queue:
current_node, current_depth = queue.pop(0)
if current_node in visited:
continue
visited.add(current_node)
if current_depth >= depth:
continue
# 获取邻居并加入队列
neighbors = self.get_neighbors(current_node, direction="out", relation_types=relation_types)
for neighbor_id, _ in neighbors:
if neighbor_id not in visited:
queue.append((neighbor_id, current_depth + 1))
return visited
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
"""
获取包含指定节点的子图
Args:
node_ids: 节点ID列表
Returns:
NetworkX 子图
"""
return self.graph.subgraph(node_ids).copy()
def merge_nodes(self, source_id: str, target_id: str) -> None:
"""
合并两个节点将source的所有边转移到target然后删除source
Args:
source_id: 源节点ID将被删除
target_id: 目标节点ID保留
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
logger.warning(f"合并节点失败: 节点不存在 ({source_id}, {target_id})")
return
try:
# 1. 转移入边
for pred, _, edge_data in self.graph.in_edges(source_id, data=True):
if pred != target_id: # 避免自环
self.graph.add_edge(pred, target_id, **edge_data)
# 2. 转移出边
for _, succ, edge_data in self.graph.out_edges(source_id, data=True):
if succ != target_id: # 避免自环
self.graph.add_edge(target_id, succ, **edge_data)
# 3. 更新节点到记忆的映射
if source_id in self.node_to_memories:
memory_ids = self.node_to_memories[source_id]
if target_id not in self.node_to_memories:
self.node_to_memories[target_id] = set()
self.node_to_memories[target_id].update(memory_ids)
del self.node_to_memories[source_id]
# 4. 删除源节点
self.graph.remove_node(source_id)
logger.info(f"节点合并: {source_id}{target_id}")
except Exception as e:
logger.error(f"合并节点失败: {e}", exc_info=True)
raise
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
"""
获取节点的度数
Args:
node_id: 节点ID
Returns:
(in_degree, out_degree)
"""
if not self.graph.has_node(node_id):
return (0, 0)
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
def get_statistics(self) -> Dict[str, int]:
"""获取图的统计信息"""
return {
"total_nodes": self.graph.number_of_nodes(),
"total_edges": self.graph.number_of_edges(),
"total_memories": len(self.memory_index),
"connected_components": nx.number_weakly_connected_components(self.graph),
}
def to_dict(self) -> Dict:
"""
将图转换为字典(用于持久化)
Returns:
图的字典表示
"""
return {
"nodes": [
{"id": node_id, **self.graph.nodes[node_id]} for node_id in self.graph.nodes()
],
"edges": [
{
"source": u,
"target": v,
**data,
}
for u, v, data in self.graph.edges(data=True)
],
"memories": {memory_id: memory.to_dict() for memory_id, memory in self.memory_index.items()},
"node_to_memories": {node_id: list(mem_ids) for node_id, mem_ids in self.node_to_memories.items()},
}
@classmethod
def from_dict(cls, data: Dict) -> GraphStore:
"""
从字典加载图
Args:
data: 图的字典表示
Returns:
GraphStore 实例
"""
store = cls()
# 1. 加载节点
for node_data in data.get("nodes", []):
node_id = node_data.pop("id")
store.graph.add_node(node_id, **node_data)
# 2. 加载边
for edge_data in data.get("edges", []):
source = edge_data.pop("source")
target = edge_data.pop("target")
store.graph.add_edge(source, target, **edge_data)
# 3. 加载记忆
for memory_id, memory_dict in data.get("memories", {}).items():
store.memory_index[memory_id] = Memory.from_dict(memory_dict)
# 4. 加载节点到记忆的映射
for node_id, mem_ids in data.get("node_to_memories", {}).items():
store.node_to_memories[node_id] = set(mem_ids)
logger.info(f"从字典加载图: {store.get_statistics()}")
return store
def clear(self) -> None:
"""清空图(危险操作,仅用于测试)"""
self.graph.clear()
self.memory_index.clear()
self.node_to_memories.clear()
logger.warning("图存储已清空")

View File

@@ -0,0 +1,297 @@
"""
向量存储层:基于 ChromaDB 的语义向量存储
"""
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from src.common.logger import get_logger
from src.memory_graph.models import MemoryNode, NodeType
logger = get_logger(__name__)
class VectorStore:
"""
向量存储封装类
负责:
1. 节点的语义向量存储和检索
2. 基于相似度的向量搜索
3. 节点去重时的相似节点查找
"""
def __init__(
self,
collection_name: str = "memory_nodes",
data_dir: Optional[Path] = None,
embedding_function: Optional[Any] = None,
):
"""
初始化向量存储
Args:
collection_name: ChromaDB 集合名称
data_dir: 数据存储目录
embedding_function: 嵌入函数如果为None则使用默认
"""
self.collection_name = collection_name
self.data_dir = data_dir or Path("data/memory_graph")
self.data_dir.mkdir(parents=True, exist_ok=True)
self.client = None
self.collection = None
self.embedding_function = embedding_function
logger.info(f"初始化向量存储: collection={collection_name}, dir={self.data_dir}")
async def initialize(self) -> None:
"""异步初始化 ChromaDB"""
try:
import chromadb
from chromadb.config import Settings
# 创建持久化客户端
self.client = chromadb.PersistentClient(
path=str(self.data_dir / "chroma"),
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
),
)
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
logger.info(f"ChromaDB 初始化完成,集合包含 {self.collection.count()} 个节点")
except Exception as e:
logger.error(f"初始化 ChromaDB 失败: {e}", exc_info=True)
raise
async def add_node(self, node: MemoryNode) -> None:
"""
添加节点到向量存储
Args:
node: 要添加的节点
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
if not node.has_embedding():
logger.warning(f"节点 {node.id} 没有 embedding跳过添加")
return
try:
self.collection.add(
ids=[node.id],
embeddings=[node.embedding.tolist()],
metadatas=[
{
"content": node.content,
"node_type": node.node_type.value,
"created_at": node.created_at.isoformat(),
**node.metadata,
}
],
documents=[node.content], # 文本内容用于检索
)
logger.debug(f"添加节点到向量存储: {node}")
except Exception as e:
logger.error(f"添加节点失败: {e}", exc_info=True)
raise
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
"""
批量添加节点
Args:
nodes: 节点列表
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
# 过滤出有 embedding 的节点
valid_nodes = [n for n in nodes if n.has_embedding()]
if not valid_nodes:
logger.warning("批量添加:没有有效的节点(缺少 embedding")
return
try:
self.collection.add(
ids=[n.id for n in valid_nodes],
embeddings=[n.embedding.tolist() for n in valid_nodes],
metadatas=[
{
"content": n.content,
"node_type": n.node_type.value,
"created_at": n.created_at.isoformat(),
**n.metadata,
}
for n in valid_nodes
],
documents=[n.content for n in valid_nodes],
)
logger.info(f"批量添加 {len(valid_nodes)} 个节点到向量存储")
except Exception as e:
logger.error(f"批量添加节点失败: {e}", exc_info=True)
raise
async def search_similar_nodes(
self,
query_embedding: np.ndarray,
limit: int = 10,
node_types: Optional[List[NodeType]] = None,
min_similarity: float = 0.0,
) -> List[Tuple[str, float, Dict[str, Any]]]:
"""
搜索相似节点
Args:
query_embedding: 查询向量
limit: 返回结果数量
node_types: 限制节点类型(可选)
min_similarity: 最小相似度阈值
Returns:
List of (node_id, similarity, metadata)
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
# 构建 where 条件
where_filter = None
if node_types:
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
# 执行查询
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=limit,
where=where_filter,
)
# 解析结果
similar_nodes = []
if results["ids"] and results["ids"][0]:
for i, node_id in enumerate(results["ids"][0]):
# ChromaDB 返回的是距离,需要转换为相似度
# 余弦距离: distance = 1 - similarity
distance = results["distances"][0][i]
similarity = 1.0 - distance
if similarity >= min_similarity:
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
similar_nodes.append((node_id, similarity, metadata))
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
return similar_nodes
except Exception as e:
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
raise
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
根据ID获取节点元数据
Args:
node_id: 节点ID
Returns:
节点元数据或 None
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
result = self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
if result["ids"]:
return {
"id": result["ids"][0],
"metadata": result["metadatas"][0] if result["metadatas"] else {},
"embedding": np.array(result["embeddings"][0]) if result["embeddings"] else None,
}
return None
except Exception as e:
logger.error(f"获取节点失败: {e}", exc_info=True)
return None
async def delete_node(self, node_id: str) -> None:
"""
删除节点
Args:
node_id: 节点ID
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
self.collection.delete(ids=[node_id])
logger.debug(f"删除节点: {node_id}")
except Exception as e:
logger.error(f"删除节点失败: {e}", exc_info=True)
raise
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
"""
更新节点的 embedding
Args:
node_id: 节点ID
embedding: 新的向量
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
logger.debug(f"更新节点 embedding: {node_id}")
except Exception as e:
logger.error(f"更新节点 embedding 失败: {e}", exc_info=True)
raise
def get_total_count(self) -> int:
"""获取向量存储中的节点总数"""
if not self.collection:
return 0
return self.collection.count()
async def clear(self) -> None:
"""清空向量存储(危险操作,仅用于测试)"""
if not self.collection:
return
try:
# 删除并重新创建集合
self.client.delete_collection(self.collection_name)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
logger.warning(f"向量存储已清空: {self.collection_name}")
except Exception as e:
logger.error(f"清空向量存储失败: {e}", exc_info=True)
raise