feat: 优化长期记忆检索和合并操作,支持图结构扩展和智能合并

This commit is contained in:
Windpicker-owo
2025-11-19 11:33:10 +08:00
parent f3ea6a692e
commit 4c7bc1928e
4 changed files with 687 additions and 98 deletions

View File

@@ -207,24 +207,70 @@ class LongTermMemoryManager:
""" """
在长期记忆中检索与短期记忆相似的记忆 在长期记忆中检索与短期记忆相似的记忆
Args: 优化:不仅检索内容相似的,还利用图结构获取上下文相关的记忆
stm: 短期记忆
Returns:
相似的长期记忆列表
""" """
try: try:
# 使用短期记忆的内容进行检索 from src.config.config import global_config
# 检查是否启用了高级路径扩展算法
use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False)
# 1. 检索记忆
# 如果启用了路径扩展search_memories 内部会自动使用 PathScoreExpansion
# 我们只需要传入合适的 expand_depth
expand_depth = getattr(global_config.memory, "path_expansion_max_hops", 2) if use_path_expansion else 0
memories = await self.memory_manager.search_memories( memories = await self.memory_manager.search_memories(
query=stm.content, query=stm.content,
top_k=self.search_top_k, top_k=self.search_top_k,
include_forgotten=False, include_forgotten=False,
use_multi_query=False, # 不使用多查询,避免过度扩展 use_multi_query=False, # 不使用多查询,避免过度扩展
expand_depth=expand_depth
) )
logger.debug(f"为短期记忆 {stm.id} 找到 {len(memories)} 个相似长期记忆") # 2. 图结构扩展 (Graph Expansion)
# 如果已经使用了高级路径扩展算法,就不需要再做简单的手动扩展了
if use_path_expansion:
logger.debug(f"已使用路径扩展算法检索到 {len(memories)} 条记忆")
return memories return memories
# 如果未启用高级算法,使用简单的 1 跳邻居扩展作为保底
expanded_memories = []
seen_ids = {m.id for m in memories}
for mem in memories:
expanded_memories.append(mem)
# 获取该记忆的直接关联记忆1跳邻居
try:
# 利用 MemoryManager 的底层图遍历能力
related_ids = self.memory_manager._get_related_memories(mem.id, max_depth=1)
# 限制每个记忆扩展的邻居数量,避免上下文爆炸
max_neighbors = 2
neighbor_count = 0
for rid in related_ids:
if rid not in seen_ids:
related_mem = await self.memory_manager.get_memory(rid)
if related_mem:
expanded_memories.append(related_mem)
seen_ids.add(rid)
neighbor_count += 1
if neighbor_count >= max_neighbors:
break
except Exception as e:
logger.warning(f"获取关联记忆失败: {e}")
# 总数限制
if len(expanded_memories) >= self.search_top_k * 2:
break
logger.debug(f"为短期记忆 {stm.id} 找到 {len(expanded_memories)} 个长期记忆 (含简单图扩展)")
return expanded_memories
except Exception as e: except Exception as e:
logger.error(f"检索相似长期记忆失败: {e}", exc_info=True) logger.error(f"检索相似长期记忆失败: {e}", exc_info=True)
return [] return []
@@ -465,10 +511,26 @@ class LongTermMemoryManager:
await self._execute_create_node(op) await self._execute_create_node(op)
success_count += 1 success_count += 1
elif op.operation_type == GraphOperationType.UPDATE_NODE:
await self._execute_update_node(op)
success_count += 1
elif op.operation_type == GraphOperationType.MERGE_NODES:
await self._execute_merge_nodes(op)
success_count += 1
elif op.operation_type == GraphOperationType.CREATE_EDGE: elif op.operation_type == GraphOperationType.CREATE_EDGE:
await self._execute_create_edge(op) await self._execute_create_edge(op)
success_count += 1 success_count += 1
elif op.operation_type == GraphOperationType.UPDATE_EDGE:
await self._execute_update_edge(op)
success_count += 1
elif op.operation_type == GraphOperationType.DELETE_EDGE:
await self._execute_delete_edge(op)
success_count += 1
else: else:
logger.warning(f"未实现的操作类型: {op.operation_type}") logger.warning(f"未实现的操作类型: {op.operation_type}")
@@ -525,7 +587,7 @@ class LongTermMemoryManager:
async def _execute_merge_memories( async def _execute_merge_memories(
self, op: GraphOperation, source_stm: ShortTermMemory self, op: GraphOperation, source_stm: ShortTermMemory
) -> None: ) -> None:
"""执行合并记忆操作""" """执行合并记忆操作 (智能合并版)"""
source_ids = op.parameters.get("source_memory_ids", []) source_ids = op.parameters.get("source_memory_ids", [])
merged_content = op.parameters.get("merged_content", "") merged_content = op.parameters.get("merged_content", "")
merged_importance = op.parameters.get("merged_importance", source_stm.importance) merged_importance = op.parameters.get("merged_importance", source_stm.importance)
@@ -534,38 +596,191 @@ class LongTermMemoryManager:
logger.warning("合并操作缺少源记忆ID跳过") logger.warning("合并操作缺少源记忆ID跳过")
return return
# 简化实现:更新第一个记忆,删除其他记忆 # 目标记忆(保留的那个)
target_id = source_ids[0] target_id = source_ids[0]
success = await self.memory_manager.update_memory(
# 待合并记忆(将被删除的)
memories_to_merge = source_ids[1:]
logger.info(f"开始智能合并记忆: {memories_to_merge} -> {target_id}")
# 1. 调用 GraphStore 的合并功能(转移节点和边)
merge_success = self.memory_manager.graph_store.merge_memories(target_id, memories_to_merge)
if merge_success:
# 2. 更新目标记忆的元数据
await self.memory_manager.update_memory(
target_id, target_id,
metadata={ metadata={
"merged_content": merged_content, "merged_content": merged_content,
"merged_from": source_ids[1:], "merged_from": memories_to_merge,
"merged_from_stm": source_stm.id, "merged_from_stm": source_stm.id,
"merge_time": datetime.now().isoformat()
}, },
importance=merged_importance, importance=merged_importance,
) )
if success: # 3. 异步保存
# 删除其他记忆 asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆"))
for mem_id in source_ids[1:]: logger.info(f"✅ 合并记忆完成: {source_ids} -> {target_id}")
await self.memory_manager.delete_memory(mem_id)
logger.info(f"✅ 合并记忆: {source_ids}{target_id}")
else: else:
logger.error(f"合并记忆失败: {source_ids}") logger.error(f"合并记忆失败: {source_ids}")
async def _execute_create_node(self, op: GraphOperation) -> None: async def _execute_create_node(self, op: GraphOperation) -> None:
"""执行创建节点操作""" """执行创建节点操作"""
# 注意:当前 MemoryManager 不直接支持单独创建节点 params = op.parameters
# 这里记录操作,实际执行需要扩展 MemoryManager API content = params.get("content")
logger.info(f"创建节点操作(待实现): {op.parameters}") node_type = params.get("node_type", "OBJECT")
memory_id = params.get("memory_id")
if not content or not memory_id:
logger.warning(f"创建节点失败: 缺少必要参数 (content={content}, memory_id={memory_id})")
return
import uuid
node_id = str(uuid.uuid4())
success = self.memory_manager.graph_store.add_node(
node_id=node_id,
content=content,
node_type=node_type,
memory_id=memory_id,
metadata={"created_by": "long_term_manager"}
)
if success:
# 尝试为新节点生成 embedding (异步)
asyncio.create_task(self._generate_node_embedding(node_id, content))
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
else:
logger.error(f"创建节点失败: {op}")
async def _execute_update_node(self, op: GraphOperation) -> None:
"""执行更新节点操作"""
node_id = op.target_id
params = op.parameters
updated_content = params.get("updated_content")
if not node_id:
logger.warning("更新节点失败: 缺少 node_id")
return
success = self.memory_manager.graph_store.update_node(
node_id=node_id,
content=updated_content
)
if success:
logger.info(f"✅ 更新节点: {node_id}")
else:
logger.error(f"更新节点失败: {node_id}")
async def _execute_merge_nodes(self, op: GraphOperation) -> None:
"""执行合并节点操作"""
params = op.parameters
source_node_ids = params.get("source_node_ids", [])
merged_content = params.get("merged_content")
if not source_node_ids or len(source_node_ids) < 2:
logger.warning("合并节点失败: 需要至少两个节点")
return
target_id = source_node_ids[0]
sources = source_node_ids[1:]
# 更新目标节点内容
if merged_content:
self.memory_manager.graph_store.update_node(target_id, content=merged_content)
# 合并其他节点到目标节点
for source_id in sources:
self.memory_manager.graph_store.merge_nodes(source_id, target_id)
logger.info(f"✅ 合并节点: {sources} -> {target_id}")
async def _execute_create_edge(self, op: GraphOperation) -> None: async def _execute_create_edge(self, op: GraphOperation) -> None:
"""执行创建边操作""" """执行创建边操作"""
# 注意:当前 MemoryManager 不直接支持单独创建边 params = op.parameters
# 这里记录操作,实际执行需要扩展 MemoryManager API source_id = params.get("source_node_id")
logger.info(f"创建边操作(待实现): {op.parameters}") target_id = params.get("target_node_id")
relation = params.get("relation", "related")
edge_type = params.get("edge_type", "RELATION")
importance = params.get("importance", 0.5)
if not source_id or not target_id:
logger.warning(f"创建边失败: 缺少节点ID ({source_id} -> {target_id})")
return
edge_id = self.memory_manager.graph_store.add_edge(
source_id=source_id,
target_id=target_id,
relation=relation,
edge_type=edge_type,
importance=importance,
metadata={"created_by": "long_term_manager"}
)
if edge_id:
logger.info(f"✅ 创建边: {source_id} -> {target_id} ({relation})")
else:
logger.error(f"创建边失败: {op}")
async def _execute_update_edge(self, op: GraphOperation) -> None:
"""执行更新边操作"""
edge_id = op.target_id
params = op.parameters
updated_relation = params.get("updated_relation")
updated_importance = params.get("updated_importance")
if not edge_id:
logger.warning("更新边失败: 缺少 edge_id")
return
success = self.memory_manager.graph_store.update_edge(
edge_id=edge_id,
relation=updated_relation,
importance=updated_importance
)
if success:
logger.info(f"✅ 更新边: {edge_id}")
else:
logger.error(f"更新边失败: {edge_id}")
async def _execute_delete_edge(self, op: GraphOperation) -> None:
"""执行删除边操作"""
edge_id = op.target_id
if not edge_id:
logger.warning("删除边失败: 缺少 edge_id")
return
success = self.memory_manager.graph_store.remove_edge(edge_id)
if success:
logger.info(f"✅ 删除边: {edge_id}")
else:
logger.error(f"删除边失败: {edge_id}")
async def _generate_node_embedding(self, node_id: str, content: str) -> None:
"""为新节点生成 embedding 并存入向量库"""
try:
if not self.memory_manager.vector_store or not self.memory_manager.embedding_generator:
return
embedding = await self.memory_manager.embedding_generator.generate(content)
if embedding is not None:
# 需要构造一个 MemoryNode 对象来调用 add_node
from src.memory_graph.models import MemoryNode, NodeType
node = MemoryNode(
id=node_id,
content=content,
node_type=NodeType.OBJECT, # 默认
embedding=embedding
)
await self.memory_manager.vector_store.add_node(node)
except Exception as e:
logger.warning(f"生成节点 embedding 失败: {e}")
async def apply_long_term_decay(self) -> dict[str, Any]: async def apply_long_term_decay(self) -> dict[str, Any]:
""" """

View File

@@ -82,6 +82,374 @@ class GraphStore:
logger.error(f"添加记忆失败: {e}", exc_info=True) logger.error(f"添加记忆失败: {e}", exc_info=True)
raise raise
def add_node(
self,
node_id: str,
content: str,
node_type: str,
memory_id: str,
metadata: dict | None = None,
) -> bool:
"""
添加单个节点到图和指定记忆
Args:
node_id: 节点ID
content: 节点内容
node_type: 节点类型
memory_id: 所属记忆ID
metadata: 元数据
Returns:
是否添加成功
"""
try:
# 1. 检查记忆是否存在
if memory_id not in self.memory_index:
logger.warning(f"添加节点失败: 记忆不存在 {memory_id}")
return False
memory = self.memory_index[memory_id]
# 2. 添加节点到图
if not self.graph.has_node(node_id):
from datetime import datetime
self.graph.add_node(
node_id,
content=content,
node_type=node_type,
created_at=datetime.now().isoformat(),
metadata=metadata or {},
)
else:
# 如果节点已存在,更新内容(可选)
pass
# 3. 更新节点到记忆的映射
if node_id not in self.node_to_memories:
self.node_to_memories[node_id] = set()
self.node_to_memories[node_id].add(memory_id)
# 4. 更新记忆对象的 nodes 列表
# 检查是否已在列表中
if not any(n.id == node_id for n in memory.nodes):
from src.memory_graph.models import MemoryNode, NodeType
# 尝试转换 node_type 字符串为枚举
try:
node_type_enum = NodeType(node_type)
except ValueError:
node_type_enum = NodeType.OBJECT # 默认
new_node = MemoryNode(
id=node_id,
content=content,
node_type=node_type_enum,
metadata=metadata or {}
)
memory.nodes.append(new_node)
logger.debug(f"添加节点成功: {node_id} -> {memory_id}")
return True
except Exception as e:
logger.error(f"添加节点失败: {e}", exc_info=True)
return False
def update_node(
self,
node_id: str,
content: str | None = None,
metadata: dict | None = None
) -> bool:
"""
更新节点信息
Args:
node_id: 节点ID
content: 新内容
metadata: 要更新的元数据
Returns:
是否更新成功
"""
if not self.graph.has_node(node_id):
logger.warning(f"更新节点失败: 节点不存在 {node_id}")
return False
try:
# 更新图中的节点数据
if content is not None:
self.graph.nodes[node_id]["content"] = content
if metadata:
if "metadata" not in self.graph.nodes[node_id]:
self.graph.nodes[node_id]["metadata"] = {}
self.graph.nodes[node_id]["metadata"].update(metadata)
# 同步更新所有相关记忆中的节点对象
if node_id in self.node_to_memories:
for mem_id in self.node_to_memories[node_id]:
memory = self.memory_index.get(mem_id)
if memory:
for node in memory.nodes:
if node.id == node_id:
if content is not None:
node.content = content
if metadata:
node.metadata.update(metadata)
break
return True
except Exception as e:
logger.error(f"更新节点失败: {e}", exc_info=True)
return False
def add_edge(
self,
source_id: str,
target_id: str,
relation: str,
edge_type: str,
importance: float = 0.5,
metadata: dict | None = None,
) -> str | None:
"""
添加边到图
Args:
source_id: 源节点ID
target_id: 目标节点ID
relation: 关系描述
edge_type: 边类型
importance: 重要性
metadata: 元数据
Returns:
新边的ID失败返回 None
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
logger.warning(f"添加边失败: 节点不存在 ({source_id}, {target_id})")
return None
try:
import uuid
from datetime import datetime
from src.memory_graph.models import MemoryEdge, EdgeType
edge_id = str(uuid.uuid4())
created_at = datetime.now().isoformat()
# 1. 添加到图
self.graph.add_edge(
source_id,
target_id,
edge_id=edge_id,
relation=relation,
edge_type=edge_type,
importance=importance,
metadata=metadata or {},
created_at=created_at,
)
# 2. 同步到相关记忆
# 找到包含源节点或目标节点的记忆
related_memory_ids = set()
if source_id in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[source_id])
if target_id in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[target_id])
# 尝试转换 edge_type
try:
edge_type_enum = EdgeType(edge_type)
except ValueError:
edge_type_enum = EdgeType.RELATION
new_edge = MemoryEdge(
id=edge_id,
source_id=source_id,
target_id=target_id,
relation=relation,
edge_type=edge_type_enum,
importance=importance,
metadata=metadata or {}
)
for mem_id in related_memory_ids:
memory = self.memory_index.get(mem_id)
if memory:
memory.edges.append(new_edge)
logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})")
return edge_id
except Exception as e:
logger.error(f"添加边失败: {e}", exc_info=True)
return None
def update_edge(
self,
edge_id: str,
relation: str | None = None,
importance: float | None = None
) -> bool:
"""
更新边信息
Args:
edge_id: 边ID
relation: 新关系描述
importance: 新重要性
Returns:
是否更新成功
"""
# NetworkX 的边是通过 (u, v) 索引的,没有直接的 edge_id 索引
# 需要遍历查找(或者维护一个 edge_id -> (u, v) 的映射,这里简化处理)
target_edge = None
source_node = None
target_node = None
for u, v, data in self.graph.edges(data=True):
if data.get("edge_id") == edge_id or data.get("id") == edge_id:
target_edge = data
source_node = u
target_node = v
break
if not target_edge:
logger.warning(f"更新边失败: 边不存在 {edge_id}")
return False
try:
# 更新图数据
if relation is not None:
self.graph[source_node][target_node]["relation"] = relation
if importance is not None:
self.graph[source_node][target_node]["importance"] = importance
# 同步更新记忆中的边对象
related_memory_ids = set()
if source_node in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[source_node])
if target_node in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[target_node])
for mem_id in related_memory_ids:
memory = self.memory_index.get(mem_id)
if memory:
for edge in memory.edges:
if edge.id == edge_id:
if relation is not None:
edge.relation = relation
if importance is not None:
edge.importance = importance
break
return True
except Exception as e:
logger.error(f"更新边失败: {e}", exc_info=True)
return False
def remove_edge(self, edge_id: str) -> bool:
"""
删除边
Args:
edge_id: 边ID
Returns:
是否删除成功
"""
target_edge = None
source_node = None
target_node = None
for u, v, data in self.graph.edges(data=True):
if data.get("edge_id") == edge_id or data.get("id") == edge_id:
target_edge = data
source_node = u
target_node = v
break
if not target_edge:
logger.warning(f"删除边失败: 边不存在 {edge_id}")
return False
try:
# 从图中删除
self.graph.remove_edge(source_node, target_node)
# 从相关记忆中删除
related_memory_ids = set()
if source_node in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[source_node])
if target_node in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[target_node])
for mem_id in related_memory_ids:
memory = self.memory_index.get(mem_id)
if memory:
memory.edges = [e for e in memory.edges if e.id != edge_id]
return True
except Exception as e:
logger.error(f"删除边失败: {e}", exc_info=True)
return False
def merge_memories(self, target_memory_id: str, source_memory_ids: list[str]) -> bool:
"""
合并多个记忆到目标记忆
将源记忆的所有节点和边转移到目标记忆,然后删除源记忆。
Args:
target_memory_id: 目标记忆ID
source_memory_ids: 源记忆ID列表
Returns:
是否合并成功
"""
if target_memory_id not in self.memory_index:
logger.error(f"合并失败: 目标记忆不存在 {target_memory_id}")
return False
target_memory = self.memory_index[target_memory_id]
try:
for source_id in source_memory_ids:
if source_id not in self.memory_index:
continue
source_memory = self.memory_index[source_id]
# 1. 转移节点
for node in source_memory.nodes:
# 更新映射
if node.id in self.node_to_memories:
self.node_to_memories[node.id].discard(source_id)
self.node_to_memories[node.id].add(target_memory_id)
# 添加到目标记忆(如果不存在)
if not any(n.id == node.id for n in target_memory.nodes):
target_memory.nodes.append(node)
# 2. 转移边
for edge in source_memory.edges:
# 添加到目标记忆(如果不存在)
if not any(e.id == edge.id for e in target_memory.edges):
target_memory.edges.append(edge)
# 3. 删除源记忆(不清理孤立节点,因为节点已转移)
del self.memory_index[source_id]
logger.info(f"成功合并记忆: {source_memory_ids} -> {target_memory_id}")
return True
except Exception as e:
logger.error(f"合并记忆失败: {e}", exc_info=True)
return False
def get_memory_by_id(self, memory_id: str) -> Memory | None: def get_memory_by_id(self, memory_id: str) -> Memory | None:
""" """
根据ID获取记忆 根据ID获取记忆

View File

@@ -115,8 +115,8 @@ class BaseEvent:
if not self.enabled: if not self.enabled:
return HandlerResultsCollection([]) return HandlerResultsCollection([])
# 使用锁确保同一事件不能同时激活多次 # 移除全局锁,允许同一事件并发触发
async with self.event_handle_lock: # async with self.event_handle_lock:
sorted_subscribers = sorted( sorted_subscribers = sorted(
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
) )

View File

@@ -586,10 +586,16 @@ class PluginManager:
# 从组件注册表中移除插件的所有组件 # 从组件注册表中移除插件的所有组件
try: try:
loop = asyncio.get_event_loop() try:
if loop.is_running(): loop = asyncio.get_running_loop()
fut = asyncio.run_coroutine_threadsafe(component_registry.unregister_plugin(plugin_name), loop) except RuntimeError:
fut.result(timeout=5) loop = None
if loop and loop.is_running():
# 如果在运行的事件循环中,直接创建任务,不等待结果以避免死锁
# 注意:这意味着我们无法确切知道卸载是否成功完成,但避免了阻塞
logger.warning(f"unload_plugin 在异步上下文中被调用 ({plugin_name}),将异步执行组件卸载。建议使用 remove_registered_plugin。")
loop.create_task(component_registry.unregister_plugin(plugin_name))
else: else:
asyncio.run(component_registry.unregister_plugin(plugin_name)) asyncio.run(component_registry.unregister_plugin(plugin_name))
except Exception as e: # 捕获并记录卸载阶段协程调用错误 except Exception as e: # 捕获并记录卸载阶段协程调用错误