feat(memory): 添加运行期索引以优化节点和边的访问性能
This commit is contained in:
@@ -238,6 +238,10 @@ class Memory:
|
||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||
|
||||
# 运行期索引(不序列化)
|
||||
_node_index: dict[str, "MemoryNode"] = field(default_factory=dict, init=False, repr=False, compare=False)
|
||||
_edges_by_source: dict[str, list["MemoryEdge"]] = field(default_factory=dict, init=False, repr=False, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if not self.id:
|
||||
@@ -247,6 +251,73 @@ class Memory:
|
||||
self.activation = max(0.0, min(1.0, self.activation))
|
||||
if not self.updated_at:
|
||||
self.updated_at = self.created_at
|
||||
# 构建加速索引
|
||||
self._rebuild_indexes()
|
||||
|
||||
# --------------------
|
||||
# 索引与快速访问
|
||||
# --------------------
|
||||
def _rebuild_indexes(self) -> None:
|
||||
"""从当前 nodes/edges 重新构建运行期索引。"""
|
||||
# 节点索引
|
||||
self._node_index.clear()
|
||||
for n in self.nodes:
|
||||
self._node_index[n.id] = n
|
||||
|
||||
# 源端分组的边索引
|
||||
self._edges_by_source.clear()
|
||||
for e in self.edges:
|
||||
self._edges_by_source.setdefault(e.source_id, []).append(e)
|
||||
|
||||
def add_node(self, node: "MemoryNode") -> None:
|
||||
"""添加节点并更新索引。"""
|
||||
self.nodes.append(node)
|
||||
self._node_index[node.id] = node
|
||||
|
||||
def remove_node(self, node_id: str) -> bool:
|
||||
"""按 ID 移除节点,并清理相关索引。返回是否移除成功。"""
|
||||
node = self._node_index.pop(node_id, None)
|
||||
if node is None:
|
||||
return False
|
||||
# 从列表移除(保持原有数据结构契合度)
|
||||
for i, n in enumerate(self.nodes):
|
||||
if n.id == node_id:
|
||||
del self.nodes[i]
|
||||
break
|
||||
# 同时移除以该节点为源或目标的边
|
||||
new_edges = []
|
||||
affected_sources: set[str] = set()
|
||||
for e in self.edges:
|
||||
if e.source_id == node_id or e.target_id == node_id:
|
||||
affected_sources.add(e.source_id)
|
||||
continue
|
||||
new_edges.append(e)
|
||||
self.edges = new_edges
|
||||
# 更新边索引
|
||||
for sid in affected_sources:
|
||||
if sid in self._edges_by_source:
|
||||
self._edges_by_source[sid] = [e for e in self._edges_by_source[sid] if e.source_id != node_id and e.target_id != node_id]
|
||||
if not self._edges_by_source[sid]:
|
||||
self._edges_by_source.pop(sid, None)
|
||||
return True
|
||||
|
||||
def add_edge(self, edge: "MemoryEdge") -> None:
|
||||
"""添加边并更新索引。"""
|
||||
self.edges.append(edge)
|
||||
self._edges_by_source.setdefault(edge.source_id, []).append(edge)
|
||||
|
||||
def remove_edge(self, edge_id: str) -> bool:
|
||||
"""按 ID 移除边,并更新索引。返回是否移除成功。"""
|
||||
for i, e in enumerate(self.edges):
|
||||
if e.id == edge_id:
|
||||
del self.edges[i]
|
||||
src_list = self._edges_by_source.get(e.source_id)
|
||||
if src_list is not None:
|
||||
self._edges_by_source[e.source_id] = [x for x in src_list if x.id != edge_id]
|
||||
if not self._edges_by_source[e.source_id]:
|
||||
self._edges_by_source.pop(e.source_id, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
@@ -312,10 +383,13 @@ class Memory:
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
|
||||
"""根据ID获取节点"""
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
# O(1) 访问,回退安全(极少情况下索引不同步可重建)
|
||||
node = self._node_index.get(node_id)
|
||||
if node is not None:
|
||||
return node
|
||||
# 回退:重建索引后再取,避免外部直接修改列表导致的不同步
|
||||
self._rebuild_indexes()
|
||||
return self._node_index.get(node_id)
|
||||
|
||||
def get_subject_node(self) -> MemoryNode | None:
|
||||
"""获取主体节点"""
|
||||
@@ -330,21 +404,22 @@ class Memory:
|
||||
# 简单的文本生成逻辑
|
||||
parts = [f"{subject_node.content}"]
|
||||
|
||||
# 查找主题节点(通过记忆类型边连接)
|
||||
topic_node = None
|
||||
for edge in self.edges:
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id:
|
||||
# 查找主题节点(通过记忆类型边连接)——使用索引加速
|
||||
topic_node: MemoryNode | None = None
|
||||
for edge in self._edges_by_source.get(self.subject_id, []):
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE:
|
||||
topic_node = self.get_node_by_id(edge.target_id)
|
||||
break
|
||||
if topic_node is not None:
|
||||
break
|
||||
|
||||
if topic_node:
|
||||
parts.append(topic_node.content)
|
||||
|
||||
# 查找客体节点(通过核心关系边连接)
|
||||
for edge in self.edges:
|
||||
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
|
||||
# 查找客体节点(通过核心关系边连接)——使用索引加速
|
||||
for edge in self._edges_by_source.get(topic_node.id, []):
|
||||
if edge.edge_type == EdgeType.CORE_RELATION:
|
||||
obj_node = self.get_node_by_id(edge.target_id)
|
||||
if obj_node:
|
||||
if obj_node is not None:
|
||||
parts.append(f"{edge.relation} {obj_node.content}")
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user