feat: 增强内存移除和图扩展功能

- 更新了`graph_store.py`中的`remove_memory`方法,以包含一个可选参数`cleanup_orphans`,用于立即清理孤立节点。
- 对`graph_expansion.py`中的图扩展算法进行了优化,具体优化措施包括:
  - 采用内存级广度优先搜索(BFS)遍历,而非节点级遍历。
  - 批量检索邻居内存,以减少数据库调用次数。
  - 早期停止机制,以避免不必要的扩展。
  - 增强日志记录功能,以提高可追溯性。
- 增加了性能指标,以追踪内存扩展的效率。
This commit is contained in:
Windpicker-owo
2025-11-09 16:39:46 +08:00
parent a0bb9660d4
commit f4d2b54f83
5 changed files with 795 additions and 155 deletions

View File

@@ -78,7 +78,7 @@ class MemoryManager:
self._last_maintenance = datetime.now()
self._maintenance_task: asyncio.Task | None = None
self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
self._maintenance_schedule_id: str | None = None # 调度任务ID
self._maintenance_running = False # 维护任务运行状态
logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
@@ -155,8 +155,8 @@ class MemoryManager:
self._initialized = True
logger.info("✅ 记忆管理器初始化完成")
# 启动后台维护调度任务
await self.start_maintenance_scheduler()
# 启动后台维护任务
self._start_maintenance_task()
except Exception as e:
logger.error(f"记忆管理器初始化失败: {e}", exc_info=True)
@@ -178,8 +178,8 @@ class MemoryManager:
try:
logger.info("正在关闭记忆管理器...")
# 1. 停止调度任务
await self.stop_maintenance_scheduler()
# 1. 停止维护任务
await self._stop_maintenance_task()
# 2. 执行最后一次维护(保存数据)
if self.graph_store and self.persistence:
@@ -867,12 +867,19 @@ class MemoryManager:
max_expanded=max_expanded,
)
async def forget_memory(self, memory_id: str) -> bool:
async def forget_memory(self, memory_id: str, cleanup_orphans: bool = True) -> bool:
"""
遗忘记忆(标记为已遗忘,不删除)
遗忘记忆(直接删除)
这个方法会:
1. 从向量存储中删除节点的嵌入向量
2. 从图存储中删除记忆
3. 可选:清理孤立节点(建议批量遗忘后统一清理)
4. 保存更新后的数据
Args:
memory_id: 记忆 ID
cleanup_orphans: 是否立即清理孤立节点默认True批量遗忘时设为False
Returns:
是否遗忘成功
@@ -886,13 +893,36 @@ class MemoryManager:
logger.warning(f"记忆不存在: {memory_id}")
return False
memory.metadata["forgotten"] = True
memory.metadata["forgotten_at"] = datetime.now().isoformat()
# 1. 从向量存储删除节点的嵌入向量
deleted_vectors = 0
for node in memory.nodes:
if node.embedding is not None:
try:
await self.vector_store.delete_node(node.id)
deleted_vectors += 1
except Exception as e:
logger.warning(f"删除节点向量失败 {node.id}: {e}")
# 保存更新
await self.persistence.save_graph_store(self.graph_store)
logger.info(f"记忆已遗忘: {memory_id}")
return True
# 2. 从图存储删除记忆
success = self.graph_store.remove_memory(memory_id, cleanup_orphans=False)
if success:
# 3. 可选:清理孤立节点
if cleanup_orphans:
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
logger.info(
f"记忆已遗忘并删除: {memory_id} "
f"(删除了 {deleted_vectors} 个向量, 清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边)"
)
else:
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
# 4. 保存更新
await self.persistence.save_graph_store(self.graph_store)
return True
else:
logger.error(f"从图存储删除记忆失败: {memory_id}")
return False
except Exception as e:
logger.error(f"遗忘记忆失败: {e}", exc_info=True)
@@ -900,7 +930,12 @@ class MemoryManager:
async def auto_forget_memories(self, threshold: float = 0.1) -> int:
"""
自动遗忘低激活度的记忆
自动遗忘低激活度的记忆(批量优化版)
应用时间衰减公式计算当前激活度,低于阈值则遗忘。
衰减公式activation = base_activation * (decay_rate ^ days_passed)
优化:批量删除记忆后统一清理孤立节点,减少重复检查
Args:
threshold: 激活度阈值
@@ -914,41 +949,145 @@ class MemoryManager:
try:
forgotten_count = 0
all_memories = self.graph_store.get_all_memories()
# 获取配置参数
min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
# 收集需要遗忘的记忆ID
memories_to_forget = []
for memory in all_memories:
# 跳过已遗忘的记忆
if memory.metadata.get("forgotten", False):
continue
# 跳过高重要性记忆
min_importance = getattr(self.config, "forgetting_min_importance", 7.0)
# 跳过高重要性记忆(保护重要记忆不被遗忘)
if memory.importance >= min_importance:
continue
# 计算当前激活度
# 计算当前激活度(应用时间衰减)
activation_info = memory.metadata.get("activation", {})
base_activation = activation_info.get("level", memory.activation)
last_access = activation_info.get("last_access")
if last_access:
last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days
try:
last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days
# 应用指数衰减activation = base * (decay_rate ^ days)
current_activation = base_activation * (decay_rate ** days_passed)
logger.debug(
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
f"经过{days_passed}天衰减后={current_activation:.3f}"
)
except (ValueError, TypeError) as e:
logger.warning(f"解析时间失败: {e}, 使用基础激活度")
current_activation = base_activation
else:
# 没有访问记录,使用基础激活度
current_activation = base_activation
# 长时间未访问的记忆,应用时间衰减
decay_factor = 0.9 ** days_passed
current_activation = activation_info.get("level", 0.0) * decay_factor
# 低于阈值则标记为待遗忘
if current_activation < threshold:
memories_to_forget.append((memory.id, current_activation))
logger.debug(
f"标记遗忘 {memory.id[:8]}: 激活度={current_activation:.3f} < 阈值={threshold:.3f}"
)
# 低于阈值则遗忘
if current_activation < threshold:
await self.forget_memory(memory.id)
# 批量遗忘记忆(不立即清理孤立节点)
if memories_to_forget:
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
for memory_id, activation in memories_to_forget:
# cleanup_orphans=False暂不清理孤立节点
success = await self.forget_memory(memory_id, cleanup_orphans=False)
if success:
forgotten_count += 1
# 统一清理孤立节点和边
logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
# 保存最终更新
await self.persistence.save_graph_store(self.graph_store)
logger.info(
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
)
else:
logger.info("✅ 自动遗忘完成: 没有需要遗忘的记忆")
logger.info(f"自动遗忘完成: 遗忘了 {forgotten_count} 条记忆")
return forgotten_count
except Exception as e:
logger.error(f"自动遗忘失败: {e}", exc_info=True)
return 0
async def _cleanup_orphan_nodes_and_edges(self) -> tuple[int, int]:
"""
清理孤立节点和边
孤立节点:不再属于任何记忆的节点
孤立边:连接到已删除节点的边
Returns:
(清理的孤立节点数, 清理的孤立边数)
"""
try:
orphan_nodes_count = 0
orphan_edges_count = 0
# 1. 清理孤立节点
# graph_store.node_to_memories 记录了每个节点属于哪些记忆
nodes_to_remove = []
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
# 如果节点不再属于任何记忆,标记为删除
if not memory_ids:
nodes_to_remove.append(node_id)
# 从图中删除孤立节点
for node_id in nodes_to_remove:
if self.graph_store.graph.has_node(node_id):
self.graph_store.graph.remove_node(node_id)
orphan_nodes_count += 1
# 从映射中删除
if node_id in self.graph_store.node_to_memories:
del self.graph_store.node_to_memories[node_id]
# 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
# 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories:
edges_to_remove.append((source, target))
# 删除孤立边
for source, target in edges_to_remove:
try:
self.graph_store.graph.remove_edge(source, target)
orphan_edges_count += 1
except Exception as e:
logger.debug(f"删除边失败 {source} -> {target}: {e}")
if orphan_nodes_count > 0 or orphan_edges_count > 0:
logger.info(
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
)
return orphan_nodes_count, orphan_edges_count
except Exception as e:
logger.error(f"清理孤立节点和边失败: {e}", exc_info=True)
return 0, 0
# ==================== 统计与维护 ====================
def get_statistics(self) -> dict[str, Any]:
@@ -1043,7 +1182,14 @@ class MemoryManager:
max_batch_size: int,
) -> None:
"""
后台执行记忆整理的具体实现
后台执行记忆整理的具体实现 (完整版)
流程:
1. 获取时间窗口内的记忆
2. 重要性过滤
3. 向量检索关联记忆
4. 分批交给LLM分析关系
5. 统一更新记忆数据
这个方法会在独立任务中运行,不阻塞主流程
"""
@@ -1052,9 +1198,11 @@ class MemoryManager:
"merged_count": 0,
"checked_count": 0,
"skipped_count": 0,
"linked_count": 0,
"importance_filtered": 0,
}
# 获取最近创建的记忆
# ===== 步骤1: 获取时间窗口内的记忆 =====
cutoff_time = datetime.now() - timedelta(hours=time_window_hours)
all_memories = self.graph_store.get_all_memories()
@@ -1067,18 +1215,37 @@ class MemoryManager:
logger.info("✅ 记忆整理完成: 没有需要整理的记忆")
return
logger.info(f"📋 步骤1: 找到 {len(recent_memories)} 条时间窗口内的记忆")
# ===== 步骤2: 重要性过滤 =====
min_importance_for_consolidation = getattr(self.config, "consolidation_min_importance", 0.3)
important_memories = [
mem for mem in recent_memories
if mem.importance >= min_importance_for_consolidation
]
result["importance_filtered"] = len(recent_memories) - len(important_memories)
logger.info(
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
f"{len(recent_memories)}{len(important_memories)} 条记忆"
)
if not important_memories:
logger.info("✅ 记忆整理完成: 没有重要的记忆需要整理")
return
# 限制批量处理数量
if len(recent_memories) > max_batch_size:
logger.info(f"📊 记忆数量 {len(recent_memories)} 超过批量限制 {max_batch_size},仅处理最新的 {max_batch_size}")
recent_memories = sorted(recent_memories, key=lambda m: m.created_at, reverse=True)[:max_batch_size]
result["skipped_count"] = len(all_memories) - max_batch_size
if len(important_memories) > max_batch_size:
logger.info(f"📊 记忆数量 {len(important_memories)} 超过批量限制 {max_batch_size},仅处理最新的 {max_batch_size}")
important_memories = sorted(important_memories, key=lambda m: m.created_at, reverse=True)[:max_batch_size]
result["skipped_count"] = len(important_memories) - max_batch_size
logger.info(f"📋 找到 {len(recent_memories)} 条待整理记忆")
result["checked_count"] = len(recent_memories)
result["checked_count"] = len(important_memories)
# ===== 步骤3: 去重(相似记忆合并)=====
# 按记忆类型分组,减少跨类型比较
memories_by_type: dict[str, list[Memory]] = {}
for mem in recent_memories:
for mem in important_memories:
mem_type = mem.metadata.get("memory_type", "")
if mem_type not in memories_by_type:
memories_by_type[mem_type] = []
@@ -1088,7 +1255,8 @@ class MemoryManager:
to_delete: list[tuple[Memory, str]] = [] # (memory, reason)
deleted_ids = set()
# 对每个类型的记忆进行相似度检测
# 对每个类型的记忆进行相似度检测(去重)
logger.info("📍 步骤3: 开始相似记忆去重...")
for mem_type, memories in memories_by_type.items():
if len(memories) < 2:
continue
@@ -1106,7 +1274,6 @@ class MemoryManager:
valid_memories.append(mem)
# 批量计算相似度矩阵(比逐个计算更高效)
for i in range(len(valid_memories)):
# 更频繁的协作式多任务让出
if i % 5 == 0:
@@ -1158,7 +1325,7 @@ class MemoryManager:
# 批量删除标记的记忆
if to_delete:
logger.info(f"🗑️ 开始批量删除 {len(to_delete)} 条相似记忆")
logger.info(f"🗑️ 批量删除 {len(to_delete)} 条相似记忆")
for memory, reason in to_delete:
try:
@@ -1175,7 +1342,118 @@ class MemoryManager:
# 批量保存一次性写入减少I/O
await self.persistence.save_graph_store(self.graph_store)
logger.info("💾 批量保存完成")
logger.info("💾 去重保存完成")
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
# 过滤掉已删除的记忆
remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
if not remaining_memories:
logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
return
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
# 分批处理记忆关联
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
all_new_edges = [] # 收集所有新建的边
for batch_start in range(0, len(remaining_memories), llm_batch_size):
batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
batch = remaining_memories[batch_start:batch_end]
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
for memory in batch:
# 跳过已经有很多连接的记忆
existing_edges = len([
e for e in memory.edges
if e.edge_type == EdgeType.RELATION
])
if existing_edges >= 10:
continue
# 使用向量搜索找候选关联记忆
candidates = await self._find_link_candidates(
memory,
exclude_ids={memory.id} | deleted_ids,
max_results=max_candidates_per_memory
)
if not candidates:
continue
# 使用LLM分析关系
relations = await self._analyze_memory_relations(
source_memory=memory,
candidate_memories=candidates,
min_confidence=min_confidence
)
# 建立关联边
for relation in relations:
try:
# 创建关联边
edge = MemoryEdge(
id=f"edge_{uuid.uuid4().hex[:12]}",
source_id=memory.subject_id,
target_id=relation["target_memory"].subject_id,
relation=relation["relation_type"],
edge_type=EdgeType.RELATION,
importance=relation["confidence"],
metadata={
"auto_linked": True,
"confidence": relation["confidence"],
"reasoning": relation["reasoning"],
"created_at": datetime.now().isoformat(),
"created_by": "consolidation",
}
)
all_new_edges.append((memory, edge, relation))
result["linked_count"] += 1
except Exception as e:
logger.warning(f"创建关联边失败: {e}")
continue
# 每个批次后让出控制权
await asyncio.sleep(0.01)
# ===== 步骤5: 统一更新记忆数据 =====
if all_new_edges:
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
for memory, edge, relation in all_new_edges:
try:
# 添加到图
self.graph_store.graph.add_edge(
edge.source_id,
edge.target_id,
edge_id=edge.id,
relation=edge.relation,
edge_type=edge.edge_type.value,
importance=edge.importance,
metadata=edge.metadata,
)
# 同时添加到记忆的边列表
memory.edges.append(edge)
logger.debug(
f"{memory.id[:8]} --[{relation['relation_type']}]--> "
f"{relation['target_memory'].id[:8]} (置信度={relation['confidence']:.2f})"
)
except Exception as e:
logger.warning(f"添加边到图失败: {e}")
# 批量保存更新
await self.persistence.save_graph_store(self.graph_store)
logger.info("💾 关联边保存完成")
logger.info(f"✅ 记忆整理完成: {result}")
@@ -1917,11 +2195,11 @@ class MemoryManager:
logger.error(f"LLM批量关系分析失败: {e}", exc_info=True)
return []
async def start_maintenance_scheduler(self) -> None:
def _start_maintenance_task(self) -> None:
"""
启动记忆维护调度任务
启动记忆维护后台任务
使用 unified_scheduler 定期执行维护任务
直接创建async task避免使用scheduler阻塞主程序
- 记忆整合(合并相似记忆)
- 自动遗忘低激活度记忆
- 保存数据
@@ -1929,57 +2207,96 @@ class MemoryManager:
默认间隔1小时
"""
try:
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
# 如果已有维护任务,先停止
if self._maintenance_task and not self._maintenance_task.done():
self._maintenance_task.cancel()
logger.info("取消旧的维护任务")
# 如果已有调度任务,先移除
if self._maintenance_schedule_id:
await unified_scheduler.remove_schedule(self._maintenance_schedule_id)
logger.info("移除旧的维护调度任务")
# 创建新的调度任务
interval_seconds = self._maintenance_interval_hours * 3600
self._maintenance_schedule_id = await unified_scheduler.create_schedule(
callback=self.maintenance,
trigger_type=TriggerType.TIME,
trigger_config={
"delay_seconds": interval_seconds, # 首次延迟启动后1小时
"interval_seconds": interval_seconds, # 循环间隔
},
is_recurring=True,
task_name="memory_maintenance",
# 创建新的后台维护任务
self._maintenance_task = asyncio.create_task(
self._maintenance_loop(),
name="memory_maintenance_loop"
)
logger.info(
f"✅ 记忆维护调度任务已启动 "
f"(间隔={self._maintenance_interval_hours}小时, "
f"schedule_id={self._maintenance_schedule_id[:8]}...)"
f"✅ 记忆维护后台任务已启动 "
f"(间隔={self._maintenance_interval_hours}小时)"
)
except ImportError:
logger.warning("无法导入 unified_scheduler维护调度功能不可用")
except Exception as e:
logger.error(f"启动维护调度任务失败: {e}", exc_info=True)
logger.error(f"启动维护后台任务失败: {e}", exc_info=True)
async def stop_maintenance_scheduler(self) -> None:
async def _stop_maintenance_task(self) -> None:
"""
停止记忆维护调度任务
停止记忆维护后台任务
"""
if not self._maintenance_schedule_id:
if not self._maintenance_task or self._maintenance_task.done():
return
try:
from src.schedule.unified_scheduler import unified_scheduler
self._maintenance_running = False # 设置停止标志
self._maintenance_task.cancel()
success = await unified_scheduler.remove_schedule(self._maintenance_schedule_id)
if success:
logger.info(f"✅ 记忆维护调度任务已停止 (schedule_id={self._maintenance_schedule_id[:8]}...)")
else:
logger.warning(f"停止维护调度任务失败 (schedule_id={self._maintenance_schedule_id[:8]}...)")
try:
await self._maintenance_task
except asyncio.CancelledError:
logger.debug("维护任务已取消")
self._maintenance_schedule_id = None
logger.info("✅ 记忆维护后台任务已停止")
self._maintenance_task = None
except ImportError:
logger.warning("无法导入 unified_scheduler")
except Exception as e:
logger.error(f"停止维护调度任务失败: {e}", exc_info=True)
logger.error(f"停止维护后台任务失败: {e}", exc_info=True)
async def _maintenance_loop(self) -> None:
"""
记忆维护循环
在后台独立运行,定期执行维护任务,避免阻塞主程序
"""
self._maintenance_running = True
try:
# 首次执行延迟启动后1小时
initial_delay = self._maintenance_interval_hours * 3600
logger.debug(f"记忆维护任务将在 {initial_delay} 秒后首次执行")
while self._maintenance_running:
try:
# 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for(
asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消
)
# 检查是否仍然需要运行
if not self._maintenance_running:
break
# 执行维护任务使用try-catch避免崩溃
try:
await self.maintenance()
except Exception as e:
logger.error(f"维护任务执行失败: {e}", exc_info=True)
# 后续执行使用相同间隔
initial_delay = self._maintenance_interval_hours * 3600
except asyncio.CancelledError:
logger.debug("维护循环被取消")
break
except Exception as e:
logger.error(f"维护循环发生异常: {e}", exc_info=True)
# 异常后等待较短时间再重试
try:
await asyncio.sleep(300) # 5分钟后重试
except asyncio.CancelledError:
break
except asyncio.CancelledError:
logger.debug("维护循环完全退出")
except Exception as e:
logger.error(f"维护循环意外结束: {e}", exc_info=True)
finally:
self._maintenance_running = False
logger.debug("维护循环已清理完毕")