feat: 重构统一记忆管理器,整合聊天历史上下文并优化记忆块转移逻辑

This commit is contained in:
Windpicker-owo
2025-11-18 20:39:05 +08:00
parent dc3ad19809
commit 999d7b285f
3 changed files with 144 additions and 614 deletions

View File

@@ -177,7 +177,7 @@ class UnifiedMemoryManager:
# 立即将该块转移到短期记忆
async def search_memories(
self, query_text: str, use_judge: bool = True
self, query_text: str, use_judge: bool = True, recent_chat_history: str = ""
) -> dict[str, Any]:
"""
智能检索记忆
@@ -190,6 +190,7 @@ class UnifiedMemoryManager:
Args:
query_text: 查询文本
use_judge: 是否使用裁判模型
recent_chat_history: 最近的聊天历史上下文(可选)
Returns:
检索结果字典,包含:
@@ -213,24 +214,20 @@ class UnifiedMemoryManager:
perceptual_blocks = await self.perceptual_manager.recall_blocks(query_text)
short_term_memories = await self.short_term_manager.search_memories(query_text)
# 步骤1.5: 检查并处理需要转移的记忆块
# 当某个块的召回次数达到阈值时,立即转移到短期记忆
# 步骤1.5: 检查需要转移的感知块,推迟到后台处理
blocks_to_transfer = [
block for block in perceptual_blocks
block
for block in perceptual_blocks
if block.metadata.get("needs_transfer", False)
]
if blocks_to_transfer:
logger.info(f"检测到 {len(blocks_to_transfer)} 个记忆块需要转移到短期记忆")
logger.info(
f"检测到 {len(blocks_to_transfer)} 个感知记忆需要转移,已交由后台后处理任务执行"
)
for block in blocks_to_transfer:
# 转换为短期记忆
stm = await self.short_term_manager.add_from_block(block)
if stm:
# 从感知记忆中移除
await self.perceptual_manager.remove_block(block.id)
logger.info(f"✅ 记忆块 {block.id} 已转为短期记忆 {stm.id}")
# 将新创建的短期记忆加入结果
short_term_memories.append(stm)
block.metadata["needs_transfer"] = False
self._schedule_perceptual_block_transfer(blocks_to_transfer)
result["perceptual_blocks"] = perceptual_blocks
result["short_term_memories"] = short_term_memories
@@ -243,36 +240,23 @@ class UnifiedMemoryManager:
# 步骤2: 裁判模型评估
if use_judge:
judge_decision = await self._judge_retrieval_sufficiency(
query_text, perceptual_blocks, short_term_memories
query_text, perceptual_blocks, short_term_memories, recent_chat_history
)
result["judge_decision"] = judge_decision
# 步骤3: 如果不充足,检索长期记忆
if not judge_decision.is_sufficient:
logger.info("裁判判定记忆不足,启动长期记忆检索")
logger.info("判官判断记忆不足,开始检索长期记忆")
# 使用额外的 query 检索
long_term_memories = []
queries = [query_text] + judge_decision.additional_queries
long_term_memories = await self._retrieve_long_term_memories(
base_query=query_text,
queries=queries,
recent_chat_history=recent_chat_history,
)
for q in queries:
memories = await self.memory_manager.search_memories(
query=q,
top_k=5,
use_multi_query=False,
)
long_term_memories.extend(memories)
result["long_term_memories"] = long_term_memories
# 去重
seen_ids = set()
unique_memories = []
for mem in long_term_memories:
if mem.id not in seen_ids:
unique_memories.append(mem)
seen_ids.add(mem.id)
result["long_term_memories"] = unique_memories
logger.info(f"长期记忆检索: {len(unique_memories)}")
else:
# 不使用裁判,直接检索长期记忆
long_term_memories = await self.memory_manager.search_memories(
@@ -298,6 +282,7 @@ class UnifiedMemoryManager:
query: str,
perceptual_blocks: list[MemoryBlock],
short_term_memories: list[ShortTermMemory],
recent_chat_history: str = "",
) -> JudgeDecision:
"""
使用裁判模型评估检索结果是否充足
@@ -306,6 +291,7 @@ class UnifiedMemoryManager:
query: 原始查询
perceptual_blocks: 感知记忆块
short_term_memories: 短期记忆
recent_chat_history: 最近的聊天历史上下文(可选)
Returns:
裁判决策
@@ -326,7 +312,7 @@ class UnifiedMemoryManager:
text = str(text)
perceptual_texts.append(f"记忆块{i+1}:\n{text}")
perceptual_desc = "\n\n".join(perceptual_texts)
perceptual_desc = "\n\n".join(str(item) for item in perceptual_texts)
# 短期记忆使用 "主体-主题(属性)" 格式
short_term_texts = []
@@ -335,14 +321,22 @@ class UnifiedMemoryManager:
if formatted: # 只添加非空的格式化结果
short_term_texts.append(f"- {formatted}")
short_term_desc = "\n".join(short_term_texts)
short_term_desc = "\n".join(str(item) for item in short_term_texts)
# 构建聊天历史块(如果提供)
chat_history_block = ""
if recent_chat_history:
chat_history_block = f"""**最近的聊天历史:**
{recent_chat_history}
"""
prompt = f"""你是一个记忆检索评估专家。请判断检索到的记忆是否足以回答用户的问题。
**用户查询:**
{query}
**检索到的感知记忆块:**
{chat_history_block}**检索到的感知记忆块:**
{perceptual_desc or '(无)'}
**检索到的短期记忆(结构化记忆,格式:主体-主题(属性)**
@@ -411,6 +405,114 @@ class UnifiedMemoryManager:
additional_queries=[query],
)
def _schedule_perceptual_block_transfer(self, blocks: list[MemoryBlock]) -> None:
"""将感知记忆块转移到短期记忆,后台执行以避免阻塞"""
if not blocks:
return
task = asyncio.create_task(
self._transfer_blocks_to_short_term(list(blocks))
)
self._attach_background_task_callback(task, "perceptual->short-term transfer")
def _attach_background_task_callback(self, task: asyncio.Task, task_name: str) -> None:
"""确保后台任务异常被记录"""
def _callback(done_task: asyncio.Task) -> None:
try:
done_task.result()
except asyncio.CancelledError:
logger.info(f"{task_name} 后台任务已取消")
except Exception as exc:
logger.error(f"{task_name} 后台任务失败: {exc}", exc_info=True)
task.add_done_callback(_callback)
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
"""实际转换逻辑在后台执行"""
logger.info(f"正在后台处理 {len(blocks)} 个感知记忆块")
for block in blocks:
try:
stm = await self.short_term_manager.add_from_block(block)
if not stm:
continue
await self.perceptual_manager.remove_block(block.id)
logger.info(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
except Exception as exc:
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}", exc_info=True)
def _build_manual_multi_queries(self, queries: list[str]) -> list[dict[str, float]]:
"""去重裁判查询并附加权重以进行多查询搜索"""
deduplicated: list[str] = []
seen = set()
for raw in queries:
text = (raw or "").strip()
if not text or text in seen:
continue
deduplicated.append(text)
seen.add(text)
if len(deduplicated) <= 1:
return []
manual_queries: list[dict[str, float]] = []
decay = 0.15
for idx, text in enumerate(deduplicated):
weight = max(0.3, 1.0 - idx * decay)
manual_queries.append({"text": text, "weight": round(weight, 2)})
return manual_queries
async def _retrieve_long_term_memories(
self,
base_query: str,
queries: list[str],
recent_chat_history: str = "",
) -> list[Any]:
"""可一次性运行多查询搜索的集中式长期检索条目"""
manual_queries = self._build_manual_multi_queries(queries)
context: dict[str, Any] = {}
if recent_chat_history:
context["chat_history"] = recent_chat_history
if manual_queries:
context["manual_multi_queries"] = manual_queries
search_params: dict[str, Any] = {
"query": base_query,
"top_k": self._config["long_term"]["search_top_k"],
"use_multi_query": bool(manual_queries),
}
if context:
search_params["context"] = context
memories = await self.memory_manager.search_memories(**search_params)
unique_memories = self._deduplicate_memories(memories)
query_count = len(manual_queries) if manual_queries else 1
logger.info(
f"Long-term retrieval done: {len(unique_memories)} hits (queries fused={query_count})"
)
return unique_memories
def _deduplicate_memories(self, memories: list[Any]) -> list[Any]:
"""通过 memory.id 去重"""
seen_ids: set[str] = set()
unique_memories: list[Any] = []
for mem in memories:
mem_id = getattr(mem, "id", None)
if mem_id and mem_id in seen_ids:
continue
unique_memories.append(mem)
if mem_id:
seen_ids.add(mem_id)
return unique_memories
def _start_auto_transfer_task(self) -> None:
"""启动自动转移任务"""
if self._auto_transfer_task and not self._auto_transfer_task.done():