feat(memory): 引入基于向量的瞬时记忆系统

引入了一个新的瞬时记忆系统,该系统将短期对话片段(消息集合)存储在专门的向量数据库中,以提供更即时、更相关的上下文。

该系统通过以下组件实现:
- **MessageCollection**: 用于封装一组相关消息的数据结构。
- **MessageCollectionStorage**: 负责将消息集合向量化并存入专用的ChromaDB集合,同时管理集合的生命周期(基于数量和时间清理)。
- **MessageCollectionProcessor**: 缓冲每个聊天的消息,当达到阈值时,将它们组合成一个`MessageCollection`并交由Storage处理。
- **集成**: `MemoryManager`和`MemorySystem`已更新,将瞬时记忆无缝融合到现有的记忆检索流程中,优先展示来自当前聊天的上下文。

此外,还进行了以下调整:
- 移除`orjson.dumps`中的`ensure_ascii=False`参数,以遵循`orjson`的默认行为,该行为始终返回UTF-8字节串,从而简化了编码处理。
- 在配置文件中增加了瞬时记忆最大集合数和保留时间的选项。
This commit is contained in:
minecraft1024a
2025-10-25 20:07:58 +08:00
parent 3877772c7c
commit 917754b4e0
8 changed files with 431 additions and 6 deletions

View File

@@ -702,7 +702,7 @@ class MemoryBuilder:
if isinstance(value, list | dict):
try:
value = orjson.dumps(value, ensure_ascii=False).decode("utf-8")
value = orjson.dumps(value).decode("utf-8")
except Exception:
value = str(value)

View File

@@ -454,7 +454,7 @@ class MemoryChunk:
def to_json(self) -> str:
"""转换为JSON字符串"""
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8")
return orjson.dumps(self.to_dict()).decode("utf-8")
@classmethod
def from_json(cls, json_str: str) -> "MemoryChunk":
@@ -610,3 +610,38 @@ def create_memory_chunk(
chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs)
return chunk
@dataclass
class MessageCollection:
"""消息集合数据结构"""
collection_id: str = field(default_factory=lambda: str(uuid.uuid4()))
chat_id: str | None = None # 聊天ID群聊或私聊
messages: list[str] = field(default_factory=list)
combined_text: str = ""
created_at: float = field(default_factory=time.time)
embedding: list[float] | None = None
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"collection_id": self.collection_id,
"chat_id": self.chat_id,
"messages": self.messages,
"combined_text": self.combined_text,
"created_at": self.created_at,
"embedding": self.embedding,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MessageCollection":
"""从字典创建实例"""
return cls(
collection_id=data.get("collection_id", str(uuid.uuid4())),
chat_id=data.get("chat_id"),
messages=data.get("messages", []),
combined_text=data.get("combined_text", ""),
created_at=data.get("created_at", time.time()),
embedding=data.get("embedding"),
)

View File

@@ -7,8 +7,10 @@ import re
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, MessageCollection
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -33,6 +35,8 @@ class MemoryManager:
def __init__(self):
self.memory_system: MemorySystem | None = None
self.message_collection_storage: MessageCollectionStorage | None = None
self.message_collection_processor: MessageCollectionProcessor | None = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
@@ -69,6 +73,10 @@ class MemoryManager:
# 初始化记忆系统
self.memory_system = await initialize_memory_system(llm_model)
# 初始化消息集合系统
self.message_collection_storage = MessageCollectionStorage()
self.message_collection_processor = MessageCollectionProcessor(self.message_collection_storage)
self.is_initialized = True
logger.info(" 记忆系统初始化完成")
@@ -76,6 +84,8 @@ class MemoryManager:
logger.error(f"记忆系统初始化失败: {e}")
# 如果系统初始化失败,创建一个空的管理器避免系统崩溃
self.memory_system = None
self.message_collection_storage = None
self.message_collection_processor = None
self.is_initialized = True # 标记为已初始化但系统不可用
def get_hippocampus(self):
@@ -235,6 +245,11 @@ class MemoryManager:
return []
try:
# 将消息添加到消息集合处理器
chat_id = context.get("chat_id")
if self.message_collection_processor and chat_id:
await self.message_collection_processor.add_message(conversation_text, chat_id)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
@@ -485,6 +500,60 @@ class MemoryManager:
return text
return text[: max_length - 1] + ""
async def get_relevant_message_collection(self, query_text: str, n_results: int = 3) -> list[MessageCollection]:
"""获取相关的消息集合列表"""
if not self.is_initialized or not self.message_collection_storage:
return []
try:
return await self.message_collection_storage.get_relevant_collection(query_text, n_results=n_results)
except Exception as e:
logger.error(f"get_relevant_message_collection 失败: {e}")
return []
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
if not self.is_initialized or not self.message_collection_storage:
return ""
try:
collections = await self.get_relevant_message_collection(query_text, n_results=3)
if not collections:
return ""
# 根据传入的 chat_id 对集合进行排序
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
context_parts = []
for collection in collections:
if not collection.combined_text:
continue
header = "## 📝 相关对话上下文\n"
if collection.chat_id == chat_id:
# 匹配的ID使用更明显的标识
context_parts.append(
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
)
else:
# 不匹配的ID
context_parts.append(
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
)
if not context_parts:
return ""
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
except Exception as e:
logger.error(f"get_message_collection_context 失败: {e}")
return ""
async def shutdown(self):
"""关闭增强记忆系统"""
if not self.is_initialized:

View File

@@ -19,6 +19,7 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
# 记忆采样模式枚举
@@ -142,6 +143,7 @@ class MemorySystem:
self.memory_builder: MemoryBuilder | None = None
self.fusion_engine: MemoryFusionEngine | None = None
self.unified_storage: VectorMemoryStorage | None = None # 统一存储系统
self.message_collection_storage: MessageCollectionStorage | None = None
self.query_planner: MemoryQueryPlanner | None = None
self.forgetting_engine: MemoryForgettingEngine | None = None
@@ -153,6 +155,7 @@ class MemorySystem:
self.total_memories = 0
self.last_build_time = None
self.last_retrieval_time = None
self.last_collection_cleanup_time: float = time.time()
# 构建节流记录
self._last_memory_build_times: dict[str, float] = {}
@@ -199,6 +202,9 @@ class MemorySystem:
self.memory_builder = MemoryBuilder(self.memory_extraction_model)
self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold)
# 初始化消息集合存储
self.message_collection_storage = MessageCollectionStorage()
# 初始化Vector DB存储系统替代旧的unified_memory_storage
from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig
@@ -694,7 +700,7 @@ class MemorySystem:
limit: int = 5,
**kwargs,
) -> list[MemoryChunk]:
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排),并融合瞬时记忆"""
raw_query = query_text or kwargs.get("query")
if not raw_query:
raise ValueError("query_text 或 query 参数不能为空")
@@ -839,7 +845,22 @@ class MemorySystem:
scored_memories.sort(key=lambda x: x[1], reverse=True)
# 返回 Top-K
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
final_memories = [mem for mem, score, details in scored_memories]
# === 新增:融合瞬时记忆 ===
try:
chat_id = normalized_context.get("chat_id")
instant_memories = await self._retrieve_instant_memories(raw_query, chat_id)
if instant_memories:
# 将瞬时记忆放在列表最前面
final_memories = instant_memories + final_memories
logger.info(f"融合了 {len(instant_memories)} 条瞬时记忆")
except Exception as e:
logger.warning(f"检索瞬时记忆失败: {e}", exc_info=True)
# 最终截断
final_memories = final_memories[:effective_limit]
retrieval_time = time.time() - start_time
@@ -913,6 +934,61 @@ class MemorySystem:
logger.error(f"❌ 记忆检索失败: {e}", exc_info=True)
raise
async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]:
"""检索瞬时记忆消息集合并转换为MemoryChunk"""
if not self.message_collection_storage:
return []
# 1. 优先检索当前聊天的消息集合
collections = []
if chat_id:
collections = await self.memory_manger.get_message_collection_context(query_text, chat_id=chat_id, n_results=1)
# 2. 如果当前聊天没有,或者不需要区分聊天,则进行全局检索
if not collections:
collections = await self.message_collection_storage.get_relevant_collection(query_text, chat_id=None, n_results=1)
if not collections:
return []
# 3. 将 MessageCollection 转换为 MemoryChunk
instant_memories = []
for collection in collections:
from src.chat.memory_system.memory_chunk import (
ContentStructure,
ImportanceLevel,
MemoryMetadata,
MemoryType,
)
header = f"[来自群/聊 {collection.chat_id} 的近期对话]"
if collection.chat_id == chat_id:
header = f"[🔥 来自当前聊天的近期对话]"
display_text = f"{header}\n---\n{collection.combined_text}"
metadata = MemoryMetadata(
memory_id=f"instant_{collection.collection_id}",
user_id=GLOBAL_MEMORY_SCOPE,
chat_id=collection.chat_id,
created_at=collection.created_at,
importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性
)
content = ContentStructure(
subject="对话上下文",
predicate="包含",
object=collection.combined_text,
display=display_text
)
chunk = MemoryChunk(
metadata=metadata,
content=content,
memory_type=MemoryType.CONTEXTUAL,
)
instant_memories.append(chunk)
return instant_memories
@staticmethod
def _extract_json_payload(response: str) -> str | None:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
@@ -1521,6 +1597,15 @@ class MemorySystem:
if self.fusion_engine:
await self.fusion_engine.maintenance()
# 清理消息集合每12小时
if self.message_collection_storage:
current_time = time.time()
if current_time - self.last_collection_cleanup_time > 12 * 3600:
logger.info("开始清理过期的消息集合...")
self.message_collection_storage.clear_all()
self.last_collection_cleanup_time = current_time
logger.info("✅ 消息集合清理完成")
logger.info("✅ 简化记忆系统维护完成")
except Exception as e:

View File

@@ -0,0 +1,75 @@
"""
消息集合处理器
负责收集消息、创建集合并将其存入向量存储。
"""
import asyncio
from collections import deque
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
class MessageCollectionProcessor:
"""处理消息集合的创建和存储"""
def __init__(self, storage: MessageCollectionStorage, buffer_size: int = 5):
self.storage = storage
self.buffer_size = buffer_size
self.message_buffers: dict[str, deque[str]] = {}
self._lock = asyncio.Lock()
async def add_message(self, message_text: str, chat_id: str):
"""添加一条新消息到指定聊天的缓冲区,并在满时触发处理"""
async with self._lock:
if not isinstance(message_text, str) or not message_text.strip():
return
if chat_id not in self.message_buffers:
self.message_buffers[chat_id] = deque(maxlen=self.buffer_size)
buffer = self.message_buffers[chat_id]
buffer.append(message_text)
logger.debug(f"消息已添加到聊天 '{chat_id}' 的缓冲区,当前数量: {len(buffer)}/{self.buffer_size}")
if len(buffer) == self.buffer_size:
await self._process_buffer(chat_id)
async def _process_buffer(self, chat_id: str):
"""处理指定聊天缓冲区中的消息,创建并存储一个集合"""
buffer = self.message_buffers.get(chat_id)
if not buffer or len(buffer) < self.buffer_size:
return
messages_to_process = list(buffer)
buffer.clear()
logger.info(f"聊天 '{chat_id}' 的消息缓冲区已满,开始创建消息集合...")
try:
combined_text = "\n".join(messages_to_process)
collection = MessageCollection(
chat_id=chat_id,
messages=messages_to_process,
combined_text=combined_text,
)
await self.storage.add_collection(collection)
logger.info(f"成功为聊天 '{chat_id}' 创建并存储了新的消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"处理聊天 '{chat_id}' 的消息缓冲区失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取处理器统计信息"""
total_buffered_messages = sum(len(buf) for buf in self.message_buffers.values())
return {
"active_buffers": len(self.message_buffers),
"total_buffered_messages": total_buffered_messages,
"buffer_capacity_per_chat": self.buffer_size,
}

View File

@@ -0,0 +1,155 @@
"""
消息集合向量存储系统
专用于存储和检索消息集合,以提供即时上下文。
"""
import asyncio
import time
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.utils.utils import get_embedding
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.config.config import global_config
logger = get_logger(__name__)
class MessageCollectionStorage:
"""消息集合向量存储"""
def __init__(self):
self.config = global_config.memory
self.vector_db_service = vector_db_service
self.collection_name = "message_collections"
self._initialize_storage()
def _initialize_storage(self):
"""初始化存储"""
try:
self.vector_db_service.get_or_create_collection(
name=self.collection_name,
metadata={"description": "短期消息集合记忆", "hnsw:space": "cosine"},
)
logger.info(f"消息集合存储初始化完成,集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"消息集合存储初始化失败: {e}", exc_info=True)
raise
async def add_collection(self, collection: MessageCollection):
"""添加一个新的消息集合,并处理容量和时间限制"""
try:
# 清理过期和超额的集合
await self._cleanup_collections()
# 向量化并存储
embedding = await get_embedding(collection.combined_text)
if not embedding:
logger.warning(f"无法为消息集合 {collection.collection_id} 生成向量,跳过存储。")
return
collection.embedding = embedding
self.vector_db_service.add(
collection_name=self.collection_name,
embeddings=[embedding],
ids=[collection.collection_id],
documents=[collection.combined_text],
metadatas=[collection.to_dict()],
)
logger.debug(f"成功存储消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"存储消息集合失败: {e}", exc_info=True)
async def _cleanup_collections(self):
"""清理超额和过期的消息集合"""
try:
# 基于时间清理
if self.config.instant_memory_retention_hours > 0:
expiration_time = time.time() - self.config.instant_memory_retention_hours * 3600
expired_docs = self.vector_db_service.get(
collection_name=self.collection_name,
where={"created_at": {"$lt": expiration_time}},
include=[], # 只获取ID
)
if expired_docs and expired_docs.get("ids"):
self.vector_db_service.delete(collection_name=self.collection_name, ids=expired_docs["ids"])
logger.info(f"删除了 {len(expired_docs['ids'])} 个过期的瞬时记忆")
# 基于数量清理
current_count = self.vector_db_service.count(self.collection_name)
if current_count > self.config.instant_memory_max_collections:
num_to_delete = current_count - self.config.instant_memory_max_collections
# 获取所有文档的元数据以进行排序
all_docs = self.vector_db_service.get(
collection_name=self.collection_name,
include=["metadatas"]
)
if all_docs and all_docs.get("ids"):
# 在内存中排序找到最旧的文档
sorted_docs = sorted(
zip(all_docs["ids"], all_docs["metadatas"]),
key=lambda item: item[1].get("created_at", 0),
)
ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]]
if ids_to_delete:
self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete)
logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合")
except Exception as e:
logger.error(f"清理消息集合失败: {e}", exc_info=True)
async def get_relevant_collection(self, query_text: str, n_results: int = 1) -> list[MessageCollection]:
"""根据查询文本检索最相关的消息集合"""
if not query_text.strip():
return []
try:
query_embedding = await get_embedding(query_text)
if not query_embedding:
return []
results = self.vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_embedding],
n_results=n_results,
)
collections = []
if results and results.get("ids") and results["ids"][0]:
for metadata in results["metadatas"][0]:
collections.append(MessageCollection.from_dict(metadata))
return collections
except Exception as e:
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
return []
def clear_all(self):
"""清空所有消息集合"""
try:
# In ChromaDB, the easiest way to clear a collection is to delete and recreate it.
self.vector_db_service.delete_collection(name=self.collection_name)
self._initialize_storage()
logger.info(f"已清空所有消息集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"清空消息集合失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
try:
count = self.vector_db_service.count(self.collection_name)
return {
"collection_name": self.collection_name,
"total_collections": count,
"storage_limit": self.config.instant_memory_max_collections,
}
except Exception as e:
logger.error(f"获取消息集合存储统计失败: {e}")
return {}

View File

@@ -315,6 +315,10 @@ class MemoryConfig(ValidatedConfigBase):
enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储")
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
instant_memory_max_collections: int = Field(default=100, ge=1, description="瞬时记忆最大集合数")
instant_memory_retention_hours: int = Field(
default=0, ge=0, description="瞬时记忆保留时间小时0表示不基于时间清理"
)
# Vector DB配置
vector_db_similarity_threshold: float = Field(

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.4.7"
version = "7.4.8"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -289,6 +289,8 @@ dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访
enable_vector_memory_storage = true # 启用Vector DB存储
enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆
enable_vector_instant_memory = true # 启用基于向量的瞬时记忆
instant_memory_max_collections = 100 # 瞬时记忆最大集合数
instant_memory_retention_hours = 0 # 瞬时记忆保留时间小时0表示不基于时间清理
# Vector DB配置
vector_db_similarity_threshold = 0.5 # Vector DB相似度阈值 (推荐范围: 0.5-0.6, 过高会导致检索不到结果)