diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py index faf77b888..af01d1fb5 100644 --- a/src/chat/interest_system/interest_manager.py +++ b/src/chat/interest_system/interest_manager.py @@ -41,19 +41,11 @@ class InterestManager: async def initialize(self): """初始化管理器""" - if self._worker_task is None: - self._worker_task = asyncio.create_task(self._calculation_worker()) - logger.info("兴趣值管理器已启动") + pass async def shutdown(self): """关闭管理器""" self._shutdown_event.set() - if self._worker_task: - self._worker_task.cancel() - try: - await self._worker_task - except asyncio.CancelledError: - pass if self._current_calculator: await self._current_calculator.cleanup() diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 6bc7e602e..f40e48eea 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -453,6 +453,10 @@ class MemoryConfig(ValidatedConfigBase): activation_propagation_strength: float = Field(default=0.5, description="激活传播强度") activation_propagation_depth: int = Field(default=2, description="激活传播深度") + # 记忆激活配置(强制执行) + auto_activate_base_strength: float = Field(default=0.1, description="记忆被检索时自动激活的基础强度") + auto_activate_max_count: int = Field(default=5, description="单次搜索最多自动激活的记忆数量") + # 性能配置 max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数") max_related_memories: int = Field(default=5, description="相关记忆最大数量") diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 5becfe0b1..d2ba2951a 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -437,6 +437,11 @@ class MemoryManager: logger.info( f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})" ) + + # 强制激活被检索到的记忆(核心功能) + if filtered_memories: + await self._auto_activate_searched_memories(filtered_memories) + return filtered_memories[:top_k] except Exception as e: @@ -536,6 +541,8 @@ class MemoryManager: "access_count": activation_info.get("access_count", 0) + 1, }) + # 同步更新 memory.activation 字段,确保数据一致性 + memory.activation = new_activation memory.metadata["activation"] = activation_info memory.last_accessed = now @@ -562,6 +569,49 @@ class MemoryManager: logger.error(f"激活记忆失败: {e}", exc_info=True) return False + async def _auto_activate_searched_memories(self, memories: list[Memory]) -> None: + """ + 自动激活被搜索到的记忆 + + Args: + memories: 被检索到的记忆列表 + """ + try: + # 获取配置参数 + base_strength = getattr(self.config, "auto_activate_base_strength", 0.1) + max_activate_count = getattr(self.config, "auto_activate_max_count", 5) + + # 激活强度根据记忆重要性调整 + activate_tasks = [] + for i, memory in enumerate(memories[:max_activate_count]): + # 重要性越高,激活强度越大 + strength = base_strength * (0.5 + memory.importance) + + # 创建异步激活任务 + task = self.activate_memory(memory.id, strength=strength) + activate_tasks.append(task) + + if i >= max_activate_count - 1: + break + + # 并发执行激活任务(但不等待所有完成,避免阻塞搜索) + if activate_tasks: + import asyncio + # 使用 asyncio.gather 但设置较短的 timeout + try: + await asyncio.wait_for( + asyncio.gather(*activate_tasks, return_exceptions=True), + timeout=2.0 # 2秒超时,避免阻塞主流程 + ) + logger.debug(f"自动激活 {len(activate_tasks)} 条记忆完成") + except asyncio.TimeoutError: + logger.warning(f"自动激活记忆超时,已激活部分记忆") + except Exception as e: + logger.warning(f"自动激活记忆失败: {e}") + + except Exception as e: + logger.warning(f"自动激活搜索记忆失败: {e}") + def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]: """ 获取相关记忆 ID 列表(旧版本,保留用于激活传播) diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index 6f7f1109d..0441c9bf3 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -199,6 +199,17 @@ class Memory: @classmethod def from_dict(cls, data: dict[str, Any]) -> Memory: """从字典创建记忆""" + metadata = data.get("metadata", {}) + + # 优先从 metadata 中获取激活度信息 + activation_level = 0.0 + activation_info = metadata.get("activation", {}) + if activation_info and "level" in activation_info: + activation_level = activation_info["level"] + else: + # 备选:使用直接的 activation 字段 + activation_level = data.get("activation", 0.0) + return cls( id=data["id"], subject_id=data["subject_id"], @@ -206,13 +217,13 @@ class Memory: nodes=[MemoryNode.from_dict(n) for n in data["nodes"]], edges=[MemoryEdge.from_dict(e) for e in data["edges"]], importance=data.get("importance", 0.5), - activation=data.get("activation", 0.0), + activation=activation_level, # 使用统一的激活度值 status=MemoryStatus(data.get("status", "staged")), created_at=datetime.fromisoformat(data["created_at"]), last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])), access_count=data.get("access_count", 0), decay_factor=data.get("decay_factor", 1.0), - metadata=data.get("metadata", {}), + metadata=metadata, ) def update_access(self) -> None: diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 8986ce732..d26ac9f64 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -552,7 +552,7 @@ class MemoryTools: for memory_id in sorted_memory_ids: memory = self.graph_store.get_memory_by_id(memory_id) if memory: - # 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%) + # 综合评分:相似度(40%) + 重要性(20%) + 时效性(10%) + 激活度(30%) similarity_score = final_scores[memory_id] importance_score = memory.importance @@ -567,11 +567,20 @@ class MemoryTools: age_days = (now - memory_time).total_seconds() / 86400 recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期 - # 综合分数 + # 获取激活度分数(从metadata中读取,兼容memory.activation字段) + activation_info = memory.metadata.get("activation", {}) + activation_score = activation_info.get("level", memory.activation) + + # 如果metadata中没有激活度信息,使用memory.activation作为备选 + if activation_score == 0.0 and memory.activation > 0.0: + activation_score = memory.activation + + # 综合分数 - 加入激活度影响 final_score = ( - similarity_score * 0.6 + - importance_score * 0.3 + - recency_score * 0.1 + similarity_score * 0.4 + # 向量相似度 40% + importance_score * 0.2 + # 重要性 20% + recency_score * 0.1 + # 时效性 10% + activation_score * 0.3 # 激活度 30% ← 新增 ) memories_with_scores.append((memory, final_score)) diff --git a/src/schedule/unified_scheduler.py b/src/schedule/unified_scheduler.py index 195518c96..1ae1bb61f 100644 --- a/src/schedule/unified_scheduler.py +++ b/src/schedule/unified_scheduler.py @@ -82,6 +82,7 @@ class UnifiedScheduler: self._lock = asyncio.Lock() self._event_subscriptions: set[str] = set() # 追踪已订阅的事件 self._executing_tasks: dict[str, asyncio.Task] = {} # 追踪正在执行的任务 + self._execution_lock = asyncio.Lock() # 专门用于保护执行任务的并发访问 async def _handle_event_trigger(self, event_name: str | EventType, event_params: dict[str, Any]) -> None: """处理来自 event_manager 的事件通知 @@ -118,23 +119,25 @@ class UnifiedScheduler: logger.debug(f"[调度器] 事件 '{event_name}' 触发,共有 {len(event_tasks)} 个调度任务") # 并发执行所有事件任务 - execution_tasks = [] - for task in event_tasks: - execution_task = asyncio.create_task( - self._execute_event_task_callback(task, event_params), - name=f"execute_event_{task.task_name}" - ) - execution_tasks.append(execution_task) + async with self._execution_lock: + execution_tasks = [] + for task in event_tasks: + execution_task = asyncio.create_task( + self._execute_event_task_callback(task, event_params), + name=f"execute_event_{task.task_name}" + ) + execution_tasks.append(execution_task) - # 追踪正在执行的任务 - self._executing_tasks[task.schedule_id] = execution_task + # 追踪正在执行的任务 + self._executing_tasks[task.schedule_id] = execution_task # 等待所有任务完成 results = await asyncio.gather(*execution_tasks, return_exceptions=True) # 清理执行追踪 - for task in event_tasks: - self._executing_tasks.pop(task.schedule_id, None) + async with self._execution_lock: + for task in event_tasks: + self._executing_tasks.pop(task.schedule_id, None) # 收集需要移除的任务 tasks_to_remove = [] @@ -194,17 +197,29 @@ class UnifiedScheduler: except ImportError: pass - # 取消所有正在执行的任务 + # 取消所有正在执行的任务(避免在锁内进行阻塞操作) executing_tasks = list(self._executing_tasks.values()) if executing_tasks: logger.debug(f"取消 {len(executing_tasks)} 个正在执行的任务") + # 在取消任务前先清空追踪,避免死锁 + self._executing_tasks.clear() + + # 在锁外取消任务 for task in executing_tasks: if not task.done(): task.cancel() - # 等待所有任务取消完成 - await asyncio.gather(*executing_tasks, return_exceptions=True) + + # 等待所有任务取消完成,使用较长的超时时间 + try: + await asyncio.wait_for( + asyncio.gather(*executing_tasks, return_exceptions=True), + timeout=10.0 + ) + except asyncio.TimeoutError: + logger.warning("部分任务取消超时,强制停止") logger.info("统一调度器已停止") + # 清空资源时不需要锁,因为已经停止运行 self._tasks.clear() self._event_subscriptions.clear() self._executing_tasks.clear() @@ -259,23 +274,25 @@ class UnifiedScheduler: return # 为每个任务创建独立的异步任务,确保并发执行 - execution_tasks = [] - for task in tasks_to_trigger: - execution_task = asyncio.create_task( - self._execute_task_callback(task, current_time), - name=f"execute_{task.task_name}" - ) - execution_tasks.append(execution_task) + async with self._execution_lock: + execution_tasks = [] + for task in tasks_to_trigger: + execution_task = asyncio.create_task( + self._execute_task_callback(task, current_time), + name=f"execute_{task.task_name}" + ) + execution_tasks.append(execution_task) - # 追踪正在执行的任务,以便在 remove_schedule 时可以取消 - self._executing_tasks[task.schedule_id] = execution_task + # 追踪正在执行的任务,以便在 remove_schedule 时可以取消 + self._executing_tasks[task.schedule_id] = execution_task # 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务) results = await asyncio.gather(*execution_tasks, return_exceptions=True) # 清理执行追踪 - for task in tasks_to_trigger: - self._executing_tasks.pop(task.schedule_id, None) + async with self._execution_lock: + for task in tasks_to_trigger: + self._executing_tasks.pop(task.schedule_id, None) # 第三阶段:收集需要移除的任务并在锁内移除 tasks_to_remove = [] @@ -378,9 +395,8 @@ class UnifiedScheduler: # 如果不是循环任务,需要移除 if not task.is_recurring: - async with self._lock: - await self._remove_task_internal(task.schedule_id) - logger.debug(f"[调度器] 一次性任务 {task.task_name} 已完成并移除") + await self._remove_task_internal(task.schedule_id) + logger.debug(f"[调度器] 一次性任务 {task.task_name} 已完成并移除") return True @@ -458,20 +474,21 @@ class UnifiedScheduler: logger.error(f"执行任务 {task.task_name} 的回调函数时出错: {e}", exc_info=True) async def _remove_task_internal(self, schedule_id: str): - """内部方法:移除任务(不加锁)""" - task = self._tasks.pop(schedule_id, None) - if task: - if task.trigger_type == TriggerType.EVENT: - event_name = task.trigger_config.get("event_name") - if event_name: - has_other_subscribers = any( - t.trigger_type == TriggerType.EVENT and t.trigger_config.get("event_name") == event_name - for t in self._tasks.values() - ) - # 如果没有其他任务订阅此事件,从追踪集合中移除 - if not has_other_subscribers and event_name in self._event_subscriptions: - self._event_subscriptions.discard(event_name) - logger.debug(f"事件 '{event_name}' 已无订阅任务,从追踪中移除") + """内部方法:移除任务(需要加锁保护)""" + async with self._lock: + task = self._tasks.pop(schedule_id, None) + if task: + if task.trigger_type == TriggerType.EVENT: + event_name = task.trigger_config.get("event_name") + if event_name: + has_other_subscribers = any( + t.trigger_type == TriggerType.EVENT and t.trigger_config.get("event_name") == event_name + for t in self._tasks.values() + ) + # 如果没有其他任务订阅此事件,从追踪集合中移除 + if not has_other_subscribers and event_name in self._event_subscriptions: + self._event_subscriptions.discard(event_name) + logger.debug(f"事件 '{event_name}' 已无订阅任务,从追踪中移除") async def create_schedule( self, @@ -518,27 +535,40 @@ class UnifiedScheduler: 如果任务正在执行,会取消执行中的任务 """ + # 先获取任务信息和执行任务,避免长时间持有锁 async with self._lock: if schedule_id not in self._tasks: logger.warning(f"尝试移除不存在的任务: {schedule_id}") return False task = self._tasks[schedule_id] - - # 检查是否有正在执行的任务 executing_task = self._executing_tasks.get(schedule_id) - if executing_task and not executing_task.done(): - logger.debug(f"取消正在执行的任务: {task.task_name}") + + # 在锁外取消正在执行的任务,避免死锁 + if executing_task and not executing_task.done(): + logger.debug(f"取消正在执行的任务: {task.task_name}") + try: executing_task.cancel() await asyncio.wait_for(executing_task, 3) - self._executing_tasks.pop(schedule_id, None) + except asyncio.TimeoutError: + logger.warning(f"取消任务 {task.task_name} 超时,强制移除") + except Exception as e: + logger.warning(f"取消任务 {task.task_name} 时发生错误: {e}") + # 重新获取锁移除任务 + async with self._lock: await self._remove_task_internal(schedule_id) - logger.debug(f"移除调度任务: {task.task_name}") - return True + + # 使用执行锁清理执行追踪 + async with self._execution_lock: + self._executing_tasks.pop(schedule_id, None) + + logger.debug(f"移除调度任务: {task.task_name}") + return True async def trigger_schedule(self, schedule_id: str) -> bool: """强制触发指定任务""" + # 先获取任务信息,减少锁持有时间 async with self._lock: task = self._tasks.get(schedule_id) if not task: @@ -550,16 +580,17 @@ class UnifiedScheduler: return False # 检查任务是否已经在执行中 - if schedule_id in self._executing_tasks: - executing_task = self._executing_tasks[schedule_id] - if not executing_task.done(): - logger.warning(f"任务 {task.task_name} 已在执行中,无法重复触发") - return False - else: - # 任务已完成但未清理,先清理 - self._executing_tasks.pop(schedule_id, None) + executing_task = self._executing_tasks.get(schedule_id) + if executing_task and not executing_task.done(): + logger.warning(f"任务 {task.task_name} 已在执行中,无法重复触发") + return False - # 释放锁,在锁外执行任务 + # 清理已完成的任务 + if executing_task and executing_task.done(): + self._executing_tasks.pop(schedule_id, None) + + # 在锁外创建执行任务 + async with self._execution_lock: execution_task = asyncio.create_task( self._execute_trigger_task_callback(task), name=f"trigger_{task.task_name}" @@ -574,7 +605,8 @@ class UnifiedScheduler: return result finally: # 清理执行追踪 - self._executing_tasks.pop(schedule_id, None) + async with self._execution_lock: + self._executing_tasks.pop(schedule_id, None) async def pause_schedule(self, schedule_id: str) -> bool: """暂停任务(不删除)"""