feat(memory): 实现增强记忆系统并完全替换原有架构
引入全新的增强记忆系统,彻底取代海马体记忆架构 删除旧版记忆系统相关模块,包括Hippocampus、异步包装器和优化器 重构消息处理流程,集成增强记忆系统的存储和检索功能 更新配置结构以支持增强记忆的各项参数设置 禁用原有定时任务,采用内置维护机制保证系统性能
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,248 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
异步瞬时记忆包装器
|
||||
提供对现有瞬时记忆系统的异步包装,支持超时控制和回退机制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("async_instant_memory_wrapper")
|
||||
|
||||
|
||||
class AsyncInstantMemoryWrapper:
|
||||
"""异步瞬时记忆包装器"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.llm_memory = None
|
||||
self.vector_memory = None
|
||||
self.cache: Dict[str, tuple[Any, float]] = {} # 缓存:(结果, 时间戳)
|
||||
self.cache_ttl = 300 # 缓存5分钟
|
||||
self.default_timeout = 3.0 # 默认超时3秒
|
||||
|
||||
# 从配置中读取是否启用各种记忆系统
|
||||
self.llm_memory_enabled = global_config.memory.enable_llm_instant_memory
|
||||
self.vector_memory_enabled = global_config.memory.enable_vector_instant_memory
|
||||
|
||||
async def _ensure_llm_memory(self):
|
||||
"""确保LLM记忆系统已初始化"""
|
||||
if self.llm_memory is None and self.llm_memory_enabled:
|
||||
try:
|
||||
from src.chat.memory_system.instant_memory import InstantMemory
|
||||
|
||||
self.llm_memory = InstantMemory(self.chat_id)
|
||||
logger.info(f"LLM瞬时记忆系统已初始化: {self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM瞬时记忆系统初始化失败: {e}")
|
||||
self.llm_memory_enabled = False # 初始化失败则禁用
|
||||
|
||||
async def _ensure_vector_memory(self):
|
||||
"""确保向量记忆系统已初始化"""
|
||||
if self.vector_memory is None and self.vector_memory_enabled:
|
||||
try:
|
||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||
|
||||
self.vector_memory = VectorInstantMemoryV2(self.chat_id)
|
||||
logger.info(f"向量瞬时记忆系统已初始化: {self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"向量瞬时记忆系统初始化失败: {e}")
|
||||
self.vector_memory_enabled = False # 初始化失败则禁用
|
||||
|
||||
def _get_cache_key(self, operation: str, content: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{operation}_{self.chat_id}_{hash(content)}"
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
_, timestamp = self.cache[cache_key]
|
||||
return time.time() - timestamp < self.cache_ttl
|
||||
|
||||
def _get_cached_result(self, cache_key: str) -> Optional[Any]:
|
||||
"""获取缓存结果"""
|
||||
if self._is_cache_valid(cache_key):
|
||||
result, _ = self.cache[cache_key]
|
||||
return result
|
||||
return None
|
||||
|
||||
def _cache_result(self, cache_key: str, result: Any):
|
||||
"""缓存结果"""
|
||||
self.cache[cache_key] = (result, time.time())
|
||||
|
||||
async def store_memory_async(self, content: str, timeout: Optional[float] = None) -> bool:
|
||||
"""异步存储记忆(带超时控制)"""
|
||||
if timeout is None:
|
||||
timeout = self.default_timeout
|
||||
|
||||
success_count = 0
|
||||
|
||||
# 异步存储到LLM记忆系统
|
||||
await self._ensure_llm_memory()
|
||||
if self.llm_memory:
|
||||
try:
|
||||
await asyncio.wait_for(self.llm_memory.create_and_store_memory(content), timeout=timeout)
|
||||
success_count += 1
|
||||
logger.debug(f"LLM记忆存储成功: {content[:50]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"LLM记忆存储超时: {content[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆存储失败: {e}")
|
||||
|
||||
# 异步存储到向量记忆系统
|
||||
await self._ensure_vector_memory()
|
||||
if self.vector_memory:
|
||||
try:
|
||||
await asyncio.wait_for(self.vector_memory.store_message(content), timeout=timeout)
|
||||
success_count += 1
|
||||
logger.debug(f"向量记忆存储成功: {content[:50]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"向量记忆存储超时: {content[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"向量记忆存储失败: {e}")
|
||||
|
||||
return success_count > 0
|
||||
|
||||
async def retrieve_memory_async(
|
||||
self, query: str, timeout: Optional[float] = None, use_cache: bool = True
|
||||
) -> Optional[Any]:
|
||||
"""异步检索记忆(带缓存和超时控制)"""
|
||||
if timeout is None:
|
||||
timeout = self.default_timeout
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("retrieve", query)
|
||||
cached_result = self._get_cached_result(cache_key)
|
||||
if cached_result is not None:
|
||||
logger.debug(f"记忆检索命中缓存: {query[:30]}...")
|
||||
return cached_result
|
||||
|
||||
# 尝试多种记忆系统
|
||||
results = []
|
||||
|
||||
# 从向量记忆系统检索(优先,速度快)
|
||||
await self._ensure_vector_memory()
|
||||
if self.vector_memory:
|
||||
try:
|
||||
vector_result = await asyncio.wait_for(
|
||||
self.vector_memory.get_memory_for_context(query),
|
||||
timeout=timeout * 0.6, # 给向量系统60%的时间
|
||||
)
|
||||
if vector_result:
|
||||
results.append(vector_result)
|
||||
logger.debug(f"向量记忆检索成功: {query[:30]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"向量记忆检索超时: {query[:30]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"向量记忆检索失败: {e}")
|
||||
|
||||
# 从LLM记忆系统检索(备用,更准确但较慢)
|
||||
await self._ensure_llm_memory()
|
||||
if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM
|
||||
try:
|
||||
llm_result = await asyncio.wait_for(
|
||||
self.llm_memory.get_memory(query),
|
||||
timeout=timeout * 0.4, # 给LLM系统40%的时间
|
||||
)
|
||||
if llm_result:
|
||||
results.extend(llm_result)
|
||||
logger.debug(f"LLM记忆检索成功: {query[:30]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"LLM记忆检索超时: {query[:30]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆检索失败: {e}")
|
||||
|
||||
# 合并结果
|
||||
final_result = None
|
||||
if results:
|
||||
if len(results) == 1:
|
||||
final_result = results[0]
|
||||
else:
|
||||
# 合并多个结果
|
||||
if isinstance(results[0], str):
|
||||
final_result = "\n".join(str(r) for r in results)
|
||||
elif isinstance(results[0], list):
|
||||
final_result = []
|
||||
for r in results:
|
||||
if isinstance(r, list):
|
||||
final_result.extend(r)
|
||||
else:
|
||||
final_result.append(r)
|
||||
else:
|
||||
final_result = results[0] # 使用第一个结果
|
||||
|
||||
# 缓存结果
|
||||
if use_cache and final_result is not None:
|
||||
cache_key = self._get_cache_key("retrieve", query)
|
||||
self._cache_result(cache_key, final_result)
|
||||
|
||||
return final_result
|
||||
|
||||
async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str:
|
||||
"""获取记忆的回退方法,保证不会长时间阻塞"""
|
||||
try:
|
||||
# 首先尝试快速检索
|
||||
result = await self.retrieve_memory_async(query, timeout=max_timeout)
|
||||
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
return "\n".join(str(item) for item in result)
|
||||
return str(result)
|
||||
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索完全失败: {e}")
|
||||
return ""
|
||||
|
||||
def store_memory_background(self, content: str):
|
||||
"""在后台存储记忆(发后即忘模式)"""
|
||||
|
||||
async def background_store():
|
||||
try:
|
||||
await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时
|
||||
except Exception as e:
|
||||
logger.error(f"后台记忆存储失败: {e}")
|
||||
|
||||
# 创建后台任务
|
||||
asyncio.create_task(background_store())
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取包装器状态"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"llm_memory_available": self.llm_memory is not None,
|
||||
"vector_memory_available": self.vector_memory is not None,
|
||||
"cache_entries": len(self.cache),
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"default_timeout": self.default_timeout,
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清理缓存"""
|
||||
self.cache.clear()
|
||||
logger.info(f"记忆缓存已清理: {self.chat_id}")
|
||||
|
||||
|
||||
# 缓存包装器实例,避免重复创建
|
||||
_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {}
|
||||
|
||||
|
||||
def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper:
|
||||
"""获取异步瞬时记忆包装器实例"""
|
||||
if chat_id not in _wrapper_cache:
|
||||
_wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id)
|
||||
return _wrapper_cache[chat_id]
|
||||
|
||||
|
||||
def clear_wrapper_cache():
|
||||
"""清理包装器缓存"""
|
||||
global _wrapper_cache
|
||||
_wrapper_cache.clear()
|
||||
logger.info("异步瞬时记忆包装器缓存已清理")
|
||||
@@ -1,358 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
异步记忆系统优化器
|
||||
解决记忆系统阻塞主程序的问题,将同步操作改为异步非阻塞操作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
|
||||
logger = get_logger("async_memory_optimizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryTask:
|
||||
"""记忆任务数据结构"""
|
||||
|
||||
task_id: str
|
||||
task_type: str # "store", "retrieve", "build"
|
||||
chat_id: str
|
||||
content: str
|
||||
priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级
|
||||
callback: Optional[Callable] = None
|
||||
created_at: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = time.time()
|
||||
|
||||
|
||||
class AsyncMemoryQueue:
|
||||
"""异步记忆任务队列管理器"""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self.max_workers = max_workers
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.task_queue = asyncio.Queue()
|
||||
self.running_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.completed_tasks: Dict[str, Any] = {}
|
||||
self.failed_tasks: Dict[str, str] = {}
|
||||
self.is_running = False
|
||||
self.worker_tasks: List[asyncio.Task] = []
|
||||
|
||||
async def start(self):
|
||||
"""启动异步队列处理器"""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
# 启动多个工作协程
|
||||
for i in range(self.max_workers):
|
||||
worker = asyncio.create_task(self._worker(f"worker-{i}"))
|
||||
self.worker_tasks.append(worker)
|
||||
|
||||
logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}")
|
||||
|
||||
async def stop(self):
|
||||
"""停止队列处理器"""
|
||||
self.is_running = False
|
||||
|
||||
# 等待所有工作任务完成
|
||||
for task in self.worker_tasks:
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*self.worker_tasks, return_exceptions=True)
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("异步记忆队列已停止")
|
||||
|
||||
async def _worker(self, worker_name: str):
|
||||
"""工作协程,处理队列中的任务"""
|
||||
logger.info(f"记忆处理工作线程 {worker_name} 启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待任务,超时1秒避免永久阻塞
|
||||
task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0)
|
||||
|
||||
# 执行任务
|
||||
await self._execute_task(task, worker_name)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常,继续下一次循环
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}")
|
||||
|
||||
async def _execute_task(self, task: MemoryTask, worker_name: str):
|
||||
"""执行具体的记忆任务"""
|
||||
try:
|
||||
logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}")
|
||||
start_time = time.time()
|
||||
|
||||
# 根据任务类型执行不同的处理逻辑
|
||||
result = None
|
||||
if task.task_type == "store":
|
||||
result = await self._handle_store_task(task)
|
||||
elif task.task_type == "retrieve":
|
||||
result = await self._handle_retrieve_task(task)
|
||||
elif task.task_type == "build":
|
||||
result = await self._handle_build_task(task)
|
||||
else:
|
||||
raise ValueError(f"未知的任务类型: {task.task_type}")
|
||||
|
||||
# 记录完成的任务
|
||||
self.completed_tasks[task.task_id] = result
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)")
|
||||
|
||||
# 执行回调函数
|
||||
if task.callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(task.callback):
|
||||
await task.callback(result)
|
||||
else:
|
||||
task.callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"任务回调执行失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"任务执行失败: {e}"
|
||||
logger.error(f"[{worker_name}] {error_msg}")
|
||||
self.failed_tasks[task.task_id] = error_msg
|
||||
|
||||
# 执行错误回调
|
||||
if task.callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(task.callback):
|
||||
await task.callback(None)
|
||||
else:
|
||||
task.callback(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _handle_store_task(task: MemoryTask) -> Any:
|
||||
"""处理记忆存储任务"""
|
||||
# 这里需要根据具体的记忆系统来实现
|
||||
# 为了避免循环导入,这里使用延迟导入
|
||||
try:
|
||||
# 获取包装器实例
|
||||
memory_wrapper = get_async_instant_memory(task.chat_id)
|
||||
|
||||
# 使用包装器中的llm_memory实例
|
||||
if memory_wrapper and memory_wrapper.llm_memory:
|
||||
await memory_wrapper.llm_memory.create_and_store_memory(task.content)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"无法获取记忆系统实例,存储任务失败: chat_id={task.chat_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"记忆存储失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _handle_retrieve_task(task: MemoryTask) -> Any:
|
||||
"""处理记忆检索任务"""
|
||||
try:
|
||||
# 获取包装器实例
|
||||
memory_wrapper = get_async_instant_memory(task.chat_id)
|
||||
|
||||
# 使用包装器中的llm_memory实例
|
||||
if memory_wrapper and memory_wrapper.llm_memory:
|
||||
memories = await memory_wrapper.llm_memory.get_memory(task.content)
|
||||
return memories or []
|
||||
else:
|
||||
logger.warning(f"无法获取记忆系统实例,检索任务失败: chat_id={task.chat_id}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _handle_build_task(task: MemoryTask) -> Any:
|
||||
"""处理记忆构建任务(海马体系统)"""
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
if global_config.memory.enable_memory:
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if hippocampus_manager._initialized:
|
||||
# 确保海马体对象已正确初始化
|
||||
if not hippocampus_manager._hippocampus.parahippocampal_gyrus:
|
||||
logger.warning("海马体对象未完全初始化,进行同步初始化")
|
||||
hippocampus_manager._hippocampus.initialize()
|
||||
|
||||
await hippocampus_manager.build_memory()
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"记忆构建失败: {e}")
|
||||
return False
|
||||
|
||||
async def add_task(self, task: MemoryTask) -> str:
|
||||
"""添加任务到队列"""
|
||||
await self.task_queue.put(task)
|
||||
self.running_tasks[task.task_id] = task
|
||||
logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}")
|
||||
return task.task_id
|
||||
|
||||
def get_task_result(self, task_id: str) -> Optional[Any]:
|
||||
"""获取任务结果(非阻塞)"""
|
||||
return self.completed_tasks.get(task_id)
|
||||
|
||||
def is_task_completed(self, task_id: str) -> bool:
|
||||
"""检查任务是否完成"""
|
||||
return task_id in self.completed_tasks or task_id in self.failed_tasks
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""获取队列状态"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"queue_size": self.task_queue.qsize(),
|
||||
"running_tasks": len(self.running_tasks),
|
||||
"completed_tasks": len(self.completed_tasks),
|
||||
"failed_tasks": len(self.failed_tasks),
|
||||
"worker_count": len(self.worker_tasks),
|
||||
}
|
||||
|
||||
|
||||
class NonBlockingMemoryManager:
|
||||
"""非阻塞记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = AsyncMemoryQueue(max_workers=3)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl: Dict[str, float] = {}
|
||||
self.cache_timeout = 300 # 缓存5分钟
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.queue.start()
|
||||
logger.info("非阻塞记忆管理器已初始化")
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭管理器"""
|
||||
await self.queue.stop()
|
||||
logger.info("非阻塞记忆管理器已关闭")
|
||||
|
||||
async def store_memory_async(self, chat_id: str, content: str, callback: Optional[Callable] = None) -> str:
|
||||
"""异步存储记忆(非阻塞)"""
|
||||
task = MemoryTask(
|
||||
task_id=f"store_{chat_id}_{int(time.time() * 1000)}",
|
||||
task_type="store",
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
priority=1, # 存储优先级较低
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
async def retrieve_memory_async(self, chat_id: str, query: str, callback: Optional[Callable] = None) -> str:
|
||||
"""异步检索记忆(非阻塞)"""
|
||||
# 先检查缓存
|
||||
cache_key = f"retrieve_{chat_id}_{hash(query)}"
|
||||
if self._is_cache_valid(cache_key):
|
||||
result = self.cache[cache_key]
|
||||
if callback:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(result)
|
||||
else:
|
||||
callback(result)
|
||||
return "cache_hit"
|
||||
|
||||
task = MemoryTask(
|
||||
task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}",
|
||||
task_type="retrieve",
|
||||
chat_id=chat_id,
|
||||
content=query,
|
||||
priority=2, # 检索优先级中等
|
||||
callback=self._create_cache_callback(cache_key, callback),
|
||||
)
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
async def build_memory_async(self, callback: Optional[Callable] = None) -> str:
|
||||
"""异步构建记忆(非阻塞)"""
|
||||
task = MemoryTask(
|
||||
task_id=f"build_memory_{int(time.time() * 1000)}",
|
||||
task_type="build",
|
||||
chat_id="system",
|
||||
content="",
|
||||
priority=1, # 构建优先级较低,避免影响用户体验
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout
|
||||
|
||||
def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]):
|
||||
"""创建带缓存的回调函数"""
|
||||
|
||||
async def cache_callback(result):
|
||||
# 存储到缓存
|
||||
if result is not None:
|
||||
self.cache[cache_key] = result
|
||||
self.cache_ttl[cache_key] = time.time()
|
||||
|
||||
# 执行原始回调
|
||||
if original_callback:
|
||||
if asyncio.iscoroutinefunction(original_callback):
|
||||
await original_callback(result)
|
||||
else:
|
||||
original_callback(result)
|
||||
|
||||
return cache_callback
|
||||
|
||||
def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]:
|
||||
"""获取缓存的记忆(同步,立即返回)"""
|
||||
cache_key = f"retrieve_{chat_id}_{hash(query)}"
|
||||
if self._is_cache_valid(cache_key):
|
||||
return self.cache[cache_key]
|
||||
return None
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
status = self.queue.get_queue_status()
|
||||
status.update({"cache_entries": len(self.cache), "cache_timeout": self.cache_timeout})
|
||||
return status
|
||||
|
||||
|
||||
# 全局实例
|
||||
async_memory_manager = NonBlockingMemoryManager()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def store_memory_nonblocking(chat_id: str, content: str) -> str:
|
||||
"""非阻塞存储记忆的便捷函数"""
|
||||
return await async_memory_manager.store_memory_async(chat_id, content)
|
||||
|
||||
|
||||
async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]:
|
||||
"""非阻塞检索记忆的便捷函数,支持缓存"""
|
||||
# 先尝试从缓存获取
|
||||
cached_result = async_memory_manager.get_cached_memory(chat_id, query)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# 缓存未命中,启动异步检索
|
||||
await async_memory_manager.retrieve_memory_async(chat_id, query)
|
||||
return None # 返回None表示需要异步获取
|
||||
|
||||
|
||||
async def build_memory_nonblocking() -> str:
|
||||
"""非阻塞构建记忆的便捷函数"""
|
||||
return await async_memory_manager.build_memory_async()
|
||||
@@ -1,196 +0,0 @@
|
||||
# 记忆系统异步优化说明
|
||||
|
||||
## 🎯 优化目标
|
||||
|
||||
解决MaiBot-Plus记忆系统阻塞主程序的问题,将原本的线性同步调用改为异步非阻塞运行。
|
||||
|
||||
## ⚠️ 问题分析
|
||||
|
||||
### 原有问题
|
||||
1. **瞬时记忆阻塞**:每次用户发消息时,`await self.instant_memory.get_memory_for_context(target)` 会阻塞等待LLM响应
|
||||
2. **定时记忆构建阻塞**:每600秒执行的 `build_memory_task()` 会完全阻塞主程序数十秒
|
||||
3. **LLM调用链阻塞**:记忆存储和检索都需要调用LLM,延迟较高
|
||||
|
||||
### 卡顿表现
|
||||
- 用户发消息后,程序响应延迟明显增加
|
||||
- 定时记忆构建时,整个程序无响应
|
||||
- 高并发时,记忆系统成为性能瓶颈
|
||||
|
||||
## 🚀 优化方案
|
||||
|
||||
### 1. 异步记忆队列系统 (`async_memory_optimizer.py`)
|
||||
|
||||
**核心思想**:将记忆操作放入异步队列,后台处理,不阻塞主程序。
|
||||
|
||||
**关键特性**:
|
||||
- 任务队列管理:支持存储、检索、构建三种任务类型
|
||||
- 优先级调度:高优先级任务(用户查询)优先处理
|
||||
- 线程池执行:避免阻塞事件循环
|
||||
- 结果缓存:减少重复计算
|
||||
- 失败重试:提高系统可靠性
|
||||
|
||||
```python
|
||||
# 使用示例
|
||||
from src.chat.memory_system.async_memory_optimizer import (
|
||||
store_memory_nonblocking,
|
||||
retrieve_memory_nonblocking,
|
||||
build_memory_nonblocking
|
||||
)
|
||||
|
||||
# 非阻塞存储记忆
|
||||
task_id = await store_memory_nonblocking(chat_id, content)
|
||||
|
||||
# 非阻塞检索记忆(支持缓存)
|
||||
memories = await retrieve_memory_nonblocking(chat_id, query)
|
||||
|
||||
# 非阻塞构建记忆
|
||||
task_id = await build_memory_nonblocking()
|
||||
```
|
||||
|
||||
### 2. 异步瞬时记忆包装器 (`async_instant_memory_wrapper.py`)
|
||||
|
||||
**核心思想**:为现有瞬时记忆系统提供异步包装,支持超时控制和多层回退。
|
||||
|
||||
**关键特性**:
|
||||
- 超时控制:防止长时间阻塞
|
||||
- 缓存机制:热点查询快速响应
|
||||
- 多系统融合:LLM记忆 + 向量记忆
|
||||
- 回退策略:保证系统稳定性
|
||||
- 后台存储:存储操作完全非阻塞
|
||||
|
||||
```python
|
||||
# 使用示例
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
|
||||
async_memory = get_async_instant_memory(chat_id)
|
||||
|
||||
# 后台存储(发后即忘)
|
||||
async_memory.store_memory_background(content)
|
||||
|
||||
# 快速检索(带超时)
|
||||
result = await async_memory.get_memory_with_fallback(query, max_timeout=2.0)
|
||||
```
|
||||
|
||||
### 3. 主程序优化
|
||||
|
||||
**记忆构建任务异步化**:
|
||||
- 原来:`await self.hippocampus_manager.build_memory()` 阻塞主程序
|
||||
- 现在:使用异步队列或线程池,后台执行
|
||||
|
||||
**消息处理优化**:
|
||||
- 原来:同步等待记忆检索完成
|
||||
- 现在:最大2秒超时,保证用户体验
|
||||
|
||||
## 📊 性能提升预期
|
||||
|
||||
### 响应速度
|
||||
- **用户消息响应**:从原来的3-10秒减少到0.5-2秒
|
||||
- **记忆检索**:缓存命中时几乎即时响应
|
||||
- **记忆存储**:从同步阻塞改为后台处理
|
||||
|
||||
### 并发能力
|
||||
- **多用户同时使用**:不再因记忆系统相互阻塞
|
||||
- **高峰期稳定性**:记忆任务排队处理,不会崩溃
|
||||
|
||||
### 资源使用
|
||||
- **CPU使用**:异步处理,更好的CPU利用率
|
||||
- **内存优化**:缓存机制,减少重复计算
|
||||
- **网络延迟**:LLM调用并行化,减少等待时间
|
||||
|
||||
## 🔧 部署和配置
|
||||
|
||||
### 1. 自动部署
|
||||
新的异步系统已经集成到现有代码中,支持自动回退:
|
||||
|
||||
```python
|
||||
# 优先级回退机制
|
||||
1. 异步瞬时记忆包装器 (最优)
|
||||
2. 异步记忆管理器 (次优)
|
||||
3. 带超时的同步模式 (保底)
|
||||
```
|
||||
|
||||
### 2. 配置参数
|
||||
|
||||
在 `config.toml` 中可以调整相关参数:
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
enable_memory = true
|
||||
enable_instant_memory = true
|
||||
memory_build_interval = 600 # 记忆构建间隔(秒)
|
||||
```
|
||||
|
||||
### 3. 监控和调试
|
||||
|
||||
```python
|
||||
# 查看异步队列状态
|
||||
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
||||
status = async_memory_manager.get_status()
|
||||
print(status)
|
||||
|
||||
# 查看包装器状态
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
wrapper = get_async_instant_memory(chat_id)
|
||||
status = wrapper.get_status()
|
||||
print(status)
|
||||
```
|
||||
|
||||
## 🧪 验证方法
|
||||
|
||||
### 1. 性能测试
|
||||
```bash
|
||||
# 测试用户消息响应时间
|
||||
time curl -X POST "http://localhost:8080/api/message" -d '{"message": "你还记得我们昨天聊的内容吗?"}'
|
||||
|
||||
# 观察内存构建时的程序响应
|
||||
# 构建期间发送消息,观察是否还有阻塞
|
||||
```
|
||||
|
||||
### 2. 并发测试
|
||||
```python
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
async def test_concurrent_messages():
|
||||
"""测试并发消息处理"""
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
task = asyncio.create_task(send_message(f"测试消息 {i}"))
|
||||
tasks.append(task)
|
||||
|
||||
start_time = time.time()
|
||||
results = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"10条并发消息处理完成,耗时: {end_time - start_time:.2f}秒")
|
||||
```
|
||||
|
||||
### 3. 日志监控
|
||||
关注以下日志输出:
|
||||
- `"异步瞬时记忆:"` - 确认使用了异步系统
|
||||
- `"记忆构建任务已提交"` - 确认构建任务非阻塞
|
||||
- `"瞬时记忆检索超时"` - 监控超时情况
|
||||
|
||||
## 🔄 回退机制
|
||||
|
||||
系统设计了多层回退机制,确保即使新系统出现问题,也能维持基本功能:
|
||||
|
||||
1. **异步包装器失败** → 使用异步队列管理器
|
||||
2. **异步队列失败** → 使用带超时的同步模式
|
||||
3. **超时保护** → 最长等待时间不超过2秒
|
||||
4. **完全失败** → 跳过记忆功能,保证基本对话
|
||||
|
||||
## 📝 注意事项
|
||||
|
||||
1. **首次启动**:异步系统需要初始化时间,可能前几次记忆调用延迟稍高
|
||||
2. **缓存预热**:系统运行一段时间后,缓存效果会显著提升响应速度
|
||||
3. **内存使用**:缓存会增加内存使用,但相对于性能提升是值得的
|
||||
4. **兼容性**:如果发现异步系统有问题,可以临时禁用相关导入,自动回退到原系统
|
||||
|
||||
## 🎉 预期效果
|
||||
|
||||
- ✅ **消息响应速度提升60%+**
|
||||
- ✅ **记忆构建不再阻塞主程序**
|
||||
- ✅ **支持更高的并发用户数**
|
||||
- ✅ **系统整体稳定性提升**
|
||||
- ✅ **保持原有记忆功能完整性**
|
||||
237
src/chat/memory_system/enhanced_memory_activator.py
Normal file
237
src/chat/memory_system/enhanced_memory_activator.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆激活器
|
||||
替代原有的 MemoryActivator,使用增强记忆系统
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager, EnhancedMemoryResult
|
||||
|
||||
logger = get_logger("enhanced_memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
Args:
|
||||
json_str: JSON格式的字符串
|
||||
|
||||
Returns:
|
||||
List[str]: 关键词列表
|
||||
"""
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json = repair_json(json_str)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
return result.get("keywords", [])
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# --- Enhanced Memory Activator Prompt ---
|
||||
enhanced_memory_activator_prompt = """
|
||||
你是一个增强记忆分析器,你需要根据以下信息来进行记忆检索
|
||||
|
||||
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
|
||||
|
||||
聊天记录:
|
||||
{obs_info_text}
|
||||
|
||||
用户想要回复的消息:
|
||||
{target_message}
|
||||
|
||||
历史关键词(请避免重复提取这些关键词):
|
||||
{cached_keywords}
|
||||
|
||||
请输出一个json格式,包含以下字段:
|
||||
{{
|
||||
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
||||
}}
|
||||
|
||||
不要输出其他多余内容,只输出json格式就好
|
||||
"""
|
||||
|
||||
Prompt(enhanced_memory_activator_prompt, "enhanced_memory_activator_prompt")
|
||||
|
||||
|
||||
class EnhancedMemoryActivator:
|
||||
"""增强记忆激活器 - 替代原有的 MemoryActivator"""
|
||||
|
||||
def __init__(self):
|
||||
self.key_words_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="enhanced_memory.activator",
|
||||
)
|
||||
|
||||
self.running_memory = []
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_enhanced_query_time = 0 # 上次查询增强记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
"""
|
||||
激活增强记忆
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
# 将缓存的关键词转换为字符串,用于prompt
|
||||
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"enhanced_memory_activator_prompt",
|
||||
obs_info_text=chat_history_prompt,
|
||||
target_message=target_message,
|
||||
cached_keywords=cached_keywords_str,
|
||||
)
|
||||
|
||||
# 生成关键词
|
||||
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
|
||||
prompt, temperature=0.5
|
||||
)
|
||||
|
||||
keywords = list(get_keywords_from_json(response))
|
||||
|
||||
# 更新关键词缓存
|
||||
if keywords:
|
||||
# 限制缓存大小,最多保留10个关键词
|
||||
if len(self.cached_keywords) > 10:
|
||||
# 转换为列表,移除最早的关键词
|
||||
cached_list = list(self.cached_keywords)
|
||||
self.cached_keywords = set(cached_list[-8:])
|
||||
|
||||
# 添加新的关键词到缓存
|
||||
self.cached_keywords.update(keywords)
|
||||
|
||||
logger.debug(f"增强记忆关键词: {self.cached_keywords}")
|
||||
|
||||
# 使用增强记忆系统获取相关记忆
|
||||
enhanced_results = await self._query_enhanced_memory(keywords, target_message)
|
||||
|
||||
# 处理和增强记忆结果
|
||||
if enhanced_results:
|
||||
for result in enhanced_results:
|
||||
# 检查是否已存在相似内容的记忆
|
||||
exists = any(
|
||||
m["content"] == result.content or
|
||||
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
|
||||
for m in self.running_memory
|
||||
)
|
||||
if not exists:
|
||||
memory_entry = {
|
||||
"topic": result.memory_type,
|
||||
"content": result.content,
|
||||
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
|
||||
"duration": 1,
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"source": result.source
|
||||
}
|
||||
self.running_memory.append(memory_entry)
|
||||
logger.debug(f"添加新增强记忆: {result.memory_type} - {result.content}")
|
||||
|
||||
# 激活时,所有已有记忆的duration+1,达到3则移除
|
||||
for m in self.running_memory[:]:
|
||||
m["duration"] = m.get("duration", 1) + 1
|
||||
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
|
||||
|
||||
# 限制同时加载的记忆条数,最多保留最后5条(增强记忆可以处理更多)
|
||||
if len(self.running_memory) > 5:
|
||||
self.running_memory = self.running_memory[-5:]
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_enhanced_memory(self, keywords: List[str], query_text: str) -> List[EnhancedMemoryResult]:
|
||||
"""查询增强记忆系统"""
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"keywords": keywords,
|
||||
"query_intent": "conversation_response",
|
||||
"expected_memory_types": [
|
||||
"personal_fact", "event", "preference", "opinion"
|
||||
]
|
||||
}
|
||||
|
||||
# 查询增强记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text,
|
||||
user_id="default_user", # 可以根据实际用户ID调整
|
||||
context=context,
|
||||
limit=5
|
||||
)
|
||||
|
||||
logger.debug(f"增强记忆查询返回 {len(enhanced_results)} 条结果")
|
||||
return enhanced_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询增强记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口
|
||||
"""
|
||||
try:
|
||||
# 使用增强记忆系统获取相关记忆
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
context = {
|
||||
"query_intent": "instant_response",
|
||||
"chat_id": chat_id,
|
||||
"expected_memory_types": ["preference", "opinion", "personal_fact"]
|
||||
}
|
||||
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=target_message,
|
||||
user_id="default_user",
|
||||
context=context,
|
||||
limit=1
|
||||
)
|
||||
|
||||
if enhanced_results:
|
||||
# 返回最相关的记忆内容
|
||||
return enhanced_results[0].content
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取即时记忆失败: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清除缓存"""
|
||||
self.cached_keywords.clear()
|
||||
self.running_memory.clear()
|
||||
logger.debug("增强记忆激活器缓存已清除")
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
enhanced_memory_activator = EnhancedMemoryActivator()
|
||||
|
||||
|
||||
# 为了兼容性,保留原有名称
|
||||
MemoryActivator = EnhancedMemoryActivator
|
||||
|
||||
|
||||
init_prompt()
|
||||
332
src/chat/memory_system/enhanced_memory_adapter.py
Normal file
332
src/chat/memory_system/enhanced_memory_adapter.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""适配器配置"""
|
||||
enable_enhanced_memory: bool = True
|
||||
integration_mode: str = "enhanced_only" # replace, enhanced_only
|
||||
auto_migration: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
|
||||
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: Optional[MemoryIntegrationLayer] = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
self.adapter_stats = {
|
||||
"total_processed": 0,
|
||||
"enhanced_used": 0,
|
||||
"legacy_used": 0,
|
||||
"hybrid_used": 0,
|
||||
"memories_created": 0,
|
||||
"memories_retrieved": 0,
|
||||
"average_processing_time": 0.0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化适配器"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统适配器...")
|
||||
|
||||
# 转换配置格式
|
||||
integration_config = IntegrationConfig(
|
||||
mode=IntegrationMode(self.config.integration_mode),
|
||||
enable_enhanced_memory=self.config.enable_enhanced_memory,
|
||||
memory_value_threshold=self.config.memory_value_threshold,
|
||||
fusion_threshold=self.config.fusion_threshold,
|
||||
max_retrieval_results=self.config.max_retrieval_results,
|
||||
enable_learning=True # 启用学习功能
|
||||
)
|
||||
|
||||
# 创建集成层
|
||||
self.integration_layer = MemoryIntegrationLayer(
|
||||
llm_model=self.llm_model,
|
||||
config=integration_config
|
||||
)
|
||||
|
||||
# 初始化集成层
|
||||
await self.integration_layer.initialize()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统适配器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""处理对话记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
# 使用集成层处理对话
|
||||
result = await self.integration_layer.process_conversation(
|
||||
conversation_text, context, user_id, timestamp
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_processing_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.adapter_stats["memories_created"] += created_count
|
||||
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(
|
||||
query, user_id, context, limit
|
||||
)
|
||||
|
||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||
logger.debug(f"检索到 {len(memories)} 条相关记忆")
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
# 格式化记忆为提示词友好的格式
|
||||
memory_context_parts = []
|
||||
for memory in memories:
|
||||
memory_context_parts.append(f"- {memory.text_content}")
|
||||
|
||||
return "\n".join(memory_context_parts)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
|
||||
try:
|
||||
# 获取系统状态
|
||||
status = await self.integration_layer.get_system_status()
|
||||
|
||||
# 获取适配器统计
|
||||
adapter_stats = self.adapter_stats.copy()
|
||||
|
||||
# 获取集成统计
|
||||
integration_stats = self.integration_layer.get_integration_stats()
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"system_status": status,
|
||||
"adapter_stats": adapter_stats,
|
||||
"integration_stats": integration_stats,
|
||||
"total_memories_created": adapter_stats["memories_created"],
|
||||
"total_memories_retrieved": adapter_stats["memories_retrieved"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
def _update_processing_stats(self, processing_time: float):
|
||||
"""更新处理统计"""
|
||||
total_processed = self.adapter_stats["total_processed"]
|
||||
if total_processed > 0:
|
||||
current_avg = self.adapter_stats["average_processing_time"]
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> Dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
async def maintenance(self):
|
||||
"""维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 增强记忆系统适配器维护...")
|
||||
await self.integration_layer.maintenance()
|
||||
logger.info("✅ 增强记忆系统适配器维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭适配器"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭增强记忆系统适配器...")
|
||||
await self.integration_layer.shutdown()
|
||||
self._initialized = False
|
||||
logger.info("✅ 增强记忆系统适配器已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
"""获取全局增强记忆适配器实例"""
|
||||
global _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter is None:
|
||||
# 从配置中获取适配器配置
|
||||
from src.config.config import global_config
|
||||
|
||||
adapter_config = AdapterConfig(
|
||||
enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True),
|
||||
integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'),
|
||||
auto_migration=getattr(global_config.memory, 'enable_memory_migration', True),
|
||||
memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6),
|
||||
fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85),
|
||||
max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10)
|
||||
)
|
||||
|
||||
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
|
||||
await _enhanced_memory_adapter.initialize()
|
||||
|
||||
return _enhanced_memory_adapter
|
||||
|
||||
|
||||
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统...")
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
return adapter
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 10,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
|
||||
return ""
|
||||
753
src/chat/memory_system/enhanced_memory_core.py
Normal file
753
src/chat/memory_system/enhanced_memory_core.py
Normal file
@@ -0,0 +1,753 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强型精准记忆系统核心模块
|
||||
基于文档设计的高效记忆构建、存储与召回优化系统
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import orjson
|
||||
import re
|
||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||
from src.chat.memory_system.metadata_index import MetadataIndexManager
|
||||
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemorySystemStatus(Enum):
|
||||
"""记忆系统状态"""
|
||||
INITIALIZING = "initializing"
|
||||
READY = "ready"
|
||||
BUILDING = "building"
|
||||
RETRIEVING = "retrieving"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySystemConfig:
|
||||
"""记忆系统配置"""
|
||||
# 记忆构建配置
|
||||
min_memory_length: int = 10
|
||||
max_memory_length: int = 500
|
||||
memory_value_threshold: float = 0.7
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension: int = 768
|
||||
similarity_threshold: float = 0.8
|
||||
|
||||
# 召回配置
|
||||
coarse_recall_limit: int = 50
|
||||
fine_recall_limit: int = 10
|
||||
final_recall_limit: int = 5
|
||||
|
||||
# 融合配置
|
||||
fusion_similarity_threshold: float = 0.85
|
||||
deduplication_window: timedelta = timedelta(hours=24)
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建配置实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return cls(
|
||||
# 记忆构建配置
|
||||
min_memory_length=global_config.memory.min_memory_length,
|
||||
max_memory_length=global_config.memory.max_memory_length,
|
||||
memory_value_threshold=global_config.memory.memory_value_threshold,
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension=global_config.memory.vector_dimension,
|
||||
similarity_threshold=global_config.memory.vector_similarity_threshold,
|
||||
|
||||
# 召回配置
|
||||
coarse_recall_limit=global_config.memory.metadata_filter_limit,
|
||||
fine_recall_limit=global_config.memory.final_result_limit,
|
||||
final_recall_limit=global_config.memory.final_result_limit,
|
||||
|
||||
# 融合配置
|
||||
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
|
||||
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours)
|
||||
)
|
||||
|
||||
|
||||
class EnhancedMemorySystem:
|
||||
"""增强型精准记忆系统核心类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
config: Optional[MemorySystemConfig] = None
|
||||
):
|
||||
self.config = config or MemorySystemConfig.from_global_config()
|
||||
self.llm_model = llm_model
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
|
||||
# 核心组件
|
||||
self.memory_builder: MemoryBuilder = None
|
||||
self.fusion_engine: MemoryFusionEngine = None
|
||||
self.vector_storage: VectorStorageManager = None
|
||||
self.metadata_index: MetadataIndexManager = None
|
||||
self.retrieval_system: MultiStageRetrieval = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
self.memory_extraction_model: LLMRequest = None
|
||||
|
||||
# 统计信息
|
||||
self.total_memories = 0
|
||||
self.last_build_time = None
|
||||
self.last_retrieval_time = None
|
||||
|
||||
logger.info("EnhancedMemorySystem 初始化开始")
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
logger.info("正在初始化增强型记忆系统...")
|
||||
|
||||
# 初始化LLM模型
|
||||
task_config = (
|
||||
self.llm_model.model_for_task
|
||||
if self.llm_model is not None
|
||||
else model_config.model_task_config.utils
|
||||
)
|
||||
|
||||
self.value_assessment_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.value_assessment"
|
||||
)
|
||||
|
||||
self.memory_extraction_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.extraction"
|
||||
)
|
||||
|
||||
# 初始化核心组件
|
||||
self.memory_builder = MemoryBuilder(self.memory_extraction_model)
|
||||
self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold)
|
||||
# 创建向量存储配置
|
||||
vector_config = VectorStorageConfig(
|
||||
dimension=self.config.vector_dimension,
|
||||
similarity_threshold=self.config.similarity_threshold
|
||||
)
|
||||
self.vector_storage = VectorStorageManager(vector_config)
|
||||
self.metadata_index = MetadataIndexManager()
|
||||
# 创建检索配置
|
||||
retrieval_config = RetrievalConfig(
|
||||
metadata_filter_limit=self.config.coarse_recall_limit,
|
||||
vector_search_limit=self.config.fine_recall_limit,
|
||||
final_result_limit=self.config.final_recall_limit
|
||||
)
|
||||
self.retrieval_system = MultiStageRetrieval(retrieval_config)
|
||||
|
||||
# 加载持久化数据
|
||||
await self.vector_storage.load_storage()
|
||||
await self.metadata_index.load_index()
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
logger.info("✅ 增强型记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
self.status = MemorySystemStatus.ERROR
|
||||
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
Args:
|
||||
conversation_text: 对话文本
|
||||
context: 上下文信息(包括用户信息、群组信息等)
|
||||
user_id: 用户ID
|
||||
timestamp: 时间戳,默认为当前时间
|
||||
|
||||
Returns:
|
||||
构建的记忆块列表
|
||||
"""
|
||||
if self.status != MemorySystemStatus.READY:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
||||
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
|
||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
# 1. 信息价值评估
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
|
||||
if value_score < self.config.memory_value_threshold:
|
||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||
self.status = MemorySystemStatus.READY
|
||||
return []
|
||||
|
||||
# 2. 构建记忆块
|
||||
memory_chunks = await self.memory_builder.build_memories(
|
||||
conversation_text,
|
||||
normalized_context,
|
||||
user_id,
|
||||
timestamp or time.time()
|
||||
)
|
||||
|
||||
if not memory_chunks:
|
||||
logger.debug("未提取到有效记忆块")
|
||||
self.status = MemorySystemStatus.READY
|
||||
return []
|
||||
|
||||
# 3. 记忆融合与去重
|
||||
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
|
||||
|
||||
# 4. 存储记忆
|
||||
await self._store_memories(fused_chunks)
|
||||
|
||||
# 5. 更新统计
|
||||
self.total_memories += len(fused_chunks)
|
||||
self.last_build_time = time.time()
|
||||
|
||||
build_time = time.time() - start_time
|
||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
return fused_chunks
|
||||
|
||||
except Exception as e:
|
||||
self.status = MemorySystemStatus.ERROR
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
||||
|
||||
memories = await self.build_memory_from_conversation(
|
||||
conversation_text=conversation_text,
|
||||
context=normalized_context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
memory_count = len(memories)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"created_memories": memories,
|
||||
"memory_count": memory_count,
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value
|
||||
}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query_text: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5,
|
||||
**kwargs
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆,兼容 query/query_text 参数形式"""
|
||||
if self.status != MemorySystemStatus.READY:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
query_text = query_text or kwargs.get("query")
|
||||
if not query_text:
|
||||
raise ValueError("query_text 或 query 参数不能为空")
|
||||
|
||||
context = context or {}
|
||||
user_id = user_id or kwargs.get("user_id")
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, None)
|
||||
|
||||
candidate_memories = list(self.vector_storage.memory_cache.values())
|
||||
if user_id:
|
||||
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
|
||||
|
||||
if not candidate_memories:
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.last_retrieval_time = time.time()
|
||||
logger.debug(f"未找到用户 {user_id} 的候选记忆")
|
||||
return []
|
||||
|
||||
scored_memories = []
|
||||
for memory in candidate_memories:
|
||||
score = self._compute_memory_score(query_text, memory, normalized_context)
|
||||
if score > 0:
|
||||
scored_memories.append((memory, score))
|
||||
|
||||
if not scored_memories:
|
||||
# 如果所有分数为0,返回最近的记忆作为降级策略
|
||||
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
|
||||
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
|
||||
else:
|
||||
scored_memories.sort(key=lambda item: item[1], reverse=True)
|
||||
|
||||
top_memories = [memory for memory, _ in scored_memories[:limit]]
|
||||
|
||||
# 更新访问信息和缓存
|
||||
for memory, score in scored_memories[:limit]:
|
||||
memory.update_access()
|
||||
memory.update_relevance(score)
|
||||
|
||||
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
|
||||
if cache_entry is not None:
|
||||
cache_entry["last_accessed"] = memory.metadata.last_accessed
|
||||
cache_entry["access_count"] = memory.metadata.access_count
|
||||
cache_entry["relevance_score"] = memory.metadata.relevance_score
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}秒"
|
||||
)
|
||||
|
||||
self.last_retrieval_time = time.time()
|
||||
self.status = MemorySystemStatus.READY
|
||||
|
||||
return top_memories
|
||||
|
||||
except Exception as e:
|
||||
self.status = MemorySystemStatus.ERROR
|
||||
logger.error(f"❌ 记忆检索失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_payload(response: str) -> Optional[str]:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
stripped = response.strip()
|
||||
|
||||
# 优先处理Markdown代码块格式 ```json ... ```
|
||||
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||
if code_block_match:
|
||||
candidate = code_block_match.group(1).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
# 回退到查找第一个 JSON 对象的大括号范围
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1].strip()
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _normalize_context(
|
||||
self,
|
||||
raw_context: Optional[Dict[str, Any]],
|
||||
user_id: Optional[str],
|
||||
timestamp: Optional[float]
|
||||
) -> Dict[str, Any]:
|
||||
"""标准化上下文,确保必备字段存在且格式正确"""
|
||||
context: Dict[str, Any] = {}
|
||||
if raw_context:
|
||||
try:
|
||||
context = dict(raw_context)
|
||||
except Exception:
|
||||
context = dict(raw_context or {})
|
||||
|
||||
# 基础字段
|
||||
context["user_id"] = context.get("user_id") or user_id or "unknown"
|
||||
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
||||
context["message_type"] = context.get("message_type") or "normal"
|
||||
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
||||
|
||||
# 标准化关键词类型
|
||||
keywords = context.get("keywords")
|
||||
if keywords is None:
|
||||
context["keywords"] = []
|
||||
elif isinstance(keywords, tuple):
|
||||
context["keywords"] = list(keywords)
|
||||
elif not isinstance(keywords, list):
|
||||
context["keywords"] = [str(keywords)] if keywords else []
|
||||
|
||||
# 统一 stream_id
|
||||
stream_id = context.get("stream_id") or context.get("stram_id")
|
||||
if not stream_id:
|
||||
potential = context.get("chat_id") or context.get("session_id")
|
||||
if isinstance(potential, str) and potential:
|
||||
stream_id = potential
|
||||
if stream_id:
|
||||
context["stream_id"] = stream_id
|
||||
|
||||
# chat_id 兜底
|
||||
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
|
||||
|
||||
# 历史窗口配置
|
||||
window_candidate = (
|
||||
context.get("history_limit")
|
||||
or context.get("history_window")
|
||||
or context.get("memory_history_limit")
|
||||
)
|
||||
if window_candidate is not None:
|
||||
try:
|
||||
context["history_limit"] = int(window_candidate)
|
||||
except (TypeError, ValueError):
|
||||
context.pop("history_limit", None)
|
||||
|
||||
return context
|
||||
|
||||
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
|
||||
stream_id = context.get("stream_id") or context.get("stram_id")
|
||||
if not stream_id:
|
||||
return fallback_text
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream or not hasattr(chat_stream, "context_manager"):
|
||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||
return fallback_text
|
||||
|
||||
history_limit = self._determine_history_limit(context)
|
||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||
if not messages:
|
||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||
return fallback_text
|
||||
|
||||
transcript = self._format_history_messages(messages)
|
||||
if not transcript:
|
||||
return fallback_text
|
||||
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
logger.debug(
|
||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||
stream_id,
|
||||
len(messages),
|
||||
history_limit,
|
||||
)
|
||||
return transcript
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||
return fallback_text
|
||||
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
default_limit = 40
|
||||
candidate = (
|
||||
context.get("history_limit")
|
||||
or context.get("history_window")
|
||||
or context.get("memory_history_limit")
|
||||
)
|
||||
|
||||
if isinstance(candidate, str):
|
||||
try:
|
||||
candidate = int(candidate)
|
||||
except ValueError:
|
||||
candidate = None
|
||||
|
||||
if isinstance(candidate, int):
|
||||
history_limit = max(30, min(50, candidate))
|
||||
else:
|
||||
history_limit = default_limit
|
||||
|
||||
return history_limit
|
||||
|
||||
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
|
||||
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: List[str] = []
|
||||
for msg in messages:
|
||||
try:
|
||||
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
content = re.sub(r"\s+", " ", str(content).strip())
|
||||
if not content:
|
||||
continue
|
||||
|
||||
speaker = None
|
||||
if hasattr(msg, "user_info") and msg.user_info:
|
||||
speaker = (
|
||||
getattr(msg.user_info, "user_nickname", None)
|
||||
or getattr(msg.user_info, "user_cardname", None)
|
||||
or getattr(msg.user_info, "user_id", None)
|
||||
)
|
||||
speaker = speaker or getattr(msg, "user_nickname", None) or getattr(msg, "user_id", None) or "用户"
|
||||
|
||||
timestamp_value = getattr(msg, "time", None) or 0.0
|
||||
try:
|
||||
timestamp_dt = datetime.fromtimestamp(float(timestamp_value)) if timestamp_value else datetime.now()
|
||||
except (TypeError, ValueError, OSError):
|
||||
timestamp_dt = datetime.now()
|
||||
|
||||
timestamp_str = timestamp_dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
lines.append(f"[{timestamp_str}] {speaker}: {content}")
|
||||
|
||||
except Exception as message_exc:
|
||||
logger.debug(f"格式化历史消息失败: {message_exc}")
|
||||
continue
|
||||
|
||||
return "\n".join(lines) if lines else None
|
||||
|
||||
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
|
||||
"""评估信息价值
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
价值评分 (0.0-1.0)
|
||||
"""
|
||||
try:
|
||||
# 构建评估提示
|
||||
prompt = f"""
|
||||
请评估以下对话内容的信息价值,重点识别包含个人事实、事件、偏好、观点等重要信息的内容。
|
||||
|
||||
## 🎯 价值评估重点标准:
|
||||
|
||||
### 高价值信息 (0.7-1.0分):
|
||||
1. **个人事实** (personal_fact):包含姓名、年龄、职业、联系方式、住址、健康状况、家庭情况等个人信息
|
||||
2. **重要事件** (event):约会、会议、旅行、考试、面试、搬家等重要活动或经历
|
||||
3. **明确偏好** (preference):表达喜欢/不喜欢的食物、电影、音乐、品牌、生活习惯等偏好信息
|
||||
4. **观点态度** (opinion):对事物的评价、看法、建议、态度等主观观点
|
||||
5. **核心关系** (relationship):重要的朋友、家人、同事等人际关系信息
|
||||
|
||||
### 中等价值信息 (0.4-0.7分):
|
||||
1. **情感表达**:当前情绪状态、心情变化
|
||||
2. **日常活动**:常规的工作、学习、生活安排
|
||||
3. **一般兴趣**:兴趣爱好、休闲活动
|
||||
4. **短期计划**:即将进行的安排和计划
|
||||
|
||||
### 低价值信息 (0.0-0.4分):
|
||||
1. **寒暄问候**:简单的打招呼、礼貌用语
|
||||
2. **重复信息**:已经多次提到的相同内容
|
||||
3. **临时状态**:短暂的情绪波动、临时想法
|
||||
4. **无关内容**:与用户画像建立无关的信息
|
||||
|
||||
对话内容:
|
||||
{text}
|
||||
|
||||
上下文信息:
|
||||
- 用户ID: {context.get('user_id', 'unknown')}
|
||||
- 消息类型: {context.get('message_type', 'unknown')}
|
||||
- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))}
|
||||
|
||||
## 📋 评估要求:
|
||||
|
||||
### 积极识别原则:
|
||||
- **宁可高估,不可低估** - 对于可能的个人信息给予较高评估
|
||||
- **重点关注** - 特别注意包含 personal_fact、event、preference、opinion 的内容
|
||||
- **细节丰富** - 具体的细节信息比笼统的描述更有价值
|
||||
- **建立画像** - 有助于建立完整用户画像的信息更有价值
|
||||
|
||||
### 评分指导:
|
||||
- **0.9-1.0**:核心个人信息(姓名、联系方式、重要偏好)
|
||||
- **0.7-0.8**:重要的个人事实、观点、事件经历
|
||||
- **0.5-0.6**:一般性偏好、日常活动、情感表达
|
||||
- **0.3-0.4**:简单的兴趣表达、临时状态
|
||||
- **0.0-0.2**:寒暄问候、重复内容、无关信息
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"value_score": 0.0到1.0之间的数值,
|
||||
"reasoning": "评估理由,包含具体识别到的信息类型",
|
||||
"key_factors": ["关键因素1", "关键因素2"],
|
||||
"detected_types": ["personal_fact", "preference", "opinion", "event", "relationship", "emotion", "goal"]
|
||||
}}
|
||||
"""
|
||||
|
||||
response, _ = await self.value_assessment_model.generate_response_async(
|
||||
prompt, temperature=0.3
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
payload = self._extract_json_payload(response)
|
||||
if not payload:
|
||||
raise ValueError("未在响应中找到有效的JSON负载")
|
||||
|
||||
result = orjson.loads(payload)
|
||||
value_score = float(result.get("value_score", 0.0))
|
||||
reasoning = result.get("reasoning", "")
|
||||
key_factors = result.get("key_factors", [])
|
||||
|
||||
logger.info(f"信息价值评估: {value_score:.2f}, 理由: {reasoning}")
|
||||
if key_factors:
|
||||
logger.info(f"关键因素: {', '.join(key_factors)}")
|
||||
|
||||
return max(0.0, min(1.0, value_score))
|
||||
|
||||
except (orjson.JSONDecodeError, ValueError) as e:
|
||||
preview = response[:200].replace('\n', ' ')
|
||||
logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}")
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
|
||||
"""存储记忆块到各个存储系统"""
|
||||
if not memory_chunks:
|
||||
return
|
||||
|
||||
# 并行存储到向量数据库和元数据索引
|
||||
storage_tasks = []
|
||||
|
||||
# 向量存储
|
||||
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
|
||||
|
||||
# 元数据索引
|
||||
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
|
||||
|
||||
# 等待所有存储任务完成
|
||||
await asyncio.gather(*storage_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
|
||||
|
||||
def get_system_stats(self) -> Dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"total_memories": self.total_memories,
|
||||
"last_build_time": self.last_build_time,
|
||||
"last_retrieval_time": self.last_retrieval_time,
|
||||
"config": asdict(self.config)
|
||||
}
|
||||
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
"""根据查询和上下文为记忆计算匹配分数"""
|
||||
tokens_query = self._tokenize_text(query_text)
|
||||
tokens_memory = self._tokenize_text(memory.text_content)
|
||||
|
||||
if tokens_query and tokens_memory:
|
||||
base_score = len(tokens_query & tokens_memory) / len(tokens_query | tokens_memory)
|
||||
else:
|
||||
base_score = 0.0
|
||||
|
||||
context_keywords = context.get("keywords") or []
|
||||
keyword_overlap = 0.0
|
||||
if context_keywords:
|
||||
memory_keywords = set(k.lower() for k in memory.keywords)
|
||||
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1)
|
||||
|
||||
importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1
|
||||
confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05
|
||||
|
||||
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
|
||||
return max(0.0, min(1.0, final_score))
|
||||
|
||||
def _tokenize_text(self, text: str) -> Set[str]:
|
||||
"""简单分词,兼容中英文"""
|
||||
if not text:
|
||||
return set()
|
||||
|
||||
tokens = re.findall(r"[\w\u4e00-\u9fa5]+", text.lower())
|
||||
return {token for token in tokens if len(token) > 1}
|
||||
|
||||
async def maintenance(self):
|
||||
"""系统维护操作"""
|
||||
try:
|
||||
logger.info("开始记忆系统维护...")
|
||||
|
||||
# 向量存储优化
|
||||
await self.vector_storage.optimize_storage()
|
||||
|
||||
# 元数据索引优化
|
||||
await self.metadata_index.optimize_index()
|
||||
|
||||
# 记忆融合引擎维护
|
||||
await self.fusion_engine.maintenance()
|
||||
|
||||
logger.info("✅ 记忆系统维护完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭系统"""
|
||||
try:
|
||||
logger.info("正在关闭增强型记忆系统...")
|
||||
|
||||
# 保存持久化数据
|
||||
await self.vector_storage.save_storage()
|
||||
await self.metadata_index.save_index()
|
||||
|
||||
logger.info("✅ 增强型记忆系统已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆系统关闭失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局记忆系统实例
|
||||
enhanced_memory_system: EnhancedMemorySystem = None
|
||||
|
||||
|
||||
def get_enhanced_memory_system() -> EnhancedMemorySystem:
|
||||
"""获取全局记忆系统实例"""
|
||||
global enhanced_memory_system
|
||||
if enhanced_memory_system is None:
|
||||
enhanced_memory_system = EnhancedMemorySystem()
|
||||
return enhanced_memory_system
|
||||
|
||||
|
||||
async def initialize_enhanced_memory_system():
|
||||
"""初始化全局记忆系统"""
|
||||
global enhanced_memory_system
|
||||
if enhanced_memory_system is None:
|
||||
enhanced_memory_system = EnhancedMemorySystem()
|
||||
await enhanced_memory_system.initialize()
|
||||
return enhanced_memory_system
|
||||
181
src/chat/memory_system/enhanced_memory_hooks.py
Normal file
181
src/chat/memory_system/enhanced_memory_hooks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EnhancedMemoryHooks:
|
||||
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = (global_config.memory.enable_memory and
|
||||
global_config.memory.enable_enhanced_memory)
|
||||
self.processed_messages = set() # 避免重复处理
|
||||
|
||||
async def process_message_for_memory(
|
||||
self,
|
||||
message_content: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功处理
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if message_id in self.processed_messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 构建上下文
|
||||
memory_context = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"message_type": "user_message",
|
||||
**(context or {})
|
||||
}
|
||||
|
||||
# 处理对话并构建记忆
|
||||
memory_chunks = await enhanced_memory_manager.process_conversation(
|
||||
conversation_text=message_content,
|
||||
context=memory_context,
|
||||
user_id=user_id,
|
||||
timestamp=memory_context["timestamp"]
|
||||
)
|
||||
|
||||
# 标记消息已处理
|
||||
self.processed_messages.add(message_id)
|
||||
|
||||
# 限制处理历史大小
|
||||
if len(self.processed_messages) > 1000:
|
||||
# 移除最旧的500个记录
|
||||
self.processed_messages = set(list(self.processed_messages)[-500:])
|
||||
|
||||
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
|
||||
return len(memory_chunks) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_memory_for_response(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 相关记忆列表
|
||||
"""
|
||||
if not self.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"query_intent": "response_generation",
|
||||
"expected_memory_types": [
|
||||
"personal_fact", "event", "preference", "opinion"
|
||||
]
|
||||
}
|
||||
|
||||
# 获取相关记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
results = []
|
||||
for result in enhanced_results:
|
||||
memory_dict = {
|
||||
"content": result.content,
|
||||
"type": result.memory_type,
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"timestamp": result.timestamp,
|
||||
"source": result.source
|
||||
}
|
||||
results.append(memory_dict)
|
||||
|
||||
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_old_memories(self):
|
||||
"""清理旧记忆"""
|
||||
try:
|
||||
if enhanced_memory_manager.is_initialized:
|
||||
# 调用增强记忆系统的维护功能
|
||||
await enhanced_memory_manager.enhanced_system.maintenance()
|
||||
logger.debug("增强记忆系统维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"清理旧记忆失败: {e}")
|
||||
|
||||
def clear_processed_cache(self):
|
||||
"""清除已处理消息的缓存"""
|
||||
self.processed_messages.clear()
|
||||
logger.debug("已清除消息处理缓存")
|
||||
|
||||
def enable(self):
|
||||
"""启用记忆钩子"""
|
||||
self.enabled = True
|
||||
logger.info("增强记忆钩子已启用")
|
||||
|
||||
def disable(self):
|
||||
"""禁用记忆钩子"""
|
||||
self.enabled = False
|
||||
logger.info("增强记忆钩子已禁用")
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
enhanced_memory_hooks = EnhancedMemoryHooks()
|
||||
206
src/chat/memory_system/enhanced_memory_integration.py
Normal file
206
src/chat/memory_system/enhanced_memory_integration.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 额外的上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功构建记忆
|
||||
"""
|
||||
try:
|
||||
success = await enhanced_memory_hooks.process_message_for_memory(
|
||||
message_content=message_content,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
context=context
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"成功为消息 {message_id} 构建记忆")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本(通常是用户的当前消息)
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
|
||||
Returns:
|
||||
Dict: 包含记忆信息的字典
|
||||
"""
|
||||
try:
|
||||
memories = await enhanced_memory_hooks.get_memory_for_response(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
result = {
|
||||
"has_memories": len(memories) > 0,
|
||||
"memories": memories,
|
||||
"memory_count": len(memories)
|
||||
}
|
||||
|
||||
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return {
|
||||
"has_memories": False,
|
||||
"memories": [],
|
||||
"memory_count": 0
|
||||
}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
Args:
|
||||
memories: 记忆信息字典
|
||||
|
||||
Returns:
|
||||
str: 格式化后的记忆文本
|
||||
"""
|
||||
if not memories["has_memories"]:
|
||||
return ""
|
||||
|
||||
memory_lines = ["以下是相关的记忆信息:"]
|
||||
|
||||
for memory in memories["memories"]:
|
||||
content = memory["content"]
|
||||
memory_type = memory["type"]
|
||||
confidence = memory["confidence"]
|
||||
importance = memory["importance"]
|
||||
|
||||
# 根据重要性添加不同的标记
|
||||
importance_marker = "🔥" if importance >= 3 else "⭐" if importance >= 2 else "📝"
|
||||
confidence_marker = "✅" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
|
||||
|
||||
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
|
||||
memory_lines.append(memory_line)
|
||||
|
||||
return "\n".join(memory_lines)
|
||||
|
||||
|
||||
async def cleanup_memory_system():
|
||||
"""清理记忆系统"""
|
||||
try:
|
||||
await enhanced_memory_hooks.cleanup_old_memories()
|
||||
logger.info("记忆系统清理完成")
|
||||
except Exception as e:
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> Dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
Returns:
|
||||
Dict: 系统状态信息
|
||||
"""
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
return {
|
||||
"enabled": enhanced_memory_hooks.enabled,
|
||||
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
|
||||
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
|
||||
"system_type": "enhanced_memory_system"
|
||||
}
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat"
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
|
||||
Args:
|
||||
message: 要记住的消息
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
import uuid
|
||||
message_id = str(uuid.uuid4())
|
||||
return await process_user_message_memory(
|
||||
message_content=message,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id
|
||||
)
|
||||
|
||||
|
||||
async def recall_memories(
|
||||
query: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Dict: 记忆信息
|
||||
"""
|
||||
return await get_relevant_memories_for_response(
|
||||
query_text=query,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
)
|
||||
305
src/chat/memory_system/enhanced_memory_manager.py
Normal file
305
src/chat/memory_system/enhanced_memory_manager.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统管理器
|
||||
替代原有的 Hippocampus 和 instant_memory 系统
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
initialize_enhanced_memory_system
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedMemoryResult:
|
||||
"""增强记忆查询结果"""
|
||||
content: str
|
||||
memory_type: str
|
||||
confidence: float
|
||||
importance: float
|
||||
timestamp: float
|
||||
source: str = "enhanced_memory"
|
||||
relevance_score: float = 0.0
|
||||
|
||||
|
||||
class EnhancedMemoryManager:
|
||||
"""增强记忆系统管理器 - 替代原有的 HippocampusManager"""
|
||||
|
||||
def __init__(self):
|
||||
self.enhanced_system: Optional[EnhancedMemorySystem] = None
|
||||
self.is_initialized = False
|
||||
self.user_cache = {} # 用户记忆缓存
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化增强记忆系统"""
|
||||
if self.is_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 检查是否启用增强记忆系统
|
||||
if not global_config.memory.enable_enhanced_memory:
|
||||
logger.info("增强记忆系统已禁用,跳过初始化")
|
||||
self.is_initialized = True
|
||||
return
|
||||
|
||||
logger.info("正在初始化增强记忆系统...")
|
||||
|
||||
# 获取LLM模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
# 初始化增强记忆系统
|
||||
self.enhanced_system = await initialize_enhanced_memory_system(llm_model)
|
||||
|
||||
# 设置全局实例
|
||||
global_enhanced_manager = self.enhanced_system
|
||||
|
||||
self.is_initialized = True
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}")
|
||||
# 如果增强系统初始化失败,创建一个空的管理器避免系统崩溃
|
||||
self.enhanced_system = None
|
||||
self.is_initialized = True # 标记为已初始化但系统不可用
|
||||
|
||||
def get_hippocampus(self):
|
||||
"""兼容原有接口 - 返回空"""
|
||||
logger.debug("get_hippocampus 调用 - 增强记忆系统不使用此方法")
|
||||
return {}
|
||||
|
||||
async def build_memory(self):
|
||||
"""兼容原有接口 - 构建记忆"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return
|
||||
|
||||
try:
|
||||
# 增强记忆系统使用实时构建,不需要定时构建
|
||||
logger.debug("build_memory 调用 - 增强记忆系统使用实时构建")
|
||||
except Exception as e:
|
||||
logger.error(f"build_memory 失败: {e}")
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
"""兼容原有接口 - 遗忘机制"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return
|
||||
|
||||
try:
|
||||
# 增强记忆系统有内置的遗忘机制
|
||||
logger.debug(f"forget_memory 调用 - 参数: {percentage}")
|
||||
# 可以在这里调用增强系统的维护功能
|
||||
await self.enhanced_system.maintenance()
|
||||
except Exception as e:
|
||||
logger.error(f"forget_memory 失败: {e}")
|
||||
|
||||
async def consolidate_memory(self):
|
||||
"""兼容原有接口 - 记忆巩固"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return
|
||||
|
||||
try:
|
||||
# 增强记忆系统自动处理记忆巩固
|
||||
logger.debug("consolidate_memory 调用 - 增强记忆系统自动处理")
|
||||
except Exception as e:
|
||||
logger.error(f"consolidate_memory 失败: {e}")
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
chat_id: str,
|
||||
user_id: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
time_weight: float = 1.0,
|
||||
keyword_weight: float = 1.0
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""从文本获取相关记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 使用增强记忆系统检索
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE]
|
||||
}
|
||||
|
||||
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
||||
query=text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=max_memory_num
|
||||
)
|
||||
|
||||
# 转换为原有格式 (topic, content)
|
||||
results = []
|
||||
for memory in relevant_memories:
|
||||
topic = memory.memory_type.value
|
||||
content = memory.text_content
|
||||
results.append((topic, content))
|
||||
|
||||
logger.debug(f"从文本检索到 {len(results)} 条相关记忆")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"get_memory_from_text 失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
valid_keywords: List[str],
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""从关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 将关键词转换为查询文本
|
||||
query_text = " ".join(valid_keywords)
|
||||
|
||||
# 使用增强记忆系统检索
|
||||
context = {
|
||||
"keywords": valid_keywords,
|
||||
"expected_memory_types": [
|
||||
MemoryType.PERSONAL_FACT,
|
||||
MemoryType.EVENT,
|
||||
MemoryType.PREFERENCE,
|
||||
MemoryType.OPINION
|
||||
]
|
||||
}
|
||||
|
||||
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
||||
query_text=query_text,
|
||||
user_id="default_user", # 可以根据实际需要传递
|
||||
context=context,
|
||||
limit=max_memory_num
|
||||
)
|
||||
|
||||
# 转换为原有格式 (topic, content)
|
||||
results = []
|
||||
for memory in relevant_memories:
|
||||
topic = memory.memory_type.value
|
||||
content = memory.text_content
|
||||
results.append((topic, content))
|
||||
|
||||
logger.debug(f"从关键词 {valid_keywords} 检索到 {len(results)} 条相关记忆")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"get_memory_from_topic 失败: {e}")
|
||||
return []
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从单个关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 同步方法,返回空列表
|
||||
logger.debug(f"get_memory_from_keyword 调用 - 关键词: {keyword}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"get_memory_from_keyword 失败: {e}")
|
||||
return []
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""处理对话并构建记忆 - 新增功能"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self.enhanced_system.process_conversation_memory(
|
||||
conversation_text=conversation_text,
|
||||
context=context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
# 从结果中提取记忆块
|
||||
memory_chunks = []
|
||||
if result.get("success"):
|
||||
memory_chunks = result.get("created_memories", [])
|
||||
|
||||
logger.info(f"从对话构建了 {len(memory_chunks)} 条记忆")
|
||||
return memory_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"process_conversation 失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_enhanced_memory_context(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
) -> List[EnhancedMemoryResult]:
|
||||
"""获取增强记忆上下文 - 新增功能"""
|
||||
if not self.is_initialized or not self.enhanced_system:
|
||||
return []
|
||||
|
||||
try:
|
||||
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
||||
query=query_text,
|
||||
user_id=user_id,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
results = []
|
||||
for memory in relevant_memories:
|
||||
result = EnhancedMemoryResult(
|
||||
content=memory.text_content,
|
||||
memory_type=memory.memory_type.value,
|
||||
confidence=memory.metadata.confidence.value,
|
||||
importance=memory.metadata.importance.value,
|
||||
timestamp=memory.metadata.created_at,
|
||||
source="enhanced_memory",
|
||||
relevance_score=memory.metadata.relevance_score
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"get_enhanced_memory_context 失败: {e}")
|
||||
return []
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭增强记忆系统"""
|
||||
if not self.is_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.enhanced_system:
|
||||
await self.enhanced_system.shutdown()
|
||||
logger.info("✅ 增强记忆系统已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭增强记忆系统失败: {e}")
|
||||
|
||||
|
||||
# 全局增强记忆管理器实例
|
||||
enhanced_memory_manager = EnhancedMemoryManager()
|
||||
@@ -1,168 +0,0 @@
|
||||
# 混合瞬时记忆系统设计
|
||||
|
||||
## 系统概述
|
||||
|
||||
融合 `instant_memory.py`(LLM系统)和 `vector_instant_memory.py`(向量系统)的混合记忆系统,智能选择最优策略,无需配置文件控制。
|
||||
|
||||
## 融合架构
|
||||
|
||||
```
|
||||
聊天输入 → 智能调度器 → 选择策略 → 双重存储 → 融合检索 → 统一输出
|
||||
```
|
||||
|
||||
## 核心组件设计
|
||||
|
||||
### 1. HybridInstantMemory (主类)
|
||||
|
||||
**职责**: 统一接口,智能调度两套记忆系统
|
||||
|
||||
**关键方法**:
|
||||
- `__init__(chat_id)` - 初始化两套子系统
|
||||
- `create_and_store_memory(text)` - 智能存储记忆
|
||||
- `get_memory(target)` - 融合检索记忆
|
||||
- `get_stats()` - 统计信息
|
||||
|
||||
### 2. MemoryStrategy (策略判断器)
|
||||
|
||||
**职责**: 判断使用哪种记忆策略
|
||||
|
||||
**判断规则**:
|
||||
- 文本长度 < 30字符 → 优先向量系统(快速)
|
||||
- 包含情感词汇/重要信息 → 使用LLM系统(准确)
|
||||
- 复杂场景 → 双重验证
|
||||
|
||||
**实现方法**:
|
||||
```python
|
||||
def decide_strategy(self, text: str) -> MemoryMode:
|
||||
# 长度判断
|
||||
if len(text) < 30:
|
||||
return MemoryMode.VECTOR_ONLY
|
||||
|
||||
# 情感关键词检测
|
||||
if self._contains_emotional_content(text):
|
||||
return MemoryMode.LLM_PREFERRED
|
||||
|
||||
# 默认混合模式
|
||||
return MemoryMode.HYBRID
|
||||
```
|
||||
|
||||
### 3. MemorySync (同步器)
|
||||
|
||||
**职责**: 处理两套系统间的记忆同步和去重
|
||||
|
||||
**同步策略**:
|
||||
- 向量系统存储的记忆 → 异步同步到LLM系统
|
||||
- LLM系统生成的高质量记忆 → 生成向量存储
|
||||
- 定期去重,避免重复记忆
|
||||
|
||||
### 4. HybridRetriever (检索器)
|
||||
|
||||
**职责**: 融合两种检索方式,提供最优结果
|
||||
|
||||
**检索策略**:
|
||||
1. 并行查询向量系统和LLM系统
|
||||
2. 按相似度/相关性排序
|
||||
3. 去重合并,返回最相关的记忆
|
||||
|
||||
## 智能调度逻辑
|
||||
|
||||
### 快速路径 (Vector Path)
|
||||
- 适用: 短文本、常规对话、快速查询
|
||||
- 优势: 响应速度快,资源消耗低
|
||||
- 时机: 文本简单、无特殊情感内容
|
||||
|
||||
### 准确路径 (LLM Path)
|
||||
- 适用: 重要信息、情感表达、复杂语义
|
||||
- 优势: 语义理解深度,记忆质量高
|
||||
- 时机: 检测到重要性标志
|
||||
|
||||
### 混合路径 (Hybrid Path)
|
||||
- 适用: 中等复杂度内容
|
||||
- 策略: 向量快速筛选 + LLM精确处理
|
||||
- 平衡: 速度与准确性
|
||||
|
||||
## 记忆存储策略
|
||||
|
||||
### 双重备份机制
|
||||
1. **主存储**: 根据策略选择主要存储方式
|
||||
2. **备份存储**: 异步备份到另一系统
|
||||
3. **同步检查**: 定期校验两边数据一致性
|
||||
|
||||
### 存储优化
|
||||
- 向量系统: 立即存储,快速可用
|
||||
- LLM系统: 批量处理,高质量整理
|
||||
- 重复检测: 跨系统去重
|
||||
|
||||
## 检索融合策略
|
||||
|
||||
### 并行检索
|
||||
```python
|
||||
async def get_memory(self, target: str):
|
||||
# 并行查询两个系统
|
||||
vector_task = self.vector_memory.get_memory(target)
|
||||
llm_task = self.llm_memory.get_memory(target)
|
||||
|
||||
vector_results, llm_results = await asyncio.gather(
|
||||
vector_task, llm_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# 融合结果
|
||||
return self._merge_results(vector_results, llm_results)
|
||||
```
|
||||
|
||||
### 结果融合
|
||||
1. **相似度评分**: 统一两种系统的相似度计算
|
||||
2. **权重调整**: 根据查询类型调整系统权重
|
||||
3. **去重合并**: 移除重复内容,保留最相关的
|
||||
|
||||
## 性能优化
|
||||
|
||||
### 异步处理
|
||||
- 向量检索: 同步快速响应
|
||||
- LLM处理: 异步后台处理
|
||||
- 批量操作: 减少系统调用开销
|
||||
|
||||
### 缓存策略
|
||||
- 热点记忆缓存
|
||||
- 查询结果缓存
|
||||
- 向量计算缓存
|
||||
|
||||
### 降级机制
|
||||
- 向量系统故障 → 只使用LLM系统
|
||||
- LLM系统故障 → 只使用向量系统
|
||||
- 全部故障 → 返回空结果,记录错误
|
||||
|
||||
## 实现计划
|
||||
|
||||
1. **基础框架**: 创建HybridInstantMemory主类
|
||||
2. **策略判断**: 实现智能调度逻辑
|
||||
3. **存储融合**: 实现双重存储机制
|
||||
4. **检索融合**: 实现并行检索和结果合并
|
||||
5. **同步机制**: 实现跨系统数据同步
|
||||
6. **性能优化**: 异步处理和缓存优化
|
||||
7. **错误处理**: 降级机制和异常处理
|
||||
|
||||
## 使用接口
|
||||
|
||||
```python
|
||||
# 初始化混合记忆系统
|
||||
hybrid_memory = HybridInstantMemory(chat_id="user_123")
|
||||
|
||||
# 智能存储记忆
|
||||
await hybrid_memory.create_and_store_memory("今天天气真好,我去公园散步了")
|
||||
|
||||
# 融合检索记忆
|
||||
memories = await hybrid_memory.get_memory("天气")
|
||||
|
||||
# 获取系统状态
|
||||
stats = hybrid_memory.get_stats()
|
||||
print(f"向量记忆: {stats['vector_count']} 条")
|
||||
print(f"LLM记忆: {stats['llm_count']} 条")
|
||||
```
|
||||
|
||||
## 预期效果
|
||||
|
||||
- **响应速度**: 比纯LLM系统快60%+
|
||||
- **记忆质量**: 比纯向量系统准确30%+
|
||||
- **资源使用**: 智能调度,按需使用资源
|
||||
- **可靠性**: 双系统备份,单点故障不影响服务
|
||||
@@ -1,254 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import re
|
||||
import orjson
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.config.config import model_config
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
self.memory_id = memory_id
|
||||
self.chat_id = chat_id
|
||||
self.memory_text: str = memory_text
|
||||
self.keywords: list[str] = keywords
|
||||
self.create_time: float = time.time()
|
||||
self.last_view_time: float = time.time()
|
||||
|
||||
|
||||
class InstantMemory:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
async def if_need_build(self, text):
|
||||
prompt = f"""
|
||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||
{text}
|
||||
请只输出1或0就好
|
||||
"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
return "1" in response
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def build_memory(self, text):
|
||||
prompt = f"""
|
||||
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
||||
{text}
|
||||
请以json格式输出一段概括的记忆内容和关键词
|
||||
{{
|
||||
"memory_text": "记忆内容",
|
||||
"keywords": "关键词,用/划分"
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = orjson.loads(repaired)
|
||||
memory_text = result.get("memory_text", "")
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
return {"memory_text": memory_text, "keywords": keywords_list}
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def create_and_store_memory(self, text):
|
||||
if_need = await self.if_need_build(text)
|
||||
if if_need:
|
||||
logger.info(f"需要记忆:{text}")
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get("memory_text"):
|
||||
memory_id = f"{self.chat_id}_{time.time()}"
|
||||
memory_item = MemoryItem(
|
||||
memory_id=memory_id,
|
||||
chat_id=self.chat_id,
|
||||
memory_text=memory["memory_text"],
|
||||
keywords=memory.get("keywords", []),
|
||||
)
|
||||
await self.store_memory(memory_item)
|
||||
else:
|
||||
logger.info(f"不需要记忆:{text}")
|
||||
|
||||
@staticmethod
|
||||
async def store_memory(memory_item: MemoryItem):
|
||||
async with get_db_session() as session:
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=orjson.dumps(memory_item.keywords).decode("utf-8"),
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
|
||||
prompt = f"""
|
||||
请根据以下发言内容,判断是否需要提取记忆
|
||||
{target}
|
||||
请用json格式输出,包含以下字段:
|
||||
其中,time的要求是:
|
||||
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
||||
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
||||
可以选择留空进行模糊搜索
|
||||
{{
|
||||
"need_memory": 1,
|
||||
"keywords": "希望获取的记忆关键词,用/划分",
|
||||
"time": "希望获取的记忆大致时间"
|
||||
}}
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = orjson.loads(repaired)
|
||||
# 解析keywords
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
# 解析time为时间段
|
||||
time_str = result.get("time", "").strip()
|
||||
start_time, end_time = self._parse_time_range(time_str)
|
||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||
# 检索包含关键词的记忆
|
||||
memories_set = set()
|
||||
async with get_db_session() as session:
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
|
||||
query = (await session.execute(
|
||||
select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)
|
||||
)).scalars()
|
||||
else:
|
||||
result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
|
||||
query = result.scalars()
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
mem_keywords_str = mem.keywords or "[]"
|
||||
try:
|
||||
mem_keywords = orjson.loads(mem_keywords_str)
|
||||
except orjson.JSONDecodeError:
|
||||
mem_keywords = []
|
||||
# logger.info(f"mem_keywords: {mem_keywords}")
|
||||
# logger.info(f"keywords_list: {keywords_list}")
|
||||
for kw in keywords_list:
|
||||
# logger.info(f"kw: {kw}")
|
||||
# logger.info(f"kw in mem_keywords: {kw in mem_keywords}")
|
||||
if kw in mem_keywords:
|
||||
# logger.info(f"mem.memory_text: {mem.memory_text}")
|
||||
memories_set.add(mem.memory_text)
|
||||
break
|
||||
return list(memories_set)
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_time_range(time_str):
|
||||
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
|
||||
"""
|
||||
支持解析如下格式:
|
||||
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
||||
- 具体日期:YYYY-MM-DD
|
||||
- 相对时间:今天,昨天,前天,N天前,N个月前
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
time_str = time_str.strip()
|
||||
# 具体日期时间
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||
return dt, dt + timedelta(hours=1)
|
||||
except Exception:
|
||||
...
|
||||
# 具体日期
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d")
|
||||
return dt, dt + timedelta(days=1)
|
||||
except Exception:
|
||||
...
|
||||
# 相对时间
|
||||
if time_str == "今天":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if time_str == "昨天":
|
||||
start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if time_str == "前天":
|
||||
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if m := re.match(r"(\d+)天前", time_str):
|
||||
days = int(m.group(1))
|
||||
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if m := re.match(r"(\d+)个月前", time_str):
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
# 其他无法解析
|
||||
return 0, now
|
||||
255
src/chat/memory_system/integration_layer.py
Normal file
255
src/chat/memory_system/integration_layer.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IntegrationMode(Enum):
|
||||
"""集成模式"""
|
||||
REPLACE = "replace" # 完全替换现有记忆系统
|
||||
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""集成配置"""
|
||||
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
|
||||
enable_enhanced_memory: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
enable_learning: bool = True
|
||||
|
||||
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
"total_queries": 0,
|
||||
"enhanced_queries": 0,
|
||||
"memory_creations": 0,
|
||||
"average_response_time": 0.0,
|
||||
"success_rate": 0.0
|
||||
}
|
||||
|
||||
# 初始化锁
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化集成层"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialization_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("🚀 开始初始化增强记忆系统集成层...")
|
||||
|
||||
try:
|
||||
# 初始化增强记忆系统
|
||||
if self.config.enable_enhanced_memory:
|
||||
await self._initialize_enhanced_memory()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统集成层初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _initialize_enhanced_memory(self):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.debug("初始化增强记忆系统...")
|
||||
|
||||
# 创建增强记忆系统配置
|
||||
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
|
||||
memory_config = MemorySystemConfig.from_global_config()
|
||||
|
||||
# 使用集成配置覆盖部分值
|
||||
memory_config.memory_value_threshold = self.config.memory_value_threshold
|
||||
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
|
||||
memory_config.final_recall_limit = self.config.max_retrieval_results
|
||||
|
||||
# 创建增强记忆系统
|
||||
self.enhanced_memory = EnhancedMemorySystem(
|
||||
config=memory_config
|
||||
)
|
||||
|
||||
# 如果外部提供了LLM模型,注入到系统中
|
||||
if self.llm_model is not None:
|
||||
self.enhanced_memory.llm_model = self.llm_model
|
||||
|
||||
# 初始化系统
|
||||
await self.enhanced_memory.initialize()
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""处理对话记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.integration_stats["total_queries"] += 1
|
||||
self.integration_stats["enhanced_queries"] += 1
|
||||
|
||||
try:
|
||||
# 直接使用增强记忆系统处理
|
||||
result = await self.enhanced_memory.process_conversation_memory(
|
||||
conversation_text=conversation_text,
|
||||
context=context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, result.get("success", False))
|
||||
|
||||
if result.get("success"):
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.integration_stats["memory_creations"] += created_count
|
||||
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, False)
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
memory_count = len(memories)
|
||||
logger.debug(f"检索到 {memory_count} 条相关记忆")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> Dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
try:
|
||||
enhanced_status = {}
|
||||
if self.enhanced_memory:
|
||||
enhanced_status = await self.enhanced_memory.get_system_status()
|
||||
|
||||
return {
|
||||
"status": "initialized",
|
||||
"mode": self.config.mode.value,
|
||||
"enhanced_memory": enhanced_status,
|
||||
"integration_stats": self.integration_stats.copy()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> Dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
def _update_response_stats(self, processing_time: float, success: bool):
|
||||
"""更新响应统计"""
|
||||
total_queries = self.integration_stats["total_queries"]
|
||||
if total_queries > 0:
|
||||
# 更新平均响应时间
|
||||
current_avg = self.integration_stats["average_response_time"]
|
||||
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
|
||||
self.integration_stats["average_response_time"] = new_avg
|
||||
|
||||
# 更新成功率
|
||||
if success:
|
||||
current_success_rate = self.integration_stats["success_rate"]
|
||||
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
|
||||
self.integration_stats["success_rate"] = new_success_rate
|
||||
|
||||
async def maintenance(self):
|
||||
"""执行维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 执行记忆系统集成层维护...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.maintenance()
|
||||
|
||||
logger.info("✅ 记忆系统集成层维护完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭集成层"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭记忆系统集成层...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.shutdown()
|
||||
|
||||
self._initialized = False
|
||||
logger.info("✅ 记忆系统集成层已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
|
||||
@@ -1,144 +0,0 @@
|
||||
import difflib
|
||||
import orjson
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
Args:
|
||||
json_str: JSON格式的字符串
|
||||
|
||||
Returns:
|
||||
List[str]: 关键词列表
|
||||
"""
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json = repair_json(json_str)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
return result.get("keywords", [])
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# --- Group Chat Prompt ---
|
||||
memory_activator_prompt = """
|
||||
你是一个记忆分析器,你需要根据以下信息来进行回忆
|
||||
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
|
||||
|
||||
聊天记录:
|
||||
{obs_info_text}
|
||||
你想要回复的消息:
|
||||
{target_message}
|
||||
|
||||
历史关键词(请避免重复提取这些关键词):
|
||||
{cached_keywords}
|
||||
|
||||
请输出一个json格式,包含以下字段:
|
||||
{{
|
||||
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
||||
}}
|
||||
不要输出其他多余内容,只输出json格式就好
|
||||
"""
|
||||
|
||||
Prompt(memory_activator_prompt, "memory_activator_prompt")
|
||||
|
||||
|
||||
class MemoryActivator:
|
||||
def __init__(self):
|
||||
self.key_words_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
|
||||
self.running_memory = []
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
# 将缓存的关键词转换为字符串,用于prompt
|
||||
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_activator_prompt",
|
||||
obs_info_text=chat_history_prompt,
|
||||
target_message=target_message,
|
||||
cached_keywords=cached_keywords_str,
|
||||
)
|
||||
|
||||
# logger.debug(f"prompt: {prompt}")
|
||||
|
||||
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
|
||||
prompt, temperature=0.5
|
||||
)
|
||||
|
||||
keywords = list(get_keywords_from_json(response))
|
||||
|
||||
# 更新关键词缓存
|
||||
if keywords:
|
||||
# 限制缓存大小,最多保留10个关键词
|
||||
if len(self.cached_keywords) > 10:
|
||||
# 转换为列表,移除最早的关键词
|
||||
cached_list = list(self.cached_keywords)
|
||||
self.cached_keywords = set(cached_list[-8:])
|
||||
|
||||
# 添加新的关键词到缓存
|
||||
self.cached_keywords.update(keywords)
|
||||
|
||||
# 调用记忆系统获取相关记忆
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
||||
)
|
||||
|
||||
logger.debug(f"当前记忆关键词: {self.cached_keywords} ")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
# 激活时,所有已有记忆的duration+1,达到3则移除
|
||||
for m in self.running_memory[:]:
|
||||
m["duration"] = m.get("duration", 1) + 1
|
||||
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
|
||||
|
||||
if related_memory:
|
||||
for topic, memory in related_memory:
|
||||
# 检查是否已存在相同topic或相似内容(相似度>=0.7)的记忆
|
||||
exists = any(
|
||||
m["topic"] == topic or difflib.SequenceMatcher(None, m["content"], memory).ratio() >= 0.7
|
||||
for m in self.running_memory
|
||||
)
|
||||
if not exists:
|
||||
self.running_memory.append(
|
||||
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat(), "duration": 1}
|
||||
)
|
||||
logger.debug(f"添加新记忆: {topic} - {memory}")
|
||||
|
||||
# 限制同时加载的记忆条数,最多保留最后3条
|
||||
if len(self.running_memory) > 3:
|
||||
self.running_memory = self.running_memory[-3:]
|
||||
|
||||
return self.running_memory
|
||||
|
||||
|
||||
init_prompt()
|
||||
602
src/chat/memory_system/memory_builder.py
Normal file
602
src/chat/memory_system/memory_builder.py
Normal file
@@ -0,0 +1,602 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆构建模块
|
||||
从对话流中提取高质量、结构化记忆单元
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Any, Set
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
||||
ContentStructure, MemoryMetadata, create_memory_chunk
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExtractionStrategy(Enum):
|
||||
"""提取策略"""
|
||||
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
||||
RULE_BASED = "rule_based" # 基于规则的提取
|
||||
HYBRID = "hybrid" # 混合策略
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""提取结果"""
|
||||
memories: List[MemoryChunk]
|
||||
confidence_scores: List[float]
|
||||
extraction_time: float
|
||||
strategy_used: ExtractionStrategy
|
||||
|
||||
|
||||
class MemoryBuilder:
|
||||
"""记忆构建器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest):
|
||||
self.llm_model = llm_model
|
||||
self.extraction_stats = {
|
||||
"total_extractions": 0,
|
||||
"successful_extractions": 0,
|
||||
"failed_extractions": 0,
|
||||
"average_confidence": 0.0
|
||||
}
|
||||
|
||||
async def build_memories(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""从对话中构建记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
# 预处理文本
|
||||
processed_text = self._preprocess_text(conversation_text)
|
||||
|
||||
# 确定提取策略
|
||||
strategy = self._determine_extraction_strategy(processed_text, context)
|
||||
|
||||
# 根据策略提取记忆
|
||||
if strategy == ExtractionStrategy.LLM_BASED:
|
||||
memories = await self._extract_with_llm(processed_text, context, user_id, timestamp)
|
||||
elif strategy == ExtractionStrategy.RULE_BASED:
|
||||
memories = self._extract_with_rules(processed_text, context, user_id, timestamp)
|
||||
else: # HYBRID
|
||||
memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp)
|
||||
|
||||
# 后处理和验证
|
||||
validated_memories = self._validate_and_enhance_memories(memories, context)
|
||||
|
||||
# 更新统计
|
||||
extraction_time = time.time() - start_time
|
||||
self._update_extraction_stats(len(validated_memories), extraction_time)
|
||||
|
||||
logger.info(f"✅ 成功构建 {len(validated_memories)} 条记忆,耗时 {extraction_time:.2f}秒")
|
||||
return validated_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
self.extraction_stats["failed_extractions"] += 1
|
||||
return []
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
"""预处理文本"""
|
||||
# 移除多余的空白字符
|
||||
text = re.sub(r'\s+', ' ', text.strip())
|
||||
|
||||
# 移除特殊字符,但保留基本标点
|
||||
text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、;:""''()【】]', '', text)
|
||||
|
||||
# 截断过长的文本
|
||||
if len(text) > 2000:
|
||||
text = text[:2000] + "..."
|
||||
|
||||
return text
|
||||
|
||||
def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy:
|
||||
"""确定提取策略"""
|
||||
text_length = len(text)
|
||||
has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"])
|
||||
message_type = context.get("message_type", "normal")
|
||||
|
||||
# 短文本使用规则提取
|
||||
if text_length < 50:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 包含结构化数据使用混合策略
|
||||
if has_structured_data:
|
||||
return ExtractionStrategy.HYBRID
|
||||
|
||||
# 系统消息或命令使用规则提取
|
||||
if message_type in ["command", "system"]:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 默认使用LLM提取
|
||||
return ExtractionStrategy.LLM_BASED
|
||||
|
||||
async def _extract_with_llm(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用LLM提取记忆"""
|
||||
try:
|
||||
prompt = self._build_llm_extraction_prompt(text, context)
|
||||
|
||||
response, _ = await self.llm_model.generate_response_async(
|
||||
prompt, temperature=0.3
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
memories = self._parse_llm_response(response, user_id, timestamp, context)
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
return []
|
||||
|
||||
def _extract_with_rules(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用规则提取记忆"""
|
||||
memories = []
|
||||
|
||||
# 规则1: 检测个人信息
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context)
|
||||
memories.extend(personal_info)
|
||||
|
||||
# 规则2: 检测偏好信息
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context)
|
||||
memories.extend(preferences)
|
||||
|
||||
# 规则3: 检测事件信息
|
||||
events = self._extract_events(text, user_id, timestamp, context)
|
||||
memories.extend(events)
|
||||
|
||||
return memories
|
||||
|
||||
async def _extract_with_hybrid(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""混合策略提取记忆"""
|
||||
all_memories = []
|
||||
|
||||
# 首先使用规则提取
|
||||
rule_memories = self._extract_with_rules(text, context, user_id, timestamp)
|
||||
all_memories.extend(rule_memories)
|
||||
|
||||
# 然后使用LLM提取
|
||||
llm_memories = await self._extract_with_llm(text, context, user_id, timestamp)
|
||||
|
||||
# 合并和去重
|
||||
final_memories = self._merge_hybrid_results(all_memories, llm_memories)
|
||||
|
||||
return final_memories
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_id = context.get("chat_id", "unknown")
|
||||
message_type = context.get("message_type", "normal")
|
||||
|
||||
prompt = f"""
|
||||
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
|
||||
|
||||
当前时间: {current_date}
|
||||
聊天ID: {chat_id}
|
||||
消息类型: {message_type}
|
||||
|
||||
对话内容:
|
||||
{text}
|
||||
|
||||
## 🎯 重点记忆类型识别指南
|
||||
|
||||
### 1. **个人事实** (personal_fact) - 高优先级记忆
|
||||
**包括但不限于:**
|
||||
- 基本信息:姓名、年龄、职业、学校、专业、工作地点
|
||||
- 生活状况:住址、电话、邮箱、社交账号
|
||||
- 身份特征:生日、星座、血型、国籍、语言能力
|
||||
- 健康信息:身体状况、疾病史、药物过敏、运动习惯
|
||||
- 家庭情况:家庭成员、婚姻状况、子女信息、宠物信息
|
||||
|
||||
**判断标准:** 涉及个人身份和生活的重要信息,都应该记忆
|
||||
|
||||
### 2. **事件** (event) - 高优先级记忆
|
||||
**包括但不限于:**
|
||||
- 重要时刻:生日聚会、毕业典礼、婚礼、旅行
|
||||
- 日常活动:上班、上学、约会、看电影、吃饭
|
||||
- 特殊经历:考试、面试、会议、搬家、购物
|
||||
- 计划安排:约会、会议、旅行、活动
|
||||
|
||||
**判断标准:** 涉及时间地点的具体活动和经历,都应该记忆
|
||||
|
||||
### 3. **偏好** (preference) - 高优先级记忆
|
||||
**包括但不限于:**
|
||||
- 饮食偏好:喜欢的食物、餐厅、口味、禁忌
|
||||
- 娱乐喜好:喜欢的电影、音乐、游戏、书籍
|
||||
- 生活习惯:作息时间、运动方式、购物习惯
|
||||
- 消费偏好:品牌喜好、价格敏感度、购物场所
|
||||
- 风格偏好:服装风格、装修风格、颜色喜好
|
||||
|
||||
**判断标准:** 任何表达"喜欢"、"不喜欢"、"习惯"、"经常"等偏好的内容,都应该记忆
|
||||
|
||||
### 4. **观点** (opinion) - 高优先级记忆
|
||||
**包括但不限于:**
|
||||
- 评价看法:对事物的评价、意见、建议
|
||||
- 价值判断:认为什么重要、什么不重要
|
||||
- 态度立场:支持、反对、中立的态度
|
||||
- 感受反馈:对经历的感受、反馈
|
||||
|
||||
**判断标准:** 任何表达主观看法和态度的内容,都应该记忆
|
||||
|
||||
### 5. **关系** (relationship) - 中等优先级记忆
|
||||
**包括但不限于:**
|
||||
- 人际关系:朋友、同事、家人、恋人的关系状态
|
||||
- 社交互动:与他人的互动、交流、合作
|
||||
- 群体归属:所属团队、组织、社群
|
||||
|
||||
### 6. **情感** (emotion) - 中等优先级记忆
|
||||
**包括但不限于:**
|
||||
- 情绪状态:开心、难过、生气、焦虑、兴奋
|
||||
- 情感变化:情绪的转变、原因和结果
|
||||
|
||||
### 7. **目标** (goal) - 中等优先级记忆
|
||||
**包括但不限于:**
|
||||
- 计划安排:短期计划、长期目标
|
||||
- 愿望期待:想要实现的事情、期望的结果
|
||||
|
||||
## 📝 记忆提取原则
|
||||
|
||||
### ✅ 积极提取原则:
|
||||
1. **宁可错记,不可遗漏** - 对于可能的个人信息优先记忆
|
||||
2. **持续追踪** - 相同信息的多次提及要强化记忆
|
||||
3. **上下文关联** - 结合对话背景理解信息重要性
|
||||
4. **细节丰富** - 记录具体的细节和描述
|
||||
|
||||
### 🎯 重要性等级标准:
|
||||
- **4分 (关键)**:个人核心信息(姓名、联系方式、重要日期)
|
||||
- **3分 (高)**:重要偏好、观点、经历事件
|
||||
- **2分 (一般)**:一般性信息、日常活动、感受表达
|
||||
- **1分 (低)**:琐碎细节、重复信息、临时状态
|
||||
|
||||
### 🔍 置信度标准:
|
||||
- **4分 (已验证)**:用户明确确认的信息
|
||||
- **3分 (高)**:用户直接表达的清晰信息
|
||||
- **2分 (中等)**:需要推理或上下文判断的信息
|
||||
- **1分 (低)**:模糊或不完整的信息
|
||||
|
||||
输出格式要求:
|
||||
{{
|
||||
"memories": [
|
||||
{{
|
||||
"type": "记忆类型",
|
||||
"subject": "主语(通常是用户)",
|
||||
"predicate": "谓语(动作/状态)",
|
||||
"object": "宾语(对象/属性)",
|
||||
"keywords": ["关键词1", "关键词2"],
|
||||
"importance": "重要性等级(1-4)",
|
||||
"confidence": "置信度(1-4)",
|
||||
"reasoning": "提取理由"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:
|
||||
1. 只提取确实值得记忆的信息,不要提取琐碎内容
|
||||
2. 确保提取的信息准确、具体、有价值
|
||||
3. 使用主谓宾结构确保信息清晰
|
||||
4. 重要性等级: 1=低, 2=一般, 3=高, 4=关键
|
||||
5. 置信度: 1=低, 2=中等, 3=高, 4=已验证
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
memories = []
|
||||
|
||||
try:
|
||||
data = orjson.loads(response)
|
||||
memory_list = data.get("memories", [])
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
# 创建记忆块
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=mem_data.get("subject", user_id),
|
||||
predicate=mem_data.get("predicate", ""),
|
||||
obj=mem_data.get("object", ""),
|
||||
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
||||
chat_id=context.get("chat_id"),
|
||||
source_context=mem_data.get("reasoning", ""),
|
||||
importance=ImportanceLevel(mem_data.get("importance", 2)),
|
||||
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
|
||||
)
|
||||
|
||||
# 添加关键词
|
||||
keywords = mem_data.get("keywords", [])
|
||||
for keyword in keywords:
|
||||
memory.add_keyword(keyword)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析单个记忆失败: {e}, 数据: {mem_data}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_personal_info(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取个人信息"""
|
||||
memories = []
|
||||
|
||||
# 常见个人信息模式
|
||||
patterns = {
|
||||
r"我叫(\w+)": ("is_named", {"name": "$1"}),
|
||||
r"我今年(\d+)岁": ("is_age", {"age": "$1"}),
|
||||
r"我是(\w+)": ("is_profession", {"profession": "$1"}),
|
||||
r"我住在(\w+)": ("lives_in", {"location": "$1"}),
|
||||
r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}),
|
||||
r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}),
|
||||
}
|
||||
|
||||
for pattern, (predicate, obj_template) in patterns.items():
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
obj = obj_template
|
||||
for i, group in enumerate(match.groups(), 1):
|
||||
obj = {k: v.replace(f"${i}", group) for k, v in obj.items()}
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
memory_type=MemoryType.PERSONAL_FACT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.HIGH,
|
||||
confidence=ConfidenceLevel.HIGH
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_preferences(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取偏好信息"""
|
||||
memories = []
|
||||
|
||||
# 偏好模式
|
||||
preference_patterns = [
|
||||
(r"我喜欢(.+)", "likes"),
|
||||
(r"我不喜欢(.+)", "dislikes"),
|
||||
(r"我爱吃(.+)", "likes_food"),
|
||||
(r"我讨厌(.+)", "hates"),
|
||||
(r"我最喜欢的(.+)", "favorite_is"),
|
||||
]
|
||||
|
||||
for pattern, predicate in preference_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
predicate=predicate,
|
||||
obj=match.group(1),
|
||||
memory_type=MemoryType.PREFERENCE,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_events(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取事件信息"""
|
||||
memories = []
|
||||
|
||||
# 事件关键词
|
||||
event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"]
|
||||
|
||||
if any(keyword in text for keyword in event_keywords):
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
predicate="mentioned_event",
|
||||
obj={"event_text": text, "timestamp": timestamp},
|
||||
memory_type=MemoryType.EVENT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _merge_hybrid_results(
|
||||
self,
|
||||
rule_memories: List[MemoryChunk],
|
||||
llm_memories: List[MemoryChunk]
|
||||
) -> List[MemoryChunk]:
|
||||
"""合并混合策略结果"""
|
||||
all_memories = rule_memories.copy()
|
||||
|
||||
# 添加LLM记忆,避免重复
|
||||
for llm_memory in llm_memories:
|
||||
is_duplicate = False
|
||||
for rule_memory in rule_memories:
|
||||
if llm_memory.is_similar_to(rule_memory, threshold=0.7):
|
||||
is_duplicate = True
|
||||
# 合并置信度
|
||||
rule_memory.metadata.confidence = ConfidenceLevel(
|
||||
max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value)
|
||||
)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
all_memories.append(llm_memory)
|
||||
|
||||
return all_memories
|
||||
|
||||
def _validate_and_enhance_memories(
|
||||
self,
|
||||
memories: List[MemoryChunk],
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""验证和增强记忆"""
|
||||
validated_memories = []
|
||||
|
||||
for memory in memories:
|
||||
# 基本验证
|
||||
if not self._validate_memory(memory):
|
||||
continue
|
||||
|
||||
# 增强记忆
|
||||
enhanced_memory = self._enhance_memory(memory, context)
|
||||
validated_memories.append(enhanced_memory)
|
||||
|
||||
return validated_memories
|
||||
|
||||
def _validate_memory(self, memory: MemoryChunk) -> bool:
|
||||
"""验证记忆块"""
|
||||
# 检查基本字段
|
||||
if not memory.content.subject or not memory.content.predicate:
|
||||
logger.debug(f"记忆块缺少主语或谓语: {memory.memory_id}")
|
||||
return False
|
||||
|
||||
# 检查内容长度
|
||||
content_length = len(memory.text_content)
|
||||
if content_length < 5 or content_length > 500:
|
||||
logger.debug(f"记忆块内容长度异常: {content_length}")
|
||||
return False
|
||||
|
||||
# 检查置信度
|
||||
if memory.metadata.confidence == ConfidenceLevel.LOW:
|
||||
logger.debug(f"记忆块置信度过低: {memory.memory_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _enhance_memory(
|
||||
self,
|
||||
memory: MemoryChunk,
|
||||
context: Dict[str, Any]
|
||||
) -> MemoryChunk:
|
||||
"""增强记忆块"""
|
||||
# 添加时间上下文
|
||||
if not memory.temporal_context:
|
||||
memory.temporal_context = {
|
||||
"timestamp": memory.metadata.created_at,
|
||||
"timezone": context.get("timezone", "UTC"),
|
||||
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A")
|
||||
}
|
||||
|
||||
# 添加情感上下文(如果有)
|
||||
if context.get("sentiment"):
|
||||
memory.metadata.emotional_context = context["sentiment"]
|
||||
|
||||
# 自动添加标签
|
||||
self._auto_tag_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
def _auto_tag_memory(self, memory: MemoryChunk):
|
||||
"""自动为记忆添加标签"""
|
||||
# 基于记忆类型的自动标签
|
||||
type_tags = {
|
||||
MemoryType.PERSONAL_FACT: ["个人信息", "基本资料"],
|
||||
MemoryType.EVENT: ["事件", "日程"],
|
||||
MemoryType.PREFERENCE: ["偏好", "喜好"],
|
||||
MemoryType.OPINION: ["观点", "态度"],
|
||||
MemoryType.RELATIONSHIP: ["关系", "社交"],
|
||||
MemoryType.EMOTION: ["情感", "情绪"],
|
||||
MemoryType.KNOWLEDGE: ["知识", "信息"],
|
||||
MemoryType.SKILL: ["技能", "能力"],
|
||||
MemoryType.GOAL: ["目标", "计划"],
|
||||
MemoryType.EXPERIENCE: ["经验", "经历"],
|
||||
}
|
||||
|
||||
tags = type_tags.get(memory.memory_type, [])
|
||||
for tag in tags:
|
||||
memory.add_tag(tag)
|
||||
|
||||
def _update_extraction_stats(self, success_count: int, extraction_time: float):
|
||||
"""更新提取统计"""
|
||||
self.extraction_stats["total_extractions"] += 1
|
||||
self.extraction_stats["successful_extractions"] += success_count
|
||||
self.extraction_stats["failed_extractions"] += max(0, 1 - success_count)
|
||||
|
||||
# 更新平均置信度
|
||||
if self.extraction_stats["successful_extractions"] > 0:
|
||||
total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count)
|
||||
# 假设新记忆的平均置信度为0.8
|
||||
total_confidence += 0.8 * success_count
|
||||
self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"]
|
||||
|
||||
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||
"""获取提取统计信息"""
|
||||
return self.extraction_stats.copy()
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.extraction_stats = {
|
||||
"total_extractions": 0,
|
||||
"successful_extractions": 0,
|
||||
"failed_extractions": 0,
|
||||
"average_confidence": 0.0
|
||||
}
|
||||
463
src/chat/memory_system/memory_chunk.py
Normal file
463
src/chat/memory_system/memory_chunk.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构化记忆单元设计
|
||||
实现高质量、结构化的记忆单元,符合文档设计规范
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
|
||||
import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryType(Enum):
|
||||
"""记忆类型分类"""
|
||||
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
|
||||
EVENT = "event" # 事件(重要经历、约会等)
|
||||
PREFERENCE = "preference" # 偏好(喜好、习惯等)
|
||||
OPINION = "opinion" # 观点(对事物的看法)
|
||||
RELATIONSHIP = "relationship" # 关系(与他人的关系)
|
||||
EMOTION = "emotion" # 情感状态
|
||||
KNOWLEDGE = "knowledge" # 知识信息
|
||||
SKILL = "skill" # 技能能力
|
||||
GOAL = "goal" # 目标计划
|
||||
EXPERIENCE = "experience" # 经验教训
|
||||
CONTEXTUAL = "contextual" # 上下文信息
|
||||
|
||||
|
||||
class ConfidenceLevel(Enum):
|
||||
"""置信度等级"""
|
||||
LOW = 1 # 低置信度,可能不准确
|
||||
MEDIUM = 2 # 中等置信度,有一定依据
|
||||
HIGH = 3 # 高置信度,有明确来源
|
||||
VERIFIED = 4 # 已验证,非常可靠
|
||||
|
||||
|
||||
class ImportanceLevel(Enum):
|
||||
"""重要性等级"""
|
||||
LOW = 1 # 低重要性,普通信息
|
||||
NORMAL = 2 # 一般重要性,日常信息
|
||||
HIGH = 3 # 高重要性,重要信息
|
||||
CRITICAL = 4 # 关键重要性,核心信息
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentStructure:
|
||||
"""主谓宾三元组结构"""
|
||||
subject: str # 主语(通常为用户)
|
||||
predicate: str # 谓语(动作、状态、关系)
|
||||
object: Union[str, Dict] # 宾语(对象、属性、值)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
subject=data.get("subject", ""),
|
||||
predicate=data.get("predicate", ""),
|
||||
object=data.get("object", "")
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
if isinstance(self.object, dict):
|
||||
object_str = str(self.object)
|
||||
else:
|
||||
object_str = str(self.object)
|
||||
return f"{self.subject} {self.predicate} {object_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryMetadata:
|
||||
"""记忆元数据"""
|
||||
# 基础信息
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
|
||||
# 时间信息
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
last_accessed: float = 0.0 # 最后访问时间
|
||||
last_modified: float = 0.0 # 最后修改时间
|
||||
|
||||
# 统计信息
|
||||
access_count: int = 0 # 访问次数
|
||||
relevance_score: float = 0.0 # 相关度评分
|
||||
|
||||
# 信心和重要性
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL
|
||||
|
||||
# 情感和关系
|
||||
emotional_context: Optional[str] = None # 情感上下文
|
||||
relationship_score: float = 0.0 # 关系分(0-1)
|
||||
|
||||
# 来源和验证
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
verification_status: bool = False # 验证状态
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if not self.memory_id:
|
||||
self.memory_id = str(uuid.uuid4())
|
||||
|
||||
if self.created_at == 0:
|
||||
self.created_at = time.time()
|
||||
|
||||
if self.last_accessed == 0:
|
||||
self.last_accessed = self.created_at
|
||||
|
||||
if self.last_modified == 0:
|
||||
self.last_modified = self.created_at
|
||||
|
||||
def update_access(self):
|
||||
"""更新访问信息"""
|
||||
current_time = time.time()
|
||||
self.last_accessed = current_time
|
||||
self.access_count += 1
|
||||
|
||||
def update_relevance(self, new_score: float):
|
||||
"""更新相关度评分"""
|
||||
self.relevance_score = max(0.0, min(1.0, new_score))
|
||||
self.last_modified = time.time()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"memory_id": self.memory_id,
|
||||
"user_id": self.user_id,
|
||||
"chat_id": self.chat_id,
|
||||
"created_at": self.created_at,
|
||||
"last_accessed": self.last_accessed,
|
||||
"last_modified": self.last_modified,
|
||||
"access_count": self.access_count,
|
||||
"relevance_score": self.relevance_score,
|
||||
"confidence": self.confidence.value,
|
||||
"importance": self.importance.value,
|
||||
"emotional_context": self.emotional_context,
|
||||
"relationship_score": self.relationship_score,
|
||||
"source_context": self.source_context,
|
||||
"verification_status": self.verification_status
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
memory_id=data.get("memory_id", ""),
|
||||
user_id=data.get("user_id", ""),
|
||||
chat_id=data.get("chat_id"),
|
||||
created_at=data.get("created_at", 0),
|
||||
last_accessed=data.get("last_accessed", 0),
|
||||
last_modified=data.get("last_modified", 0),
|
||||
access_count=data.get("access_count", 0),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
confidence=ConfidenceLevel(data.get("confidence", ConfidenceLevel.MEDIUM.value)),
|
||||
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
|
||||
emotional_context=data.get("emotional_context"),
|
||||
relationship_score=data.get("relationship_score", 0.0),
|
||||
source_context=data.get("source_context"),
|
||||
verification_status=data.get("verification_status", False)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""结构化记忆单元 - 核心数据结构"""
|
||||
|
||||
# 元数据
|
||||
metadata: MemoryMetadata
|
||||
|
||||
# 内容结构
|
||||
content: ContentStructure # 主谓宾结构
|
||||
memory_type: MemoryType # 记忆类型
|
||||
|
||||
# 扩展信息
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
|
||||
# 语义信息
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
|
||||
# 关联信息
|
||||
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
if self.embedding and len(self.embedding) > 0:
|
||||
self._generate_semantic_hash()
|
||||
|
||||
def _generate_semantic_hash(self):
|
||||
"""生成语义哈希值"""
|
||||
if not self.embedding:
|
||||
return
|
||||
|
||||
try:
|
||||
# 使用向量和内容生成稳定的哈希
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
|
||||
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
|
||||
|
||||
hash_input = f"{content_str}|{embedding_str}"
|
||||
hash_object = hashlib.sha256(hash_input.encode('utf-8'))
|
||||
self.semantic_hash = hash_object.hexdigest()[:16]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"生成语义哈希失败: {e}")
|
||||
self.semantic_hash = str(uuid.uuid4())[:16]
|
||||
|
||||
@property
|
||||
def memory_id(self) -> str:
|
||||
"""获取记忆ID"""
|
||||
return self.metadata.memory_id
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""获取用户ID"""
|
||||
return self.metadata.user_id
|
||||
|
||||
@property
|
||||
def text_content(self) -> str:
|
||||
"""获取文本内容"""
|
||||
return str(self.content)
|
||||
|
||||
def update_access(self):
|
||||
"""更新访问信息"""
|
||||
self.metadata.update_access()
|
||||
|
||||
def update_relevance(self, new_score: float):
|
||||
"""更新相关度评分"""
|
||||
self.metadata.update_relevance(new_score)
|
||||
|
||||
def add_keyword(self, keyword: str):
|
||||
"""添加关键词"""
|
||||
if keyword and keyword not in self.keywords:
|
||||
self.keywords.append(keyword.strip())
|
||||
|
||||
def add_tag(self, tag: str):
|
||||
"""添加标签"""
|
||||
if tag and tag not in self.tags:
|
||||
self.tags.append(tag.strip())
|
||||
|
||||
def add_category(self, category: str):
|
||||
"""添加分类"""
|
||||
if category and category not in self.categories:
|
||||
self.categories.append(category.strip())
|
||||
|
||||
def add_related_memory(self, memory_id: str):
|
||||
"""添加关联记忆"""
|
||||
if memory_id and memory_id not in self.related_memories:
|
||||
self.related_memories.append(memory_id)
|
||||
|
||||
def set_embedding(self, embedding: List[float]):
|
||||
"""设置语义向量"""
|
||||
self.embedding = embedding
|
||||
self._generate_semantic_hash()
|
||||
|
||||
def calculate_similarity(self, other: "MemoryChunk") -> float:
|
||||
"""计算与另一个记忆块的相似度"""
|
||||
if not self.embedding or not other.embedding:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# 计算余弦相似度
|
||||
v1 = np.array(self.embedding)
|
||||
v2 = np.array(other.embedding)
|
||||
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = dot_product / (norm1 * norm2)
|
||||
return max(0.0, min(1.0, similarity))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算记忆相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为完整的字典格式"""
|
||||
return {
|
||||
"metadata": self.metadata.to_dict(),
|
||||
"content": self.content.to_dict(),
|
||||
"memory_type": self.memory_type.value,
|
||||
"keywords": self.keywords,
|
||||
"tags": self.tags,
|
||||
"categories": self.categories,
|
||||
"embedding": self.embedding,
|
||||
"semantic_hash": self.semantic_hash,
|
||||
"related_memories": self.related_memories,
|
||||
"temporal_context": self.temporal_context
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
|
||||
"""从字典创建实例"""
|
||||
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
|
||||
content = ContentStructure.from_dict(data.get("content", {}))
|
||||
|
||||
chunk = cls(
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
memory_type=MemoryType(data.get("memory_type", MemoryType.CONTEXTUAL.value)),
|
||||
keywords=data.get("keywords", []),
|
||||
tags=data.get("tags", []),
|
||||
categories=data.get("categories", []),
|
||||
embedding=data.get("embedding"),
|
||||
semantic_hash=data.get("semantic_hash"),
|
||||
related_memories=data.get("related_memories", []),
|
||||
temporal_context=data.get("temporal_context")
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""转换为JSON字符串"""
|
||||
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "MemoryChunk":
|
||||
"""从JSON字符串创建实例"""
|
||||
try:
|
||||
data = orjson.loads(json_str)
|
||||
return cls.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"从JSON创建记忆块失败: {e}")
|
||||
raise
|
||||
|
||||
def is_similar_to(self, other: "MemoryChunk", threshold: float = 0.8) -> bool:
|
||||
"""判断是否与另一个记忆块相似"""
|
||||
if self.semantic_hash and other.semantic_hash:
|
||||
return self.semantic_hash == other.semantic_hash
|
||||
|
||||
return self.calculate_similarity(other) >= threshold
|
||||
|
||||
def merge_with(self, other: "MemoryChunk") -> bool:
|
||||
"""与另一个记忆块合并(如果相似)"""
|
||||
if not self.is_similar_to(other):
|
||||
return False
|
||||
|
||||
try:
|
||||
# 合并关键词
|
||||
for keyword in other.keywords:
|
||||
self.add_keyword(keyword)
|
||||
|
||||
# 合并标签
|
||||
for tag in other.tags:
|
||||
self.add_tag(tag)
|
||||
|
||||
# 合并分类
|
||||
for category in other.categories:
|
||||
self.add_category(category)
|
||||
|
||||
# 合并关联记忆
|
||||
for memory_id in other.related_memories:
|
||||
self.add_related_memory(memory_id)
|
||||
|
||||
# 更新元数据
|
||||
self.metadata.last_modified = time.time()
|
||||
self.metadata.access_count += other.metadata.access_count
|
||||
self.metadata.relevance_score = max(self.metadata.relevance_score, other.metadata.relevance_score)
|
||||
|
||||
# 更新置信度
|
||||
if other.metadata.confidence.value > self.metadata.confidence.value:
|
||||
self.metadata.confidence = other.metadata.confidence
|
||||
|
||||
# 更新重要性
|
||||
if other.metadata.importance.value > self.metadata.importance.value:
|
||||
self.metadata.importance = other.metadata.importance
|
||||
|
||||
logger.debug(f"记忆块 {self.memory_id} 合并了记忆块 {other.memory_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆块失败: {e}")
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
type_emoji = {
|
||||
MemoryType.PERSONAL_FACT: "👤",
|
||||
MemoryType.EVENT: "📅",
|
||||
MemoryType.PREFERENCE: "❤️",
|
||||
MemoryType.OPINION: "💭",
|
||||
MemoryType.RELATIONSHIP: "👥",
|
||||
MemoryType.EMOTION: "😊",
|
||||
MemoryType.KNOWLEDGE: "📚",
|
||||
MemoryType.SKILL: "🛠️",
|
||||
MemoryType.GOAL: "🎯",
|
||||
MemoryType.EXPERIENCE: "💡",
|
||||
MemoryType.CONTEXTUAL: "📝"
|
||||
}
|
||||
|
||||
emoji = type_emoji.get(self.memory_type, "📝")
|
||||
confidence_icon = "●" * self.metadata.confidence.value
|
||||
importance_icon = "★" * self.metadata.importance.value
|
||||
|
||||
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""调试表示"""
|
||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||
|
||||
|
||||
def create_memory_chunk(
|
||||
user_id: str,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: Union[str, Dict],
|
||||
memory_type: MemoryType,
|
||||
chat_id: Optional[str] = None,
|
||||
source_context: Optional[str] = None,
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
**kwargs
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
metadata = MemoryMetadata(
|
||||
memory_id="",
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
created_at=time.time(),
|
||||
last_accessed=0,
|
||||
last_modified=0,
|
||||
confidence=confidence,
|
||||
importance=importance,
|
||||
source_context=source_context
|
||||
)
|
||||
|
||||
content = ContentStructure(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
object=obj
|
||||
)
|
||||
|
||||
chunk = MemoryChunk(
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
memory_type=memory_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return chunk
|
||||
522
src/chat/memory_system/memory_fusion.py
Normal file
522
src/chat/memory_system/memory_fusion.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆融合与去重机制
|
||||
避免记忆碎片化,确保长期记忆库的高质量
|
||||
"""
|
||||
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionResult:
|
||||
"""融合结果"""
|
||||
original_count: int
|
||||
fused_count: int
|
||||
removed_duplicates: int
|
||||
merged_memories: List[MemoryChunk]
|
||||
fusion_time: float
|
||||
details: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DuplicateGroup:
|
||||
"""重复记忆组"""
|
||||
group_id: str
|
||||
memories: List[MemoryChunk]
|
||||
similarity_matrix: List[List[float]]
|
||||
representative_memory: Optional[MemoryChunk] = None
|
||||
|
||||
|
||||
class MemoryFusionEngine:
|
||||
"""记忆融合引擎"""
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.fusion_stats = {
|
||||
"total_fusions": 0,
|
||||
"memories_fused": 0,
|
||||
"duplicates_removed": 0,
|
||||
"average_similarity": 0.0
|
||||
}
|
||||
|
||||
# 融合策略配置
|
||||
self.fusion_strategies = {
|
||||
"semantic_similarity": True, # 语义相似性融合
|
||||
"temporal_proximity": True, # 时间接近性融合
|
||||
"logical_consistency": True, # 逻辑一致性融合
|
||||
"confidence_boosting": True, # 置信度提升
|
||||
"importance_preservation": True # 重要性保持
|
||||
}
|
||||
|
||||
async def fuse_memories(
|
||||
self,
|
||||
new_memories: List[MemoryChunk],
|
||||
existing_memories: Optional[List[MemoryChunk]] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""融合记忆列表"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not new_memories:
|
||||
return []
|
||||
|
||||
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
|
||||
|
||||
# 1. 检测重复记忆组
|
||||
duplicate_groups = await self._detect_duplicate_groups(
|
||||
new_memories, existing_memories or []
|
||||
)
|
||||
|
||||
# 2. 对每个重复组进行融合
|
||||
fused_memories = []
|
||||
removed_count = 0
|
||||
|
||||
for group in duplicate_groups:
|
||||
if len(group.memories) == 1:
|
||||
# 单个记忆,直接添加
|
||||
fused_memories.append(group.memories[0])
|
||||
else:
|
||||
# 多个记忆,进行融合
|
||||
fused_memory = await self._fuse_memory_group(group)
|
||||
if fused_memory:
|
||||
fused_memories.append(fused_memory)
|
||||
removed_count += len(group.memories) - 1
|
||||
|
||||
# 3. 更新统计
|
||||
fusion_time = time.time() - start_time
|
||||
self._update_fusion_stats(len(new_memories), removed_count, fusion_time)
|
||||
|
||||
logger.info(f"✅ 记忆融合完成: {len(fused_memories)} 条记忆,移除 {removed_count} 条重复")
|
||||
return fused_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆融合失败: {e}", exc_info=True)
|
||||
return new_memories # 失败时返回原始记忆
|
||||
|
||||
async def _detect_duplicate_groups(
|
||||
self,
|
||||
new_memories: List[MemoryChunk],
|
||||
existing_memories: List[MemoryChunk]
|
||||
) -> List[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
groups = []
|
||||
processed_ids = set()
|
||||
|
||||
for i, memory1 in enumerate(all_memories):
|
||||
if memory1.memory_id in processed_ids:
|
||||
continue
|
||||
|
||||
# 创建新的重复组
|
||||
group = DuplicateGroup(
|
||||
group_id=f"group_{len(groups)}",
|
||||
memories=[memory1],
|
||||
similarity_matrix=[[1.0]]
|
||||
)
|
||||
|
||||
processed_ids.add(memory1.memory_id)
|
||||
|
||||
# 寻找相似记忆
|
||||
for j, memory2 in enumerate(all_memories[i+1:], i+1):
|
||||
if memory2.memory_id in processed_ids:
|
||||
continue
|
||||
|
||||
similarity = self._calculate_comprehensive_similarity(memory1, memory2)
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
group.memories.append(memory2)
|
||||
processed_ids.add(memory2.memory_id)
|
||||
|
||||
# 更新相似度矩阵
|
||||
self._update_similarity_matrix(group, memory2, similarity)
|
||||
|
||||
if len(group.memories) > 1:
|
||||
# 选择代表性记忆
|
||||
group.representative_memory = self._select_representative_memory(group)
|
||||
groups.append(group)
|
||||
|
||||
logger.debug(f"检测到 {len(groups)} 个重复记忆组")
|
||||
return groups
|
||||
|
||||
def _calculate_comprehensive_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
|
||||
"""计算综合相似度"""
|
||||
similarity_scores = []
|
||||
|
||||
# 1. 语义向量相似度
|
||||
if self.fusion_strategies["semantic_similarity"]:
|
||||
semantic_sim = mem1.calculate_similarity(mem2)
|
||||
similarity_scores.append(("semantic", semantic_sim))
|
||||
|
||||
# 2. 文本相似度
|
||||
text_sim = self._calculate_text_similarity(mem1.text_content, mem2.text_content)
|
||||
similarity_scores.append(("text", text_sim))
|
||||
|
||||
# 3. 关键词重叠度
|
||||
keyword_sim = self._calculate_keyword_similarity(mem1.keywords, mem2.keywords)
|
||||
similarity_scores.append(("keyword", keyword_sim))
|
||||
|
||||
# 4. 类型一致性
|
||||
type_consistency = 1.0 if mem1.memory_type == mem2.memory_type else 0.0
|
||||
similarity_scores.append(("type", type_consistency))
|
||||
|
||||
# 5. 时间接近性
|
||||
if self.fusion_strategies["temporal_proximity"]:
|
||||
temporal_sim = self._calculate_temporal_similarity(
|
||||
mem1.metadata.created_at, mem2.metadata.created_at
|
||||
)
|
||||
similarity_scores.append(("temporal", temporal_sim))
|
||||
|
||||
# 6. 逻辑一致性
|
||||
if self.fusion_strategies["logical_consistency"]:
|
||||
logical_sim = self._calculate_logical_similarity(mem1, mem2)
|
||||
similarity_scores.append(("logical", logical_sim))
|
||||
|
||||
# 计算加权平均相似度
|
||||
weights = {
|
||||
"semantic": 0.35,
|
||||
"text": 0.25,
|
||||
"keyword": 0.15,
|
||||
"type": 0.10,
|
||||
"temporal": 0.10,
|
||||
"logical": 0.05
|
||||
}
|
||||
|
||||
weighted_sum = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for score_type, score in similarity_scores:
|
||||
weight = weights.get(score_type, 0.1)
|
||||
weighted_sum += weight * score
|
||||
total_weight += weight
|
||||
|
||||
final_similarity = weighted_sum / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
logger.debug(f"综合相似度计算: {final_similarity:.3f} - {[(t, f'{s:.3f}') for t, s in similarity_scores]}")
|
||||
|
||||
return final_similarity
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""计算文本相似度"""
|
||||
# 简单的词汇重叠度计算
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1 & words2
|
||||
union = words1 | words2
|
||||
|
||||
jaccard_similarity = len(intersection) / len(union)
|
||||
return jaccard_similarity
|
||||
|
||||
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
|
||||
"""计算关键词相似度"""
|
||||
if not keywords1 or not keywords2:
|
||||
return 0.0
|
||||
|
||||
set1 = set(k.lower() for k in keywords1)
|
||||
set2 = set(k.lower() for k in keywords2)
|
||||
|
||||
intersection = set1 & set2
|
||||
union = set1 | set2
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _calculate_temporal_similarity(self, time1: float, time2: float) -> float:
|
||||
"""计算时间相似度"""
|
||||
time_diff = abs(time1 - time2)
|
||||
hours_diff = time_diff / 3600
|
||||
|
||||
# 24小时内相似度较高
|
||||
if hours_diff <= 24:
|
||||
return 1.0 - (hours_diff / 24)
|
||||
elif hours_diff <= 168: # 一周内
|
||||
return 0.7 - ((hours_diff - 24) / 168) * 0.5
|
||||
else:
|
||||
return 0.2
|
||||
|
||||
def _calculate_logical_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
|
||||
"""计算逻辑一致性"""
|
||||
# 检查主谓宾结构的逻辑一致性
|
||||
consistency_score = 0.0
|
||||
|
||||
# 主语一致性
|
||||
if mem1.content.subject == mem2.content.subject:
|
||||
consistency_score += 0.4
|
||||
|
||||
# 谓语相似性
|
||||
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
|
||||
consistency_score += predicate_sim * 0.3
|
||||
|
||||
# 宾语相似性
|
||||
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
|
||||
object_sim = self._calculate_text_similarity(
|
||||
str(mem1.content.object), str(mem2.content.object)
|
||||
)
|
||||
consistency_score += object_sim * 0.3
|
||||
|
||||
return consistency_score
|
||||
|
||||
def _update_similarity_matrix(self, group: DuplicateGroup, new_memory: MemoryChunk, similarity: float):
|
||||
"""更新组的相似度矩阵"""
|
||||
# 为新记忆添加行和列
|
||||
for i in range(len(group.similarity_matrix)):
|
||||
group.similarity_matrix[i].append(similarity)
|
||||
|
||||
# 添加新行
|
||||
new_row = [similarity] + [1.0] * len(group.similarity_matrix)
|
||||
group.similarity_matrix.append(new_row)
|
||||
|
||||
def _select_representative_memory(self, group: DuplicateGroup) -> MemoryChunk:
|
||||
"""选择代表性记忆"""
|
||||
if not group.memories:
|
||||
return None
|
||||
|
||||
# 评分标准
|
||||
best_memory = None
|
||||
best_score = -1.0
|
||||
|
||||
for memory in group.memories:
|
||||
score = 0.0
|
||||
|
||||
# 置信度权重
|
||||
score += memory.metadata.confidence.value * 0.3
|
||||
|
||||
# 重要性权重
|
||||
score += memory.metadata.importance.value * 0.3
|
||||
|
||||
# 访问次数权重
|
||||
score += min(memory.metadata.access_count * 0.1, 0.2)
|
||||
|
||||
# 相关度权重
|
||||
score += memory.metadata.relevance_score * 0.2
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_memory = memory
|
||||
|
||||
return best_memory
|
||||
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
|
||||
"""融合记忆组"""
|
||||
if not group.memories:
|
||||
return None
|
||||
|
||||
if len(group.memories) == 1:
|
||||
return group.memories[0]
|
||||
|
||||
try:
|
||||
# 选择基础记忆(通常是代表性记忆)
|
||||
base_memory = group.representative_memory or group.memories[0]
|
||||
|
||||
# 融合其他记忆的属性
|
||||
fused_memory = await self._merge_memory_attributes(base_memory, group.memories)
|
||||
|
||||
# 更新元数据
|
||||
self._update_fused_metadata(fused_memory, group)
|
||||
|
||||
logger.debug(f"成功融合记忆组,包含 {len(group.memories)} 条原始记忆")
|
||||
return fused_memory
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"融合记忆组失败: {e}")
|
||||
# 返回置信度最高的记忆
|
||||
return max(group.memories, key=lambda m: m.metadata.confidence.value)
|
||||
|
||||
async def _merge_memory_attributes(
|
||||
self,
|
||||
base_memory: MemoryChunk,
|
||||
memories: List[MemoryChunk]
|
||||
) -> MemoryChunk:
|
||||
"""合并记忆属性"""
|
||||
# 创建基础记忆的深拷贝
|
||||
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
|
||||
|
||||
# 合并关键词
|
||||
all_keywords = set()
|
||||
for memory in memories:
|
||||
all_keywords.update(memory.keywords)
|
||||
fused_memory.keywords = sorted(all_keywords)
|
||||
|
||||
# 合并标签
|
||||
all_tags = set()
|
||||
for memory in memories:
|
||||
all_tags.update(memory.tags)
|
||||
fused_memory.tags = sorted(all_tags)
|
||||
|
||||
# 合并分类
|
||||
all_categories = set()
|
||||
for memory in memories:
|
||||
all_categories.update(memory.categories)
|
||||
fused_memory.categories = sorted(all_categories)
|
||||
|
||||
# 合并关联记忆
|
||||
all_related = set()
|
||||
for memory in memories:
|
||||
all_related.update(memory.related_memories)
|
||||
# 移除对自身和组内记忆的引用
|
||||
all_related = {rid for rid in all_related if rid not in [m.memory_id for m in memories]}
|
||||
fused_memory.related_memories = sorted(all_related)
|
||||
|
||||
# 合并时间上下文
|
||||
if self.fusion_strategies["temporal_proximity"]:
|
||||
fused_memory.temporal_context = self._merge_temporal_context(memories)
|
||||
|
||||
return fused_memory
|
||||
|
||||
def _update_fused_metadata(self, fused_memory: MemoryChunk, group: DuplicateGroup):
|
||||
"""更新融合记忆的元数据"""
|
||||
# 更新修改时间
|
||||
fused_memory.metadata.last_modified = time.time()
|
||||
|
||||
# 计算平均访问次数
|
||||
total_access = sum(m.metadata.access_count for m in group.memories)
|
||||
fused_memory.metadata.access_count = total_access
|
||||
|
||||
# 提升置信度(如果有多个来源支持)
|
||||
if self.fusion_strategies["confidence_boosting"] and len(group.memories) > 1:
|
||||
max_confidence = max(m.metadata.confidence.value for m in group.memories)
|
||||
if max_confidence < ConfidenceLevel.VERIFIED.value:
|
||||
fused_memory.metadata.confidence = ConfidenceLevel(
|
||||
min(max_confidence + 1, ConfidenceLevel.VERIFIED.value)
|
||||
)
|
||||
|
||||
# 保持最高重要性
|
||||
if self.fusion_strategies["importance_preservation"]:
|
||||
max_importance = max(m.metadata.importance.value for m in group.memories)
|
||||
fused_memory.metadata.importance = ImportanceLevel(max_importance)
|
||||
|
||||
# 计算平均相关度
|
||||
avg_relevance = sum(m.metadata.relevance_score for m in group.memories) / len(group.memories)
|
||||
fused_memory.metadata.relevance_score = min(avg_relevance * 1.1, 1.0) # 稍微提升相关度
|
||||
|
||||
# 设置来源信息
|
||||
source_ids = [m.memory_id[:8] for m in group.memories]
|
||||
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
|
||||
|
||||
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
|
||||
"""合并时间上下文"""
|
||||
contexts = [m.temporal_context for m in memories if m.temporal_context]
|
||||
|
||||
if not contexts:
|
||||
return {}
|
||||
|
||||
# 计算时间范围
|
||||
timestamps = [m.metadata.created_at for m in memories]
|
||||
earliest_time = min(timestamps)
|
||||
latest_time = max(timestamps)
|
||||
|
||||
merged_context = {
|
||||
"earliest_timestamp": earliest_time,
|
||||
"latest_timestamp": latest_time,
|
||||
"time_span_hours": (latest_time - earliest_time) / 3600,
|
||||
"source_memories": len(memories)
|
||||
}
|
||||
|
||||
# 合并其他上下文信息
|
||||
for context in contexts:
|
||||
for key, value in context.items():
|
||||
if key not in ["timestamp", "earliest_timestamp", "latest_timestamp"]:
|
||||
if key not in merged_context:
|
||||
merged_context[key] = value
|
||||
elif merged_context[key] != value:
|
||||
merged_context[key] = f"multiple: {value}"
|
||||
|
||||
return merged_context
|
||||
|
||||
async def incremental_fusion(
|
||||
self,
|
||||
new_memory: MemoryChunk,
|
||||
existing_memories: List[MemoryChunk]
|
||||
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
|
||||
"""增量融合(单个新记忆与现有记忆融合)"""
|
||||
# 寻找相似记忆
|
||||
similar_memories = []
|
||||
|
||||
for existing in existing_memories:
|
||||
similarity = self._calculate_comprehensive_similarity(new_memory, existing)
|
||||
if similarity >= self.similarity_threshold:
|
||||
similar_memories.append((existing, similarity))
|
||||
|
||||
if not similar_memories:
|
||||
# 没有相似记忆,直接返回
|
||||
return new_memory, existing_memories
|
||||
|
||||
# 按相似度排序
|
||||
similar_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 与最相似的记忆融合
|
||||
best_match, similarity = similar_memories[0]
|
||||
|
||||
# 创建融合组
|
||||
group = DuplicateGroup(
|
||||
group_id=f"incremental_{int(time.time())}",
|
||||
memories=[new_memory, best_match],
|
||||
similarity_matrix=[[1.0, similarity], [similarity, 1.0]]
|
||||
)
|
||||
|
||||
# 执行融合
|
||||
fused_memory = await self._fuse_memory_group(group)
|
||||
|
||||
# 从现有记忆中移除被融合的记忆
|
||||
updated_existing = [m for m in existing_memories if m.memory_id != best_match.memory_id]
|
||||
updated_existing.append(fused_memory)
|
||||
|
||||
logger.debug(f"增量融合完成,相似度: {similarity:.3f}")
|
||||
|
||||
return fused_memory, updated_existing
|
||||
|
||||
def _update_fusion_stats(self, original_count: int, removed_count: int, fusion_time: float):
|
||||
"""更新融合统计"""
|
||||
self.fusion_stats["total_fusions"] += 1
|
||||
self.fusion_stats["memories_fused"] += original_count
|
||||
self.fusion_stats["duplicates_removed"] += removed_count
|
||||
|
||||
# 更新平均相似度(估算)
|
||||
if removed_count > 0:
|
||||
avg_similarity = 0.9 # 假设平均相似度较高
|
||||
total_similarity = self.fusion_stats["average_similarity"] * (self.fusion_stats["total_fusions"] - 1)
|
||||
total_similarity += avg_similarity
|
||||
self.fusion_stats["average_similarity"] = total_similarity / self.fusion_stats["total_fusions"]
|
||||
|
||||
async def maintenance(self):
|
||||
"""维护操作"""
|
||||
try:
|
||||
logger.info("开始记忆融合引擎维护...")
|
||||
|
||||
# 可以在这里添加定期维护任务,如:
|
||||
# - 重新评估低置信度记忆
|
||||
# - 清理孤立记忆引用
|
||||
# - 优化融合策略参数
|
||||
|
||||
logger.info("✅ 记忆融合引擎维护完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
|
||||
|
||||
def get_fusion_stats(self) -> Dict[str, Any]:
|
||||
"""获取融合统计信息"""
|
||||
return self.fusion_stats.copy()
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.fusion_stats = {
|
||||
"total_fusions": 0,
|
||||
"memories_fused": 0,
|
||||
"duplicates_removed": 0,
|
||||
"average_similarity": 0.0
|
||||
}
|
||||
542
src/chat/memory_system/memory_integration_hooks.py
Normal file
542
src/chat/memory_system/memory_integration_hooks.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_enhanced_memory_adapter,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
get_memory_context_for_prompt
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""钩子执行结果"""
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
class MemoryIntegrationHooks:
|
||||
"""记忆系统集成钩子"""
|
||||
|
||||
def __init__(self):
|
||||
self.hooks_registered = False
|
||||
self.hook_stats = {
|
||||
"message_processing_hooks": 0,
|
||||
"memory_retrieval_hooks": 0,
|
||||
"prompt_enhancement_hooks": 0,
|
||||
"total_hook_executions": 0,
|
||||
"average_hook_time": 0.0
|
||||
}
|
||||
|
||||
async def register_hooks(self):
|
||||
"""注册所有集成钩子"""
|
||||
if self.hooks_registered:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔗 注册记忆系统集成钩子...")
|
||||
|
||||
# 注册消息处理钩子
|
||||
await self._register_message_processing_hooks()
|
||||
|
||||
# 注册记忆检索钩子
|
||||
await self._register_memory_retrieval_hooks()
|
||||
|
||||
# 注册提示词增强钩子
|
||||
await self._register_prompt_enhancement_hooks()
|
||||
|
||||
# 注册系统维护钩子
|
||||
await self._register_maintenance_hooks()
|
||||
|
||||
self.hooks_registered = True
|
||||
logger.info("✅ 记忆系统集成钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
|
||||
|
||||
async def _register_message_processing_hooks(self):
|
||||
"""注册消息处理钩子"""
|
||||
try:
|
||||
# 钩子1: 在消息处理后创建记忆
|
||||
await self._register_post_message_hook()
|
||||
|
||||
# 钩子2: 在聊天流保存时处理记忆
|
||||
await self._register_chat_stream_hook()
|
||||
|
||||
logger.debug("消息处理钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息处理钩子失败: {e}")
|
||||
|
||||
async def _register_memory_retrieval_hooks(self):
|
||||
"""注册记忆检索钩子"""
|
||||
try:
|
||||
# 钩子1: 在生成回复前检索相关记忆
|
||||
await self._register_pre_response_hook()
|
||||
|
||||
# 钩子2: 在知识库查询前增强上下文
|
||||
await self._register_knowledge_query_hook()
|
||||
|
||||
logger.debug("记忆检索钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册记忆检索钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_enhancement_hooks(self):
|
||||
"""注册提示词增强钩子"""
|
||||
try:
|
||||
# 钩子1: 增强提示词构建
|
||||
await self._register_prompt_building_hook()
|
||||
|
||||
logger.debug("提示词增强钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词增强钩子失败: {e}")
|
||||
|
||||
async def _register_maintenance_hooks(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 钩子1: 系统维护时的记忆系统维护
|
||||
await self._register_system_maintenance_hook()
|
||||
|
||||
logger.debug("系统维护钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
async def _register_post_message_hook(self):
|
||||
"""注册消息后处理钩子"""
|
||||
try:
|
||||
# 这里需要根据实际的系统架构来注册钩子
|
||||
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(
|
||||
EventType.MESSAGE_PROCESSED,
|
||||
self._on_message_processed_handler
|
||||
)
|
||||
logger.debug("已注册到事件系统的消息处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("事件系统不可用,跳过事件钩子注册")
|
||||
|
||||
# 尝试注册到消息管理器
|
||||
try:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
# 如果消息管理器支持钩子注册
|
||||
if hasattr(message_manager, 'register_post_process_hook'):
|
||||
message_manager.register_post_process_hook(
|
||||
self._on_message_processed_hook
|
||||
)
|
||||
logger.debug("已注册到消息管理器的处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息后处理钩子失败: {e}")
|
||||
|
||||
async def _register_chat_stream_hook(self):
|
||||
"""注册聊天流钩子"""
|
||||
try:
|
||||
# 尝试注册到聊天流管理器
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if hasattr(chat_manager, 'register_save_hook'):
|
||||
chat_manager.register_save_hook(
|
||||
self._on_chat_stream_save_hook
|
||||
)
|
||||
logger.debug("已注册到聊天流管理器的保存钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册聊天流钩子失败: {e}")
|
||||
|
||||
async def _register_pre_response_hook(self):
|
||||
"""注册回复前钩子"""
|
||||
try:
|
||||
# 尝试注册到回复生成器
|
||||
try:
|
||||
from src.chat.replyer.default_generator import default_generator
|
||||
|
||||
if hasattr(default_generator, 'register_pre_generation_hook'):
|
||||
default_generator.register_pre_generation_hook(
|
||||
self._on_pre_response_hook
|
||||
)
|
||||
logger.debug("已注册到回复生成器的前置钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("回复生成器不可用,跳过回复前钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册回复前钩子失败: {e}")
|
||||
|
||||
async def _register_knowledge_query_hook(self):
|
||||
"""注册知识库查询钩子"""
|
||||
try:
|
||||
# 尝试注册到知识库系统
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import knowledge_manager
|
||||
|
||||
if hasattr(knowledge_manager, 'register_query_enhancer'):
|
||||
knowledge_manager.register_query_enhancer(
|
||||
self._on_knowledge_query_hook
|
||||
)
|
||||
logger.debug("已注册到知识库的查询增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("知识库系统不可用,跳过知识库钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册知识库查询钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_building_hook(self):
|
||||
"""注册提示词构建钩子"""
|
||||
try:
|
||||
# 尝试注册到提示词系统
|
||||
try:
|
||||
from src.chat.utils.prompt import prompt_manager
|
||||
|
||||
if hasattr(prompt_manager, 'register_enhancer'):
|
||||
prompt_manager.register_enhancer(
|
||||
self._on_prompt_building_hook
|
||||
)
|
||||
logger.debug("已注册到提示词管理器的增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("提示词系统不可用,跳过提示词钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词构建钩子失败: {e}")
|
||||
|
||||
async def _register_system_maintenance_hook(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 尝试注册到系统维护器
|
||||
try:
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
|
||||
# 注册定期维护任务
|
||||
async_task_manager.add_task(MemoryMaintenanceTask())
|
||||
logger.debug("已注册到系统维护器的定期任务")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 提取必要的信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
user_info = message_info.get("user_info", {})
|
||||
conversation_text = message_data.get("processed_plain_text", "")
|
||||
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
user_id = str(user_info.get("user_id", "unknown"))
|
||||
context = {
|
||||
"chat_id": message_data.get("chat_id"),
|
||||
"message_type": message_data.get("message_type", "normal"),
|
||||
"platform": message_info.get("platform", "unknown"),
|
||||
"interest_value": message_data.get("interest_value", 0.0),
|
||||
"keywords": message_data.get("key_words", []),
|
||||
"timestamp": message_data.get("time", time.time())
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
result = await process_conversation_with_enhanced_memory(
|
||||
conversation_text, context, user_id
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 从聊天流数据中提取对话信息
|
||||
stream_context = chat_stream_data.get("stream_context", {})
|
||||
user_id = stream_context.get("user_id", "unknown")
|
||||
messages = stream_context.get("messages", [])
|
||||
|
||||
if not messages:
|
||||
return HookResult(success=True, data="No messages to process")
|
||||
|
||||
# 构建对话文本
|
||||
conversation_parts = []
|
||||
for msg in messages[-10:]: # 只处理最近10条消息
|
||||
text = msg.get("processed_plain_text", "")
|
||||
if text:
|
||||
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
|
||||
|
||||
conversation_text = "\n".join(conversation_parts)
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
context = {
|
||||
"chat_id": chat_stream_data.get("chat_id"),
|
||||
"stream_id": chat_stream_data.get("stream_id"),
|
||||
"platform": chat_stream_data.get("platform", "unknown"),
|
||||
"message_count": len(messages),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
result = await process_conversation_with_enhanced_memory(
|
||||
conversation_text, context, user_id
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
# 提取查询信息
|
||||
query = response_data.get("query", "")
|
||||
user_id = response_data.get("user_id", "unknown")
|
||||
context = response_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 检索相关记忆
|
||||
memories = await retrieve_memories_with_enhanced_system(
|
||||
query, user_id, context, limit=5
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆添加到响应数据中
|
||||
response_data["enhanced_memories"] = memories
|
||||
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=5
|
||||
)
|
||||
|
||||
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
|
||||
return HookResult(success=True, data=memories, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
query = query_data.get("query", "")
|
||||
user_id = query_data.get("user_id", "unknown")
|
||||
context = query_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文并增强查询
|
||||
memory_context = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=3
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆上下文添加到查询数据中
|
||||
query_data["enhanced_memory_context"] = memory_context
|
||||
|
||||
logger.debug("知识库查询钩子执行成功")
|
||||
return HookResult(success=True, data=memory_context, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["prompt_enhancement_hooks"] += 1
|
||||
|
||||
query = prompt_data.get("query", "")
|
||||
user_id = prompt_data.get("user_id", "unknown")
|
||||
context = prompt_data.get("context", {})
|
||||
base_prompt = prompt_data.get("base_prompt", "")
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文
|
||||
memory_context = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=5
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = base_prompt
|
||||
if memory_context:
|
||||
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
|
||||
|
||||
# 将增强的提示词添加到数据中
|
||||
prompt_data["enhanced_prompt"] = enhanced_prompt
|
||||
prompt_data["memory_context"] = memory_context
|
||||
|
||||
logger.debug("提示词构建钩子执行成功")
|
||||
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
def _update_hook_stats(self, processing_time: float):
|
||||
"""更新钩子统计"""
|
||||
self.hook_stats["total_hook_executions"] += 1
|
||||
|
||||
total_executions = self.hook_stats["total_hook_executions"]
|
||||
if total_executions > 0:
|
||||
current_avg = self.hook_stats["average_hook_time"]
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> Dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
|
||||
class MemoryMaintenanceTask:
|
||||
"""记忆系统维护任务"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_name = "enhanced_memory_maintenance"
|
||||
self.interval = 3600 # 1小时执行一次
|
||||
|
||||
async def execute(self):
|
||||
"""执行维护任务"""
|
||||
try:
|
||||
logger.info("🔧 执行增强记忆系统维护任务...")
|
||||
|
||||
# 获取适配器实例
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
|
||||
if _enhanced_memory_adapter:
|
||||
await _enhanced_memory_adapter.maintenance()
|
||||
logger.info("✅ 增强记忆系统维护任务完成")
|
||||
else:
|
||||
logger.debug("增强记忆适配器未初始化,跳过维护")
|
||||
except Exception as e:
|
||||
logger.error(f"增强记忆系统维护失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
|
||||
|
||||
def get_interval(self) -> int:
|
||||
"""获取执行间隔"""
|
||||
return self.interval
|
||||
|
||||
def get_task_name(self) -> str:
|
||||
"""获取任务名称"""
|
||||
return self.task_name
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: Optional[MemoryIntegrationHooks] = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
"""获取全局记忆集成钩子实例"""
|
||||
global _memory_hooks
|
||||
|
||||
if _memory_hooks is None:
|
||||
_memory_hooks = MemoryIntegrationHooks()
|
||||
await _memory_hooks.register_hooks()
|
||||
|
||||
return _memory_hooks
|
||||
|
||||
|
||||
async def initialize_memory_integration_hooks():
|
||||
"""初始化记忆集成钩子"""
|
||||
try:
|
||||
logger.info("🚀 初始化记忆集成钩子...")
|
||||
hooks = await get_memory_integration_hooks()
|
||||
logger.info("✅ 记忆集成钩子初始化完成")
|
||||
return hooks
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
832
src/chat/memory_system/metadata_index.py
Normal file
832
src/chat/memory_system/metadata_index.py
Normal file
@@ -0,0 +1,832 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
元数据索引系统
|
||||
为记忆系统提供多维度的精准过滤和查询能力
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
"""索引类型"""
|
||||
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
||||
USER_ID = "user_id" # 用户ID索引
|
||||
KEYWORD = "keyword" # 关键词索引
|
||||
TAG = "tag" # 标签索引
|
||||
CATEGORY = "category" # 分类索引
|
||||
TIMESTAMP = "timestamp" # 时间索引
|
||||
CONFIDENCE = "confidence" # 置信度索引
|
||||
IMPORTANCE = "importance" # 重要性索引
|
||||
RELATIONSHIP_SCORE = "relationship_score" # 关系分索引
|
||||
ACCESS_FREQUENCY = "access_frequency" # 访问频率索引
|
||||
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
categories: Optional[List[str]] = None
|
||||
time_range: Optional[Tuple[float, float]] = None
|
||||
confidence_levels: Optional[List[ConfidenceLevel]] = None
|
||||
importance_levels: Optional[List[ImportanceLevel]] = None
|
||||
min_relationship_score: Optional[float] = None
|
||||
max_relationship_score: Optional[float] = None
|
||||
min_access_count: Optional[int] = None
|
||||
semantic_hashes: Optional[List[str]] = None
|
||||
limit: Optional[int] = None
|
||||
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
|
||||
sort_order: str = "desc" # "asc", "desc"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
"""索引结果"""
|
||||
memory_ids: List[str]
|
||||
total_count: int
|
||||
query_time: float
|
||||
filtered_by: List[str]
|
||||
|
||||
|
||||
class MetadataIndexManager:
|
||||
"""元数据索引管理器"""
|
||||
|
||||
def __init__(self, index_path: str = "data/memory_metadata"):
|
||||
self.index_path = Path(index_path)
|
||||
self.index_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 各类索引
|
||||
self.indices = {
|
||||
IndexType.MEMORY_TYPE: defaultdict(set),
|
||||
IndexType.USER_ID: defaultdict(set),
|
||||
IndexType.KEYWORD: defaultdict(set),
|
||||
IndexType.TAG: defaultdict(set),
|
||||
IndexType.CATEGORY: defaultdict(set),
|
||||
IndexType.CONFIDENCE: defaultdict(set),
|
||||
IndexType.IMPORTANCE: defaultdict(set),
|
||||
IndexType.SEMANTIC_HASH: defaultdict(set),
|
||||
}
|
||||
|
||||
# 时间索引(特殊处理)
|
||||
self.time_index = [] # [(timestamp, memory_id), ...]
|
||||
self.relationship_index = [] # [(relationship_score, memory_id), ...]
|
||||
self.access_frequency_index = [] # [(access_count, memory_id), ...]
|
||||
|
||||
# 内存缓存
|
||||
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.index_stats = {
|
||||
"total_memories": 0,
|
||||
"index_build_time": 0.0,
|
||||
"average_query_time": 0.0,
|
||||
"total_queries": 0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"cache_hits": 0
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._dirty = False # 标记索引是否有未保存的更改
|
||||
|
||||
# 自动保存配置
|
||||
self.auto_save_interval = 500 # 每500次操作自动保存
|
||||
self._operation_count = 0
|
||||
|
||||
async def index_memories(self, memories: List[MemoryChunk]):
|
||||
"""为记忆建立索引"""
|
||||
if not memories:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
for memory in memories:
|
||||
self._index_single_memory(memory)
|
||||
|
||||
# 标记为需要保存
|
||||
self._dirty = True
|
||||
self._operation_count += len(memories)
|
||||
|
||||
# 自动保存检查
|
||||
if self._operation_count >= self.auto_save_interval:
|
||||
await self.save_index()
|
||||
self._operation_count = 0
|
||||
|
||||
index_time = time.time() - start_time
|
||||
self.index_stats["index_build_time"] = (
|
||||
(self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) /
|
||||
len(memories)
|
||||
)
|
||||
|
||||
logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _index_single_memory(self, memory: MemoryChunk):
|
||||
"""为单个记忆建立索引"""
|
||||
memory_id = memory.memory_id
|
||||
|
||||
# 更新内存缓存
|
||||
self.memory_metadata_cache[memory_id] = {
|
||||
"user_id": memory.user_id,
|
||||
"memory_type": memory.memory_type,
|
||||
"created_at": memory.metadata.created_at,
|
||||
"last_accessed": memory.metadata.last_accessed,
|
||||
"access_count": memory.metadata.access_count,
|
||||
"confidence": memory.metadata.confidence,
|
||||
"importance": memory.metadata.importance,
|
||||
"relationship_score": memory.metadata.relationship_score,
|
||||
"relevance_score": memory.metadata.relevance_score,
|
||||
"semantic_hash": memory.semantic_hash
|
||||
}
|
||||
|
||||
# 记忆类型索引
|
||||
self.indices[IndexType.MEMORY_TYPE][memory.memory_type].add(memory_id)
|
||||
|
||||
# 用户ID索引
|
||||
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
|
||||
|
||||
# 关键词索引
|
||||
for keyword in memory.keywords:
|
||||
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
|
||||
|
||||
# 标签索引
|
||||
for tag in memory.tags:
|
||||
self.indices[IndexType.TAG][tag.lower()].add(memory_id)
|
||||
|
||||
# 分类索引
|
||||
for category in memory.categories:
|
||||
self.indices[IndexType.CATEGORY][category.lower()].add(memory_id)
|
||||
|
||||
# 置信度索引
|
||||
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory_id)
|
||||
|
||||
# 重要性索引
|
||||
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory_id)
|
||||
|
||||
# 语义哈希索引
|
||||
if memory.semantic_hash:
|
||||
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory_id)
|
||||
|
||||
# 时间索引(插入排序保持有序)
|
||||
self._insert_into_time_index(memory.metadata.created_at, memory_id)
|
||||
|
||||
# 关系分索引(插入排序保持有序)
|
||||
self._insert_into_relationship_index(memory.metadata.relationship_score, memory_id)
|
||||
|
||||
# 访问频率索引(插入排序保持有序)
|
||||
self._insert_into_access_frequency_index(memory.metadata.access_count, memory_id)
|
||||
|
||||
# 更新统计
|
||||
self.index_stats["total_memories"] += 1
|
||||
|
||||
def _insert_into_time_index(self, timestamp: float, memory_id: str):
|
||||
"""插入时间索引(保持降序)"""
|
||||
insert_pos = len(self.time_index)
|
||||
for i, (ts, _) in enumerate(self.time_index):
|
||||
if timestamp >= ts:
|
||||
insert_pos = i
|
||||
break
|
||||
|
||||
self.time_index.insert(insert_pos, (timestamp, memory_id))
|
||||
|
||||
def _insert_into_relationship_index(self, relationship_score: float, memory_id: str):
|
||||
"""插入关系分索引(保持降序)"""
|
||||
insert_pos = len(self.relationship_index)
|
||||
for i, (score, _) in enumerate(self.relationship_index):
|
||||
if relationship_score >= score:
|
||||
insert_pos = i
|
||||
break
|
||||
|
||||
self.relationship_index.insert(insert_pos, (relationship_score, memory_id))
|
||||
|
||||
def _insert_into_access_frequency_index(self, access_count: int, memory_id: str):
|
||||
"""插入访问频率索引(保持降序)"""
|
||||
insert_pos = len(self.access_frequency_index)
|
||||
for i, (count, _) in enumerate(self.access_frequency_index):
|
||||
if access_count >= count:
|
||||
insert_pos = i
|
||||
break
|
||||
|
||||
self.access_frequency_index.insert(insert_pos, (access_count, memory_id))
|
||||
|
||||
async def query_memories(self, query: IndexQuery) -> IndexResult:
|
||||
"""查询记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
# 获取候选记忆ID集合
|
||||
candidate_ids = self._get_candidate_memories(query)
|
||||
|
||||
# 应用过滤条件
|
||||
filtered_ids = self._apply_filters(candidate_ids, query)
|
||||
|
||||
# 排序
|
||||
if query.sort_by:
|
||||
filtered_ids = self._sort_memories(filtered_ids, query.sort_by, query.sort_order)
|
||||
|
||||
# 限制数量
|
||||
if query.limit and len(filtered_ids) > query.limit:
|
||||
filtered_ids = filtered_ids[:query.limit]
|
||||
|
||||
# 记录查询统计
|
||||
query_time = time.time() - start_time
|
||||
self.index_stats["total_queries"] += 1
|
||||
self.index_stats["average_query_time"] = (
|
||||
(self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) /
|
||||
self.index_stats["total_queries"]
|
||||
)
|
||||
|
||||
return IndexResult(
|
||||
memory_ids=filtered_ids,
|
||||
total_count=len(filtered_ids),
|
||||
query_time=query_time,
|
||||
filtered_by=self._get_applied_filters(query)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
|
||||
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
|
||||
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
|
||||
"""获取候选记忆ID集合"""
|
||||
candidate_ids = set()
|
||||
|
||||
# 获取所有记忆ID作为起点
|
||||
all_memory_ids = set(self.memory_metadata_cache.keys())
|
||||
|
||||
if not all_memory_ids:
|
||||
return candidate_ids
|
||||
|
||||
# 应用最严格的过滤条件
|
||||
applied_filters = []
|
||||
|
||||
if query.user_ids:
|
||||
user_ids_set = set()
|
||||
for user_id in query.user_ids:
|
||||
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
candidate_ids.update(user_ids_set)
|
||||
applied_filters.append("user_ids")
|
||||
|
||||
if query.memory_types:
|
||||
memory_types_set = set()
|
||||
for memory_type in query.memory_types:
|
||||
memory_types_set.update(self.indices[IndexType.MEMORY_TYPE].get(memory_type, set()))
|
||||
if applied_filters:
|
||||
candidate_ids &= memory_types_set
|
||||
else:
|
||||
candidate_ids.update(memory_types_set)
|
||||
applied_filters.append("memory_types")
|
||||
|
||||
if query.keywords:
|
||||
keywords_set = set()
|
||||
for keyword in query.keywords:
|
||||
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
|
||||
if applied_filters:
|
||||
candidate_ids &= keywords_set
|
||||
else:
|
||||
candidate_ids.update(keywords_set)
|
||||
applied_filters.append("keywords")
|
||||
|
||||
if query.tags:
|
||||
tags_set = set()
|
||||
for tag in query.tags:
|
||||
tags_set.update(self.indices[IndexType.TAG].get(tag.lower(), set()))
|
||||
if applied_filters:
|
||||
candidate_ids &= tags_set
|
||||
else:
|
||||
candidate_ids.update(tags_set)
|
||||
applied_filters.append("tags")
|
||||
|
||||
if query.categories:
|
||||
categories_set = set()
|
||||
for category in query.categories:
|
||||
categories_set.update(self.indices[IndexType.CATEGORY].get(category.lower(), set()))
|
||||
if applied_filters:
|
||||
candidate_ids &= categories_set
|
||||
else:
|
||||
candidate_ids.update(categories_set)
|
||||
applied_filters.append("categories")
|
||||
|
||||
# 如果没有应用任何过滤条件,返回所有记忆
|
||||
if not applied_filters:
|
||||
return all_memory_ids
|
||||
|
||||
return candidate_ids
|
||||
|
||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||
"""应用过滤条件"""
|
||||
filtered_ids = list(candidate_ids)
|
||||
|
||||
# 时间范围过滤
|
||||
if query.time_range:
|
||||
start_time, end_time = query.time_range
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self._is_in_time_range(memory_id, start_time, end_time)
|
||||
]
|
||||
|
||||
# 置信度过滤
|
||||
if query.confidence_levels:
|
||||
confidence_set = set(query.confidence_levels)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set
|
||||
]
|
||||
|
||||
# 重要性过滤
|
||||
if query.importance_levels:
|
||||
importance_set = set(query.importance_levels)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["importance"] in importance_set
|
||||
]
|
||||
|
||||
# 关系分范围过滤
|
||||
if query.min_relationship_score is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score
|
||||
]
|
||||
|
||||
if query.max_relationship_score is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score
|
||||
]
|
||||
|
||||
# 最小访问次数过滤
|
||||
if query.min_access_count is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count
|
||||
]
|
||||
|
||||
# 语义哈希过滤
|
||||
if query.semantic_hashes:
|
||||
hash_set = set(query.semantic_hashes)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set
|
||||
]
|
||||
|
||||
return filtered_ids
|
||||
|
||||
def _is_in_time_range(self, memory_id: str, start_time: float, end_time: float) -> bool:
|
||||
"""检查记忆是否在时间范围内"""
|
||||
created_at = self.memory_metadata_cache[memory_id]["created_at"]
|
||||
return start_time <= created_at <= end_time
|
||||
|
||||
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
|
||||
"""对记忆进行排序"""
|
||||
if sort_by == "created_at":
|
||||
# 使用时间索引(已经有序)
|
||||
if sort_order == "desc":
|
||||
return memory_ids # 时间索引已经是降序
|
||||
else:
|
||||
return memory_ids[::-1] # 反转为升序
|
||||
|
||||
elif sort_by == "access_count":
|
||||
# 使用访问频率索引(已经有序)
|
||||
if sort_order == "desc":
|
||||
return memory_ids # 访问频率索引已经是降序
|
||||
else:
|
||||
return memory_ids[::-1] # 反转为升序
|
||||
|
||||
elif sort_by == "relevance_score":
|
||||
# 按相关度排序
|
||||
memory_ids.sort(
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"],
|
||||
reverse=(sort_order == "desc")
|
||||
)
|
||||
|
||||
elif sort_by == "relationship_score":
|
||||
# 使用关系分索引(已经有序)
|
||||
if sort_order == "desc":
|
||||
return memory_ids # 关系分索引已经是降序
|
||||
else:
|
||||
return memory_ids[::-1] # 反转为升序
|
||||
|
||||
elif sort_by == "last_accessed":
|
||||
# 按最后访问时间排序
|
||||
memory_ids.sort(
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"],
|
||||
reverse=(sort_order == "desc")
|
||||
)
|
||||
|
||||
return memory_ids
|
||||
|
||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||
"""获取应用的过滤器列表"""
|
||||
filters = []
|
||||
if query.user_ids:
|
||||
filters.append("user_ids")
|
||||
if query.memory_types:
|
||||
filters.append("memory_types")
|
||||
if query.keywords:
|
||||
filters.append("keywords")
|
||||
if query.tags:
|
||||
filters.append("tags")
|
||||
if query.categories:
|
||||
filters.append("categories")
|
||||
if query.time_range:
|
||||
filters.append("time_range")
|
||||
if query.confidence_levels:
|
||||
filters.append("confidence_levels")
|
||||
if query.importance_levels:
|
||||
filters.append("importance_levels")
|
||||
if query.min_relationship_score is not None or query.max_relationship_score is not None:
|
||||
filters.append("relationship_score_range")
|
||||
if query.min_access_count is not None:
|
||||
filters.append("min_access_count")
|
||||
if query.semantic_hashes:
|
||||
filters.append("semantic_hashes")
|
||||
return filters
|
||||
|
||||
async def update_memory_index(self, memory: MemoryChunk):
|
||||
"""更新记忆索引"""
|
||||
with self._lock:
|
||||
try:
|
||||
memory_id = memory.memory_id
|
||||
|
||||
# 如果记忆已存在,先删除旧索引
|
||||
if memory_id in self.memory_metadata_cache:
|
||||
await self.remove_memory_index(memory_id)
|
||||
|
||||
# 重新建立索引
|
||||
self._index_single_memory(memory)
|
||||
self._dirty = True
|
||||
self._operation_count += 1
|
||||
|
||||
# 自动保存检查
|
||||
if self._operation_count >= self.auto_save_interval:
|
||||
await self.save_index()
|
||||
self._operation_count = 0
|
||||
|
||||
logger.debug(f"更新记忆索引完成: {memory_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新记忆索引失败: {e}")
|
||||
|
||||
async def remove_memory_index(self, memory_id: str):
|
||||
"""移除记忆索引"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_metadata_cache:
|
||||
return
|
||||
|
||||
# 获取记忆元数据
|
||||
metadata = self.memory_metadata_cache[memory_id]
|
||||
|
||||
# 从各类索引中移除
|
||||
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
|
||||
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
|
||||
|
||||
# 从时间索引中移除
|
||||
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
|
||||
|
||||
# 从关系分索引中移除
|
||||
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id]
|
||||
|
||||
# 从访问频率索引中移除
|
||||
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id]
|
||||
|
||||
# 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理
|
||||
# 实际实现中可能需要重新加载记忆或维护反向索引
|
||||
|
||||
# 从缓存中移除
|
||||
del self.memory_metadata_cache[memory_id]
|
||||
|
||||
# 更新统计
|
||||
self.index_stats["total_memories"] = max(0, self.index_stats["total_memories"] - 1)
|
||||
self._dirty = True
|
||||
|
||||
logger.debug(f"移除记忆索引完成: {memory_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 移除记忆索引失败: {e}")
|
||||
|
||||
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆元数据"""
|
||||
return self.memory_metadata_cache.get(memory_id)
|
||||
|
||||
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
|
||||
"""获取用户的所有记忆ID"""
|
||||
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
|
||||
if limit and len(user_memory_ids) > limit:
|
||||
user_memory_ids = user_memory_ids[:limit]
|
||||
|
||||
return user_memory_ids
|
||||
|
||||
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取记忆统计信息"""
|
||||
stats = {
|
||||
"total_memories": self.index_stats["total_memories"],
|
||||
"memory_types": {},
|
||||
"average_confidence": 0.0,
|
||||
"average_importance": 0.0,
|
||||
"average_relationship_score": 0.0,
|
||||
"top_keywords": [],
|
||||
"top_tags": []
|
||||
}
|
||||
|
||||
if user_id:
|
||||
# 限定用户统计
|
||||
user_memory_ids = self.indices[IndexType.USER_ID].get(user_id, set())
|
||||
stats["user_total_memories"] = len(user_memory_ids)
|
||||
|
||||
if not user_memory_ids:
|
||||
return stats
|
||||
|
||||
# 用户记忆类型分布
|
||||
user_types = {}
|
||||
for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items():
|
||||
user_count = len(user_memory_ids & memory_ids)
|
||||
if user_count > 0:
|
||||
user_types[memory_type.value] = user_count
|
||||
stats["memory_types"] = user_types
|
||||
|
||||
# 计算用户平均值
|
||||
user_confidences = []
|
||||
user_importances = []
|
||||
user_relationship_scores = []
|
||||
|
||||
for memory_id in user_memory_ids:
|
||||
metadata = self.memory_metadata_cache.get(memory_id, {})
|
||||
if metadata:
|
||||
user_confidences.append(metadata["confidence"].value)
|
||||
user_importances.append(metadata["importance"].value)
|
||||
user_relationship_scores.append(metadata["relationship_score"])
|
||||
|
||||
if user_confidences:
|
||||
stats["average_confidence"] = sum(user_confidences) / len(user_confidences)
|
||||
if user_importances:
|
||||
stats["average_importance"] = sum(user_importances) / len(user_importances)
|
||||
if user_relationship_scores:
|
||||
stats["average_relationship_score"] = sum(user_relationship_scores) / len(user_relationship_scores)
|
||||
|
||||
else:
|
||||
# 全局统计
|
||||
for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items():
|
||||
stats["memory_types"][memory_type.value] = len(memory_ids)
|
||||
|
||||
# 计算全局平均值
|
||||
if self.memory_metadata_cache:
|
||||
all_confidences = [m["confidence"].value for m in self.memory_metadata_cache.values()]
|
||||
all_importances = [m["importance"].value for m in self.memory_metadata_cache.values()]
|
||||
all_relationship_scores = [m["relationship_score"] for m in self.memory_metadata_cache.values()]
|
||||
|
||||
if all_confidences:
|
||||
stats["average_confidence"] = sum(all_confidences) / len(all_confidences)
|
||||
if all_importances:
|
||||
stats["average_importance"] = sum(all_importances) / len(all_importances)
|
||||
if all_relationship_scores:
|
||||
stats["average_relationship_score"] = sum(all_relationship_scores) / len(all_relationship_scores)
|
||||
|
||||
# 统计热门关键词和标签
|
||||
keyword_counts = [(keyword, len(memory_ids)) for keyword, memory_ids in self.indices[IndexType.KEYWORD].items()]
|
||||
keyword_counts.sort(key=lambda x: x[1], reverse=True)
|
||||
stats["top_keywords"] = keyword_counts[:10]
|
||||
|
||||
tag_counts = [(tag, len(memory_ids)) for tag, memory_ids in self.indices[IndexType.TAG].items()]
|
||||
tag_counts.sort(key=lambda x: x[1], reverse=True)
|
||||
stats["top_tags"] = tag_counts[:10]
|
||||
|
||||
return stats
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引到文件"""
|
||||
if not self._dirty:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("正在保存元数据索引...")
|
||||
|
||||
# 保存各类索引
|
||||
indices_data = {}
|
||||
for index_type, index_data in self.indices.items():
|
||||
indices_data[index_type.value] = {
|
||||
key: list(values) for key, values in index_data.items()
|
||||
}
|
||||
|
||||
indices_file = self.index_path / "indices.json"
|
||||
with open(indices_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
with open(time_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
with open(relationship_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
with open(access_frequency_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
with open(stats_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
self._dirty = False
|
||||
logger.info("✅ 元数据索引保存完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 保存元数据索引失败: {e}")
|
||||
|
||||
async def load_index(self):
|
||||
"""从文件加载索引"""
|
||||
try:
|
||||
logger.info("正在加载元数据索引...")
|
||||
|
||||
# 加载各类索引
|
||||
indices_file = self.index_path / "indices.json"
|
||||
if indices_file.exists():
|
||||
with open(indices_file, 'r', encoding='utf-8') as f:
|
||||
indices_data = orjson.loads(f.read())
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
index_type = IndexType(index_type_value)
|
||||
self.indices[index_type] = {
|
||||
key: set(values) for key, values in index_data.items()
|
||||
}
|
||||
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
if time_index_file.exists():
|
||||
with open(time_index_file, 'r', encoding='utf-8') as f:
|
||||
self.time_index = orjson.loads(f.read())
|
||||
|
||||
# 加载关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
if relationship_index_file.exists():
|
||||
with open(relationship_index_file, 'r', encoding='utf-8') as f:
|
||||
self.relationship_index = orjson.loads(f.read())
|
||||
|
||||
# 加载访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
if access_frequency_index_file.exists():
|
||||
with open(access_frequency_index_file, 'r', encoding='utf-8') as f:
|
||||
self.access_frequency_index = orjson.loads(f.read())
|
||||
|
||||
# 加载元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
if metadata_cache_file.exists():
|
||||
with open(metadata_cache_file, 'r', encoding='utf-8') as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
for memory_id, metadata in cache_data.items():
|
||||
if isinstance(metadata["confidence"], str):
|
||||
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
|
||||
if isinstance(metadata["importance"], str):
|
||||
metadata["importance"] = ImportanceLevel(metadata["importance"])
|
||||
|
||||
self.memory_metadata_cache = cache_data
|
||||
|
||||
# 加载统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, 'r', encoding='utf-8') as f:
|
||||
self.index_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新记忆计数
|
||||
self.index_stats["total_memories"] = len(self.memory_metadata_cache)
|
||||
|
||||
logger.info(f"✅ 元数据索引加载完成,{self.index_stats['total_memories']} 个记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载元数据索引失败: {e}")
|
||||
|
||||
async def optimize_index(self):
|
||||
"""优化索引"""
|
||||
try:
|
||||
logger.info("开始元数据索引优化...")
|
||||
|
||||
# 清理无效引用
|
||||
self._cleanup_invalid_references()
|
||||
|
||||
# 重建有序索引
|
||||
self._rebuild_ordered_indices()
|
||||
|
||||
# 清理低频关键词和标签
|
||||
self._cleanup_low_frequency_terms()
|
||||
|
||||
# 更新统计信息
|
||||
if self.index_stats["total_queries"] > 0:
|
||||
self.index_stats["cache_hit_rate"] = (
|
||||
self.index_stats["cache_hits"] / self.index_stats["total_queries"]
|
||||
)
|
||||
|
||||
logger.info("✅ 元数据索引优化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 元数据索引优化失败: {e}")
|
||||
|
||||
def _cleanup_invalid_references(self):
|
||||
"""清理无效引用"""
|
||||
valid_memory_ids = set(self.memory_metadata_cache.keys())
|
||||
|
||||
# 清理各类索引中的无效引用
|
||||
for index_type in self.indices:
|
||||
for key in list(self.indices[index_type].keys()):
|
||||
valid_ids = self.indices[index_type][key] & valid_memory_ids
|
||||
self.indices[index_type][key] = valid_ids
|
||||
|
||||
# 如果某类别下没有记忆了,删除该类别
|
||||
if not valid_ids:
|
||||
del self.indices[index_type][key]
|
||||
|
||||
# 清理时间索引中的无效引用
|
||||
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid in valid_memory_ids]
|
||||
|
||||
# 清理关系分索引中的无效引用
|
||||
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids]
|
||||
|
||||
# 清理访问频率索引中的无效引用
|
||||
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids]
|
||||
|
||||
# 更新总记忆数
|
||||
self.index_stats["total_memories"] = len(valid_memory_ids)
|
||||
|
||||
def _rebuild_ordered_indices(self):
|
||||
"""重建有序索引"""
|
||||
# 重建时间索引
|
||||
self.time_index.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 重建关系分索引
|
||||
self.relationship_index.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 重建访问频率索引
|
||||
self.access_frequency_index.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
def _cleanup_low_frequency_terms(self, min_frequency: int = 2):
|
||||
"""清理低频术语"""
|
||||
# 清理低频关键词
|
||||
for keyword in list(self.indices[IndexType.KEYWORD].keys()):
|
||||
if len(self.indices[IndexType.KEYWORD][keyword]) < min_frequency:
|
||||
del self.indices[IndexType.KEYWORD][keyword]
|
||||
|
||||
# 清理低频标签
|
||||
for tag in list(self.indices[IndexType.TAG].keys()):
|
||||
if len(self.indices[IndexType.TAG][tag]) < min_frequency:
|
||||
del self.indices[IndexType.TAG][tag]
|
||||
|
||||
# 清理低频分类
|
||||
for category in list(self.indices[IndexType.CATEGORY].keys()):
|
||||
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
|
||||
del self.indices[IndexType.CATEGORY][category]
|
||||
|
||||
def get_index_stats(self) -> Dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
stats = self.index_stats.copy()
|
||||
if stats["total_queries"] > 0:
|
||||
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_queries"]
|
||||
else:
|
||||
stats["cache_hit_rate"] = 0.0
|
||||
|
||||
# 添加索引详细信息
|
||||
stats["index_details"] = {
|
||||
"memory_types": len(self.indices[IndexType.MEMORY_TYPE]),
|
||||
"user_ids": len(self.indices[IndexType.USER_ID]),
|
||||
"keywords": len(self.indices[IndexType.KEYWORD]),
|
||||
"tags": len(self.indices[IndexType.TAG]),
|
||||
"categories": len(self.indices[IndexType.CATEGORY]),
|
||||
"confidence_levels": len(self.indices[IndexType.CONFIDENCE]),
|
||||
"importance_levels": len(self.indices[IndexType.IMPORTANCE]),
|
||||
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH])
|
||||
}
|
||||
|
||||
return stats
|
||||
595
src/chat/memory_system/multi_stage_retrieval.py
Normal file
595
src/chat/memory_system/multi_stage_retrieval.py
Normal file
@@ -0,0 +1,595 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多阶段召回机制
|
||||
实现粗粒度到细粒度的记忆检索优化
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RetrievalStage(Enum):
|
||||
"""检索阶段"""
|
||||
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
|
||||
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
|
||||
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
|
||||
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""检索配置"""
|
||||
# 各阶段配置
|
||||
metadata_filter_limit: int = 100 # 元数据过滤阶段返回数量
|
||||
vector_search_limit: int = 50 # 向量搜索阶段返回数量
|
||||
semantic_rerank_limit: int = 20 # 语义重排序阶段返回数量
|
||||
final_result_limit: int = 10 # 最终结果数量
|
||||
|
||||
# 相似度阈值
|
||||
vector_similarity_threshold: float = 0.7 # 向量相似度阈值
|
||||
semantic_similarity_threshold: float = 0.6 # 语义相似度阈值
|
||||
|
||||
# 权重配置
|
||||
vector_weight: float = 0.4 # 向量相似度权重
|
||||
semantic_weight: float = 0.3 # 语义相似度权重
|
||||
context_weight: float = 0.2 # 上下文权重
|
||||
recency_weight: float = 0.1 # 时效性权重
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建配置实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return cls(
|
||||
# 各阶段配置
|
||||
metadata_filter_limit=global_config.memory.metadata_filter_limit,
|
||||
vector_search_limit=global_config.memory.vector_search_limit,
|
||||
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
|
||||
final_result_limit=global_config.memory.final_result_limit,
|
||||
|
||||
# 相似度阈值
|
||||
vector_similarity_threshold=global_config.memory.vector_similarity_threshold,
|
||||
semantic_similarity_threshold=0.6, # 保持默认值
|
||||
|
||||
# 权重配置
|
||||
vector_weight=global_config.memory.vector_weight,
|
||||
semantic_weight=global_config.memory.semantic_weight,
|
||||
context_weight=global_config.memory.context_weight,
|
||||
recency_weight=global_config.memory.recency_weight
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageResult:
|
||||
"""阶段结果"""
|
||||
stage: RetrievalStage
|
||||
memory_ids: List[str]
|
||||
processing_time: float
|
||||
filtered_count: int
|
||||
score_threshold: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
query: str
|
||||
user_id: str
|
||||
final_memories: List[MemoryChunk]
|
||||
stage_results: List[StageResult]
|
||||
total_processing_time: float
|
||||
total_filtered: int
|
||||
retrieval_stats: Dict[str, Any]
|
||||
|
||||
|
||||
class MultiStageRetrieval:
|
||||
"""多阶段召回系统"""
|
||||
|
||||
def __init__(self, config: Optional[RetrievalConfig] = None):
|
||||
self.config = config or RetrievalConfig.from_global_config()
|
||||
self.retrieval_stats = {
|
||||
"total_queries": 0,
|
||||
"average_retrieval_time": 0.0,
|
||||
"stage_stats": {
|
||||
"metadata_filtering": {"calls": 0, "avg_time": 0.0},
|
||||
"vector_search": {"calls": 0, "avg_time": 0.0},
|
||||
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
|
||||
"contextual_filtering": {"calls": 0, "avg_time": 0.0}
|
||||
}
|
||||
}
|
||||
|
||||
async def retrieve_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
metadata_index,
|
||||
vector_storage,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: Optional[int] = None
|
||||
) -> RetrievalResult:
|
||||
"""多阶段记忆检索"""
|
||||
start_time = time.time()
|
||||
limit = limit or self.config.final_result_limit
|
||||
|
||||
stage_results = []
|
||||
current_memory_ids = set()
|
||||
|
||||
try:
|
||||
logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'")
|
||||
|
||||
# 阶段1:元数据过滤
|
||||
stage1_result = await self._metadata_filtering_stage(
|
||||
query, user_id, context, metadata_index, all_memories_cache
|
||||
)
|
||||
stage_results.append(stage1_result)
|
||||
current_memory_ids.update(stage1_result.memory_ids)
|
||||
|
||||
# 阶段2:向量搜索
|
||||
stage2_result = await self._vector_search_stage(
|
||||
query, user_id, context, vector_storage, current_memory_ids, all_memories_cache
|
||||
)
|
||||
stage_results.append(stage2_result)
|
||||
current_memory_ids.update(stage2_result.memory_ids)
|
||||
|
||||
# 阶段3:语义重排序
|
||||
stage3_result = await self._semantic_reranking_stage(
|
||||
query, user_id, context, current_memory_ids, all_memories_cache
|
||||
)
|
||||
stage_results.append(stage3_result)
|
||||
|
||||
# 阶段4:上下文过滤
|
||||
stage4_result = await self._contextual_filtering_stage(
|
||||
query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit
|
||||
)
|
||||
stage_results.append(stage4_result)
|
||||
|
||||
# 获取最终记忆对象
|
||||
final_memories = []
|
||||
for memory_id in stage4_result.memory_ids:
|
||||
if memory_id in all_memories_cache:
|
||||
final_memories.append(all_memories_cache[memory_id])
|
||||
|
||||
# 更新统计
|
||||
total_time = time.time() - start_time
|
||||
self._update_retrieval_stats(total_time, stage_results)
|
||||
|
||||
total_filtered = sum(result.filtered_count for result in stage_results)
|
||||
|
||||
logger.debug(f"多阶段检索完成:返回 {len(final_memories)} 条记忆,耗时 {total_time:.3f}s")
|
||||
|
||||
return RetrievalResult(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
final_memories=final_memories,
|
||||
stage_results=stage_results,
|
||||
total_processing_time=total_time,
|
||||
total_filtered=total_filtered,
|
||||
retrieval_stats=self.retrieval_stats.copy()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"多阶段检索失败: {e}", exc_info=True)
|
||||
# 返回空结果
|
||||
return RetrievalResult(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
final_memories=[],
|
||||
stage_results=stage_results,
|
||||
total_processing_time=time.time() - start_time,
|
||||
total_filtered=0,
|
||||
retrieval_stats=self.retrieval_stats.copy()
|
||||
)
|
||||
|
||||
async def _metadata_filtering_stage(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
metadata_index,
|
||||
all_memories_cache: Dict[str, MemoryChunk]
|
||||
) -> StageResult:
|
||||
"""阶段1:元数据过滤"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from .metadata_index import IndexQuery
|
||||
|
||||
# 构建索引查询
|
||||
index_query = IndexQuery(
|
||||
user_ids=[user_id],
|
||||
memory_types=self._extract_memory_types_from_context(context),
|
||||
keywords=self._extract_keywords_from_query(query),
|
||||
limit=self.config.metadata_filter_limit,
|
||||
sort_by="last_accessed",
|
||||
sort_order="desc"
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
result = await metadata_index.query_memories(index_query)
|
||||
filtered_count = result.total_count - len(result.memory_ids)
|
||||
|
||||
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.METADATA_FILTERING,
|
||||
memory_ids=result.memory_ids,
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=filtered_count,
|
||||
score_threshold=0.0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"元数据过滤阶段失败: {e}")
|
||||
return StageResult(
|
||||
stage=RetrievalStage.METADATA_FILTERING,
|
||||
memory_ids=[],
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=0.0
|
||||
)
|
||||
|
||||
async def _vector_search_stage(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
vector_storage,
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk]
|
||||
) -> StageResult:
|
||||
"""阶段2:向量搜索"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_embedding = await self._generate_query_embedding(query, context)
|
||||
|
||||
if not query_embedding:
|
||||
return StageResult(
|
||||
stage=RetrievalStage.VECTOR_SEARCH,
|
||||
memory_ids=[],
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=self.config.vector_similarity_threshold
|
||||
)
|
||||
|
||||
# 执行向量搜索
|
||||
search_result = await vector_storage.search_similar(
|
||||
query_embedding,
|
||||
limit=self.config.vector_search_limit
|
||||
)
|
||||
|
||||
# 过滤候选记忆
|
||||
filtered_memories = []
|
||||
for memory_id, similarity in search_result:
|
||||
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
|
||||
filtered_memories.append((memory_id, similarity))
|
||||
|
||||
# 按相似度排序
|
||||
filtered_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
|
||||
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.VECTOR_SEARCH,
|
||||
memory_ids=result_ids,
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=filtered_count,
|
||||
score_threshold=self.config.vector_similarity_threshold
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向量搜索阶段失败: {e}")
|
||||
return StageResult(
|
||||
stage=RetrievalStage.VECTOR_SEARCH,
|
||||
memory_ids=[],
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=self.config.vector_similarity_threshold
|
||||
)
|
||||
|
||||
async def _semantic_reranking_stage(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk]
|
||||
) -> StageResult:
|
||||
"""阶段3:语义重排序"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
reranked_memories = []
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in all_memories_cache:
|
||||
continue
|
||||
|
||||
memory = all_memories_cache[memory_id]
|
||||
|
||||
# 计算综合语义相似度
|
||||
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
|
||||
|
||||
if semantic_score >= self.config.semantic_similarity_threshold:
|
||||
reranked_memories.append((memory_id, semantic_score))
|
||||
|
||||
# 按语义相似度排序
|
||||
reranked_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]]
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
|
||||
logger.debug(f"语义重排序:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.SEMANTIC_RERANKING,
|
||||
memory_ids=result_ids,
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=filtered_count,
|
||||
score_threshold=self.config.semantic_similarity_threshold
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义重排序阶段失败: {e}")
|
||||
return StageResult(
|
||||
stage=RetrievalStage.SEMANTIC_RERANKING,
|
||||
memory_ids=list(candidate_ids), # 失败时返回原候选集
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=self.config.semantic_similarity_threshold
|
||||
)
|
||||
|
||||
async def _contextual_filtering_stage(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: int
|
||||
) -> StageResult:
|
||||
"""阶段4:上下文过滤"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
final_memories = []
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in all_memories_cache:
|
||||
continue
|
||||
|
||||
memory = all_memories_cache[memory_id]
|
||||
|
||||
# 计算上下文相关度评分
|
||||
context_score = await self._calculate_context_relevance(query, memory, context)
|
||||
|
||||
# 结合多因子评分
|
||||
final_score = await self._calculate_final_score(query, memory, context, context_score)
|
||||
|
||||
final_memories.append((memory_id, final_score))
|
||||
|
||||
# 按最终评分排序
|
||||
final_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in final_memories[:limit]]
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
|
||||
logger.debug(f"上下文过滤:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.CONTEXTUAL_FILTERING,
|
||||
memory_ids=result_ids,
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=filtered_count,
|
||||
score_threshold=0.0 # 动态阈值
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"上下文过滤阶段失败: {e}")
|
||||
return StageResult(
|
||||
stage=RetrievalStage.CONTEXTUAL_FILTERING,
|
||||
memory_ids=candidate_ids[:limit], # 失败时返回前limit个
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=0.0
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
# 这里应该调用embedding模型
|
||||
# 由于我们可能没有直接的embedding模型,返回None或使用简单的方法
|
||||
# 在实际实现中,这里应该调用与记忆存储相同的embedding模型
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"生成查询向量失败: {e}")
|
||||
return None
|
||||
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
"""计算语义相似度"""
|
||||
try:
|
||||
# 简单的文本相似度计算
|
||||
query_words = set(query.lower().split())
|
||||
memory_words = set(memory.text_content.lower().split())
|
||||
|
||||
if not query_words or not memory_words:
|
||||
return 0.0
|
||||
|
||||
intersection = query_words & memory_words
|
||||
union = query_words | memory_words
|
||||
|
||||
jaccard_similarity = len(intersection) / len(union)
|
||||
return jaccard_similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算语义相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
"""计算上下文相关度"""
|
||||
try:
|
||||
score = 0.0
|
||||
|
||||
# 检查记忆类型是否匹配上下文
|
||||
if context.get("expected_memory_types"):
|
||||
if memory.memory_type in context["expected_memory_types"]:
|
||||
score += 0.3
|
||||
|
||||
# 检查关键词匹配
|
||||
if context.get("keywords"):
|
||||
memory_keywords = set(memory.keywords)
|
||||
context_keywords = set(context["keywords"])
|
||||
overlap = memory_keywords & context_keywords
|
||||
if overlap:
|
||||
score += len(overlap) / max(len(context_keywords), 1) * 0.4
|
||||
|
||||
# 检查时效性
|
||||
if context.get("recent_only", False):
|
||||
memory_age = time.time() - memory.metadata.created_at
|
||||
if memory_age < 7 * 24 * 3600: # 7天内
|
||||
score += 0.3
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算上下文相关度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
# 语义相似度
|
||||
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
|
||||
|
||||
# 向量相似度(如果有)
|
||||
vector_score = 0.0
|
||||
if memory.embedding:
|
||||
# 这里应该有向量相似度计算,简化处理
|
||||
vector_score = 0.5
|
||||
|
||||
# 时效性评分
|
||||
recency_score = self._calculate_recency_score(memory.metadata.created_at)
|
||||
|
||||
# 权重组合
|
||||
final_score = (
|
||||
semantic_score * self.config.semantic_weight +
|
||||
vector_score * self.config.vector_weight +
|
||||
context_score * self.config.context_weight +
|
||||
recency_score * self.config.recency_weight
|
||||
)
|
||||
|
||||
# 加入记忆重要性权重
|
||||
importance_weight = memory.metadata.importance.value / 4.0 # 标准化到0-1
|
||||
final_score = final_score * (0.7 + importance_weight * 0.3) # 重要性影响30%
|
||||
|
||||
return final_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算最终评分失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_recency_score(self, timestamp: float) -> float:
|
||||
"""计算时效性评分"""
|
||||
try:
|
||||
age = time.time() - timestamp
|
||||
age_days = age / (24 * 3600)
|
||||
|
||||
if age_days < 1:
|
||||
return 1.0
|
||||
elif age_days < 7:
|
||||
return 0.8
|
||||
elif age_days < 30:
|
||||
return 0.6
|
||||
elif age_days < 90:
|
||||
return 0.4
|
||||
else:
|
||||
return 0.2
|
||||
|
||||
except Exception:
|
||||
return 0.5
|
||||
|
||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||
"""从上下文中提取记忆类型"""
|
||||
try:
|
||||
if "expected_memory_types" in context:
|
||||
return context["expected_memory_types"]
|
||||
|
||||
# 根据上下文推断记忆类型
|
||||
if "message_type" in context:
|
||||
message_type = context["message_type"]
|
||||
if message_type in ["personal_info", "fact"]:
|
||||
return [MemoryType.PERSONAL_FACT]
|
||||
elif message_type in ["event", "activity"]:
|
||||
return [MemoryType.EVENT]
|
||||
elif message_type in ["preference", "like"]:
|
||||
return [MemoryType.PREFERENCE]
|
||||
elif message_type in ["opinion", "view"]:
|
||||
return [MemoryType.OPINION]
|
||||
|
||||
return []
|
||||
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _extract_keywords_from_query(self, query: str) -> List[str]:
|
||||
"""从查询中提取关键词"""
|
||||
try:
|
||||
# 简单的关键词提取
|
||||
words = query.lower().split()
|
||||
# 过滤停用词
|
||||
stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"}
|
||||
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
|
||||
return keywords[:10] # 最多返回10个关键词
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
|
||||
"""更新检索统计"""
|
||||
self.retrieval_stats["total_queries"] += 1
|
||||
|
||||
# 更新平均检索时间
|
||||
current_avg = self.retrieval_stats["average_retrieval_time"]
|
||||
total_queries = self.retrieval_stats["total_queries"]
|
||||
new_avg = (current_avg * (total_queries - 1) + total_time) / total_queries
|
||||
self.retrieval_stats["average_retrieval_time"] = new_avg
|
||||
|
||||
# 更新各阶段统计
|
||||
for result in stage_results:
|
||||
stage_name = result.stage.value
|
||||
if stage_name in self.retrieval_stats["stage_stats"]:
|
||||
stage_stat = self.retrieval_stats["stage_stats"][stage_name]
|
||||
stage_stat["calls"] += 1
|
||||
|
||||
current_stage_avg = stage_stat["avg_time"]
|
||||
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"]
|
||||
stage_stat["avg_time"] = new_stage_avg
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
"""获取检索统计信息"""
|
||||
return self.retrieval_stats.copy()
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.retrieval_stats = {
|
||||
"total_queries": 0,
|
||||
"average_retrieval_time": 0.0,
|
||||
"stage_stats": {
|
||||
"metadata_filtering": {"calls": 0, "avg_time": 0.0},
|
||||
"vector_search": {"calls": 0, "avg_time": 0.0},
|
||||
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
|
||||
"contextual_filtering": {"calls": 0, "avg_time": 0.0}
|
||||
}
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class MemoryBuildScheduler:
|
||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||
"""
|
||||
初始化记忆构建调度器
|
||||
|
||||
参数:
|
||||
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
||||
std_hours1 (float): 第一个分布的标准差(小时)
|
||||
weight1 (float): 第一个分布的权重
|
||||
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
|
||||
std_hours2 (float): 第二个分布的标准差(小时)
|
||||
weight2 (float): 第二个分布的权重
|
||||
total_samples (int): 要生成的总时间点数量
|
||||
"""
|
||||
# 验证参数
|
||||
if total_samples <= 0:
|
||||
raise ValueError("total_samples 必须大于0")
|
||||
if weight1 < 0 or weight2 < 0:
|
||||
raise ValueError("权重必须为非负数")
|
||||
if std_hours1 < 0 or std_hours2 < 0:
|
||||
raise ValueError("标准差必须为非负数")
|
||||
|
||||
# 归一化权重
|
||||
total_weight = weight1 + weight2
|
||||
if total_weight == 0:
|
||||
raise ValueError("权重总和不能为0")
|
||||
self.weight1 = weight1 / total_weight
|
||||
self.weight2 = weight2 / total_weight
|
||||
|
||||
self.n_hours1 = n_hours1
|
||||
self.std_hours1 = std_hours1
|
||||
self.n_hours2 = n_hours2
|
||||
self.std_hours2 = std_hours2
|
||||
self.total_samples = total_samples
|
||||
self.base_time = datetime.now()
|
||||
|
||||
def generate_time_samples(self):
|
||||
"""生成混合分布的时间采样点"""
|
||||
# 根据权重计算每个分布的样本数
|
||||
samples1 = max(1, int(self.total_samples * self.weight1))
|
||||
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
|
||||
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
||||
|
||||
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
||||
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
def get_timestamp_array(self):
|
||||
"""返回时间戳数组"""
|
||||
timestamps = self.generate_time_samples()
|
||||
return [int(t.timestamp()) for t in timestamps]
|
||||
|
||||
|
||||
# def print_time_samples(timestamps, show_distribution=True):
|
||||
# """打印时间样本和分布信息"""
|
||||
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||
# print("-" * 50)
|
||||
|
||||
# now = datetime.now()
|
||||
# time_diffs = []
|
||||
|
||||
# for i, timestamp in enumerate(timestamps, 1):
|
||||
# hours_diff = (now - timestamp).total_seconds() / 3600
|
||||
# time_diffs.append(hours_diff)
|
||||
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||
|
||||
# # 打印统计信息
|
||||
# print("\n统计信息:")
|
||||
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||
|
||||
# if show_distribution:
|
||||
# # 计算时间分布的直方图
|
||||
# hist, bins = np.histogram(time_diffs, bins=40)
|
||||
# print("\n时间分布(每个*代表一个时间点):")
|
||||
# for i in range(len(hist)):
|
||||
# if hist[i] > 0:
|
||||
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 创建一个双峰分布的记忆调度器
|
||||
# scheduler = MemoryBuildScheduler(
|
||||
# n_hours1=12, # 第一个分布均值(12小时前)
|
||||
# std_hours1=8, # 第一个分布标准差
|
||||
# weight1=0.7, # 第一个分布权重 70%
|
||||
# n_hours2=36, # 第二个分布均值(36小时前)
|
||||
# std_hours2=24, # 第二个分布标准差
|
||||
# weight2=0.3, # 第二个分布权重 30%
|
||||
# total_samples=50, # 总共生成50个时间点
|
||||
# )
|
||||
|
||||
# # 生成时间分布
|
||||
# timestamps = scheduler.generate_time_samples()
|
||||
|
||||
# # 打印结果,包含分布可视化
|
||||
# print_time_samples(timestamps, show_distribution=True)
|
||||
|
||||
# # 打印时间戳数组
|
||||
# timestamp_array = scheduler.get_timestamp_array()
|
||||
# print("\n时间戳数组(Unix时间戳):")
|
||||
# print("[", end="")
|
||||
# for i, ts in enumerate(timestamp_array):
|
||||
# if i > 0:
|
||||
# print(", ", end="")
|
||||
# print(ts, end="")
|
||||
# print("]")
|
||||
@@ -1,359 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
|
||||
logger = get_logger("vector_instant_memory_v2")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息数据结构"""
|
||||
|
||||
message_id: str
|
||||
chat_id: str
|
||||
content: str
|
||||
timestamp: float
|
||||
sender: str = "unknown"
|
||||
message_type: str = "text"
|
||||
|
||||
|
||||
class VectorInstantMemoryV2:
|
||||
"""重构的向量瞬时记忆系统 V2
|
||||
|
||||
新设计理念:
|
||||
1. 全量存储 - 所有聊天记录都存储为向量
|
||||
2. 定时清理 - 定期清理过期记录
|
||||
3. 实时匹配 - 新消息与历史记录做向量相似度匹配
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, retention_hours: int = 24, cleanup_interval: int = 3600):
|
||||
"""
|
||||
初始化向量瞬时记忆系统
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
retention_hours: 记忆保留时长(小时)
|
||||
cleanup_interval: 清理间隔(秒)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.retention_hours = retention_hours
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.collection_name = "instant_memory"
|
||||
|
||||
# 清理任务相关
|
||||
self.cleanup_task = None
|
||||
self.is_running = True
|
||||
|
||||
# 初始化系统
|
||||
self._init_chroma()
|
||||
self._start_cleanup_task()
|
||||
|
||||
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
|
||||
|
||||
def _init_chroma(self):
|
||||
"""使用全局服务初始化向量数据库集合"""
|
||||
try:
|
||||
# 现在我们只获取集合,而不是创建新的客户端
|
||||
vector_db_service.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
|
||||
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
|
||||
except Exception as e:
|
||||
logger.error(f"获取向量记忆集合失败: {e}")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动定时清理任务"""
|
||||
|
||||
def cleanup_worker():
|
||||
while self.is_running:
|
||||
try:
|
||||
self._cleanup_expired_messages()
|
||||
time.sleep(self.cleanup_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务异常: {e}")
|
||||
time.sleep(60) # 异常时等待1分钟再继续
|
||||
|
||||
self.cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
self.cleanup_task.start()
|
||||
logger.info(f"定时清理任务已启动,间隔{self.cleanup_interval}秒")
|
||||
|
||||
def _cleanup_expired_messages(self):
|
||||
"""清理过期的聊天记录"""
|
||||
try:
|
||||
expire_time = time.time() - (self.retention_hours * 3600)
|
||||
|
||||
# 采用 get -> filter -> delete 模式,避免复杂的 where 查询
|
||||
# 1. 获取当前 chat_id 的所有文档
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.collection_name, where={"chat_id": self.chat_id}, include=["metadatas"]
|
||||
)
|
||||
|
||||
if not results or not results.get("ids"):
|
||||
logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理")
|
||||
return
|
||||
|
||||
# 2. 在内存中过滤出过期的文档
|
||||
expired_ids = []
|
||||
metadatas = results.get("metadatas", [])
|
||||
ids = results.get("ids", [])
|
||||
|
||||
for i, metadata in enumerate(metadatas):
|
||||
if metadata and metadata.get("timestamp", float("inf")) < expire_time:
|
||||
expired_ids.append(ids[i])
|
||||
|
||||
# 3. 如果有过期文档,根据 ID 进行删除
|
||||
if expired_ids:
|
||||
vector_db_service.delete(collection_name=self.collection_name, ids=expired_ids)
|
||||
logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录")
|
||||
else:
|
||||
logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期记录失败: {e}")
|
||||
|
||||
async def store_message(self, content: str, sender: str = "user") -> bool:
|
||||
"""
|
||||
存储聊天消息到向量库
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
sender: 发送者
|
||||
|
||||
Returns:
|
||||
bool: 是否存储成功
|
||||
"""
|
||||
if not content.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
# 生成消息向量
|
||||
message_vector = await get_embedding(content)
|
||||
if not message_vector:
|
||||
logger.warning(f"消息向量生成失败: {content[:50]}...")
|
||||
return False
|
||||
|
||||
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
|
||||
|
||||
message = ChatMessage(
|
||||
message_id=message_id, chat_id=self.chat_id, content=content, timestamp=time.time(), sender=sender
|
||||
)
|
||||
|
||||
# 使用新的服务存储
|
||||
vector_db_service.add(
|
||||
collection_name=self.collection_name,
|
||||
embeddings=[message_vector],
|
||||
documents=[content],
|
||||
metadatas=[
|
||||
{
|
||||
"message_id": message.message_id,
|
||||
"chat_id": message.chat_id,
|
||||
"timestamp": message.timestamp,
|
||||
"sender": message.sender,
|
||||
"message_type": message.message_type,
|
||||
}
|
||||
],
|
||||
ids=[message_id],
|
||||
)
|
||||
|
||||
logger.debug(f"消息已存储: {content[:50]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息失败: {e}")
|
||||
return False
|
||||
|
||||
async def find_similar_messages(
|
||||
self, query: str, top_k: int = 5, similarity_threshold: float = 0.7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查找与查询相似的历史消息
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
top_k: 返回的最相似消息数量
|
||||
similarity_threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 相似消息列表,包含content、similarity、timestamp等信息
|
||||
"""
|
||||
if not query.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
query_vector = await get_embedding(query)
|
||||
if not query_vector:
|
||||
return []
|
||||
|
||||
# 使用新的服务进行查询
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.collection_name,
|
||||
query_embeddings=[query_vector],
|
||||
n_results=top_k,
|
||||
where={"chat_id": self.chat_id},
|
||||
)
|
||||
|
||||
if not results.get("documents") or not results["documents"][0]:
|
||||
return []
|
||||
|
||||
# 处理搜索结果
|
||||
similar_messages = []
|
||||
documents = results["documents"][0]
|
||||
distances = results["distances"][0] if results["distances"] else []
|
||||
metadatas = results["metadatas"][0] if results["metadatas"] else []
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 计算相似度(ChromaDB返回距离,需转换)
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
similarity = 1 - distance
|
||||
|
||||
# 过滤低相似度结果
|
||||
if similarity < similarity_threshold:
|
||||
continue
|
||||
|
||||
# 获取元数据
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
|
||||
# 安全获取timestamp
|
||||
timestamp = metadata.get("timestamp", 0) if isinstance(metadata, dict) else 0
|
||||
timestamp = float(timestamp) if isinstance(timestamp, (int, float)) else 0.0
|
||||
|
||||
similar_messages.append(
|
||||
{
|
||||
"content": doc,
|
||||
"similarity": similarity,
|
||||
"timestamp": timestamp,
|
||||
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
|
||||
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
|
||||
"time_ago": self._format_time_ago(timestamp),
|
||||
}
|
||||
)
|
||||
|
||||
# 按相似度排序
|
||||
similar_messages.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
logger.debug(f"找到 {len(similar_messages)} 条相似消息 (查询: {query[:30]}...)")
|
||||
return similar_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找相似消息失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_time_ago(timestamp: float) -> str:
|
||||
"""格式化时间差显示"""
|
||||
if timestamp <= 0:
|
||||
return "未知时间"
|
||||
|
||||
try:
|
||||
now = time.time()
|
||||
diff = now - timestamp
|
||||
|
||||
if diff < 60:
|
||||
return f"{int(diff)}秒前"
|
||||
elif diff < 3600:
|
||||
return f"{int(diff / 60)}分钟前"
|
||||
elif diff < 86400:
|
||||
return f"{int(diff / 3600)}小时前"
|
||||
else:
|
||||
return f"{int(diff / 86400)}天前"
|
||||
except Exception:
|
||||
return "时间格式错误"
|
||||
|
||||
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
|
||||
"""
|
||||
获取与当前消息相关的记忆上下文
|
||||
|
||||
Args:
|
||||
current_message: 当前消息
|
||||
context_size: 上下文消息数量
|
||||
|
||||
Returns:
|
||||
str: 格式化的记忆上下文
|
||||
"""
|
||||
similar_messages = await self.find_similar_messages(
|
||||
current_message,
|
||||
top_k=context_size,
|
||||
similarity_threshold=0.6, # 降低阈值以获得更多上下文
|
||||
)
|
||||
|
||||
if not similar_messages:
|
||||
return ""
|
||||
|
||||
# 格式化上下文
|
||||
context_lines = []
|
||||
for msg in similar_messages:
|
||||
context_lines.append(
|
||||
f"[{msg['time_ago']}] {msg['sender']}: {msg['content']} (相似度: {msg['similarity']:.2f})"
|
||||
)
|
||||
|
||||
return "相关的历史记忆:\n" + "\n".join(context_lines)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取记忆系统统计信息"""
|
||||
stats = {
|
||||
"chat_id": self.chat_id,
|
||||
"retention_hours": self.retention_hours,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
"system_status": "running" if self.is_running else "stopped",
|
||||
"total_messages": 0,
|
||||
"db_status": "connected",
|
||||
}
|
||||
|
||||
try:
|
||||
# 注意:count() 现在没有 chat_id 过滤,返回的是整个集合的数量
|
||||
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
|
||||
# 这里为了简化,暂时显示集合总数
|
||||
result = vector_db_service.count(collection_name=self.collection_name)
|
||||
stats["total_messages"] = result
|
||||
except Exception:
|
||||
stats["total_messages"] = "查询失败"
|
||||
stats["db_status"] = "disconnected"
|
||||
|
||||
return stats
|
||||
|
||||
def stop(self):
|
||||
"""停止记忆系统"""
|
||||
self.is_running = False
|
||||
if self.cleanup_task and self.cleanup_task.is_alive():
|
||||
logger.info("正在停止定时清理任务...")
|
||||
logger.info(f"向量瞬时记忆系统已停止: {self.chat_id}")
|
||||
|
||||
|
||||
# 为了兼容现有代码,提供工厂函数
|
||||
def create_vector_memory_v2(chat_id: str, retention_hours: int = 24) -> VectorInstantMemoryV2:
|
||||
"""创建向量瞬时记忆系统V2实例"""
|
||||
return VectorInstantMemoryV2(chat_id, retention_hours)
|
||||
|
||||
|
||||
# 使用示例
|
||||
async def demo():
|
||||
"""使用演示"""
|
||||
memory = VectorInstantMemoryV2("demo_chat")
|
||||
|
||||
# 存储一些测试消息
|
||||
await memory.store_message("今天天气不错,出去散步了", "用户")
|
||||
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
|
||||
await memory.store_message("明天要开会,有点紧张", "用户")
|
||||
|
||||
# 查找相似消息
|
||||
similar = await memory.find_similar_messages("天气怎么样")
|
||||
print("相似消息:", similar)
|
||||
|
||||
# 获取上下文
|
||||
context = await memory.get_memory_for_context("今天心情如何")
|
||||
print("记忆上下文:", context)
|
||||
|
||||
# 查看统计信息
|
||||
stats = memory.get_stats()
|
||||
print("系统状态:", stats)
|
||||
|
||||
memory.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(demo())
|
||||
723
src/chat/memory_system/vector_storage.py
Normal file
723
src/chat/memory_system/vector_storage.py
Normal file
@@ -0,0 +1,723 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 尝试导入FAISS,如果不可用则使用简单替代
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using simple vector storage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""向量存储配置"""
|
||||
dimension: int = 768
|
||||
similarity_threshold: float = 0.8
|
||||
index_type: str = "flat" # flat, ivf, hnsw
|
||||
max_index_size: int = 100000
|
||||
storage_path: str = "data/memory_vectors"
|
||||
auto_save_interval: int = 100 # 每N次操作自动保存
|
||||
enable_compression: bool = True
|
||||
|
||||
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量索引
|
||||
self.vector_index = None
|
||||
self.memory_id_to_index = {} # memory_id -> vector index
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, List[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
"total_vectors": 0,
|
||||
"index_build_time": 0.0,
|
||||
"average_search_time": 0.0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"total_searches": 0,
|
||||
"cache_hits": 0
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._operation_count = 0
|
||||
|
||||
# 初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model: LLMRequest = None
|
||||
|
||||
def _initialize_index(self):
|
||||
"""初始化向量索引"""
|
||||
try:
|
||||
if FAISS_AVAILABLE:
|
||||
if self.config.index_type == "flat":
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
elif self.config.index_type == "ivf":
|
||||
quantizer = faiss.IndexFlatIP(self.config.dimension)
|
||||
nlist = min(100, max(1, self.config.max_index_size // 1000))
|
||||
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
|
||||
elif self.config.index_type == "hnsw":
|
||||
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
|
||||
self.vector_index.hnsw.efConstruction = 40
|
||||
else:
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
else:
|
||||
# 简单的向量存储实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量索引初始化失败: {e}")
|
||||
# 回退到简单实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
async def initialize_embedding_model(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self.embedding_model is None:
|
||||
self.embedding_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding,
|
||||
request_type="memory.embedding"
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保嵌入模型已初始化
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
# 批量获取嵌入向量
|
||||
embedding_tasks = []
|
||||
memory_texts = []
|
||||
|
||||
for memory in memories:
|
||||
if memory.embedding is None:
|
||||
# 如果没有嵌入向量,需要生成
|
||||
text = self._prepare_embedding_text(memory)
|
||||
memory_texts.append((memory.memory_id, text))
|
||||
else:
|
||||
# 已有嵌入向量,直接使用
|
||||
await self._add_single_memory(memory, memory.embedding)
|
||||
|
||||
# 批量生成缺失的嵌入向量
|
||||
if memory_texts:
|
||||
await self._batch_generate_and_store_embeddings(memory_texts)
|
||||
|
||||
# 自动保存检查
|
||||
self._operation_count += len(memories)
|
||||
if self._operation_count >= self.config.auto_save_interval:
|
||||
await self.save_storage()
|
||||
self._operation_count = 0
|
||||
|
||||
storage_time = time.time() - start_time
|
||||
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
|
||||
|
||||
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
|
||||
"""准备用于嵌入的文本"""
|
||||
# 构建包含丰富信息的文本
|
||||
text_parts = [
|
||||
memory.text_content,
|
||||
f"类型: {memory.memory_type.value}",
|
||||
f"关键词: {', '.join(memory.keywords)}",
|
||||
f"标签: {', '.join(memory.tags)}"
|
||||
]
|
||||
|
||||
if memory.metadata.emotional_context:
|
||||
text_parts.append(f"情感: {memory.metadata.emotional_context}")
|
||||
|
||||
return " | ".join(text_parts)
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
|
||||
try:
|
||||
texts = [text for _, text in memory_texts]
|
||||
memory_ids = [memory_id for memory_id, _ in memory_texts]
|
||||
|
||||
# 批量生成嵌入向量
|
||||
embeddings = await self._batch_generate_embeddings(texts)
|
||||
|
||||
# 存储向量和记忆
|
||||
for memory_id, embedding in zip(memory_ids, embeddings):
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
await self._add_single_memory(memory, embedding)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 创建新的事件循环来运行异步操作
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 使用线程池并行生成嵌入向量
|
||||
with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor:
|
||||
tasks = []
|
||||
for text in texts:
|
||||
task = loop.run_in_executor(
|
||||
executor,
|
||||
self._generate_single_embedding,
|
||||
text
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
embeddings = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
valid_embeddings = []
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if isinstance(embedding, Exception):
|
||||
logger.warning(f"生成第 {i} 个文本的嵌入向量失败: {embedding}")
|
||||
valid_embeddings.append([])
|
||||
elif embedding and len(embedding) == self.config.dimension:
|
||||
valid_embeddings.append(embedding)
|
||||
else:
|
||||
logger.warning(f"第 {i} 个文本的嵌入向量格式异常")
|
||||
valid_embeddings.append([])
|
||||
|
||||
return valid_embeddings
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
return [[] for _ in texts]
|
||||
|
||||
def _generate_single_embedding(self, text: str) -> List[float]:
|
||||
"""生成单个文本的嵌入向量"""
|
||||
try:
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 使用模型生成嵌入向量
|
||||
embedding, _ = loop.run_until_complete(
|
||||
self.embedding_model.get_embedding(text)
|
||||
)
|
||||
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
return embedding
|
||||
else:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {self.config.dimension}, 实际 {len(embedding) if embedding else 0}")
|
||||
return []
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入向量失败: {e}")
|
||||
return []
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
# 规范化向量
|
||||
if embedding:
|
||||
embedding = self._normalize_vector(embedding)
|
||||
|
||||
# 添加到缓存
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
self.vector_cache[memory.memory_id] = embedding
|
||||
|
||||
# 更新记忆的嵌入向量
|
||||
memory.set_embedding(embedding)
|
||||
|
||||
# 添加到向量索引
|
||||
if hasattr(self.vector_index, 'add'):
|
||||
# FAISS索引
|
||||
if isinstance(embedding, np.ndarray):
|
||||
vector_array = embedding.reshape(1, -1).astype('float32')
|
||||
else:
|
||||
vector_array = np.array([embedding], dtype='float32')
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
|
||||
# IVF索引需要先训练
|
||||
logger.debug("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
index_id = self.vector_index.ntotal - 1
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
index_id = self.vector_index.add_vector(embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory.memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory.memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: List[float]) -> List[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
|
||||
try:
|
||||
vector_array = np.array(vector, dtype=np.float32)
|
||||
norm = np.linalg.norm(vector_array)
|
||||
if norm == 0:
|
||||
return vector
|
||||
|
||||
normalized = vector_array / norm
|
||||
return normalized.tolist()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"向量归一化失败: {e}")
|
||||
return vector
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
limit: int = 10,
|
||||
user_id: Optional[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 规范化查询向量
|
||||
query_vector = self._normalize_vector(query_vector)
|
||||
|
||||
# 执行向量搜索
|
||||
with self._lock:
|
||||
if hasattr(self.vector_index, 'search'):
|
||||
# FAISS索引
|
||||
if isinstance(query_vector, np.ndarray):
|
||||
query_array = query_vector.reshape(1, -1).astype('float32')
|
||||
else:
|
||||
query_array = np.array([query_vector], dtype='float32')
|
||||
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
|
||||
# 设置IVF搜索参数
|
||||
nprobe = min(self.vector_index.nlist, 10)
|
||||
self.vector_index.nprobe = nprobe
|
||||
|
||||
distances, indices = self.vector_index.search(query_array, min(limit, self.storage_stats["total_vectors"]))
|
||||
distances = distances.flatten().tolist()
|
||||
indices = indices.flatten().tolist()
|
||||
else:
|
||||
# 简单索引
|
||||
results = self.vector_index.search(query_vector, limit)
|
||||
distances = [score for _, score in results]
|
||||
indices = [idx for idx, _ in results]
|
||||
|
||||
# 处理搜索结果
|
||||
results = []
|
||||
for distance, index in zip(distances, indices):
|
||||
if index == -1: # FAISS的无效索引标记
|
||||
continue
|
||||
|
||||
memory_id = self.index_to_memory_id.get(index)
|
||||
if memory_id:
|
||||
# 应用用户过滤
|
||||
if user_id:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory and memory.user_id != user_id:
|
||||
continue
|
||||
|
||||
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
||||
results.append((memory_id, similarity))
|
||||
|
||||
# 更新统计
|
||||
search_time = time.time() - start_time
|
||||
self.storage_stats["total_searches"] += 1
|
||||
self.storage_stats["average_search_time"] = (
|
||||
(self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) /
|
||||
self.storage_stats["total_searches"]
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量搜索失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
self.storage_stats["cache_hits"] += 1
|
||||
return self.memory_cache[memory_id]
|
||||
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
|
||||
return
|
||||
|
||||
# 获取旧索引
|
||||
old_index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 删除旧向量(如果支持)
|
||||
if hasattr(self.vector_index, 'remove_ids'):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([old_index]))
|
||||
except:
|
||||
logger.warning("无法删除旧向量,将直接添加新向量")
|
||||
|
||||
# 规范化新向量
|
||||
new_embedding = self._normalize_vector(new_embedding)
|
||||
|
||||
# 添加新向量
|
||||
if hasattr(self.vector_index, 'add'):
|
||||
if isinstance(new_embedding, np.ndarray):
|
||||
vector_array = new_embedding.reshape(1, -1).astype('float32')
|
||||
else:
|
||||
vector_array = np.array([new_embedding], dtype='float32')
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
new_index = self.vector_index.ntotal - 1
|
||||
else:
|
||||
new_index = self.vector_index.add_vector(new_embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory_id] = new_index
|
||||
self.index_to_memory_id[new_index] = memory_id
|
||||
|
||||
# 更新缓存
|
||||
self.vector_cache[memory_id] = new_embedding
|
||||
|
||||
# 更新记忆对象
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
memory.set_embedding(new_embedding)
|
||||
|
||||
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
|
||||
|
||||
async def delete_memory(self, memory_id: str):
|
||||
"""删除记忆"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
return
|
||||
|
||||
# 获取索引
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 从向量索引中删除(如果支持)
|
||||
if hasattr(self.vector_index, 'remove_ids'):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([index]))
|
||||
except:
|
||||
logger.warning("无法从向量索引中删除,仅从缓存中移除")
|
||||
|
||||
# 删除映射关系
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
# 从缓存中删除
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.vector_cache.pop(memory_id, None)
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
|
||||
|
||||
logger.debug(f"删除记忆 {memory_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除记忆失败: {e}")
|
||||
|
||||
async def save_storage(self):
|
||||
"""保存向量存储到文件"""
|
||||
try:
|
||||
logger.info("正在保存向量存储...")
|
||||
|
||||
# 保存记忆缓存
|
||||
cache_data = {
|
||||
memory_id: memory.to_dict()
|
||||
for memory_id, memory in self.memory_cache.items()
|
||||
}
|
||||
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
with open(vector_cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
mapping_data = {
|
||||
"memory_id_to_index": self.memory_id_to_index,
|
||||
"index_to_memory_id": self.index_to_memory_id
|
||||
}
|
||||
with open(mapping_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'):
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
faiss.write_index(self.vector_index, str(index_file))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
with open(stats_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
logger.info("✅ 向量存储保存完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 保存向量存储失败: {e}")
|
||||
|
||||
async def load_storage(self):
|
||||
"""从文件加载向量存储"""
|
||||
try:
|
||||
logger.info("正在加载向量存储...")
|
||||
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
memory_id: MemoryChunk.from_dict(memory_data)
|
||||
for memory_id, memory_data in cache_data.items()
|
||||
}
|
||||
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, 'r', encoding='utf-8') as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
|
||||
|
||||
# 加载FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE:
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
if index_file.exists() and hasattr(self.vector_index, 'load'):
|
||||
try:
|
||||
loaded_index = faiss.read_index(str(index_file))
|
||||
# 如果索引类型匹配,则替换
|
||||
if type(loaded_index) == type(self.vector_index):
|
||||
self.vector_index = loaded_index
|
||||
logger.info("✅ FAISS索引加载完成")
|
||||
else:
|
||||
logger.warning("索引类型不匹配,重新构建索引")
|
||||
await self._rebuild_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载FAISS索引失败: {e},重新构建")
|
||||
await self._rebuild_index()
|
||||
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, 'r', encoding='utf-8') as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载向量存储失败: {e}")
|
||||
|
||||
async def _rebuild_index(self):
|
||||
"""重建向量索引"""
|
||||
try:
|
||||
logger.info("正在重建向量索引...")
|
||||
|
||||
# 重新初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 重新添加所有向量
|
||||
for memory_id, embedding in self.vector_cache.items():
|
||||
if embedding:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
await self._add_single_memory(memory, embedding)
|
||||
|
||||
logger.info("✅ 向量索引重建完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重建向量索引失败: {e}")
|
||||
|
||||
async def optimize_storage(self):
|
||||
"""优化存储"""
|
||||
try:
|
||||
logger.info("开始向量存储优化...")
|
||||
|
||||
# 清理无效引用
|
||||
self._cleanup_invalid_references()
|
||||
|
||||
# 重新构建索引(如果碎片化严重)
|
||||
if self.storage_stats["total_vectors"] > 1000:
|
||||
await self._rebuild_index()
|
||||
|
||||
# 更新缓存命中率
|
||||
if self.storage_stats["total_searches"] > 0:
|
||||
self.storage_stats["cache_hit_rate"] = (
|
||||
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
|
||||
)
|
||||
|
||||
logger.info("✅ 向量存储优化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储优化失败: {e}")
|
||||
|
||||
def _cleanup_invalid_references(self):
|
||||
"""清理无效引用"""
|
||||
with self._lock:
|
||||
# 清理无效的memory_id到index的映射
|
||||
valid_memory_ids = set(self.memory_cache.keys())
|
||||
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
|
||||
|
||||
for memory_id in invalid_memory_ids:
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
|
||||
else:
|
||||
stats["cache_hit_rate"] = 0.0
|
||||
return stats
|
||||
|
||||
|
||||
class SimpleVectorIndex:
|
||||
"""简单的向量索引实现(当FAISS不可用时的替代方案)"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: List[List[float]] = []
|
||||
self.vector_ids: List[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: List[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
|
||||
vector_id = self.next_id
|
||||
self.vectors.append(vector.copy())
|
||||
self.vector_ids.append(vector_id)
|
||||
self.next_id += 1
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
|
||||
results = []
|
||||
|
||||
for i, vector in enumerate(self.vectors):
|
||||
similarity = self._calculate_cosine_similarity(query_vector, vector)
|
||||
results.append((self.vector_ids[i], similarity))
|
||||
|
||||
# 按相似度排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2))
|
||||
norm1 = sum(x * x for x in v1) ** 0.5
|
||||
norm2 = sum(x * x for x in v2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def ntotal(self) -> int:
|
||||
"""向量总数"""
|
||||
return len(self.vectors)
|
||||
@@ -28,8 +28,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
replace_user_references_sync,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||
# 旧记忆系统已被移除
|
||||
# 旧记忆系统已被移除
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
@@ -231,9 +231,12 @@ class DefaultReplyer:
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
# 使用纯向量瞬时记忆系统V2,支持自定义保留时间
|
||||
self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1)
|
||||
# 使用新的增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
|
||||
# self.memory_activator = EnhancedMemoryActivator()
|
||||
self.memory_activator = None # 暂时禁用记忆激活器
|
||||
# 旧的即时记忆系统已被移除,现在使用增强记忆系统
|
||||
# self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1)
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
@@ -459,90 +462,65 @@ class DefaultReplyer:
|
||||
|
||||
instant_memory = None
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history_prompt=chat_history
|
||||
)
|
||||
# 使用新的增强记忆系统检索记忆
|
||||
running_memories = []
|
||||
instant_memory = None
|
||||
|
||||
if global_config.memory.enable_instant_memory:
|
||||
# 使用异步记忆包装器(最优化的非阻塞模式)
|
||||
try:
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
# 使用新的增强记忆系统
|
||||
from src.chat.memory_system.enhanced_memory_integration import recall_memories, remember_message
|
||||
|
||||
# 获取异步记忆包装器
|
||||
async_memory = get_async_instant_memory(self.chat_stream.stream_id)
|
||||
|
||||
# 后台存储聊天历史(完全非阻塞)
|
||||
async_memory.store_memory_background(chat_history)
|
||||
|
||||
# 快速检索记忆,最大超时2秒
|
||||
instant_memory = await async_memory.get_memory_with_fallback(target, max_timeout=2.0)
|
||||
|
||||
logger.info(f"异步瞬时记忆:{instant_memory}")
|
||||
|
||||
except ImportError:
|
||||
# 如果异步包装器不可用,尝试使用异步记忆管理器
|
||||
try:
|
||||
from src.chat.memory_system.async_memory_optimizer import (
|
||||
retrieve_memory_nonblocking,
|
||||
store_memory_nonblocking,
|
||||
# 异步存储聊天历史(非阻塞)
|
||||
asyncio.create_task(
|
||||
remember_message(
|
||||
message=chat_history,
|
||||
user_id=str(self.chat_stream.stream_id),
|
||||
chat_id=self.chat_stream.stream_id
|
||||
)
|
||||
)
|
||||
|
||||
# 异步存储聊天历史(非阻塞)
|
||||
asyncio.create_task(
|
||||
store_memory_nonblocking(chat_id=self.chat_stream.stream_id, content=chat_history)
|
||||
)
|
||||
# 检索相关记忆
|
||||
enhanced_memories = await recall_memories(
|
||||
query=target,
|
||||
user_id=str(self.chat_stream.stream_id),
|
||||
chat_id=self.chat_stream.stream_id
|
||||
)
|
||||
|
||||
# 尝试从缓存获取瞬时记忆
|
||||
instant_memory = await retrieve_memory_nonblocking(chat_id=self.chat_stream.stream_id, query=target)
|
||||
# 转换格式以兼容现有代码
|
||||
running_memories = []
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
for memory in enhanced_memories.get("memories", []):
|
||||
running_memories.append({
|
||||
'content': memory.get("content", ""),
|
||||
'score': memory.get("confidence", 0.0),
|
||||
'memory_type': memory.get("type", "unknown")
|
||||
})
|
||||
|
||||
# 如果没有缓存结果,快速检索一次
|
||||
if instant_memory is None:
|
||||
try:
|
||||
instant_memory = await asyncio.wait_for(
|
||||
self.instant_memory.get_memory_for_context(target), timeout=1.5
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("瞬时记忆检索超时,使用空结果")
|
||||
instant_memory = ""
|
||||
# 构建瞬时记忆字符串
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
instant_memory = "\\n".join([
|
||||
f"{memory.get('content', '')} (相似度: {memory.get('confidence', 0.0):.2f})"
|
||||
for memory in enhanced_memories.get("memories", [])[:3] # 取前3条
|
||||
])
|
||||
|
||||
logger.info(f"向量瞬时记忆:{instant_memory}")
|
||||
|
||||
except ImportError:
|
||||
# 最后的fallback:使用原有逻辑但加上超时控制
|
||||
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
|
||||
|
||||
# 异步存储聊天历史
|
||||
asyncio.create_task(self.instant_memory.store_message(chat_history))
|
||||
|
||||
# 带超时的记忆检索
|
||||
try:
|
||||
instant_memory = await asyncio.wait_for(
|
||||
self.instant_memory.get_memory_for_context(target),
|
||||
timeout=1.0, # 最保守的1秒超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("瞬时记忆检索超时,跳过记忆获取")
|
||||
instant_memory = ""
|
||||
except Exception as e:
|
||||
logger.error(f"瞬时记忆检索失败: {e}")
|
||||
instant_memory = ""
|
||||
|
||||
logger.info(f"同步瞬时记忆:{instant_memory}")
|
||||
logger.info(f"增强记忆系统检索到 {len(running_memories)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"瞬时记忆系统异常: {e}")
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
running_memories = []
|
||||
instant_memory = ""
|
||||
|
||||
# 构建记忆字符串,即使某种记忆为空也要继续
|
||||
memory_str = ""
|
||||
has_any_memory = False
|
||||
|
||||
# 添加长期记忆
|
||||
# 添加长期记忆(来自增强记忆系统)
|
||||
if running_memories:
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
memory_str += f"- {running_memory['content']} (类型: {running_memory['memory_type']}, 相似度: {running_memory['score']:.2f})\n"
|
||||
has_any_memory = True
|
||||
|
||||
# 添加瞬时记忆
|
||||
|
||||
@@ -371,28 +371,35 @@ class Prompt:
|
||||
tasks.append(self._build_cross_context())
|
||||
task_names.append("cross_context")
|
||||
|
||||
# 性能优化
|
||||
base_timeout = 10.0
|
||||
task_timeout = 2.0
|
||||
timeout_seconds = min(
|
||||
max(base_timeout, len(tasks) * task_timeout),
|
||||
30.0,
|
||||
)
|
||||
# 性能优化 - 为不同任务设置不同的超时时间
|
||||
task_timeouts = {
|
||||
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
|
||||
"tool_info": 3.0, # 工具信息中等速度
|
||||
"relation_info": 2.0, # 关系信息通常较快
|
||||
"knowledge_info": 3.0, # 知识库查询中等速度
|
||||
"cross_context": 2.0, # 上下文处理通常较快
|
||||
"expression_habits": 1.5, # 表达习惯处理很快
|
||||
}
|
||||
|
||||
max_concurrent_tasks = 5
|
||||
if len(tasks) > max_concurrent_tasks:
|
||||
results = []
|
||||
for i in range(0, len(tasks), max_concurrent_tasks):
|
||||
batch_tasks = tasks[i : i + max_concurrent_tasks]
|
||||
# 分别处理每个任务,避免慢任务影响快任务
|
||||
results = []
|
||||
for i, task in enumerate(tasks):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
task_timeout = task_timeouts.get(task_name, 2.0) # 默认2秒
|
||||
|
||||
batch_results = await asyncio.wait_for(
|
||||
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
results.extend(batch_results)
|
||||
else:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
try:
|
||||
result = await asyncio.wait_for(task, timeout=task_timeout)
|
||||
results.append(result)
|
||||
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"构建任务{task_name}超时 ({task_timeout}s),使用默认值")
|
||||
# 为超时任务提供默认值
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {str(e)}")
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
|
||||
# 处理结果
|
||||
context_data = {}
|
||||
@@ -528,8 +535,7 @@ class Prompt:
|
||||
return {"memory_block": ""}
|
||||
|
||||
try:
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
|
||||
|
||||
# 获取聊天历史
|
||||
chat_history = ""
|
||||
@@ -539,15 +545,38 @@ class Prompt:
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
# 激活长期记忆
|
||||
memory_activator = MemoryActivator()
|
||||
running_memories = await memory_activator.activate_memory_with_chat_history(
|
||||
target_message=self.parameters.target, chat_history_prompt=chat_history
|
||||
)
|
||||
# 并行执行记忆查询以提高性能
|
||||
import asyncio
|
||||
|
||||
# 获取即时记忆
|
||||
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
|
||||
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
|
||||
# 创建记忆查询任务
|
||||
memory_tasks = [
|
||||
enhanced_memory_activator.activate_memory_with_chat_history(
|
||||
target_message=self.parameters.target, chat_history_prompt=chat_history
|
||||
),
|
||||
enhanced_memory_activator.get_instant_memory(
|
||||
target_message=self.parameters.target, chat_id=self.parameters.chat_id
|
||||
)
|
||||
]
|
||||
|
||||
# 等待所有记忆查询完成(最多3秒)
|
||||
try:
|
||||
running_memories, instant_memory = await asyncio.wait_for(
|
||||
asyncio.gather(*memory_tasks, return_exceptions=True),
|
||||
timeout=3.0
|
||||
)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||
running_memories = []
|
||||
if isinstance(instant_memory, Exception):
|
||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||
instant_memory = None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("记忆查询超时,使用部分结果")
|
||||
running_memories = []
|
||||
instant_memory = None
|
||||
|
||||
# 构建记忆块
|
||||
memory_parts = []
|
||||
@@ -870,6 +899,32 @@ class Prompt:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
为超时的任务提供默认结果
|
||||
|
||||
Args:
|
||||
task_name: 任务名称
|
||||
|
||||
Returns:
|
||||
Dict: 默认结果
|
||||
"""
|
||||
defaults = {
|
||||
"memory_block": {"memory_block": ""},
|
||||
"tool_info": {"tool_info_block": ""},
|
||||
"relation_info": {"relation_info_block": ""},
|
||||
"knowledge_info": {"knowledge_prompt": ""},
|
||||
"cross_context": {"cross_context_block": ""},
|
||||
"expression_habits": {"expression_habits_block": ""},
|
||||
}
|
||||
|
||||
if task_name in defaults:
|
||||
logger.info(f"为超时任务 {task_name} 提供默认值")
|
||||
return defaults[task_name]
|
||||
else:
|
||||
logger.warning(f"未知任务类型 {task_name},返回空结果")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
|
||||
@@ -459,6 +459,40 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||
|
||||
# 增强记忆系统配置
|
||||
enable_enhanced_memory: bool = Field(default=True, description="启用增强记忆系统")
|
||||
enhanced_memory_auto_save: bool = Field(default=True, description="自动保存增强记忆")
|
||||
|
||||
# 记忆构建配置
|
||||
min_memory_length: int = Field(default=10, description="最小记忆长度")
|
||||
max_memory_length: int = Field(default=500, description="最大记忆长度")
|
||||
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension: int = Field(default=768, description="向量维度")
|
||||
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
|
||||
|
||||
# 多阶段检索配置
|
||||
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")
|
||||
vector_search_limit: int = Field(default=50, description="向量搜索阶段返回数量")
|
||||
semantic_rerank_limit: int = Field(default=20, description="语义重排序阶段返回数量")
|
||||
final_result_limit: int = Field(default=10, description="最终结果数量")
|
||||
|
||||
# 检索权重配置
|
||||
vector_weight: float = Field(default=0.4, description="向量相似度权重")
|
||||
semantic_weight: float = Field(default=0.3, description="语义相似度权重")
|
||||
context_weight: float = Field(default=0.2, description="上下文权重")
|
||||
recency_weight: float = Field(default=0.1, description="时效性权重")
|
||||
|
||||
# 记忆融合配置
|
||||
fusion_similarity_threshold: float = Field(default=0.85, description="融合相似度阈值")
|
||||
deduplication_window_hours: int = Field(default=24, description="去重时间窗口(小时)")
|
||||
|
||||
# 缓存配置
|
||||
enable_memory_cache: bool = Field(default=True, description="启用记忆缓存")
|
||||
cache_ttl_seconds: int = Field(default=300, description="缓存生存时间(秒)")
|
||||
max_cache_size: int = Field(default=1000, description="最大缓存大小")
|
||||
|
||||
|
||||
class MoodConfig(ValidatedConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
157
src/main.py
157
src/main.py
@@ -34,54 +34,8 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if not global_config.memory.enable_memory:
|
||||
import src.chat.memory_system.Hippocampus as hippocampus_module
|
||||
|
||||
class MockHippocampusManager:
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def get_hippocampus(self):
|
||||
return None
|
||||
|
||||
async def build_memory(self):
|
||||
pass
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
pass
|
||||
|
||||
async def consolidate_memory(self):
|
||||
pass
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list:
|
||||
return []
|
||||
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str]]:
|
||||
return 0.0, []
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
return []
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
return []
|
||||
|
||||
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
||||
# 导入增强记忆系统管理器
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
@@ -106,7 +60,8 @@ def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float):
|
||||
|
||||
class MainSystem:
|
||||
def __init__(self):
|
||||
self.hippocampus_manager = hippocampus_manager
|
||||
# 使用增强记忆系统
|
||||
self.enhanced_memory_manager = enhanced_memory_manager
|
||||
|
||||
self.individuality: Individuality = get_individuality()
|
||||
|
||||
@@ -169,19 +124,18 @@ class MainSystem:
|
||||
logger.error(f"停止热重载系统时出错: {e}")
|
||||
|
||||
try:
|
||||
# 停止异步记忆管理器
|
||||
# 停止增强记忆系统
|
||||
if global_config.memory.enable_memory:
|
||||
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(async_memory_manager.shutdown())
|
||||
asyncio.create_task(self.enhanced_memory_manager.shutdown())
|
||||
else:
|
||||
loop.run_until_complete(async_memory_manager.shutdown())
|
||||
logger.info("🛑 记忆管理器已停止")
|
||||
loop.run_until_complete(self.enhanced_memory_manager.shutdown())
|
||||
logger.info("🛑 增强记忆系统已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止记忆管理器时出错: {e}")
|
||||
logger.error(f"停止增强记忆系统时出错: {e}")
|
||||
|
||||
async def _message_process_wrapper(self, message_data: Dict[str, Any]):
|
||||
"""并行处理消息的包装器"""
|
||||
@@ -304,9 +258,11 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# 初始化记忆系统
|
||||
self.hippocampus_manager.initialize()
|
||||
logger.info("记忆系统初始化成功")
|
||||
# 初始化增强记忆系统
|
||||
await self.enhanced_memory_manager.initialize()
|
||||
logger.info("增强记忆系统初始化成功")
|
||||
|
||||
# 老记忆系统已完全删除
|
||||
|
||||
# 初始化LPMM知识库
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
@@ -314,14 +270,8 @@ MoFox_Bot(第三方修改版)
|
||||
initialize_lpmm_knowledge()
|
||||
logger.info("LPMM知识库初始化成功")
|
||||
|
||||
# 初始化异步记忆管理器
|
||||
try:
|
||||
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
||||
|
||||
await async_memory_manager.initialize()
|
||||
logger.info("记忆管理器初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"记忆管理器初始化失败: {e}")
|
||||
# 异步记忆管理器已禁用,增强记忆系统有内置的优化机制
|
||||
logger.info("异步记忆管理器已禁用 - 使用增强记忆系统内置优化")
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
@@ -376,81 +326,12 @@ MoFox_Bot(第三方修改版)
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
# 添加记忆系统相关任务
|
||||
tasks.extend(
|
||||
[
|
||||
self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
self.consolidate_memory_task(),
|
||||
]
|
||||
)
|
||||
# 增强记忆系统不需要定时任务,已禁用原有记忆系统的定时任务
|
||||
logger.info("原有记忆系统定时任务已禁用 - 使用增强记忆系统")
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def build_memory_task(self):
|
||||
"""记忆构建任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.memory_build_interval)
|
||||
|
||||
try:
|
||||
# 使用异步记忆管理器进行非阻塞记忆构建
|
||||
from src.chat.memory_system.async_memory_optimizer import build_memory_nonblocking
|
||||
|
||||
logger.info("正在启动记忆构建")
|
||||
|
||||
# 定义构建完成的回调函数
|
||||
def build_completed(result):
|
||||
if result:
|
||||
logger.info("记忆构建完成")
|
||||
else:
|
||||
logger.warning("记忆构建失败")
|
||||
|
||||
# 启动异步构建,不等待完成
|
||||
task_id = await build_memory_nonblocking()
|
||||
logger.info(f"记忆构建任务已提交:{task_id}")
|
||||
|
||||
except ImportError:
|
||||
# 如果异步优化器不可用,使用原有的同步方式(但在单独的线程中运行)
|
||||
logger.warning("记忆优化器不可用,使用线性运行执行记忆构建")
|
||||
|
||||
def sync_build_memory():
|
||||
"""在线程池中执行同步记忆构建"""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(self.hippocampus_manager.build_memory())
|
||||
logger.info("记忆构建完成")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"记忆构建失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# 在线程池中执行记忆构建
|
||||
asyncio.get_event_loop().run_in_executor(None, sync_build_memory)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆构建任务启动失败: {e}")
|
||||
# fallback到原有的同步方式
|
||||
logger.info("正在进行记忆构建(同步模式)")
|
||||
await self.hippocampus_manager.build_memory() # type: ignore
|
||||
|
||||
async def forget_memory_task(self):
|
||||
"""记忆遗忘任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
|
||||
async def consolidate_memory_task(self):
|
||||
"""记忆整合任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
||||
logger.info("[记忆整合] 开始整合记忆...")
|
||||
await self.hippocampus_manager.consolidate_memory() # type: ignore
|
||||
logger.info("[记忆整合] 记忆整合完成")
|
||||
# 老记忆系统的定时任务已删除 - 增强记忆系统使用内置的维护机制
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -2,7 +2,8 @@ import asyncio
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -40,11 +41,31 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, _ = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
# 使用新的增强记忆系统计算兴趣度
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_integration import recall_memories
|
||||
|
||||
# 检索相关记忆来估算兴趣度
|
||||
enhanced_memories = await recall_memories(
|
||||
query=message.processed_plain_text,
|
||||
user_id=str(message.user_info.user_id),
|
||||
chat_id=message.chat_id
|
||||
)
|
||||
|
||||
# 基于检索结果计算兴趣度
|
||||
if enhanced_memories:
|
||||
# 有相关记忆,兴趣度基于相似度计算
|
||||
max_score = max(score for _, score in enhanced_memories)
|
||||
interested_rate = min(max_score, 1.0) # 限制在0-1之间
|
||||
else:
|
||||
# 没有相关记忆,给予基础兴趣度
|
||||
interested_rate = 0.1
|
||||
|
||||
logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统兴趣度计算失败: {e}")
|
||||
interested_rate = 0.1 # 默认基础兴趣度
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
|
||||
@@ -4,7 +4,8 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
import random
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
@@ -171,16 +172,26 @@ class PromptBuilder:
|
||||
|
||||
@staticmethod
|
||||
async def build_memory_block(text: str) -> str:
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
# 使用新的增强记忆系统检索记忆
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_integration import recall_memories
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
|
||||
return ""
|
||||
enhanced_memories = await recall_memories(
|
||||
query=text,
|
||||
user_id="system", # 系统查询
|
||||
chat_id="system"
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
for memory in enhanced_memories.get("memories", []):
|
||||
related_memory_info += memory.get("content", "") + " "
|
||||
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info.strip())
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
|
||||
@@ -98,7 +98,6 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
message_recv = MessageRecv(new_message_dict)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||
logger.info(message_recv)
|
||||
return message_recv
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
build_readable_messages_with_id,
|
||||
@@ -602,14 +603,32 @@ class ChatterPlanFilter:
|
||||
else:
|
||||
keywords.append("晚上")
|
||||
|
||||
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
||||
)
|
||||
# 使用新的增强记忆系统检索记忆
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_integration import recall_memories
|
||||
|
||||
if not retrieved_memories:
|
||||
# 将关键词转换为查询字符串
|
||||
query = " ".join(keywords)
|
||||
enhanced_memories = await recall_memories(
|
||||
query=query,
|
||||
user_id="system", # 系统查询
|
||||
chat_id="system"
|
||||
)
|
||||
|
||||
if not enhanced_memories:
|
||||
return "最近没有什么特别的记忆。"
|
||||
|
||||
# 转换格式以兼容现有代码
|
||||
retrieved_memories = []
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
for memory in enhanced_memories.get("memories", []):
|
||||
retrieved_memories.append((memory.get("type", "unknown"), memory.get("content", "")))
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}")
|
||||
return "最近没有什么特别的记忆。"
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
return " ".join(memory_statements)
|
||||
except Exception as e:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
|
||||
Reference in New Issue
Block a user