From 2200a9ff2a71460975cc26bfaaf4cc7464eaa6ea Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 25 Oct 2025 20:07:58 +0800 Subject: [PATCH] =?UTF-8?q?feat(memory):=20=E5=BC=95=E5=85=A5=E5=9F=BA?= =?UTF-8?q?=E4=BA=8E=E5=90=91=E9=87=8F=E7=9A=84=E7=9E=AC=E6=97=B6=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入了一个新的瞬时记忆系统,该系统将短期对话片段(消息集合)存储在专门的向量数据库中,以提供更即时、更相关的上下文。 该系统通过以下组件实现: - **MessageCollection**: 用于封装一组相关消息的数据结构。 - **MessageCollectionStorage**: 负责将消息集合向量化并存入专用的ChromaDB集合,同时管理集合的生命周期(基于数量和时间清理)。 - **MessageCollectionProcessor**: 缓冲每个聊天的消息,当达到阈值时,将它们组合成一个`MessageCollection`并交由Storage处理。 - **集成**: `MemoryManager`和`MemorySystem`已更新,将瞬时记忆无缝融合到现有的记忆检索流程中,优先展示来自当前聊天的上下文。 此外,还进行了以下调整: - 移除`orjson.dumps`中的`ensure_ascii=False`参数,以遵循`orjson`的默认行为,该行为始终返回UTF-8字节串,从而简化了编码处理。 - 在配置文件中增加了瞬时记忆最大集合数和保留时间的选项。 --- src/chat/memory_system/memory_builder.py | 2 +- src/chat/memory_system/memory_chunk.py | 37 ++++- src/chat/memory_system/memory_manager.py | 71 +++++++- src/chat/memory_system/memory_system.py | 89 +++++++++- .../message_collection_processor.py | 75 +++++++++ .../message_collection_storage.py | 155 ++++++++++++++++++ src/config/official_configs.py | 4 + template/bot_config_template.toml | 4 +- 8 files changed, 431 insertions(+), 6 deletions(-) create mode 100644 src/chat/memory_system/message_collection_processor.py create mode 100644 src/chat/memory_system/message_collection_storage.py diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index d4aea4153..4e5d2e0e7 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -702,7 +702,7 @@ class MemoryBuilder: if isinstance(value, list | dict): try: - value = orjson.dumps(value, ensure_ascii=False).decode("utf-8") + value = orjson.dumps(value).decode("utf-8") except Exception: value = str(value) diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index 6fc746ce3..c3c5fe0ee 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -454,7 +454,7 @@ class MemoryChunk: def to_json(self) -> str: """转换为JSON字符串""" - return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8") + return orjson.dumps(self.to_dict()).decode("utf-8") @classmethod def from_json(cls, json_str: str) -> "MemoryChunk": @@ -610,3 +610,38 @@ def create_memory_chunk( chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs) return chunk + + +@dataclass +class MessageCollection: + """消息集合数据结构""" + + collection_id: str = field(default_factory=lambda: str(uuid.uuid4())) + chat_id: str | None = None # 聊天ID(群聊或私聊) + messages: list[str] = field(default_factory=list) + combined_text: str = "" + created_at: float = field(default_factory=time.time) + embedding: list[float] | None = None + + def to_dict(self) -> dict[str, Any]: + """转换为字典格式""" + return { + "collection_id": self.collection_id, + "chat_id": self.chat_id, + "messages": self.messages, + "combined_text": self.combined_text, + "created_at": self.created_at, + "embedding": self.embedding, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MessageCollection": + """从字典创建实例""" + return cls( + collection_id=data.get("collection_id", str(uuid.uuid4())), + chat_id=data.get("chat_id"), + messages=data.get("messages", []), + combined_text=data.get("combined_text", ""), + created_at=data.get("created_at", time.time()), + embedding=data.get("embedding"), + ) diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index c4c76481b..8b1666c3d 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -7,8 +7,10 @@ import re from dataclasses import dataclass from typing import Any -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, MessageCollection from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system +from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor +from src.chat.memory_system.message_collection_storage import MessageCollectionStorage from src.common.logger import get_logger logger = get_logger(__name__) @@ -33,6 +35,8 @@ class MemoryManager: def __init__(self): self.memory_system: MemorySystem | None = None + self.message_collection_storage: MessageCollectionStorage | None = None + self.message_collection_processor: MessageCollectionProcessor | None = None self.is_initialized = False self.user_cache = {} # 用户记忆缓存 @@ -69,6 +73,10 @@ class MemoryManager: # 初始化记忆系统 self.memory_system = await initialize_memory_system(llm_model) + # 初始化消息集合系统 + self.message_collection_storage = MessageCollectionStorage() + self.message_collection_processor = MessageCollectionProcessor(self.message_collection_storage) + self.is_initialized = True logger.info(" 记忆系统初始化完成") @@ -76,6 +84,8 @@ class MemoryManager: logger.error(f"记忆系统初始化失败: {e}") # 如果系统初始化失败,创建一个空的管理器避免系统崩溃 self.memory_system = None + self.message_collection_storage = None + self.message_collection_processor = None self.is_initialized = True # 标记为已初始化但系统不可用 def get_hippocampus(self): @@ -235,6 +245,11 @@ class MemoryManager: return [] try: + # 将消息添加到消息集合处理器 + chat_id = context.get("chat_id") + if self.message_collection_processor and chat_id: + await self.message_collection_processor.add_message(conversation_text, chat_id) + payload_context = dict(context or {}) payload_context.setdefault("conversation_text", conversation_text) if timestamp is not None: @@ -485,6 +500,60 @@ class MemoryManager: return text return text[: max_length - 1] + "…" + async def get_relevant_message_collection(self, query_text: str, n_results: int = 3) -> list[MessageCollection]: + """获取相关的消息集合列表""" + if not self.is_initialized or not self.message_collection_storage: + return [] + + try: + return await self.message_collection_storage.get_relevant_collection(query_text, n_results=n_results) + except Exception as e: + logger.error(f"get_relevant_message_collection 失败: {e}") + return [] + + async def get_message_collection_context(self, query_text: str, chat_id: str) -> str: + """获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。""" + if not self.is_initialized or not self.message_collection_storage: + return "" + + try: + collections = await self.get_relevant_message_collection(query_text, n_results=3) + if not collections: + return "" + + # 根据传入的 chat_id 对集合进行排序 + collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True) + + context_parts = [] + for collection in collections: + if not collection.combined_text: + continue + + header = "## 📝 相关对话上下文\n" + if collection.chat_id == chat_id: + # 匹配的ID,使用更明显的标识 + context_parts.append( + f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```" + ) + else: + # 不匹配的ID + context_parts.append( + f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```" + ) + + if not context_parts: + return "" + + # 格式化消息集合为 prompt 上下文 + final_context = "\n\n---\n\n".join(context_parts) + "\n\n---" + + logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文") + return f"\n{final_context}\n" + + except Exception as e: + logger.error(f"get_message_collection_context 失败: {e}") + return "" + async def shutdown(self): """关闭增强记忆系统""" if not self.is_initialized: diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 2551966a1..4c52c26e8 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -19,6 +19,7 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner +from src.chat.memory_system.message_collection_storage import MessageCollectionStorage # 记忆采样模式枚举 @@ -142,6 +143,7 @@ class MemorySystem: self.memory_builder: MemoryBuilder | None = None self.fusion_engine: MemoryFusionEngine | None = None self.unified_storage: VectorMemoryStorage | None = None # 统一存储系统 + self.message_collection_storage: MessageCollectionStorage | None = None self.query_planner: MemoryQueryPlanner | None = None self.forgetting_engine: MemoryForgettingEngine | None = None @@ -153,6 +155,7 @@ class MemorySystem: self.total_memories = 0 self.last_build_time = None self.last_retrieval_time = None + self.last_collection_cleanup_time: float = time.time() # 构建节流记录 self._last_memory_build_times: dict[str, float] = {} @@ -199,6 +202,9 @@ class MemorySystem: self.memory_builder = MemoryBuilder(self.memory_extraction_model) self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold) + # 初始化消息集合存储 + self.message_collection_storage = MessageCollectionStorage() + # 初始化Vector DB存储系统(替代旧的unified_memory_storage) from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig @@ -694,7 +700,7 @@ class MemorySystem: limit: int = 5, **kwargs, ) -> list[MemoryChunk]: - """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)""" + """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排),并融合瞬时记忆""" raw_query = query_text or kwargs.get("query") if not raw_query: raise ValueError("query_text 或 query 参数不能为空") @@ -839,7 +845,22 @@ class MemorySystem: scored_memories.sort(key=lambda x: x[1], reverse=True) # 返回 Top-K - final_memories = [mem for mem, score, details in scored_memories[:effective_limit]] + final_memories = [mem for mem, score, details in scored_memories] + + # === 新增:融合瞬时记忆 === + try: + chat_id = normalized_context.get("chat_id") + instant_memories = await self._retrieve_instant_memories(raw_query, chat_id) + if instant_memories: + # 将瞬时记忆放在列表最前面 + final_memories = instant_memories + final_memories + logger.info(f"融合了 {len(instant_memories)} 条瞬时记忆") + + except Exception as e: + logger.warning(f"检索瞬时记忆失败: {e}", exc_info=True) + + # 最终截断 + final_memories = final_memories[:effective_limit] retrieval_time = time.time() - start_time @@ -913,6 +934,61 @@ class MemorySystem: logger.error(f"❌ 记忆检索失败: {e}", exc_info=True) raise + async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]: + """检索瞬时记忆(消息集合)并转换为MemoryChunk""" + if not self.message_collection_storage: + return [] + + # 1. 优先检索当前聊天的消息集合 + collections = [] + if chat_id: + collections = await self.memory_manger.get_message_collection_context(query_text, chat_id=chat_id, n_results=1) + + # 2. 如果当前聊天没有,或者不需要区分聊天,则进行全局检索 + if not collections: + collections = await self.message_collection_storage.get_relevant_collection(query_text, chat_id=None, n_results=1) + + if not collections: + return [] + + # 3. 将 MessageCollection 转换为 MemoryChunk + instant_memories = [] + for collection in collections: + from src.chat.memory_system.memory_chunk import ( + ContentStructure, + ImportanceLevel, + MemoryMetadata, + MemoryType, + ) + + header = f"[来自群/聊 {collection.chat_id} 的近期对话]" + if collection.chat_id == chat_id: + header = f"[🔥 来自当前聊天的近期对话]" + + display_text = f"{header}\n---\n{collection.combined_text}" + + metadata = MemoryMetadata( + memory_id=f"instant_{collection.collection_id}", + user_id=GLOBAL_MEMORY_SCOPE, + chat_id=collection.chat_id, + created_at=collection.created_at, + importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性 + ) + content = ContentStructure( + subject="对话上下文", + predicate="包含", + object=collection.combined_text, + display=display_text + ) + chunk = MemoryChunk( + metadata=metadata, + content=content, + memory_type=MemoryType.CONTEXTUAL, + ) + instant_memories.append(chunk) + + return instant_memories + @staticmethod def _extract_json_payload(response: str) -> str | None: """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" @@ -1523,6 +1599,15 @@ class MemorySystem: if self.fusion_engine: await self.fusion_engine.maintenance() + # 清理消息集合(每12小时) + if self.message_collection_storage: + current_time = time.time() + if current_time - self.last_collection_cleanup_time > 12 * 3600: + logger.info("开始清理过期的消息集合...") + self.message_collection_storage.clear_all() + self.last_collection_cleanup_time = current_time + logger.info("✅ 消息集合清理完成") + logger.info("✅ 简化记忆系统维护完成") except Exception as e: diff --git a/src/chat/memory_system/message_collection_processor.py b/src/chat/memory_system/message_collection_processor.py new file mode 100644 index 000000000..756250dc4 --- /dev/null +++ b/src/chat/memory_system/message_collection_processor.py @@ -0,0 +1,75 @@ +""" +消息集合处理器 +负责收集消息、创建集合并将其存入向量存储。 +""" + +import asyncio +from collections import deque +from typing import Any + +from src.chat.memory_system.memory_chunk import MessageCollection +from src.chat.memory_system.message_collection_storage import MessageCollectionStorage +from src.common.logger import get_logger + +logger = get_logger(__name__) + + +class MessageCollectionProcessor: + """处理消息集合的创建和存储""" + + def __init__(self, storage: MessageCollectionStorage, buffer_size: int = 5): + self.storage = storage + self.buffer_size = buffer_size + self.message_buffers: dict[str, deque[str]] = {} + self._lock = asyncio.Lock() + + async def add_message(self, message_text: str, chat_id: str): + """添加一条新消息到指定聊天的缓冲区,并在满时触发处理""" + async with self._lock: + if not isinstance(message_text, str) or not message_text.strip(): + return + + if chat_id not in self.message_buffers: + self.message_buffers[chat_id] = deque(maxlen=self.buffer_size) + + buffer = self.message_buffers[chat_id] + buffer.append(message_text) + logger.debug(f"消息已添加到聊天 '{chat_id}' 的缓冲区,当前数量: {len(buffer)}/{self.buffer_size}") + + if len(buffer) == self.buffer_size: + await self._process_buffer(chat_id) + + async def _process_buffer(self, chat_id: str): + """处理指定聊天缓冲区中的消息,创建并存储一个集合""" + buffer = self.message_buffers.get(chat_id) + if not buffer or len(buffer) < self.buffer_size: + return + + messages_to_process = list(buffer) + buffer.clear() + + logger.info(f"聊天 '{chat_id}' 的消息缓冲区已满,开始创建消息集合...") + + try: + combined_text = "\n".join(messages_to_process) + + collection = MessageCollection( + chat_id=chat_id, + messages=messages_to_process, + combined_text=combined_text, + ) + + await self.storage.add_collection(collection) + logger.info(f"成功为聊天 '{chat_id}' 创建并存储了新的消息集合: {collection.collection_id}") + + except Exception as e: + logger.error(f"处理聊天 '{chat_id}' 的消息缓冲区失败: {e}", exc_info=True) + + def get_stats(self) -> dict[str, Any]: + """获取处理器统计信息""" + total_buffered_messages = sum(len(buf) for buf in self.message_buffers.values()) + return { + "active_buffers": len(self.message_buffers), + "total_buffered_messages": total_buffered_messages, + "buffer_capacity_per_chat": self.buffer_size, + } \ No newline at end of file diff --git a/src/chat/memory_system/message_collection_storage.py b/src/chat/memory_system/message_collection_storage.py new file mode 100644 index 000000000..92c42dcc7 --- /dev/null +++ b/src/chat/memory_system/message_collection_storage.py @@ -0,0 +1,155 @@ +""" +消息集合向量存储系统 +专用于存储和检索消息集合,以提供即时上下文。 +""" + +import asyncio +import time +from typing import Any + +from src.chat.memory_system.memory_chunk import MessageCollection +from src.chat.utils.utils import get_embedding +from src.common.logger import get_logger +from src.common.vector_db import vector_db_service +from src.config.config import global_config + +logger = get_logger(__name__) + +class MessageCollectionStorage: + """消息集合向量存储""" + + def __init__(self): + self.config = global_config.memory + self.vector_db_service = vector_db_service + self.collection_name = "message_collections" + self._initialize_storage() + + def _initialize_storage(self): + """初始化存储""" + try: + self.vector_db_service.get_or_create_collection( + name=self.collection_name, + metadata={"description": "短期消息集合记忆", "hnsw:space": "cosine"}, + ) + logger.info(f"消息集合存储初始化完成,集合: '{self.collection_name}'") + except Exception as e: + logger.error(f"消息集合存储初始化失败: {e}", exc_info=True) + raise + + async def add_collection(self, collection: MessageCollection): + """添加一个新的消息集合,并处理容量和时间限制""" + try: + # 清理过期和超额的集合 + await self._cleanup_collections() + + # 向量化并存储 + embedding = await get_embedding(collection.combined_text) + if not embedding: + logger.warning(f"无法为消息集合 {collection.collection_id} 生成向量,跳过存储。") + return + + collection.embedding = embedding + + self.vector_db_service.add( + collection_name=self.collection_name, + embeddings=[embedding], + ids=[collection.collection_id], + documents=[collection.combined_text], + metadatas=[collection.to_dict()], + ) + logger.debug(f"成功存储消息集合: {collection.collection_id}") + + except Exception as e: + logger.error(f"存储消息集合失败: {e}", exc_info=True) + + async def _cleanup_collections(self): + """清理超额和过期的消息集合""" + try: + # 基于时间清理 + if self.config.instant_memory_retention_hours > 0: + expiration_time = time.time() - self.config.instant_memory_retention_hours * 3600 + expired_docs = self.vector_db_service.get( + collection_name=self.collection_name, + where={"created_at": {"$lt": expiration_time}}, + include=[], # 只获取ID + ) + if expired_docs and expired_docs.get("ids"): + self.vector_db_service.delete(collection_name=self.collection_name, ids=expired_docs["ids"]) + logger.info(f"删除了 {len(expired_docs['ids'])} 个过期的瞬时记忆") + + # 基于数量清理 + current_count = self.vector_db_service.count(self.collection_name) + if current_count > self.config.instant_memory_max_collections: + num_to_delete = current_count - self.config.instant_memory_max_collections + + # 获取所有文档的元数据以进行排序 + all_docs = self.vector_db_service.get( + collection_name=self.collection_name, + include=["metadatas"] + ) + + if all_docs and all_docs.get("ids"): + # 在内存中排序找到最旧的文档 + sorted_docs = sorted( + zip(all_docs["ids"], all_docs["metadatas"]), + key=lambda item: item[1].get("created_at", 0), + ) + + ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]] + + if ids_to_delete: + self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete) + logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合") + + except Exception as e: + logger.error(f"清理消息集合失败: {e}", exc_info=True) + + + async def get_relevant_collection(self, query_text: str, n_results: int = 1) -> list[MessageCollection]: + """根据查询文本检索最相关的消息集合""" + if not query_text.strip(): + return [] + + try: + query_embedding = await get_embedding(query_text) + if not query_embedding: + return [] + + results = self.vector_db_service.query( + collection_name=self.collection_name, + query_embeddings=[query_embedding], + n_results=n_results, + ) + + collections = [] + if results and results.get("ids") and results["ids"][0]: + for metadata in results["metadatas"][0]: + collections.append(MessageCollection.from_dict(metadata)) + + return collections + except Exception as e: + logger.error(f"检索相关消息集合失败: {e}", exc_info=True) + return [] + + def clear_all(self): + """清空所有消息集合""" + try: + # In ChromaDB, the easiest way to clear a collection is to delete and recreate it. + self.vector_db_service.delete_collection(name=self.collection_name) + self._initialize_storage() + logger.info(f"已清空所有消息集合: '{self.collection_name}'") + except Exception as e: + logger.error(f"清空消息集合失败: {e}", exc_info=True) + + def get_stats(self) -> dict[str, Any]: + """获取存储统计信息""" + try: + count = self.vector_db_service.count(self.collection_name) + return { + "collection_name": self.collection_name, + "total_collections": count, + "storage_limit": self.config.instant_memory_max_collections, + } + except Exception as e: + logger.error(f"获取消息集合存储统计失败: {e}") + return {} \ No newline at end of file diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 1080180ac..19731b4fb 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -315,6 +315,10 @@ class MemoryConfig(ValidatedConfigBase): enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储") enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆") enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆") + instant_memory_max_collections: int = Field(default=100, ge=1, description="瞬时记忆最大集合数") + instant_memory_retention_hours: int = Field( + default=0, ge=0, description="瞬时记忆保留时间(小时),0表示不基于时间清理" + ) # Vector DB配置 vector_db_similarity_threshold: float = Field( diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 746fa0a33..a67bdfbd7 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.4.7" +version = "7.4.8" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -289,6 +289,8 @@ dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访 enable_vector_memory_storage = true # 启用Vector DB存储 enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆 enable_vector_instant_memory = true # 启用基于向量的瞬时记忆 +instant_memory_max_collections = 100 # 瞬时记忆最大集合数 +instant_memory_retention_hours = 0 # 瞬时记忆保留时间(小时),0表示不基于时间清理 # Vector DB配置 vector_db_similarity_threshold = 0.5 # Vector DB相似度阈值 (推荐范围: 0.5-0.6, 过高会导致检索不到结果)