feat: 优化长期记忆检索和合并操作,支持图结构扩展和智能合并
This commit is contained in:
@@ -207,23 +207,69 @@ class LongTermMemoryManager:
|
||||
"""
|
||||
在长期记忆中检索与短期记忆相似的记忆
|
||||
|
||||
Args:
|
||||
stm: 短期记忆
|
||||
|
||||
Returns:
|
||||
相似的长期记忆列表
|
||||
优化:不仅检索内容相似的,还利用图结构获取上下文相关的记忆
|
||||
"""
|
||||
try:
|
||||
# 使用短期记忆的内容进行检索
|
||||
from src.config.config import global_config
|
||||
|
||||
# 检查是否启用了高级路径扩展算法
|
||||
use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False)
|
||||
|
||||
# 1. 检索记忆
|
||||
# 如果启用了路径扩展,search_memories 内部会自动使用 PathScoreExpansion
|
||||
# 我们只需要传入合适的 expand_depth
|
||||
expand_depth = getattr(global_config.memory, "path_expansion_max_hops", 2) if use_path_expansion else 0
|
||||
|
||||
memories = await self.memory_manager.search_memories(
|
||||
query=stm.content,
|
||||
top_k=self.search_top_k,
|
||||
include_forgotten=False,
|
||||
use_multi_query=False, # 不使用多查询,避免过度扩展
|
||||
expand_depth=expand_depth
|
||||
)
|
||||
|
||||
logger.debug(f"为短期记忆 {stm.id} 找到 {len(memories)} 个相似长期记忆")
|
||||
return memories
|
||||
# 2. 图结构扩展 (Graph Expansion)
|
||||
# 如果已经使用了高级路径扩展算法,就不需要再做简单的手动扩展了
|
||||
if use_path_expansion:
|
||||
logger.debug(f"已使用路径扩展算法检索到 {len(memories)} 条记忆")
|
||||
return memories
|
||||
|
||||
# 如果未启用高级算法,使用简单的 1 跳邻居扩展作为保底
|
||||
expanded_memories = []
|
||||
seen_ids = {m.id for m in memories}
|
||||
|
||||
for mem in memories:
|
||||
expanded_memories.append(mem)
|
||||
|
||||
# 获取该记忆的直接关联记忆(1跳邻居)
|
||||
try:
|
||||
# 利用 MemoryManager 的底层图遍历能力
|
||||
related_ids = self.memory_manager._get_related_memories(mem.id, max_depth=1)
|
||||
|
||||
# 限制每个记忆扩展的邻居数量,避免上下文爆炸
|
||||
max_neighbors = 2
|
||||
neighbor_count = 0
|
||||
|
||||
for rid in related_ids:
|
||||
if rid not in seen_ids:
|
||||
related_mem = await self.memory_manager.get_memory(rid)
|
||||
if related_mem:
|
||||
expanded_memories.append(related_mem)
|
||||
seen_ids.add(rid)
|
||||
neighbor_count += 1
|
||||
|
||||
if neighbor_count >= max_neighbors:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取关联记忆失败: {e}")
|
||||
|
||||
# 总数限制
|
||||
if len(expanded_memories) >= self.search_top_k * 2:
|
||||
break
|
||||
|
||||
logger.debug(f"为短期记忆 {stm.id} 找到 {len(expanded_memories)} 个长期记忆 (含简单图扩展)")
|
||||
return expanded_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相似长期记忆失败: {e}", exc_info=True)
|
||||
@@ -465,10 +511,26 @@ class LongTermMemoryManager:
|
||||
await self._execute_create_node(op)
|
||||
success_count += 1
|
||||
|
||||
elif op.operation_type == GraphOperationType.UPDATE_NODE:
|
||||
await self._execute_update_node(op)
|
||||
success_count += 1
|
||||
|
||||
elif op.operation_type == GraphOperationType.MERGE_NODES:
|
||||
await self._execute_merge_nodes(op)
|
||||
success_count += 1
|
||||
|
||||
elif op.operation_type == GraphOperationType.CREATE_EDGE:
|
||||
await self._execute_create_edge(op)
|
||||
success_count += 1
|
||||
|
||||
elif op.operation_type == GraphOperationType.UPDATE_EDGE:
|
||||
await self._execute_update_edge(op)
|
||||
success_count += 1
|
||||
|
||||
elif op.operation_type == GraphOperationType.DELETE_EDGE:
|
||||
await self._execute_delete_edge(op)
|
||||
success_count += 1
|
||||
|
||||
else:
|
||||
logger.warning(f"未实现的操作类型: {op.operation_type}")
|
||||
|
||||
@@ -525,7 +587,7 @@ class LongTermMemoryManager:
|
||||
async def _execute_merge_memories(
|
||||
self, op: GraphOperation, source_stm: ShortTermMemory
|
||||
) -> None:
|
||||
"""执行合并记忆操作"""
|
||||
"""执行合并记忆操作 (智能合并版)"""
|
||||
source_ids = op.parameters.get("source_memory_ids", [])
|
||||
merged_content = op.parameters.get("merged_content", "")
|
||||
merged_importance = op.parameters.get("merged_importance", source_stm.importance)
|
||||
@@ -534,38 +596,191 @@ class LongTermMemoryManager:
|
||||
logger.warning("合并操作缺少源记忆ID,跳过")
|
||||
return
|
||||
|
||||
# 简化实现:更新第一个记忆,删除其他记忆
|
||||
# 目标记忆(保留的那个)
|
||||
target_id = source_ids[0]
|
||||
success = await self.memory_manager.update_memory(
|
||||
target_id,
|
||||
metadata={
|
||||
"merged_content": merged_content,
|
||||
"merged_from": source_ids[1:],
|
||||
"merged_from_stm": source_stm.id,
|
||||
},
|
||||
importance=merged_importance,
|
||||
)
|
||||
|
||||
# 待合并记忆(将被删除的)
|
||||
memories_to_merge = source_ids[1:]
|
||||
|
||||
logger.info(f"开始智能合并记忆: {memories_to_merge} -> {target_id}")
|
||||
|
||||
if success:
|
||||
# 删除其他记忆
|
||||
for mem_id in source_ids[1:]:
|
||||
await self.memory_manager.delete_memory(mem_id)
|
||||
# 1. 调用 GraphStore 的合并功能(转移节点和边)
|
||||
merge_success = self.memory_manager.graph_store.merge_memories(target_id, memories_to_merge)
|
||||
|
||||
logger.info(f"✅ 合并记忆: {source_ids} → {target_id}")
|
||||
if merge_success:
|
||||
# 2. 更新目标记忆的元数据
|
||||
await self.memory_manager.update_memory(
|
||||
target_id,
|
||||
metadata={
|
||||
"merged_content": merged_content,
|
||||
"merged_from": memories_to_merge,
|
||||
"merged_from_stm": source_stm.id,
|
||||
"merge_time": datetime.now().isoformat()
|
||||
},
|
||||
importance=merged_importance,
|
||||
)
|
||||
|
||||
# 3. 异步保存
|
||||
asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆"))
|
||||
logger.info(f"✅ 合并记忆完成: {source_ids} -> {target_id}")
|
||||
else:
|
||||
logger.error(f"合并记忆失败: {source_ids}")
|
||||
|
||||
async def _execute_create_node(self, op: GraphOperation) -> None:
|
||||
"""执行创建节点操作"""
|
||||
# 注意:当前 MemoryManager 不直接支持单独创建节点
|
||||
# 这里记录操作,实际执行需要扩展 MemoryManager API
|
||||
logger.info(f"创建节点操作(待实现): {op.parameters}")
|
||||
params = op.parameters
|
||||
content = params.get("content")
|
||||
node_type = params.get("node_type", "OBJECT")
|
||||
memory_id = params.get("memory_id")
|
||||
|
||||
if not content or not memory_id:
|
||||
logger.warning(f"创建节点失败: 缺少必要参数 (content={content}, memory_id={memory_id})")
|
||||
return
|
||||
|
||||
import uuid
|
||||
node_id = str(uuid.uuid4())
|
||||
|
||||
success = self.memory_manager.graph_store.add_node(
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
node_type=node_type,
|
||||
memory_id=memory_id,
|
||||
metadata={"created_by": "long_term_manager"}
|
||||
)
|
||||
|
||||
if success:
|
||||
# 尝试为新节点生成 embedding (异步)
|
||||
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
||||
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
|
||||
else:
|
||||
logger.error(f"创建节点失败: {op}")
|
||||
|
||||
async def _execute_update_node(self, op: GraphOperation) -> None:
|
||||
"""执行更新节点操作"""
|
||||
node_id = op.target_id
|
||||
params = op.parameters
|
||||
updated_content = params.get("updated_content")
|
||||
|
||||
if not node_id:
|
||||
logger.warning("更新节点失败: 缺少 node_id")
|
||||
return
|
||||
|
||||
success = self.memory_manager.graph_store.update_node(
|
||||
node_id=node_id,
|
||||
content=updated_content
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 更新节点: {node_id}")
|
||||
else:
|
||||
logger.error(f"更新节点失败: {node_id}")
|
||||
|
||||
async def _execute_merge_nodes(self, op: GraphOperation) -> None:
|
||||
"""执行合并节点操作"""
|
||||
params = op.parameters
|
||||
source_node_ids = params.get("source_node_ids", [])
|
||||
merged_content = params.get("merged_content")
|
||||
|
||||
if not source_node_ids or len(source_node_ids) < 2:
|
||||
logger.warning("合并节点失败: 需要至少两个节点")
|
||||
return
|
||||
|
||||
target_id = source_node_ids[0]
|
||||
sources = source_node_ids[1:]
|
||||
|
||||
# 更新目标节点内容
|
||||
if merged_content:
|
||||
self.memory_manager.graph_store.update_node(target_id, content=merged_content)
|
||||
|
||||
# 合并其他节点到目标节点
|
||||
for source_id in sources:
|
||||
self.memory_manager.graph_store.merge_nodes(source_id, target_id)
|
||||
|
||||
logger.info(f"✅ 合并节点: {sources} -> {target_id}")
|
||||
|
||||
async def _execute_create_edge(self, op: GraphOperation) -> None:
|
||||
"""执行创建边操作"""
|
||||
# 注意:当前 MemoryManager 不直接支持单独创建边
|
||||
# 这里记录操作,实际执行需要扩展 MemoryManager API
|
||||
logger.info(f"创建边操作(待实现): {op.parameters}")
|
||||
params = op.parameters
|
||||
source_id = params.get("source_node_id")
|
||||
target_id = params.get("target_node_id")
|
||||
relation = params.get("relation", "related")
|
||||
edge_type = params.get("edge_type", "RELATION")
|
||||
importance = params.get("importance", 0.5)
|
||||
|
||||
if not source_id or not target_id:
|
||||
logger.warning(f"创建边失败: 缺少节点ID ({source_id} -> {target_id})")
|
||||
return
|
||||
|
||||
edge_id = self.memory_manager.graph_store.add_edge(
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
relation=relation,
|
||||
edge_type=edge_type,
|
||||
importance=importance,
|
||||
metadata={"created_by": "long_term_manager"}
|
||||
)
|
||||
|
||||
if edge_id:
|
||||
logger.info(f"✅ 创建边: {source_id} -> {target_id} ({relation})")
|
||||
else:
|
||||
logger.error(f"创建边失败: {op}")
|
||||
|
||||
async def _execute_update_edge(self, op: GraphOperation) -> None:
|
||||
"""执行更新边操作"""
|
||||
edge_id = op.target_id
|
||||
params = op.parameters
|
||||
updated_relation = params.get("updated_relation")
|
||||
updated_importance = params.get("updated_importance")
|
||||
|
||||
if not edge_id:
|
||||
logger.warning("更新边失败: 缺少 edge_id")
|
||||
return
|
||||
|
||||
success = self.memory_manager.graph_store.update_edge(
|
||||
edge_id=edge_id,
|
||||
relation=updated_relation,
|
||||
importance=updated_importance
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 更新边: {edge_id}")
|
||||
else:
|
||||
logger.error(f"更新边失败: {edge_id}")
|
||||
|
||||
async def _execute_delete_edge(self, op: GraphOperation) -> None:
|
||||
"""执行删除边操作"""
|
||||
edge_id = op.target_id
|
||||
|
||||
if not edge_id:
|
||||
logger.warning("删除边失败: 缺少 edge_id")
|
||||
return
|
||||
|
||||
success = self.memory_manager.graph_store.remove_edge(edge_id)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 删除边: {edge_id}")
|
||||
else:
|
||||
logger.error(f"删除边失败: {edge_id}")
|
||||
|
||||
async def _generate_node_embedding(self, node_id: str, content: str) -> None:
|
||||
"""为新节点生成 embedding 并存入向量库"""
|
||||
try:
|
||||
if not self.memory_manager.vector_store or not self.memory_manager.embedding_generator:
|
||||
return
|
||||
|
||||
embedding = await self.memory_manager.embedding_generator.generate(content)
|
||||
if embedding is not None:
|
||||
# 需要构造一个 MemoryNode 对象来调用 add_node
|
||||
from src.memory_graph.models import MemoryNode, NodeType
|
||||
node = MemoryNode(
|
||||
id=node_id,
|
||||
content=content,
|
||||
node_type=NodeType.OBJECT, # 默认
|
||||
embedding=embedding
|
||||
)
|
||||
await self.memory_manager.vector_store.add_node(node)
|
||||
except Exception as e:
|
||||
logger.warning(f"生成节点 embedding 失败: {e}")
|
||||
|
||||
async def apply_long_term_decay(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -82,6 +82,374 @@ class GraphStore:
|
||||
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
content: str,
|
||||
node_type: str,
|
||||
memory_id: str,
|
||||
metadata: dict | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加单个节点到图和指定记忆
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
content: 节点内容
|
||||
node_type: 节点类型
|
||||
memory_id: 所属记忆ID
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
try:
|
||||
# 1. 检查记忆是否存在
|
||||
if memory_id not in self.memory_index:
|
||||
logger.warning(f"添加节点失败: 记忆不存在 {memory_id}")
|
||||
return False
|
||||
|
||||
memory = self.memory_index[memory_id]
|
||||
|
||||
# 2. 添加节点到图
|
||||
if not self.graph.has_node(node_id):
|
||||
from datetime import datetime
|
||||
self.graph.add_node(
|
||||
node_id,
|
||||
content=content,
|
||||
node_type=node_type,
|
||||
created_at=datetime.now().isoformat(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
else:
|
||||
# 如果节点已存在,更新内容(可选)
|
||||
pass
|
||||
|
||||
# 3. 更新节点到记忆的映射
|
||||
if node_id not in self.node_to_memories:
|
||||
self.node_to_memories[node_id] = set()
|
||||
self.node_to_memories[node_id].add(memory_id)
|
||||
|
||||
# 4. 更新记忆对象的 nodes 列表
|
||||
# 检查是否已在列表中
|
||||
if not any(n.id == node_id for n in memory.nodes):
|
||||
from src.memory_graph.models import MemoryNode, NodeType
|
||||
# 尝试转换 node_type 字符串为枚举
|
||||
try:
|
||||
node_type_enum = NodeType(node_type)
|
||||
except ValueError:
|
||||
node_type_enum = NodeType.OBJECT # 默认
|
||||
|
||||
new_node = MemoryNode(
|
||||
id=node_id,
|
||||
content=content,
|
||||
node_type=node_type_enum,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
memory.nodes.append(new_node)
|
||||
|
||||
logger.debug(f"添加节点成功: {node_id} -> {memory_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加节点失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def update_node(
|
||||
self,
|
||||
node_id: str,
|
||||
content: str | None = None,
|
||||
metadata: dict | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新节点信息
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
content: 新内容
|
||||
metadata: 要更新的元数据
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
if not self.graph.has_node(node_id):
|
||||
logger.warning(f"更新节点失败: 节点不存在 {node_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新图中的节点数据
|
||||
if content is not None:
|
||||
self.graph.nodes[node_id]["content"] = content
|
||||
|
||||
if metadata:
|
||||
if "metadata" not in self.graph.nodes[node_id]:
|
||||
self.graph.nodes[node_id]["metadata"] = {}
|
||||
self.graph.nodes[node_id]["metadata"].update(metadata)
|
||||
|
||||
# 同步更新所有相关记忆中的节点对象
|
||||
if node_id in self.node_to_memories:
|
||||
for mem_id in self.node_to_memories[node_id]:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
for node in memory.nodes:
|
||||
if node.id == node_id:
|
||||
if content is not None:
|
||||
node.content = content
|
||||
if metadata:
|
||||
node.metadata.update(metadata)
|
||||
break
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"更新节点失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relation: str,
|
||||
edge_type: str,
|
||||
importance: float = 0.5,
|
||||
metadata: dict | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
添加边到图
|
||||
|
||||
Args:
|
||||
source_id: 源节点ID
|
||||
target_id: 目标节点ID
|
||||
relation: 关系描述
|
||||
edge_type: 边类型
|
||||
importance: 重要性
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
新边的ID,失败返回 None
|
||||
"""
|
||||
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
|
||||
logger.warning(f"添加边失败: 节点不存在 ({source_id}, {target_id})")
|
||||
return None
|
||||
|
||||
try:
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from src.memory_graph.models import MemoryEdge, EdgeType
|
||||
|
||||
edge_id = str(uuid.uuid4())
|
||||
created_at = datetime.now().isoformat()
|
||||
|
||||
# 1. 添加到图
|
||||
self.graph.add_edge(
|
||||
source_id,
|
||||
target_id,
|
||||
edge_id=edge_id,
|
||||
relation=relation,
|
||||
edge_type=edge_type,
|
||||
importance=importance,
|
||||
metadata=metadata or {},
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# 2. 同步到相关记忆
|
||||
# 找到包含源节点或目标节点的记忆
|
||||
related_memory_ids = set()
|
||||
if source_id in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[source_id])
|
||||
if target_id in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[target_id])
|
||||
|
||||
# 尝试转换 edge_type
|
||||
try:
|
||||
edge_type_enum = EdgeType(edge_type)
|
||||
except ValueError:
|
||||
edge_type_enum = EdgeType.RELATION
|
||||
|
||||
new_edge = MemoryEdge(
|
||||
id=edge_id,
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
relation=relation,
|
||||
edge_type=edge_type_enum,
|
||||
importance=importance,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
for mem_id in related_memory_ids:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
memory.edges.append(new_edge)
|
||||
|
||||
logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})")
|
||||
return edge_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加边失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def update_edge(
|
||||
self,
|
||||
edge_id: str,
|
||||
relation: str | None = None,
|
||||
importance: float | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新边信息
|
||||
|
||||
Args:
|
||||
edge_id: 边ID
|
||||
relation: 新关系描述
|
||||
importance: 新重要性
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
# NetworkX 的边是通过 (u, v) 索引的,没有直接的 edge_id 索引
|
||||
# 需要遍历查找(或者维护一个 edge_id -> (u, v) 的映射,这里简化处理)
|
||||
target_edge = None
|
||||
source_node = None
|
||||
target_node = None
|
||||
|
||||
for u, v, data in self.graph.edges(data=True):
|
||||
if data.get("edge_id") == edge_id or data.get("id") == edge_id:
|
||||
target_edge = data
|
||||
source_node = u
|
||||
target_node = v
|
||||
break
|
||||
|
||||
if not target_edge:
|
||||
logger.warning(f"更新边失败: 边不存在 {edge_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新图数据
|
||||
if relation is not None:
|
||||
self.graph[source_node][target_node]["relation"] = relation
|
||||
if importance is not None:
|
||||
self.graph[source_node][target_node]["importance"] = importance
|
||||
|
||||
# 同步更新记忆中的边对象
|
||||
related_memory_ids = set()
|
||||
if source_node in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[source_node])
|
||||
if target_node in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[target_node])
|
||||
|
||||
for mem_id in related_memory_ids:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
for edge in memory.edges:
|
||||
if edge.id == edge_id:
|
||||
if relation is not None:
|
||||
edge.relation = relation
|
||||
if importance is not None:
|
||||
edge.importance = importance
|
||||
break
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"更新边失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def remove_edge(self, edge_id: str) -> bool:
|
||||
"""
|
||||
删除边
|
||||
|
||||
Args:
|
||||
edge_id: 边ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
target_edge = None
|
||||
source_node = None
|
||||
target_node = None
|
||||
|
||||
for u, v, data in self.graph.edges(data=True):
|
||||
if data.get("edge_id") == edge_id or data.get("id") == edge_id:
|
||||
target_edge = data
|
||||
source_node = u
|
||||
target_node = v
|
||||
break
|
||||
|
||||
if not target_edge:
|
||||
logger.warning(f"删除边失败: 边不存在 {edge_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 从图中删除
|
||||
self.graph.remove_edge(source_node, target_node)
|
||||
|
||||
# 从相关记忆中删除
|
||||
related_memory_ids = set()
|
||||
if source_node in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[source_node])
|
||||
if target_node in self.node_to_memories:
|
||||
related_memory_ids.update(self.node_to_memories[target_node])
|
||||
|
||||
for mem_id in related_memory_ids:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
memory.edges = [e for e in memory.edges if e.id != edge_id]
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除边失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def merge_memories(self, target_memory_id: str, source_memory_ids: list[str]) -> bool:
|
||||
"""
|
||||
合并多个记忆到目标记忆
|
||||
|
||||
将源记忆的所有节点和边转移到目标记忆,然后删除源记忆。
|
||||
|
||||
Args:
|
||||
target_memory_id: 目标记忆ID
|
||||
source_memory_ids: 源记忆ID列表
|
||||
|
||||
Returns:
|
||||
是否合并成功
|
||||
"""
|
||||
if target_memory_id not in self.memory_index:
|
||||
logger.error(f"合并失败: 目标记忆不存在 {target_memory_id}")
|
||||
return False
|
||||
|
||||
target_memory = self.memory_index[target_memory_id]
|
||||
|
||||
try:
|
||||
for source_id in source_memory_ids:
|
||||
if source_id not in self.memory_index:
|
||||
continue
|
||||
|
||||
source_memory = self.memory_index[source_id]
|
||||
|
||||
# 1. 转移节点
|
||||
for node in source_memory.nodes:
|
||||
# 更新映射
|
||||
if node.id in self.node_to_memories:
|
||||
self.node_to_memories[node.id].discard(source_id)
|
||||
self.node_to_memories[node.id].add(target_memory_id)
|
||||
|
||||
# 添加到目标记忆(如果不存在)
|
||||
if not any(n.id == node.id for n in target_memory.nodes):
|
||||
target_memory.nodes.append(node)
|
||||
|
||||
# 2. 转移边
|
||||
for edge in source_memory.edges:
|
||||
# 添加到目标记忆(如果不存在)
|
||||
if not any(e.id == edge.id for e in target_memory.edges):
|
||||
target_memory.edges.append(edge)
|
||||
|
||||
# 3. 删除源记忆(不清理孤立节点,因为节点已转移)
|
||||
del self.memory_index[source_id]
|
||||
|
||||
logger.info(f"成功合并记忆: {source_memory_ids} -> {target_memory_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Memory | None:
|
||||
"""
|
||||
根据ID获取记忆
|
||||
|
||||
@@ -115,75 +115,75 @@ class BaseEvent:
|
||||
if not self.enabled:
|
||||
return HandlerResultsCollection([])
|
||||
|
||||
# 使用锁确保同一个事件不能同时激活多次
|
||||
async with self.event_handle_lock:
|
||||
sorted_subscribers = sorted(
|
||||
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
|
||||
# 移除全局锁,允许同一事件并发触发
|
||||
# async with self.event_handle_lock:
|
||||
sorted_subscribers = sorted(
|
||||
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
|
||||
)
|
||||
|
||||
if not sorted_subscribers:
|
||||
return HandlerResultsCollection([])
|
||||
|
||||
concurrency_limit = None
|
||||
if max_concurrency is not None:
|
||||
concurrency_limit = max_concurrency if max_concurrency > 0 else None
|
||||
if concurrency_limit:
|
||||
concurrency_limit = min(concurrency_limit, len(sorted_subscribers))
|
||||
|
||||
semaphore = (
|
||||
asyncio.Semaphore(concurrency_limit)
|
||||
if concurrency_limit and concurrency_limit < len(sorted_subscribers)
|
||||
else None
|
||||
)
|
||||
|
||||
async def _run_handler(subscriber):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
|
||||
if not sorted_subscribers:
|
||||
return HandlerResultsCollection([])
|
||||
async def _invoke():
|
||||
return await self._execute_subscriber(subscriber, params)
|
||||
|
||||
concurrency_limit = None
|
||||
if max_concurrency is not None:
|
||||
concurrency_limit = max_concurrency if max_concurrency > 0 else None
|
||||
if concurrency_limit:
|
||||
concurrency_limit = min(concurrency_limit, len(sorted_subscribers))
|
||||
|
||||
semaphore = (
|
||||
asyncio.Semaphore(concurrency_limit)
|
||||
if concurrency_limit and concurrency_limit < len(sorted_subscribers)
|
||||
else None
|
||||
)
|
||||
|
||||
async def _run_handler(subscriber):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
|
||||
async def _invoke():
|
||||
return await self._execute_subscriber(subscriber, params)
|
||||
|
||||
try:
|
||||
if handler_timeout and handler_timeout > 0:
|
||||
result = await asyncio.wait_for(_invoke(), timeout=handler_timeout)
|
||||
else:
|
||||
result = await _invoke()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)")
|
||||
return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name)
|
||||
except Exception as exc:
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {exc}")
|
||||
return HandlerResult(False, True, str(exc), handler_name)
|
||||
|
||||
if not isinstance(result, HandlerResult):
|
||||
return HandlerResult(True, True, result, handler_name)
|
||||
|
||||
if not result.handler_name:
|
||||
result.handler_name = handler_name
|
||||
return result
|
||||
|
||||
async def _guarded_run(subscriber):
|
||||
if semaphore:
|
||||
async with semaphore:
|
||||
return await _run_handler(subscriber)
|
||||
return await _run_handler(subscriber)
|
||||
|
||||
tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
processed_results: list[HandlerResult] = []
|
||||
for subscriber, result in zip(sorted_subscribers, results):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
|
||||
processed_results.append(HandlerResult(False, True, str(result), handler_name))
|
||||
try:
|
||||
if handler_timeout and handler_timeout > 0:
|
||||
result = await asyncio.wait_for(_invoke(), timeout=handler_timeout)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
result = await _invoke()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)")
|
||||
return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name)
|
||||
except Exception as exc:
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {exc}")
|
||||
return HandlerResult(False, True, str(exc), handler_name)
|
||||
|
||||
return HandlerResultsCollection(processed_results)
|
||||
if not isinstance(result, HandlerResult):
|
||||
return HandlerResult(True, True, result, handler_name)
|
||||
|
||||
if not result.handler_name:
|
||||
result.handler_name = handler_name
|
||||
return result
|
||||
|
||||
async def _guarded_run(subscriber):
|
||||
if semaphore:
|
||||
async with semaphore:
|
||||
return await _run_handler(subscriber)
|
||||
return await _run_handler(subscriber)
|
||||
|
||||
tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
processed_results: list[HandlerResult] = []
|
||||
for subscriber, result in zip(sorted_subscribers, results):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
|
||||
processed_results.append(HandlerResult(False, True, str(result), handler_name))
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return HandlerResultsCollection(processed_results)
|
||||
|
||||
@staticmethod
|
||||
async def _execute_subscriber(subscriber, params: dict) -> HandlerResult:
|
||||
|
||||
@@ -586,10 +586,16 @@ class PluginManager:
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(component_registry.unregister_plugin(plugin_name), loop)
|
||||
fut.result(timeout=5)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# 如果在运行的事件循环中,直接创建任务,不等待结果以避免死锁
|
||||
# 注意:这意味着我们无法确切知道卸载是否成功完成,但避免了阻塞
|
||||
logger.warning(f"unload_plugin 在异步上下文中被调用 ({plugin_name}),将异步执行组件卸载。建议使用 remove_registered_plugin。")
|
||||
loop.create_task(component_registry.unregister_plugin(plugin_name))
|
||||
else:
|
||||
asyncio.run(component_registry.unregister_plugin(plugin_name))
|
||||
except Exception as e: # 捕获并记录卸载阶段协程调用错误
|
||||
|
||||
Reference in New Issue
Block a user