From 968e9b4535398cf90b19780a9c317e10dd00a562 Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Wed, 17 Dec 2025 18:25:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(memory):=20=E6=B7=BB=E5=8A=A0=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=9C=9F=E7=B4=A2=E5=BC=95=E4=BB=A5=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E5=92=8C=E8=BE=B9=E7=9A=84=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_graph/models.py | 101 ++++++++++++++++++++++++++++++++----- 1 file changed, 88 insertions(+), 13 deletions(-) diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index a2a22ac47..d4d7314b5 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -238,6 +238,10 @@ class Memory: decay_factor: float = 1.0 # 衰减因子(随时间变化) metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据 + # 运行期索引(不序列化) + _node_index: dict[str, "MemoryNode"] = field(default_factory=dict, init=False, repr=False, compare=False) + _edges_by_source: dict[str, list["MemoryEdge"]] = field(default_factory=dict, init=False, repr=False, compare=False) + def __post_init__(self): """后初始化处理""" if not self.id: @@ -247,6 +251,73 @@ class Memory: self.activation = max(0.0, min(1.0, self.activation)) if not self.updated_at: self.updated_at = self.created_at + # 构建加速索引 + self._rebuild_indexes() + + # -------------------- + # 索引与快速访问 + # -------------------- + def _rebuild_indexes(self) -> None: + """从当前 nodes/edges 重新构建运行期索引。""" + # 节点索引 + self._node_index.clear() + for n in self.nodes: + self._node_index[n.id] = n + + # 源端分组的边索引 + self._edges_by_source.clear() + for e in self.edges: + self._edges_by_source.setdefault(e.source_id, []).append(e) + + def add_node(self, node: "MemoryNode") -> None: + """添加节点并更新索引。""" + self.nodes.append(node) + self._node_index[node.id] = node + + def remove_node(self, node_id: str) -> bool: + """按 ID 移除节点,并清理相关索引。返回是否移除成功。""" + node = self._node_index.pop(node_id, None) + if node is None: + return False + # 从列表移除(保持原有数据结构契合度) + for i, n in enumerate(self.nodes): + if n.id == node_id: + del self.nodes[i] + break + # 同时移除以该节点为源或目标的边 + new_edges = [] + affected_sources: set[str] = set() + for e in self.edges: + if e.source_id == node_id or e.target_id == node_id: + affected_sources.add(e.source_id) + continue + new_edges.append(e) + self.edges = new_edges + # 更新边索引 + for sid in affected_sources: + if sid in self._edges_by_source: + self._edges_by_source[sid] = [e for e in self._edges_by_source[sid] if e.source_id != node_id and e.target_id != node_id] + if not self._edges_by_source[sid]: + self._edges_by_source.pop(sid, None) + return True + + def add_edge(self, edge: "MemoryEdge") -> None: + """添加边并更新索引。""" + self.edges.append(edge) + self._edges_by_source.setdefault(edge.source_id, []).append(edge) + + def remove_edge(self, edge_id: str) -> bool: + """按 ID 移除边,并更新索引。返回是否移除成功。""" + for i, e in enumerate(self.edges): + if e.id == edge_id: + del self.edges[i] + src_list = self._edges_by_source.get(e.source_id) + if src_list is not None: + self._edges_by_source[e.source_id] = [x for x in src_list if x.id != edge_id] + if not self._edges_by_source[e.source_id]: + self._edges_by_source.pop(e.source_id, None) + return True + return False def to_dict(self) -> dict[str, Any]: """转换为字典(用于序列化)""" @@ -312,10 +383,13 @@ class Memory: def get_node_by_id(self, node_id: str) -> MemoryNode | None: """根据ID获取节点""" - for node in self.nodes: - if node.id == node_id: - return node - return None + # O(1) 访问,回退安全(极少情况下索引不同步可重建) + node = self._node_index.get(node_id) + if node is not None: + return node + # 回退:重建索引后再取,避免外部直接修改列表导致的不同步 + self._rebuild_indexes() + return self._node_index.get(node_id) def get_subject_node(self) -> MemoryNode | None: """获取主体节点""" @@ -330,21 +404,22 @@ class Memory: # 简单的文本生成逻辑 parts = [f"{subject_node.content}"] - # 查找主题节点(通过记忆类型边连接) - topic_node = None - for edge in self.edges: - if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id: + # 查找主题节点(通过记忆类型边连接)——使用索引加速 + topic_node: MemoryNode | None = None + for edge in self._edges_by_source.get(self.subject_id, []): + if edge.edge_type == EdgeType.MEMORY_TYPE: topic_node = self.get_node_by_id(edge.target_id) - break + if topic_node is not None: + break if topic_node: parts.append(topic_node.content) - # 查找客体节点(通过核心关系边连接) - for edge in self.edges: - if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id: + # 查找客体节点(通过核心关系边连接)——使用索引加速 + for edge in self._edges_by_source.get(topic_node.id, []): + if edge.edge_type == EdgeType.CORE_RELATION: obj_node = self.get_node_by_id(edge.target_id) - if obj_node: + if obj_node is not None: parts.append(f"{edge.relation} {obj_node.content}") break