refactor(memory): 引入 MemoryManager 统一管理瞬时记忆

引入 `MemoryManager` 类来封装和统一处理瞬时记忆(消息集合)的检索和管理逻辑。此举将瞬时记忆的相关操作从 `MemorySystem` 中解耦,提高了代码的模块化和可维护性。

主要变更:
- 创建 `MemoryManager` 类,负责消息集合的初始化、上下文检索等。
- `MemorySystem` 现在通过 `MemoryManager` 实例来获取瞬时记忆,简化了其内部实现。
- 移除了 `MemorySystem` 中原有的、分散的瞬时记忆检索代码,使其职责更单一。
This commit is contained in:
minecraft1024a
2025-10-25 20:13:32 +08:00
parent 917754b4e0
commit 5a99442064
2 changed files with 41 additions and 50 deletions

View File

@@ -18,6 +18,7 @@ import orjson
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_manager import MemoryManager
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
@@ -146,6 +147,7 @@ class MemorySystem:
self.message_collection_storage: MessageCollectionStorage | None = None
self.query_planner: MemoryQueryPlanner | None = None
self.forgetting_engine: MemoryForgettingEngine | None = None
self.memory_manager: MemoryManager | None = None
# LLM模型
self.value_assessment_model: LLMRequest | None = None
@@ -171,6 +173,10 @@ class MemorySystem:
async def initialize(self):
"""异步初始化记忆系统"""
try:
# 初始化 MemoryManager
self.memory_manager = MemoryManager()
await self.memory_manager.initialize()
# 初始化LLM模型
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
@@ -936,58 +942,43 @@ class MemorySystem:
async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]:
"""检索瞬时记忆消息集合并转换为MemoryChunk"""
if not self.message_collection_storage:
if not self.memory_manager:
return []
# 1. 优先检索当前聊天的消息集合
collections = []
if chat_id:
collections = await self.memory_manger.get_message_collection_context(query_text, chat_id=chat_id, n_results=1)
# 2. 如果当前聊天没有,或者不需要区分聊天,则进行全局检索
if not collections:
collections = await self.message_collection_storage.get_relevant_collection(query_text, chat_id=None, n_results=1)
if not collections:
if not chat_id:
return []
# 3. 将 MessageCollection 转换为 MemoryChunk
instant_memories = []
for collection in collections:
from src.chat.memory_system.memory_chunk import (
ContentStructure,
ImportanceLevel,
MemoryMetadata,
MemoryType,
)
header = f"[来自群/聊 {collection.chat_id} 的近期对话]"
if collection.chat_id == chat_id:
header = f"[🔥 来自当前聊天的近期对话]"
context_text = await self.memory_manager.get_message_collection_context(query_text, chat_id=chat_id)
if not context_text:
return []
display_text = f"{header}\n---\n{collection.combined_text}"
from src.chat.memory_system.memory_chunk import (
ContentStructure,
ImportanceLevel,
MemoryMetadata,
MemoryType,
)
metadata = MemoryMetadata(
memory_id=f"instant_{collection.collection_id}",
user_id=GLOBAL_MEMORY_SCOPE,
chat_id=collection.chat_id,
created_at=collection.created_at,
importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性
)
content = ContentStructure(
subject="对话上下文",
predicate="包含",
object=collection.combined_text,
display=display_text
)
chunk = MemoryChunk(
metadata=metadata,
content=content,
memory_type=MemoryType.CONTEXTUAL,
)
instant_memories.append(chunk)
return instant_memories
metadata = MemoryMetadata(
memory_id=f"instant_{chat_id}_{time.time()}",
user_id=GLOBAL_MEMORY_SCOPE,
chat_id=chat_id,
created_at=time.time(),
importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性
)
content = ContentStructure(
subject="近期对话上下文",
predicate="相关内容",
object=context_text,
display=context_text,
)
chunk = MemoryChunk(
metadata=metadata,
content=content,
memory_type=MemoryType.CONTEXTUAL,
)
return [chunk]
@staticmethod
def _extract_json_payload(response: str) -> str | None:

View File

@@ -81,22 +81,22 @@ class MessageCollectionStorage:
current_count = self.vector_db_service.count(self.collection_name)
if current_count > self.config.instant_memory_max_collections:
num_to_delete = current_count - self.config.instant_memory_max_collections
# 获取所有文档的元数据以进行排序
all_docs = self.vector_db_service.get(
collection_name=self.collection_name,
include=["metadatas"]
)
if all_docs and all_docs.get("ids"):
# 在内存中排序找到最旧的文档
sorted_docs = sorted(
zip(all_docs["ids"], all_docs["metadatas"]),
key=lambda item: item[1].get("created_at", 0),
)
ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]]
if ids_to_delete:
self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete)
logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合")