refactor: 优化记忆激活机制,增加快速批量激活功能以提升性能
This commit is contained in:
@@ -438,9 +438,9 @@ class MemoryManager:
|
||||
f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})"
|
||||
)
|
||||
|
||||
# 强制激活被检索到的记忆(核心功能)
|
||||
# 强制激活被检索到的记忆(核心功能)- 使用快速批量激活
|
||||
if filtered_memories:
|
||||
await self._auto_activate_searched_memories(filtered_memories)
|
||||
await self._quick_batch_activate_memories(filtered_memories)
|
||||
|
||||
return filtered_memories[:top_k]
|
||||
|
||||
@@ -571,46 +571,235 @@ class MemoryManager:
|
||||
|
||||
async def _auto_activate_searched_memories(self, memories: list[Memory]) -> None:
|
||||
"""
|
||||
自动激活被搜索到的记忆
|
||||
批量激活被搜索到的记忆
|
||||
|
||||
Args:
|
||||
memories: 被检索到的记忆列表
|
||||
"""
|
||||
try:
|
||||
if not memories:
|
||||
return
|
||||
|
||||
# 获取配置参数
|
||||
base_strength = getattr(self.config, "auto_activate_base_strength", 0.1)
|
||||
max_activate_count = getattr(self.config, "auto_activate_max_count", 5)
|
||||
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
|
||||
now = datetime.now()
|
||||
|
||||
# 激活强度根据记忆重要性调整
|
||||
activate_tasks = []
|
||||
for i, memory in enumerate(memories[:max_activate_count]):
|
||||
# 重要性越高,激活强度越大
|
||||
# 限制处理的记忆数量
|
||||
memories_to_activate = memories[:max_activate_count]
|
||||
|
||||
# 批量更新激活度
|
||||
activation_updates = []
|
||||
for memory in memories_to_activate:
|
||||
# 计算激活强度
|
||||
strength = base_strength * (0.5 + memory.importance)
|
||||
|
||||
# 创建异步激活任务
|
||||
task = self.activate_memory(memory.id, strength=strength)
|
||||
activate_tasks.append(task)
|
||||
# 获取当前激活度信息
|
||||
activation_info = memory.metadata.get("activation", {})
|
||||
last_access = activation_info.get("last_access")
|
||||
|
||||
if i >= max_activate_count - 1:
|
||||
break
|
||||
if last_access:
|
||||
# 计算时间衰减
|
||||
last_access_dt = datetime.fromisoformat(last_access)
|
||||
hours_passed = (now - last_access_dt).total_seconds() / 3600
|
||||
decay_factor = decay_rate ** (hours_passed / 24)
|
||||
current_activation = activation_info.get("level", 0.0) * decay_factor
|
||||
else:
|
||||
current_activation = 0.0
|
||||
|
||||
# 并发执行激活任务(但不等待所有完成,避免阻塞搜索)
|
||||
if activate_tasks:
|
||||
import asyncio
|
||||
# 使用 asyncio.gather 但设置较短的 timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*activate_tasks, return_exceptions=True),
|
||||
timeout=2.0 # 2秒超时,避免阻塞主流程
|
||||
)
|
||||
logger.debug(f"自动激活 {len(activate_tasks)} 条记忆完成")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"自动激活记忆超时,已激活部分记忆")
|
||||
except Exception as e:
|
||||
logger.warning(f"自动激活记忆失败: {e}")
|
||||
# 计算新的激活度
|
||||
new_activation = min(1.0, current_activation + strength)
|
||||
|
||||
# 更新记忆对象
|
||||
memory.activation = new_activation
|
||||
memory.last_accessed = now
|
||||
activation_info.update({
|
||||
"level": new_activation,
|
||||
"last_access": now.isoformat(),
|
||||
"access_count": activation_info.get("access_count", 0) + 1,
|
||||
})
|
||||
memory.metadata["activation"] = activation_info
|
||||
|
||||
activation_updates.append({
|
||||
"memory_id": memory.id,
|
||||
"old_activation": current_activation,
|
||||
"new_activation": new_activation,
|
||||
"strength": strength
|
||||
})
|
||||
|
||||
# 批量保存到数据库
|
||||
if activation_updates:
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
# 激活传播(异步执行,不阻塞主流程)
|
||||
asyncio.create_task(self._batch_propagate_activation(memories_to_activate, base_strength))
|
||||
|
||||
logger.debug(f"批量激活 {len(activation_updates)} 条记忆完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"自动激活搜索记忆失败: {e}")
|
||||
logger.warning(f"批量激活搜索记忆失败: {e}")
|
||||
|
||||
async def _quick_batch_activate_memories(self, memories: list[Memory]) -> None:
|
||||
"""
|
||||
快速批量激活记忆(用于搜索结果,优化性能)
|
||||
|
||||
与 _auto_activate_searched_memories 的区别:
|
||||
- 更轻量级,专注于速度
|
||||
- 简化激活传播逻辑
|
||||
- 减少数据库写入次数
|
||||
|
||||
Args:
|
||||
memories: 需要激活的记忆列表
|
||||
"""
|
||||
try:
|
||||
if not memories:
|
||||
return
|
||||
|
||||
# 获取配置参数
|
||||
base_strength = getattr(self.config, "auto_activate_base_strength", 0.1)
|
||||
max_activate_count = getattr(self.config, "auto_activate_max_count", 5)
|
||||
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
|
||||
now = datetime.now()
|
||||
|
||||
# 限制处理的记忆数量
|
||||
memories_to_activate = memories[:max_activate_count]
|
||||
|
||||
# 批量更新激活度(内存操作)
|
||||
for memory in memories_to_activate:
|
||||
# 计算激活强度
|
||||
strength = base_strength * (0.5 + memory.importance)
|
||||
|
||||
# 快速计算新的激活度(简化版)
|
||||
activation_info = memory.metadata.get("activation", {})
|
||||
last_access = activation_info.get("last_access")
|
||||
|
||||
if last_access:
|
||||
# 简化的时间衰减计算
|
||||
try:
|
||||
last_access_dt = datetime.fromisoformat(last_access)
|
||||
hours_passed = (now - last_access_dt).total_seconds() / 3600
|
||||
decay_factor = decay_rate ** (hours_passed / 24)
|
||||
current_activation = activation_info.get("level", 0.0) * decay_factor
|
||||
except (ValueError, TypeError):
|
||||
current_activation = activation_info.get("level", 0.0) * 0.9 # 默认衰减
|
||||
else:
|
||||
current_activation = 0.0
|
||||
|
||||
# 计算新的激活度
|
||||
new_activation = min(1.0, current_activation + strength)
|
||||
|
||||
# 直接更新记忆对象(内存中)
|
||||
memory.activation = new_activation
|
||||
memory.last_accessed = now
|
||||
activation_info.update({
|
||||
"level": new_activation,
|
||||
"last_access": now.isoformat(),
|
||||
"access_count": activation_info.get("access_count", 0) + 1,
|
||||
})
|
||||
memory.metadata["activation"] = activation_info
|
||||
|
||||
# 异步批量保存(不阻塞搜索)
|
||||
if memories_to_activate:
|
||||
asyncio.create_task(self._background_save_activation(memories_to_activate, base_strength))
|
||||
|
||||
logger.debug(f"快速批量激活 {len(memories_to_activate)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"快速批量激活记忆失败: {e}")
|
||||
|
||||
async def _background_save_activation(self, memories: list[Memory], base_strength: float) -> None:
|
||||
"""
|
||||
后台保存激活更新并执行传播
|
||||
|
||||
Args:
|
||||
memories: 已更新的记忆列表
|
||||
base_strength: 基础激活强度
|
||||
"""
|
||||
try:
|
||||
# 批量保存到数据库
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
# 简化的激活传播(仅在强度足够时执行)
|
||||
if base_strength > 0.08: # 提高传播阈值,减少传播频率
|
||||
propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.3) # 降低传播强度
|
||||
max_related = getattr(self.config, "max_related_memories", 3) # 减少传播数量
|
||||
|
||||
# 只传播最重要的记忆的激活
|
||||
important_memories = [m for m in memories if m.importance > 0.6][:2] # 最多2个重要记忆
|
||||
|
||||
for memory in important_memories:
|
||||
related_memories = self._get_related_memories(memory.id, max_depth=1) # 减少传播深度
|
||||
propagation_strength = base_strength * propagation_strength_factor
|
||||
|
||||
for related_id in related_memories[:max_related]:
|
||||
try:
|
||||
related_memory = self.graph_store.get_memory_by_id(related_id)
|
||||
if related_memory:
|
||||
# 简单的激活度增加(不调用完整激活方法)
|
||||
current_activation = related_memory.metadata.get("activation", {}).get("level", related_memory.activation)
|
||||
new_activation = min(1.0, current_activation + propagation_strength * 0.5)
|
||||
|
||||
related_memory.activation = new_activation
|
||||
related_memory.metadata["activation"] = {
|
||||
"level": new_activation,
|
||||
"last_access": datetime.now().isoformat(),
|
||||
"access_count": related_memory.metadata.get("activation", {}).get("access_count", 0) + 1,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"传播激活到相关记忆 {related_id[:8]} 失败: {e}")
|
||||
|
||||
# 再次保存传播后的更新
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
logger.debug(f"后台保存激活更新完成,处理了 {len(memories)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"后台保存激活更新失败: {e}")
|
||||
|
||||
async def _batch_propagate_activation(self, memories: list[Memory], base_strength: float) -> None:
|
||||
"""
|
||||
批量传播激活到相关记忆(后台执行)
|
||||
|
||||
Args:
|
||||
memories: 已激活的记忆列表
|
||||
base_strength: 基础激活强度
|
||||
"""
|
||||
try:
|
||||
propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5)
|
||||
propagation_depth = getattr(self.config, "activation_propagation_depth", 2)
|
||||
max_related = getattr(self.config, "max_related_memories", 5)
|
||||
|
||||
# 收集所有需要传播激活的记忆ID
|
||||
propagation_tasks = []
|
||||
for memory in memories:
|
||||
if base_strength > 0.05: # 只有足够强的激活才传播
|
||||
related_memories = self._get_related_memories(
|
||||
memory.id,
|
||||
max_depth=propagation_depth
|
||||
)
|
||||
propagation_strength = base_strength * propagation_strength_factor
|
||||
|
||||
for related_id in related_memories[:max_related]:
|
||||
task = self.activate_memory(related_id, propagation_strength)
|
||||
propagation_tasks.append(task)
|
||||
|
||||
# 批量执行传播任务
|
||||
if propagation_tasks:
|
||||
import asyncio
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*propagation_tasks, return_exceptions=True),
|
||||
timeout=3.0 # 传播操作超时时间稍长
|
||||
)
|
||||
logger.debug(f"激活传播完成: {len(propagation_tasks)} 个相关记忆")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("激活传播超时,部分相关记忆未激活")
|
||||
except Exception as e:
|
||||
logger.warning(f"激活传播失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"批量传播激活失败: {e}")
|
||||
|
||||
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user