feat(memory-graph): Phase 1 基础架构实现
- 定义核心数据模型 (MemoryNode, MemoryEdge, Memory) - 实现配置管理系统 (MemoryGraphConfig) - 实现向量存储层 (VectorStore with ChromaDB) - 实现图存储层 (GraphStore with NetworkX) - 创建设计文档大纲 - 添加基础测试并验证通过 待完成: - 持久化管理 - 节点去重逻辑 - 记忆构建器 - 记忆检索器
This commit is contained in:
8
src/memory_graph/storage/__init__.py
Normal file
8
src/memory_graph/storage/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
存储层模块
|
||||
"""
|
||||
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
|
||||
__all__ = ["VectorStore", "GraphStore"]
|
||||
389
src/memory_graph/storage/graph_store.py
Normal file
389
src/memory_graph/storage/graph_store.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
图存储层:基于 NetworkX 的图结构管理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GraphStore:
|
||||
"""
|
||||
图存储封装类
|
||||
|
||||
负责:
|
||||
1. 记忆图的构建和维护
|
||||
2. 节点和边的快速查询
|
||||
3. 图遍历算法(BFS/DFS)
|
||||
4. 邻接关系查询
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化图存储"""
|
||||
# 使用有向图(记忆关系通常是有向的)
|
||||
self.graph = nx.DiGraph()
|
||||
|
||||
# 索引:记忆ID -> 记忆对象
|
||||
self.memory_index: Dict[str, Memory] = {}
|
||||
|
||||
# 索引:节点ID -> 所属记忆ID集合
|
||||
self.node_to_memories: Dict[str, Set[str]] = {}
|
||||
|
||||
logger.info("初始化图存储")
|
||||
|
||||
def add_memory(self, memory: Memory) -> None:
|
||||
"""
|
||||
添加记忆到图
|
||||
|
||||
Args:
|
||||
memory: 要添加的记忆
|
||||
"""
|
||||
try:
|
||||
# 1. 添加所有节点到图
|
||||
for node in memory.nodes:
|
||||
if not self.graph.has_node(node.id):
|
||||
self.graph.add_node(
|
||||
node.id,
|
||||
content=node.content,
|
||||
node_type=node.node_type.value,
|
||||
created_at=node.created_at.isoformat(),
|
||||
metadata=node.metadata,
|
||||
)
|
||||
|
||||
# 更新节点到记忆的映射
|
||||
if node.id not in self.node_to_memories:
|
||||
self.node_to_memories[node.id] = set()
|
||||
self.node_to_memories[node.id].add(memory.id)
|
||||
|
||||
# 2. 添加所有边到图
|
||||
for edge in memory.edges:
|
||||
self.graph.add_edge(
|
||||
edge.source_id,
|
||||
edge.target_id,
|
||||
edge_id=edge.id,
|
||||
relation=edge.relation,
|
||||
edge_type=edge.edge_type.value,
|
||||
importance=edge.importance,
|
||||
metadata=edge.metadata,
|
||||
created_at=edge.created_at.isoformat(),
|
||||
)
|
||||
|
||||
# 3. 保存记忆对象
|
||||
self.memory_index[memory.id] = memory
|
||||
|
||||
logger.debug(f"添加记忆到图: {memory}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
|
||||
"""
|
||||
根据ID获取记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
记忆对象或 None
|
||||
"""
|
||||
return self.memory_index.get(memory_id)
|
||||
|
||||
def get_memories_by_node(self, node_id: str) -> List[Memory]:
|
||||
"""
|
||||
获取包含指定节点的所有记忆
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
if node_id not in self.node_to_memories:
|
||||
return []
|
||||
|
||||
memory_ids = self.node_to_memories[node_id]
|
||||
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
||||
|
||||
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
|
||||
"""
|
||||
获取从指定节点出发的所有边
|
||||
|
||||
Args:
|
||||
node_id: 源节点ID
|
||||
relation_types: 关系类型过滤(可选)
|
||||
|
||||
Returns:
|
||||
边信息列表
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
edges = []
|
||||
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
|
||||
# 过滤关系类型
|
||||
if relation_types and edge_data.get("relation") not in relation_types:
|
||||
continue
|
||||
|
||||
edges.append(
|
||||
{
|
||||
"source_id": node_id,
|
||||
"target_id": target_id,
|
||||
"relation": edge_data.get("relation"),
|
||||
"edge_type": edge_data.get("edge_type"),
|
||||
"importance": edge_data.get("importance", 0.5),
|
||||
**edge_data,
|
||||
}
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
def get_neighbors(
|
||||
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
|
||||
) -> List[Tuple[str, Dict]]:
|
||||
"""
|
||||
获取节点的邻居节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
|
||||
relation_types: 关系类型过滤
|
||||
|
||||
Returns:
|
||||
List of (neighbor_id, edge_data)
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
neighbors = []
|
||||
|
||||
# 处理出边
|
||||
if direction in ["out", "both"]:
|
||||
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
|
||||
if not relation_types or edge_data.get("relation") in relation_types:
|
||||
neighbors.append((target_id, edge_data))
|
||||
|
||||
# 处理入边
|
||||
if direction in ["in", "both"]:
|
||||
for source_id, _, edge_data in self.graph.in_edges(node_id, data=True):
|
||||
if not relation_types or edge_data.get("relation") in relation_types:
|
||||
neighbors.append((source_id, edge_data))
|
||||
|
||||
return neighbors
|
||||
|
||||
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
|
||||
"""
|
||||
查找两个节点之间的最短路径
|
||||
|
||||
Args:
|
||||
source_id: 源节点ID
|
||||
target_id: 目标节点ID
|
||||
max_length: 最大路径长度(可选)
|
||||
|
||||
Returns:
|
||||
路径节点ID列表,或 None(如果不存在路径)
|
||||
"""
|
||||
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
|
||||
return None
|
||||
|
||||
try:
|
||||
if max_length:
|
||||
# 使用 cutoff 限制路径长度
|
||||
path = nx.shortest_path(self.graph, source_id, target_id, weight=None)
|
||||
if len(path) - 1 <= max_length: # 边数 = 节点数 - 1
|
||||
return path
|
||||
return None
|
||||
else:
|
||||
return nx.shortest_path(self.graph, source_id, target_id, weight=None)
|
||||
|
||||
except nx.NetworkXNoPath:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"查找路径失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def bfs_expand(
|
||||
self,
|
||||
start_nodes: List[str],
|
||||
depth: int = 1,
|
||||
relation_types: Optional[List[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
从起始节点进行广度优先搜索扩展
|
||||
|
||||
Args:
|
||||
start_nodes: 起始节点ID列表
|
||||
depth: 扩展深度
|
||||
relation_types: 关系类型过滤
|
||||
|
||||
Returns:
|
||||
扩展到的所有节点ID集合
|
||||
"""
|
||||
visited = set()
|
||||
queue = [(node_id, 0) for node_id in start_nodes if self.graph.has_node(node_id)]
|
||||
|
||||
while queue:
|
||||
current_node, current_depth = queue.pop(0)
|
||||
|
||||
if current_node in visited:
|
||||
continue
|
||||
visited.add(current_node)
|
||||
|
||||
if current_depth >= depth:
|
||||
continue
|
||||
|
||||
# 获取邻居并加入队列
|
||||
neighbors = self.get_neighbors(current_node, direction="out", relation_types=relation_types)
|
||||
for neighbor_id, _ in neighbors:
|
||||
if neighbor_id not in visited:
|
||||
queue.append((neighbor_id, current_depth + 1))
|
||||
|
||||
return visited
|
||||
|
||||
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
|
||||
"""
|
||||
获取包含指定节点的子图
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
|
||||
Returns:
|
||||
NetworkX 子图
|
||||
"""
|
||||
return self.graph.subgraph(node_ids).copy()
|
||||
|
||||
def merge_nodes(self, source_id: str, target_id: str) -> None:
|
||||
"""
|
||||
合并两个节点(将source的所有边转移到target,然后删除source)
|
||||
|
||||
Args:
|
||||
source_id: 源节点ID(将被删除)
|
||||
target_id: 目标节点ID(保留)
|
||||
"""
|
||||
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
|
||||
logger.warning(f"合并节点失败: 节点不存在 ({source_id}, {target_id})")
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. 转移入边
|
||||
for pred, _, edge_data in self.graph.in_edges(source_id, data=True):
|
||||
if pred != target_id: # 避免自环
|
||||
self.graph.add_edge(pred, target_id, **edge_data)
|
||||
|
||||
# 2. 转移出边
|
||||
for _, succ, edge_data in self.graph.out_edges(source_id, data=True):
|
||||
if succ != target_id: # 避免自环
|
||||
self.graph.add_edge(target_id, succ, **edge_data)
|
||||
|
||||
# 3. 更新节点到记忆的映射
|
||||
if source_id in self.node_to_memories:
|
||||
memory_ids = self.node_to_memories[source_id]
|
||||
if target_id not in self.node_to_memories:
|
||||
self.node_to_memories[target_id] = set()
|
||||
self.node_to_memories[target_id].update(memory_ids)
|
||||
del self.node_to_memories[source_id]
|
||||
|
||||
# 4. 删除源节点
|
||||
self.graph.remove_node(source_id)
|
||||
|
||||
logger.info(f"节点合并: {source_id} → {target_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
|
||||
"""
|
||||
获取节点的度数
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
(in_degree, out_degree)
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
return (0, 0)
|
||||
|
||||
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""获取图的统计信息"""
|
||||
return {
|
||||
"total_nodes": self.graph.number_of_nodes(),
|
||||
"total_edges": self.graph.number_of_edges(),
|
||||
"total_memories": len(self.memory_index),
|
||||
"connected_components": nx.number_weakly_connected_components(self.graph),
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
将图转换为字典(用于持久化)
|
||||
|
||||
Returns:
|
||||
图的字典表示
|
||||
"""
|
||||
return {
|
||||
"nodes": [
|
||||
{"id": node_id, **self.graph.nodes[node_id]} for node_id in self.graph.nodes()
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": u,
|
||||
"target": v,
|
||||
**data,
|
||||
}
|
||||
for u, v, data in self.graph.edges(data=True)
|
||||
],
|
||||
"memories": {memory_id: memory.to_dict() for memory_id, memory in self.memory_index.items()},
|
||||
"node_to_memories": {node_id: list(mem_ids) for node_id, mem_ids in self.node_to_memories.items()},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> GraphStore:
|
||||
"""
|
||||
从字典加载图
|
||||
|
||||
Args:
|
||||
data: 图的字典表示
|
||||
|
||||
Returns:
|
||||
GraphStore 实例
|
||||
"""
|
||||
store = cls()
|
||||
|
||||
# 1. 加载节点
|
||||
for node_data in data.get("nodes", []):
|
||||
node_id = node_data.pop("id")
|
||||
store.graph.add_node(node_id, **node_data)
|
||||
|
||||
# 2. 加载边
|
||||
for edge_data in data.get("edges", []):
|
||||
source = edge_data.pop("source")
|
||||
target = edge_data.pop("target")
|
||||
store.graph.add_edge(source, target, **edge_data)
|
||||
|
||||
# 3. 加载记忆
|
||||
for memory_id, memory_dict in data.get("memories", {}).items():
|
||||
store.memory_index[memory_id] = Memory.from_dict(memory_dict)
|
||||
|
||||
# 4. 加载节点到记忆的映射
|
||||
for node_id, mem_ids in data.get("node_to_memories", {}).items():
|
||||
store.node_to_memories[node_id] = set(mem_ids)
|
||||
|
||||
logger.info(f"从字典加载图: {store.get_statistics()}")
|
||||
return store
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空图(危险操作,仅用于测试)"""
|
||||
self.graph.clear()
|
||||
self.memory_index.clear()
|
||||
self.node_to_memories.clear()
|
||||
logger.warning("图存储已清空")
|
||||
297
src/memory_graph/storage/vector_store.py
Normal file
297
src/memory_graph/storage/vector_store.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
向量存储层:基于 ChromaDB 的语义向量存储
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import MemoryNode, NodeType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
向量存储封装类
|
||||
|
||||
负责:
|
||||
1. 节点的语义向量存储和检索
|
||||
2. 基于相似度的向量搜索
|
||||
3. 节点去重时的相似节点查找
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "memory_nodes",
|
||||
data_dir: Optional[Path] = None,
|
||||
embedding_function: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化向量存储
|
||||
|
||||
Args:
|
||||
collection_name: ChromaDB 集合名称
|
||||
data_dir: 数据存储目录
|
||||
embedding_function: 嵌入函数(如果为None则使用默认)
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.data_dir = data_dir or Path("data/memory_graph")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = None
|
||||
self.collection = None
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
logger.info(f"初始化向量存储: collection={collection_name}, dir={self.data_dir}")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""异步初始化 ChromaDB"""
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
# 创建持久化客户端
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=str(self.data_dir / "chroma"),
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 获取或创建集合
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"ChromaDB 初始化完成,集合包含 {self.collection.count()} 个节点")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 ChromaDB 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_node(self, node: MemoryNode) -> None:
|
||||
"""
|
||||
添加节点到向量存储
|
||||
|
||||
Args:
|
||||
node: 要添加的节点
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
if not node.has_embedding():
|
||||
logger.warning(f"节点 {node.id} 没有 embedding,跳过添加")
|
||||
return
|
||||
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[node.id],
|
||||
embeddings=[node.embedding.tolist()],
|
||||
metadatas=[
|
||||
{
|
||||
"content": node.content,
|
||||
"node_type": node.node_type.value,
|
||||
"created_at": node.created_at.isoformat(),
|
||||
**node.metadata,
|
||||
}
|
||||
],
|
||||
documents=[node.content], # 文本内容用于检索
|
||||
)
|
||||
|
||||
logger.debug(f"添加节点到向量存储: {node}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
|
||||
"""
|
||||
批量添加节点
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
# 过滤出有 embedding 的节点
|
||||
valid_nodes = [n for n in nodes if n.has_embedding()]
|
||||
|
||||
if not valid_nodes:
|
||||
logger.warning("批量添加:没有有效的节点(缺少 embedding)")
|
||||
return
|
||||
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[n.id for n in valid_nodes],
|
||||
embeddings=[n.embedding.tolist() for n in valid_nodes],
|
||||
metadatas=[
|
||||
{
|
||||
"content": n.content,
|
||||
"node_type": n.node_type.value,
|
||||
"created_at": n.created_at.isoformat(),
|
||||
**n.metadata,
|
||||
}
|
||||
for n in valid_nodes
|
||||
],
|
||||
documents=[n.content for n in valid_nodes],
|
||||
)
|
||||
|
||||
logger.info(f"批量添加 {len(valid_nodes)} 个节点到向量存储")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量添加节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def search_similar_nodes(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
limit: int = 10,
|
||||
node_types: Optional[List[NodeType]] = None,
|
||||
min_similarity: float = 0.0,
|
||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||
"""
|
||||
搜索相似节点
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
limit: 返回结果数量
|
||||
node_types: 限制节点类型(可选)
|
||||
min_similarity: 最小相似度阈值
|
||||
|
||||
Returns:
|
||||
List of (node_id, similarity, metadata)
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
# 构建 where 条件
|
||||
where_filter = None
|
||||
if node_types:
|
||||
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
|
||||
|
||||
# 执行查询
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=limit,
|
||||
where=where_filter,
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
similar_nodes = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, node_id in enumerate(results["ids"][0]):
|
||||
# ChromaDB 返回的是距离,需要转换为相似度
|
||||
# 余弦距离: distance = 1 - similarity
|
||||
distance = results["distances"][0][i]
|
||||
similarity = 1.0 - distance
|
||||
|
||||
if similarity >= min_similarity:
|
||||
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
similar_nodes.append((node_id, similarity, metadata))
|
||||
|
||||
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
||||
return similar_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据ID获取节点元数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
节点元数据或 None
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
result = self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
|
||||
|
||||
if result["ids"]:
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"metadata": result["metadatas"][0] if result["metadatas"] else {},
|
||||
"embedding": np.array(result["embeddings"][0]) if result["embeddings"] else None,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取节点失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
删除节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
self.collection.delete(ids=[node_id])
|
||||
logger.debug(f"删除节点: {node_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除节点失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
|
||||
"""
|
||||
更新节点的 embedding
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
embedding: 新的向量
|
||||
"""
|
||||
if not self.collection:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
|
||||
logger.debug(f"更新节点 embedding: {node_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新节点 embedding 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""获取向量存储中的节点总数"""
|
||||
if not self.collection:
|
||||
return 0
|
||||
return self.collection.count()
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""清空向量存储(危险操作,仅用于测试)"""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
try:
|
||||
# 删除并重新创建集合
|
||||
self.client.delete_collection(self.collection_name)
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
logger.warning(f"向量存储已清空: {self.collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空向量存储失败: {e}", exc_info=True)
|
||||
raise
|
||||
Reference in New Issue
Block a user