refactor(memory): 引入 MemoryManager 统一管理瞬时记忆
引入 `MemoryManager` 类来封装和统一处理瞬时记忆(消息集合)的检索和管理逻辑。此举将瞬时记忆的相关操作从 `MemorySystem` 中解耦,提高了代码的模块化和可维护性。 主要变更: - 创建 `MemoryManager` 类,负责消息集合的初始化、上下文检索等。 - `MemorySystem` 现在通过 `MemoryManager` 实例来获取瞬时记忆,简化了其内部实现。 - 移除了 `MemorySystem` 中原有的、分散的瞬时记忆检索代码,使其职责更单一。
This commit is contained in:
@@ -18,6 +18,7 @@ import orjson
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_manager import MemoryManager
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
||||
|
||||
@@ -146,6 +147,7 @@ class MemorySystem:
|
||||
self.message_collection_storage: MessageCollectionStorage | None = None
|
||||
self.query_planner: MemoryQueryPlanner | None = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
self.memory_manager: MemoryManager | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest | None = None
|
||||
@@ -171,6 +173,10 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
# 初始化 MemoryManager
|
||||
self.memory_manager = MemoryManager()
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
@@ -936,58 +942,43 @@ class MemorySystem:
|
||||
|
||||
async def _retrieve_instant_memories(self, query_text: str, chat_id: str | None) -> list[MemoryChunk]:
|
||||
"""检索瞬时记忆(消息集合)并转换为MemoryChunk"""
|
||||
if not self.message_collection_storage:
|
||||
if not self.memory_manager:
|
||||
return []
|
||||
|
||||
# 1. 优先检索当前聊天的消息集合
|
||||
collections = []
|
||||
if chat_id:
|
||||
collections = await self.memory_manger.get_message_collection_context(query_text, chat_id=chat_id, n_results=1)
|
||||
|
||||
# 2. 如果当前聊天没有,或者不需要区分聊天,则进行全局检索
|
||||
if not collections:
|
||||
collections = await self.message_collection_storage.get_relevant_collection(query_text, chat_id=None, n_results=1)
|
||||
|
||||
if not collections:
|
||||
if not chat_id:
|
||||
return []
|
||||
|
||||
# 3. 将 MessageCollection 转换为 MemoryChunk
|
||||
instant_memories = []
|
||||
for collection in collections:
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
ContentStructure,
|
||||
ImportanceLevel,
|
||||
MemoryMetadata,
|
||||
MemoryType,
|
||||
)
|
||||
|
||||
header = f"[来自群/聊 {collection.chat_id} 的近期对话]"
|
||||
if collection.chat_id == chat_id:
|
||||
header = f"[🔥 来自当前聊天的近期对话]"
|
||||
context_text = await self.memory_manager.get_message_collection_context(query_text, chat_id=chat_id)
|
||||
if not context_text:
|
||||
return []
|
||||
|
||||
display_text = f"{header}\n---\n{collection.combined_text}"
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
ContentStructure,
|
||||
ImportanceLevel,
|
||||
MemoryMetadata,
|
||||
MemoryType,
|
||||
)
|
||||
|
||||
metadata = MemoryMetadata(
|
||||
memory_id=f"instant_{collection.collection_id}",
|
||||
user_id=GLOBAL_MEMORY_SCOPE,
|
||||
chat_id=collection.chat_id,
|
||||
created_at=collection.created_at,
|
||||
importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性
|
||||
)
|
||||
content = ContentStructure(
|
||||
subject="对话上下文",
|
||||
predicate="包含",
|
||||
object=collection.combined_text,
|
||||
display=display_text
|
||||
)
|
||||
chunk = MemoryChunk(
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
memory_type=MemoryType.CONTEXTUAL,
|
||||
)
|
||||
instant_memories.append(chunk)
|
||||
|
||||
return instant_memories
|
||||
metadata = MemoryMetadata(
|
||||
memory_id=f"instant_{chat_id}_{time.time()}",
|
||||
user_id=GLOBAL_MEMORY_SCOPE,
|
||||
chat_id=chat_id,
|
||||
created_at=time.time(),
|
||||
importance=ImportanceLevel.HIGH, # 瞬时记忆通常具有高重要性
|
||||
)
|
||||
content = ContentStructure(
|
||||
subject="近期对话上下文",
|
||||
predicate="相关内容",
|
||||
object=context_text,
|
||||
display=context_text,
|
||||
)
|
||||
chunk = MemoryChunk(
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
memory_type=MemoryType.CONTEXTUAL,
|
||||
)
|
||||
|
||||
return [chunk]
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_payload(response: str) -> str | None:
|
||||
|
||||
@@ -81,22 +81,22 @@ class MessageCollectionStorage:
|
||||
current_count = self.vector_db_service.count(self.collection_name)
|
||||
if current_count > self.config.instant_memory_max_collections:
|
||||
num_to_delete = current_count - self.config.instant_memory_max_collections
|
||||
|
||||
|
||||
# 获取所有文档的元数据以进行排序
|
||||
all_docs = self.vector_db_service.get(
|
||||
collection_name=self.collection_name,
|
||||
include=["metadatas"]
|
||||
)
|
||||
|
||||
|
||||
if all_docs and all_docs.get("ids"):
|
||||
# 在内存中排序找到最旧的文档
|
||||
sorted_docs = sorted(
|
||||
zip(all_docs["ids"], all_docs["metadatas"]),
|
||||
key=lambda item: item[1].get("created_at", 0),
|
||||
)
|
||||
|
||||
|
||||
ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]]
|
||||
|
||||
|
||||
if ids_to_delete:
|
||||
self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete)
|
||||
logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合")
|
||||
|
||||
Reference in New Issue
Block a user