短期记忆补丁

This commit is contained in:
LuiKlee
2025-12-14 14:12:39 +08:00
parent 1ad9c932bb
commit 6de5cd9902
5 changed files with 79 additions and 39 deletions

View File

@@ -121,7 +121,7 @@ async def conversation_loop(
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消") logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
break break
except Exception as e: # noqa: BLE001 except Exception as e:
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}") logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
@@ -151,10 +151,10 @@ async def run_chat_stream(
# 创建生成器 # 创建生成器
tick_generator = conversation_loop( tick_generator = conversation_loop(
stream_id=stream_id, stream_id=stream_id,
get_context_func=manager._get_stream_context, # noqa: SLF001 get_context_func=manager._get_stream_context,
calculate_interval_func=manager._calculate_interval, # noqa: SLF001 calculate_interval_func=manager._calculate_interval,
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001 flush_cache_func=manager._flush_cached_messages_to_unread,
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001 check_force_dispatch_func=manager._needs_force_dispatch_for_context,
is_running_func=lambda: manager.is_running, is_running_func=lambda: manager.is_running,
) )
@@ -162,13 +162,13 @@ async def run_chat_stream(
async for tick in tick_generator: async for tick in tick_generator:
try: try:
# 获取上下文 # 获取上下文
context = await manager._get_stream_context(stream_id) # noqa: SLF001 context = await manager._get_stream_context(stream_id)
if not context: if not context:
continue continue
# 并发保护:检查是否正在处理 # 并发保护:检查是否正在处理
if context.is_chatter_processing: if context.is_chatter_processing:
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001 if manager._recover_stale_chatter_state(stream_id, context):
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复") logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
else: else:
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理跳过此Tick") logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理跳过此Tick")
@@ -182,7 +182,7 @@ async def run_chat_stream(
# 更新能量值 # 更新能量值
try: try:
await manager._update_stream_energy(stream_id, context) # noqa: SLF001 await manager._update_stream_energy(stream_id, context)
except Exception as e: except Exception as e:
logger.debug(f"更新能量失败: {e}") logger.debug(f"更新能量失败: {e}")
@@ -191,7 +191,7 @@ async def run_chat_stream(
try: try:
async with manager._processing_semaphore: async with manager._processing_semaphore:
success = await asyncio.wait_for( success = await asyncio.wait_for(
manager._process_stream_messages(stream_id, context), # noqa: SLF001 manager._process_stream_messages(stream_id, context),
global_config.chat.thinking_timeout, global_config.chat.thinking_timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -209,7 +209,7 @@ async def run_chat_stream(
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except Exception as e: # noqa: BLE001 except Exception as e:
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}") logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
manager.stats["total_failures"] += 1 manager.stats["total_failures"] += 1
@@ -222,7 +222,7 @@ async def run_chat_stream(
if context and context.stream_loop_task: if context and context.stream_loop_task:
context.stream_loop_task = None context.stream_loop_task = None
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录") logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
except Exception as e: # noqa: BLE001 except Exception as e:
logger.debug(f"清理任务记录失败: {e}") logger.debug(f"清理任务记录失败: {e}")

View File

@@ -110,10 +110,10 @@ class MessageManager:
if not (context.stream_loop_task and not context.stream_loop_task.done()): if not (context.stream_loop_task and not context.stream_loop_task.done()):
# 异步启动驱动器任务;避免在高并发下阻塞消息入队 # 异步启动驱动器任务;避免在高并发下阻塞消息入队
await stream_loop_manager.start_stream_loop(stream_id) await stream_loop_manager.start_stream_loop(stream_id)
# 检查并处理消息打断 # 检查并处理消息打断
await self._check_and_handle_interruption(chat_stream, message) await self._check_and_handle_interruption(chat_stream, message)
# 入队消息 # 入队消息
await chat_stream.context.add_message(message) await chat_stream.context.add_message(message)

View File

@@ -1,8 +1,8 @@
import os import os
import shutil import shutil
import sys import sys
import typing
import types import types
import typing
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, get_args, get_origin from typing import Any, get_args, get_origin
@@ -30,8 +30,8 @@ from src.config.official_configs import (
ExperimentalConfig, ExperimentalConfig,
ExpressionConfig, ExpressionConfig,
InnerConfig, InnerConfig,
LogConfig,
KokoroFlowChatterConfig, KokoroFlowChatterConfig,
LogConfig,
LPMMKnowledgeConfig, LPMMKnowledgeConfig,
MemoryConfig, MemoryConfig,
MessageBusConfig, MessageBusConfig,
@@ -515,7 +515,7 @@ class Config(ValidatedConfigBase):
) )
@property @property
def MMC_VERSION(self) -> str: # noqa: N802 def MMC_VERSION(self) -> str:
return MMC_VERSION return MMC_VERSION

View File

@@ -66,6 +66,13 @@ class LongTermMemoryManager:
self._similar_memory_cache: dict[str, list[Memory]] = {} self._similar_memory_cache: dict[str, list[Memory]] = {}
self._cache_max_size = 100 self._cache_max_size = 100
# 错误/重试统计与配置
self._max_process_retries = 2
self._retry_backoff = 0.5
self._total_processed = 0
self._failed_single_memory_count = 0
self._retry_attempts = 0
logger.info( logger.info(
f"长期记忆管理器已创建 (batch_size={batch_size}, " f"长期记忆管理器已创建 (batch_size={batch_size}, "
f"search_top_k={search_top_k}, decay_factor={long_term_decay_factor:.2f})" f"search_top_k={search_top_k}, decay_factor={long_term_decay_factor:.2f})"
@@ -202,6 +209,10 @@ class LongTermMemoryManager:
else: else:
result["failed_count"] += 1 result["failed_count"] += 1
# 更新全局计数
self._total_processed += result["processed_count"]
self._failed_single_memory_count += result["failed_count"]
# 处理完批次后批量生成embeddings # 处理完批次后批量生成embeddings
await self._flush_pending_embeddings() await self._flush_pending_embeddings()
@@ -217,26 +228,45 @@ class LongTermMemoryManager:
Returns: Returns:
处理结果或None如果失败 处理结果或None如果失败
""" """
try: # 增加重试机制以应对 LLM/执行的临时失败
# 步骤1: 在长期记忆中检索相似记忆 attempt = 0
similar_memories = await self._search_similar_long_term_memories(stm) last_exc: Exception | None = None
while attempt <= self._max_process_retries:
try:
# 步骤1: 在长期记忆中检索相似记忆
similar_memories = await self._search_similar_long_term_memories(stm)
# 步骤2: LLM 决策如何更新图结构 # 步骤2: LLM 决策如何更新图结构
operations = await self._decide_graph_operations(stm, similar_memories) operations = await self._decide_graph_operations(stm, similar_memories)
# 步骤3: 执行图操作 # 步骤3: 执行图操作
success = await self._execute_graph_operations(operations, stm) success = await self._execute_graph_operations(operations, stm)
if success: if success:
return { return {
"success": True, "success": True,
"operations": [op.operation_type for op in operations] "operations": [op.operation_type for op in operations]
} }
return None
except Exception as e: # 如果执行返回 False视为一次失败准备重试
logger.error(f"处理短期记忆 {stm.id} 失败: {e}") last_exc = RuntimeError("_execute_graph_operations 返回 False")
return None raise last_exc
except Exception as e:
last_exc = e
attempt += 1
if attempt <= self._max_process_retries:
self._retry_attempts += 1
backoff = self._retry_backoff * attempt
logger.warning(
f"处理短期记忆 {stm.id} 时发生可恢复错误,重试 {attempt}/{self._max_process_retries},等待 {backoff}s: {e}"
)
await asyncio.sleep(backoff)
continue
# 超过重试次数,记录失败并返回 None
logger.error(f"处理短期记忆 {stm.id} 最终失败: {last_exc}")
self._failed_single_memory_count += 1
return None
async def _search_similar_long_term_memories( async def _search_similar_long_term_memories(
self, stm: ShortTermMemory self, stm: ShortTermMemory

View File

@@ -648,15 +648,15 @@ class ShortTermMemoryManager:
else: else:
low_importance_memories.append(mem) low_importance_memories.append(mem)
# 如果低重要性记忆数量超过了上限(说明积压严重) # 如果总体记忆数量超过了上限,优先清理低重要性最早创建的记忆
# 我们需要清理掉一部分,而不是转移它们 if len(self.memories) > self.max_memories:
if len(low_importance_memories) > self.max_memories:
# 目标保留数量(降至上限的 90% # 目标保留数量(降至上限的 90%
target_keep_count = int(self.max_memories * 0.9) target_keep_count = int(self.max_memories * 0.9)
num_to_remove = len(low_importance_memories) - target_keep_count # 需要删除的数量(从当前总数降到 target_keep_count
num_to_remove = len(self.memories) - target_keep_count
if num_to_remove > 0: if num_to_remove > 0 and low_importance_memories:
# 按创建时间排序,删除最早的 # 按创建时间排序,删除最早的低重要性记忆
low_importance_memories.sort(key=lambda x: x.created_at) low_importance_memories.sort(key=lambda x: x.created_at)
to_remove = low_importance_memories[:num_to_remove] to_remove = low_importance_memories[:num_to_remove]
@@ -664,7 +664,7 @@ class ShortTermMemoryManager:
remove_ids = {mem.id for mem in to_remove} remove_ids = {mem.id for mem in to_remove}
self.memories = [mem for mem in self.memories if mem.id not in remove_ids] self.memories = [mem for mem in self.memories if mem.id not in remove_ids]
for mem_id in remove_ids: for mem_id in remove_ids:
del self._memory_id_index[mem_id] self._memory_id_index.pop(mem_id, None)
self._similarity_cache.pop(mem_id, None) self._similarity_cache.pop(mem_id, None)
logger.info( logger.info(
@@ -675,6 +675,16 @@ class ShortTermMemoryManager:
# 触发保存 # 触发保存
asyncio.create_task(self._save_to_disk()) asyncio.create_task(self._save_to_disk())
# 优先返回高重要性候选
if candidates:
return candidates
# 如果没有高重要性候选但总体超过上限,返回按创建时间最早的低重要性记忆作为后备转移候选
if len(self.memories) > self.max_memories:
needed = len(self.memories) - self.max_memories + 1
low_importance_memories.sort(key=lambda x: x.created_at)
return low_importance_memories[:needed]
return candidates return candidates
async def clear_transferred_memories(self, memory_ids: list[str]) -> None: async def clear_transferred_memories(self, memory_ids: list[str]) -> None: