refactor(memory): 简化记忆系统架构,移除 MemoryManager
将 `get_message_collection_context` 方法从 `MemoryManager` 移动到 `MessageCollectionStorage`,并直接在 `MemorySystem` 中调用它。这一变更简化了依赖关系,并消除了 `MemoryManager` 类,使其职责更加清晰。 主要变更: - 从 `MemoryManager` 中移除 `get_message_collection_context` 并将其功能迁移。 - 更新 `MemorySystem` 以直接使用 `MessageCollectionStorage` 来检索上下文。 - 移除 `MemoryManager` 的初始化和依赖,因为它不再需要。 - 将 `MemorySystem` 的初始化改为单例模式,以确保全局唯一实例。
This commit is contained in:
@@ -7,8 +7,8 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, MessageCollection
|
||||
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor
|
||||
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
||||
from src.common.logger import get_logger
|
||||
@@ -64,14 +64,9 @@ class MemoryManager:
|
||||
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 获取LLM模型
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
# 初始化记忆系统
|
||||
self.memory_system = await initialize_memory_system(llm_model)
|
||||
from src.chat.memory_system.memory_system import get_memory_system
|
||||
self.memory_system = get_memory_system()
|
||||
|
||||
# 初始化消息集合系统
|
||||
self.message_collection_storage = MessageCollectionStorage()
|
||||
@@ -500,60 +495,6 @@ class MemoryManager:
|
||||
return text
|
||||
return text[: max_length - 1] + "…"
|
||||
|
||||
async def get_relevant_message_collection(self, query_text: str, n_results: int = 3) -> list[MessageCollection]:
|
||||
"""获取相关的消息集合列表"""
|
||||
if not self.is_initialized or not self.message_collection_storage:
|
||||
return []
|
||||
|
||||
try:
|
||||
return await self.message_collection_storage.get_relevant_collection(query_text, n_results=n_results)
|
||||
except Exception as e:
|
||||
logger.error(f"get_relevant_message_collection 失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
|
||||
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
|
||||
if not self.is_initialized or not self.message_collection_storage:
|
||||
return ""
|
||||
|
||||
try:
|
||||
collections = await self.get_relevant_message_collection(query_text, n_results=3)
|
||||
if not collections:
|
||||
return ""
|
||||
|
||||
# 根据传入的 chat_id 对集合进行排序
|
||||
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
|
||||
|
||||
context_parts = []
|
||||
for collection in collections:
|
||||
if not collection.combined_text:
|
||||
continue
|
||||
|
||||
header = "## 📝 相关对话上下文\n"
|
||||
if collection.chat_id == chat_id:
|
||||
# 匹配的ID,使用更明显的标识
|
||||
context_parts.append(
|
||||
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
|
||||
)
|
||||
else:
|
||||
# 不匹配的ID
|
||||
context_parts.append(
|
||||
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
|
||||
)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
# 格式化消息集合为 prompt 上下文
|
||||
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
|
||||
|
||||
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
|
||||
return f"\n{final_context}\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"get_message_collection_context 失败: {e}")
|
||||
return ""
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭增强记忆系统"""
|
||||
if not self.is_initialized:
|
||||
|
||||
@@ -18,7 +18,6 @@ import orjson
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_manager import MemoryManager
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
||||
|
||||
@@ -147,7 +146,6 @@ class MemorySystem:
|
||||
self.message_collection_storage: MessageCollectionStorage | None = None
|
||||
self.query_planner: MemoryQueryPlanner | None = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
self.memory_manager: MemoryManager | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest | None = None
|
||||
@@ -173,10 +171,6 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
# 初始化 MemoryManager
|
||||
self.memory_manager = MemoryManager()
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
@@ -942,13 +936,10 @@ class MemorySystem:
|
||||
|
||||
async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]:
|
||||
"""检索瞬时记忆(消息集合)并转换为MemoryChunk"""
|
||||
if not self.memory_manager:
|
||||
if not self.message_collection_storage or not chat_id:
|
||||
return []
|
||||
|
||||
if not chat_id:
|
||||
return []
|
||||
|
||||
context_text = await self.memory_manager.get_message_collection_context(query_text, chat_id=chat_id)
|
||||
context_text = await self.message_collection_storage.get_message_collection_context(query_text, chat_id=chat_id)
|
||||
if not context_text:
|
||||
return []
|
||||
|
||||
|
||||
@@ -131,6 +131,46 @@ class MessageCollectionStorage:
|
||||
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
|
||||
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
|
||||
try:
|
||||
collections = await self.get_relevant_collection(query_text, n_results=5)
|
||||
if not collections:
|
||||
return ""
|
||||
|
||||
# 根据传入的 chat_id 对集合进行排序
|
||||
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
|
||||
|
||||
context_parts = []
|
||||
for collection in collections:
|
||||
if not collection.combined_text:
|
||||
continue
|
||||
|
||||
header = "## 📝 相关对话上下文\n"
|
||||
if collection.chat_id == chat_id:
|
||||
# 匹配的ID,使用更明显的标识
|
||||
context_parts.append(
|
||||
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
|
||||
)
|
||||
else:
|
||||
# 不匹配的ID
|
||||
context_parts.append(
|
||||
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
|
||||
)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
# 格式化消息集合为 prompt 上下文
|
||||
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
|
||||
|
||||
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
|
||||
return f"\n{final_context}\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"get_message_collection_context 失败: {e}")
|
||||
return ""
|
||||
|
||||
def clear_all(self):
|
||||
"""清空所有消息集合"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user