feat(memory): 实现增强记忆系统并完全替换原有架构

引入全新的增强记忆系统,彻底取代海马体记忆架构
删除旧版记忆系统相关模块,包括Hippocampus、异步包装器和优化器
重构消息处理流程,集成增强记忆系统的存储和检索功能
更新配置结构以支持增强记忆的各项参数设置
禁用原有定时任务,采用内置维护机制保证系统性能
This commit is contained in:
Windpicker-owo
2025-09-30 00:09:46 +08:00
parent 33be072f04
commit b30db43776
31 changed files with 6806 additions and 3878 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,248 +0,0 @@
# -*- coding: utf-8 -*-
"""
异步瞬时记忆包装器
提供对现有瞬时记忆系统的异步包装,支持超时控制和回退机制
"""
import asyncio
import time
from typing import Optional, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("async_instant_memory_wrapper")
class AsyncInstantMemoryWrapper:
"""异步瞬时记忆包装器"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.llm_memory = None
self.vector_memory = None
self.cache: Dict[str, tuple[Any, float]] = {} # 缓存:(结果, 时间戳)
self.cache_ttl = 300 # 缓存5分钟
self.default_timeout = 3.0 # 默认超时3秒
# 从配置中读取是否启用各种记忆系统
self.llm_memory_enabled = global_config.memory.enable_llm_instant_memory
self.vector_memory_enabled = global_config.memory.enable_vector_instant_memory
async def _ensure_llm_memory(self):
"""确保LLM记忆系统已初始化"""
if self.llm_memory is None and self.llm_memory_enabled:
try:
from src.chat.memory_system.instant_memory import InstantMemory
self.llm_memory = InstantMemory(self.chat_id)
logger.info(f"LLM瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
logger.warning(f"LLM瞬时记忆系统初始化失败: {e}")
self.llm_memory_enabled = False # 初始化失败则禁用
async def _ensure_vector_memory(self):
"""确保向量记忆系统已初始化"""
if self.vector_memory is None and self.vector_memory_enabled:
try:
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
self.vector_memory = VectorInstantMemoryV2(self.chat_id)
logger.info(f"向量瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
logger.warning(f"向量瞬时记忆系统初始化失败: {e}")
self.vector_memory_enabled = False # 初始化失败则禁用
def _get_cache_key(self, operation: str, content: str) -> str:
"""生成缓存键"""
return f"{operation}_{self.chat_id}_{hash(content)}"
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
_, timestamp = self.cache[cache_key]
return time.time() - timestamp < self.cache_ttl
def _get_cached_result(self, cache_key: str) -> Optional[Any]:
"""获取缓存结果"""
if self._is_cache_valid(cache_key):
result, _ = self.cache[cache_key]
return result
return None
def _cache_result(self, cache_key: str, result: Any):
"""缓存结果"""
self.cache[cache_key] = (result, time.time())
async def store_memory_async(self, content: str, timeout: Optional[float] = None) -> bool:
"""异步存储记忆(带超时控制)"""
if timeout is None:
timeout = self.default_timeout
success_count = 0
# 异步存储到LLM记忆系统
await self._ensure_llm_memory()
if self.llm_memory:
try:
await asyncio.wait_for(self.llm_memory.create_and_store_memory(content), timeout=timeout)
success_count += 1
logger.debug(f"LLM记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"LLM记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"LLM记忆存储失败: {e}")
# 异步存储到向量记忆系统
await self._ensure_vector_memory()
if self.vector_memory:
try:
await asyncio.wait_for(self.vector_memory.store_message(content), timeout=timeout)
success_count += 1
logger.debug(f"向量记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"向量记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"向量记忆存储失败: {e}")
return success_count > 0
async def retrieve_memory_async(
self, query: str, timeout: Optional[float] = None, use_cache: bool = True
) -> Optional[Any]:
"""异步检索记忆(带缓存和超时控制)"""
if timeout is None:
timeout = self.default_timeout
# 检查缓存
if use_cache:
cache_key = self._get_cache_key("retrieve", query)
cached_result = self._get_cached_result(cache_key)
if cached_result is not None:
logger.debug(f"记忆检索命中缓存: {query[:30]}...")
return cached_result
# 尝试多种记忆系统
results = []
# 从向量记忆系统检索(优先,速度快)
await self._ensure_vector_memory()
if self.vector_memory:
try:
vector_result = await asyncio.wait_for(
self.vector_memory.get_memory_for_context(query),
timeout=timeout * 0.6, # 给向量系统60%的时间
)
if vector_result:
results.append(vector_result)
logger.debug(f"向量记忆检索成功: {query[:30]}...")
except asyncio.TimeoutError:
logger.warning(f"向量记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"向量记忆检索失败: {e}")
# 从LLM记忆系统检索备用更准确但较慢
await self._ensure_llm_memory()
if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM
try:
llm_result = await asyncio.wait_for(
self.llm_memory.get_memory(query),
timeout=timeout * 0.4, # 给LLM系统40%的时间
)
if llm_result:
results.extend(llm_result)
logger.debug(f"LLM记忆检索成功: {query[:30]}...")
except asyncio.TimeoutError:
logger.warning(f"LLM记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"LLM记忆检索失败: {e}")
# 合并结果
final_result = None
if results:
if len(results) == 1:
final_result = results[0]
else:
# 合并多个结果
if isinstance(results[0], str):
final_result = "\n".join(str(r) for r in results)
elif isinstance(results[0], list):
final_result = []
for r in results:
if isinstance(r, list):
final_result.extend(r)
else:
final_result.append(r)
else:
final_result = results[0] # 使用第一个结果
# 缓存结果
if use_cache and final_result is not None:
cache_key = self._get_cache_key("retrieve", query)
self._cache_result(cache_key, final_result)
return final_result
async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str:
"""获取记忆的回退方法,保证不会长时间阻塞"""
try:
# 首先尝试快速检索
result = await self.retrieve_memory_async(query, timeout=max_timeout)
if result:
if isinstance(result, list):
return "\n".join(str(item) for item in result)
return str(result)
return ""
except Exception as e:
logger.error(f"记忆检索完全失败: {e}")
return ""
def store_memory_background(self, content: str):
"""在后台存储记忆(发后即忘模式)"""
async def background_store():
try:
await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时
except Exception as e:
logger.error(f"后台记忆存储失败: {e}")
# 创建后台任务
asyncio.create_task(background_store())
def get_status(self) -> Dict[str, Any]:
"""获取包装器状态"""
return {
"chat_id": self.chat_id,
"llm_memory_available": self.llm_memory is not None,
"vector_memory_available": self.vector_memory is not None,
"cache_entries": len(self.cache),
"cache_ttl": self.cache_ttl,
"default_timeout": self.default_timeout,
}
def clear_cache(self):
"""清理缓存"""
self.cache.clear()
logger.info(f"记忆缓存已清理: {self.chat_id}")
# 缓存包装器实例,避免重复创建
_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {}
def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper:
"""获取异步瞬时记忆包装器实例"""
if chat_id not in _wrapper_cache:
_wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id)
return _wrapper_cache[chat_id]
def clear_wrapper_cache():
"""清理包装器缓存"""
global _wrapper_cache
_wrapper_cache.clear()
logger.info("异步瞬时记忆包装器缓存已清理")

View File

@@ -1,358 +0,0 @@
# -*- coding: utf-8 -*-
"""
异步记忆系统优化器
解决记忆系统阻塞主程序的问题,将同步操作改为异步非阻塞操作
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
logger = get_logger("async_memory_optimizer")
@dataclass
class MemoryTask:
"""记忆任务数据结构"""
task_id: str
task_type: str # "store", "retrieve", "build"
chat_id: str
content: str
priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级
callback: Optional[Callable] = None
created_at: float = None
def __post_init__(self):
if self.created_at is None:
self.created_at = time.time()
class AsyncMemoryQueue:
"""异步记忆任务队列管理器"""
def __init__(self, max_workers: int = 3):
self.max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.task_queue = asyncio.Queue()
self.running_tasks: Dict[str, asyncio.Task] = {}
self.completed_tasks: Dict[str, Any] = {}
self.failed_tasks: Dict[str, str] = {}
self.is_running = False
self.worker_tasks: List[asyncio.Task] = []
async def start(self):
"""启动异步队列处理器"""
if self.is_running:
return
self.is_running = True
# 启动多个工作协程
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(f"worker-{i}"))
self.worker_tasks.append(worker)
logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}")
async def stop(self):
"""停止队列处理器"""
self.is_running = False
# 等待所有工作任务完成
for task in self.worker_tasks:
task.cancel()
await asyncio.gather(*self.worker_tasks, return_exceptions=True)
self.executor.shutdown(wait=True)
logger.info("异步记忆队列已停止")
async def _worker(self, worker_name: str):
"""工作协程,处理队列中的任务"""
logger.info(f"记忆处理工作线程 {worker_name} 启动")
while self.is_running:
try:
# 等待任务超时1秒避免永久阻塞
task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0)
# 执行任务
await self._execute_task(task, worker_name)
except asyncio.TimeoutError:
# 超时正常,继续下一次循环
continue
except Exception as e:
logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}")
async def _execute_task(self, task: MemoryTask, worker_name: str):
"""执行具体的记忆任务"""
try:
logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}")
start_time = time.time()
# 根据任务类型执行不同的处理逻辑
result = None
if task.task_type == "store":
result = await self._handle_store_task(task)
elif task.task_type == "retrieve":
result = await self._handle_retrieve_task(task)
elif task.task_type == "build":
result = await self._handle_build_task(task)
else:
raise ValueError(f"未知的任务类型: {task.task_type}")
# 记录完成的任务
self.completed_tasks[task.task_id] = result
execution_time = time.time() - start_time
logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)")
# 执行回调函数
if task.callback:
try:
if asyncio.iscoroutinefunction(task.callback):
await task.callback(result)
else:
task.callback(result)
except Exception as e:
logger.error(f"任务回调执行失败: {e}")
except Exception as e:
error_msg = f"任务执行失败: {e}"
logger.error(f"[{worker_name}] {error_msg}")
self.failed_tasks[task.task_id] = error_msg
# 执行错误回调
if task.callback:
try:
if asyncio.iscoroutinefunction(task.callback):
await task.callback(None)
else:
task.callback(None)
except Exception:
pass
@staticmethod
async def _handle_store_task(task: MemoryTask) -> Any:
"""处理记忆存储任务"""
# 这里需要根据具体的记忆系统来实现
# 为了避免循环导入,这里使用延迟导入
try:
# 获取包装器实例
memory_wrapper = get_async_instant_memory(task.chat_id)
# 使用包装器中的llm_memory实例
if memory_wrapper and memory_wrapper.llm_memory:
await memory_wrapper.llm_memory.create_and_store_memory(task.content)
return True
else:
logger.warning(f"无法获取记忆系统实例,存储任务失败: chat_id={task.chat_id}")
return False
except Exception as e:
logger.error(f"记忆存储失败: {e}")
return False
@staticmethod
async def _handle_retrieve_task(task: MemoryTask) -> Any:
"""处理记忆检索任务"""
try:
# 获取包装器实例
memory_wrapper = get_async_instant_memory(task.chat_id)
# 使用包装器中的llm_memory实例
if memory_wrapper and memory_wrapper.llm_memory:
memories = await memory_wrapper.llm_memory.get_memory(task.content)
return memories or []
else:
logger.warning(f"无法获取记忆系统实例,检索任务失败: chat_id={task.chat_id}")
return []
except Exception as e:
logger.error(f"记忆检索失败: {e}")
return []
@staticmethod
async def _handle_build_task(task: MemoryTask) -> Any:
"""处理记忆构建任务(海马体系统)"""
try:
# 延迟导入避免循环依赖
if global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager
if hippocampus_manager._initialized:
# 确保海马体对象已正确初始化
if not hippocampus_manager._hippocampus.parahippocampal_gyrus:
logger.warning("海马体对象未完全初始化,进行同步初始化")
hippocampus_manager._hippocampus.initialize()
await hippocampus_manager.build_memory()
return True
return False
except Exception as e:
logger.error(f"记忆构建失败: {e}")
return False
async def add_task(self, task: MemoryTask) -> str:
"""添加任务到队列"""
await self.task_queue.put(task)
self.running_tasks[task.task_id] = task
logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}")
return task.task_id
def get_task_result(self, task_id: str) -> Optional[Any]:
"""获取任务结果(非阻塞)"""
return self.completed_tasks.get(task_id)
def is_task_completed(self, task_id: str) -> bool:
"""检查任务是否完成"""
return task_id in self.completed_tasks or task_id in self.failed_tasks
def get_queue_status(self) -> Dict[str, Any]:
"""获取队列状态"""
return {
"is_running": self.is_running,
"queue_size": self.task_queue.qsize(),
"running_tasks": len(self.running_tasks),
"completed_tasks": len(self.completed_tasks),
"failed_tasks": len(self.failed_tasks),
"worker_count": len(self.worker_tasks),
}
class NonBlockingMemoryManager:
"""非阻塞记忆管理器"""
def __init__(self):
self.queue = AsyncMemoryQueue(max_workers=3)
self.cache: Dict[str, Any] = {}
self.cache_ttl: Dict[str, float] = {}
self.cache_timeout = 300 # 缓存5分钟
async def initialize(self):
"""初始化管理器"""
await self.queue.start()
logger.info("非阻塞记忆管理器已初始化")
async def shutdown(self):
"""关闭管理器"""
await self.queue.stop()
logger.info("非阻塞记忆管理器已关闭")
async def store_memory_async(self, chat_id: str, content: str, callback: Optional[Callable] = None) -> str:
"""异步存储记忆(非阻塞)"""
task = MemoryTask(
task_id=f"store_{chat_id}_{int(time.time() * 1000)}",
task_type="store",
chat_id=chat_id,
content=content,
priority=1, # 存储优先级较低
callback=callback,
)
return await self.queue.add_task(task)
async def retrieve_memory_async(self, chat_id: str, query: str, callback: Optional[Callable] = None) -> str:
"""异步检索记忆(非阻塞)"""
# 先检查缓存
cache_key = f"retrieve_{chat_id}_{hash(query)}"
if self._is_cache_valid(cache_key):
result = self.cache[cache_key]
if callback:
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
return "cache_hit"
task = MemoryTask(
task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}",
task_type="retrieve",
chat_id=chat_id,
content=query,
priority=2, # 检索优先级中等
callback=self._create_cache_callback(cache_key, callback),
)
return await self.queue.add_task(task)
async def build_memory_async(self, callback: Optional[Callable] = None) -> str:
"""异步构建记忆(非阻塞)"""
task = MemoryTask(
task_id=f"build_memory_{int(time.time() * 1000)}",
task_type="build",
chat_id="system",
content="",
priority=1, # 构建优先级较低,避免影响用户体验
callback=callback,
)
return await self.queue.add_task(task)
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout
def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]):
"""创建带缓存的回调函数"""
async def cache_callback(result):
# 存储到缓存
if result is not None:
self.cache[cache_key] = result
self.cache_ttl[cache_key] = time.time()
# 执行原始回调
if original_callback:
if asyncio.iscoroutinefunction(original_callback):
await original_callback(result)
else:
original_callback(result)
return cache_callback
def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]:
"""获取缓存的记忆(同步,立即返回)"""
cache_key = f"retrieve_{chat_id}_{hash(query)}"
if self._is_cache_valid(cache_key):
return self.cache[cache_key]
return None
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
status = self.queue.get_queue_status()
status.update({"cache_entries": len(self.cache), "cache_timeout": self.cache_timeout})
return status
# 全局实例
async_memory_manager = NonBlockingMemoryManager()
# 便捷函数
async def store_memory_nonblocking(chat_id: str, content: str) -> str:
"""非阻塞存储记忆的便捷函数"""
return await async_memory_manager.store_memory_async(chat_id, content)
async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]:
"""非阻塞检索记忆的便捷函数,支持缓存"""
# 先尝试从缓存获取
cached_result = async_memory_manager.get_cached_memory(chat_id, query)
if cached_result is not None:
return cached_result
# 缓存未命中,启动异步检索
await async_memory_manager.retrieve_memory_async(chat_id, query)
return None # 返回None表示需要异步获取
async def build_memory_nonblocking() -> str:
"""非阻塞构建记忆的便捷函数"""
return await async_memory_manager.build_memory_async()

View File

@@ -1,196 +0,0 @@
# 记忆系统异步优化说明
## 🎯 优化目标
解决MaiBot-Plus记忆系统阻塞主程序的问题将原本的线性同步调用改为异步非阻塞运行。
## ⚠️ 问题分析
### 原有问题
1. **瞬时记忆阻塞**:每次用户发消息时,`await self.instant_memory.get_memory_for_context(target)` 会阻塞等待LLM响应
2. **定时记忆构建阻塞**每600秒执行的 `build_memory_task()` 会完全阻塞主程序数十秒
3. **LLM调用链阻塞**记忆存储和检索都需要调用LLM延迟较高
### 卡顿表现
- 用户发消息后,程序响应延迟明显增加
- 定时记忆构建时,整个程序无响应
- 高并发时,记忆系统成为性能瓶颈
## 🚀 优化方案
### 1. 异步记忆队列系统 (`async_memory_optimizer.py`)
**核心思想**:将记忆操作放入异步队列,后台处理,不阻塞主程序。
**关键特性**
- 任务队列管理:支持存储、检索、构建三种任务类型
- 优先级调度:高优先级任务(用户查询)优先处理
- 线程池执行:避免阻塞事件循环
- 结果缓存:减少重复计算
- 失败重试:提高系统可靠性
```python
# 使用示例
from src.chat.memory_system.async_memory_optimizer import (
store_memory_nonblocking,
retrieve_memory_nonblocking,
build_memory_nonblocking
)
# 非阻塞存储记忆
task_id = await store_memory_nonblocking(chat_id, content)
# 非阻塞检索记忆(支持缓存)
memories = await retrieve_memory_nonblocking(chat_id, query)
# 非阻塞构建记忆
task_id = await build_memory_nonblocking()
```
### 2. 异步瞬时记忆包装器 (`async_instant_memory_wrapper.py`)
**核心思想**:为现有瞬时记忆系统提供异步包装,支持超时控制和多层回退。
**关键特性**
- 超时控制:防止长时间阻塞
- 缓存机制:热点查询快速响应
- 多系统融合LLM记忆 + 向量记忆
- 回退策略:保证系统稳定性
- 后台存储:存储操作完全非阻塞
```python
# 使用示例
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
async_memory = get_async_instant_memory(chat_id)
# 后台存储(发后即忘)
async_memory.store_memory_background(content)
# 快速检索(带超时)
result = await async_memory.get_memory_with_fallback(query, max_timeout=2.0)
```
### 3. 主程序优化
**记忆构建任务异步化**
- 原来:`await self.hippocampus_manager.build_memory()` 阻塞主程序
- 现在:使用异步队列或线程池,后台执行
**消息处理优化**
- 原来:同步等待记忆检索完成
- 现在最大2秒超时保证用户体验
## 📊 性能提升预期
### 响应速度
- **用户消息响应**从原来的3-10秒减少到0.5-2秒
- **记忆检索**:缓存命中时几乎即时响应
- **记忆存储**:从同步阻塞改为后台处理
### 并发能力
- **多用户同时使用**:不再因记忆系统相互阻塞
- **高峰期稳定性**:记忆任务排队处理,不会崩溃
### 资源使用
- **CPU使用**异步处理更好的CPU利用率
- **内存优化**:缓存机制,减少重复计算
- **网络延迟**LLM调用并行化减少等待时间
## 🔧 部署和配置
### 1. 自动部署
新的异步系统已经集成到现有代码中,支持自动回退:
```python
# 优先级回退机制
1. 异步瞬时记忆包装器 (最优)
2. 异步记忆管理器 (次优)
3. 带超时的同步模式 (保底)
```
### 2. 配置参数
`config.toml` 中可以调整相关参数:
```toml
[memory]
enable_memory = true
enable_instant_memory = true
memory_build_interval = 600 # 记忆构建间隔(秒)
```
### 3. 监控和调试
```python
# 查看异步队列状态
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
status = async_memory_manager.get_status()
print(status)
# 查看包装器状态
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
wrapper = get_async_instant_memory(chat_id)
status = wrapper.get_status()
print(status)
```
## 🧪 验证方法
### 1. 性能测试
```bash
# 测试用户消息响应时间
time curl -X POST "http://localhost:8080/api/message" -d '{"message": "你还记得我们昨天聊的内容吗?"}'
# 观察内存构建时的程序响应
# 构建期间发送消息,观察是否还有阻塞
```
### 2. 并发测试
```python
import asyncio
import time
async def test_concurrent_messages():
"""测试并发消息处理"""
tasks = []
for i in range(10):
task = asyncio.create_task(send_message(f"测试消息 {i}"))
tasks.append(task)
start_time = time.time()
results = await asyncio.gather(*tasks)
end_time = time.time()
print(f"10条并发消息处理完成耗时: {end_time - start_time:.2f}秒")
```
### 3. 日志监控
关注以下日志输出:
- `"异步瞬时记忆:"` - 确认使用了异步系统
- `"记忆构建任务已提交"` - 确认构建任务非阻塞
- `"瞬时记忆检索超时"` - 监控超时情况
## 🔄 回退机制
系统设计了多层回退机制,确保即使新系统出现问题,也能维持基本功能:
1. **异步包装器失败** → 使用异步队列管理器
2. **异步队列失败** → 使用带超时的同步模式
3. **超时保护** → 最长等待时间不超过2秒
4. **完全失败** → 跳过记忆功能,保证基本对话
## 📝 注意事项
1. **首次启动**:异步系统需要初始化时间,可能前几次记忆调用延迟稍高
2. **缓存预热**:系统运行一段时间后,缓存效果会显著提升响应速度
3. **内存使用**:缓存会增加内存使用,但相对于性能提升是值得的
4. **兼容性**:如果发现异步系统有问题,可以临时禁用相关导入,自动回退到原系统
## 🎉 预期效果
-**消息响应速度提升60%+**
-**记忆构建不再阻塞主程序**
-**支持更高的并发用户数**
-**系统整体稳定性提升**
-**保持原有记忆功能完整性**

View File

@@ -0,0 +1,237 @@
# -*- coding: utf-8 -*-
"""
增强记忆激活器
替代原有的 MemoryActivator使用增强记忆系统
"""
import difflib
import orjson
import time
from typing import List, Dict, Optional
from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager, EnhancedMemoryResult
logger = get_logger("enhanced_memory_activator")
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Enhanced Memory Activator Prompt ---
enhanced_memory_activator_prompt = """
你是一个增强记忆分析器,你需要根据以下信息来进行记忆检索
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
聊天记录:
{obs_info_text}
用户想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(enhanced_memory_activator_prompt, "enhanced_memory_activator_prompt")
class EnhancedMemoryActivator:
"""增强记忆激活器 - 替代原有的 MemoryActivator"""
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="enhanced_memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
self.last_enhanced_query_time = 0 # 上次查询增强记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
"""
激活增强记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"enhanced_memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# 生成关键词
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"增强记忆关键词: {self.cached_keywords}")
# 使用增强记忆系统获取相关记忆
enhanced_results = await self._query_enhanced_memory(keywords, target_message)
# 处理和增强记忆结果
if enhanced_results:
for result in enhanced_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content or
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
memory_entry = {
"topic": result.memory_type,
"content": result.content,
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
"duration": 1,
"confidence": result.confidence,
"importance": result.importance,
"source": result.source
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新增强记忆: {result.memory_type} - {result.content}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
# 限制同时加载的记忆条数最多保留最后5条增强记忆可以处理更多
if len(self.running_memory) > 5:
self.running_memory = self.running_memory[-5:]
return self.running_memory
async def _query_enhanced_memory(self, keywords: List[str], query_text: str) -> List[EnhancedMemoryResult]:
"""查询增强记忆系统"""
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建查询上下文
context = {
"keywords": keywords,
"query_intent": "conversation_response",
"expected_memory_types": [
"personal_fact", "event", "preference", "opinion"
]
}
# 查询增强记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text,
user_id="default_user", # 可以根据实际用户ID调整
context=context,
limit=5
)
logger.debug(f"增强记忆查询返回 {len(enhanced_results)} 条结果")
return enhanced_results
except Exception as e:
logger.error(f"查询增强记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
"""
获取即时记忆 - 兼容原有接口
"""
try:
# 使用增强记忆系统获取相关记忆
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
context = {
"query_intent": "instant_response",
"chat_id": chat_id,
"expected_memory_types": ["preference", "opinion", "personal_fact"]
}
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=target_message,
user_id="default_user",
context=context,
limit=1
)
if enhanced_results:
# 返回最相关的记忆内容
return enhanced_results[0].content
return None
except Exception as e:
logger.error(f"获取即时记忆失败: {e}")
return None
def clear_cache(self):
"""清除缓存"""
self.cached_keywords.clear()
self.running_memory.clear()
logger.debug("增强记忆激活器缓存已清除")
# 创建全局实例
enhanced_memory_activator = EnhancedMemoryActivator()
# 为了兼容性,保留原有名称
MemoryActivator = EnhancedMemoryActivator
init_prompt()

View File

@@ -0,0 +1,332 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@dataclass
class AdapterConfig:
"""适配器配置"""
enable_enhanced_memory: bool = True
integration_mode: str = "enhanced_only" # replace, enhanced_only
auto_migration: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: Optional[MemoryIntegrationLayer] = None
self._initialized = False
# 统计信息
self.adapter_stats = {
"total_processed": 0,
"enhanced_used": 0,
"legacy_used": 0,
"hybrid_used": 0,
"memories_created": 0,
"memories_retrieved": 0,
"average_processing_time": 0.0
}
async def initialize(self):
"""初始化适配器"""
if self._initialized:
return
try:
logger.info("🚀 初始化增强记忆系统适配器...")
# 转换配置格式
integration_config = IntegrationConfig(
mode=IntegrationMode(self.config.integration_mode),
enable_enhanced_memory=self.config.enable_enhanced_memory,
memory_value_threshold=self.config.memory_value_threshold,
fusion_threshold=self.config.fusion_threshold,
max_retrieval_results=self.config.max_retrieval_results,
enable_learning=True # 启用学习功能
)
# 创建集成层
self.integration_layer = MemoryIntegrationLayer(
llm_model=self.llm_model,
config=integration_config
)
# 初始化集成层
await self.integration_layer.initialize()
self._initialized = True
logger.info("✅ 增强记忆系统适配器初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> Dict[str, Any]:
"""处理对话记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
start_time = time.time()
self.adapter_stats["total_processed"] += 1
try:
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(
conversation_text, context, user_id, timestamp
)
# 更新统计
processing_time = time.time() - start_time
self._update_processing_stats(processing_time)
if result["success"]:
created_count = len(result.get("created_memories", []))
self.adapter_stats["memories_created"] += created_count
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> List[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(
query, user_id, context, limit
)
self.adapter_stats["memories_retrieved"] += len(memories)
logger.debug(f"检索到 {len(memories)} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
self,
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
if not memories:
return ""
# 格式化记忆为提示词友好的格式
memory_context_parts = []
for memory in memories:
memory_context_parts.append(f"- {memory.text_content}")
return "\n".join(memory_context_parts)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
try:
# 获取系统状态
status = await self.integration_layer.get_system_status()
# 获取适配器统计
adapter_stats = self.adapter_stats.copy()
# 获取集成统计
integration_stats = self.integration_layer.get_integration_stats()
return {
"available": True,
"system_status": status,
"adapter_stats": adapter_stats,
"integration_stats": integration_stats,
"total_memories_created": adapter_stats["memories_created"],
"total_memories_retrieved": adapter_stats["memories_retrieved"]
}
except Exception as e:
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
return {"available": False, "error": str(e)}
def _update_processing_stats(self, processing_time: float):
"""更新处理统计"""
total_processed = self.adapter_stats["total_processed"]
if total_processed > 0:
current_avg = self.adapter_stats["average_processing_time"]
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> Dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
async def maintenance(self):
"""维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 增强记忆系统适配器维护...")
await self.integration_layer.maintenance()
logger.info("✅ 增强记忆系统适配器维护完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭适配器"""
if not self._initialized:
return
try:
logger.info("🔄 关闭增强记忆系统适配器...")
await self.integration_layer.shutdown()
self._initialized = False
logger.info("✅ 增强记忆系统适配器已关闭")
except Exception as e:
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
# 全局适配器实例
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
"""获取全局增强记忆适配器实例"""
global _enhanced_memory_adapter
if _enhanced_memory_adapter is None:
# 从配置中获取适配器配置
from src.config.config import global_config
adapter_config = AdapterConfig(
enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True),
integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'),
auto_migration=getattr(global_config.memory, 'enable_memory_migration', True),
memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6),
fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85),
max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10)
)
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
await _enhanced_memory_adapter.initialize()
return _enhanced_memory_adapter
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
"""初始化增强记忆系统"""
try:
logger.info("🚀 初始化增强记忆系统...")
adapter = await get_enhanced_memory_adapter(llm_model)
logger.info("✅ 增强记忆系统初始化完成")
return adapter
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
return None
async def process_conversation_with_enhanced_memory(
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None,
llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
"""使用增强记忆系统处理对话"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 10,
llm_model: Optional[LLMRequest] = None
) -> List[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
except Exception as e:
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
max_memories: int = 5,
llm_model: Optional[LLMRequest] = None
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
except Exception as e:
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
return ""

View File

@@ -0,0 +1,753 @@
# -*- coding: utf-8 -*-
"""
增强型精准记忆系统核心模块
基于文档设计的高效记忆构建、存储与召回优化系统
"""
import asyncio
import time
import orjson
import re
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
import numpy as np
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_builder import MemoryBuilder
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
from src.chat.memory_system.metadata_index import MetadataIndexManager
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
class MemorySystemStatus(Enum):
"""记忆系统状态"""
INITIALIZING = "initializing"
READY = "ready"
BUILDING = "building"
RETRIEVING = "retrieving"
ERROR = "error"
@dataclass
class MemorySystemConfig:
"""记忆系统配置"""
# 记忆构建配置
min_memory_length: int = 10
max_memory_length: int = 500
memory_value_threshold: float = 0.7
# 向量存储配置
vector_dimension: int = 768
similarity_threshold: float = 0.8
# 召回配置
coarse_recall_limit: int = 50
fine_recall_limit: int = 10
final_recall_limit: int = 5
# 融合配置
fusion_similarity_threshold: float = 0.85
deduplication_window: timedelta = timedelta(hours=24)
@classmethod
def from_global_config(cls):
"""从全局配置创建配置实例"""
from src.config.config import global_config
return cls(
# 记忆构建配置
min_memory_length=global_config.memory.min_memory_length,
max_memory_length=global_config.memory.max_memory_length,
memory_value_threshold=global_config.memory.memory_value_threshold,
# 向量存储配置
vector_dimension=global_config.memory.vector_dimension,
similarity_threshold=global_config.memory.vector_similarity_threshold,
# 召回配置
coarse_recall_limit=global_config.memory.metadata_filter_limit,
fine_recall_limit=global_config.memory.final_result_limit,
final_recall_limit=global_config.memory.final_result_limit,
# 融合配置
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours)
)
class EnhancedMemorySystem:
"""增强型精准记忆系统核心类"""
def __init__(
self,
llm_model: Optional[LLMRequest] = None,
config: Optional[MemorySystemConfig] = None
):
self.config = config or MemorySystemConfig.from_global_config()
self.llm_model = llm_model
self.status = MemorySystemStatus.INITIALIZING
# 核心组件
self.memory_builder: MemoryBuilder = None
self.fusion_engine: MemoryFusionEngine = None
self.vector_storage: VectorStorageManager = None
self.metadata_index: MetadataIndexManager = None
self.retrieval_system: MultiStageRetrieval = None
# LLM模型
self.value_assessment_model: LLMRequest = None
self.memory_extraction_model: LLMRequest = None
# 统计信息
self.total_memories = 0
self.last_build_time = None
self.last_retrieval_time = None
logger.info("EnhancedMemorySystem 初始化开始")
async def initialize(self):
"""异步初始化记忆系统"""
try:
logger.info("正在初始化增强型记忆系统...")
# 初始化LLM模型
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils
)
self.value_assessment_model = LLMRequest(
model_set=task_config,
request_type="memory.value_assessment"
)
self.memory_extraction_model = LLMRequest(
model_set=task_config,
request_type="memory.extraction"
)
# 初始化核心组件
self.memory_builder = MemoryBuilder(self.memory_extraction_model)
self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold)
# 创建向量存储配置
vector_config = VectorStorageConfig(
dimension=self.config.vector_dimension,
similarity_threshold=self.config.similarity_threshold
)
self.vector_storage = VectorStorageManager(vector_config)
self.metadata_index = MetadataIndexManager()
# 创建检索配置
retrieval_config = RetrievalConfig(
metadata_filter_limit=self.config.coarse_recall_limit,
vector_search_limit=self.config.fine_recall_limit,
final_result_limit=self.config.final_recall_limit
)
self.retrieval_system = MultiStageRetrieval(retrieval_config)
# 加载持久化数据
await self.vector_storage.load_storage()
await self.metadata_index.load_index()
self.status = MemorySystemStatus.READY
logger.info("✅ 增强型记忆系统初始化完成")
except Exception as e:
self.status = MemorySystemStatus.ERROR
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
raise
async def build_memory_from_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""从对话中构建记忆
Args:
conversation_text: 对话文本
context: 上下文信息(包括用户信息、群组信息等)
user_id: 用户ID
timestamp: 时间戳,默认为当前时间
Returns:
构建的记忆块列表
"""
if self.status != MemorySystemStatus.READY:
raise RuntimeError("记忆系统未就绪")
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
# 1. 信息价值评估
value_score = await self._assess_information_value(conversation_text, normalized_context)
if value_score < self.config.memory_value_threshold:
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
self.status = MemorySystemStatus.READY
return []
# 2. 构建记忆块
memory_chunks = await self.memory_builder.build_memories(
conversation_text,
normalized_context,
user_id,
timestamp or time.time()
)
if not memory_chunks:
logger.debug("未提取到有效记忆块")
self.status = MemorySystemStatus.READY
return []
# 3. 记忆融合与去重
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
# 4. 存储记忆
await self._store_memories(fused_chunks)
# 5. 更新统计
self.total_memories += len(fused_chunks)
self.last_build_time = time.time()
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
self.status = MemorySystemStatus.READY
return fused_chunks
except Exception as e:
self.status = MemorySystemStatus.ERROR
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
raise
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> Dict[str, Any]:
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
memories = await self.build_memory_from_conversation(
conversation_text=conversation_text,
context=normalized_context,
user_id=user_id,
timestamp=timestamp
)
processing_time = time.time() - start_time
memory_count = len(memories)
return {
"success": True,
"created_memories": memories,
"memory_count": memory_count,
"processing_time": processing_time,
"status": self.status.value
}
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"processing_time": processing_time,
"status": self.status.value
}
async def retrieve_relevant_memories(
self,
query_text: Optional[str] = None,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5,
**kwargs
) -> List[MemoryChunk]:
"""检索相关记忆,兼容 query/query_text 参数形式"""
if self.status != MemorySystemStatus.READY:
raise RuntimeError("记忆系统未就绪")
query_text = query_text or kwargs.get("query")
if not query_text:
raise ValueError("query_text 或 query 参数不能为空")
context = context or {}
user_id = user_id or kwargs.get("user_id")
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, None)
candidate_memories = list(self.vector_storage.memory_cache.values())
if user_id:
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
if not candidate_memories:
self.status = MemorySystemStatus.READY
self.last_retrieval_time = time.time()
logger.debug(f"未找到用户 {user_id} 的候选记忆")
return []
scored_memories = []
for memory in candidate_memories:
score = self._compute_memory_score(query_text, memory, normalized_context)
if score > 0:
scored_memories.append((memory, score))
if not scored_memories:
# 如果所有分数为0返回最近的记忆作为降级策略
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
else:
scored_memories.sort(key=lambda item: item[1], reverse=True)
top_memories = [memory for memory, _ in scored_memories[:limit]]
# 更新访问信息和缓存
for memory, score in scored_memories[:limit]:
memory.update_access()
memory.update_relevance(score)
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
if cache_entry is not None:
cache_entry["last_accessed"] = memory.metadata.last_accessed
cache_entry["access_count"] = memory.metadata.access_count
cache_entry["relevance_score"] = memory.metadata.relevance_score
retrieval_time = time.time() - start_time
logger.info(
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}"
)
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return top_memories
except Exception as e:
self.status = MemorySystemStatus.ERROR
logger.error(f"❌ 记忆检索失败: {e}", exc_info=True)
raise
@staticmethod
def _extract_json_payload(response: str) -> Optional[str]:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
stripped = response.strip()
# 优先处理Markdown代码块格式 ```json ... ```
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
# 回退到查找第一个 JSON 对象的大括号范围
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _normalize_context(
self,
raw_context: Optional[Dict[str, Any]],
user_id: Optional[str],
timestamp: Optional[float]
) -> Dict[str, Any]:
"""标准化上下文,确保必备字段存在且格式正确"""
context: Dict[str, Any] = {}
if raw_context:
try:
context = dict(raw_context)
except Exception:
context = dict(raw_context or {})
# 基础字段
context["user_id"] = context.get("user_id") or user_id or "unknown"
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
context["message_type"] = context.get("message_type") or "normal"
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
# 标准化关键词类型
keywords = context.get("keywords")
if keywords is None:
context["keywords"] = []
elif isinstance(keywords, tuple):
context["keywords"] = list(keywords)
elif not isinstance(keywords, list):
context["keywords"] = [str(keywords)] if keywords else []
# 统一 stream_id
stream_id = context.get("stream_id") or context.get("stram_id")
if not stream_id:
potential = context.get("chat_id") or context.get("session_id")
if isinstance(potential, str) and potential:
stream_id = potential
if stream_id:
context["stream_id"] = stream_id
# chat_id 兜底
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
# 历史窗口配置
window_candidate = (
context.get("history_limit")
or context.get("history_window")
or context.get("memory_history_limit")
)
if window_candidate is not None:
try:
context["history_limit"] = int(window_candidate)
except (TypeError, ValueError):
context.pop("history_limit", None)
return context
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
if not context:
return fallback_text
stream_id = context.get("stream_id") or context.get("stram_id")
if not stream_id:
return fallback_text
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
return fallback_text
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if not messages:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
return fallback_text
transcript = self._format_history_messages(messages)
if not transcript:
return fallback_text
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
logger.debug(
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
stream_id,
len(messages),
history_limit,
)
return transcript
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
return fallback_text
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
default_limit = 40
candidate = (
context.get("history_limit")
or context.get("history_window")
or context.get("memory_history_limit")
)
if isinstance(candidate, str):
try:
candidate = int(candidate)
except ValueError:
candidate = None
if isinstance(candidate, int):
history_limit = max(30, min(50, candidate))
else:
history_limit = default_limit
return history_limit
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
if not messages:
return None
lines: List[str] = []
for msg in messages:
try:
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
if not content:
continue
content = re.sub(r"\s+", " ", str(content).strip())
if not content:
continue
speaker = None
if hasattr(msg, "user_info") and msg.user_info:
speaker = (
getattr(msg.user_info, "user_nickname", None)
or getattr(msg.user_info, "user_cardname", None)
or getattr(msg.user_info, "user_id", None)
)
speaker = speaker or getattr(msg, "user_nickname", None) or getattr(msg, "user_id", None) or "用户"
timestamp_value = getattr(msg, "time", None) or 0.0
try:
timestamp_dt = datetime.fromtimestamp(float(timestamp_value)) if timestamp_value else datetime.now()
except (TypeError, ValueError, OSError):
timestamp_dt = datetime.now()
timestamp_str = timestamp_dt.strftime("%Y-%m-%d %H:%M:%S")
lines.append(f"[{timestamp_str}] {speaker}: {content}")
except Exception as message_exc:
logger.debug(f"格式化历史消息失败: {message_exc}")
continue
return "\n".join(lines) if lines else None
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
"""评估信息价值
Args:
text: 文本内容
context: 上下文信息
Returns:
价值评分 (0.0-1.0)
"""
try:
# 构建评估提示
prompt = f"""
请评估以下对话内容的信息价值,重点识别包含个人事实、事件、偏好、观点等重要信息的内容。
## 🎯 价值评估重点标准:
### 高价值信息 (0.7-1.0分)
1. **个人事实** (personal_fact):包含姓名、年龄、职业、联系方式、住址、健康状况、家庭情况等个人信息
2. **重要事件** (event):约会、会议、旅行、考试、面试、搬家等重要活动或经历
3. **明确偏好** (preference):表达喜欢/不喜欢的食物、电影、音乐、品牌、生活习惯等偏好信息
4. **观点态度** (opinion):对事物的评价、看法、建议、态度等主观观点
5. **核心关系** (relationship):重要的朋友、家人、同事等人际关系信息
### 中等价值信息 (0.4-0.7分)
1. **情感表达**:当前情绪状态、心情变化
2. **日常活动**:常规的工作、学习、生活安排
3. **一般兴趣**:兴趣爱好、休闲活动
4. **短期计划**:即将进行的安排和计划
### 低价值信息 (0.0-0.4分)
1. **寒暄问候**:简单的打招呼、礼貌用语
2. **重复信息**:已经多次提到的相同内容
3. **临时状态**:短暂的情绪波动、临时想法
4. **无关内容**:与用户画像建立无关的信息
对话内容:
{text}
上下文信息:
- 用户ID: {context.get('user_id', 'unknown')}
- 消息类型: {context.get('message_type', 'unknown')}
- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))}
## 📋 评估要求:
### 积极识别原则:
- **宁可高估,不可低估** - 对于可能的个人信息给予较高评估
- **重点关注** - 特别注意包含 personal_fact、event、preference、opinion 的内容
- **细节丰富** - 具体的细节信息比笼统的描述更有价值
- **建立画像** - 有助于建立完整用户画像的信息更有价值
### 评分指导:
- **0.9-1.0**:核心个人信息(姓名、联系方式、重要偏好)
- **0.7-0.8**:重要的个人事实、观点、事件经历
- **0.5-0.6**:一般性偏好、日常活动、情感表达
- **0.3-0.4**:简单的兴趣表达、临时状态
- **0.0-0.2**:寒暄问候、重复内容、无关信息
请以JSON格式输出评估结果
{{
"value_score": 0.0到1.0之间的数值,
"reasoning": "评估理由,包含具体识别到的信息类型",
"key_factors": ["关键因素1", "关键因素2"],
"detected_types": ["personal_fact", "preference", "opinion", "event", "relationship", "emotion", "goal"]
}}
"""
response, _ = await self.value_assessment_model.generate_response_async(
prompt, temperature=0.3
)
# 解析响应
try:
payload = self._extract_json_payload(response)
if not payload:
raise ValueError("未在响应中找到有效的JSON负载")
result = orjson.loads(payload)
value_score = float(result.get("value_score", 0.0))
reasoning = result.get("reasoning", "")
key_factors = result.get("key_factors", [])
logger.info(f"信息价值评估: {value_score:.2f}, 理由: {reasoning}")
if key_factors:
logger.info(f"关键因素: {', '.join(key_factors)}")
return max(0.0, min(1.0, value_score))
except (orjson.JSONDecodeError, ValueError) as e:
preview = response[:200].replace('\n', ' ')
logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}")
return 0.5 # 默认中等价值
except Exception as e:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
"""存储记忆块到各个存储系统"""
if not memory_chunks:
return
# 并行存储到向量数据库和元数据索引
storage_tasks = []
# 向量存储
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
# 元数据索引
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
# 等待所有存储任务完成
await asyncio.gather(*storage_tasks, return_exceptions=True)
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
def get_system_stats(self) -> Dict[str, Any]:
"""获取系统统计信息"""
return {
"status": self.status.value,
"total_memories": self.total_memories,
"last_build_time": self.last_build_time,
"last_retrieval_time": self.last_retrieval_time,
"config": asdict(self.config)
}
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""根据查询和上下文为记忆计算匹配分数"""
tokens_query = self._tokenize_text(query_text)
tokens_memory = self._tokenize_text(memory.text_content)
if tokens_query and tokens_memory:
base_score = len(tokens_query & tokens_memory) / len(tokens_query | tokens_memory)
else:
base_score = 0.0
context_keywords = context.get("keywords") or []
keyword_overlap = 0.0
if context_keywords:
memory_keywords = set(k.lower() for k in memory.keywords)
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1)
importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1
confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
return max(0.0, min(1.0, final_score))
def _tokenize_text(self, text: str) -> Set[str]:
"""简单分词,兼容中英文"""
if not text:
return set()
tokens = re.findall(r"[\w\u4e00-\u9fa5]+", text.lower())
return {token for token in tokens if len(token) > 1}
async def maintenance(self):
"""系统维护操作"""
try:
logger.info("开始记忆系统维护...")
# 向量存储优化
await self.vector_storage.optimize_storage()
# 元数据索引优化
await self.metadata_index.optimize_index()
# 记忆融合引擎维护
await self.fusion_engine.maintenance()
logger.info("✅ 记忆系统维护完成")
except Exception as e:
logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭系统"""
try:
logger.info("正在关闭增强型记忆系统...")
# 保存持久化数据
await self.vector_storage.save_storage()
await self.metadata_index.save_index()
logger.info("✅ 增强型记忆系统已关闭")
except Exception as e:
logger.error(f"❌ 记忆系统关闭失败: {e}", exc_info=True)
# 全局记忆系统实例
enhanced_memory_system: EnhancedMemorySystem = None
def get_enhanced_memory_system() -> EnhancedMemorySystem:
"""获取全局记忆系统实例"""
global enhanced_memory_system
if enhanced_memory_system is None:
enhanced_memory_system = EnhancedMemorySystem()
return enhanced_memory_system
async def initialize_enhanced_memory_system():
"""初始化全局记忆系统"""
global enhanced_memory_system
if enhanced_memory_system is None:
enhanced_memory_system = EnhancedMemorySystem()
await enhanced_memory_system.initialize()
return enhanced_memory_system

View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
import asyncio
from typing import Dict, List, Any, Optional
from datetime import datetime
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
logger = get_logger(__name__)
class EnhancedMemoryHooks:
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
def __init__(self):
self.enabled = (global_config.memory.enable_memory and
global_config.memory.enable_enhanced_memory)
self.processed_messages = set() # 避免重复处理
async def process_message_for_memory(
self,
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None
) -> bool:
"""
处理消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 上下文信息
Returns:
bool: 是否成功处理
"""
if not self.enabled:
return False
if message_id in self.processed_messages:
return False
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建上下文
memory_context = {
"chat_id": chat_id,
"message_id": message_id,
"timestamp": datetime.now().timestamp(),
"message_type": "user_message",
**(context or {})
}
# 处理对话并构建记忆
memory_chunks = await enhanced_memory_manager.process_conversation(
conversation_text=message_content,
context=memory_context,
user_id=user_id,
timestamp=memory_context["timestamp"]
)
# 标记消息已处理
self.processed_messages.add(message_id)
# 限制处理历史大小
if len(self.processed_messages) > 1000:
# 移除最旧的500个记录
self.processed_messages = set(list(self.processed_messages)[-500:])
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
return len(memory_chunks) > 0
except Exception as e:
logger.error(f"处理消息记忆失败: {e}")
return False
async def get_memory_for_response(
self,
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5
) -> List[Dict[str, Any]]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
Returns:
List[Dict]: 相关记忆列表
"""
if not self.enabled:
return []
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建查询上下文
context = {
"chat_id": chat_id,
"query_intent": "response_generation",
"expected_memory_types": [
"personal_fact", "event", "preference", "opinion"
]
}
# 获取相关记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text,
user_id=user_id,
context=context,
limit=limit
)
# 转换为字典格式
results = []
for result in enhanced_results:
memory_dict = {
"content": result.content,
"type": result.memory_type,
"confidence": result.confidence,
"importance": result.importance,
"timestamp": result.timestamp,
"source": result.source
}
results.append(memory_dict)
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return []
async def cleanup_old_memories(self):
"""清理旧记忆"""
try:
if enhanced_memory_manager.is_initialized:
# 调用增强记忆系统的维护功能
await enhanced_memory_manager.enhanced_system.maintenance()
logger.debug("增强记忆系统维护完成")
except Exception as e:
logger.error(f"清理旧记忆失败: {e}")
def clear_processed_cache(self):
"""清除已处理消息的缓存"""
self.processed_messages.clear()
logger.debug("已清除消息处理缓存")
def enable(self):
"""启用记忆钩子"""
self.enabled = True
logger.info("增强记忆钩子已启用")
def disable(self):
"""禁用记忆钩子"""
self.enabled = False
logger.info("增强记忆钩子已禁用")
# 创建全局实例
enhanced_memory_hooks = EnhancedMemoryHooks()

View File

@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
import asyncio
from typing import Dict, Any, Optional
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None
) -> bool:
"""
处理用户消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 额外的上下文信息
Returns:
bool: 是否成功构建记忆
"""
try:
success = await enhanced_memory_hooks.process_message_for_memory(
message_content=message_content,
user_id=user_id,
chat_id=chat_id,
message_id=message_id,
context=context
)
if success:
logger.debug(f"成功为消息 {message_id} 构建记忆")
return success
except Exception as e:
logger.error(f"处理用户消息记忆失败: {e}")
return False
async def get_relevant_memories_for_response(
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5
) -> Dict[str, Any]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本(通常是用户的当前消息)
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
Returns:
Dict: 包含记忆信息的字典
"""
try:
memories = await enhanced_memory_hooks.get_memory_for_response(
query_text=query_text,
user_id=user_id,
chat_id=chat_id,
limit=limit
)
result = {
"has_memories": len(memories) > 0,
"memories": memories,
"memory_count": len(memories)
}
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
return result
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return {
"has_memories": False,
"memories": [],
"memory_count": 0
}
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
Args:
memories: 记忆信息字典
Returns:
str: 格式化后的记忆文本
"""
if not memories["has_memories"]:
return ""
memory_lines = ["以下是相关的记忆信息:"]
for memory in memories["memories"]:
content = memory["content"]
memory_type = memory["type"]
confidence = memory["confidence"]
importance = memory["importance"]
# 根据重要性添加不同的标记
importance_marker = "🔥" if importance >= 3 else "" if importance >= 2 else "📝"
confidence_marker = "" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
memory_lines.append(memory_line)
return "\n".join(memory_lines)
async def cleanup_memory_system():
"""清理记忆系统"""
try:
await enhanced_memory_hooks.cleanup_old_memories()
logger.info("记忆系统清理完成")
except Exception as e:
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> Dict[str, Any]:
"""
获取记忆系统状态
Returns:
Dict: 系统状态信息
"""
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
return {
"enabled": enhanced_memory_hooks.enabled,
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
"system_type": "enhanced_memory_system"
}
# 便捷函数
async def remember_message(
message: str,
user_id: str = "default_user",
chat_id: str = "default_chat"
) -> bool:
"""
便捷的记忆构建函数
Args:
message: 要记住的消息
user_id: 用户ID
chat_id: 聊天ID
Returns:
bool: 是否成功
"""
import uuid
message_id = str(uuid.uuid4())
return await process_user_message_memory(
message_content=message,
user_id=user_id,
chat_id=chat_id,
message_id=message_id
)
async def recall_memories(
query: str,
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5
) -> Dict[str, Any]:
"""
便捷的记忆检索函数
Args:
query: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回数量限制
Returns:
Dict: 记忆信息
"""
return await get_relevant_memories_for_response(
query_text=query,
user_id=user_id,
chat_id=chat_id,
limit=limit
)

View File

@@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统管理器
替代原有的 Hippocampus 和 instant_memory 系统
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.enhanced_memory_adapter import (
initialize_enhanced_memory_system
)
logger = get_logger(__name__)
@dataclass
class EnhancedMemoryResult:
"""增强记忆查询结果"""
content: str
memory_type: str
confidence: float
importance: float
timestamp: float
source: str = "enhanced_memory"
relevance_score: float = 0.0
class EnhancedMemoryManager:
"""增强记忆系统管理器 - 替代原有的 HippocampusManager"""
def __init__(self):
self.enhanced_system: Optional[EnhancedMemorySystem] = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
async def initialize(self):
"""初始化增强记忆系统"""
if self.is_initialized:
return
try:
from src.config.config import global_config
# 检查是否启用增强记忆系统
if not global_config.memory.enable_enhanced_memory:
logger.info("增强记忆系统已禁用,跳过初始化")
self.is_initialized = True
return
logger.info("正在初始化增强记忆系统...")
# 获取LLM模型
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
# 初始化增强记忆系统
self.enhanced_system = await initialize_enhanced_memory_system(llm_model)
# 设置全局实例
global_enhanced_manager = self.enhanced_system
self.is_initialized = True
logger.info("✅ 增强记忆系统初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}")
# 如果增强系统初始化失败,创建一个空的管理器避免系统崩溃
self.enhanced_system = None
self.is_initialized = True # 标记为已初始化但系统不可用
def get_hippocampus(self):
"""兼容原有接口 - 返回空"""
logger.debug("get_hippocampus 调用 - 增强记忆系统不使用此方法")
return {}
async def build_memory(self):
"""兼容原有接口 - 构建记忆"""
if not self.is_initialized or not self.enhanced_system:
return
try:
# 增强记忆系统使用实时构建,不需要定时构建
logger.debug("build_memory 调用 - 增强记忆系统使用实时构建")
except Exception as e:
logger.error(f"build_memory 失败: {e}")
async def forget_memory(self, percentage: float = 0.005):
"""兼容原有接口 - 遗忘机制"""
if not self.is_initialized or not self.enhanced_system:
return
try:
# 增强记忆系统有内置的遗忘机制
logger.debug(f"forget_memory 调用 - 参数: {percentage}")
# 可以在这里调用增强系统的维护功能
await self.enhanced_system.maintenance()
except Exception as e:
logger.error(f"forget_memory 失败: {e}")
async def consolidate_memory(self):
"""兼容原有接口 - 记忆巩固"""
if not self.is_initialized or not self.enhanced_system:
return
try:
# 增强记忆系统自动处理记忆巩固
logger.debug("consolidate_memory 调用 - 增强记忆系统自动处理")
except Exception as e:
logger.error(f"consolidate_memory 失败: {e}")
async def get_memory_from_text(
self,
text: str,
chat_id: str,
user_id: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0
) -> List[Tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.enhanced_system:
return []
try:
# 使用增强记忆系统检索
context = {
"chat_id": chat_id,
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE]
}
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query=text,
user_id=user_id,
context=context,
limit=max_memory_num
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从文本检索到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"get_memory_from_text 失败: {e}")
return []
async def get_memory_from_topic(
self,
valid_keywords: List[str],
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3
) -> List[Tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.enhanced_system:
return []
try:
# 将关键词转换为查询文本
query_text = " ".join(valid_keywords)
# 使用增强记忆系统检索
context = {
"keywords": valid_keywords,
"expected_memory_types": [
MemoryType.PERSONAL_FACT,
MemoryType.EVENT,
MemoryType.PREFERENCE,
MemoryType.OPINION
]
}
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query_text=query_text,
user_id="default_user", # 可以根据实际需要传递
context=context,
limit=max_memory_num
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从关键词 {valid_keywords} 检索到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"get_memory_from_topic 失败: {e}")
return []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从单个关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.enhanced_system:
return []
try:
# 同步方法,返回空列表
logger.debug(f"get_memory_from_keyword 调用 - 关键词: {keyword}")
return []
except Exception as e:
logger.error(f"get_memory_from_keyword 失败: {e}")
return []
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.enhanced_system:
return []
try:
result = await self.enhanced_system.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
# 从结果中提取记忆块
memory_chunks = []
if result.get("success"):
memory_chunks = result.get("created_memories", [])
logger.info(f"从对话构建了 {len(memory_chunks)} 条记忆")
return memory_chunks
except Exception as e:
logger.error(f"process_conversation 失败: {e}")
return []
async def get_enhanced_memory_context(
self,
query_text: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[EnhancedMemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.enhanced_system:
return []
try:
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query=query_text,
user_id=user_id,
context=context or {},
limit=limit
)
results = []
for memory in relevant_memories:
result = EnhancedMemoryResult(
content=memory.text_content,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="enhanced_memory",
relevance_score=memory.metadata.relevance_score
)
results.append(result)
return results
except Exception as e:
logger.error(f"get_enhanced_memory_context 失败: {e}")
return []
async def shutdown(self):
"""关闭增强记忆系统"""
if not self.is_initialized:
return
try:
if self.enhanced_system:
await self.enhanced_system.shutdown()
logger.info("✅ 增强记忆系统已关闭")
except Exception as e:
logger.error(f"关闭增强记忆系统失败: {e}")
# 全局增强记忆管理器实例
enhanced_memory_manager = EnhancedMemoryManager()

View File

@@ -1,168 +0,0 @@
# 混合瞬时记忆系统设计
## 系统概述
融合 `instant_memory.py`LLM系统`vector_instant_memory.py`(向量系统)的混合记忆系统,智能选择最优策略,无需配置文件控制。
## 融合架构
```
聊天输入 → 智能调度器 → 选择策略 → 双重存储 → 融合检索 → 统一输出
```
## 核心组件设计
### 1. HybridInstantMemory (主类)
**职责**: 统一接口,智能调度两套记忆系统
**关键方法**:
- `__init__(chat_id)` - 初始化两套子系统
- `create_and_store_memory(text)` - 智能存储记忆
- `get_memory(target)` - 融合检索记忆
- `get_stats()` - 统计信息
### 2. MemoryStrategy (策略判断器)
**职责**: 判断使用哪种记忆策略
**判断规则**:
- 文本长度 < 30字符 优先向量系统快速
- 包含情感词汇/重要信息 使用LLM系统准确
- 复杂场景 双重验证
**实现方法**:
```python
def decide_strategy(self, text: str) -> MemoryMode:
# 长度判断
if len(text) < 30:
return MemoryMode.VECTOR_ONLY
# 情感关键词检测
if self._contains_emotional_content(text):
return MemoryMode.LLM_PREFERRED
# 默认混合模式
return MemoryMode.HYBRID
```
### 3. MemorySync (同步器)
**职责**: 处理两套系统间的记忆同步和去重
**同步策略**:
- 向量系统存储的记忆 异步同步到LLM系统
- LLM系统生成的高质量记忆 生成向量存储
- 定期去重避免重复记忆
### 4. HybridRetriever (检索器)
**职责**: 融合两种检索方式提供最优结果
**检索策略**:
1. 并行查询向量系统和LLM系统
2. 按相似度/相关性排序
3. 去重合并返回最相关的记忆
## 智能调度逻辑
### 快速路径 (Vector Path)
- 适用: 短文本常规对话快速查询
- 优势: 响应速度快资源消耗低
- 时机: 文本简单无特殊情感内容
### 准确路径 (LLM Path)
- 适用: 重要信息情感表达复杂语义
- 优势: 语义理解深度记忆质量高
- 时机: 检测到重要性标志
### 混合路径 (Hybrid Path)
- 适用: 中等复杂度内容
- 策略: 向量快速筛选 + LLM精确处理
- 平衡: 速度与准确性
## 记忆存储策略
### 双重备份机制
1. **主存储**: 根据策略选择主要存储方式
2. **备份存储**: 异步备份到另一系统
3. **同步检查**: 定期校验两边数据一致性
### 存储优化
- 向量系统: 立即存储快速可用
- LLM系统: 批量处理高质量整理
- 重复检测: 跨系统去重
## 检索融合策略
### 并行检索
```python
async def get_memory(self, target: str):
# 并行查询两个系统
vector_task = self.vector_memory.get_memory(target)
llm_task = self.llm_memory.get_memory(target)
vector_results, llm_results = await asyncio.gather(
vector_task, llm_task, return_exceptions=True
)
# 融合结果
return self._merge_results(vector_results, llm_results)
```
### 结果融合
1. **相似度评分**: 统一两种系统的相似度计算
2. **权重调整**: 根据查询类型调整系统权重
3. **去重合并**: 移除重复内容保留最相关的
## 性能优化
### 异步处理
- 向量检索: 同步快速响应
- LLM处理: 异步后台处理
- 批量操作: 减少系统调用开销
### 缓存策略
- 热点记忆缓存
- 查询结果缓存
- 向量计算缓存
### 降级机制
- 向量系统故障 只使用LLM系统
- LLM系统故障 只使用向量系统
- 全部故障 返回空结果记录错误
## 实现计划
1. **基础框架**: 创建HybridInstantMemory主类
2. **策略判断**: 实现智能调度逻辑
3. **存储融合**: 实现双重存储机制
4. **检索融合**: 实现并行检索和结果合并
5. **同步机制**: 实现跨系统数据同步
6. **性能优化**: 异步处理和缓存优化
7. **错误处理**: 降级机制和异常处理
## 使用接口
```python
# 初始化混合记忆系统
hybrid_memory = HybridInstantMemory(chat_id="user_123")
# 智能存储记忆
await hybrid_memory.create_and_store_memory("今天天气真好,我去公园散步了")
# 融合检索记忆
memories = await hybrid_memory.get_memory("天气")
# 获取系统状态
stats = hybrid_memory.get_stats()
print(f"向量记忆: {stats['vector_count']} 条")
print(f"LLM记忆: {stats['llm_count']} 条")
```
## 预期效果
- **响应速度**: 比纯LLM系统快60%+
- **记忆质量**: 比纯向量系统准确30%+
- **资源使用**: 智能调度按需使用资源
- **可靠性**: 双系统备份单点故障不影响服务

View File

@@ -1,254 +0,0 @@
# -*- coding: utf-8 -*-
import time
import re
import orjson
import traceback
from json_repair import repair_json
from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
from src.common.database.sqlalchemy_database_api import get_db_session
from src.config.config import model_config
from sqlalchemy import select
logger = get_logger(__name__)
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
self.chat_id = chat_id
self.memory_text: str = memory_text
self.keywords: list[str] = keywords
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class InstantMemory:
def __init__(self, chat_id):
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="memory.summary",
)
async def if_need_build(self, text):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
请只输出1或0就好
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
return "1" in response
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
async def build_memory(self, text):
prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text}
请以json格式输出一段概括的记忆内容和关键词
{{
"memory_text": "记忆内容",
"keywords": "关键词,用/划分"
}}
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
# print(prompt)
# print(response)
if not response:
return None
try:
repaired = repair_json(response)
result = orjson.loads(repaired)
memory_text = result.get("memory_text", "")
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory["memory_text"],
keywords=memory.get("keywords", []),
)
await self.store_memory(memory_item)
else:
logger.info(f"不需要记忆:{text}")
@staticmethod
async def store_memory(memory_item: MemoryItem):
async with get_db_session() as session:
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=orjson.dumps(memory_item.keywords).decode("utf-8"),
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
session.add(memory)
await session.commit()
async def get_memory(self, target: str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
请用json格式输出包含以下字段
其中time的要求是
可以选择具体日期时间格式为YYYY-MM-DD HH:MM:SS或者大致时间格式为YYYY-MM-DD
可以选择相对时间例如今天昨天前天5天前1个月前
可以选择留空进行模糊搜索
{{
"need_memory": 1,
"keywords": "希望获取的记忆关键词,用/划分",
"time": "希望获取的记忆大致时间"
}}
请只输出json格式不要输出其他多余内容
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if not response:
return None
try:
repaired = repair_json(response)
result = orjson.loads(repaired)
# 解析keywords
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
memories_set = set()
async with get_db_session() as session:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = (await session.execute(
select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)
)).scalars()
else:
result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
query = result.scalars()
for mem in query:
# 对每条记忆
mem_keywords_str = mem.keywords or "[]"
try:
mem_keywords = orjson.loads(mem_keywords_str)
except orjson.JSONDecodeError:
mem_keywords = []
# logger.info(f"mem_keywords: {mem_keywords}")
# logger.info(f"keywords_list: {keywords_list}")
for kw in keywords_list:
# logger.info(f"kw: {kw}")
# logger.info(f"kw in mem_keywords: {kw in mem_keywords}")
if kw in mem_keywords:
# logger.info(f"mem.memory_text: {mem.memory_text}")
memories_set.add(mem.memory_text)
break
return list(memories_set)
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
@staticmethod
def _parse_time_range(time_str):
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
"""
支持解析如下格式:
- 具体日期时间YYYY-MM-DD HH:MM:SS
- 具体日期YYYY-MM-DD
- 相对时间今天昨天前天N天前N个月前
- 空字符串:返回(None, None)
"""
now = datetime.now()
if not time_str:
return 0, now
time_str = time_str.strip()
# 具体日期时间
try:
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
return dt, dt + timedelta(hours=1)
except Exception:
...
# 具体日期
try:
dt = datetime.strptime(time_str, "%Y-%m-%d")
return dt, dt + timedelta(days=1)
except Exception:
...
# 相对时间
if time_str == "今天":
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if time_str == "昨天":
start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if time_str == "前天":
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if m := re.match(r"(\d+)天前", time_str):
days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if m := re.match(r"(\d+)个月前", time_str):
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
# 其他无法解析
return 0, now

View File

@@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import time
import asyncio
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from enum import Enum
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
class IntegrationMode(Enum):
"""集成模式"""
REPLACE = "replace" # 完全替换现有记忆系统
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
@dataclass
class IntegrationConfig:
"""集成配置"""
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
enable_enhanced_memory: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
enable_learning: bool = True
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
# 集成统计
self.integration_stats = {
"total_queries": 0,
"enhanced_queries": 0,
"memory_creations": 0,
"average_response_time": 0.0,
"success_rate": 0.0
}
# 初始化锁
self._initialization_lock = asyncio.Lock()
self._initialized = False
async def initialize(self):
"""初始化集成层"""
if self._initialized:
return
async with self._initialization_lock:
if self._initialized:
return
logger.info("🚀 开始初始化增强记忆系统集成层...")
try:
# 初始化增强记忆系统
if self.config.enable_enhanced_memory:
await self._initialize_enhanced_memory()
self._initialized = True
logger.info("✅ 增强记忆系统集成层初始化完成")
except Exception as e:
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
raise
async def _initialize_enhanced_memory(self):
"""初始化增强记忆系统"""
try:
logger.debug("初始化增强记忆系统...")
# 创建增强记忆系统配置
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
memory_config = MemorySystemConfig.from_global_config()
# 使用集成配置覆盖部分值
memory_config.memory_value_threshold = self.config.memory_value_threshold
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
memory_config.final_recall_limit = self.config.max_retrieval_results
# 创建增强记忆系统
self.enhanced_memory = EnhancedMemorySystem(
config=memory_config
)
# 如果外部提供了LLM模型注入到系统中
if self.llm_model is not None:
self.enhanced_memory.llm_model = self.llm_model
# 初始化系统
await self.enhanced_memory.initialize()
logger.info("✅ 增强记忆系统初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> Dict[str, Any]:
"""处理对话记忆"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
start_time = time.time()
self.integration_stats["total_queries"] += 1
self.integration_stats["enhanced_queries"] += 1
try:
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
# 更新统计
processing_time = time.time() - start_time
self._update_response_stats(processing_time, result.get("success", False))
if result.get("success"):
created_count = len(result.get("created_memories", []))
self.integration_stats["memory_creations"] += created_count
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
processing_time = time.time() - start_time
self._update_response_stats(processing_time, False)
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> List[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query,
user_id=user_id,
context=context or {},
limit=limit
)
memory_count = len(memories)
logger.debug(f"检索到 {memory_count} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> Dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
try:
enhanced_status = {}
if self.enhanced_memory:
enhanced_status = await self.enhanced_memory.get_system_status()
return {
"status": "initialized",
"mode": self.config.mode.value,
"enhanced_memory": enhanced_status,
"integration_stats": self.integration_stats.copy()
}
except Exception as e:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> Dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()
def _update_response_stats(self, processing_time: float, success: bool):
"""更新响应统计"""
total_queries = self.integration_stats["total_queries"]
if total_queries > 0:
# 更新平均响应时间
current_avg = self.integration_stats["average_response_time"]
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
self.integration_stats["average_response_time"] = new_avg
# 更新成功率
if success:
current_success_rate = self.integration_stats["success_rate"]
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
self.integration_stats["success_rate"] = new_success_rate
async def maintenance(self):
"""执行维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 执行记忆系统集成层维护...")
if self.enhanced_memory:
await self.enhanced_memory.maintenance()
logger.info("✅ 记忆系统集成层维护完成")
except Exception as e:
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭集成层"""
if not self._initialized:
return
try:
logger.info("🔄 关闭记忆系统集成层...")
if self.enhanced_memory:
await self.enhanced_memory.shutdown()
self._initialized = False
logger.info("✅ 记忆系统集成层已关闭")
except Exception as e:
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)

View File

@@ -1,144 +0,0 @@
import difflib
import orjson
from json_repair import repair_json
from typing import List, Dict
from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Group Chat Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行回忆
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
聊天记录:
{obs_info_text}
你想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# logger.debug(f"prompt: {prompt}")
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
# 调用记忆系统获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
)
logger.debug(f"当前记忆关键词: {self.cached_keywords} ")
logger.debug(f"获取到的记忆: {related_memory}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
if related_memory:
for topic, memory in related_memory:
# 检查是否已存在相同topic或相似内容相似度>=0.7)的记忆
exists = any(
m["topic"] == topic or difflib.SequenceMatcher(None, m["content"], memory).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
self.running_memory.append(
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat(), "duration": 1}
)
logger.debug(f"添加新记忆: {topic} - {memory}")
# 限制同时加载的记忆条数最多保留最后3条
if len(self.running_memory) > 3:
self.running_memory = self.running_memory[-3:]
return self.running_memory
init_prompt()

View File

@@ -0,0 +1,602 @@
# -*- coding: utf-8 -*-
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
"""
import re
import time
import orjson
from typing import Dict, List, Optional, Tuple, Any, Set
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
ContentStructure, MemoryMetadata, create_memory_chunk
)
logger = get_logger(__name__)
class ExtractionStrategy(Enum):
"""提取策略"""
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
@dataclass
class ExtractionResult:
"""提取结果"""
memories: List[MemoryChunk]
confidence_scores: List[float]
extraction_time: float
strategy_used: ExtractionStrategy
class MemoryBuilder:
"""记忆构建器"""
def __init__(self, llm_model: LLMRequest):
self.llm_model = llm_model
self.extraction_stats = {
"total_extractions": 0,
"successful_extractions": 0,
"failed_extractions": 0,
"average_confidence": 0.0
}
async def build_memories(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""从对话中构建记忆"""
start_time = time.time()
try:
logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}")
# 预处理文本
processed_text = self._preprocess_text(conversation_text)
# 确定提取策略
strategy = self._determine_extraction_strategy(processed_text, context)
# 根据策略提取记忆
if strategy == ExtractionStrategy.LLM_BASED:
memories = await self._extract_with_llm(processed_text, context, user_id, timestamp)
elif strategy == ExtractionStrategy.RULE_BASED:
memories = self._extract_with_rules(processed_text, context, user_id, timestamp)
else: # HYBRID
memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp)
# 后处理和验证
validated_memories = self._validate_and_enhance_memories(memories, context)
# 更新统计
extraction_time = time.time() - start_time
self._update_extraction_stats(len(validated_memories), extraction_time)
logger.info(f"✅ 成功构建 {len(validated_memories)} 条记忆,耗时 {extraction_time:.2f}")
return validated_memories
except Exception as e:
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
self.extraction_stats["failed_extractions"] += 1
return []
def _preprocess_text(self, text: str) -> str:
"""预处理文本"""
# 移除多余的空白字符
text = re.sub(r'\s+', ' ', text.strip())
# 移除特殊字符,但保留基本标点
text = re.sub(r'[^\w\s\u4e00-\u9fff""''()【】]', '', text)
# 截断过长的文本
if len(text) > 2000:
text = text[:2000] + "..."
return text
def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy:
"""确定提取策略"""
text_length = len(text)
has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"])
message_type = context.get("message_type", "normal")
# 短文本使用规则提取
if text_length < 50:
return ExtractionStrategy.RULE_BASED
# 包含结构化数据使用混合策略
if has_structured_data:
return ExtractionStrategy.HYBRID
# 系统消息或命令使用规则提取
if message_type in ["command", "system"]:
return ExtractionStrategy.RULE_BASED
# 默认使用LLM提取
return ExtractionStrategy.LLM_BASED
async def _extract_with_llm(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""使用LLM提取记忆"""
try:
prompt = self._build_llm_extraction_prompt(text, context)
response, _ = await self.llm_model.generate_response_async(
prompt, temperature=0.3
)
# 解析LLM响应
memories = self._parse_llm_response(response, user_id, timestamp, context)
return memories
except Exception as e:
logger.error(f"LLM提取失败: {e}")
return []
def _extract_with_rules(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""使用规则提取记忆"""
memories = []
# 规则1: 检测个人信息
personal_info = self._extract_personal_info(text, user_id, timestamp, context)
memories.extend(personal_info)
# 规则2: 检测偏好信息
preferences = self._extract_preferences(text, user_id, timestamp, context)
memories.extend(preferences)
# 规则3: 检测事件信息
events = self._extract_events(text, user_id, timestamp, context)
memories.extend(events)
return memories
async def _extract_with_hybrid(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""混合策略提取记忆"""
all_memories = []
# 首先使用规则提取
rule_memories = self._extract_with_rules(text, context, user_id, timestamp)
all_memories.extend(rule_memories)
# 然后使用LLM提取
llm_memories = await self._extract_with_llm(text, context, user_id, timestamp)
# 合并和去重
final_memories = self._merge_hybrid_results(all_memories, llm_memories)
return final_memories
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
"""构建LLM提取提示"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_id = context.get("chat_id", "unknown")
message_type = context.get("message_type", "normal")
prompt = f"""
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
当前时间: {current_date}
聊天ID: {chat_id}
消息类型: {message_type}
对话内容:
{text}
## 🎯 重点记忆类型识别指南
### 1. **个人事实** (personal_fact) - 高优先级记忆
**包括但不限于:**
- 基本信息:姓名、年龄、职业、学校、专业、工作地点
- 生活状况:住址、电话、邮箱、社交账号
- 身份特征:生日、星座、血型、国籍、语言能力
- 健康信息:身体状况、疾病史、药物过敏、运动习惯
- 家庭情况:家庭成员、婚姻状况、子女信息、宠物信息
**判断标准:** 涉及个人身份和生活的重要信息,都应该记忆
### 2. **事件** (event) - 高优先级记忆
**包括但不限于:**
- 重要时刻:生日聚会、毕业典礼、婚礼、旅行
- 日常活动:上班、上学、约会、看电影、吃饭
- 特殊经历:考试、面试、会议、搬家、购物
- 计划安排:约会、会议、旅行、活动
**判断标准:** 涉及时间地点的具体活动和经历,都应该记忆
### 3. **偏好** (preference) - 高优先级记忆
**包括但不限于:**
- 饮食偏好:喜欢的食物、餐厅、口味、禁忌
- 娱乐喜好:喜欢的电影、音乐、游戏、书籍
- 生活习惯:作息时间、运动方式、购物习惯
- 消费偏好:品牌喜好、价格敏感度、购物场所
- 风格偏好:服装风格、装修风格、颜色喜好
**判断标准:** 任何表达"喜欢""不喜欢""习惯""经常"等偏好的内容,都应该记忆
### 4. **观点** (opinion) - 高优先级记忆
**包括但不限于:**
- 评价看法:对事物的评价、意见、建议
- 价值判断:认为什么重要、什么不重要
- 态度立场:支持、反对、中立的态度
- 感受反馈:对经历的感受、反馈
**判断标准:** 任何表达主观看法和态度的内容,都应该记忆
### 5. **关系** (relationship) - 中等优先级记忆
**包括但不限于:**
- 人际关系:朋友、同事、家人、恋人的关系状态
- 社交互动:与他人的互动、交流、合作
- 群体归属:所属团队、组织、社群
### 6. **情感** (emotion) - 中等优先级记忆
**包括但不限于:**
- 情绪状态:开心、难过、生气、焦虑、兴奋
- 情感变化:情绪的转变、原因和结果
### 7. **目标** (goal) - 中等优先级记忆
**包括但不限于:**
- 计划安排:短期计划、长期目标
- 愿望期待:想要实现的事情、期望的结果
## 📝 记忆提取原则
### ✅ 积极提取原则:
1. **宁可错记,不可遗漏** - 对于可能的个人信息优先记忆
2. **持续追踪** - 相同信息的多次提及要强化记忆
3. **上下文关联** - 结合对话背景理解信息重要性
4. **细节丰富** - 记录具体的细节和描述
### 🎯 重要性等级标准:
- **4分 (关键)**:个人核心信息(姓名、联系方式、重要日期)
- **3分 (高)**:重要偏好、观点、经历事件
- **2分 (一般)**:一般性信息、日常活动、感受表达
- **1分 (低)**:琐碎细节、重复信息、临时状态
### 🔍 置信度标准:
- **4分 (已验证)**:用户明确确认的信息
- **3分 (高)**:用户直接表达的清晰信息
- **2分 (中等)**:需要推理或上下文判断的信息
- **1分 (低)**:模糊或不完整的信息
输出格式要求:
{{
"memories": [
{{
"type": "记忆类型",
"subject": "主语(通常是用户)",
"predicate": "谓语(动作/状态)",
"object": "宾语(对象/属性)",
"keywords": ["关键词1", "关键词2"],
"importance": "重要性等级(1-4)",
"confidence": "置信度(1-4)",
"reasoning": "提取理由"
}}
]
}}
注意:
1. 只提取确实值得记忆的信息,不要提取琐碎内容
2. 确保提取的信息准确、具体、有价值
3. 使用主谓宾结构确保信息清晰
4. 重要性等级: 1=低, 2=一般, 3=高, 4=关键
5. 置信度: 1=低, 2=中等, 3=高, 4=已验证
"""
return prompt
def _parse_llm_response(
self,
response: str,
user_id: str,
timestamp: float,
context: Dict[str, Any]
) -> List[MemoryChunk]:
"""解析LLM响应"""
memories = []
try:
data = orjson.loads(response)
memory_list = data.get("memories", [])
for mem_data in memory_list:
try:
# 创建记忆块
memory = create_memory_chunk(
user_id=user_id,
subject=mem_data.get("subject", user_id),
predicate=mem_data.get("predicate", ""),
obj=mem_data.get("object", ""),
memory_type=MemoryType(mem_data.get("type", "contextual")),
chat_id=context.get("chat_id"),
source_context=mem_data.get("reasoning", ""),
importance=ImportanceLevel(mem_data.get("importance", 2)),
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
)
# 添加关键词
keywords = mem_data.get("keywords", [])
for keyword in keywords:
memory.add_keyword(keyword)
memories.append(memory)
except Exception as e:
logger.warning(f"解析单个记忆失败: {e}, 数据: {mem_data}")
continue
except Exception as e:
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
return memories
def _extract_personal_info(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any]
) -> List[MemoryChunk]:
"""提取个人信息"""
memories = []
# 常见个人信息模式
patterns = {
r"我叫(\w+)": ("is_named", {"name": "$1"}),
r"我今年(\d+)岁": ("is_age", {"age": "$1"}),
r"我是(\w+)": ("is_profession", {"profession": "$1"}),
r"我住在(\w+)": ("lives_in", {"location": "$1"}),
r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}),
r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}),
}
for pattern, (predicate, obj_template) in patterns.items():
match = re.search(pattern, text)
if match:
obj = obj_template
for i, group in enumerate(match.groups(), 1):
obj = {k: v.replace(f"${i}", group) for k, v in obj.items()}
memory = create_memory_chunk(
user_id=user_id,
subject=user_id,
predicate=predicate,
obj=obj,
memory_type=MemoryType.PERSONAL_FACT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.HIGH,
confidence=ConfidenceLevel.HIGH
)
memories.append(memory)
return memories
def _extract_preferences(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any]
) -> List[MemoryChunk]:
"""提取偏好信息"""
memories = []
# 偏好模式
preference_patterns = [
(r"我喜欢(.+)", "likes"),
(r"我不喜欢(.+)", "dislikes"),
(r"我爱吃(.+)", "likes_food"),
(r"我讨厌(.+)", "hates"),
(r"我最喜欢的(.+)", "favorite_is"),
]
for pattern, predicate in preference_patterns:
match = re.search(pattern, text)
if match:
memory = create_memory_chunk(
user_id=user_id,
subject=user_id,
predicate=predicate,
obj=match.group(1),
memory_type=MemoryType.PREFERENCE,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
)
memories.append(memory)
return memories
def _extract_events(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any]
) -> List[MemoryChunk]:
"""提取事件信息"""
memories = []
# 事件关键词
event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"]
if any(keyword in text for keyword in event_keywords):
memory = create_memory_chunk(
user_id=user_id,
subject=user_id,
predicate="mentioned_event",
obj={"event_text": text, "timestamp": timestamp},
memory_type=MemoryType.EVENT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
)
memories.append(memory)
return memories
def _merge_hybrid_results(
self,
rule_memories: List[MemoryChunk],
llm_memories: List[MemoryChunk]
) -> List[MemoryChunk]:
"""合并混合策略结果"""
all_memories = rule_memories.copy()
# 添加LLM记忆避免重复
for llm_memory in llm_memories:
is_duplicate = False
for rule_memory in rule_memories:
if llm_memory.is_similar_to(rule_memory, threshold=0.7):
is_duplicate = True
# 合并置信度
rule_memory.metadata.confidence = ConfidenceLevel(
max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value)
)
break
if not is_duplicate:
all_memories.append(llm_memory)
return all_memories
def _validate_and_enhance_memories(
self,
memories: List[MemoryChunk],
context: Dict[str, Any]
) -> List[MemoryChunk]:
"""验证和增强记忆"""
validated_memories = []
for memory in memories:
# 基本验证
if not self._validate_memory(memory):
continue
# 增强记忆
enhanced_memory = self._enhance_memory(memory, context)
validated_memories.append(enhanced_memory)
return validated_memories
def _validate_memory(self, memory: MemoryChunk) -> bool:
"""验证记忆块"""
# 检查基本字段
if not memory.content.subject or not memory.content.predicate:
logger.debug(f"记忆块缺少主语或谓语: {memory.memory_id}")
return False
# 检查内容长度
content_length = len(memory.text_content)
if content_length < 5 or content_length > 500:
logger.debug(f"记忆块内容长度异常: {content_length}")
return False
# 检查置信度
if memory.metadata.confidence == ConfidenceLevel.LOW:
logger.debug(f"记忆块置信度过低: {memory.memory_id}")
return False
return True
def _enhance_memory(
self,
memory: MemoryChunk,
context: Dict[str, Any]
) -> MemoryChunk:
"""增强记忆块"""
# 添加时间上下文
if not memory.temporal_context:
memory.temporal_context = {
"timestamp": memory.metadata.created_at,
"timezone": context.get("timezone", "UTC"),
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A")
}
# 添加情感上下文(如果有)
if context.get("sentiment"):
memory.metadata.emotional_context = context["sentiment"]
# 自动添加标签
self._auto_tag_memory(memory)
return memory
def _auto_tag_memory(self, memory: MemoryChunk):
"""自动为记忆添加标签"""
# 基于记忆类型的自动标签
type_tags = {
MemoryType.PERSONAL_FACT: ["个人信息", "基本资料"],
MemoryType.EVENT: ["事件", "日程"],
MemoryType.PREFERENCE: ["偏好", "喜好"],
MemoryType.OPINION: ["观点", "态度"],
MemoryType.RELATIONSHIP: ["关系", "社交"],
MemoryType.EMOTION: ["情感", "情绪"],
MemoryType.KNOWLEDGE: ["知识", "信息"],
MemoryType.SKILL: ["技能", "能力"],
MemoryType.GOAL: ["目标", "计划"],
MemoryType.EXPERIENCE: ["经验", "经历"],
}
tags = type_tags.get(memory.memory_type, [])
for tag in tags:
memory.add_tag(tag)
def _update_extraction_stats(self, success_count: int, extraction_time: float):
"""更新提取统计"""
self.extraction_stats["total_extractions"] += 1
self.extraction_stats["successful_extractions"] += success_count
self.extraction_stats["failed_extractions"] += max(0, 1 - success_count)
# 更新平均置信度
if self.extraction_stats["successful_extractions"] > 0:
total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count)
# 假设新记忆的平均置信度为0.8
total_confidence += 0.8 * success_count
self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"]
def get_extraction_stats(self) -> Dict[str, Any]:
"""获取提取统计信息"""
return self.extraction_stats.copy()
def reset_stats(self):
"""重置统计信息"""
self.extraction_stats = {
"total_extractions": 0,
"successful_extractions": 0,
"failed_extractions": 0,
"average_confidence": 0.0
}

View File

@@ -0,0 +1,463 @@
# -*- coding: utf-8 -*-
"""
结构化记忆单元设计
实现高质量、结构化的记忆单元,符合文档设计规范
"""
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
import hashlib
import numpy as np
from src.common.logger import get_logger
logger = get_logger(__name__)
class MemoryType(Enum):
"""记忆类型分类"""
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
EVENT = "event" # 事件(重要经历、约会等)
PREFERENCE = "preference" # 偏好(喜好、习惯等)
OPINION = "opinion" # 观点(对事物的看法)
RELATIONSHIP = "relationship" # 关系(与他人的关系)
EMOTION = "emotion" # 情感状态
KNOWLEDGE = "knowledge" # 知识信息
SKILL = "skill" # 技能能力
GOAL = "goal" # 目标计划
EXPERIENCE = "experience" # 经验教训
CONTEXTUAL = "contextual" # 上下文信息
class ConfidenceLevel(Enum):
"""置信度等级"""
LOW = 1 # 低置信度,可能不准确
MEDIUM = 2 # 中等置信度,有一定依据
HIGH = 3 # 高置信度,有明确来源
VERIFIED = 4 # 已验证,非常可靠
class ImportanceLevel(Enum):
"""重要性等级"""
LOW = 1 # 低重要性,普通信息
NORMAL = 2 # 一般重要性,日常信息
HIGH = 3 # 高重要性,重要信息
CRITICAL = 4 # 关键重要性,核心信息
@dataclass
class ContentStructure:
"""主谓宾三元组结构"""
subject: str # 主语(通常为用户)
predicate: str # 谓语(动作、状态、关系)
object: Union[str, Dict] # 宾语(对象、属性、值)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"subject": self.subject,
"predicate": self.predicate,
"object": self.object
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
"""从字典创建实例"""
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", "")
)
def __str__(self) -> str:
"""字符串表示"""
if isinstance(self.object, dict):
object_str = str(self.object)
else:
object_str = str(self.object)
return f"{self.subject} {self.predicate} {object_str}"
@dataclass
class MemoryMetadata:
"""记忆元数据"""
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: Optional[str] = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
last_accessed: float = 0.0 # 最后访问时间
last_modified: float = 0.0 # 最后修改时间
# 统计信息
access_count: int = 0 # 访问次数
relevance_score: float = 0.0 # 相关度评分
# 信心和重要性
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
importance: ImportanceLevel = ImportanceLevel.NORMAL
# 情感和关系
emotional_context: Optional[str] = None # 情感上下文
relationship_score: float = 0.0 # 关系分0-1
# 来源和验证
source_context: Optional[str] = None # 来源上下文片段
verification_status: bool = False # 验证状态
def __post_init__(self):
"""后初始化处理"""
if not self.memory_id:
self.memory_id = str(uuid.uuid4())
if self.created_at == 0:
self.created_at = time.time()
if self.last_accessed == 0:
self.last_accessed = self.created_at
if self.last_modified == 0:
self.last_modified = self.created_at
def update_access(self):
"""更新访问信息"""
current_time = time.time()
self.last_accessed = current_time
self.access_count += 1
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.relevance_score = max(0.0, min(1.0, new_score))
self.last_modified = time.time()
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"memory_id": self.memory_id,
"user_id": self.user_id,
"chat_id": self.chat_id,
"created_at": self.created_at,
"last_accessed": self.last_accessed,
"last_modified": self.last_modified,
"access_count": self.access_count,
"relevance_score": self.relevance_score,
"confidence": self.confidence.value,
"importance": self.importance.value,
"emotional_context": self.emotional_context,
"relationship_score": self.relationship_score,
"source_context": self.source_context,
"verification_status": self.verification_status
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
"""从字典创建实例"""
return cls(
memory_id=data.get("memory_id", ""),
user_id=data.get("user_id", ""),
chat_id=data.get("chat_id"),
created_at=data.get("created_at", 0),
last_accessed=data.get("last_accessed", 0),
last_modified=data.get("last_modified", 0),
access_count=data.get("access_count", 0),
relevance_score=data.get("relevance_score", 0.0),
confidence=ConfidenceLevel(data.get("confidence", ConfidenceLevel.MEDIUM.value)),
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
emotional_context=data.get("emotional_context"),
relationship_score=data.get("relationship_score", 0.0),
source_context=data.get("source_context"),
verification_status=data.get("verification_status", False)
)
@dataclass
class MemoryChunk:
"""结构化记忆单元 - 核心数据结构"""
# 元数据
metadata: MemoryMetadata
# 内容结构
content: ContentStructure # 主谓宾结构
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: List[str] = field(default_factory=list) # 关键词列表
tags: List[str] = field(default_factory=list) # 标签列表
categories: List[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: Optional[List[float]] = None # 语义向量
semantic_hash: Optional[str] = None # 语义哈希值
# 关联信息
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
if self.embedding and len(self.embedding) > 0:
self._generate_semantic_hash()
def _generate_semantic_hash(self):
"""生成语义哈希值"""
if not self.embedding:
return
try:
# 使用向量和内容生成稳定的哈希
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
hash_object = hashlib.sha256(hash_input.encode('utf-8'))
self.semantic_hash = hash_object.hexdigest()[:16]
except Exception as e:
logger.warning(f"生成语义哈希失败: {e}")
self.semantic_hash = str(uuid.uuid4())[:16]
@property
def memory_id(self) -> str:
"""获取记忆ID"""
return self.metadata.memory_id
@property
def user_id(self) -> str:
"""获取用户ID"""
return self.metadata.user_id
@property
def text_content(self) -> str:
"""获取文本内容"""
return str(self.content)
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.metadata.update_relevance(new_score)
def add_keyword(self, keyword: str):
"""添加关键词"""
if keyword and keyword not in self.keywords:
self.keywords.append(keyword.strip())
def add_tag(self, tag: str):
"""添加标签"""
if tag and tag not in self.tags:
self.tags.append(tag.strip())
def add_category(self, category: str):
"""添加分类"""
if category and category not in self.categories:
self.categories.append(category.strip())
def add_related_memory(self, memory_id: str):
"""添加关联记忆"""
if memory_id and memory_id not in self.related_memories:
self.related_memories.append(memory_id)
def set_embedding(self, embedding: List[float]):
"""设置语义向量"""
self.embedding = embedding
self._generate_semantic_hash()
def calculate_similarity(self, other: "MemoryChunk") -> float:
"""计算与另一个记忆块的相似度"""
if not self.embedding or not other.embedding:
return 0.0
try:
# 计算余弦相似度
v1 = np.array(self.embedding)
v2 = np.array(other.embedding)
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0.0
similarity = dot_product / (norm1 * norm2)
return max(0.0, min(1.0, similarity))
except Exception as e:
logger.warning(f"计算记忆相似度失败: {e}")
return 0.0
def to_dict(self) -> Dict[str, Any]:
"""转换为完整的字典格式"""
return {
"metadata": self.metadata.to_dict(),
"content": self.content.to_dict(),
"memory_type": self.memory_type.value,
"keywords": self.keywords,
"tags": self.tags,
"categories": self.categories,
"embedding": self.embedding,
"semantic_hash": self.semantic_hash,
"related_memories": self.related_memories,
"temporal_context": self.temporal_context
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
"""从字典创建实例"""
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
content = ContentStructure.from_dict(data.get("content", {}))
chunk = cls(
metadata=metadata,
content=content,
memory_type=MemoryType(data.get("memory_type", MemoryType.CONTEXTUAL.value)),
keywords=data.get("keywords", []),
tags=data.get("tags", []),
categories=data.get("categories", []),
embedding=data.get("embedding"),
semantic_hash=data.get("semantic_hash"),
related_memories=data.get("related_memories", []),
temporal_context=data.get("temporal_context")
)
return chunk
def to_json(self) -> str:
"""转换为JSON字符串"""
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8')
@classmethod
def from_json(cls, json_str: str) -> "MemoryChunk":
"""从JSON字符串创建实例"""
try:
data = orjson.loads(json_str)
return cls.from_dict(data)
except Exception as e:
logger.error(f"从JSON创建记忆块失败: {e}")
raise
def is_similar_to(self, other: "MemoryChunk", threshold: float = 0.8) -> bool:
"""判断是否与另一个记忆块相似"""
if self.semantic_hash and other.semantic_hash:
return self.semantic_hash == other.semantic_hash
return self.calculate_similarity(other) >= threshold
def merge_with(self, other: "MemoryChunk") -> bool:
"""与另一个记忆块合并(如果相似)"""
if not self.is_similar_to(other):
return False
try:
# 合并关键词
for keyword in other.keywords:
self.add_keyword(keyword)
# 合并标签
for tag in other.tags:
self.add_tag(tag)
# 合并分类
for category in other.categories:
self.add_category(category)
# 合并关联记忆
for memory_id in other.related_memories:
self.add_related_memory(memory_id)
# 更新元数据
self.metadata.last_modified = time.time()
self.metadata.access_count += other.metadata.access_count
self.metadata.relevance_score = max(self.metadata.relevance_score, other.metadata.relevance_score)
# 更新置信度
if other.metadata.confidence.value > self.metadata.confidence.value:
self.metadata.confidence = other.metadata.confidence
# 更新重要性
if other.metadata.importance.value > self.metadata.importance.value:
self.metadata.importance = other.metadata.importance
logger.debug(f"记忆块 {self.memory_id} 合并了记忆块 {other.memory_id}")
return True
except Exception as e:
logger.error(f"合并记忆块失败: {e}")
return False
def __str__(self) -> str:
"""字符串表示"""
type_emoji = {
MemoryType.PERSONAL_FACT: "👤",
MemoryType.EVENT: "📅",
MemoryType.PREFERENCE: "❤️",
MemoryType.OPINION: "💭",
MemoryType.RELATIONSHIP: "👥",
MemoryType.EMOTION: "😊",
MemoryType.KNOWLEDGE: "📚",
MemoryType.SKILL: "🛠️",
MemoryType.GOAL: "🎯",
MemoryType.EXPERIENCE: "💡",
MemoryType.CONTEXTUAL: "📝"
}
emoji = type_emoji.get(self.memory_type, "📝")
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def create_memory_chunk(
user_id: str,
subject: str,
predicate: str,
obj: Union[str, Dict],
memory_type: MemoryType,
chat_id: Optional[str] = None,
source_context: Optional[str] = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
**kwargs
) -> MemoryChunk:
"""便捷的内存块创建函数"""
metadata = MemoryMetadata(
memory_id="",
user_id=user_id,
chat_id=chat_id,
created_at=time.time(),
last_accessed=0,
last_modified=0,
confidence=confidence,
importance=importance,
source_context=source_context
)
content = ContentStructure(
subject=subject,
predicate=predicate,
object=obj
)
chunk = MemoryChunk(
metadata=metadata,
content=content,
memory_type=memory_type,
**kwargs
)
return chunk

View File

@@ -0,0 +1,522 @@
# -*- coding: utf-8 -*-
"""
记忆融合与去重机制
避免记忆碎片化,确保长期记忆库的高质量
"""
import time
import hashlib
from typing import Dict, List, Optional, Tuple, Set, Any
from datetime import datetime, timedelta
from dataclasses import dataclass
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
)
logger = get_logger(__name__)
@dataclass
class FusionResult:
"""融合结果"""
original_count: int
fused_count: int
removed_duplicates: int
merged_memories: List[MemoryChunk]
fusion_time: float
details: List[str]
@dataclass
class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: List[MemoryChunk]
similarity_matrix: List[List[float]]
representative_memory: Optional[MemoryChunk] = None
class MemoryFusionEngine:
"""记忆融合引擎"""
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0
}
# 融合策略配置
self.fusion_strategies = {
"semantic_similarity": True, # 语义相似性融合
"temporal_proximity": True, # 时间接近性融合
"logical_consistency": True, # 逻辑一致性融合
"confidence_boosting": True, # 置信度提升
"importance_preservation": True # 重要性保持
}
async def fuse_memories(
self,
new_memories: List[MemoryChunk],
existing_memories: Optional[List[MemoryChunk]] = None
) -> List[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
try:
if not new_memories:
return []
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
# 1. 检测重复记忆组
duplicate_groups = await self._detect_duplicate_groups(
new_memories, existing_memories or []
)
# 2. 对每个重复组进行融合
fused_memories = []
removed_count = 0
for group in duplicate_groups:
if len(group.memories) == 1:
# 单个记忆,直接添加
fused_memories.append(group.memories[0])
else:
# 多个记忆,进行融合
fused_memory = await self._fuse_memory_group(group)
if fused_memory:
fused_memories.append(fused_memory)
removed_count += len(group.memories) - 1
# 3. 更新统计
fusion_time = time.time() - start_time
self._update_fusion_stats(len(new_memories), removed_count, fusion_time)
logger.info(f"✅ 记忆融合完成: {len(fused_memories)} 条记忆,移除 {removed_count} 条重复")
return fused_memories
except Exception as e:
logger.error(f"❌ 记忆融合失败: {e}", exc_info=True)
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self,
new_memories: List[MemoryChunk],
existing_memories: List[MemoryChunk]
) -> List[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
groups = []
processed_ids = set()
for i, memory1 in enumerate(all_memories):
if memory1.memory_id in processed_ids:
continue
# 创建新的重复组
group = DuplicateGroup(
group_id=f"group_{len(groups)}",
memories=[memory1],
similarity_matrix=[[1.0]]
)
processed_ids.add(memory1.memory_id)
# 寻找相似记忆
for j, memory2 in enumerate(all_memories[i+1:], i+1):
if memory2.memory_id in processed_ids:
continue
similarity = self._calculate_comprehensive_similarity(memory1, memory2)
if similarity >= self.similarity_threshold:
group.memories.append(memory2)
processed_ids.add(memory2.memory_id)
# 更新相似度矩阵
self._update_similarity_matrix(group, memory2, similarity)
if len(group.memories) > 1:
# 选择代表性记忆
group.representative_memory = self._select_representative_memory(group)
groups.append(group)
logger.debug(f"检测到 {len(groups)} 个重复记忆组")
return groups
def _calculate_comprehensive_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算综合相似度"""
similarity_scores = []
# 1. 语义向量相似度
if self.fusion_strategies["semantic_similarity"]:
semantic_sim = mem1.calculate_similarity(mem2)
similarity_scores.append(("semantic", semantic_sim))
# 2. 文本相似度
text_sim = self._calculate_text_similarity(mem1.text_content, mem2.text_content)
similarity_scores.append(("text", text_sim))
# 3. 关键词重叠度
keyword_sim = self._calculate_keyword_similarity(mem1.keywords, mem2.keywords)
similarity_scores.append(("keyword", keyword_sim))
# 4. 类型一致性
type_consistency = 1.0 if mem1.memory_type == mem2.memory_type else 0.0
similarity_scores.append(("type", type_consistency))
# 5. 时间接近性
if self.fusion_strategies["temporal_proximity"]:
temporal_sim = self._calculate_temporal_similarity(
mem1.metadata.created_at, mem2.metadata.created_at
)
similarity_scores.append(("temporal", temporal_sim))
# 6. 逻辑一致性
if self.fusion_strategies["logical_consistency"]:
logical_sim = self._calculate_logical_similarity(mem1, mem2)
similarity_scores.append(("logical", logical_sim))
# 计算加权平均相似度
weights = {
"semantic": 0.35,
"text": 0.25,
"keyword": 0.15,
"type": 0.10,
"temporal": 0.10,
"logical": 0.05
}
weighted_sum = 0.0
total_weight = 0.0
for score_type, score in similarity_scores:
weight = weights.get(score_type, 0.1)
weighted_sum += weight * score
total_weight += weight
final_similarity = weighted_sum / total_weight if total_weight > 0 else 0.0
logger.debug(f"综合相似度计算: {final_similarity:.3f} - {[(t, f'{s:.3f}') for t, s in similarity_scores]}")
return final_similarity
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
# 简单的词汇重叠度计算
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
"""计算关键词相似度"""
if not keywords1 or not keywords2:
return 0.0
set1 = set(k.lower() for k in keywords1)
set2 = set(k.lower() for k in keywords2)
intersection = set1 & set2
union = set1 | set2
return len(intersection) / len(union) if union else 0.0
def _calculate_temporal_similarity(self, time1: float, time2: float) -> float:
"""计算时间相似度"""
time_diff = abs(time1 - time2)
hours_diff = time_diff / 3600
# 24小时内相似度较高
if hours_diff <= 24:
return 1.0 - (hours_diff / 24)
elif hours_diff <= 168: # 一周内
return 0.7 - ((hours_diff - 24) / 168) * 0.5
else:
return 0.2
def _calculate_logical_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算逻辑一致性"""
# 检查主谓宾结构的逻辑一致性
consistency_score = 0.0
# 主语一致性
if mem1.content.subject == mem2.content.subject:
consistency_score += 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
consistency_score += predicate_sim * 0.3
# 宾语相似性
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
object_sim = self._calculate_text_similarity(
str(mem1.content.object), str(mem2.content.object)
)
consistency_score += object_sim * 0.3
return consistency_score
def _update_similarity_matrix(self, group: DuplicateGroup, new_memory: MemoryChunk, similarity: float):
"""更新组的相似度矩阵"""
# 为新记忆添加行和列
for i in range(len(group.similarity_matrix)):
group.similarity_matrix[i].append(similarity)
# 添加新行
new_row = [similarity] + [1.0] * len(group.similarity_matrix)
group.similarity_matrix.append(new_row)
def _select_representative_memory(self, group: DuplicateGroup) -> MemoryChunk:
"""选择代表性记忆"""
if not group.memories:
return None
# 评分标准
best_memory = None
best_score = -1.0
for memory in group.memories:
score = 0.0
# 置信度权重
score += memory.metadata.confidence.value * 0.3
# 重要性权重
score += memory.metadata.importance.value * 0.3
# 访问次数权重
score += min(memory.metadata.access_count * 0.1, 0.2)
# 相关度权重
score += memory.metadata.relevance_score * 0.2
if score > best_score:
best_score = score
best_memory = memory
return best_memory
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
"""融合记忆组"""
if not group.memories:
return None
if len(group.memories) == 1:
return group.memories[0]
try:
# 选择基础记忆(通常是代表性记忆)
base_memory = group.representative_memory or group.memories[0]
# 融合其他记忆的属性
fused_memory = await self._merge_memory_attributes(base_memory, group.memories)
# 更新元数据
self._update_fused_metadata(fused_memory, group)
logger.debug(f"成功融合记忆组,包含 {len(group.memories)} 条原始记忆")
return fused_memory
except Exception as e:
logger.error(f"融合记忆组失败: {e}")
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(
self,
base_memory: MemoryChunk,
memories: List[MemoryChunk]
) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
# 合并关键词
all_keywords = set()
for memory in memories:
all_keywords.update(memory.keywords)
fused_memory.keywords = sorted(all_keywords)
# 合并标签
all_tags = set()
for memory in memories:
all_tags.update(memory.tags)
fused_memory.tags = sorted(all_tags)
# 合并分类
all_categories = set()
for memory in memories:
all_categories.update(memory.categories)
fused_memory.categories = sorted(all_categories)
# 合并关联记忆
all_related = set()
for memory in memories:
all_related.update(memory.related_memories)
# 移除对自身和组内记忆的引用
all_related = {rid for rid in all_related if rid not in [m.memory_id for m in memories]}
fused_memory.related_memories = sorted(all_related)
# 合并时间上下文
if self.fusion_strategies["temporal_proximity"]:
fused_memory.temporal_context = self._merge_temporal_context(memories)
return fused_memory
def _update_fused_metadata(self, fused_memory: MemoryChunk, group: DuplicateGroup):
"""更新融合记忆的元数据"""
# 更新修改时间
fused_memory.metadata.last_modified = time.time()
# 计算平均访问次数
total_access = sum(m.metadata.access_count for m in group.memories)
fused_memory.metadata.access_count = total_access
# 提升置信度(如果有多个来源支持)
if self.fusion_strategies["confidence_boosting"] and len(group.memories) > 1:
max_confidence = max(m.metadata.confidence.value for m in group.memories)
if max_confidence < ConfidenceLevel.VERIFIED.value:
fused_memory.metadata.confidence = ConfidenceLevel(
min(max_confidence + 1, ConfidenceLevel.VERIFIED.value)
)
# 保持最高重要性
if self.fusion_strategies["importance_preservation"]:
max_importance = max(m.metadata.importance.value for m in group.memories)
fused_memory.metadata.importance = ImportanceLevel(max_importance)
# 计算平均相关度
avg_relevance = sum(m.metadata.relevance_score for m in group.memories) / len(group.memories)
fused_memory.metadata.relevance_score = min(avg_relevance * 1.1, 1.0) # 稍微提升相关度
# 设置来源信息
source_ids = [m.memory_id[:8] for m in group.memories]
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
"""合并时间上下文"""
contexts = [m.temporal_context for m in memories if m.temporal_context]
if not contexts:
return {}
# 计算时间范围
timestamps = [m.metadata.created_at for m in memories]
earliest_time = min(timestamps)
latest_time = max(timestamps)
merged_context = {
"earliest_timestamp": earliest_time,
"latest_timestamp": latest_time,
"time_span_hours": (latest_time - earliest_time) / 3600,
"source_memories": len(memories)
}
# 合并其他上下文信息
for context in contexts:
for key, value in context.items():
if key not in ["timestamp", "earliest_timestamp", "latest_timestamp"]:
if key not in merged_context:
merged_context[key] = value
elif merged_context[key] != value:
merged_context[key] = f"multiple: {value}"
return merged_context
async def incremental_fusion(
self,
new_memory: MemoryChunk,
existing_memories: List[MemoryChunk]
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
similar_memories = []
for existing in existing_memories:
similarity = self._calculate_comprehensive_similarity(new_memory, existing)
if similarity >= self.similarity_threshold:
similar_memories.append((existing, similarity))
if not similar_memories:
# 没有相似记忆,直接返回
return new_memory, existing_memories
# 按相似度排序
similar_memories.sort(key=lambda x: x[1], reverse=True)
# 与最相似的记忆融合
best_match, similarity = similar_memories[0]
# 创建融合组
group = DuplicateGroup(
group_id=f"incremental_{int(time.time())}",
memories=[new_memory, best_match],
similarity_matrix=[[1.0, similarity], [similarity, 1.0]]
)
# 执行融合
fused_memory = await self._fuse_memory_group(group)
# 从现有记忆中移除被融合的记忆
updated_existing = [m for m in existing_memories if m.memory_id != best_match.memory_id]
updated_existing.append(fused_memory)
logger.debug(f"增量融合完成,相似度: {similarity:.3f}")
return fused_memory, updated_existing
def _update_fusion_stats(self, original_count: int, removed_count: int, fusion_time: float):
"""更新融合统计"""
self.fusion_stats["total_fusions"] += 1
self.fusion_stats["memories_fused"] += original_count
self.fusion_stats["duplicates_removed"] += removed_count
# 更新平均相似度(估算)
if removed_count > 0:
avg_similarity = 0.9 # 假设平均相似度较高
total_similarity = self.fusion_stats["average_similarity"] * (self.fusion_stats["total_fusions"] - 1)
total_similarity += avg_similarity
self.fusion_stats["average_similarity"] = total_similarity / self.fusion_stats["total_fusions"]
async def maintenance(self):
"""维护操作"""
try:
logger.info("开始记忆融合引擎维护...")
# 可以在这里添加定期维护任务,如:
# - 重新评估低置信度记忆
# - 清理孤立记忆引用
# - 优化融合策略参数
logger.info("✅ 记忆融合引擎维护完成")
except Exception as e:
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
def get_fusion_stats(self) -> Dict[str, Any]:
"""获取融合统计信息"""
return self.fusion_stats.copy()
def reset_stats(self):
"""重置统计信息"""
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0
}

View File

@@ -0,0 +1,542 @@
# -*- coding: utf-8 -*-
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_adapter import (
get_enhanced_memory_adapter,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
get_memory_context_for_prompt
)
logger = get_logger(__name__)
@dataclass
class HookResult:
"""钩子执行结果"""
success: bool
data: Any = None
error: Optional[str] = None
processing_time: float = 0.0
class MemoryIntegrationHooks:
"""记忆系统集成钩子"""
def __init__(self):
self.hooks_registered = False
self.hook_stats = {
"message_processing_hooks": 0,
"memory_retrieval_hooks": 0,
"prompt_enhancement_hooks": 0,
"total_hook_executions": 0,
"average_hook_time": 0.0
}
async def register_hooks(self):
"""注册所有集成钩子"""
if self.hooks_registered:
return
try:
logger.info("🔗 注册记忆系统集成钩子...")
# 注册消息处理钩子
await self._register_message_processing_hooks()
# 注册记忆检索钩子
await self._register_memory_retrieval_hooks()
# 注册提示词增强钩子
await self._register_prompt_enhancement_hooks()
# 注册系统维护钩子
await self._register_maintenance_hooks()
self.hooks_registered = True
logger.info("✅ 记忆系统集成钩子注册完成")
except Exception as e:
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
async def _register_message_processing_hooks(self):
"""注册消息处理钩子"""
try:
# 钩子1: 在消息处理后创建记忆
await self._register_post_message_hook()
# 钩子2: 在聊天流保存时处理记忆
await self._register_chat_stream_hook()
logger.debug("消息处理钩子注册完成")
except Exception as e:
logger.error(f"注册消息处理钩子失败: {e}")
async def _register_memory_retrieval_hooks(self):
"""注册记忆检索钩子"""
try:
# 钩子1: 在生成回复前检索相关记忆
await self._register_pre_response_hook()
# 钩子2: 在知识库查询前增强上下文
await self._register_knowledge_query_hook()
logger.debug("记忆检索钩子注册完成")
except Exception as e:
logger.error(f"注册记忆检索钩子失败: {e}")
async def _register_prompt_enhancement_hooks(self):
"""注册提示词增强钩子"""
try:
# 钩子1: 增强提示词构建
await self._register_prompt_building_hook()
logger.debug("提示词增强钩子注册完成")
except Exception as e:
logger.error(f"注册提示词增强钩子失败: {e}")
async def _register_maintenance_hooks(self):
"""注册系统维护钩子"""
try:
# 钩子1: 系统维护时的记忆系统维护
await self._register_system_maintenance_hook()
logger.debug("系统维护钩子注册完成")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
async def _register_post_message_hook(self):
"""注册消息后处理钩子"""
try:
# 这里需要根据实际的系统架构来注册钩子
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
# 尝试注册到事件系统
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
# 注册消息后处理事件
event_manager.subscribe(
EventType.MESSAGE_PROCESSED,
self._on_message_processed_handler
)
logger.debug("已注册到事件系统的消息处理钩子")
except ImportError:
logger.debug("事件系统不可用,跳过事件钩子注册")
# 尝试注册到消息管理器
try:
from src.chat.message_manager import message_manager
# 如果消息管理器支持钩子注册
if hasattr(message_manager, 'register_post_process_hook'):
message_manager.register_post_process_hook(
self._on_message_processed_hook
)
logger.debug("已注册到消息管理器的处理钩子")
except ImportError:
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
except Exception as e:
logger.error(f"注册消息后处理钩子失败: {e}")
async def _register_chat_stream_hook(self):
"""注册聊天流钩子"""
try:
# 尝试注册到聊天流管理器
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if hasattr(chat_manager, 'register_save_hook'):
chat_manager.register_save_hook(
self._on_chat_stream_save_hook
)
logger.debug("已注册到聊天流管理器的保存钩子")
except ImportError:
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
except Exception as e:
logger.error(f"注册聊天流钩子失败: {e}")
async def _register_pre_response_hook(self):
"""注册回复前钩子"""
try:
# 尝试注册到回复生成器
try:
from src.chat.replyer.default_generator import default_generator
if hasattr(default_generator, 'register_pre_generation_hook'):
default_generator.register_pre_generation_hook(
self._on_pre_response_hook
)
logger.debug("已注册到回复生成器的前置钩子")
except ImportError:
logger.debug("回复生成器不可用,跳过回复前钩子注册")
except Exception as e:
logger.error(f"注册回复前钩子失败: {e}")
async def _register_knowledge_query_hook(self):
"""注册知识库查询钩子"""
try:
# 尝试注册到知识库系统
try:
from src.chat.knowledge.knowledge_lib import knowledge_manager
if hasattr(knowledge_manager, 'register_query_enhancer'):
knowledge_manager.register_query_enhancer(
self._on_knowledge_query_hook
)
logger.debug("已注册到知识库的查询增强钩子")
except ImportError:
logger.debug("知识库系统不可用,跳过知识库钩子注册")
except Exception as e:
logger.error(f"注册知识库查询钩子失败: {e}")
async def _register_prompt_building_hook(self):
"""注册提示词构建钩子"""
try:
# 尝试注册到提示词系统
try:
from src.chat.utils.prompt import prompt_manager
if hasattr(prompt_manager, 'register_enhancer'):
prompt_manager.register_enhancer(
self._on_prompt_building_hook
)
logger.debug("已注册到提示词管理器的增强钩子")
except ImportError:
logger.debug("提示词系统不可用,跳过提示词钩子注册")
except Exception as e:
logger.error(f"注册提示词构建钩子失败: {e}")
async def _register_system_maintenance_hook(self):
"""注册系统维护钩子"""
try:
# 尝试注册到系统维护器
try:
from src.manager.async_task_manager import async_task_manager
# 注册定期维护任务
async_task_manager.add_task(MemoryMaintenanceTask())
logger.debug("已注册到系统维护器的定期任务")
except ImportError:
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 提取必要的信息
message_info = message_data.get("message_info", {})
user_info = message_info.get("user_info", {})
conversation_text = message_data.get("processed_plain_text", "")
if not conversation_text:
return HookResult(success=True, data="No conversation text")
user_id = str(user_info.get("user_id", "unknown"))
context = {
"chat_id": message_data.get("chat_id"),
"message_type": message_data.get("message_type", "normal"),
"platform": message_info.get("platform", "unknown"),
"interest_value": message_data.get("interest_value", 0.0),
"keywords": message_data.get("key_words", []),
"timestamp": message_data.get("time", time.time())
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 从聊天流数据中提取对话信息
stream_context = chat_stream_data.get("stream_context", {})
user_id = stream_context.get("user_id", "unknown")
messages = stream_context.get("messages", [])
if not messages:
return HookResult(success=True, data="No messages to process")
# 构建对话文本
conversation_parts = []
for msg in messages[-10:]: # 只处理最近10条消息
text = msg.get("processed_plain_text", "")
if text:
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
conversation_text = "\n".join(conversation_parts)
if not conversation_text:
return HookResult(success=True, data="No conversation text")
context = {
"chat_id": chat_stream_data.get("chat_id"),
"stream_id": chat_stream_data.get("stream_id"),
"platform": chat_stream_data.get("platform", "unknown"),
"message_count": len(messages),
"timestamp": time.time()
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
# 提取查询信息
query = response_data.get("query", "")
user_id = response_data.get("user_id", "unknown")
context = response_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 检索相关记忆
memories = await retrieve_memories_with_enhanced_system(
query, user_id, context, limit=5
)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆添加到响应数据中
response_data["enhanced_memories"] = memories
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
return HookResult(success=True, data=memories, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
query = query_data.get("query", "")
user_id = query_data.get("user_id", "unknown")
context = query_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文并增强查询
memory_context = await get_memory_context_for_prompt(
query, user_id, context, max_memories=3
)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆上下文添加到查询数据中
query_data["enhanced_memory_context"] = memory_context
logger.debug("知识库查询钩子执行成功")
return HookResult(success=True, data=memory_context, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
try:
self.hook_stats["prompt_enhancement_hooks"] += 1
query = prompt_data.get("query", "")
user_id = prompt_data.get("user_id", "unknown")
context = prompt_data.get("context", {})
base_prompt = prompt_data.get("base_prompt", "")
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文
memory_context = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 构建增强的提示词
enhanced_prompt = base_prompt
if memory_context:
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
# 将增强的提示词添加到数据中
prompt_data["enhanced_prompt"] = enhanced_prompt
prompt_data["memory_context"] = memory_context
logger.debug("提示词构建钩子执行成功")
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
def _update_hook_stats(self, processing_time: float):
"""更新钩子统计"""
self.hook_stats["total_hook_executions"] += 1
total_executions = self.hook_stats["total_hook_executions"]
if total_executions > 0:
current_avg = self.hook_stats["average_hook_time"]
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> Dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
class MemoryMaintenanceTask:
"""记忆系统维护任务"""
def __init__(self):
self.task_name = "enhanced_memory_maintenance"
self.interval = 3600 # 1小时执行一次
async def execute(self):
"""执行维护任务"""
try:
logger.info("🔧 执行增强记忆系统维护任务...")
# 获取适配器实例
try:
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
if _enhanced_memory_adapter:
await _enhanced_memory_adapter.maintenance()
logger.info("✅ 增强记忆系统维护任务完成")
else:
logger.debug("增强记忆适配器未初始化,跳过维护")
except Exception as e:
logger.error(f"增强记忆系统维护失败: {e}")
except Exception as e:
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
def get_interval(self) -> int:
"""获取执行间隔"""
return self.interval
def get_task_name(self) -> str:
"""获取任务名称"""
return self.task_name
# 全局钩子实例
_memory_hooks: Optional[MemoryIntegrationHooks] = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
"""获取全局记忆集成钩子实例"""
global _memory_hooks
if _memory_hooks is None:
_memory_hooks = MemoryIntegrationHooks()
await _memory_hooks.register_hooks()
return _memory_hooks
async def initialize_memory_integration_hooks():
"""初始化记忆集成钩子"""
try:
logger.info("🚀 初始化记忆集成钩子...")
hooks = await get_memory_integration_hooks()
logger.info("✅ 记忆集成钩子初始化完成")
return hooks
except Exception as e:
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,832 @@
# -*- coding: utf-8 -*-
"""
元数据索引系统
为记忆系统提供多维度的精准过滤和查询能力
"""
import os
import time
import orjson
from typing import Dict, List, Optional, Tuple, Set, Any, Union
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from enum import Enum
import threading
from collections import defaultdict
from pathlib import Path
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
class IndexType(Enum):
"""索引类型"""
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
TIMESTAMP = "timestamp" # 时间索引
CONFIDENCE = "confidence" # 置信度索引
IMPORTANCE = "importance" # 重要性索引
RELATIONSHIP_SCORE = "relationship_score" # 关系分索引
ACCESS_FREQUENCY = "access_frequency" # 访问频率索引
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
@dataclass
class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
time_range: Optional[Tuple[float, float]] = None
confidence_levels: Optional[List[ConfidenceLevel]] = None
importance_levels: Optional[List[ImportanceLevel]] = None
min_relationship_score: Optional[float] = None
max_relationship_score: Optional[float] = None
min_access_count: Optional[int] = None
semantic_hashes: Optional[List[str]] = None
limit: Optional[int] = None
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
sort_order: str = "desc" # "asc", "desc"
@dataclass
class IndexResult:
"""索引结果"""
memory_ids: List[str]
total_count: int
query_time: float
filtered_by: List[str]
class MetadataIndexManager:
"""元数据索引管理器"""
def __init__(self, index_path: str = "data/memory_metadata"):
self.index_path = Path(index_path)
self.index_path.mkdir(parents=True, exist_ok=True)
# 各类索引
self.indices = {
IndexType.MEMORY_TYPE: defaultdict(set),
IndexType.USER_ID: defaultdict(set),
IndexType.KEYWORD: defaultdict(set),
IndexType.TAG: defaultdict(set),
IndexType.CATEGORY: defaultdict(set),
IndexType.CONFIDENCE: defaultdict(set),
IndexType.IMPORTANCE: defaultdict(set),
IndexType.SEMANTIC_HASH: defaultdict(set),
}
# 时间索引(特殊处理)
self.time_index = [] # [(timestamp, memory_id), ...]
self.relationship_index = [] # [(relationship_score, memory_id), ...]
self.access_frequency_index = [] # [(access_count, memory_id), ...]
# 内存缓存
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
# 统计信息
self.index_stats = {
"total_memories": 0,
"index_build_time": 0.0,
"average_query_time": 0.0,
"total_queries": 0,
"cache_hit_rate": 0.0,
"cache_hits": 0
}
# 线程锁
self._lock = threading.RLock()
self._dirty = False # 标记索引是否有未保存的更改
# 自动保存配置
self.auto_save_interval = 500 # 每500次操作自动保存
self._operation_count = 0
async def index_memories(self, memories: List[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
return
start_time = time.time()
try:
with self._lock:
for memory in memories:
self._index_single_memory(memory)
# 标记为需要保存
self._dirty = True
self._operation_count += len(memories)
# 自动保存检查
if self._operation_count >= self.auto_save_interval:
await self.save_index()
self._operation_count = 0
index_time = time.time() - start_time
self.index_stats["index_build_time"] = (
(self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) /
len(memories)
)
logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}")
except Exception as e:
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
def _index_single_memory(self, memory: MemoryChunk):
"""为单个记忆建立索引"""
memory_id = memory.memory_id
# 更新内存缓存
self.memory_metadata_cache[memory_id] = {
"user_id": memory.user_id,
"memory_type": memory.memory_type,
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed,
"access_count": memory.metadata.access_count,
"confidence": memory.metadata.confidence,
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash
}
# 记忆类型索引
self.indices[IndexType.MEMORY_TYPE][memory.memory_type].add(memory_id)
# 用户ID索引
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
# 关键词索引
for keyword in memory.keywords:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
# 标签索引
for tag in memory.tags:
self.indices[IndexType.TAG][tag.lower()].add(memory_id)
# 分类索引
for category in memory.categories:
self.indices[IndexType.CATEGORY][category.lower()].add(memory_id)
# 置信度索引
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory_id)
# 重要性索引
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory_id)
# 语义哈希索引
if memory.semantic_hash:
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory_id)
# 时间索引(插入排序保持有序)
self._insert_into_time_index(memory.metadata.created_at, memory_id)
# 关系分索引(插入排序保持有序)
self._insert_into_relationship_index(memory.metadata.relationship_score, memory_id)
# 访问频率索引(插入排序保持有序)
self._insert_into_access_frequency_index(memory.metadata.access_count, memory_id)
# 更新统计
self.index_stats["total_memories"] += 1
def _insert_into_time_index(self, timestamp: float, memory_id: str):
"""插入时间索引(保持降序)"""
insert_pos = len(self.time_index)
for i, (ts, _) in enumerate(self.time_index):
if timestamp >= ts:
insert_pos = i
break
self.time_index.insert(insert_pos, (timestamp, memory_id))
def _insert_into_relationship_index(self, relationship_score: float, memory_id: str):
"""插入关系分索引(保持降序)"""
insert_pos = len(self.relationship_index)
for i, (score, _) in enumerate(self.relationship_index):
if relationship_score >= score:
insert_pos = i
break
self.relationship_index.insert(insert_pos, (relationship_score, memory_id))
def _insert_into_access_frequency_index(self, access_count: int, memory_id: str):
"""插入访问频率索引(保持降序)"""
insert_pos = len(self.access_frequency_index)
for i, (count, _) in enumerate(self.access_frequency_index):
if access_count >= count:
insert_pos = i
break
self.access_frequency_index.insert(insert_pos, (access_count, memory_id))
async def query_memories(self, query: IndexQuery) -> IndexResult:
"""查询记忆"""
start_time = time.time()
try:
with self._lock:
# 获取候选记忆ID集合
candidate_ids = self._get_candidate_memories(query)
# 应用过滤条件
filtered_ids = self._apply_filters(candidate_ids, query)
# 排序
if query.sort_by:
filtered_ids = self._sort_memories(filtered_ids, query.sort_by, query.sort_order)
# 限制数量
if query.limit and len(filtered_ids) > query.limit:
filtered_ids = filtered_ids[:query.limit]
# 记录查询统计
query_time = time.time() - start_time
self.index_stats["total_queries"] += 1
self.index_stats["average_query_time"] = (
(self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) /
self.index_stats["total_queries"]
)
return IndexResult(
memory_ids=filtered_ids,
total_count=len(filtered_ids),
query_time=query_time,
filtered_by=self._get_applied_filters(query)
)
except Exception as e:
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
"""获取候选记忆ID集合"""
candidate_ids = set()
# 获取所有记忆ID作为起点
all_memory_ids = set(self.memory_metadata_cache.keys())
if not all_memory_ids:
return candidate_ids
# 应用最严格的过滤条件
applied_filters = []
if query.user_ids:
user_ids_set = set()
for user_id in query.user_ids:
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
candidate_ids.update(user_ids_set)
applied_filters.append("user_ids")
if query.memory_types:
memory_types_set = set()
for memory_type in query.memory_types:
memory_types_set.update(self.indices[IndexType.MEMORY_TYPE].get(memory_type, set()))
if applied_filters:
candidate_ids &= memory_types_set
else:
candidate_ids.update(memory_types_set)
applied_filters.append("memory_types")
if query.keywords:
keywords_set = set()
for keyword in query.keywords:
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
if applied_filters:
candidate_ids &= keywords_set
else:
candidate_ids.update(keywords_set)
applied_filters.append("keywords")
if query.tags:
tags_set = set()
for tag in query.tags:
tags_set.update(self.indices[IndexType.TAG].get(tag.lower(), set()))
if applied_filters:
candidate_ids &= tags_set
else:
candidate_ids.update(tags_set)
applied_filters.append("tags")
if query.categories:
categories_set = set()
for category in query.categories:
categories_set.update(self.indices[IndexType.CATEGORY].get(category.lower(), set()))
if applied_filters:
candidate_ids &= categories_set
else:
candidate_ids.update(categories_set)
applied_filters.append("categories")
# 如果没有应用任何过滤条件,返回所有记忆
if not applied_filters:
return all_memory_ids
return candidate_ids
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
# 时间范围过滤
if query.time_range:
start_time, end_time = query.time_range
filtered_ids = [
memory_id for memory_id in filtered_ids
if self._is_in_time_range(memory_id, start_time, end_time)
]
# 置信度过滤
if query.confidence_levels:
confidence_set = set(query.confidence_levels)
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set
]
# 重要性过滤
if query.importance_levels:
importance_set = set(query.importance_levels)
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["importance"] in importance_set
]
# 关系分范围过滤
if query.min_relationship_score is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score
]
if query.max_relationship_score is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score
]
# 最小访问次数过滤
if query.min_access_count is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count
]
# 语义哈希过滤
if query.semantic_hashes:
hash_set = set(query.semantic_hashes)
filtered_ids = [
memory_id for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set
]
return filtered_ids
def _is_in_time_range(self, memory_id: str, start_time: float, end_time: float) -> bool:
"""检查记忆是否在时间范围内"""
created_at = self.memory_metadata_cache[memory_id]["created_at"]
return start_time <= created_at <= end_time
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
"""对记忆进行排序"""
if sort_by == "created_at":
# 使用时间索引(已经有序)
if sort_order == "desc":
return memory_ids # 时间索引已经是降序
else:
return memory_ids[::-1] # 反转为升序
elif sort_by == "access_count":
# 使用访问频率索引(已经有序)
if sort_order == "desc":
return memory_ids # 访问频率索引已经是降序
else:
return memory_ids[::-1] # 反转为升序
elif sort_by == "relevance_score":
# 按相关度排序
memory_ids.sort(
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"],
reverse=(sort_order == "desc")
)
elif sort_by == "relationship_score":
# 使用关系分索引(已经有序)
if sort_order == "desc":
return memory_ids # 关系分索引已经是降序
else:
return memory_ids[::-1] # 反转为升序
elif sort_by == "last_accessed":
# 按最后访问时间排序
memory_ids.sort(
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"],
reverse=(sort_order == "desc")
)
return memory_ids
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
"""获取应用的过滤器列表"""
filters = []
if query.user_ids:
filters.append("user_ids")
if query.memory_types:
filters.append("memory_types")
if query.keywords:
filters.append("keywords")
if query.tags:
filters.append("tags")
if query.categories:
filters.append("categories")
if query.time_range:
filters.append("time_range")
if query.confidence_levels:
filters.append("confidence_levels")
if query.importance_levels:
filters.append("importance_levels")
if query.min_relationship_score is not None or query.max_relationship_score is not None:
filters.append("relationship_score_range")
if query.min_access_count is not None:
filters.append("min_access_count")
if query.semantic_hashes:
filters.append("semantic_hashes")
return filters
async def update_memory_index(self, memory: MemoryChunk):
"""更新记忆索引"""
with self._lock:
try:
memory_id = memory.memory_id
# 如果记忆已存在,先删除旧索引
if memory_id in self.memory_metadata_cache:
await self.remove_memory_index(memory_id)
# 重新建立索引
self._index_single_memory(memory)
self._dirty = True
self._operation_count += 1
# 自动保存检查
if self._operation_count >= self.auto_save_interval:
await self.save_index()
self._operation_count = 0
logger.debug(f"更新记忆索引完成: {memory_id}")
except Exception as e:
logger.error(f"❌ 更新记忆索引失败: {e}")
async def remove_memory_index(self, memory_id: str):
"""移除记忆索引"""
with self._lock:
try:
if memory_id not in self.memory_metadata_cache:
return
# 获取记忆元数据
metadata = self.memory_metadata_cache[memory_id]
# 从各类索引中移除
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
# 从时间索引中移除
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
# 从关系分索引中移除
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id]
# 从访问频率索引中移除
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id]
# 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理
# 实际实现中可能需要重新加载记忆或维护反向索引
# 从缓存中移除
del self.memory_metadata_cache[memory_id]
# 更新统计
self.index_stats["total_memories"] = max(0, self.index_stats["total_memories"] - 1)
self._dirty = True
logger.debug(f"移除记忆索引完成: {memory_id}")
except Exception as e:
logger.error(f"❌ 移除记忆索引失败: {e}")
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""获取记忆元数据"""
return self.memory_metadata_cache.get(memory_id)
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
"""获取用户的所有记忆ID"""
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
if limit and len(user_memory_ids) > limit:
user_memory_ids = user_memory_ids[:limit]
return user_memory_ids
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
"""获取记忆统计信息"""
stats = {
"total_memories": self.index_stats["total_memories"],
"memory_types": {},
"average_confidence": 0.0,
"average_importance": 0.0,
"average_relationship_score": 0.0,
"top_keywords": [],
"top_tags": []
}
if user_id:
# 限定用户统计
user_memory_ids = self.indices[IndexType.USER_ID].get(user_id, set())
stats["user_total_memories"] = len(user_memory_ids)
if not user_memory_ids:
return stats
# 用户记忆类型分布
user_types = {}
for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items():
user_count = len(user_memory_ids & memory_ids)
if user_count > 0:
user_types[memory_type.value] = user_count
stats["memory_types"] = user_types
# 计算用户平均值
user_confidences = []
user_importances = []
user_relationship_scores = []
for memory_id in user_memory_ids:
metadata = self.memory_metadata_cache.get(memory_id, {})
if metadata:
user_confidences.append(metadata["confidence"].value)
user_importances.append(metadata["importance"].value)
user_relationship_scores.append(metadata["relationship_score"])
if user_confidences:
stats["average_confidence"] = sum(user_confidences) / len(user_confidences)
if user_importances:
stats["average_importance"] = sum(user_importances) / len(user_importances)
if user_relationship_scores:
stats["average_relationship_score"] = sum(user_relationship_scores) / len(user_relationship_scores)
else:
# 全局统计
for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items():
stats["memory_types"][memory_type.value] = len(memory_ids)
# 计算全局平均值
if self.memory_metadata_cache:
all_confidences = [m["confidence"].value for m in self.memory_metadata_cache.values()]
all_importances = [m["importance"].value for m in self.memory_metadata_cache.values()]
all_relationship_scores = [m["relationship_score"] for m in self.memory_metadata_cache.values()]
if all_confidences:
stats["average_confidence"] = sum(all_confidences) / len(all_confidences)
if all_importances:
stats["average_importance"] = sum(all_importances) / len(all_importances)
if all_relationship_scores:
stats["average_relationship_score"] = sum(all_relationship_scores) / len(all_relationship_scores)
# 统计热门关键词和标签
keyword_counts = [(keyword, len(memory_ids)) for keyword, memory_ids in self.indices[IndexType.KEYWORD].items()]
keyword_counts.sort(key=lambda x: x[1], reverse=True)
stats["top_keywords"] = keyword_counts[:10]
tag_counts = [(tag, len(memory_ids)) for tag, memory_ids in self.indices[IndexType.TAG].items()]
tag_counts.sort(key=lambda x: x[1], reverse=True)
stats["top_tags"] = tag_counts[:10]
return stats
async def save_index(self):
"""保存索引到文件"""
if not self._dirty:
return
try:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data = {}
for index_type, index_data in self.indices.items():
indices_data[index_type.value] = {
key: list(values) for key, values in index_data.items()
}
indices_file = self.index_path / "indices.json"
with open(indices_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存时间索引
time_index_file = self.index_path / "time_index.json"
with open(time_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
with open(relationship_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
with open(access_frequency_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存统计信息
stats_file = self.index_path / "index_stats.json"
with open(stats_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
self._dirty = False
logger.info("✅ 元数据索引保存完成")
except Exception as e:
logger.error(f"❌ 保存元数据索引失败: {e}")
async def load_index(self):
"""从文件加载索引"""
try:
logger.info("正在加载元数据索引...")
# 加载各类索引
indices_file = self.index_path / "indices.json"
if indices_file.exists():
with open(indices_file, 'r', encoding='utf-8') as f:
indices_data = orjson.loads(f.read())
for index_type_value, index_data in indices_data.items():
index_type = IndexType(index_type_value)
self.indices[index_type] = {
key: set(values) for key, values in index_data.items()
}
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
if time_index_file.exists():
with open(time_index_file, 'r', encoding='utf-8') as f:
self.time_index = orjson.loads(f.read())
# 加载关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
if relationship_index_file.exists():
with open(relationship_index_file, 'r', encoding='utf-8') as f:
self.relationship_index = orjson.loads(f.read())
# 加载访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
if access_frequency_index_file.exists():
with open(access_frequency_index_file, 'r', encoding='utf-8') as f:
self.access_frequency_index = orjson.loads(f.read())
# 加载元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
if metadata_cache_file.exists():
with open(metadata_cache_file, 'r', encoding='utf-8') as f:
cache_data = orjson.loads(f.read())
# 转换置信度和重要性为枚举类型
for memory_id, metadata in cache_data.items():
if isinstance(metadata["confidence"], str):
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
if isinstance(metadata["importance"], str):
metadata["importance"] = ImportanceLevel(metadata["importance"])
self.memory_metadata_cache = cache_data
# 加载统计信息
stats_file = self.index_path / "index_stats.json"
if stats_file.exists():
with open(stats_file, 'r', encoding='utf-8') as f:
self.index_stats = orjson.loads(f.read())
# 更新记忆计数
self.index_stats["total_memories"] = len(self.memory_metadata_cache)
logger.info(f"✅ 元数据索引加载完成,{self.index_stats['total_memories']} 个记忆")
except Exception as e:
logger.error(f"❌ 加载元数据索引失败: {e}")
async def optimize_index(self):
"""优化索引"""
try:
logger.info("开始元数据索引优化...")
# 清理无效引用
self._cleanup_invalid_references()
# 重建有序索引
self._rebuild_ordered_indices()
# 清理低频关键词和标签
self._cleanup_low_frequency_terms()
# 更新统计信息
if self.index_stats["total_queries"] > 0:
self.index_stats["cache_hit_rate"] = (
self.index_stats["cache_hits"] / self.index_stats["total_queries"]
)
logger.info("✅ 元数据索引优化完成")
except Exception as e:
logger.error(f"❌ 元数据索引优化失败: {e}")
def _cleanup_invalid_references(self):
"""清理无效引用"""
valid_memory_ids = set(self.memory_metadata_cache.keys())
# 清理各类索引中的无效引用
for index_type in self.indices:
for key in list(self.indices[index_type].keys()):
valid_ids = self.indices[index_type][key] & valid_memory_ids
self.indices[index_type][key] = valid_ids
# 如果某类别下没有记忆了,删除该类别
if not valid_ids:
del self.indices[index_type][key]
# 清理时间索引中的无效引用
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid in valid_memory_ids]
# 清理关系分索引中的无效引用
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids]
# 清理访问频率索引中的无效引用
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids]
# 更新总记忆数
self.index_stats["total_memories"] = len(valid_memory_ids)
def _rebuild_ordered_indices(self):
"""重建有序索引"""
# 重建时间索引
self.time_index.sort(key=lambda x: x[0], reverse=True)
# 重建关系分索引
self.relationship_index.sort(key=lambda x: x[0], reverse=True)
# 重建访问频率索引
self.access_frequency_index.sort(key=lambda x: x[0], reverse=True)
def _cleanup_low_frequency_terms(self, min_frequency: int = 2):
"""清理低频术语"""
# 清理低频关键词
for keyword in list(self.indices[IndexType.KEYWORD].keys()):
if len(self.indices[IndexType.KEYWORD][keyword]) < min_frequency:
del self.indices[IndexType.KEYWORD][keyword]
# 清理低频标签
for tag in list(self.indices[IndexType.TAG].keys()):
if len(self.indices[IndexType.TAG][tag]) < min_frequency:
del self.indices[IndexType.TAG][tag]
# 清理低频分类
for category in list(self.indices[IndexType.CATEGORY].keys()):
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
del self.indices[IndexType.CATEGORY][category]
def get_index_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
stats = self.index_stats.copy()
if stats["total_queries"] > 0:
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_queries"]
else:
stats["cache_hit_rate"] = 0.0
# 添加索引详细信息
stats["index_details"] = {
"memory_types": len(self.indices[IndexType.MEMORY_TYPE]),
"user_ids": len(self.indices[IndexType.USER_ID]),
"keywords": len(self.indices[IndexType.KEYWORD]),
"tags": len(self.indices[IndexType.TAG]),
"categories": len(self.indices[IndexType.CATEGORY]),
"confidence_levels": len(self.indices[IndexType.CONFIDENCE]),
"importance_levels": len(self.indices[IndexType.IMPORTANCE]),
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH])
}
return stats

View File

@@ -0,0 +1,595 @@
# -*- coding: utf-8 -*-
"""
多阶段召回机制
实现粗粒度到细粒度的记忆检索优化
"""
import time
import asyncio
from typing import Dict, List, Optional, Tuple, Set, Any
from dataclasses import dataclass
from enum import Enum
import numpy as np
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
class RetrievalStage(Enum):
"""检索阶段"""
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
@dataclass
class RetrievalConfig:
"""检索配置"""
# 各阶段配置
metadata_filter_limit: int = 100 # 元数据过滤阶段返回数量
vector_search_limit: int = 50 # 向量搜索阶段返回数量
semantic_rerank_limit: int = 20 # 语义重排序阶段返回数量
final_result_limit: int = 10 # 最终结果数量
# 相似度阈值
vector_similarity_threshold: float = 0.7 # 向量相似度阈值
semantic_similarity_threshold: float = 0.6 # 语义相似度阈值
# 权重配置
vector_weight: float = 0.4 # 向量相似度权重
semantic_weight: float = 0.3 # 语义相似度权重
context_weight: float = 0.2 # 上下文权重
recency_weight: float = 0.1 # 时效性权重
@classmethod
def from_global_config(cls):
"""从全局配置创建配置实例"""
from src.config.config import global_config
return cls(
# 各阶段配置
metadata_filter_limit=global_config.memory.metadata_filter_limit,
vector_search_limit=global_config.memory.vector_search_limit,
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
final_result_limit=global_config.memory.final_result_limit,
# 相似度阈值
vector_similarity_threshold=global_config.memory.vector_similarity_threshold,
semantic_similarity_threshold=0.6, # 保持默认值
# 权重配置
vector_weight=global_config.memory.vector_weight,
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight
)
@dataclass
class StageResult:
"""阶段结果"""
stage: RetrievalStage
memory_ids: List[str]
processing_time: float
filtered_count: int
score_threshold: float
@dataclass
class RetrievalResult:
"""检索结果"""
query: str
user_id: str
final_memories: List[MemoryChunk]
stage_results: List[StageResult]
total_processing_time: float
total_filtered: int
retrieval_stats: Dict[str, Any]
class MultiStageRetrieval:
"""多阶段召回系统"""
def __init__(self, config: Optional[RetrievalConfig] = None):
self.config = config or RetrievalConfig.from_global_config()
self.retrieval_stats = {
"total_queries": 0,
"average_retrieval_time": 0.0,
"stage_stats": {
"metadata_filtering": {"calls": 0, "avg_time": 0.0},
"vector_search": {"calls": 0, "avg_time": 0.0},
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
"contextual_filtering": {"calls": 0, "avg_time": 0.0}
}
}
async def retrieve_memories(
self,
query: str,
user_id: str,
context: Dict[str, Any],
metadata_index,
vector_storage,
all_memories_cache: Dict[str, MemoryChunk],
limit: Optional[int] = None
) -> RetrievalResult:
"""多阶段记忆检索"""
start_time = time.time()
limit = limit or self.config.final_result_limit
stage_results = []
current_memory_ids = set()
try:
logger.debug(f"开始多阶段检索query='{query}', user_id='{user_id}'")
# 阶段1元数据过滤
stage1_result = await self._metadata_filtering_stage(
query, user_id, context, metadata_index, all_memories_cache
)
stage_results.append(stage1_result)
current_memory_ids.update(stage1_result.memory_ids)
# 阶段2向量搜索
stage2_result = await self._vector_search_stage(
query, user_id, context, vector_storage, current_memory_ids, all_memories_cache
)
stage_results.append(stage2_result)
current_memory_ids.update(stage2_result.memory_ids)
# 阶段3语义重排序
stage3_result = await self._semantic_reranking_stage(
query, user_id, context, current_memory_ids, all_memories_cache
)
stage_results.append(stage3_result)
# 阶段4上下文过滤
stage4_result = await self._contextual_filtering_stage(
query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit
)
stage_results.append(stage4_result)
# 获取最终记忆对象
final_memories = []
for memory_id in stage4_result.memory_ids:
if memory_id in all_memories_cache:
final_memories.append(all_memories_cache[memory_id])
# 更新统计
total_time = time.time() - start_time
self._update_retrieval_stats(total_time, stage_results)
total_filtered = sum(result.filtered_count for result in stage_results)
logger.debug(f"多阶段检索完成:返回 {len(final_memories)} 条记忆,耗时 {total_time:.3f}s")
return RetrievalResult(
query=query,
user_id=user_id,
final_memories=final_memories,
stage_results=stage_results,
total_processing_time=total_time,
total_filtered=total_filtered,
retrieval_stats=self.retrieval_stats.copy()
)
except Exception as e:
logger.error(f"多阶段检索失败: {e}", exc_info=True)
# 返回空结果
return RetrievalResult(
query=query,
user_id=user_id,
final_memories=[],
stage_results=stage_results,
total_processing_time=time.time() - start_time,
total_filtered=0,
retrieval_stats=self.retrieval_stats.copy()
)
async def _metadata_filtering_stage(
self,
query: str,
user_id: str,
context: Dict[str, Any],
metadata_index,
all_memories_cache: Dict[str, MemoryChunk]
) -> StageResult:
"""阶段1元数据过滤"""
start_time = time.time()
try:
from .metadata_index import IndexQuery
# 构建索引查询
index_query = IndexQuery(
user_ids=[user_id],
memory_types=self._extract_memory_types_from_context(context),
keywords=self._extract_keywords_from_query(query),
limit=self.config.metadata_filter_limit,
sort_by="last_accessed",
sort_order="desc"
)
# 执行查询
result = await metadata_index.query_memories(index_query)
filtered_count = result.total_count - len(result.memory_ids)
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
return StageResult(
stage=RetrievalStage.METADATA_FILTERING,
memory_ids=result.memory_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=0.0
)
except Exception as e:
logger.error(f"元数据过滤阶段失败: {e}")
return StageResult(
stage=RetrievalStage.METADATA_FILTERING,
memory_ids=[],
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=0.0
)
async def _vector_search_stage(
self,
query: str,
user_id: str,
context: Dict[str, Any],
vector_storage,
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk]
) -> StageResult:
"""阶段2向量搜索"""
start_time = time.time()
try:
# 生成查询向量
query_embedding = await self._generate_query_embedding(query, context)
if not query_embedding:
return StageResult(
stage=RetrievalStage.VECTOR_SEARCH,
memory_ids=[],
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=self.config.vector_similarity_threshold
)
# 执行向量搜索
search_result = await vector_storage.search_similar(
query_embedding,
limit=self.config.vector_search_limit
)
# 过滤候选记忆
filtered_memories = []
for memory_id, similarity in search_result:
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
filtered_memories.append((memory_id, similarity))
# 按相似度排序
filtered_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
filtered_count = len(candidate_ids) - len(result_ids)
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
return StageResult(
stage=RetrievalStage.VECTOR_SEARCH,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=self.config.vector_similarity_threshold
)
except Exception as e:
logger.error(f"向量搜索阶段失败: {e}")
return StageResult(
stage=RetrievalStage.VECTOR_SEARCH,
memory_ids=[],
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=self.config.vector_similarity_threshold
)
async def _semantic_reranking_stage(
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk]
) -> StageResult:
"""阶段3语义重排序"""
start_time = time.time()
try:
reranked_memories = []
for memory_id in candidate_ids:
if memory_id not in all_memories_cache:
continue
memory = all_memories_cache[memory_id]
# 计算综合语义相似度
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
if semantic_score >= self.config.semantic_similarity_threshold:
reranked_memories.append((memory_id, semantic_score))
# 按语义相似度排序
reranked_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]]
filtered_count = len(candidate_ids) - len(result_ids)
logger.debug(f"语义重排序:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
return StageResult(
stage=RetrievalStage.SEMANTIC_RERANKING,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=self.config.semantic_similarity_threshold
)
except Exception as e:
logger.error(f"语义重排序阶段失败: {e}")
return StageResult(
stage=RetrievalStage.SEMANTIC_RERANKING,
memory_ids=list(candidate_ids), # 失败时返回原候选集
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=self.config.semantic_similarity_threshold
)
async def _contextual_filtering_stage(
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: List[str],
all_memories_cache: Dict[str, MemoryChunk],
limit: int
) -> StageResult:
"""阶段4上下文过滤"""
start_time = time.time()
try:
final_memories = []
for memory_id in candidate_ids:
if memory_id not in all_memories_cache:
continue
memory = all_memories_cache[memory_id]
# 计算上下文相关度评分
context_score = await self._calculate_context_relevance(query, memory, context)
# 结合多因子评分
final_score = await self._calculate_final_score(query, memory, context, context_score)
final_memories.append((memory_id, final_score))
# 按最终评分排序
final_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in final_memories[:limit]]
filtered_count = len(candidate_ids) - len(result_ids)
logger.debug(f"上下文过滤:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
return StageResult(
stage=RetrievalStage.CONTEXTUAL_FILTERING,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=0.0 # 动态阈值
)
except Exception as e:
logger.error(f"上下文过滤阶段失败: {e}")
return StageResult(
stage=RetrievalStage.CONTEXTUAL_FILTERING,
memory_ids=candidate_ids[:limit], # 失败时返回前limit个
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=0.0
)
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
"""生成查询向量"""
try:
# 这里应该调用embedding模型
# 由于我们可能没有直接的embedding模型返回None或使用简单的方法
# 在实际实现中这里应该调用与记忆存储相同的embedding模型
return None
except Exception as e:
logger.warning(f"生成查询向量失败: {e}")
return None
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""计算语义相似度"""
try:
# 简单的文本相似度计算
query_words = set(query.lower().split())
memory_words = set(memory.text_content.lower().split())
if not query_words or not memory_words:
return 0.0
intersection = query_words & memory_words
union = query_words | memory_words
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
except Exception as e:
logger.warning(f"计算语义相似度失败: {e}")
return 0.0
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""计算上下文相关度"""
try:
score = 0.0
# 检查记忆类型是否匹配上下文
if context.get("expected_memory_types"):
if memory.memory_type in context["expected_memory_types"]:
score += 0.3
# 检查关键词匹配
if context.get("keywords"):
memory_keywords = set(memory.keywords)
context_keywords = set(context["keywords"])
overlap = memory_keywords & context_keywords
if overlap:
score += len(overlap) / max(len(context_keywords), 1) * 0.4
# 检查时效性
if context.get("recent_only", False):
memory_age = time.time() - memory.metadata.created_at
if memory_age < 7 * 24 * 3600: # 7天内
score += 0.3
return min(score, 1.0)
except Exception as e:
logger.warning(f"计算上下文相关度失败: {e}")
return 0.0
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
"""计算最终评分"""
try:
# 语义相似度
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
# 向量相似度(如果有)
vector_score = 0.0
if memory.embedding:
# 这里应该有向量相似度计算,简化处理
vector_score = 0.5
# 时效性评分
recency_score = self._calculate_recency_score(memory.metadata.created_at)
# 权重组合
final_score = (
semantic_score * self.config.semantic_weight +
vector_score * self.config.vector_weight +
context_score * self.config.context_weight +
recency_score * self.config.recency_weight
)
# 加入记忆重要性权重
importance_weight = memory.metadata.importance.value / 4.0 # 标准化到0-1
final_score = final_score * (0.7 + importance_weight * 0.3) # 重要性影响30%
return final_score
except Exception as e:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_recency_score(self, timestamp: float) -> float:
"""计算时效性评分"""
try:
age = time.time() - timestamp
age_days = age / (24 * 3600)
if age_days < 1:
return 1.0
elif age_days < 7:
return 0.8
elif age_days < 30:
return 0.6
elif age_days < 90:
return 0.4
else:
return 0.2
except Exception:
return 0.5
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
"""从上下文中提取记忆类型"""
try:
if "expected_memory_types" in context:
return context["expected_memory_types"]
# 根据上下文推断记忆类型
if "message_type" in context:
message_type = context["message_type"]
if message_type in ["personal_info", "fact"]:
return [MemoryType.PERSONAL_FACT]
elif message_type in ["event", "activity"]:
return [MemoryType.EVENT]
elif message_type in ["preference", "like"]:
return [MemoryType.PREFERENCE]
elif message_type in ["opinion", "view"]:
return [MemoryType.OPINION]
return []
except Exception:
return []
def _extract_keywords_from_query(self, query: str) -> List[str]:
"""从查询中提取关键词"""
try:
# 简单的关键词提取
words = query.lower().split()
# 过滤停用词
stopwords = {"", "", "", "", "", "", "", "", "", "", "", "", "", ""}
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
return keywords[:10] # 最多返回10个关键词
except Exception:
return []
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
"""更新检索统计"""
self.retrieval_stats["total_queries"] += 1
# 更新平均检索时间
current_avg = self.retrieval_stats["average_retrieval_time"]
total_queries = self.retrieval_stats["total_queries"]
new_avg = (current_avg * (total_queries - 1) + total_time) / total_queries
self.retrieval_stats["average_retrieval_time"] = new_avg
# 更新各阶段统计
for result in stage_results:
stage_name = result.stage.value
if stage_name in self.retrieval_stats["stage_stats"]:
stage_stat = self.retrieval_stats["stage_stats"][stage_name]
stage_stat["calls"] += 1
current_stage_avg = stage_stat["avg_time"]
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"]
stage_stat["avg_time"] = new_stage_avg
def get_retrieval_stats(self) -> Dict[str, Any]:
"""获取检索统计信息"""
return self.retrieval_stats.copy()
def reset_stats(self):
"""重置统计信息"""
self.retrieval_stats = {
"total_queries": 0,
"average_retrieval_time": 0.0,
"stage_stats": {
"metadata_filtering": {"calls": 0, "avg_time": 0.0},
"vector_search": {"calls": 0, "avg_time": 0.0},
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
"contextual_filtering": {"calls": 0, "avg_time": 0.0}
}
}

View File

@@ -1,126 +0,0 @@
import numpy as np
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
"""
初始化记忆构建调度器
参数:
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
std_hours1 (float): 第一个分布的标准差(小时)
weight1 (float): 第一个分布的权重
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
std_hours2 (float): 第二个分布的标准差(小时)
weight2 (float): 第二个分布的权重
total_samples (int): 要生成的总时间点数量
"""
# 验证参数
if total_samples <= 0:
raise ValueError("total_samples 必须大于0")
if weight1 < 0 or weight2 < 0:
raise ValueError("权重必须为非负数")
if std_hours1 < 0 or std_hours2 < 0:
raise ValueError("标准差必须为非负数")
# 归一化权重
total_weight = weight1 + weight2
if total_weight == 0:
raise ValueError("权重总和不能为0")
self.weight1 = weight1 / total_weight
self.weight2 = weight2 / total_weight
self.n_hours1 = n_hours1
self.std_hours1 = std_hours1
self.n_hours2 = n_hours2
self.std_hours2 = std_hours2
self.total_samples = total_samples
self.base_time = datetime.now()
def generate_time_samples(self):
"""生成混合分布的时间采样点"""
# 根据权重计算每个分布的样本数
samples1 = max(1, int(self.total_samples * self.weight1))
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
# 生成两个正态分布的小时偏移
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
# 合并两个分布的偏移
hours_offset = np.concatenate([hours_offset1, hours_offset2])
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
# 按时间排序(从最早到最近)
return sorted(timestamps)
def get_timestamp_array(self):
"""返回时间戳数组"""
timestamps = self.generate_time_samples()
return [int(t.timestamp()) for t in timestamps]
# def print_time_samples(timestamps, show_distribution=True):
# """打印时间样本和分布信息"""
# print(f"\n生成的{len(timestamps)}个时间点分布:")
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
# print("-" * 50)
# now = datetime.now()
# time_diffs = []
# for i, timestamp in enumerate(timestamps, 1):
# hours_diff = (now - timestamp).total_seconds() / 3600
# time_diffs.append(hours_diff)
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# # 打印统计信息
# print("\n统计信息")
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
# print(f"标准差:{np.std(time_diffs):.2f}小时")
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
# if show_distribution:
# # 计算时间分布的直方图
# hist, bins = np.histogram(time_diffs, bins=40)
# print("\n时间分布每个*代表一个时间点):")
# for i in range(len(hist)):
# if hist[i] > 0:
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# # 使用示例
# if __name__ == "__main__":
# # 创建一个双峰分布的记忆调度器
# scheduler = MemoryBuildScheduler(
# n_hours1=12, # 第一个分布均值12小时前
# std_hours1=8, # 第一个分布标准差
# weight1=0.7, # 第一个分布权重 70%
# n_hours2=36, # 第二个分布均值36小时前
# std_hours2=24, # 第二个分布标准差
# weight2=0.3, # 第二个分布权重 30%
# total_samples=50, # 总共生成50个时间点
# )
# # 生成时间分布
# timestamps = scheduler.generate_time_samples()
# # 打印结果,包含分布可视化
# print_time_samples(timestamps, show_distribution=True)
# # 打印时间戳数组
# timestamp_array = scheduler.get_timestamp_array()
# print("\n时间戳数组Unix时间戳")
# print("[", end="")
# for i, ts in enumerate(timestamp_array):
# if i > 0:
# print(", ", end="")
# print(ts, end="")
# print("]")

View File

@@ -1,359 +0,0 @@
import asyncio
import time
from typing import List, Dict, Any
from dataclasses import dataclass
import threading
from src.common.logger import get_logger
from src.chat.utils.utils import get_embedding
from src.common.vector_db import vector_db_service
logger = get_logger("vector_instant_memory_v2")
@dataclass
class ChatMessage:
"""聊天消息数据结构"""
message_id: str
chat_id: str
content: str
timestamp: float
sender: str = "unknown"
message_type: str = "text"
class VectorInstantMemoryV2:
"""重构的向量瞬时记忆系统 V2
新设计理念:
1. 全量存储 - 所有聊天记录都存储为向量
2. 定时清理 - 定期清理过期记录
3. 实时匹配 - 新消息与历史记录做向量相似度匹配
"""
def __init__(self, chat_id: str, retention_hours: int = 24, cleanup_interval: int = 3600):
"""
初始化向量瞬时记忆系统
Args:
chat_id: 聊天ID
retention_hours: 记忆保留时长(小时)
cleanup_interval: 清理间隔(秒)
"""
self.chat_id = chat_id
self.retention_hours = retention_hours
self.cleanup_interval = cleanup_interval
self.collection_name = "instant_memory"
# 清理任务相关
self.cleanup_task = None
self.is_running = True
# 初始化系统
self._init_chroma()
self._start_cleanup_task()
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
def _init_chroma(self):
"""使用全局服务初始化向量数据库集合"""
try:
# 现在我们只获取集合,而不是创建新的客户端
vector_db_service.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
except Exception as e:
logger.error(f"获取向量记忆集合失败: {e}")
def _start_cleanup_task(self):
"""启动定时清理任务"""
def cleanup_worker():
while self.is_running:
try:
self._cleanup_expired_messages()
time.sleep(self.cleanup_interval)
except Exception as e:
logger.error(f"清理任务异常: {e}")
time.sleep(60) # 异常时等待1分钟再继续
self.cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
self.cleanup_task.start()
logger.info(f"定时清理任务已启动,间隔{self.cleanup_interval}")
def _cleanup_expired_messages(self):
"""清理过期的聊天记录"""
try:
expire_time = time.time() - (self.retention_hours * 3600)
# 采用 get -> filter -> delete 模式,避免复杂的 where 查询
# 1. 获取当前 chat_id 的所有文档
results = vector_db_service.get(
collection_name=self.collection_name, where={"chat_id": self.chat_id}, include=["metadatas"]
)
if not results or not results.get("ids"):
logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理")
return
# 2. 在内存中过滤出过期的文档
expired_ids = []
metadatas = results.get("metadatas", [])
ids = results.get("ids", [])
for i, metadata in enumerate(metadatas):
if metadata and metadata.get("timestamp", float("inf")) < expire_time:
expired_ids.append(ids[i])
# 3. 如果有过期文档,根据 ID 进行删除
if expired_ids:
vector_db_service.delete(collection_name=self.collection_name, ids=expired_ids)
logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录")
else:
logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录")
except Exception as e:
logger.error(f"清理过期记录失败: {e}")
async def store_message(self, content: str, sender: str = "user") -> bool:
"""
存储聊天消息到向量库
Args:
content: 消息内容
sender: 发送者
Returns:
bool: 是否存储成功
"""
if not content.strip():
return False
try:
# 生成消息向量
message_vector = await get_embedding(content)
if not message_vector:
logger.warning(f"消息向量生成失败: {content[:50]}...")
return False
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
message = ChatMessage(
message_id=message_id, chat_id=self.chat_id, content=content, timestamp=time.time(), sender=sender
)
# 使用新的服务存储
vector_db_service.add(
collection_name=self.collection_name,
embeddings=[message_vector],
documents=[content],
metadatas=[
{
"message_id": message.message_id,
"chat_id": message.chat_id,
"timestamp": message.timestamp,
"sender": message.sender,
"message_type": message.message_type,
}
],
ids=[message_id],
)
logger.debug(f"消息已存储: {content[:50]}...")
return True
except Exception as e:
logger.error(f"存储消息失败: {e}")
return False
async def find_similar_messages(
self, query: str, top_k: int = 5, similarity_threshold: float = 0.7
) -> List[Dict[str, Any]]:
"""
查找与查询相似的历史消息
Args:
query: 查询内容
top_k: 返回的最相似消息数量
similarity_threshold: 相似度阈值
Returns:
List[Dict]: 相似消息列表包含content、similarity、timestamp等信息
"""
if not query.strip():
return []
try:
query_vector = await get_embedding(query)
if not query_vector:
return []
# 使用新的服务进行查询
results = vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_vector],
n_results=top_k,
where={"chat_id": self.chat_id},
)
if not results.get("documents") or not results["documents"][0]:
return []
# 处理搜索结果
similar_messages = []
documents = results["documents"][0]
distances = results["distances"][0] if results["distances"] else []
metadatas = results["metadatas"][0] if results["metadatas"] else []
for i, doc in enumerate(documents):
# 计算相似度ChromaDB返回距离需转换
distance = distances[i] if i < len(distances) else 1.0
similarity = 1 - distance
# 过滤低相似度结果
if similarity < similarity_threshold:
continue
# 获取元数据
metadata = metadatas[i] if i < len(metadatas) else {}
# 安全获取timestamp
timestamp = metadata.get("timestamp", 0) if isinstance(metadata, dict) else 0
timestamp = float(timestamp) if isinstance(timestamp, (int, float)) else 0.0
similar_messages.append(
{
"content": doc,
"similarity": similarity,
"timestamp": timestamp,
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
"time_ago": self._format_time_ago(timestamp),
}
)
# 按相似度排序
similar_messages.sort(key=lambda x: x["similarity"], reverse=True)
logger.debug(f"找到 {len(similar_messages)} 条相似消息 (查询: {query[:30]}...)")
return similar_messages
except Exception as e:
logger.error(f"查找相似消息失败: {e}")
return []
@staticmethod
def _format_time_ago(timestamp: float) -> str:
"""格式化时间差显示"""
if timestamp <= 0:
return "未知时间"
try:
now = time.time()
diff = now - timestamp
if diff < 60:
return f"{int(diff)}秒前"
elif diff < 3600:
return f"{int(diff / 60)}分钟前"
elif diff < 86400:
return f"{int(diff / 3600)}小时前"
else:
return f"{int(diff / 86400)}天前"
except Exception:
return "时间格式错误"
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
"""
获取与当前消息相关的记忆上下文
Args:
current_message: 当前消息
context_size: 上下文消息数量
Returns:
str: 格式化的记忆上下文
"""
similar_messages = await self.find_similar_messages(
current_message,
top_k=context_size,
similarity_threshold=0.6, # 降低阈值以获得更多上下文
)
if not similar_messages:
return ""
# 格式化上下文
context_lines = []
for msg in similar_messages:
context_lines.append(
f"[{msg['time_ago']}] {msg['sender']}: {msg['content']} (相似度: {msg['similarity']:.2f})"
)
return "相关的历史记忆:\n" + "\n".join(context_lines)
def get_stats(self) -> Dict[str, Any]:
"""获取记忆系统统计信息"""
stats = {
"chat_id": self.chat_id,
"retention_hours": self.retention_hours,
"cleanup_interval": self.cleanup_interval,
"system_status": "running" if self.is_running else "stopped",
"total_messages": 0,
"db_status": "connected",
}
try:
# 注意count() 现在没有 chat_id 过滤,返回的是整个集合的数量
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
# 这里为了简化,暂时显示集合总数
result = vector_db_service.count(collection_name=self.collection_name)
stats["total_messages"] = result
except Exception:
stats["total_messages"] = "查询失败"
stats["db_status"] = "disconnected"
return stats
def stop(self):
"""停止记忆系统"""
self.is_running = False
if self.cleanup_task and self.cleanup_task.is_alive():
logger.info("正在停止定时清理任务...")
logger.info(f"向量瞬时记忆系统已停止: {self.chat_id}")
# 为了兼容现有代码,提供工厂函数
def create_vector_memory_v2(chat_id: str, retention_hours: int = 24) -> VectorInstantMemoryV2:
"""创建向量瞬时记忆系统V2实例"""
return VectorInstantMemoryV2(chat_id, retention_hours)
# 使用示例
async def demo():
"""使用演示"""
memory = VectorInstantMemoryV2("demo_chat")
# 存储一些测试消息
await memory.store_message("今天天气不错,出去散步了", "用户")
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
await memory.store_message("明天要开会,有点紧张", "用户")
# 查找相似消息
similar = await memory.find_similar_messages("天气怎么样")
print("相似消息:", similar)
# 获取上下文
context = await memory.get_memory_for_context("今天心情如何")
print("记忆上下文:", context)
# 查看统计信息
stats = memory.get_stats()
print("系统状态:", stats)
memory.stop()
if __name__ == "__main__":
asyncio.run(demo())

View File

@@ -0,0 +1,723 @@
# -*- coding: utf-8 -*-
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import os
import time
import orjson
import asyncio
from typing import Dict, List, Optional, Tuple, Set, Any
from dataclasses import dataclass
from datetime import datetime
import threading
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pandas as pd
from pathlib import Path
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk
logger = get_logger(__name__)
# 尝试导入FAISS如果不可用则使用简单替代
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using simple vector storage")
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 768
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 100 # 每N次操作自动保存
enable_compression: bool = True
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 向量索引
self.vector_index = None
self.memory_id_to_index = {} # memory_id -> vector index
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.vector_cache: Dict[str, List[float]] = {}
# 统计信息
self.storage_stats = {
"total_vectors": 0,
"index_build_time": 0.0,
"average_search_time": 0.0,
"cache_hit_rate": 0.0,
"total_searches": 0,
"cache_hits": 0
}
# 线程锁
self._lock = threading.RLock()
self._operation_count = 0
# 初始化索引
self._initialize_index()
# 嵌入模型
self.embedding_model: LLMRequest = None
def _initialize_index(self):
"""初始化向量索引"""
try:
if FAISS_AVAILABLE:
if self.config.index_type == "flat":
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
elif self.config.index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.config.dimension)
nlist = min(100, max(1, self.config.max_index_size // 1000))
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
elif self.config.index_type == "hnsw":
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
self.vector_index.hnsw.efConstruction = 40
else:
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
else:
# 简单的向量存储实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
except Exception as e:
logger.error(f"❌ 向量索引初始化失败: {e}")
# 回退到简单实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
async def initialize_embedding_model(self):
"""初始化嵌入模型"""
if self.embedding_model is None:
self.embedding_model = LLMRequest(
model_set=model_config.model_task_config.embedding,
request_type="memory.embedding"
)
logger.info("✅ 嵌入模型初始化完成")
async def store_memories(self, memories: List[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
start_time = time.time()
try:
# 确保嵌入模型已初始化
await self.initialize_embedding_model()
# 批量获取嵌入向量
embedding_tasks = []
memory_texts = []
for memory in memories:
if memory.embedding is None:
# 如果没有嵌入向量,需要生成
text = self._prepare_embedding_text(memory)
memory_texts.append((memory.memory_id, text))
else:
# 已有嵌入向量,直接使用
await self._add_single_memory(memory, memory.embedding)
# 批量生成缺失的嵌入向量
if memory_texts:
await self._batch_generate_and_store_embeddings(memory_texts)
# 自动保存检查
self._operation_count += len(memories)
if self._operation_count >= self.config.auto_save_interval:
await self.save_storage()
self._operation_count = 0
storage_time = time.time() - start_time
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}")
except Exception as e:
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
"""准备用于嵌入的文本"""
# 构建包含丰富信息的文本
text_parts = [
memory.text_content,
f"类型: {memory.memory_type.value}",
f"关键词: {', '.join(memory.keywords)}",
f"标签: {', '.join(memory.tags)}"
]
if memory.metadata.emotional_context:
text_parts.append(f"情感: {memory.metadata.emotional_context}")
return " | ".join(text_parts)
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
try:
texts = [text for _, text in memory_texts]
memory_ids = [memory_id for memory_id, _ in memory_texts]
# 批量生成嵌入向量
embeddings = await self._batch_generate_embeddings(texts)
# 存储向量和记忆
for memory_id, embedding in zip(memory_ids, embeddings):
if embedding and len(embedding) == self.config.dimension:
memory = self.memory_cache.get(memory_id)
if memory:
await self._add_single_memory(memory, embedding)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""批量生成嵌入向量"""
if not texts:
return []
try:
# 创建新的事件循环来运行异步操作
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 使用线程池并行生成嵌入向量
with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor:
tasks = []
for text in texts:
task = loop.run_in_executor(
executor,
self._generate_single_embedding,
text
)
tasks.append(task)
embeddings = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
valid_embeddings = []
for i, embedding in enumerate(embeddings):
if isinstance(embedding, Exception):
logger.warning(f"生成第 {i} 个文本的嵌入向量失败: {embedding}")
valid_embeddings.append([])
elif embedding and len(embedding) == self.config.dimension:
valid_embeddings.append(embedding)
else:
logger.warning(f"{i} 个文本的嵌入向量格式异常")
valid_embeddings.append([])
return valid_embeddings
finally:
loop.close()
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
return [[] for _ in texts]
def _generate_single_embedding(self, text: str) -> List[float]:
"""生成单个文本的嵌入向量"""
try:
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 使用模型生成嵌入向量
embedding, _ = loop.run_until_complete(
self.embedding_model.get_embedding(text)
)
if embedding and len(embedding) == self.config.dimension:
return embedding
else:
logger.warning(f"嵌入向量维度不匹配: 期望 {self.config.dimension}, 实际 {len(embedding) if embedding else 0}")
return []
finally:
loop.close()
except Exception as e:
logger.error(f"生成嵌入向量失败: {e}")
return []
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
# 规范化向量
if embedding:
embedding = self._normalize_vector(embedding)
# 添加到缓存
self.memory_cache[memory.memory_id] = memory
self.vector_cache[memory.memory_id] = embedding
# 更新记忆的嵌入向量
memory.set_embedding(embedding)
# 添加到向量索引
if hasattr(self.vector_index, 'add'):
# FAISS索引
if isinstance(embedding, np.ndarray):
vector_array = embedding.reshape(1, -1).astype('float32')
else:
vector_array = np.array([embedding], dtype='float32')
# 特殊处理IVF索引
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
# IVF索引需要先训练
logger.debug("训练IVF索引...")
self.vector_index.train(vector_array)
self.vector_index.add(vector_array)
index_id = self.vector_index.ntotal - 1
else:
# 简单索引
index_id = self.vector_index.add_vector(embedding)
# 更新映射关系
self.memory_id_to_index[memory.memory_id] = index_id
self.index_to_memory_id[index_id] = memory.memory_id
# 更新统计
self.storage_stats["total_vectors"] += 1
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: List[float]) -> List[float]:
"""L2归一化向量"""
if not vector:
return vector
try:
vector_array = np.array(vector, dtype=np.float32)
norm = np.linalg.norm(vector_array)
if norm == 0:
return vector
normalized = vector_array / norm
return normalized.tolist()
except Exception as e:
logger.warning(f"向量归一化失败: {e}")
return vector
async def search_similar_memories(
self,
query_vector: List[float],
limit: int = 10,
user_id: Optional[str] = None
) -> List[Tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
# 执行向量搜索
with self._lock:
if hasattr(self.vector_index, 'search'):
# FAISS索引
if isinstance(query_vector, np.ndarray):
query_array = query_vector.reshape(1, -1).astype('float32')
else:
query_array = np.array([query_vector], dtype='float32')
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
# 设置IVF搜索参数
nprobe = min(self.vector_index.nlist, 10)
self.vector_index.nprobe = nprobe
distances, indices = self.vector_index.search(query_array, min(limit, self.storage_stats["total_vectors"]))
distances = distances.flatten().tolist()
indices = indices.flatten().tolist()
else:
# 简单索引
results = self.vector_index.search(query_vector, limit)
distances = [score for _, score in results]
indices = [idx for idx, _ in results]
# 处理搜索结果
results = []
for distance, index in zip(distances, indices):
if index == -1: # FAISS的无效索引标记
continue
memory_id = self.index_to_memory_id.get(index)
if memory_id:
# 应用用户过滤
if user_id:
memory = self.memory_cache.get(memory_id)
if memory and memory.user_id != user_id:
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
results.append((memory_id, similarity))
# 更新统计
search_time = time.time() - start_time
self.storage_stats["total_searches"] += 1
self.storage_stats["average_search_time"] = (
(self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) /
self.storage_stats["total_searches"]
)
return results[:limit]
except Exception as e:
logger.error(f"❌ 向量搜索失败: {e}")
return []
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
self.storage_stats["cache_hits"] += 1
return self.memory_cache[memory_id]
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
return
# 获取旧索引
old_index = self.memory_id_to_index[memory_id]
# 删除旧向量(如果支持)
if hasattr(self.vector_index, 'remove_ids'):
try:
self.vector_index.remove_ids(np.array([old_index]))
except:
logger.warning("无法删除旧向量,将直接添加新向量")
# 规范化新向量
new_embedding = self._normalize_vector(new_embedding)
# 添加新向量
if hasattr(self.vector_index, 'add'):
if isinstance(new_embedding, np.ndarray):
vector_array = new_embedding.reshape(1, -1).astype('float32')
else:
vector_array = np.array([new_embedding], dtype='float32')
self.vector_index.add(vector_array)
new_index = self.vector_index.ntotal - 1
else:
new_index = self.vector_index.add_vector(new_embedding)
# 更新映射关系
self.memory_id_to_index[memory_id] = new_index
self.index_to_memory_id[new_index] = memory_id
# 更新缓存
self.vector_cache[memory_id] = new_embedding
# 更新记忆对象
memory = self.memory_cache.get(memory_id)
if memory:
memory.set_embedding(new_embedding)
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
except Exception as e:
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
async def delete_memory(self, memory_id: str):
"""删除记忆"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
return
# 获取索引
index = self.memory_id_to_index[memory_id]
# 从向量索引中删除(如果支持)
if hasattr(self.vector_index, 'remove_ids'):
try:
self.vector_index.remove_ids(np.array([index]))
except:
logger.warning("无法从向量索引中删除,仅从缓存中移除")
# 删除映射关系
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
# 从缓存中删除
self.memory_cache.pop(memory_id, None)
self.vector_cache.pop(memory_id, None)
# 更新统计
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
logger.debug(f"删除记忆 {memory_id}")
except Exception as e:
logger.error(f"❌ 删除记忆失败: {e}")
async def save_storage(self):
"""保存向量存储到文件"""
try:
logger.info("正在保存向量存储...")
# 保存记忆缓存
cache_data = {
memory_id: memory.to_dict()
for memory_id, memory in self.memory_cache.items()
}
cache_file = self.storage_path / "memory_cache.json"
with open(cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
with open(vector_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": self.memory_id_to_index,
"index_to_memory_id": self.index_to_memory_id
}
with open(mapping_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存FAISS索引如果可用
if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'):
index_file = self.storage_path / "vector_index.faiss"
faiss.write_index(self.vector_index, str(index_file))
# 保存统计信息
stats_file = self.storage_path / "storage_stats.json"
with open(stats_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
logger.info("✅ 向量存储保存完成")
except Exception as e:
logger.error(f"❌ 保存向量存储失败: {e}")
async def load_storage(self):
"""从文件加载向量存储"""
try:
logger.info("正在加载向量存储...")
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
memory_id: MemoryChunk.from_dict(memory_data)
for memory_id, memory_data in cache_data.items()
}
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, 'r', encoding='utf-8') as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = orjson.loads(f.read())
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
# 加载FAISS索引如果可用
if FAISS_AVAILABLE:
index_file = self.storage_path / "vector_index.faiss"
if index_file.exists() and hasattr(self.vector_index, 'load'):
try:
loaded_index = faiss.read_index(str(index_file))
# 如果索引类型匹配,则替换
if type(loaded_index) == type(self.vector_index):
self.vector_index = loaded_index
logger.info("✅ FAISS索引加载完成")
else:
logger.warning("索引类型不匹配,重新构建索引")
await self._rebuild_index()
except Exception as e:
logger.warning(f"加载FAISS索引失败: {e},重新构建")
await self._rebuild_index()
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, 'r', encoding='utf-8') as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
except Exception as e:
logger.error(f"❌ 加载向量存储失败: {e}")
async def _rebuild_index(self):
"""重建向量索引"""
try:
logger.info("正在重建向量索引...")
# 重新初始化索引
self._initialize_index()
# 重新添加所有向量
for memory_id, embedding in self.vector_cache.items():
if embedding:
memory = self.memory_cache.get(memory_id)
if memory:
await self._add_single_memory(memory, embedding)
logger.info("✅ 向量索引重建完成")
except Exception as e:
logger.error(f"❌ 重建向量索引失败: {e}")
async def optimize_storage(self):
"""优化存储"""
try:
logger.info("开始向量存储优化...")
# 清理无效引用
self._cleanup_invalid_references()
# 重新构建索引(如果碎片化严重)
if self.storage_stats["total_vectors"] > 1000:
await self._rebuild_index()
# 更新缓存命中率
if self.storage_stats["total_searches"] > 0:
self.storage_stats["cache_hit_rate"] = (
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
)
logger.info("✅ 向量存储优化完成")
except Exception as e:
logger.error(f"❌ 向量存储优化失败: {e}")
def _cleanup_invalid_references(self):
"""清理无效引用"""
with self._lock:
# 清理无效的memory_id到index的映射
valid_memory_ids = set(self.memory_cache.keys())
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
for memory_id in invalid_memory_ids:
index = self.memory_id_to_index[memory_id]
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> Dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
else:
stats["cache_hit_rate"] = 0.0
return stats
class SimpleVectorIndex:
"""简单的向量索引实现当FAISS不可用时的替代方案"""
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: List[List[float]] = []
self.vector_ids: List[int] = []
self.next_id = 0
def add_vector(self, vector: List[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
vector_id = self.next_id
self.vectors.append(vector.copy())
self.vector_ids.append(vector_id)
self.next_id += 1
return vector_id
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
results = []
for i, vector in enumerate(self.vectors):
similarity = self._calculate_cosine_similarity(query_vector, vector)
results.append((self.vector_ids[i], similarity))
# 按相似度排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:limit]
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2))
norm1 = sum(x * x for x in v1) ** 0.5
norm2 = sum(x * x for x in v2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
except Exception:
return 0.0
@property
def ntotal(self) -> int:
"""向量总数"""
return len(self.vectors)

View File

@@ -28,8 +28,8 @@ from src.chat.utils.chat_message_builder import (
replace_user_references_sync,
)
from src.chat.express.expression_selector import expression_selector
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
# 旧记忆系统已被移除
# 旧记忆系统已被移除
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, EventType
@@ -231,9 +231,12 @@ class DefaultReplyer:
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator()
# 使用纯向量瞬时记忆系统V2支持自定义保留时间
self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1)
# 使用新的增强记忆系统
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
# self.memory_activator = EnhancedMemoryActivator()
self.memory_activator = None # 暂时禁用记忆激活器
# 旧的即时记忆系统已被移除,现在使用增强记忆系统
# self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1)
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
@@ -459,90 +462,65 @@ class DefaultReplyer:
instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history
)
# 使用新的增强记忆系统检索记忆
running_memories = []
instant_memory = None
if global_config.memory.enable_instant_memory:
# 使用异步记忆包装器(最优化的非阻塞模式)
try:
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 使用新的增强记忆系统
from src.chat.memory_system.enhanced_memory_integration import recall_memories, remember_message
# 获取异步记忆包装器
async_memory = get_async_instant_memory(self.chat_stream.stream_id)
# 后台存储聊天历史(完全非阻塞)
async_memory.store_memory_background(chat_history)
# 快速检索记忆最大超时2秒
instant_memory = await async_memory.get_memory_with_fallback(target, max_timeout=2.0)
logger.info(f"异步瞬时记忆:{instant_memory}")
except ImportError:
# 如果异步包装器不可用,尝试使用异步记忆管理器
try:
from src.chat.memory_system.async_memory_optimizer import (
retrieve_memory_nonblocking,
store_memory_nonblocking,
# 异步存储聊天历史(非阻塞)
asyncio.create_task(
remember_message(
message=chat_history,
user_id=str(self.chat_stream.stream_id),
chat_id=self.chat_stream.stream_id
)
)
# 异步存储聊天历史(非阻塞)
asyncio.create_task(
store_memory_nonblocking(chat_id=self.chat_stream.stream_id, content=chat_history)
)
# 检索相关记忆
enhanced_memories = await recall_memories(
query=target,
user_id=str(self.chat_stream.stream_id),
chat_id=self.chat_stream.stream_id
)
# 尝试从缓存获取瞬时记忆
instant_memory = await retrieve_memory_nonblocking(chat_id=self.chat_stream.stream_id, query=target)
# 转换格式以兼容现有代码
running_memories = []
if enhanced_memories and enhanced_memories.get("has_memories"):
for memory in enhanced_memories.get("memories", []):
running_memories.append({
'content': memory.get("content", ""),
'score': memory.get("confidence", 0.0),
'memory_type': memory.get("type", "unknown")
})
# 如果没有缓存结果,快速检索一次
if instant_memory is None:
try:
instant_memory = await asyncio.wait_for(
self.instant_memory.get_memory_for_context(target), timeout=1.5
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,使用空结果")
instant_memory = ""
# 构建瞬时记忆字符串
if enhanced_memories and enhanced_memories.get("has_memories"):
instant_memory = "\\n".join([
f"{memory.get('content', '')} (相似度: {memory.get('confidence', 0.0):.2f})"
for memory in enhanced_memories.get("memories", [])[:3] # 取前3条
])
logger.info(f"向量瞬时记忆:{instant_memory}")
except ImportError:
# 最后的fallback使用原有逻辑但加上超时控制
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
# 异步存储聊天历史
asyncio.create_task(self.instant_memory.store_message(chat_history))
# 带超时的记忆检索
try:
instant_memory = await asyncio.wait_for(
self.instant_memory.get_memory_for_context(target),
timeout=1.0, # 最保守的1秒超时
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,跳过记忆获取")
instant_memory = ""
except Exception as e:
logger.error(f"瞬时记忆检索失败: {e}")
instant_memory = ""
logger.info(f"同步瞬时记忆:{instant_memory}")
logger.info(f"增强记忆系统检索到 {len(running_memories)} 条记忆")
except Exception as e:
logger.error(f"瞬时记忆系统异常: {e}")
logger.warning(f"增强记忆系统检索失败: {e}")
running_memories = []
instant_memory = ""
# 构建记忆字符串,即使某种记忆为空也要继续
memory_str = ""
has_any_memory = False
# 添加长期记忆
# 添加长期记忆(来自增强记忆系统)
if running_memories:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
memory_str += f"- {running_memory['content']} (类型: {running_memory['memory_type']}, 相似度: {running_memory['score']:.2f})\n"
has_any_memory = True
# 添加瞬时记忆

View File

@@ -371,28 +371,35 @@ class Prompt:
tasks.append(self._build_cross_context())
task_names.append("cross_context")
# 性能优化
base_timeout = 10.0
task_timeout = 2.0
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout),
30.0,
)
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
"tool_info": 3.0, # 工具信息中等速度
"relation_info": 2.0, # 关系信息通常较快
"knowledge_info": 3.0, # 知识库查询中等速度
"cross_context": 2.0, # 上下文处理通常较快
"expression_habits": 1.5, # 表达习惯处理很快
}
max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks:
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
# 分别处理每个任务,避免慢任务影响快任务
results = []
for i, task in enumerate(tasks):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
task_timeout = task_timeouts.get(task_name, 2.0) # 默认2秒
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
try:
result = await asyncio.wait_for(task, timeout=task_timeout)
results.append(result)
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
except asyncio.TimeoutError:
logger.warning(f"构建任务{task_name}超时 ({task_timeout}s),使用默认值")
# 为超时任务提供默认值
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
except Exception as e:
logger.error(f"构建任务{task_name}失败: {str(e)}")
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
# 处理结果
context_data = {}
@@ -528,8 +535,7 @@ class Prompt:
return {"memory_block": ""}
try:
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
# 获取聊天历史
chat_history = ""
@@ -539,15 +545,38 @@ class Prompt:
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 激活长期记忆
memory_activator = MemoryActivator()
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, chat_history_prompt=chat_history
)
# 并行执行记忆查询以提高性能
import asyncio
# 获取即时记忆
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
# 创建记忆查询任务
memory_tasks = [
enhanced_memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, chat_history_prompt=chat_history
),
enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
)
]
# 等待所有记忆查询完成最多3秒
try:
running_memories, instant_memory = await asyncio.wait_for(
asyncio.gather(*memory_tasks, return_exceptions=True),
timeout=3.0
)
# 处理可能的异常结果
if isinstance(running_memories, Exception):
logger.warning(f"长期记忆查询失败: {running_memories}")
running_memories = []
if isinstance(instant_memory, Exception):
logger.warning(f"即时记忆查询失败: {instant_memory}")
instant_memory = None
except asyncio.TimeoutError:
logger.warning("记忆查询超时,使用部分结果")
running_memories = []
instant_memory = None
# 构建记忆块
memory_parts = []
@@ -870,6 +899,32 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
"""
为超时的任务提供默认结果
Args:
task_name: 任务名称
Returns:
Dict: 默认结果
"""
defaults = {
"memory_block": {"memory_block": ""},
"tool_info": {"tool_info_block": ""},
"relation_info": {"relation_info_block": ""},
"knowledge_info": {"knowledge_prompt": ""},
"cross_context": {"cross_context_block": ""},
"expression_habits": {"expression_habits_block": ""},
}
if task_name in defaults:
logger.info(f"为超时任务 {task_name} 提供默认值")
return defaults[task_name]
else:
logger.warning(f"未知任务类型 {task_name},返回空结果")
return {}
@staticmethod
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
"""

View File

@@ -459,6 +459,40 @@ class MemoryConfig(ValidatedConfigBase):
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
# 增强记忆系统配置
enable_enhanced_memory: bool = Field(default=True, description="启用增强记忆系统")
enhanced_memory_auto_save: bool = Field(default=True, description="自动保存增强记忆")
# 记忆构建配置
min_memory_length: int = Field(default=10, description="最小记忆长度")
max_memory_length: int = Field(default=500, description="最大记忆长度")
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
# 向量存储配置
vector_dimension: int = Field(default=768, description="向量维度")
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
# 多阶段检索配置
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")
vector_search_limit: int = Field(default=50, description="向量搜索阶段返回数量")
semantic_rerank_limit: int = Field(default=20, description="语义重排序阶段返回数量")
final_result_limit: int = Field(default=10, description="最终结果数量")
# 检索权重配置
vector_weight: float = Field(default=0.4, description="向量相似度权重")
semantic_weight: float = Field(default=0.3, description="语义相似度权重")
context_weight: float = Field(default=0.2, description="上下文权重")
recency_weight: float = Field(default=0.1, description="时效性权重")
# 记忆融合配置
fusion_similarity_threshold: float = Field(default=0.85, description="融合相似度阈值")
deduplication_window_hours: int = Field(default=24, description="去重时间窗口(小时)")
# 缓存配置
enable_memory_cache: bool = Field(default=True, description="启用记忆缓存")
cache_ttl_seconds: int = Field(default=300, description="缓存生存时间(秒)")
max_cache_size: int = Field(default=1000, description="最大缓存大小")
class MoodConfig(ValidatedConfigBase):
"""情绪配置类"""

View File

@@ -34,54 +34,8 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
from src.chat.memory_system.Hippocampus import hippocampus_manager
if not global_config.memory.enable_memory:
import src.chat.memory_system.Hippocampus as hippocampus_module
class MockHippocampusManager:
def initialize(self):
pass
def get_hippocampus(self):
return None
async def build_memory(self):
pass
async def forget_memory(self, percentage: float = 0.005):
pass
async def consolidate_memory(self):
pass
async def get_memory_from_text(
self,
text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3,
fast_retrieval: bool = False,
) -> list:
return []
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list:
return []
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str]]:
return 0.0, []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
return []
def get_all_node_names(self) -> list:
return []
hippocampus_module.hippocampus_manager = MockHippocampusManager()
# 导入增强记忆系统管理器
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
# 插件系统现在使用统一的插件加载器
@@ -106,7 +60,8 @@ def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float):
class MainSystem:
def __init__(self):
self.hippocampus_manager = hippocampus_manager
# 使用增强记忆系统
self.enhanced_memory_manager = enhanced_memory_manager
self.individuality: Individuality = get_individuality()
@@ -169,19 +124,18 @@ class MainSystem:
logger.error(f"停止热重载系统时出错: {e}")
try:
# 停止异步记忆管理器
# 停止增强记忆系统
if global_config.memory.enable_memory:
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(async_memory_manager.shutdown())
asyncio.create_task(self.enhanced_memory_manager.shutdown())
else:
loop.run_until_complete(async_memory_manager.shutdown())
logger.info("🛑 记忆管理器已停止")
loop.run_until_complete(self.enhanced_memory_manager.shutdown())
logger.info("🛑 增强记忆系统已停止")
except Exception as e:
logger.error(f"停止记忆管理器时出错: {e}")
logger.error(f"停止增强记忆系统时出错: {e}")
async def _message_process_wrapper(self, message_data: Dict[str, Any]):
"""并行处理消息的包装器"""
@@ -304,9 +258,11 @@ MoFox_Bot(第三方修改版)
logger.info("聊天管理器初始化成功")
# 初始化记忆系统
self.hippocampus_manager.initialize()
logger.info("记忆系统初始化成功")
# 初始化增强记忆系统
await self.enhanced_memory_manager.initialize()
logger.info("增强记忆系统初始化成功")
# 老记忆系统已完全删除
# 初始化LPMM知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
@@ -314,14 +270,8 @@ MoFox_Bot(第三方修改版)
initialize_lpmm_knowledge()
logger.info("LPMM知识库初始化成功")
# 初始化异步记忆管理器
try:
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
await async_memory_manager.initialize()
logger.info("记忆管理器初始化成功")
except Exception as e:
logger.error(f"记忆管理器初始化失败: {e}")
# 异步记忆管理器已禁用,增强记忆系统有内置的优化机制
logger.info("异步记忆管理器已禁用 - 使用增强记忆系统内置优化")
# await asyncio.sleep(0.5) #防止logger输出飞了
@@ -376,81 +326,12 @@ MoFox_Bot(第三方修改版)
self.server.run(),
]
# 添加记忆系统相关任务
tasks.extend(
[
self.build_memory_task(),
self.forget_memory_task(),
self.consolidate_memory_task(),
]
)
# 增强记忆系统不需要定时任务,已禁用原有记忆系统的定时任务
logger.info("原有记忆系统定时任务已禁用 - 使用增强记忆系统")
await asyncio.gather(*tasks)
async def build_memory_task(self):
"""记忆构建任务"""
while True:
await asyncio.sleep(global_config.memory.memory_build_interval)
try:
# 使用异步记忆管理器进行非阻塞记忆构建
from src.chat.memory_system.async_memory_optimizer import build_memory_nonblocking
logger.info("正在启动记忆构建")
# 定义构建完成的回调函数
def build_completed(result):
if result:
logger.info("记忆构建完成")
else:
logger.warning("记忆构建失败")
# 启动异步构建,不等待完成
task_id = await build_memory_nonblocking()
logger.info(f"记忆构建任务已提交:{task_id}")
except ImportError:
# 如果异步优化器不可用,使用原有的同步方式(但在单独的线程中运行)
logger.warning("记忆优化器不可用,使用线性运行执行记忆构建")
def sync_build_memory():
"""在线程池中执行同步记忆构建"""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(self.hippocampus_manager.build_memory())
logger.info("记忆构建完成")
return result
except Exception as e:
logger.error(f"记忆构建失败: {e}")
return None
finally:
loop.close()
# 在线程池中执行记忆构建
asyncio.get_event_loop().run_in_executor(None, sync_build_memory)
except Exception as e:
logger.error(f"记忆构建任务启动失败: {e}")
# fallback到原有的同步方式
logger.info("正在进行记忆构建(同步模式)")
await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self):
"""记忆遗忘任务"""
while True:
await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self):
"""记忆整合任务"""
while True:
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
logger.info("[记忆整合] 开始整合记忆...")
await self.hippocampus_manager.consolidate_memory() # type: ignore
logger.info("[记忆整合] 记忆整合完成")
# 老记忆系统的定时任务已删除 - 增强记忆系统使用内置的维护机制
async def main():

View File

@@ -2,7 +2,8 @@ import asyncio
import math
from typing import Tuple
from src.chat.memory_system.Hippocampus import hippocampus_manager
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from maim_message.message_base import GroupInfo
from src.chat.message_receive.storage import MessageStorage
@@ -40,11 +41,31 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
if global_config.memory.enable_memory:
with Timer("记忆激活"):
interested_rate, _ = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)
logger.debug(f"记忆激活率: {interested_rate:.2f}")
# 使用新的增强记忆系统计算兴趣度
try:
from src.chat.memory_system.enhanced_memory_integration import recall_memories
# 检索相关记忆来估算兴趣度
enhanced_memories = await recall_memories(
query=message.processed_plain_text,
user_id=str(message.user_info.user_id),
chat_id=message.chat_id
)
# 基于检索结果计算兴趣度
if enhanced_memories:
# 有相关记忆,兴趣度基于相似度计算
max_score = max(score for _, score in enhanced_memories)
interested_rate = min(max_score, 1.0) # 限制在0-1之间
else:
# 没有相关记忆,给予基础兴趣度
interested_rate = 0.1
logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}")
except Exception as e:
logger.warning(f"增强记忆系统兴趣度计算失败: {e}")
interested_rate = 0.1 # 默认基础兴趣度
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算

View File

@@ -4,7 +4,8 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker
from src.chat.memory_system.Hippocampus import hippocampus_manager
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
import random
from datetime import datetime
import asyncio
@@ -171,16 +172,26 @@ class PromptBuilder:
@staticmethod
async def build_memory_block(text: str) -> str:
related_memory = await hippocampus_manager.get_memory_from_text(
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
# 使用新的增强记忆系统检索记忆
try:
from src.chat.memory_system.enhanced_memory_integration import recall_memories
related_memory_info = ""
if related_memory:
for memory in related_memory:
related_memory_info += memory[1]
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
return ""
enhanced_memories = await recall_memories(
query=text,
user_id="system", # 系统查询
chat_id="system"
)
related_memory_info = ""
if enhanced_memories and enhanced_memories.get("has_memories"):
for memory in enhanced_memories.get("memories", []):
related_memory_info += memory.get("content", "") + " "
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info.strip())
return ""
except Exception as e:
logger.warning(f"增强记忆系统检索失败: {e}")
return ""
@staticmethod
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):

View File

@@ -98,7 +98,6 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
message_recv = MessageRecv(new_message_dict)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
logger.info(message_recv)
return message_recv

View File

@@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional
from json_repair import repair_json
from src.chat.memory_system.Hippocampus import hippocampus_manager
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id,
@@ -602,14 +603,32 @@ class ChatterPlanFilter:
else:
keywords.append("晚上")
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
)
# 使用新的增强记忆系统检索记忆
try:
from src.chat.memory_system.enhanced_memory_integration import recall_memories
if not retrieved_memories:
# 将关键词转换为查询字符串
query = " ".join(keywords)
enhanced_memories = await recall_memories(
query=query,
user_id="system", # 系统查询
chat_id="system"
)
if not enhanced_memories:
return "最近没有什么特别的记忆。"
# 转换格式以兼容现有代码
retrieved_memories = []
if enhanced_memories and enhanced_memories.get("has_memories"):
for memory in enhanced_memories.get("memories", []):
retrieved_memories.append((memory.get("type", "unknown"), memory.get("content", "")))
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'" for topic, memory_item in retrieved_memories]
except Exception as e:
logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}")
return "最近没有什么特别的记忆。"
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'" for topic, memory_item in retrieved_memories]
return " ".join(memory_statements)
except Exception as e:
logger.error(f"获取长期记忆时出错: {e}")