异步记忆系统优化 & Action组件修复

主要改进:
1. 异步记忆系统优化 - 解决记忆操作阻塞主程序问题
   - 新增异步记忆队列管理器 (async_memory_optimizer.py)
   - 新增异步瞬时记忆包装器 (async_instant_memory_wrapper.py)
   - 优化主程序记忆构建任务为后台非阻塞执行
   - 优化消息处理器记忆调用,增加超时保护和回退机制

2. Action组件修复 - 解决'未找到Action组件: no_reply'问题
   - 修复no_reply动作激活类型配置错误
   - 新增reply回退动作 (reply.py)
   - 增强planner.py动作选择回退机制
   - 增强cycle_processor.py动作创建回退机制
This commit is contained in:
Furina-1013-create
2025-08-22 13:16:19 +08:00
parent ce2e5bd199
commit 980221d589
9 changed files with 1271 additions and 12 deletions

View File

@@ -0,0 +1,390 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Action组件诊断和修复脚本
检查no_reply等核心Action是否正确注册并尝试修复相关问题
"""
import sys
import os
from typing import Dict, Any
# 添加项目路径
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.component_types import ComponentType
logger = get_logger("action_diagnostics")
class ActionDiagnostics:
"""Action组件诊断器"""
def __init__(self):
self.required_actions = ["no_reply", "reply", "emoji", "at_user"]
def check_plugin_loading(self) -> Dict[str, Any]:
"""检查插件加载状态"""
logger.info("开始检查插件加载状态...")
result = {
"plugins_loaded": False,
"total_plugins": 0,
"loaded_plugins": [],
"failed_plugins": [],
"core_actions_plugin": None
}
try:
# 加载所有插件
plugin_manager.load_all_plugins()
# 获取插件统计信息
stats = plugin_manager.get_stats()
result["plugins_loaded"] = True
result["total_plugins"] = stats.get("total_plugins", 0)
# 检查是否有core_actions插件
for plugin_name in plugin_manager.loaded_plugins:
result["loaded_plugins"].append(plugin_name)
if "core_actions" in plugin_name.lower():
result["core_actions_plugin"] = plugin_name
logger.info(f"插件加载成功,总数: {result['total_plugins']}")
logger.info(f"已加载插件: {result['loaded_plugins']}")
except Exception as e:
logger.error(f"插件加载失败: {e}")
result["error"] = str(e)
return result
def check_action_registry(self) -> Dict[str, Any]:
"""检查Action注册状态"""
logger.info("开始检查Action组件注册状态...")
result = {
"registered_actions": [],
"missing_actions": [],
"default_actions": {},
"total_actions": 0
}
try:
# 获取所有注册的Action
all_components = component_registry.get_all_components(ComponentType.ACTION)
result["total_actions"] = len(all_components)
for name, info in all_components.items():
result["registered_actions"].append(name)
logger.debug(f"已注册Action: {name} (插件: {info.plugin_name})")
# 检查必需的Action是否存在
for required_action in self.required_actions:
if required_action not in all_components:
result["missing_actions"].append(required_action)
logger.warning(f"缺失必需Action: {required_action}")
else:
logger.info(f"找到必需Action: {required_action}")
# 获取默认Action
default_actions = component_registry.get_default_actions()
result["default_actions"] = {name: info.plugin_name for name, info in default_actions.items()}
logger.info(f"总注册Action数量: {result['total_actions']}")
logger.info(f"缺失Action: {result['missing_actions']}")
except Exception as e:
logger.error(f"Action注册检查失败: {e}")
result["error"] = str(e)
return result
def check_specific_action(self, action_name: str) -> Dict[str, Any]:
"""检查特定Action的详细信息"""
logger.info(f"检查Action详细信息: {action_name}")
result = {
"exists": False,
"component_info": None,
"component_class": None,
"is_default": False,
"plugin_name": None
}
try:
# 检查组件信息
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
if component_info:
result["exists"] = True
result["component_info"] = {
"name": component_info.name,
"description": component_info.description,
"plugin_name": component_info.plugin_name,
"version": component_info.version
}
result["plugin_name"] = component_info.plugin_name
logger.info(f"找到Action组件信息: {action_name}")
else:
logger.warning(f"未找到Action组件信息: {action_name}")
return result
# 检查组件类
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
if component_class:
result["component_class"] = component_class.__name__
logger.info(f"找到Action组件类: {component_class.__name__}")
else:
logger.warning(f"未找到Action组件类: {action_name}")
# 检查是否为默认Action
default_actions = component_registry.get_default_actions()
result["is_default"] = action_name in default_actions
logger.info(f"Action {action_name} 检查完成: 存在={result['exists']}, 默认={result['is_default']}")
except Exception as e:
logger.error(f"检查Action {action_name} 失败: {e}")
result["error"] = str(e)
return result
def attempt_fix_missing_actions(self) -> Dict[str, Any]:
"""尝试修复缺失的Action"""
logger.info("尝试修复缺失的Action组件...")
result = {
"fixed_actions": [],
"still_missing": [],
"errors": []
}
try:
# 重新加载插件
plugin_manager.load_all_plugins()
# 再次检查Action注册状态
registry_check = self.check_action_registry()
for required_action in self.required_actions:
if required_action in registry_check["missing_actions"]:
try:
# 尝试手动注册核心Action
if required_action == "no_reply":
self._register_no_reply_action()
result["fixed_actions"].append(required_action)
else:
result["still_missing"].append(required_action)
except Exception as e:
error_msg = f"修复Action {required_action} 失败: {e}"
logger.error(error_msg)
result["errors"].append(error_msg)
result["still_missing"].append(required_action)
logger.info(f"Action修复完成: 已修复={result['fixed_actions']}, 仍缺失={result['still_missing']}")
except Exception as e:
error_msg = f"Action修复过程失败: {e}"
logger.error(error_msg)
result["errors"].append(error_msg)
return result
def _register_no_reply_action(self):
"""手动注册no_reply Action"""
try:
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.plugin_system.base.component_types import ActionInfo
# 创建Action信息
action_info = ActionInfo(
name="no_reply",
description="暂时不回复消息",
plugin_name="built_in.core_actions",
version="1.0.0"
)
# 注册Action
success = component_registry._register_action_component(action_info, NoReplyAction)
if success:
logger.info("手动注册no_reply Action成功")
else:
raise Exception("注册失败")
except Exception as e:
raise Exception(f"手动注册no_reply Action失败: {e}")
def run_full_diagnosis(self) -> Dict[str, Any]:
"""运行完整诊断"""
logger.info("🔧 开始Action组件完整诊断")
logger.info("=" * 60)
diagnosis_result = {
"plugin_status": {},
"registry_status": {},
"action_details": {},
"fix_attempts": {},
"summary": {}
}
# 1. 检查插件加载
logger.info("\n📦 步骤1: 检查插件加载状态")
diagnosis_result["plugin_status"] = self.check_plugin_loading()
# 2. 检查Action注册
logger.info("\n📋 步骤2: 检查Action注册状态")
diagnosis_result["registry_status"] = self.check_action_registry()
# 3. 检查特定Action详细信息
logger.info("\n🔍 步骤3: 检查特定Action详细信息")
diagnosis_result["action_details"] = {}
for action in self.required_actions:
diagnosis_result["action_details"][action] = self.check_specific_action(action)
# 4. 尝试修复缺失的Action
if diagnosis_result["registry_status"].get("missing_actions"):
logger.info("\n🔧 步骤4: 尝试修复缺失的Action")
diagnosis_result["fix_attempts"] = self.attempt_fix_missing_actions()
# 5. 生成诊断摘要
logger.info("\n📊 步骤5: 生成诊断摘要")
diagnosis_result["summary"] = self._generate_summary(diagnosis_result)
self._print_diagnosis_results(diagnosis_result)
return diagnosis_result
def _generate_summary(self, diagnosis_result: Dict[str, Any]) -> Dict[str, Any]:
"""生成诊断摘要"""
summary = {
"overall_status": "unknown",
"critical_issues": [],
"recommendations": []
}
try:
# 检查插件加载状态
if not diagnosis_result["plugin_status"].get("plugins_loaded"):
summary["critical_issues"].append("插件加载失败")
summary["recommendations"].append("检查插件系统配置")
# 检查必需Action
missing_actions = diagnosis_result["registry_status"].get("missing_actions", [])
if "no_reply" in missing_actions:
summary["critical_issues"].append("缺失no_reply Action")
summary["recommendations"].append("检查core_actions插件是否正确加载")
# 检查修复结果
if diagnosis_result.get("fix_attempts"):
still_missing = diagnosis_result["fix_attempts"].get("still_missing", [])
if still_missing:
summary["critical_issues"].append(f"修复后仍缺失Action: {still_missing}")
summary["recommendations"].append("需要手动修复插件注册问题")
# 确定整体状态
if not summary["critical_issues"]:
summary["overall_status"] = "healthy"
elif len(summary["critical_issues"]) <= 2:
summary["overall_status"] = "warning"
else:
summary["overall_status"] = "critical"
except Exception as e:
summary["critical_issues"].append(f"摘要生成失败: {e}")
summary["overall_status"] = "error"
return summary
def _print_diagnosis_results(self, diagnosis_result: Dict[str, Any]):
"""打印诊断结果"""
logger.info("\n" + "=" * 60)
logger.info("📈 诊断结果摘要")
logger.info("=" * 60)
summary = diagnosis_result.get("summary", {})
overall_status = summary.get("overall_status", "unknown")
# 状态指示器
status_indicators = {
"healthy": "✅ 系统健康",
"warning": "⚠️ 存在警告",
"critical": "❌ 存在严重问题",
"error": "💥 诊断出错",
"unknown": "❓ 状态未知"
}
logger.info(f"🎯 整体状态: {status_indicators.get(overall_status, overall_status)}")
# 关键问题
critical_issues = summary.get("critical_issues", [])
if critical_issues:
logger.info("\n🚨 关键问题:")
for issue in critical_issues:
logger.info(f"{issue}")
# 建议
recommendations = summary.get("recommendations", [])
if recommendations:
logger.info("\n💡 建议:")
for rec in recommendations:
logger.info(f"{rec}")
# 详细状态
plugin_status = diagnosis_result.get("plugin_status", {})
if plugin_status.get("plugins_loaded"):
logger.info(f"\n📦 插件状态: 已加载 {plugin_status.get('total_plugins', 0)} 个插件")
else:
logger.info("\n📦 插件状态: ❌ 插件加载失败")
registry_status = diagnosis_result.get("registry_status", {})
total_actions = registry_status.get("total_actions", 0)
missing_actions = registry_status.get("missing_actions", [])
logger.info(f"📋 Action状态: 已注册 {total_actions} 个,缺失 {len(missing_actions)}")
if missing_actions:
logger.info(f" 缺失的Action: {missing_actions}")
logger.info("\n" + "=" * 60)
def main():
"""主函数"""
diagnostics = ActionDiagnostics()
try:
result = diagnostics.run_full_diagnosis()
# 保存诊断结果
import json
with open("action_diagnosis_results.json", "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False, default=str)
logger.info("📄 诊断结果已保存到: action_diagnosis_results.json")
# 根据诊断结果返回适当的退出代码
summary = result.get("summary", {})
overall_status = summary.get("overall_status", "unknown")
if overall_status == "healthy":
return 0
elif overall_status == "warning":
return 1
else:
return 2
except KeyboardInterrupt:
logger.info("❌ 诊断被用户中断")
return 3
except Exception as e:
logger.error(f"❌ 诊断执行失败: {e}")
import traceback
traceback.print_exc()
return 4
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
exit_code = main()
sys.exit(exit_code)

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""
异步瞬时记忆包装器
提供对现有瞬时记忆系统的异步包装,支持超时控制和回退机制
"""
import asyncio
import time
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("async_instant_memory_wrapper")
class AsyncInstantMemoryWrapper:
"""异步瞬时记忆包装器"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.llm_memory = None
self.vector_memory = None
self.cache: Dict[str, tuple[Any, float]] = {} # 缓存:(结果, 时间戳)
self.cache_ttl = 300 # 缓存5分钟
self.default_timeout = 3.0 # 默认超时3秒
# 延迟加载记忆系统
self._initialize_memory_systems()
def _initialize_memory_systems(self):
"""延迟初始化记忆系统"""
try:
# 初始化LLM记忆系统
from src.chat.memory_system.instant_memory import InstantMemory
self.llm_memory = InstantMemory(self.chat_id)
logger.debug(f"LLM瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
logger.warning(f"LLM瞬时记忆系统初始化失败: {e}")
try:
# 初始化向量记忆系统
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
self.vector_memory = VectorInstantMemoryV2(self.chat_id)
logger.debug(f"向量瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
logger.warning(f"向量瞬时记忆系统初始化失败: {e}")
def _get_cache_key(self, operation: str, content: str) -> str:
"""生成缓存键"""
return f"{operation}_{self.chat_id}_{hash(content)}"
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
_, timestamp = self.cache[cache_key]
return time.time() - timestamp < self.cache_ttl
def _get_cached_result(self, cache_key: str) -> Optional[Any]:
"""获取缓存结果"""
if self._is_cache_valid(cache_key):
result, _ = self.cache[cache_key]
return result
return None
def _cache_result(self, cache_key: str, result: Any):
"""缓存结果"""
self.cache[cache_key] = (result, time.time())
async def store_memory_async(self, content: str, timeout: float = None) -> bool:
"""异步存储记忆(带超时控制)"""
if timeout is None:
timeout = self.default_timeout
success_count = 0
total_systems = 0
# 异步存储到LLM记忆系统
if self.llm_memory:
total_systems += 1
try:
await asyncio.wait_for(
self.llm_memory.create_and_store_memory(content),
timeout=timeout
)
success_count += 1
logger.debug(f"LLM记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"LLM记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"LLM记忆存储失败: {e}")
# 异步存储到向量记忆系统
if self.vector_memory:
total_systems += 1
try:
await asyncio.wait_for(
self.vector_memory.store_message(content),
timeout=timeout
)
success_count += 1
logger.debug(f"向量记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"向量记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"向量记忆存储失败: {e}")
return success_count > 0
async def retrieve_memory_async(self, query: str, timeout: float = None,
use_cache: bool = True) -> Optional[Any]:
"""异步检索记忆(带缓存和超时控制)"""
if timeout is None:
timeout = self.default_timeout
# 检查缓存
if use_cache:
cache_key = self._get_cache_key("retrieve", query)
cached_result = self._get_cached_result(cache_key)
if cached_result is not None:
logger.debug(f"记忆检索命中缓存: {query[:30]}...")
return cached_result
# 尝试多种记忆系统
results = []
# 从向量记忆系统检索(优先,速度快)
if self.vector_memory:
try:
vector_result = await asyncio.wait_for(
self.vector_memory.get_memory_for_context(query),
timeout=timeout * 0.6 # 给向量系统60%的时间
)
if vector_result:
results.append(vector_result)
logger.debug(f"向量记忆检索成功: {query[:30]}...")
except asyncio.TimeoutError:
logger.warning(f"向量记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"向量记忆检索失败: {e}")
# 从LLM记忆系统检索备用更准确但较慢
if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM
try:
llm_result = await asyncio.wait_for(
self.llm_memory.get_memory(query),
timeout=timeout * 0.4 # 给LLM系统40%的时间
)
if llm_result:
results.extend(llm_result)
logger.debug(f"LLM记忆检索成功: {query[:30]}...")
except asyncio.TimeoutError:
logger.warning(f"LLM记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"LLM记忆检索失败: {e}")
# 合并结果
final_result = None
if results:
if len(results) == 1:
final_result = results[0]
else:
# 合并多个结果
if isinstance(results[0], str):
final_result = "\n".join(str(r) for r in results)
elif isinstance(results[0], list):
final_result = []
for r in results:
if isinstance(r, list):
final_result.extend(r)
else:
final_result.append(r)
else:
final_result = results[0] # 使用第一个结果
# 缓存结果
if use_cache and final_result is not None:
cache_key = self._get_cache_key("retrieve", query)
self._cache_result(cache_key, final_result)
return final_result
async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str:
"""获取记忆的回退方法,保证不会长时间阻塞"""
try:
# 首先尝试快速检索
result = await self.retrieve_memory_async(query, timeout=max_timeout)
if result:
if isinstance(result, list):
return "\n".join(str(item) for item in result)
return str(result)
return ""
except Exception as e:
logger.error(f"记忆检索完全失败: {e}")
return ""
def store_memory_background(self, content: str):
"""在后台存储记忆(发后即忘模式)"""
async def background_store():
try:
await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时
except Exception as e:
logger.error(f"后台记忆存储失败: {e}")
# 创建后台任务
asyncio.create_task(background_store())
def get_status(self) -> Dict[str, Any]:
"""获取包装器状态"""
return {
"chat_id": self.chat_id,
"llm_memory_available": self.llm_memory is not None,
"vector_memory_available": self.vector_memory is not None,
"cache_entries": len(self.cache),
"cache_ttl": self.cache_ttl,
"default_timeout": self.default_timeout
}
def clear_cache(self):
"""清理缓存"""
self.cache.clear()
logger.info(f"记忆缓存已清理: {self.chat_id}")
# 缓存包装器实例,避免重复创建
_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {}
def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper:
"""获取异步瞬时记忆包装器实例"""
if chat_id not in _wrapper_cache:
_wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id)
return _wrapper_cache[chat_id]
def clear_wrapper_cache():
"""清理包装器缓存"""
global _wrapper_cache
_wrapper_cache.clear()
logger.info("异步瞬时记忆包装器缓存已清理")

View File

@@ -0,0 +1,337 @@
# -*- coding: utf-8 -*-
"""
异步记忆系统优化器
解决记忆系统阻塞主程序的问题,将同步操作改为异步非阻塞操作
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from queue import Queue
import threading
from concurrent.futures import ThreadPoolExecutor
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("async_memory_optimizer")
@dataclass
class MemoryTask:
"""记忆任务数据结构"""
task_id: str
task_type: str # "store", "retrieve", "build"
chat_id: str
content: str
priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级
callback: Optional[Callable] = None
created_at: float = None
def __post_init__(self):
if self.created_at is None:
self.created_at = time.time()
class AsyncMemoryQueue:
"""异步记忆任务队列管理器"""
def __init__(self, max_workers: int = 3):
self.max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.task_queue = asyncio.Queue()
self.running_tasks: Dict[str, asyncio.Task] = {}
self.completed_tasks: Dict[str, Any] = {}
self.failed_tasks: Dict[str, str] = {}
self.is_running = False
self.worker_tasks: List[asyncio.Task] = []
async def start(self):
"""启动异步队列处理器"""
if self.is_running:
return
self.is_running = True
# 启动多个工作协程
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(f"worker-{i}"))
self.worker_tasks.append(worker)
logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}")
async def stop(self):
"""停止队列处理器"""
self.is_running = False
# 等待所有工作任务完成
for task in self.worker_tasks:
task.cancel()
await asyncio.gather(*self.worker_tasks, return_exceptions=True)
self.executor.shutdown(wait=True)
logger.info("异步记忆队列已停止")
async def _worker(self, worker_name: str):
"""工作协程,处理队列中的任务"""
logger.info(f"记忆处理工作线程 {worker_name} 启动")
while self.is_running:
try:
# 等待任务超时1秒避免永久阻塞
task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0)
# 执行任务
await self._execute_task(task, worker_name)
except asyncio.TimeoutError:
# 超时正常,继续下一次循环
continue
except Exception as e:
logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}")
async def _execute_task(self, task: MemoryTask, worker_name: str):
"""执行具体的记忆任务"""
try:
logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}")
start_time = time.time()
# 根据任务类型执行不同的处理逻辑
result = None
if task.task_type == "store":
result = await self._handle_store_task(task)
elif task.task_type == "retrieve":
result = await self._handle_retrieve_task(task)
elif task.task_type == "build":
result = await self._handle_build_task(task)
else:
raise ValueError(f"未知的任务类型: {task.task_type}")
# 记录完成的任务
self.completed_tasks[task.task_id] = result
execution_time = time.time() - start_time
logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)")
# 执行回调函数
if task.callback:
try:
if asyncio.iscoroutinefunction(task.callback):
await task.callback(result)
else:
task.callback(result)
except Exception as e:
logger.error(f"任务回调执行失败: {e}")
except Exception as e:
error_msg = f"任务执行失败: {e}"
logger.error(f"[{worker_name}] {error_msg}")
self.failed_tasks[task.task_id] = error_msg
# 执行错误回调
if task.callback:
try:
if asyncio.iscoroutinefunction(task.callback):
await task.callback(None)
else:
task.callback(None)
except:
pass
async def _handle_store_task(self, task: MemoryTask) -> Any:
"""处理记忆存储任务"""
# 这里需要根据具体的记忆系统来实现
# 为了避免循环导入,这里使用延迟导入
try:
from src.chat.memory_system.instant_memory import InstantMemory
instant_memory = InstantMemory(task.chat_id)
await instant_memory.create_and_store_memory(task.content)
return True
except Exception as e:
logger.error(f"记忆存储失败: {e}")
return False
async def _handle_retrieve_task(self, task: MemoryTask) -> Any:
"""处理记忆检索任务"""
try:
from src.chat.memory_system.instant_memory import InstantMemory
instant_memory = InstantMemory(task.chat_id)
memories = await instant_memory.get_memory(task.content)
return memories or []
except Exception as e:
logger.error(f"记忆检索失败: {e}")
return []
async def _handle_build_task(self, task: MemoryTask) -> Any:
"""处理记忆构建任务(海马体系统)"""
try:
# 延迟导入避免循环依赖
if global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager
if hippocampus_manager._initialized:
await hippocampus_manager.build_memory()
return True
return False
except Exception as e:
logger.error(f"记忆构建失败: {e}")
return False
async def add_task(self, task: MemoryTask) -> str:
"""添加任务到队列"""
await self.task_queue.put(task)
self.running_tasks[task.task_id] = task
logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}")
return task.task_id
def get_task_result(self, task_id: str) -> Optional[Any]:
"""获取任务结果(非阻塞)"""
return self.completed_tasks.get(task_id)
def is_task_completed(self, task_id: str) -> bool:
"""检查任务是否完成"""
return task_id in self.completed_tasks or task_id in self.failed_tasks
def get_queue_status(self) -> Dict[str, Any]:
"""获取队列状态"""
return {
"is_running": self.is_running,
"queue_size": self.task_queue.qsize(),
"running_tasks": len(self.running_tasks),
"completed_tasks": len(self.completed_tasks),
"failed_tasks": len(self.failed_tasks),
"worker_count": len(self.worker_tasks)
}
class NonBlockingMemoryManager:
"""非阻塞记忆管理器"""
def __init__(self):
self.queue = AsyncMemoryQueue(max_workers=3)
self.cache: Dict[str, Any] = {}
self.cache_ttl: Dict[str, float] = {}
self.cache_timeout = 300 # 缓存5分钟
async def initialize(self):
"""初始化管理器"""
await self.queue.start()
logger.info("非阻塞记忆管理器已初始化")
async def shutdown(self):
"""关闭管理器"""
await self.queue.stop()
logger.info("非阻塞记忆管理器已关闭")
async def store_memory_async(self, chat_id: str, content: str,
callback: Optional[Callable] = None) -> str:
"""异步存储记忆(非阻塞)"""
task = MemoryTask(
task_id=f"store_{chat_id}_{int(time.time() * 1000)}",
task_type="store",
chat_id=chat_id,
content=content,
priority=1, # 存储优先级较低
callback=callback
)
return await self.queue.add_task(task)
async def retrieve_memory_async(self, chat_id: str, query: str,
callback: Optional[Callable] = None) -> str:
"""异步检索记忆(非阻塞)"""
# 先检查缓存
cache_key = f"retrieve_{chat_id}_{hash(query)}"
if self._is_cache_valid(cache_key):
result = self.cache[cache_key]
if callback:
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
return "cache_hit"
task = MemoryTask(
task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}",
task_type="retrieve",
chat_id=chat_id,
content=query,
priority=2, # 检索优先级中等
callback=self._create_cache_callback(cache_key, callback)
)
return await self.queue.add_task(task)
async def build_memory_async(self, callback: Optional[Callable] = None) -> str:
"""异步构建记忆(非阻塞)"""
task = MemoryTask(
task_id=f"build_memory_{int(time.time() * 1000)}",
task_type="build",
chat_id="system",
content="",
priority=1, # 构建优先级较低,避免影响用户体验
callback=callback
)
return await self.queue.add_task(task)
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout
def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]):
"""创建带缓存的回调函数"""
async def cache_callback(result):
# 存储到缓存
if result is not None:
self.cache[cache_key] = result
self.cache_ttl[cache_key] = time.time()
# 执行原始回调
if original_callback:
if asyncio.iscoroutinefunction(original_callback):
await original_callback(result)
else:
original_callback(result)
return cache_callback
def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]:
"""获取缓存的记忆(同步,立即返回)"""
cache_key = f"retrieve_{chat_id}_{hash(query)}"
if self._is_cache_valid(cache_key):
return self.cache[cache_key]
return None
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
status = self.queue.get_queue_status()
status.update({
"cache_entries": len(self.cache),
"cache_timeout": self.cache_timeout
})
return status
# 全局实例
async_memory_manager = NonBlockingMemoryManager()
# 便捷函数
async def store_memory_nonblocking(chat_id: str, content: str) -> str:
"""非阻塞存储记忆的便捷函数"""
return await async_memory_manager.store_memory_async(chat_id, content)
async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]:
"""非阻塞检索记忆的便捷函数,支持缓存"""
# 先尝试从缓存获取
cached_result = async_memory_manager.get_cached_memory(chat_id, query)
if cached_result is not None:
return cached_result
# 缓存未命中,启动异步检索
await async_memory_manager.retrieve_memory_async(chat_id, query)
return None # 返回None表示需要异步获取
async def build_memory_nonblocking() -> str:
"""非阻塞构建记忆的便捷函数"""
return await async_memory_manager.build_memory_async()