diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 8b1666c3d..3b6ef6a46 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -7,8 +7,8 @@ import re from dataclasses import dataclass from typing import Any -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, MessageCollection -from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor from src.chat.memory_system.message_collection_storage import MessageCollectionStorage from src.common.logger import get_logger @@ -64,14 +64,9 @@ class MemoryManager: logger.info("正在初始化记忆系统...") - # 获取LLM模型 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest - - llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") - # 初始化记忆系统 - self.memory_system = await initialize_memory_system(llm_model) + from src.chat.memory_system.memory_system import get_memory_system + self.memory_system = get_memory_system() # 初始化消息集合系统 self.message_collection_storage = MessageCollectionStorage() @@ -500,60 +495,6 @@ class MemoryManager: return text return text[: max_length - 1] + "…" - async def get_relevant_message_collection(self, query_text: str, n_results: int = 3) -> list[MessageCollection]: - """获取相关的消息集合列表""" - if not self.is_initialized or not self.message_collection_storage: - return [] - - try: - return await self.message_collection_storage.get_relevant_collection(query_text, n_results=n_results) - except Exception as e: - logger.error(f"get_relevant_message_collection 失败: {e}") - return [] - - async def get_message_collection_context(self, query_text: str, chat_id: str) -> str: - """获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。""" - if not self.is_initialized or not self.message_collection_storage: - return "" - - try: - collections = await self.get_relevant_message_collection(query_text, n_results=3) - if not collections: - return "" - - # 根据传入的 chat_id 对集合进行排序 - collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True) - - context_parts = [] - for collection in collections: - if not collection.combined_text: - continue - - header = "## 📝 相关对话上下文\n" - if collection.chat_id == chat_id: - # 匹配的ID,使用更明显的标识 - context_parts.append( - f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```" - ) - else: - # 不匹配的ID - context_parts.append( - f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```" - ) - - if not context_parts: - return "" - - # 格式化消息集合为 prompt 上下文 - final_context = "\n\n---\n\n".join(context_parts) + "\n\n---" - - logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文") - return f"\n{final_context}\n" - - except Exception as e: - logger.error(f"get_message_collection_context 失败: {e}") - return "" - async def shutdown(self): """关闭增强记忆系统""" if not self.is_initialized: diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 8d2cea9d8..9da9c793d 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -18,7 +18,6 @@ import orjson from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine -from src.chat.memory_system.memory_manager import MemoryManager from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner from src.chat.memory_system.message_collection_storage import MessageCollectionStorage @@ -147,7 +146,6 @@ class MemorySystem: self.message_collection_storage: MessageCollectionStorage | None = None self.query_planner: MemoryQueryPlanner | None = None self.forgetting_engine: MemoryForgettingEngine | None = None - self.memory_manager: MemoryManager | None = None # LLM模型 self.value_assessment_model: LLMRequest | None = None @@ -173,10 +171,6 @@ class MemorySystem: async def initialize(self): """异步初始化记忆系统""" try: - # 初始化 MemoryManager - self.memory_manager = MemoryManager() - await self.memory_manager.initialize() - # 初始化LLM模型 fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None @@ -942,13 +936,10 @@ class MemorySystem: async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]: """检索瞬时记忆(消息集合)并转换为MemoryChunk""" - if not self.memory_manager: + if not self.message_collection_storage or not chat_id: return [] - if not chat_id: - return [] - - context_text = await self.memory_manager.get_message_collection_context(query_text, chat_id=chat_id) + context_text = await self.message_collection_storage.get_message_collection_context(query_text, chat_id=chat_id) if not context_text: return [] diff --git a/src/chat/memory_system/message_collection_storage.py b/src/chat/memory_system/message_collection_storage.py index d55d3ea92..22f4c75f3 100644 --- a/src/chat/memory_system/message_collection_storage.py +++ b/src/chat/memory_system/message_collection_storage.py @@ -131,6 +131,46 @@ class MessageCollectionStorage: logger.error(f"检索相关消息集合失败: {e}", exc_info=True) return [] + async def get_message_collection_context(self, query_text: str, chat_id: str) -> str: + """获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。""" + try: + collections = await self.get_relevant_collection(query_text, n_results=5) + if not collections: + return "" + + # 根据传入的 chat_id 对集合进行排序 + collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True) + + context_parts = [] + for collection in collections: + if not collection.combined_text: + continue + + header = "## 📝 相关对话上下文\n" + if collection.chat_id == chat_id: + # 匹配的ID,使用更明显的标识 + context_parts.append( + f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```" + ) + else: + # 不匹配的ID + context_parts.append( + f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```" + ) + + if not context_parts: + return "" + + # 格式化消息集合为 prompt 上下文 + final_context = "\n\n---\n\n".join(context_parts) + "\n\n---" + + logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文") + return f"\n{final_context}\n" + + except Exception as e: + logger.error(f"get_message_collection_context 失败: {e}") + return "" + def clear_all(self): """清空所有消息集合""" try: