From 9a0163d06b6eaa6d298624d653b865561f54e9fe Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Sat, 13 Dec 2025 20:19:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/benchmark_distribution_manager.py | 306 ++++++++++++++++++ scripts/benchmark_unified_manager.py | 8 +- .../message_manager/batch_database_writer.py | 85 ++--- .../message_manager/distribution_manager.py | 34 +- src/chat/message_manager/message_manager.py | 26 +- src/memory_graph/unified_manager.py | 20 +- 6 files changed, 405 insertions(+), 74 deletions(-) create mode 100644 scripts/benchmark_distribution_manager.py diff --git a/scripts/benchmark_distribution_manager.py b/scripts/benchmark_distribution_manager.py new file mode 100644 index 000000000..932aaaa62 --- /dev/null +++ b/scripts/benchmark_distribution_manager.py @@ -0,0 +1,306 @@ +import asyncio +import time +import os +import sys +from dataclasses import dataclass +from typing import Any, Optional + +# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior +# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems. + +# Ensure project root is on sys.path when running as a script +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis) +# Local minimal implementation of loop and manager to isolate benchmark from heavy deps. +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass, field + + +@dataclass +class ConversationTick: + stream_id: str + tick_time: float = field(default_factory=time.time) + force_dispatch: bool = False + tick_count: int = 0 + + +async def conversation_loop( + stream_id: str, + get_context_func: Callable[[str], Awaitable["DummyContext | None"]], + calculate_interval_func: Callable[[str, bool], Awaitable[float]], + flush_cache_func: Callable[[str], Awaitable[list[Any]]], + check_force_dispatch_func: Callable[["DummyContext", int], bool], + is_running_func: Callable[[], bool], +) -> AsyncIterator[ConversationTick]: + tick_count = 0 + while is_running_func(): + ctx = await get_context_func(stream_id) + if not ctx: + await asyncio.sleep(0.1) + continue + await flush_cache_func(stream_id) + unread = ctx.get_unread_messages() + ucnt = len(unread) + force = check_force_dispatch_func(ctx, ucnt) + if ucnt > 0 or force: + tick_count += 1 + yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count) + interval = await calculate_interval_func(stream_id, ucnt > 0) + await asyncio.sleep(interval) + + +class StreamLoopManager: + def __init__(self, max_concurrent_streams: Optional[int] = None): + self.stats: dict[str, Any] = { + "active_streams": 0, + "total_loops": 0, + "total_process_cycles": 0, + "total_failures": 0, + "start_time": time.time(), + } + self.max_concurrent_streams = max_concurrent_streams or 100 + self.force_dispatch_unread_threshold = 20 + self.chatter_manager = DummyChatterManager() + self.is_running = False + self._stream_start_locks: dict[str, asyncio.Lock] = {} + self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams) + self._chat_manager: Optional[DummyChatManager] = None + + def set_chat_manager(self, chat_manager: "DummyChatManager") -> None: + self._chat_manager = chat_manager + + async def start(self): + self.is_running = True + + async def stop(self): + self.is_running = False + + async def _get_stream_context(self, stream_id: str): + assert self._chat_manager is not None + stream = await self._chat_manager.get_stream(stream_id) + return stream.context if stream else None + + async def _flush_cached_messages_to_unread(self, stream_id: str): + ctx = await self._get_stream_context(stream_id) + return ctx.flush_cached_messages() if ctx else [] + + def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool: + return unread_count > (self.force_dispatch_unread_threshold or 20) + + async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool: + res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] + return bool(res.get("success", False)) + + async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None: + pass + + async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float: + return 0.005 if has_messages else 0.02 + + async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: + ctx = await self._get_stream_context(stream_id) + if not ctx: + return False + # create driver + loop_task = asyncio.create_task(run_chat_stream(stream_id, self)) + ctx.stream_loop_task = loop_task + self.stats["active_streams"] += 1 + self.stats["total_loops"] += 1 + return True + + +async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None: + try: + gen = conversation_loop( + stream_id=stream_id, + get_context_func=manager._get_stream_context, + calculate_interval_func=manager._calculate_interval, + flush_cache_func=manager._flush_cached_messages_to_unread, + check_force_dispatch_func=manager._needs_force_dispatch_for_context, + is_running_func=lambda: manager.is_running, + ) + async for tick in gen: + ctx = await manager._get_stream_context(stream_id) + if not ctx: + continue + if ctx.is_chatter_processing: + continue + try: + async with manager._processing_semaphore: + ok = await manager._process_stream_messages(stream_id, ctx) + except Exception: + ok = False + manager.stats["total_process_cycles"] += 1 + if not ok: + manager.stats["total_failures"] += 1 + except asyncio.CancelledError: + pass + + +@dataclass +class DummyMessage: + time: float + processed_plain_text: str = "" + display_message: str = "" + is_at: bool = False + is_mentioned: bool = False + + +class DummyContext: + def __init__(self, stream_id: str, initial_unread: int): + self.stream_id = stream_id + self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)] + self.history_messages: list[DummyMessage] = [] + self.is_chatter_processing: bool = False + self.processing_task: Optional[asyncio.Task] = None + self.stream_loop_task: Optional[asyncio.Task] = None + self.triggering_user_id: Optional[str] = None + + def get_unread_messages(self) -> list[DummyMessage]: + return list(self.unread_messages) + + def flush_cached_messages(self) -> list[DummyMessage]: + return [] + + def get_last_message(self) -> Optional[DummyMessage]: + return self.unread_messages[-1] if self.unread_messages else None + + def get_history_messages(self, limit: int = 50) -> list[DummyMessage]: + return self.history_messages[-limit:] + + +class DummyStream: + def __init__(self, stream_id: str, ctx: DummyContext): + self.stream_id = stream_id + self.context = ctx + self.group_info = None # treat as private chat to accelerate + self._focus_energy = 0.5 + + +class DummyChatManager: + def __init__(self, streams: dict[str, DummyStream]): + self._streams = streams + + async def get_stream(self, stream_id: str) -> Optional[DummyStream]: + return self._streams.get(stream_id) + + def get_all_streams(self) -> dict[str, DummyStream]: + return self._streams + + +class DummyChatterManager: + async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]: + # Simulate some processing latency and consume one unread message + await asyncio.sleep(0.01) + if context.unread_messages: + context.unread_messages.pop(0) + return {"success": True} + + +class BenchStreamLoopManager(StreamLoopManager): + def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None): + super().__init__(max_concurrent_streams=max_concurrent_streams) + self._chat_manager = chat_manager + self.chatter_manager = DummyChatterManager() + + async def _get_stream_context(self, stream_id: str): # type: ignore[override] + stream = await self._chat_manager.get_stream(stream_id) + return stream.context if stream else None + + async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override] + ctx = await self._get_stream_context(stream_id) + return ctx.flush_cached_messages() if ctx else [] + + def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override] + # force when unread exceeds threshold + return unread_count > (self.force_dispatch_unread_threshold or 20) + + async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override] + # delegate to chatter manager + res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] + return bool(res.get("success", False)) + + async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool: + return False + + async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override] + # lightweight: compute based on unread size + focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages())) + # set for compatibility + stream = await self._chat_manager.get_stream(stream_id) + if stream: + stream._focus_energy = focus + + +def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]: + streams: dict[str, DummyStream] = {} + for i in range(n_streams): + sid = f"s{i:04d}" + ctx = DummyContext(sid, initial_unread) + streams[sid] = DummyStream(sid, ctx) + return streams + + +async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]: + streams = make_streams(n_streams, initial_unread) + chat_mgr = DummyChatManager(streams) + mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent) + await mgr.start() + + # start loops for all streams + start_ts = time.time() + for sid in list(streams.keys()): + await mgr.start_stream_loop(sid, force=True) + + # run until all unread consumed or timeout + timeout = 5.0 + end_deadline = start_ts + timeout + while time.time() < end_deadline: + remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) + if remaining == 0: + break + await asyncio.sleep(0.02) + + duration = time.time() - start_ts + total_cycles = mgr.stats.get("total_process_cycles", 0) + total_failures = mgr.stats.get("total_failures", 0) + remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) + + # stop all + await mgr.stop() + + return { + "n_streams": n_streams, + "initial_unread": initial_unread, + "max_concurrent": max_concurrent, + "duration_sec": duration, + "total_cycles": total_cycles, + "total_failures": total_failures, + "remaining_unread": remaining, + "throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration), + } + + +async def main(): + cases = [ + (50, 5, None), # baseline using configured default + (50, 5, 5), # constrained concurrency + (50, 5, 10), # moderate concurrency + (100, 3, 10), # scale streams + ] + + print("Running distribution manager benchmark...\n") + for n_streams, initial_unread, max_concurrent in cases: + res = await run_benchmark(n_streams, initial_unread, max_concurrent) + print( + f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | " + f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+ + f"thr={res['throughput_msgs_per_sec']:.1f}/s" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/benchmark_unified_manager.py b/scripts/benchmark_unified_manager.py index ec0ec69f0..eae8d187b 100644 --- a/scripts/benchmark_unified_manager.py +++ b/scripts/benchmark_unified_manager.py @@ -6,8 +6,6 @@ import asyncio import time -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch class PerformanceBenchmark: @@ -20,7 +18,7 @@ class PerformanceBenchmark: """测试查询去重性能""" # 这里需要导入实际的管理器 # from src.memory_graph.unified_manager import UnifiedMemoryManager - + test_cases = [ { "name": "small_queries", @@ -67,7 +65,7 @@ class PerformanceBenchmark: seen = set() decay = 0.15 manual_queries = [] - + for raw in queries: text = (raw or "").strip() if text and text not in seen: @@ -192,7 +190,7 @@ class PerformanceBenchmark: mem_id = mem.get("id") else: mem_id = getattr(mem, "id", None) - + if mem_id and mem_id in seen_ids: continue unique_memories.append(mem) diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index 74128d8d9..5eca2b1d5 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -9,6 +9,8 @@ from collections import defaultdict from dataclasses import dataclass, field from typing import Any +from sqlalchemy.exc import SQLAlchemyError + from src.common.database.compatibility import get_db_session from src.common.database.core.models import ChatStreams from src.common.logger import get_logger @@ -159,20 +161,27 @@ class BatchDatabaseWriter: logger.info("批量写入循环结束") async def _collect_batch(self) -> list[StreamUpdatePayload]: - """收集一个批次的数据""" - batch = [] - deadline = time.time() + self.flush_interval + """收集一个批次的数据 + - 自适应刷新:队列增长加快时缩短等待时间 + - 避免长时间空转:添加轻微抖动以分散竞争 + """ + batch: list[StreamUpdatePayload] = [] + # 根据当前队列长度调整刷新时间(最多缩短到 40%) + qsize = self.write_queue.qsize() + adapt_factor = 1.0 + if qsize > 0: + adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize))) + deadline = time.time() + (self.flush_interval * adapt_factor) while len(batch) < self.batch_size and time.time() < deadline: try: - # 计算剩余等待时间 - remaining_time = max(0, deadline - time.time()) + remaining_time = max(0.0, deadline - time.time()) if remaining_time == 0: break - - payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time) + # 轻微抖动,避免多个协程同时争抢队列 + jitter = 0.002 + payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter) batch.append(payload) - except asyncio.TimeoutError: break @@ -208,48 +217,52 @@ class BatchDatabaseWriter: logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s") - except Exception as e: + except SQLAlchemyError as e: self.stats["failed_writes"] += 1 logger.error(f"批量写入失败: {e}") # 降级到单个写入 for payload in batch: try: await self._direct_write(payload.stream_id, payload.update_data) - except Exception as single_e: + except SQLAlchemyError as single_e: logger.error(f"单个写入也失败: {single_e}") async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]): - """批量写入数据库""" + """批量写入数据库(单事务、多值 UPSERT)""" if global_config is None: raise RuntimeError("Global config is not initialized") + if not payloads: + return + + # 预组装行数据,确保每行包含 stream_id + rows: list[dict[str, Any]] = [] + for p in payloads: + row = {"stream_id": p.stream_id} + row.update(p.update_data) + rows.append(row) + async with get_db_session() as session: - for payload in payloads: - stream_id = payload.stream_id - update_data = payload.update_data - - # 根据数据库类型选择不同的插入/更新策略 - if global_config.database.database_type == "sqlite": - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - elif global_config.database.database_type == "postgresql": - from sqlalchemy.dialects.postgresql import insert as pg_insert - - stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update( - index_elements=[ChatStreams.stream_id], - set_=update_data - ) - else: - # 默认使用SQLite语法 - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - + # 使用单次事务提交,显著减少 I/O + if global_config.database.database_type == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + stmt = pg_insert(ChatStreams).values(rows) + stmt = stmt.on_conflict_do_update( + index_elements=[ChatStreams.stream_id], + set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"} + ) await session.execute(stmt) + await session.commit() + else: + # 默认(sqlite) + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values(rows) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"} + ) + await session.execute(stmt) + await session.commit() async def _direct_write(self, stream_id: str, update_data: dict[str, Any]): """直接写入数据库(降级方案)""" if global_config is None: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index ff3694901..cb52e5f14 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -55,7 +55,7 @@ async def conversation_loop( stream_id: str, get_context_func: Callable[[str], Awaitable["StreamContext | None"]], calculate_interval_func: Callable[[str, bool], Awaitable[float]], - flush_cache_func: Callable[[str], Awaitable[None]], + flush_cache_func: Callable[[str], Awaitable[list[Any]]], check_force_dispatch_func: Callable[["StreamContext", int], bool], is_running_func: Callable[[], bool], ) -> AsyncIterator[ConversationTick]: @@ -121,7 +121,7 @@ async def conversation_loop( except asyncio.CancelledError: logger.info(f" [生成器] stream={stream_id[:8]}, 被取消") break - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}") await asyncio.sleep(5.0) @@ -151,10 +151,10 @@ async def run_chat_stream( # 创建生成器 tick_generator = conversation_loop( stream_id=stream_id, - get_context_func=manager._get_stream_context, - calculate_interval_func=manager._calculate_interval, - flush_cache_func=manager._flush_cached_messages_to_unread, - check_force_dispatch_func=manager._needs_force_dispatch_for_context, + get_context_func=manager._get_stream_context, # noqa: SLF001 + calculate_interval_func=manager._calculate_interval, # noqa: SLF001 + flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001 + check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001 is_running_func=lambda: manager.is_running, ) @@ -162,13 +162,13 @@ async def run_chat_stream( async for tick in tick_generator: try: # 获取上下文 - context = await manager._get_stream_context(stream_id) + context = await manager._get_stream_context(stream_id) # noqa: SLF001 if not context: continue # 并发保护:检查是否正在处理 if context.is_chatter_processing: - if manager._recover_stale_chatter_state(stream_id, context): + if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001 logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复") else: logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick") @@ -182,17 +182,18 @@ async def run_chat_stream( # 更新能量值 try: - await manager._update_stream_energy(stream_id, context) + await manager._update_stream_energy(stream_id, context) # noqa: SLF001 except Exception as e: logger.debug(f"更新能量失败: {e}") # 处理消息 assert global_config is not None try: - success = await asyncio.wait_for( - manager._process_stream_messages(stream_id, context), - global_config.chat.thinking_timeout - ) + async with manager._processing_semaphore: + success = await asyncio.wait_for( + manager._process_stream_messages(stream_id, context), # noqa: SLF001 + global_config.chat.thinking_timeout, + ) except asyncio.TimeoutError: logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时") success = False @@ -208,7 +209,7 @@ async def run_chat_stream( except asyncio.CancelledError: raise - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}") manager.stats["total_failures"] += 1 @@ -221,7 +222,7 @@ async def run_chat_stream( if context and context.stream_loop_task: context.stream_loop_task = None logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录") - except Exception as e: + except Exception as e: # noqa: BLE001 logger.debug(f"清理任务记录失败: {e}") @@ -268,6 +269,9 @@ class StreamLoopManager: # 流启动锁:防止并发启动同一个流的多个任务 self._stream_start_locks: dict[str, asyncio.Lock] = {} + # 并发控制:限制同时进行的 Chatter 处理任务数 + self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams) + logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})") # ======================================================================== diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 8c7bdb3e1..e55fe077e 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -104,9 +104,21 @@ class MessageManager: if not chat_stream: logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") return - # 启动 stream loop 任务(如果尚未启动) - await stream_loop_manager.start_stream_loop(stream_id) - await self._check_and_handle_interruption(chat_stream, message) + + # 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await + context = chat_stream.context + if not (context.stream_loop_task and not context.stream_loop_task.done()): + # 异步启动驱动器任务;避免在高并发下阻塞消息入队 + context.stream_loop_task = asyncio.create_task( + stream_loop_manager.start_stream_loop(stream_id) + ) + + # 并行触发打断检查,不阻塞消息入队 + context.interruption_task = asyncio.create_task( + self._check_and_handle_interruption(chat_stream, message) + ) + + # 入队消息 await chat_stream.context.add_message(message) except Exception as e: @@ -476,8 +488,7 @@ class MessageManager: is_processing: 是否正在处理 """ try: - # 尝试更新StreamContext的处理状态 - import asyncio + # 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入) async def _update_context(): try: chat_manager = get_chat_manager() @@ -492,7 +503,7 @@ class MessageManager: try: loop = asyncio.get_event_loop() if loop.is_running(): - asyncio.create_task(_update_context()) + self._update_context_task = asyncio.create_task(_update_context()) else: # 如果事件循环未运行,则跳过 logger.debug("事件循环未运行,跳过StreamContext状态更新") @@ -512,8 +523,7 @@ class MessageManager: bool: 是否正在处理 """ try: - # 尝试从StreamContext获取处理状态 - import asyncio + # 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入) async def _get_context_status(): try: chat_manager = get_chat_manager() diff --git a/src/memory_graph/unified_manager.py b/src/memory_graph/unified_manager.py index c0a5db3e9..78955c3bc 100644 --- a/src/memory_graph/unified_manager.py +++ b/src/memory_graph/unified_manager.py @@ -451,7 +451,7 @@ class UnifiedMemoryManager: (0.3, 10.0, 0.4), (0.1, 15.0, 0.6), ] - + for threshold, min_val, factor in occupancy_thresholds: if occupancy >= threshold: return max(min_val, base_interval * factor) @@ -461,24 +461,24 @@ class UnifiedMemoryManager: async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None: """实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)""" logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块") - + # 优化:使用 asyncio.gather 并行处理转移 async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]: try: stm = await self.short_term_manager.add_from_block(block) if not stm: return block, False - + await self.perceptual_manager.remove_block(block.id) logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}") return block, True except Exception as exc: logger.error(f"后台转移失败,记忆块 {block.id}: {exc}") return block, False - + # 并行处理所有块 results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True) - + # 统计成功的转移 success_count = sum(1 for result in results if isinstance(result, tuple) and result[1]) if success_count > 0: @@ -491,7 +491,7 @@ class UnifiedMemoryManager: seen = set() decay = 0.15 manual_queries: list[dict[str, Any]] = [] - + for raw in queries: text = (raw or "").strip() if text and text not in seen: @@ -517,7 +517,7 @@ class UnifiedMemoryManager: "top_k": self._config["long_term"]["search_top_k"], "use_multi_query": bool(manual_queries), } - + if recent_chat_history or manual_queries: context: dict[str, Any] = {} if recent_chat_history: @@ -541,7 +541,7 @@ class UnifiedMemoryManager: mem_id = mem.get("id") else: mem_id = getattr(mem, "id", None) - + # 检查去重 if mem_id and mem_id in seen_ids: continue @@ -600,7 +600,7 @@ class UnifiedMemoryManager: new_memories.append(memory) if mem_id: cached_ids.add(mem_id) - + if new_memories: transfer_cache.extend(new_memories) logger.debug( @@ -632,7 +632,7 @@ class UnifiedMemoryManager: await self.short_term_manager.clear_transferred_memories( result["transferred_memory_ids"] ) - + # 优化:使用生成器表达式保留未转移的记忆 transfer_cache = [ m