feat(long_term_manager): 优化长期记忆管理器性能

This commit is contained in:
LuiKlee
2025-12-13 16:17:30 +08:00
parent 30648565a5
commit 4fe8e29ba5
2 changed files with 534 additions and 86 deletions

View File

@@ -57,6 +57,15 @@ class LongTermMemoryManager:
# 状态
self._initialized = False
# 批量embedding生成队列
self._pending_embeddings: list[tuple[str, str]] = [] # (node_id, content)
self._embedding_batch_size = 10
self._embedding_lock = asyncio.Lock()
# 相似记忆缓存 (stm_id -> memories)
self._similar_memory_cache: dict[str, list[Memory]] = {}
self._cache_max_size = 100
logger.info(
f"长期记忆管理器已创建 (batch_size={batch_size}, "
f"search_top_k={search_top_k}, decay_factor={long_term_decay_factor:.2f})"
@@ -150,7 +159,7 @@ class LongTermMemoryManager:
async def _process_batch(self, batch: list[ShortTermMemory]) -> dict[str, Any]:
"""
处理一批短期记忆
处理一批短期记忆(并行处理)
Args:
batch: 短期记忆批次
@@ -167,57 +176,89 @@ class LongTermMemoryManager:
"transferred_memory_ids": [],
}
for stm in batch:
try:
# 步骤1: 在长期记忆中检索相似记忆
similar_memories = await self._search_similar_long_term_memories(stm)
# 并行处理批次中的所有记忆
tasks = [self._process_single_memory(stm) for stm in batch]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 步骤2: LLM 决策如何更新图结构
operations = await self._decide_graph_operations(stm, similar_memories)
# 汇总结果
for stm, single_result in zip(batch, results):
if isinstance(single_result, Exception):
logger.error(f"处理短期记忆 {stm.id} 失败: {single_result}")
result["failed_count"] += 1
elif single_result and isinstance(single_result, dict):
result["processed_count"] += 1
result["transferred_memory_ids"].append(stm.id)
# 步骤3: 执行图操作
success = await self._execute_graph_operations(operations, stm)
if success:
result["processed_count"] += 1
result["transferred_memory_ids"].append(stm.id)
# 统计操作类型
for op in operations:
if op.operation_type == GraphOperationType.CREATE_MEMORY:
# 统计操作类型
operations = single_result.get("operations", [])
if isinstance(operations, list):
for op_type in operations:
if op_type == GraphOperationType.CREATE_MEMORY:
result["created_count"] += 1
elif op.operation_type == GraphOperationType.UPDATE_MEMORY:
elif op_type == GraphOperationType.UPDATE_MEMORY:
result["updated_count"] += 1
elif op.operation_type == GraphOperationType.MERGE_MEMORIES:
elif op_type == GraphOperationType.MERGE_MEMORIES:
result["merged_count"] += 1
else:
result["failed_count"] += 1
except Exception as e:
logger.error(f"处理短期记忆 {stm.id} 失败: {e}")
else:
result["failed_count"] += 1
# 处理完批次后批量生成embeddings
await self._flush_pending_embeddings()
return result
async def _process_single_memory(self, stm: ShortTermMemory) -> dict[str, Any] | None:
"""
处理单条短期记忆
Args:
stm: 短期记忆
Returns:
处理结果或None如果失败
"""
try:
# 步骤1: 在长期记忆中检索相似记忆
similar_memories = await self._search_similar_long_term_memories(stm)
# 步骤2: LLM 决策如何更新图结构
operations = await self._decide_graph_operations(stm, similar_memories)
# 步骤3: 执行图操作
success = await self._execute_graph_operations(operations, stm)
if success:
return {
"success": True,
"operations": [op.operation_type for op in operations]
}
return None
except Exception as e:
logger.error(f"处理短期记忆 {stm.id} 失败: {e}")
return None
async def _search_similar_long_term_memories(
self, stm: ShortTermMemory
) -> list[Memory]:
"""
在长期记忆中检索与短期记忆相似的记忆
优化:不仅检索内容相似的,还利用图结构获取上下文相关的记忆
优化:使用缓存并减少重复查询
"""
# 检查缓存
if stm.id in self._similar_memory_cache:
logger.debug(f"使用缓存的相似记忆: {stm.id}")
return self._similar_memory_cache[stm.id]
try:
from src.config.config import global_config
# 检查是否启用了高级路径扩展算法
use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False)
# 1. 检索记忆
# 如果启用了路径扩展search_memories 内部会自动使用 PathScoreExpansion
# 我们只需要传入合适的 expand_depth
expand_depth = getattr(global_config.memory, "path_expansion_max_hops", 2) if use_path_expansion else 0
# 1. 检索记忆
memories = await self.memory_manager.search_memories(
query=stm.content,
top_k=self.search_top_k,
@@ -226,53 +267,91 @@ class LongTermMemoryManager:
expand_depth=expand_depth
)
# 2. 图结构扩展 (Graph Expansion)
# 如果已经使用了高级路径扩展算法,就不需要再做简单的手动扩展了
# 2. 如果启用了高级路径扩展,直接返回
if use_path_expansion:
logger.debug(f"已使用路径扩展算法检索到 {len(memories)} 条记忆")
self._cache_similar_memories(stm.id, memories)
return memories
# 如果未启用高级算法,使用简单的 1 跳邻居扩展作为保底
expanded_memories = []
seen_ids = {m.id for m in memories}
# 3. 简化的图扩展(仅在未启用高级算法时)
if memories:
# 批量获取相关记忆ID减少单次查询
related_ids_batch = await self._batch_get_related_memories(
[m.id for m in memories], max_depth=1, max_per_memory=2
)
for mem in memories:
expanded_memories.append(mem)
# 批量加载相关记忆
seen_ids = {m.id for m in memories}
new_memories = []
for rid in related_ids_batch:
if rid not in seen_ids and len(new_memories) < self.search_top_k:
related_mem = await self.memory_manager.get_memory(rid)
if related_mem:
new_memories.append(related_mem)
seen_ids.add(rid)
# 获取该记忆的直接关联记忆1跳邻居
try:
# 利用 MemoryManager 的底层图遍历能力
related_ids = self.memory_manager._get_related_memories(mem.id, max_depth=1)
memories.extend(new_memories)
# 限制每个记忆扩展的邻居数量,避免上下文爆炸
max_neighbors = 2
neighbor_count = 0
logger.debug(f"为短期记忆 {stm.id} 找到 {len(memories)} 个长期记忆")
for rid in related_ids:
if rid not in seen_ids:
related_mem = await self.memory_manager.get_memory(rid)
if related_mem:
expanded_memories.append(related_mem)
seen_ids.add(rid)
neighbor_count += 1
if neighbor_count >= max_neighbors:
break
except Exception as e:
logger.warning(f"获取关联记忆失败: {e}")
# 总数限制
if len(expanded_memories) >= self.search_top_k * 2:
break
logger.debug(f"为短期记忆 {stm.id} 找到 {len(expanded_memories)} 个长期记忆 (含简单图扩展)")
return expanded_memories
# 缓存结果
self._cache_similar_memories(stm.id, memories)
return memories
except Exception as e:
logger.error(f"检索相似长期记忆失败: {e}")
return []
async def _batch_get_related_memories(
self, memory_ids: list[str], max_depth: int = 1, max_per_memory: int = 2
) -> set[str]:
"""
批量获取相关记忆ID
Args:
memory_ids: 记忆ID列表
max_depth: 最大深度
max_per_memory: 每个记忆最多获取的相关记忆数
Returns:
相关记忆ID集合
"""
all_related_ids = set()
try:
for mem_id in memory_ids:
if len(all_related_ids) >= max_per_memory * len(memory_ids):
break
try:
related_ids = self.memory_manager._get_related_memories(mem_id, max_depth=max_depth)
# 限制每个记忆的相关数量
for rid in list(related_ids)[:max_per_memory]:
all_related_ids.add(rid)
except Exception as e:
logger.warning(f"获取记忆 {mem_id} 的相关记忆失败: {e}")
except Exception as e:
logger.error(f"批量获取相关记忆失败: {e}")
return all_related_ids
def _cache_similar_memories(self, stm_id: str, memories: list[Memory]) -> None:
"""
缓存相似记忆
Args:
stm_id: 短期记忆ID
memories: 相似记忆列表
"""
# 简单的LRU策略如果超过最大缓存数删除最早的
if len(self._similar_memory_cache) >= self._cache_max_size:
# 删除第一个(最早的)
first_key = next(iter(self._similar_memory_cache))
del self._similar_memory_cache[first_key]
self._similar_memory_cache[stm_id] = memories
async def _decide_graph_operations(
self, stm: ShortTermMemory, similar_memories: list[Memory]
) -> list[GraphOperation]:
@@ -587,17 +666,24 @@ class LongTermMemoryManager:
return temp_id_map.get(raw_id, raw_id)
def _resolve_value(self, value: Any, temp_id_map: dict[str, str]) -> Any:
if isinstance(value, str):
return self._resolve_id(value, temp_id_map)
if isinstance(value, list):
return [self._resolve_value(v, temp_id_map) for v in value]
if isinstance(value, dict):
return {k: self._resolve_value(v, temp_id_map) for k, v in value.items()}
"""优化的值解析,减少递归和类型检查"""
value_type = type(value)
if value_type is str:
return temp_id_map.get(value, value)
elif value_type is list:
return [temp_id_map.get(v, v) if isinstance(v, str) else v for v in value]
elif value_type is dict:
return {k: temp_id_map.get(v, v) if isinstance(v, str) else v
for k, v in value.items()}
return value
def _resolve_parameters(
self, params: dict[str, Any], temp_id_map: dict[str, str]
) -> dict[str, Any]:
"""优化的参数解析"""
if not temp_id_map:
return params
return {k: self._resolve_value(v, temp_id_map) for k, v in params.items()}
def _register_aliases_from_params(
@@ -730,8 +816,10 @@ class LongTermMemoryManager:
importance=merged_importance,
)
# 3. 异步保存
asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆"))
# 3. 异步保存(后台任务,不需要等待)
asyncio.create_task( # noqa: RUF006
self.memory_manager._async_save_graph_store("合并记忆")
)
logger.info(f"合并记忆完成: {source_ids} -> {target_id}")
else:
logger.error(f"合并记忆失败: {source_ids}")
@@ -761,8 +849,8 @@ class LongTermMemoryManager:
)
if success:
# 尝试为新节点生成 embedding (异步)
asyncio.create_task(self._generate_node_embedding(node_id, content))
# 将embedding生成加入队列批量处理
await self._queue_embedding_generation(node_id, content)
logger.info(f"创建节点: {content} ({node_type}) -> {memory_id}")
# 强制注册 target_id无论它是否符合 placeholder 格式
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
@@ -820,7 +908,7 @@ class LongTermMemoryManager:
# 合并其他节点到目标节点
for source_id in sources:
self.memory_manager.graph_store.merge_nodes(source_id, target_id)
logger.info(f"合并节点: {sources} -> {target_id}")
async def _execute_create_edge(
@@ -901,20 +989,83 @@ class LongTermMemoryManager:
else:
logger.error(f"删除边失败: {edge_id}")
async def _generate_node_embedding(self, node_id: str, content: str) -> None:
"""为新节点生成 embedding 并存入向量库"""
async def _queue_embedding_generation(self, node_id: str, content: str) -> None:
"""将节点加入embedding生成队列"""
async with self._embedding_lock:
self._pending_embeddings.append((node_id, content))
# 如果队列达到批次大小,立即处理
if len(self._pending_embeddings) >= self._embedding_batch_size:
await self._flush_pending_embeddings()
async def _flush_pending_embeddings(self) -> None:
"""批量处理待生成的embeddings"""
async with self._embedding_lock:
if not self._pending_embeddings:
return
batch = self._pending_embeddings[:]
self._pending_embeddings.clear()
if not self.memory_manager.vector_store or not self.memory_manager.embedding_generator:
return
try:
# 批量生成embeddings
contents = [content for _, content in batch]
embeddings = await self.memory_manager.embedding_generator.generate_batch(contents)
if not embeddings or len(embeddings) != len(batch):
logger.warning("批量生成embedding失败或数量不匹配")
# 回退到单个生成
for node_id, content in batch:
await self._generate_node_embedding_single(node_id, content)
return
# 批量添加到向量库
from src.memory_graph.models import MemoryNode, NodeType
nodes = [
MemoryNode(
id=node_id,
content=content,
node_type=NodeType.OBJECT,
embedding=embedding
)
for (node_id, content), embedding in zip(batch, embeddings)
if embedding is not None
]
if nodes:
# 批量添加节点
await self.memory_manager.vector_store.add_nodes_batch(nodes)
# 批量更新图存储
for node in nodes:
node.mark_vector_stored()
if self.memory_manager.graph_store.graph.has_node(node.id):
self.memory_manager.graph_store.graph.nodes[node.id]["has_vector"] = True
logger.debug(f"批量生成 {len(nodes)} 个节点的embedding")
except Exception as e:
logger.error(f"批量生成embedding失败: {e}")
# 回退到单个生成
for node_id, content in batch:
await self._generate_node_embedding_single(node_id, content)
async def _generate_node_embedding_single(self, node_id: str, content: str) -> None:
"""为单个节点生成 embedding 并存入向量库(回退方法)"""
try:
if not self.memory_manager.vector_store or not self.memory_manager.embedding_generator:
return
embedding = await self.memory_manager.embedding_generator.generate(content)
if embedding is not None:
# 需要构造一个 MemoryNode 对象来调用 add_node
from src.memory_graph.models import MemoryNode, NodeType
node = MemoryNode(
id=node_id,
content=content,
node_type=NodeType.OBJECT, # 默认
node_type=NodeType.OBJECT,
embedding=embedding
)
await self.memory_manager.vector_store.add_node(node)
@@ -926,7 +1077,7 @@ class LongTermMemoryManager:
async def apply_long_term_decay(self) -> dict[str, Any]:
"""
应用长期记忆的激活度衰减
应用长期记忆的激活度衰减(优化版)
长期记忆的衰减比短期记忆慢,使用更高的衰减因子。
@@ -941,6 +1092,12 @@ class LongTermMemoryManager:
all_memories = self.memory_manager.graph_store.get_all_memories()
decayed_count = 0
now = datetime.now()
# 预计算衰减因子的幂次方(缓存常用值)
decay_cache = {i: self.long_term_decay_factor ** i for i in range(1, 31)} # 缓存1-30天
memories_to_update = []
for memory in all_memories:
# 跳过已遗忘的记忆
@@ -954,27 +1111,34 @@ class LongTermMemoryManager:
if last_access:
try:
last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days
days_passed = (now - last_access_dt).days
if days_passed > 0:
# 使用长期记忆的衰减因子
# 使用缓存的衰减因子或计算新值
decay_factor = decay_cache.get(
days_passed,
self.long_term_decay_factor ** days_passed
)
base_activation = activation_info.get("level", memory.activation)
new_activation = base_activation * (self.long_term_decay_factor ** days_passed)
new_activation = base_activation * decay_factor
# 更新激活度
memory.activation = new_activation
activation_info["level"] = new_activation
memory.metadata["activation"] = activation_info
memories_to_update.append(memory)
decayed_count += 1
except (ValueError, TypeError) as e:
logger.warning(f"解析时间失败: {e}")
# 保存更新
await self.memory_manager.persistence.save_graph_store(
self.memory_manager.graph_store
)
# 批量保存更新(如果有变化)
if memories_to_update:
await self.memory_manager.persistence.save_graph_store(
self.memory_manager.graph_store
)
logger.info(f"长期记忆衰减完成: {decayed_count} 条记忆已更新")
return {"decayed_count": decayed_count, "total_memories": len(all_memories)}
@@ -1002,6 +1166,12 @@ class LongTermMemoryManager:
try:
logger.info("正在关闭长期记忆管理器...")
# 清空待处理的embedding队列
await self._flush_pending_embeddings()
# 清空缓存
self._similar_memory_cache.clear()
# 长期记忆的保存由 MemoryManager 负责
self._initialized = False