refactor(memory): 简化记忆系统架构,移除 MemoryManager

将 `get_message_collection_context` 方法从 `MemoryManager` 移动到 `MessageCollectionStorage`,并直接在 `MemorySystem` 中调用它。这一变更简化了依赖关系,并消除了 `MemoryManager` 类,使其职责更加清晰。

主要变更:
- 从 `MemoryManager` 中移除 `get_message_collection_context` 并将其功能迁移。
- 更新 `MemorySystem` 以直接使用 `MessageCollectionStorage` 来检索上下文。
- 移除 `MemoryManager` 的初始化和依赖,因为它不再需要。
- 将 `MemorySystem` 的初始化改为单例模式,以确保全局唯一实例。
This commit is contained in:
minecraft1024a
2025-10-25 20:20:07 +08:00
parent 5a99442064
commit bb9c31e63a
3 changed files with 46 additions and 74 deletions

View File

@@ -7,8 +7,8 @@ import re
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, MessageCollection
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
@@ -64,14 +64,9 @@ class MemoryManager:
logger.info("正在初始化记忆系统...")
# 获取LLM模型
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
# 初始化记忆系统
self.memory_system = await initialize_memory_system(llm_model)
from src.chat.memory_system.memory_system import get_memory_system
self.memory_system = get_memory_system()
# 初始化消息集合系统
self.message_collection_storage = MessageCollectionStorage()
@@ -500,60 +495,6 @@ class MemoryManager:
return text
return text[: max_length - 1] + ""
async def get_relevant_message_collection(self, query_text: str, n_results: int = 3) -> list[MessageCollection]:
"""获取相关的消息集合列表"""
if not self.is_initialized or not self.message_collection_storage:
return []
try:
return await self.message_collection_storage.get_relevant_collection(query_text, n_results=n_results)
except Exception as e:
logger.error(f"get_relevant_message_collection 失败: {e}")
return []
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
if not self.is_initialized or not self.message_collection_storage:
return ""
try:
collections = await self.get_relevant_message_collection(query_text, n_results=3)
if not collections:
return ""
# 根据传入的 chat_id 对集合进行排序
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
context_parts = []
for collection in collections:
if not collection.combined_text:
continue
header = "## 📝 相关对话上下文\n"
if collection.chat_id == chat_id:
# 匹配的ID使用更明显的标识
context_parts.append(
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
)
else:
# 不匹配的ID
context_parts.append(
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
)
if not context_parts:
return ""
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
except Exception as e:
logger.error(f"get_message_collection_context 失败: {e}")
return ""
async def shutdown(self):
"""关闭增强记忆系统"""
if not self.is_initialized:

View File

@@ -18,7 +18,6 @@ import orjson
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_manager import MemoryManager
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
@@ -147,7 +146,6 @@ class MemorySystem:
self.message_collection_storage: MessageCollectionStorage | None = None
self.query_planner: MemoryQueryPlanner | None = None
self.forgetting_engine: MemoryForgettingEngine | None = None
self.memory_manager: MemoryManager | None = None
# LLM模型
self.value_assessment_model: LLMRequest | None = None
@@ -173,10 +171,6 @@ class MemorySystem:
async def initialize(self):
"""异步初始化记忆系统"""
try:
# 初始化 MemoryManager
self.memory_manager = MemoryManager()
await self.memory_manager.initialize()
# 初始化LLM模型
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
@@ -942,13 +936,10 @@ class MemorySystem:
async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]:
"""检索瞬时记忆消息集合并转换为MemoryChunk"""
if not self.memory_manager:
if not self.message_collection_storage or not chat_id:
return []
if not chat_id:
return []
context_text = await self.memory_manager.get_message_collection_context(query_text, chat_id=chat_id)
context_text = await self.message_collection_storage.get_message_collection_context(query_text, chat_id=chat_id)
if not context_text:
return []

View File

@@ -131,6 +131,46 @@ class MessageCollectionStorage:
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
return []
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
try:
collections = await self.get_relevant_collection(query_text, n_results=5)
if not collections:
return ""
# 根据传入的 chat_id 对集合进行排序
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
context_parts = []
for collection in collections:
if not collection.combined_text:
continue
header = "## 📝 相关对话上下文\n"
if collection.chat_id == chat_id:
# 匹配的ID使用更明显的标识
context_parts.append(
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
)
else:
# 不匹配的ID
context_parts.append(
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
)
if not context_parts:
return ""
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
except Exception as e:
logger.error(f"get_message_collection_context 失败: {e}")
return ""
def clear_all(self):
"""清空所有消息集合"""
try: