From 816bfdb8e0fd2db2a7c4c42388dd4de28dbc8c9d Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 7 Nov 2025 22:49:41 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=BF=80=E6=B4=BB=E6=9C=BA=E5=88=B6=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=BF=AB=E9=80=9F=E6=89=B9=E9=87=8F=E6=BF=80=E6=B4=BB?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E4=BB=A5=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_graph/manager.py | 243 ++++++++++++++++++++++++++++++++---- 1 file changed, 216 insertions(+), 27 deletions(-) diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index d2ba2951a..ef7305efa 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -438,9 +438,9 @@ class MemoryManager: f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})" ) - # 强制激活被检索到的记忆(核心功能) + # 强制激活被检索到的记忆(核心功能)- 使用快速批量激活 if filtered_memories: - await self._auto_activate_searched_memories(filtered_memories) + await self._quick_batch_activate_memories(filtered_memories) return filtered_memories[:top_k] @@ -571,46 +571,235 @@ class MemoryManager: async def _auto_activate_searched_memories(self, memories: list[Memory]) -> None: """ - 自动激活被搜索到的记忆 + 批量激活被搜索到的记忆 Args: memories: 被检索到的记忆列表 """ try: + if not memories: + return + # 获取配置参数 base_strength = getattr(self.config, "auto_activate_base_strength", 0.1) max_activate_count = getattr(self.config, "auto_activate_max_count", 5) + decay_rate = getattr(self.config, "activation_decay_rate", 0.9) + now = datetime.now() - # 激活强度根据记忆重要性调整 - activate_tasks = [] - for i, memory in enumerate(memories[:max_activate_count]): - # 重要性越高,激活强度越大 + # 限制处理的记忆数量 + memories_to_activate = memories[:max_activate_count] + + # 批量更新激活度 + activation_updates = [] + for memory in memories_to_activate: + # 计算激活强度 strength = base_strength * (0.5 + memory.importance) - # 创建异步激活任务 - task = self.activate_memory(memory.id, strength=strength) - activate_tasks.append(task) + # 获取当前激活度信息 + activation_info = memory.metadata.get("activation", {}) + last_access = activation_info.get("last_access") - if i >= max_activate_count - 1: - break + if last_access: + # 计算时间衰减 + last_access_dt = datetime.fromisoformat(last_access) + hours_passed = (now - last_access_dt).total_seconds() / 3600 + decay_factor = decay_rate ** (hours_passed / 24) + current_activation = activation_info.get("level", 0.0) * decay_factor + else: + current_activation = 0.0 - # 并发执行激活任务(但不等待所有完成,避免阻塞搜索) - if activate_tasks: - import asyncio - # 使用 asyncio.gather 但设置较短的 timeout - try: - await asyncio.wait_for( - asyncio.gather(*activate_tasks, return_exceptions=True), - timeout=2.0 # 2秒超时,避免阻塞主流程 - ) - logger.debug(f"自动激活 {len(activate_tasks)} 条记忆完成") - except asyncio.TimeoutError: - logger.warning(f"自动激活记忆超时,已激活部分记忆") - except Exception as e: - logger.warning(f"自动激活记忆失败: {e}") + # 计算新的激活度 + new_activation = min(1.0, current_activation + strength) + + # 更新记忆对象 + memory.activation = new_activation + memory.last_accessed = now + activation_info.update({ + "level": new_activation, + "last_access": now.isoformat(), + "access_count": activation_info.get("access_count", 0) + 1, + }) + memory.metadata["activation"] = activation_info + + activation_updates.append({ + "memory_id": memory.id, + "old_activation": current_activation, + "new_activation": new_activation, + "strength": strength + }) + + # 批量保存到数据库 + if activation_updates: + await self.persistence.save_graph_store(self.graph_store) + + # 激活传播(异步执行,不阻塞主流程) + asyncio.create_task(self._batch_propagate_activation(memories_to_activate, base_strength)) + + logger.debug(f"批量激活 {len(activation_updates)} 条记忆完成") except Exception as e: - logger.warning(f"自动激活搜索记忆失败: {e}") + logger.warning(f"批量激活搜索记忆失败: {e}") + + async def _quick_batch_activate_memories(self, memories: list[Memory]) -> None: + """ + 快速批量激活记忆(用于搜索结果,优化性能) + + 与 _auto_activate_searched_memories 的区别: + - 更轻量级,专注于速度 + - 简化激活传播逻辑 + - 减少数据库写入次数 + + Args: + memories: 需要激活的记忆列表 + """ + try: + if not memories: + return + + # 获取配置参数 + base_strength = getattr(self.config, "auto_activate_base_strength", 0.1) + max_activate_count = getattr(self.config, "auto_activate_max_count", 5) + decay_rate = getattr(self.config, "activation_decay_rate", 0.9) + now = datetime.now() + + # 限制处理的记忆数量 + memories_to_activate = memories[:max_activate_count] + + # 批量更新激活度(内存操作) + for memory in memories_to_activate: + # 计算激活强度 + strength = base_strength * (0.5 + memory.importance) + + # 快速计算新的激活度(简化版) + activation_info = memory.metadata.get("activation", {}) + last_access = activation_info.get("last_access") + + if last_access: + # 简化的时间衰减计算 + try: + last_access_dt = datetime.fromisoformat(last_access) + hours_passed = (now - last_access_dt).total_seconds() / 3600 + decay_factor = decay_rate ** (hours_passed / 24) + current_activation = activation_info.get("level", 0.0) * decay_factor + except (ValueError, TypeError): + current_activation = activation_info.get("level", 0.0) * 0.9 # 默认衰减 + else: + current_activation = 0.0 + + # 计算新的激活度 + new_activation = min(1.0, current_activation + strength) + + # 直接更新记忆对象(内存中) + memory.activation = new_activation + memory.last_accessed = now + activation_info.update({ + "level": new_activation, + "last_access": now.isoformat(), + "access_count": activation_info.get("access_count", 0) + 1, + }) + memory.metadata["activation"] = activation_info + + # 异步批量保存(不阻塞搜索) + if memories_to_activate: + asyncio.create_task(self._background_save_activation(memories_to_activate, base_strength)) + + logger.debug(f"快速批量激活 {len(memories_to_activate)} 条记忆") + + except Exception as e: + logger.warning(f"快速批量激活记忆失败: {e}") + + async def _background_save_activation(self, memories: list[Memory], base_strength: float) -> None: + """ + 后台保存激活更新并执行传播 + + Args: + memories: 已更新的记忆列表 + base_strength: 基础激活强度 + """ + try: + # 批量保存到数据库 + await self.persistence.save_graph_store(self.graph_store) + + # 简化的激活传播(仅在强度足够时执行) + if base_strength > 0.08: # 提高传播阈值,减少传播频率 + propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.3) # 降低传播强度 + max_related = getattr(self.config, "max_related_memories", 3) # 减少传播数量 + + # 只传播最重要的记忆的激活 + important_memories = [m for m in memories if m.importance > 0.6][:2] # 最多2个重要记忆 + + for memory in important_memories: + related_memories = self._get_related_memories(memory.id, max_depth=1) # 减少传播深度 + propagation_strength = base_strength * propagation_strength_factor + + for related_id in related_memories[:max_related]: + try: + related_memory = self.graph_store.get_memory_by_id(related_id) + if related_memory: + # 简单的激活度增加(不调用完整激活方法) + current_activation = related_memory.metadata.get("activation", {}).get("level", related_memory.activation) + new_activation = min(1.0, current_activation + propagation_strength * 0.5) + + related_memory.activation = new_activation + related_memory.metadata["activation"] = { + "level": new_activation, + "last_access": datetime.now().isoformat(), + "access_count": related_memory.metadata.get("activation", {}).get("access_count", 0) + 1, + } + except Exception as e: + logger.debug(f"传播激活到相关记忆 {related_id[:8]} 失败: {e}") + + # 再次保存传播后的更新 + await self.persistence.save_graph_store(self.graph_store) + + logger.debug(f"后台保存激活更新完成,处理了 {len(memories)} 条记忆") + + except Exception as e: + logger.warning(f"后台保存激活更新失败: {e}") + + async def _batch_propagate_activation(self, memories: list[Memory], base_strength: float) -> None: + """ + 批量传播激活到相关记忆(后台执行) + + Args: + memories: 已激活的记忆列表 + base_strength: 基础激活强度 + """ + try: + propagation_strength_factor = getattr(self.config, "activation_propagation_strength", 0.5) + propagation_depth = getattr(self.config, "activation_propagation_depth", 2) + max_related = getattr(self.config, "max_related_memories", 5) + + # 收集所有需要传播激活的记忆ID + propagation_tasks = [] + for memory in memories: + if base_strength > 0.05: # 只有足够强的激活才传播 + related_memories = self._get_related_memories( + memory.id, + max_depth=propagation_depth + ) + propagation_strength = base_strength * propagation_strength_factor + + for related_id in related_memories[:max_related]: + task = self.activate_memory(related_id, propagation_strength) + propagation_tasks.append(task) + + # 批量执行传播任务 + if propagation_tasks: + import asyncio + try: + await asyncio.wait_for( + asyncio.gather(*propagation_tasks, return_exceptions=True), + timeout=3.0 # 传播操作超时时间稍长 + ) + logger.debug(f"激活传播完成: {len(propagation_tasks)} 个相关记忆") + except asyncio.TimeoutError: + logger.warning("激活传播超时,部分相关记忆未激活") + except Exception as e: + logger.warning(f"激活传播失败: {e}") + + except Exception as e: + logger.warning(f"批量传播激活失败: {e}") def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]: """