feat(memory): 添加运行期索引以优化节点和边的访问性能
This commit is contained in:
@@ -238,6 +238,10 @@ class Memory:
|
|||||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||||
|
|
||||||
|
# 运行期索引(不序列化)
|
||||||
|
_node_index: dict[str, "MemoryNode"] = field(default_factory=dict, init=False, repr=False, compare=False)
|
||||||
|
_edges_by_source: dict[str, list["MemoryEdge"]] = field(default_factory=dict, init=False, repr=False, compare=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""后初始化处理"""
|
"""后初始化处理"""
|
||||||
if not self.id:
|
if not self.id:
|
||||||
@@ -247,6 +251,73 @@ class Memory:
|
|||||||
self.activation = max(0.0, min(1.0, self.activation))
|
self.activation = max(0.0, min(1.0, self.activation))
|
||||||
if not self.updated_at:
|
if not self.updated_at:
|
||||||
self.updated_at = self.created_at
|
self.updated_at = self.created_at
|
||||||
|
# 构建加速索引
|
||||||
|
self._rebuild_indexes()
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# 索引与快速访问
|
||||||
|
# --------------------
|
||||||
|
def _rebuild_indexes(self) -> None:
|
||||||
|
"""从当前 nodes/edges 重新构建运行期索引。"""
|
||||||
|
# 节点索引
|
||||||
|
self._node_index.clear()
|
||||||
|
for n in self.nodes:
|
||||||
|
self._node_index[n.id] = n
|
||||||
|
|
||||||
|
# 源端分组的边索引
|
||||||
|
self._edges_by_source.clear()
|
||||||
|
for e in self.edges:
|
||||||
|
self._edges_by_source.setdefault(e.source_id, []).append(e)
|
||||||
|
|
||||||
|
def add_node(self, node: "MemoryNode") -> None:
|
||||||
|
"""添加节点并更新索引。"""
|
||||||
|
self.nodes.append(node)
|
||||||
|
self._node_index[node.id] = node
|
||||||
|
|
||||||
|
def remove_node(self, node_id: str) -> bool:
|
||||||
|
"""按 ID 移除节点,并清理相关索引。返回是否移除成功。"""
|
||||||
|
node = self._node_index.pop(node_id, None)
|
||||||
|
if node is None:
|
||||||
|
return False
|
||||||
|
# 从列表移除(保持原有数据结构契合度)
|
||||||
|
for i, n in enumerate(self.nodes):
|
||||||
|
if n.id == node_id:
|
||||||
|
del self.nodes[i]
|
||||||
|
break
|
||||||
|
# 同时移除以该节点为源或目标的边
|
||||||
|
new_edges = []
|
||||||
|
affected_sources: set[str] = set()
|
||||||
|
for e in self.edges:
|
||||||
|
if e.source_id == node_id or e.target_id == node_id:
|
||||||
|
affected_sources.add(e.source_id)
|
||||||
|
continue
|
||||||
|
new_edges.append(e)
|
||||||
|
self.edges = new_edges
|
||||||
|
# 更新边索引
|
||||||
|
for sid in affected_sources:
|
||||||
|
if sid in self._edges_by_source:
|
||||||
|
self._edges_by_source[sid] = [e for e in self._edges_by_source[sid] if e.source_id != node_id and e.target_id != node_id]
|
||||||
|
if not self._edges_by_source[sid]:
|
||||||
|
self._edges_by_source.pop(sid, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_edge(self, edge: "MemoryEdge") -> None:
|
||||||
|
"""添加边并更新索引。"""
|
||||||
|
self.edges.append(edge)
|
||||||
|
self._edges_by_source.setdefault(edge.source_id, []).append(edge)
|
||||||
|
|
||||||
|
def remove_edge(self, edge_id: str) -> bool:
|
||||||
|
"""按 ID 移除边,并更新索引。返回是否移除成功。"""
|
||||||
|
for i, e in enumerate(self.edges):
|
||||||
|
if e.id == edge_id:
|
||||||
|
del self.edges[i]
|
||||||
|
src_list = self._edges_by_source.get(e.source_id)
|
||||||
|
if src_list is not None:
|
||||||
|
self._edges_by_source[e.source_id] = [x for x in src_list if x.id != edge_id]
|
||||||
|
if not self._edges_by_source[e.source_id]:
|
||||||
|
self._edges_by_source.pop(e.source_id, None)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典(用于序列化)"""
|
"""转换为字典(用于序列化)"""
|
||||||
@@ -312,10 +383,13 @@ class Memory:
|
|||||||
|
|
||||||
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
|
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
|
||||||
"""根据ID获取节点"""
|
"""根据ID获取节点"""
|
||||||
for node in self.nodes:
|
# O(1) 访问,回退安全(极少情况下索引不同步可重建)
|
||||||
if node.id == node_id:
|
node = self._node_index.get(node_id)
|
||||||
|
if node is not None:
|
||||||
return node
|
return node
|
||||||
return None
|
# 回退:重建索引后再取,避免外部直接修改列表导致的不同步
|
||||||
|
self._rebuild_indexes()
|
||||||
|
return self._node_index.get(node_id)
|
||||||
|
|
||||||
def get_subject_node(self) -> MemoryNode | None:
|
def get_subject_node(self) -> MemoryNode | None:
|
||||||
"""获取主体节点"""
|
"""获取主体节点"""
|
||||||
@@ -330,21 +404,22 @@ class Memory:
|
|||||||
# 简单的文本生成逻辑
|
# 简单的文本生成逻辑
|
||||||
parts = [f"{subject_node.content}"]
|
parts = [f"{subject_node.content}"]
|
||||||
|
|
||||||
# 查找主题节点(通过记忆类型边连接)
|
# 查找主题节点(通过记忆类型边连接)——使用索引加速
|
||||||
topic_node = None
|
topic_node: MemoryNode | None = None
|
||||||
for edge in self.edges:
|
for edge in self._edges_by_source.get(self.subject_id, []):
|
||||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id:
|
if edge.edge_type == EdgeType.MEMORY_TYPE:
|
||||||
topic_node = self.get_node_by_id(edge.target_id)
|
topic_node = self.get_node_by_id(edge.target_id)
|
||||||
|
if topic_node is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
if topic_node:
|
if topic_node:
|
||||||
parts.append(topic_node.content)
|
parts.append(topic_node.content)
|
||||||
|
|
||||||
# 查找客体节点(通过核心关系边连接)
|
# 查找客体节点(通过核心关系边连接)——使用索引加速
|
||||||
for edge in self.edges:
|
for edge in self._edges_by_source.get(topic_node.id, []):
|
||||||
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
|
if edge.edge_type == EdgeType.CORE_RELATION:
|
||||||
obj_node = self.get_node_by_id(edge.target_id)
|
obj_node = self.get_node_by_id(edge.target_id)
|
||||||
if obj_node:
|
if obj_node is not None:
|
||||||
parts.append(f"{edge.relation} {obj_node.content}")
|
parts.append(f"{edge.relation} {obj_node.content}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user