refactor: 优化兴趣值管理器和统一调度器,增强任务执行的并发控制

This commit is contained in:
Windpicker-owo
2025-11-07 22:13:00 +08:00
parent ff5d14042c
commit 0cf7f87b66
6 changed files with 173 additions and 75 deletions

View File

@@ -41,19 +41,11 @@ class InterestManager:
async def initialize(self): async def initialize(self):
"""初始化管理器""" """初始化管理器"""
if self._worker_task is None: pass
self._worker_task = asyncio.create_task(self._calculation_worker())
logger.info("兴趣值管理器已启动")
async def shutdown(self): async def shutdown(self):
"""关闭管理器""" """关闭管理器"""
self._shutdown_event.set() self._shutdown_event.set()
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
if self._current_calculator: if self._current_calculator:
await self._current_calculator.cleanup() await self._current_calculator.cleanup()

View File

@@ -453,6 +453,10 @@ class MemoryConfig(ValidatedConfigBase):
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度") activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度") activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 记忆激活配置(强制执行)
auto_activate_base_strength: float = Field(default=0.1, description="记忆被检索时自动激活的基础强度")
auto_activate_max_count: int = Field(default=5, description="单次搜索最多自动激活的记忆数量")
# 性能配置 # 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数") max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量") max_related_memories: int = Field(default=5, description="相关记忆最大数量")

View File

@@ -437,6 +437,11 @@ class MemoryManager:
logger.info( logger.info(
f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})" f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})"
) )
# 强制激活被检索到的记忆(核心功能)
if filtered_memories:
await self._auto_activate_searched_memories(filtered_memories)
return filtered_memories[:top_k] return filtered_memories[:top_k]
except Exception as e: except Exception as e:
@@ -536,6 +541,8 @@ class MemoryManager:
"access_count": activation_info.get("access_count", 0) + 1, "access_count": activation_info.get("access_count", 0) + 1,
}) })
# 同步更新 memory.activation 字段,确保数据一致性
memory.activation = new_activation
memory.metadata["activation"] = activation_info memory.metadata["activation"] = activation_info
memory.last_accessed = now memory.last_accessed = now
@@ -562,6 +569,49 @@ class MemoryManager:
logger.error(f"激活记忆失败: {e}", exc_info=True) logger.error(f"激活记忆失败: {e}", exc_info=True)
return False return False
async def _auto_activate_searched_memories(self, memories: list[Memory]) -> None:
"""
自动激活被搜索到的记忆
Args:
memories: 被检索到的记忆列表
"""
try:
# 获取配置参数
base_strength = getattr(self.config, "auto_activate_base_strength", 0.1)
max_activate_count = getattr(self.config, "auto_activate_max_count", 5)
# 激活强度根据记忆重要性调整
activate_tasks = []
for i, memory in enumerate(memories[:max_activate_count]):
# 重要性越高,激活强度越大
strength = base_strength * (0.5 + memory.importance)
# 创建异步激活任务
task = self.activate_memory(memory.id, strength=strength)
activate_tasks.append(task)
if i >= max_activate_count - 1:
break
# 并发执行激活任务(但不等待所有完成,避免阻塞搜索)
if activate_tasks:
import asyncio
# 使用 asyncio.gather 但设置较短的 timeout
try:
await asyncio.wait_for(
asyncio.gather(*activate_tasks, return_exceptions=True),
timeout=2.0 # 2秒超时避免阻塞主流程
)
logger.debug(f"自动激活 {len(activate_tasks)} 条记忆完成")
except asyncio.TimeoutError:
logger.warning(f"自动激活记忆超时,已激活部分记忆")
except Exception as e:
logger.warning(f"自动激活记忆失败: {e}")
except Exception as e:
logger.warning(f"自动激活搜索记忆失败: {e}")
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]: def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> list[str]:
""" """
获取相关记忆 ID 列表(旧版本,保留用于激活传播) 获取相关记忆 ID 列表(旧版本,保留用于激活传播)

View File

@@ -199,6 +199,17 @@ class Memory:
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> Memory: def from_dict(cls, data: dict[str, Any]) -> Memory:
"""从字典创建记忆""" """从字典创建记忆"""
metadata = data.get("metadata", {})
# 优先从 metadata 中获取激活度信息
activation_level = 0.0
activation_info = metadata.get("activation", {})
if activation_info and "level" in activation_info:
activation_level = activation_info["level"]
else:
# 备选:使用直接的 activation 字段
activation_level = data.get("activation", 0.0)
return cls( return cls(
id=data["id"], id=data["id"],
subject_id=data["subject_id"], subject_id=data["subject_id"],
@@ -206,13 +217,13 @@ class Memory:
nodes=[MemoryNode.from_dict(n) for n in data["nodes"]], nodes=[MemoryNode.from_dict(n) for n in data["nodes"]],
edges=[MemoryEdge.from_dict(e) for e in data["edges"]], edges=[MemoryEdge.from_dict(e) for e in data["edges"]],
importance=data.get("importance", 0.5), importance=data.get("importance", 0.5),
activation=data.get("activation", 0.0), activation=activation_level, # 使用统一的激活度值
status=MemoryStatus(data.get("status", "staged")), status=MemoryStatus(data.get("status", "staged")),
created_at=datetime.fromisoformat(data["created_at"]), created_at=datetime.fromisoformat(data["created_at"]),
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])), last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
access_count=data.get("access_count", 0), access_count=data.get("access_count", 0),
decay_factor=data.get("decay_factor", 1.0), decay_factor=data.get("decay_factor", 1.0),
metadata=data.get("metadata", {}), metadata=metadata,
) )
def update_access(self) -> None: def update_access(self) -> None:

View File

@@ -552,7 +552,7 @@ class MemoryTools:
for memory_id in sorted_memory_ids: for memory_id in sorted_memory_ids:
memory = self.graph_store.get_memory_by_id(memory_id) memory = self.graph_store.get_memory_by_id(memory_id)
if memory: if memory:
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%) # 综合评分:相似度(40%) + 重要性(20%) + 时效性(10%) + 激活度(30%)
similarity_score = final_scores[memory_id] similarity_score = final_scores[memory_id]
importance_score = memory.importance importance_score = memory.importance
@@ -567,11 +567,20 @@ class MemoryTools:
age_days = (now - memory_time).total_seconds() / 86400 age_days = (now - memory_time).total_seconds() / 86400
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期 recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
# 综合分数 # 获取激活度分数从metadata中读取兼容memory.activation字段
activation_info = memory.metadata.get("activation", {})
activation_score = activation_info.get("level", memory.activation)
# 如果metadata中没有激活度信息使用memory.activation作为备选
if activation_score == 0.0 and memory.activation > 0.0:
activation_score = memory.activation
# 综合分数 - 加入激活度影响
final_score = ( final_score = (
similarity_score * 0.6 + similarity_score * 0.4 + # 向量相似度 40%
importance_score * 0.3 + importance_score * 0.2 + # 重要性 20%
recency_score * 0.1 recency_score * 0.1 + # 时效性 10%
activation_score * 0.3 # 激活度 30% ← 新增
) )
memories_with_scores.append((memory, final_score)) memories_with_scores.append((memory, final_score))

View File

@@ -82,6 +82,7 @@ class UnifiedScheduler:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件 self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
self._executing_tasks: dict[str, asyncio.Task] = {} # 追踪正在执行的任务 self._executing_tasks: dict[str, asyncio.Task] = {} # 追踪正在执行的任务
self._execution_lock = asyncio.Lock() # 专门用于保护执行任务的并发访问
async def _handle_event_trigger(self, event_name: str | EventType, event_params: dict[str, Any]) -> None: async def _handle_event_trigger(self, event_name: str | EventType, event_params: dict[str, Any]) -> None:
"""处理来自 event_manager 的事件通知 """处理来自 event_manager 的事件通知
@@ -118,23 +119,25 @@ class UnifiedScheduler:
logger.debug(f"[调度器] 事件 '{event_name}' 触发,共有 {len(event_tasks)} 个调度任务") logger.debug(f"[调度器] 事件 '{event_name}' 触发,共有 {len(event_tasks)} 个调度任务")
# 并发执行所有事件任务 # 并发执行所有事件任务
execution_tasks = [] async with self._execution_lock:
for task in event_tasks: execution_tasks = []
execution_task = asyncio.create_task( for task in event_tasks:
self._execute_event_task_callback(task, event_params), execution_task = asyncio.create_task(
name=f"execute_event_{task.task_name}" self._execute_event_task_callback(task, event_params),
) name=f"execute_event_{task.task_name}"
execution_tasks.append(execution_task) )
execution_tasks.append(execution_task)
# 追踪正在执行的任务 # 追踪正在执行的任务
self._executing_tasks[task.schedule_id] = execution_task self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成 # 等待所有任务完成
results = await asyncio.gather(*execution_tasks, return_exceptions=True) results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪 # 清理执行追踪
for task in event_tasks: async with self._execution_lock:
self._executing_tasks.pop(task.schedule_id, None) for task in event_tasks:
self._executing_tasks.pop(task.schedule_id, None)
# 收集需要移除的任务 # 收集需要移除的任务
tasks_to_remove = [] tasks_to_remove = []
@@ -194,17 +197,29 @@ class UnifiedScheduler:
except ImportError: except ImportError:
pass pass
# 取消所有正在执行的任务 # 取消所有正在执行的任务(避免在锁内进行阻塞操作)
executing_tasks = list(self._executing_tasks.values()) executing_tasks = list(self._executing_tasks.values())
if executing_tasks: if executing_tasks:
logger.debug(f"取消 {len(executing_tasks)} 个正在执行的任务") logger.debug(f"取消 {len(executing_tasks)} 个正在执行的任务")
# 在取消任务前先清空追踪,避免死锁
self._executing_tasks.clear()
# 在锁外取消任务
for task in executing_tasks: for task in executing_tasks:
if not task.done(): if not task.done():
task.cancel() task.cancel()
# 等待所有任务取消完成
await asyncio.gather(*executing_tasks, return_exceptions=True) # 等待所有任务取消完成,使用较长的超时时间
try:
await asyncio.wait_for(
asyncio.gather(*executing_tasks, return_exceptions=True),
timeout=10.0
)
except asyncio.TimeoutError:
logger.warning("部分任务取消超时,强制停止")
logger.info("统一调度器已停止") logger.info("统一调度器已停止")
# 清空资源时不需要锁,因为已经停止运行
self._tasks.clear() self._tasks.clear()
self._event_subscriptions.clear() self._event_subscriptions.clear()
self._executing_tasks.clear() self._executing_tasks.clear()
@@ -259,23 +274,25 @@ class UnifiedScheduler:
return return
# 为每个任务创建独立的异步任务,确保并发执行 # 为每个任务创建独立的异步任务,确保并发执行
execution_tasks = [] async with self._execution_lock:
for task in tasks_to_trigger: execution_tasks = []
execution_task = asyncio.create_task( for task in tasks_to_trigger:
self._execute_task_callback(task, current_time), execution_task = asyncio.create_task(
name=f"execute_{task.task_name}" self._execute_task_callback(task, current_time),
) name=f"execute_{task.task_name}"
execution_tasks.append(execution_task) )
execution_tasks.append(execution_task)
# 追踪正在执行的任务,以便在 remove_schedule 时可以取消 # 追踪正在执行的任务,以便在 remove_schedule 时可以取消
self._executing_tasks[task.schedule_id] = execution_task self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务) # 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务)
results = await asyncio.gather(*execution_tasks, return_exceptions=True) results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪 # 清理执行追踪
for task in tasks_to_trigger: async with self._execution_lock:
self._executing_tasks.pop(task.schedule_id, None) for task in tasks_to_trigger:
self._executing_tasks.pop(task.schedule_id, None)
# 第三阶段:收集需要移除的任务并在锁内移除 # 第三阶段:收集需要移除的任务并在锁内移除
tasks_to_remove = [] tasks_to_remove = []
@@ -378,9 +395,8 @@ class UnifiedScheduler:
# 如果不是循环任务,需要移除 # 如果不是循环任务,需要移除
if not task.is_recurring: if not task.is_recurring:
async with self._lock: await self._remove_task_internal(task.schedule_id)
await self._remove_task_internal(task.schedule_id) logger.debug(f"[调度器] 一次性任务 {task.task_name} 已完成并移除")
logger.debug(f"[调度器] 一次性任务 {task.task_name} 已完成并移除")
return True return True
@@ -458,20 +474,21 @@ class UnifiedScheduler:
logger.error(f"执行任务 {task.task_name} 的回调函数时出错: {e}", exc_info=True) logger.error(f"执行任务 {task.task_name} 的回调函数时出错: {e}", exc_info=True)
async def _remove_task_internal(self, schedule_id: str): async def _remove_task_internal(self, schedule_id: str):
"""内部方法:移除任务(不加锁""" """内部方法:移除任务(需要加锁保护"""
task = self._tasks.pop(schedule_id, None) async with self._lock:
if task: task = self._tasks.pop(schedule_id, None)
if task.trigger_type == TriggerType.EVENT: if task:
event_name = task.trigger_config.get("event_name") if task.trigger_type == TriggerType.EVENT:
if event_name: event_name = task.trigger_config.get("event_name")
has_other_subscribers = any( if event_name:
t.trigger_type == TriggerType.EVENT and t.trigger_config.get("event_name") == event_name has_other_subscribers = any(
for t in self._tasks.values() t.trigger_type == TriggerType.EVENT and t.trigger_config.get("event_name") == event_name
) for t in self._tasks.values()
# 如果没有其他任务订阅此事件,从追踪集合中移除 )
if not has_other_subscribers and event_name in self._event_subscriptions: # 如果没有其他任务订阅此事件,从追踪集合中移除
self._event_subscriptions.discard(event_name) if not has_other_subscribers and event_name in self._event_subscriptions:
logger.debug(f"事件 '{event_name}' 已无订阅任务,从追踪中移除") self._event_subscriptions.discard(event_name)
logger.debug(f"事件 '{event_name}' 已无订阅任务,从追踪中移除")
async def create_schedule( async def create_schedule(
self, self,
@@ -518,27 +535,40 @@ class UnifiedScheduler:
如果任务正在执行,会取消执行中的任务 如果任务正在执行,会取消执行中的任务
""" """
# 先获取任务信息和执行任务,避免长时间持有锁
async with self._lock: async with self._lock:
if schedule_id not in self._tasks: if schedule_id not in self._tasks:
logger.warning(f"尝试移除不存在的任务: {schedule_id}") logger.warning(f"尝试移除不存在的任务: {schedule_id}")
return False return False
task = self._tasks[schedule_id] task = self._tasks[schedule_id]
# 检查是否有正在执行的任务
executing_task = self._executing_tasks.get(schedule_id) executing_task = self._executing_tasks.get(schedule_id)
if executing_task and not executing_task.done():
logger.debug(f"取消正在执行的任务: {task.task_name}") # 在锁外取消正在执行的任务,避免死锁
if executing_task and not executing_task.done():
logger.debug(f"取消正在执行的任务: {task.task_name}")
try:
executing_task.cancel() executing_task.cancel()
await asyncio.wait_for(executing_task, 3) await asyncio.wait_for(executing_task, 3)
self._executing_tasks.pop(schedule_id, None) except asyncio.TimeoutError:
logger.warning(f"取消任务 {task.task_name} 超时,强制移除")
except Exception as e:
logger.warning(f"取消任务 {task.task_name} 时发生错误: {e}")
# 重新获取锁移除任务
async with self._lock:
await self._remove_task_internal(schedule_id) await self._remove_task_internal(schedule_id)
logger.debug(f"移除调度任务: {task.task_name}")
return True # 使用执行锁清理执行追踪
async with self._execution_lock:
self._executing_tasks.pop(schedule_id, None)
logger.debug(f"移除调度任务: {task.task_name}")
return True
async def trigger_schedule(self, schedule_id: str) -> bool: async def trigger_schedule(self, schedule_id: str) -> bool:
"""强制触发指定任务""" """强制触发指定任务"""
# 先获取任务信息,减少锁持有时间
async with self._lock: async with self._lock:
task = self._tasks.get(schedule_id) task = self._tasks.get(schedule_id)
if not task: if not task:
@@ -550,16 +580,17 @@ class UnifiedScheduler:
return False return False
# 检查任务是否已经在执行中 # 检查任务是否已经在执行中
if schedule_id in self._executing_tasks: executing_task = self._executing_tasks.get(schedule_id)
executing_task = self._executing_tasks[schedule_id] if executing_task and not executing_task.done():
if not executing_task.done(): logger.warning(f"任务 {task.task_name} 已在执行中,无法重复触发")
logger.warning(f"任务 {task.task_name} 已在执行中,无法重复触发") return False
return False
else:
# 任务已完成但未清理,先清理
self._executing_tasks.pop(schedule_id, None)
# 释放锁,在锁外执行任务 # 清理已完成的任务
if executing_task and executing_task.done():
self._executing_tasks.pop(schedule_id, None)
# 在锁外创建执行任务
async with self._execution_lock:
execution_task = asyncio.create_task( execution_task = asyncio.create_task(
self._execute_trigger_task_callback(task), self._execute_trigger_task_callback(task),
name=f"trigger_{task.task_name}" name=f"trigger_{task.task_name}"
@@ -574,7 +605,8 @@ class UnifiedScheduler:
return result return result
finally: finally:
# 清理执行追踪 # 清理执行追踪
self._executing_tasks.pop(schedule_id, None) async with self._execution_lock:
self._executing_tasks.pop(schedule_id, None)
async def pause_schedule(self, schedule_id: str) -> bool: async def pause_schedule(self, schedule_id: str) -> bool:
"""暂停任务(不删除)""" """暂停任务(不删除)"""