From 9a0163d06b6eaa6d298624d653b865561f54e9fe Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Sat, 13 Dec 2025 20:19:11 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/benchmark_distribution_manager.py | 306 ++++++++++++++++++ scripts/benchmark_unified_manager.py | 8 +- .../message_manager/batch_database_writer.py | 85 ++--- .../message_manager/distribution_manager.py | 34 +- src/chat/message_manager/message_manager.py | 26 +- src/memory_graph/unified_manager.py | 20 +- 6 files changed, 405 insertions(+), 74 deletions(-) create mode 100644 scripts/benchmark_distribution_manager.py diff --git a/scripts/benchmark_distribution_manager.py b/scripts/benchmark_distribution_manager.py new file mode 100644 index 000000000..932aaaa62 --- /dev/null +++ b/scripts/benchmark_distribution_manager.py @@ -0,0 +1,306 @@ +import asyncio +import time +import os +import sys +from dataclasses import dataclass +from typing import Any, Optional + +# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior +# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems. + +# Ensure project root is on sys.path when running as a script +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis) +# Local minimal implementation of loop and manager to isolate benchmark from heavy deps. +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass, field + + +@dataclass +class ConversationTick: + stream_id: str + tick_time: float = field(default_factory=time.time) + force_dispatch: bool = False + tick_count: int = 0 + + +async def conversation_loop( + stream_id: str, + get_context_func: Callable[[str], Awaitable["DummyContext | None"]], + calculate_interval_func: Callable[[str, bool], Awaitable[float]], + flush_cache_func: Callable[[str], Awaitable[list[Any]]], + check_force_dispatch_func: Callable[["DummyContext", int], bool], + is_running_func: Callable[[], bool], +) -> AsyncIterator[ConversationTick]: + tick_count = 0 + while is_running_func(): + ctx = await get_context_func(stream_id) + if not ctx: + await asyncio.sleep(0.1) + continue + await flush_cache_func(stream_id) + unread = ctx.get_unread_messages() + ucnt = len(unread) + force = check_force_dispatch_func(ctx, ucnt) + if ucnt > 0 or force: + tick_count += 1 + yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count) + interval = await calculate_interval_func(stream_id, ucnt > 0) + await asyncio.sleep(interval) + + +class StreamLoopManager: + def __init__(self, max_concurrent_streams: Optional[int] = None): + self.stats: dict[str, Any] = { + "active_streams": 0, + "total_loops": 0, + "total_process_cycles": 0, + "total_failures": 0, + "start_time": time.time(), + } + self.max_concurrent_streams = max_concurrent_streams or 100 + self.force_dispatch_unread_threshold = 20 + self.chatter_manager = DummyChatterManager() + self.is_running = False + self._stream_start_locks: dict[str, asyncio.Lock] = {} + self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams) + self._chat_manager: Optional[DummyChatManager] = None + + def set_chat_manager(self, chat_manager: "DummyChatManager") -> None: + self._chat_manager = chat_manager + + async def start(self): + self.is_running = True + + async def stop(self): + self.is_running = False + + async def _get_stream_context(self, stream_id: str): + assert self._chat_manager is not None + stream = await self._chat_manager.get_stream(stream_id) + return stream.context if stream else None + + async def _flush_cached_messages_to_unread(self, stream_id: str): + ctx = await self._get_stream_context(stream_id) + return ctx.flush_cached_messages() if ctx else [] + + def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool: + return unread_count > (self.force_dispatch_unread_threshold or 20) + + async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool: + res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] + return bool(res.get("success", False)) + + async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None: + pass + + async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float: + return 0.005 if has_messages else 0.02 + + async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: + ctx = await self._get_stream_context(stream_id) + if not ctx: + return False + # create driver + loop_task = asyncio.create_task(run_chat_stream(stream_id, self)) + ctx.stream_loop_task = loop_task + self.stats["active_streams"] += 1 + self.stats["total_loops"] += 1 + return True + + +async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None: + try: + gen = conversation_loop( + stream_id=stream_id, + get_context_func=manager._get_stream_context, + calculate_interval_func=manager._calculate_interval, + flush_cache_func=manager._flush_cached_messages_to_unread, + check_force_dispatch_func=manager._needs_force_dispatch_for_context, + is_running_func=lambda: manager.is_running, + ) + async for tick in gen: + ctx = await manager._get_stream_context(stream_id) + if not ctx: + continue + if ctx.is_chatter_processing: + continue + try: + async with manager._processing_semaphore: + ok = await manager._process_stream_messages(stream_id, ctx) + except Exception: + ok = False + manager.stats["total_process_cycles"] += 1 + if not ok: + manager.stats["total_failures"] += 1 + except asyncio.CancelledError: + pass + + +@dataclass +class DummyMessage: + time: float + processed_plain_text: str = "" + display_message: str = "" + is_at: bool = False + is_mentioned: bool = False + + +class DummyContext: + def __init__(self, stream_id: str, initial_unread: int): + self.stream_id = stream_id + self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)] + self.history_messages: list[DummyMessage] = [] + self.is_chatter_processing: bool = False + self.processing_task: Optional[asyncio.Task] = None + self.stream_loop_task: Optional[asyncio.Task] = None + self.triggering_user_id: Optional[str] = None + + def get_unread_messages(self) -> list[DummyMessage]: + return list(self.unread_messages) + + def flush_cached_messages(self) -> list[DummyMessage]: + return [] + + def get_last_message(self) -> Optional[DummyMessage]: + return self.unread_messages[-1] if self.unread_messages else None + + def get_history_messages(self, limit: int = 50) -> list[DummyMessage]: + return self.history_messages[-limit:] + + +class DummyStream: + def __init__(self, stream_id: str, ctx: DummyContext): + self.stream_id = stream_id + self.context = ctx + self.group_info = None # treat as private chat to accelerate + self._focus_energy = 0.5 + + +class DummyChatManager: + def __init__(self, streams: dict[str, DummyStream]): + self._streams = streams + + async def get_stream(self, stream_id: str) -> Optional[DummyStream]: + return self._streams.get(stream_id) + + def get_all_streams(self) -> dict[str, DummyStream]: + return self._streams + + +class DummyChatterManager: + async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]: + # Simulate some processing latency and consume one unread message + await asyncio.sleep(0.01) + if context.unread_messages: + context.unread_messages.pop(0) + return {"success": True} + + +class BenchStreamLoopManager(StreamLoopManager): + def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None): + super().__init__(max_concurrent_streams=max_concurrent_streams) + self._chat_manager = chat_manager + self.chatter_manager = DummyChatterManager() + + async def _get_stream_context(self, stream_id: str): # type: ignore[override] + stream = await self._chat_manager.get_stream(stream_id) + return stream.context if stream else None + + async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override] + ctx = await self._get_stream_context(stream_id) + return ctx.flush_cached_messages() if ctx else [] + + def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override] + # force when unread exceeds threshold + return unread_count > (self.force_dispatch_unread_threshold or 20) + + async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override] + # delegate to chatter manager + res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] + return bool(res.get("success", False)) + + async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool: + return False + + async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override] + # lightweight: compute based on unread size + focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages())) + # set for compatibility + stream = await self._chat_manager.get_stream(stream_id) + if stream: + stream._focus_energy = focus + + +def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]: + streams: dict[str, DummyStream] = {} + for i in range(n_streams): + sid = f"s{i:04d}" + ctx = DummyContext(sid, initial_unread) + streams[sid] = DummyStream(sid, ctx) + return streams + + +async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]: + streams = make_streams(n_streams, initial_unread) + chat_mgr = DummyChatManager(streams) + mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent) + await mgr.start() + + # start loops for all streams + start_ts = time.time() + for sid in list(streams.keys()): + await mgr.start_stream_loop(sid, force=True) + + # run until all unread consumed or timeout + timeout = 5.0 + end_deadline = start_ts + timeout + while time.time() < end_deadline: + remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) + if remaining == 0: + break + await asyncio.sleep(0.02) + + duration = time.time() - start_ts + total_cycles = mgr.stats.get("total_process_cycles", 0) + total_failures = mgr.stats.get("total_failures", 0) + remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) + + # stop all + await mgr.stop() + + return { + "n_streams": n_streams, + "initial_unread": initial_unread, + "max_concurrent": max_concurrent, + "duration_sec": duration, + "total_cycles": total_cycles, + "total_failures": total_failures, + "remaining_unread": remaining, + "throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration), + } + + +async def main(): + cases = [ + (50, 5, None), # baseline using configured default + (50, 5, 5), # constrained concurrency + (50, 5, 10), # moderate concurrency + (100, 3, 10), # scale streams + ] + + print("Running distribution manager benchmark...\n") + for n_streams, initial_unread, max_concurrent in cases: + res = await run_benchmark(n_streams, initial_unread, max_concurrent) + print( + f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | " + f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+ + f"thr={res['throughput_msgs_per_sec']:.1f}/s" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/benchmark_unified_manager.py b/scripts/benchmark_unified_manager.py index ec0ec69f0..eae8d187b 100644 --- a/scripts/benchmark_unified_manager.py +++ b/scripts/benchmark_unified_manager.py @@ -6,8 +6,6 @@ import asyncio import time -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch class PerformanceBenchmark: @@ -20,7 +18,7 @@ class PerformanceBenchmark: """测试查询去重性能""" # 这里需要导入实际的管理器 # from src.memory_graph.unified_manager import UnifiedMemoryManager - + test_cases = [ { "name": "small_queries", @@ -67,7 +65,7 @@ class PerformanceBenchmark: seen = set() decay = 0.15 manual_queries = [] - + for raw in queries: text = (raw or "").strip() if text and text not in seen: @@ -192,7 +190,7 @@ class PerformanceBenchmark: mem_id = mem.get("id") else: mem_id = getattr(mem, "id", None) - + if mem_id and mem_id in seen_ids: continue unique_memories.append(mem) diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index 74128d8d9..5eca2b1d5 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -9,6 +9,8 @@ from collections import defaultdict from dataclasses import dataclass, field from typing import Any +from sqlalchemy.exc import SQLAlchemyError + from src.common.database.compatibility import get_db_session from src.common.database.core.models import ChatStreams from src.common.logger import get_logger @@ -159,20 +161,27 @@ class BatchDatabaseWriter: logger.info("批量写入循环结束") async def _collect_batch(self) -> list[StreamUpdatePayload]: - """收集一个批次的数据""" - batch = [] - deadline = time.time() + self.flush_interval + """收集一个批次的数据 + - 自适应刷新:队列增长加快时缩短等待时间 + - 避免长时间空转:添加轻微抖动以分散竞争 + """ + batch: list[StreamUpdatePayload] = [] + # 根据当前队列长度调整刷新时间(最多缩短到 40%) + qsize = self.write_queue.qsize() + adapt_factor = 1.0 + if qsize > 0: + adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize))) + deadline = time.time() + (self.flush_interval * adapt_factor) while len(batch) < self.batch_size and time.time() < deadline: try: - # 计算剩余等待时间 - remaining_time = max(0, deadline - time.time()) + remaining_time = max(0.0, deadline - time.time()) if remaining_time == 0: break - - payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time) + # 轻微抖动,避免多个协程同时争抢队列 + jitter = 0.002 + payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter) batch.append(payload) - except asyncio.TimeoutError: break @@ -208,48 +217,52 @@ class BatchDatabaseWriter: logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s") - except Exception as e: + except SQLAlchemyError as e: self.stats["failed_writes"] += 1 logger.error(f"批量写入失败: {e}") # 降级到单个写入 for payload in batch: try: await self._direct_write(payload.stream_id, payload.update_data) - except Exception as single_e: + except SQLAlchemyError as single_e: logger.error(f"单个写入也失败: {single_e}") async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]): - """批量写入数据库""" + """批量写入数据库(单事务、多值 UPSERT)""" if global_config is None: raise RuntimeError("Global config is not initialized") + if not payloads: + return + + # 预组装行数据,确保每行包含 stream_id + rows: list[dict[str, Any]] = [] + for p in payloads: + row = {"stream_id": p.stream_id} + row.update(p.update_data) + rows.append(row) + async with get_db_session() as session: - for payload in payloads: - stream_id = payload.stream_id - update_data = payload.update_data - - # 根据数据库类型选择不同的插入/更新策略 - if global_config.database.database_type == "sqlite": - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - elif global_config.database.database_type == "postgresql": - from sqlalchemy.dialects.postgresql import insert as pg_insert - - stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update( - index_elements=[ChatStreams.stream_id], - set_=update_data - ) - else: - # 默认使用SQLite语法 - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - + # 使用单次事务提交,显著减少 I/O + if global_config.database.database_type == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + stmt = pg_insert(ChatStreams).values(rows) + stmt = stmt.on_conflict_do_update( + index_elements=[ChatStreams.stream_id], + set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"} + ) await session.execute(stmt) + await session.commit() + else: + # 默认(sqlite) + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values(rows) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"} + ) + await session.execute(stmt) + await session.commit() async def _direct_write(self, stream_id: str, update_data: dict[str, Any]): """直接写入数据库(降级方案)""" if global_config is None: diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index ff3694901..cb52e5f14 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -55,7 +55,7 @@ async def conversation_loop( stream_id: str, get_context_func: Callable[[str], Awaitable["StreamContext | None"]], calculate_interval_func: Callable[[str, bool], Awaitable[float]], - flush_cache_func: Callable[[str], Awaitable[None]], + flush_cache_func: Callable[[str], Awaitable[list[Any]]], check_force_dispatch_func: Callable[["StreamContext", int], bool], is_running_func: Callable[[], bool], ) -> AsyncIterator[ConversationTick]: @@ -121,7 +121,7 @@ async def conversation_loop( except asyncio.CancelledError: logger.info(f" [生成器] stream={stream_id[:8]}, 被取消") break - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}") await asyncio.sleep(5.0) @@ -151,10 +151,10 @@ async def run_chat_stream( # 创建生成器 tick_generator = conversation_loop( stream_id=stream_id, - get_context_func=manager._get_stream_context, - calculate_interval_func=manager._calculate_interval, - flush_cache_func=manager._flush_cached_messages_to_unread, - check_force_dispatch_func=manager._needs_force_dispatch_for_context, + get_context_func=manager._get_stream_context, # noqa: SLF001 + calculate_interval_func=manager._calculate_interval, # noqa: SLF001 + flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001 + check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001 is_running_func=lambda: manager.is_running, ) @@ -162,13 +162,13 @@ async def run_chat_stream( async for tick in tick_generator: try: # 获取上下文 - context = await manager._get_stream_context(stream_id) + context = await manager._get_stream_context(stream_id) # noqa: SLF001 if not context: continue # 并发保护:检查是否正在处理 if context.is_chatter_processing: - if manager._recover_stale_chatter_state(stream_id, context): + if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001 logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复") else: logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick") @@ -182,17 +182,18 @@ async def run_chat_stream( # 更新能量值 try: - await manager._update_stream_energy(stream_id, context) + await manager._update_stream_energy(stream_id, context) # noqa: SLF001 except Exception as e: logger.debug(f"更新能量失败: {e}") # 处理消息 assert global_config is not None try: - success = await asyncio.wait_for( - manager._process_stream_messages(stream_id, context), - global_config.chat.thinking_timeout - ) + async with manager._processing_semaphore: + success = await asyncio.wait_for( + manager._process_stream_messages(stream_id, context), # noqa: SLF001 + global_config.chat.thinking_timeout, + ) except asyncio.TimeoutError: logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时") success = False @@ -208,7 +209,7 @@ async def run_chat_stream( except asyncio.CancelledError: raise - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}") manager.stats["total_failures"] += 1 @@ -221,7 +222,7 @@ async def run_chat_stream( if context and context.stream_loop_task: context.stream_loop_task = None logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录") - except Exception as e: + except Exception as e: # noqa: BLE001 logger.debug(f"清理任务记录失败: {e}") @@ -268,6 +269,9 @@ class StreamLoopManager: # 流启动锁:防止并发启动同一个流的多个任务 self._stream_start_locks: dict[str, asyncio.Lock] = {} + # 并发控制:限制同时进行的 Chatter 处理任务数 + self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams) + logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})") # ======================================================================== diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 8c7bdb3e1..e55fe077e 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -104,9 +104,21 @@ class MessageManager: if not chat_stream: logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") return - # 启动 stream loop 任务(如果尚未启动) - await stream_loop_manager.start_stream_loop(stream_id) - await self._check_and_handle_interruption(chat_stream, message) + + # 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await + context = chat_stream.context + if not (context.stream_loop_task and not context.stream_loop_task.done()): + # 异步启动驱动器任务;避免在高并发下阻塞消息入队 + context.stream_loop_task = asyncio.create_task( + stream_loop_manager.start_stream_loop(stream_id) + ) + + # 并行触发打断检查,不阻塞消息入队 + context.interruption_task = asyncio.create_task( + self._check_and_handle_interruption(chat_stream, message) + ) + + # 入队消息 await chat_stream.context.add_message(message) except Exception as e: @@ -476,8 +488,7 @@ class MessageManager: is_processing: 是否正在处理 """ try: - # 尝试更新StreamContext的处理状态 - import asyncio + # 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入) async def _update_context(): try: chat_manager = get_chat_manager() @@ -492,7 +503,7 @@ class MessageManager: try: loop = asyncio.get_event_loop() if loop.is_running(): - asyncio.create_task(_update_context()) + self._update_context_task = asyncio.create_task(_update_context()) else: # 如果事件循环未运行,则跳过 logger.debug("事件循环未运行,跳过StreamContext状态更新") @@ -512,8 +523,7 @@ class MessageManager: bool: 是否正在处理 """ try: - # 尝试从StreamContext获取处理状态 - import asyncio + # 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入) async def _get_context_status(): try: chat_manager = get_chat_manager() diff --git a/src/memory_graph/unified_manager.py b/src/memory_graph/unified_manager.py index c0a5db3e9..78955c3bc 100644 --- a/src/memory_graph/unified_manager.py +++ b/src/memory_graph/unified_manager.py @@ -451,7 +451,7 @@ class UnifiedMemoryManager: (0.3, 10.0, 0.4), (0.1, 15.0, 0.6), ] - + for threshold, min_val, factor in occupancy_thresholds: if occupancy >= threshold: return max(min_val, base_interval * factor) @@ -461,24 +461,24 @@ class UnifiedMemoryManager: async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None: """实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)""" logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块") - + # 优化:使用 asyncio.gather 并行处理转移 async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]: try: stm = await self.short_term_manager.add_from_block(block) if not stm: return block, False - + await self.perceptual_manager.remove_block(block.id) logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}") return block, True except Exception as exc: logger.error(f"后台转移失败,记忆块 {block.id}: {exc}") return block, False - + # 并行处理所有块 results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True) - + # 统计成功的转移 success_count = sum(1 for result in results if isinstance(result, tuple) and result[1]) if success_count > 0: @@ -491,7 +491,7 @@ class UnifiedMemoryManager: seen = set() decay = 0.15 manual_queries: list[dict[str, Any]] = [] - + for raw in queries: text = (raw or "").strip() if text and text not in seen: @@ -517,7 +517,7 @@ class UnifiedMemoryManager: "top_k": self._config["long_term"]["search_top_k"], "use_multi_query": bool(manual_queries), } - + if recent_chat_history or manual_queries: context: dict[str, Any] = {} if recent_chat_history: @@ -541,7 +541,7 @@ class UnifiedMemoryManager: mem_id = mem.get("id") else: mem_id = getattr(mem, "id", None) - + # 检查去重 if mem_id and mem_id in seen_ids: continue @@ -600,7 +600,7 @@ class UnifiedMemoryManager: new_memories.append(memory) if mem_id: cached_ids.add(mem_id) - + if new_memories: transfer_cache.extend(new_memories) logger.debug( @@ -632,7 +632,7 @@ class UnifiedMemoryManager: await self.short_term_manager.clear_transferred_memories( result["transferred_memory_ids"] ) - + # 优化:使用生成器表达式保留未转移的记忆 transfer_cache = [ m From d7ab785ced7905711f21d37943d6457dd4f1dde1 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 20:50:19 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=92=8C=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...inity_flow_chatter_optimization_summary.md | 654 -------- docs/affinity_flow_guide.md | 170 -- docs/database_api_migration_checklist.md | 374 ----- docs/database_refactoring_completion.md | 224 --- docs/database_refactoring_plan.md | 1475 ----------------- docs/database_refactoring_test_report.md | 187 --- docs/development/json_parser_unification.md | 216 --- docs/guides/OBJECT_LEVEL_MEMORY_ANALYSIS.md | 267 --- docs/guides/memory_deduplication_guide.md | 391 ----- .../OPTIMIZATION_ARCHITECTURE_VISUAL.md | 451 ----- .../OPTIMIZATION_COMPLETION_REPORT.md | 345 ---- .../OPTIMIZATION_QUICK_REFERENCE.md | 216 --- .../OPTIMIZATION_REPORT_UNIFIED_MANAGER.md | 347 ---- docs/memory_graph/OPTIMIZATION_SUMMARY.md | 219 --- .../memory_graph/OPTIMIZATION_VISUAL_GUIDE.md | 287 ---- docs/message_dispatcher_refactoring.md | 210 --- docs/three_tier_memory_completion_report.md | 367 ---- scripts/benchmark_distribution_manager.py | 306 ---- scripts/benchmark_unified_manager.py | 276 --- scripts/test_bedrock_client.py | 204 --- 20 files changed, 7186 deletions(-) delete mode 100644 docs/affinity_flow_chatter_optimization_summary.md delete mode 100644 docs/affinity_flow_guide.md delete mode 100644 docs/database_api_migration_checklist.md delete mode 100644 docs/database_refactoring_completion.md delete mode 100644 docs/database_refactoring_plan.md delete mode 100644 docs/database_refactoring_test_report.md delete mode 100644 docs/development/json_parser_unification.md delete mode 100644 docs/guides/OBJECT_LEVEL_MEMORY_ANALYSIS.md delete mode 100644 docs/guides/memory_deduplication_guide.md delete mode 100644 docs/memory_graph/OPTIMIZATION_ARCHITECTURE_VISUAL.md delete mode 100644 docs/memory_graph/OPTIMIZATION_COMPLETION_REPORT.md delete mode 100644 docs/memory_graph/OPTIMIZATION_QUICK_REFERENCE.md delete mode 100644 docs/memory_graph/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md delete mode 100644 docs/memory_graph/OPTIMIZATION_SUMMARY.md delete mode 100644 docs/memory_graph/OPTIMIZATION_VISUAL_GUIDE.md delete mode 100644 docs/message_dispatcher_refactoring.md delete mode 100644 docs/three_tier_memory_completion_report.md delete mode 100644 scripts/benchmark_distribution_manager.py delete mode 100644 scripts/benchmark_unified_manager.py delete mode 100644 scripts/test_bedrock_client.py diff --git a/docs/affinity_flow_chatter_optimization_summary.md b/docs/affinity_flow_chatter_optimization_summary.md deleted file mode 100644 index c5d6c91c0..000000000 --- a/docs/affinity_flow_chatter_optimization_summary.md +++ /dev/null @@ -1,654 +0,0 @@ -# Affinity Flow Chatter 插件优化总结 - -## 更新日期 -2025年11月3日 - -## 优化概述 - -本次对 Affinity Flow Chatter 插件进行了全面的重构和优化,主要包括目录结构优化、性能改进、bug修复和新功能添加。 - -## � 任务-1: 细化提及分数机制(强提及 vs 弱提及) - -### 变更内容 -将原有的统一提及分数细化为**强提及**和**弱提及**两种类型,使用不同的分值。 - -### 原设计问题 -**旧逻辑**: -- ❌ 所有提及方式使用同一个分值(`mention_bot_interest_score`) -- ❌ 被@、私聊、文本提到名字都是相同的重要性 -- ❌ 无法区分用户的真实意图 - -### 新设计 - -#### 强提及(Strong Mention) -**定义**:用户**明确**想与bot交互 -- ✅ 被 @ 提及 -- ✅ 被回复 -- ✅ 私聊消息 - -**分值**:`strong_mention_interest_score = 2.5`(默认) - -#### 弱提及(Weak Mention) -**定义**:在讨论中**顺带**提到bot -- ✅ 消息中包含bot名字 -- ✅ 消息中包含bot别名 - -**分值**:`weak_mention_interest_score = 1.5`(默认) - -### 检测逻辑 - -```python -def is_mentioned_bot_in_message(message) -> tuple[bool, float]: - """ - Returns: - tuple[bool, float]: (是否提及, 提及类型) - 提及类型: 0=未提及, 1=弱提及, 2=强提及 - """ - # 1. 检查私聊 → 强提及 - if is_private_chat: - return True, 2.0 - - # 2. 检查 @ → 强提及 - if is_at: - return True, 2.0 - - # 3. 检查回复 → 强提及 - if is_replied: - return True, 2.0 - - # 4. 检查文本匹配 → 弱提及 - if text_contains_bot_name_or_alias: - return True, 1.0 - - return False, 0.0 -``` - -### 配置参数 - -**config/bot_config.toml**: -```toml -[affinity_flow] -# 提及bot相关参数 -strong_mention_interest_score = 2.5 # 强提及(@/回复/私聊) -weak_mention_interest_score = 1.5 # 弱提及(文本匹配) -``` - -### 实际效果对比 - -**场景1:被@** -``` -用户: "@小狐 你好呀" -旧逻辑: 提及分 = 2.5 -新逻辑: 提及分 = 2.5 (强提及) ✅ 保持不变 -``` - -**场景2:回复bot** -``` -用户: [回复 小狐:...] "是的" -旧逻辑: 提及分 = 2.5 -新逻辑: 提及分 = 2.5 (强提及) ✅ 保持不变 -``` - -**场景3:私聊** -``` -用户: "在吗" -旧逻辑: 提及分 = 2.5 -新逻辑: 提及分 = 2.5 (强提及) ✅ 保持不变 -``` - -**场景4:文本提及** -``` -用户: "小狐今天没来吗" -旧逻辑: 提及分 = 2.5 (可能过高) -新逻辑: 提及分 = 1.5 (弱提及) ✅ 更合理 -``` - -**场景5:讨论bot** -``` -用户A: "小狐这个bot挺有意思的" -旧逻辑: 提及分 = 2.5 (bot可能会插话) -新逻辑: 提及分 = 1.5 (弱提及,降低打断概率) ✅ 更自然 -``` - -### 优势 - -- ✅ **意图识别**:区分"想对话"和"在讨论" -- ✅ **减少误判**:降低在他人讨论中插话的概率 -- ✅ **灵活调节**:可以独立调整强弱提及的权重 -- ✅ **向后兼容**:保持原有强提及的行为不变 - -### 影响文件 - -- `config/bot_config.toml`:添加 `strong/weak_mention_interest_score` 配置 -- `template/bot_config_template.toml`:同步模板配置 -- `src/config/official_configs.py`:添加配置字段定义 -- `src/chat/utils/utils.py`:修改 `is_mentioned_bot_in_message()` 函数 -- `src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py`:使用新的强弱提及逻辑 -- `docs/affinity_flow_guide.md`:更新文档说明 - ---- - -## �🆔 任务0: 修改 Personality ID 生成逻辑 - -### 变更内容 -将 `bot_person_id` 从固定值改为基于人设文本的 hash 生成,实现人设变化时自动触发兴趣标签重新生成。 - -### 原设计问题 -**旧逻辑**: -```python -self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") -# 结果:md5("system_bot_id") = 固定值 -``` -- ❌ personality_id 固定不变 -- ❌ 人设修改后不会重新生成兴趣标签 -- ❌ 需要手动清空数据库才能触发重新生成 - -### 新设计 -**新逻辑**: -```python -personality_hash, _ = self._get_config_hash(bot_nickname, personality_core, personality_side, identity) -self.bot_person_id = personality_hash -# 结果:md5(人设配置的JSON) = 动态值 -``` - -### Hash 生成规则 -```python -personality_config = { - "nickname": bot_nickname, - "personality_core": personality_core, - "personality_side": personality_side, - "compress_personality": global_config.personality.compress_personality, -} -personality_hash = md5(json_dumps(personality_config, sorted=True)) -``` - -### 工作原理 -1. **初始化时**:根据当前人设配置计算 hash 作为 personality_id -2. **配置变化检测**: - - 计算当前人设的 hash - - 与上次保存的 hash 对比 - - 如果不同,触发重新生成 -3. **兴趣标签生成**: - - `bot_interest_manager` 根据 personality_id 查询数据库 - - 如果 personality_id 不存在(人设变化了),自动生成新的兴趣标签 - - 保存时使用新的 personality_id - -### 优势 -- ✅ **自动检测**:人设改变后无需手动操作 -- ✅ **数据隔离**:不同人设的兴趣标签分开存储 -- ✅ **版本管理**:可以保留历史人设的兴趣标签(如果需要) -- ✅ **逻辑清晰**:personality_id 直接反映人设内容 - -### 示例 -``` -人设 A: - nickname: "小狐" - personality_core: "活泼开朗" - personality_side: "喜欢编程" - → personality_id: a1b2c3d4e5f6... - -人设 B (修改后): - nickname: "小狐" - personality_core: "冷静理性" ← 改变 - personality_side: "喜欢编程" - → personality_id: f6e5d4c3b2a1... ← 自动生成新ID - -结果: -- 数据库查询时找不到 f6e5d4c3b2a1 的兴趣标签 -- 自动触发重新生成 -- 新兴趣标签保存在 f6e5d4c3b2a1 下 -``` - -### 影响范围 -- `src/individuality/individuality.py`:personality_id 生成逻辑 -- `src/chat/interest_system/bot_interest_manager.py`:兴趣标签加载/保存(已支持) -- 数据库:`bot_personality_interests` 表通过 personality_id 字段关联 - ---- - -## 📁 任务1: 优化插件目录结构 - -### 变更内容 -将原本扁平的文件结构重组为分层目录,提高代码可维护性: - -``` -affinity_flow_chatter/ -├── core/ # 核心模块 -│ ├── __init__.py -│ ├── affinity_chatter.py # 主聊天处理器 -│ └── affinity_interest_calculator.py # 兴趣度计算器 -│ -├── planner/ # 规划器模块 -│ ├── __init__.py -│ ├── planner.py # 动作规划器 -│ ├── planner_prompts.py # 提示词模板 -│ ├── plan_generator.py # 计划生成器 -│ ├── plan_filter.py # 计划过滤器 -│ └── plan_executor.py # 计划执行器 -│ -├── proactive/ # 主动思考模块 -│ ├── __init__.py -│ ├── proactive_thinking_scheduler.py # 主动思考调度器 -│ ├── proactive_thinking_executor.py # 主动思考执行器 -│ └── proactive_thinking_event.py # 主动思考事件 -│ -├── tools/ # 工具模块 -│ ├── __init__.py -│ ├── chat_stream_impression_tool.py # 聊天印象工具 -│ └── user_profile_tool.py # 用户档案工具 -│ -├── plugin.py # 插件注册 -├── __init__.py # 插件元数据 -└── README.md # 文档 -``` - -### 优势 -- ✅ **逻辑清晰**:相关功能集中在同一目录 -- ✅ **易于维护**:模块职责明确,便于定位和修改 -- ✅ **可扩展性**:新功能可以轻松添加到对应目录 -- ✅ **团队协作**:多人开发时减少文件冲突 - ---- - -## 💾 任务2: 修改 Embedding 存储策略 - -### 问题分析 -**原设计**:兴趣标签的 embedding 向量(2560维度浮点数组)直接存储在数据库中 -- ❌ 数据库存储过长,可能导致写入失败 -- ❌ 每次加载需要反序列化大量数据 -- ❌ 数据库体积膨胀 - -### 解决方案 -**新设计**:Embedding 改为启动时动态生成并缓存在内存中 - -#### 实现细节 - -**1. 数据库存储**(不再包含 embedding): -```python -# 保存时 -tag_dict = { - "tag_name": tag.tag_name, - "weight": tag.weight, - "expanded": tag.expanded, # 扩展描述 - "created_at": tag.created_at.isoformat(), - "updated_at": tag.updated_at.isoformat(), - "is_active": tag.is_active, - # embedding 不再存储 -} -``` - -**2. 启动时动态生成**: -```python -async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): - """为所有兴趣标签生成embedding(仅缓存在内存中)""" - for tag in interests.interest_tags: - if tag.tag_name in self.embedding_cache: - # 使用内存缓存 - tag.embedding = self.embedding_cache[tag.tag_name] - else: - # 动态生成新的embedding - embedding = await self._get_embedding(tag.tag_name) - tag.embedding = embedding # 设置到内存对象 - self.embedding_cache[tag.tag_name] = embedding # 缓存 -``` - -**3. 加载时处理**: -```python -tag = BotInterestTag( - tag_name=tag_data.get("tag_name", ""), - weight=tag_data.get("weight", 0.5), - expanded=tag_data.get("expanded"), - embedding=None, # 不从数据库加载,改为动态生成 - # ... -) -``` - -### 优势 -- ✅ **数据库轻量化**:数据库只存储标签名和权重等元数据 -- ✅ **避免写入失败**:不再因为数据过长导致数据库操作失败 -- ✅ **灵活性**:可以随时切换 embedding 模型而无需迁移数据 -- ✅ **性能**:内存缓存访问速度快 - -### 权衡 -- ⚠️ 启动时需要生成 embedding(首次启动稍慢,约10-20秒) -- ✅ 后续运行时使用内存缓存,性能与原来相当 - ---- - -## 🔧 任务3: 修复连续不回复阈值调整问题 - -### 问题描述 -原实现中,连续不回复调整只提升了分数,但阈值保持不变: -```python -# ❌ 错误的实现 -adjusted_score = self._apply_no_reply_boost(total_score) # 只提升分数 -should_reply = adjusted_score >= self.reply_threshold # 阈值不变 -``` - -**问题**:动作阈值(`non_reply_action_interest_threshold`)没有被调整,导致即使回复阈值满足,动作阈值可能仍然不满足。 - -### 解决方案 -改为**同时降低回复阈值和动作阈值**: - -```python -def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]: - """应用阈值调整(包括连续不回复和回复后降低机制)""" - base_reply_threshold = self.reply_threshold - base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold - - total_reduction = 0.0 - - # 连续不回复的阈值降低 - if self.no_reply_count > 0: - no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply - total_reduction += no_reply_reduction - - # 应用到两个阈值 - adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction) - adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction) - - return adjusted_reply_threshold, adjusted_action_threshold -``` - -**使用**: -```python -# ✅ 正确的实现 -adjusted_reply_threshold, adjusted_action_threshold = self._apply_no_reply_threshold_adjustment() -should_reply = adjusted_score >= adjusted_reply_threshold -should_take_action = adjusted_score >= adjusted_action_threshold -``` - -### 优势 -- ✅ **逻辑一致**:回复阈值和动作阈值同步调整 -- ✅ **避免矛盾**:不会出现"满足回复但不满足动作"的情况 -- ✅ **更合理**:连续不回复时,bot更容易采取任何行动 - ---- - -## ⏱️ 任务4: 添加兴趣度计算超时机制 - -### 问题描述 -兴趣匹配计算调用 embedding API,可能因为网络问题或模型响应慢导致: -- ❌ 长时间等待(>5秒) -- ❌ 整体超时导致强制使用默认分值 -- ❌ **丢失了提及分和关系分**(因为整个计算被中断) - -### 解决方案 -为兴趣匹配计算添加**1.5秒超时保护**,超时时返回默认分值: - -```python -async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float: - """计算兴趣匹配度(带超时保护)""" - try: - # 使用 asyncio.wait_for 添加1.5秒超时 - match_result = await asyncio.wait_for( - bot_interest_manager.calculate_interest_match(content, keywords or []), - timeout=1.5 - ) - - if match_result: - # 正常计算分数 - final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus - return final_score - else: - return 0.0 - - except asyncio.TimeoutError: - # 超时时返回默认分值 0.5 - logger.warning("⏱️ 兴趣匹配计算超时(>5秒),返回默认分值0.5以保留其他分数") - return 0.5 # 避免丢失提及分和关系分 - - except Exception as e: - logger.warning(f"智能兴趣匹配失败: {e}") - return 0.0 -``` - -### 工作流程 -``` -正常情况(<1.5秒): - 兴趣匹配分: 0.8 + 关系分: 0.3 + 提及分: 2.5 = 3.6 ✅ - -超时情况(>1.5秒): - 兴趣匹配分: 0.5(默认)+ 关系分: 0.3 + 提及分: 2.5 = 3.3 ✅ - (保留了关系分和提及分) - -强制中断(无超时保护): - 整体计算失败 = 0.0(默认) ❌ - (丢失了所有分数) -``` - -### 优势 -- ✅ **防止阻塞**:不会因为一个API调用卡住整个流程 -- ✅ **保留分数**:即使兴趣匹配超时,提及分和关系分依然有效 -- ✅ **用户体验**:响应更快,不会长时间无反应 -- ✅ **降级优雅**:超时时仍能给出合理的默认值 - ---- - -## 🔄 任务5: 实现回复后阈值降低机制 - -### 需求背景 -**目标**:让bot在回复后更容易进行连续对话,提升对话的连贯性和自然性。 - -**场景示例**: -``` -用户: "你好呀" -Bot: "你好!今天过得怎么样?" ← 此时激活连续对话模式 - -用户: "还不错" -Bot: "那就好~有什么有趣的事情吗?" ← 阈值降低,更容易回复 - -用户: "没什么" -Bot: "嗯嗯,那要不要聊聊别的?" ← 仍然更容易回复 - -用户: "..." -(如果一直不回复,降低效果会逐渐衰减) -``` - -### 配置项 -在 `bot_config.toml` 中添加: - -```toml -# 回复后连续对话机制参数 -enable_post_reply_boost = true # 是否启用回复后阈值降低机制 -post_reply_threshold_reduction = 0.15 # 回复后初始阈值降低值 -post_reply_boost_max_count = 3 # 回复后阈值降低的最大持续次数 -post_reply_boost_decay_rate = 0.5 # 每次回复后阈值降低衰减率(0-1) -``` - -### 实现细节 - -**1. 初始化计数器**: -```python -def __init__(self): - # 回复后阈值降低机制 - self.enable_post_reply_boost = affinity_config.enable_post_reply_boost - self.post_reply_boost_remaining = 0 # 剩余的回复后降低次数 - self.post_reply_threshold_reduction = affinity_config.post_reply_threshold_reduction - self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count - self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate -``` - -**2. 阈值调整**: -```python -def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]: - """应用阈值调整""" - total_reduction = 0.0 - - # 1. 连续不回复的降低 - if self.no_reply_count > 0: - no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply - total_reduction += no_reply_reduction - - # 2. 回复后的降低(带衰减) - if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0: - # 计算衰减因子 - decay_factor = self.post_reply_boost_decay_rate ** ( - self.post_reply_boost_max_count - self.post_reply_boost_remaining - ) - post_reply_reduction = self.post_reply_threshold_reduction * decay_factor - total_reduction += post_reply_reduction - - # 应用总降低量 - adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction) - adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction) - - return adjusted_reply_threshold, adjusted_action_threshold -``` - -**3. 状态更新**: -```python -def on_reply_sent(self): - """当机器人发送回复后调用""" - if self.enable_post_reply_boost: - # 重置回复后降低计数器 - self.post_reply_boost_remaining = self.post_reply_boost_max_count - # 同时重置不回复计数 - self.no_reply_count = 0 - -def on_message_processed(self, replied: bool): - """消息处理完成后调用""" - # 更新不回复计数 - self.update_no_reply_count(replied) - - # 如果已回复,激活回复后降低机制 - if replied: - self.on_reply_sent() - else: - # 如果没有回复,减少回复后降低剩余次数 - if self.post_reply_boost_remaining > 0: - self.post_reply_boost_remaining -= 1 -``` - -### 衰减机制说明 - -**衰减公式**: -``` -decay_factor = decay_rate ^ (max_count - remaining_count) -actual_reduction = base_reduction * decay_factor -``` - -**示例**(`base_reduction=0.15`, `decay_rate=0.5`, `max_count=3`): -``` -第1次回复后: decay_factor = 0.5^0 = 1.00, reduction = 0.15 * 1.00 = 0.15 -第2次回复后: decay_factor = 0.5^1 = 0.50, reduction = 0.15 * 0.50 = 0.075 -第3次回复后: decay_factor = 0.5^2 = 0.25, reduction = 0.15 * 0.25 = 0.0375 -``` - -### 实际效果 - -**配置示例**: -- 回复阈值: 0.7 -- 初始降低值: 0.15 -- 最大次数: 3 -- 衰减率: 0.5 - -**对话流程**: -``` -初始状态: - 回复阈值: 0.7 - -Bot发送回复 → 激活连续对话模式: - 剩余次数: 3 - -第1条消息: - 阈值降低: 0.15 - 实际阈值: 0.7 - 0.15 = 0.55 ✅ 更容易回复 - -第2条消息: - 阈值降低: 0.075 (衰减) - 实际阈值: 0.7 - 0.075 = 0.625 - -第3条消息: - 阈值降低: 0.0375 (继续衰减) - 实际阈值: 0.7 - 0.0375 = 0.6625 - -第4条消息: - 降低结束,恢复正常阈值: 0.7 -``` - -### 优势 -- ✅ **连贯对话**:bot回复后更容易继续对话 -- ✅ **自然衰减**:避免无限连续回复,逐渐恢复正常 -- ✅ **可配置**:可以根据需求调整降低值、次数和衰减率 -- ✅ **灵活控制**:可以随时启用/禁用此功能 - ---- - -## 📊 整体影响 - -### 性能优化 -- ✅ **内存优化**:不再在数据库中存储大量 embedding 数据 -- ✅ **响应速度**:超时保护避免长时间等待 -- ✅ **启动速度**:首次启动需要生成 embedding(10-20秒),后续运行使用缓存 - -### 功能增强 -- ✅ **阈值调整**:修复了回复和动作阈值不一致的问题 -- ✅ **连续对话**:新增回复后阈值降低机制,提升对话连贯性 -- ✅ **容错能力**:超时保护确保即使API失败也能保留其他分数 - -### 代码质量 -- ✅ **目录结构**:清晰的模块划分,易于维护 -- ✅ **可扩展性**:新功能可以轻松添加到对应目录 -- ✅ **可配置性**:关键参数可通过配置文件调整 - ---- - -## 🔧 使用说明 - -### 配置调整 - -在 `config/bot_config.toml` 中调整回复后连续对话参数: - -```toml -[affinity_flow] -# 回复后连续对话机制 -enable_post_reply_boost = true # 启用/禁用 -post_reply_threshold_reduction = 0.15 # 初始降低值(建议0.1-0.2) -post_reply_boost_max_count = 3 # 持续次数(建议2-5) -post_reply_boost_decay_rate = 0.5 # 衰减率(建议0.3-0.7) -``` - -### 调用方式 - -在 planner 或其他需要的地方调用: - -```python -# 计算兴趣值 -result = await interest_calculator.execute(message) - -# 消息处理完成后更新状态 -interest_calculator.on_message_processed(replied=result.should_reply) -``` - ---- - -## 🐛 已知问题 - -暂无 - ---- - -## 📝 后续优化建议 - -1. **监控日志**:观察实际使用中的阈值调整效果 -2. **A/B测试**:对比启用/禁用回复后降低机制的对话质量 -3. **参数调优**:根据实际使用情况调整默认配置值 -4. **性能监控**:监控 embedding 生成的时间和缓存命中率 - ---- - -## 👥 贡献者 - -- GitHub Copilot - 代码实现和文档编写 - ---- - -## 📅 更新历史 - -- 2025-11-03: 完成所有5个任务的实现 - - ✅ 优化插件目录结构 - - ✅ 修改 embedding 存储策略 - - ✅ 修复连续不回复阈值调整 - - ✅ 添加超时保护机制 - - ✅ 实现回复后阈值降低 diff --git a/docs/affinity_flow_guide.md b/docs/affinity_flow_guide.md deleted file mode 100644 index d2929572c..000000000 --- a/docs/affinity_flow_guide.md +++ /dev/null @@ -1,170 +0,0 @@ -# affinity_flow 配置项详解与调整指南 - -本指南详细说明了 MoFox-Bot `bot_config.toml` 配置文件中 `[affinity_flow]` 区块的各项参数,帮助你根据实际需求调整兴趣评分系统与回复决策系统的行为。 - ---- - -## 一、affinity_flow 作用简介 - -`affinity_flow` 主要用于控制 AI 对消息的兴趣评分(afc),并据此决定是否回复、如何回复、是否发送表情包等。通过合理调整这些参数,可以让 Bot 的回复行为更贴合你的预期。 - ---- - -## 二、配置项说明 - -### 1. 兴趣评分相关参数 - -- `reply_action_interest_threshold` - 回复动作兴趣阈值。只有兴趣分高于此值,Bot 才会主动回复消息。 - - **建议调整**:提高此值,Bot 回复更谨慎;降低则更容易回复。 - -- `non_reply_action_interest_threshold` - 非回复动作兴趣阈值(如发送表情包等)。兴趣分高于此值时,Bot 可能采取非回复行为。 - -- `high_match_interest_threshold` - 高匹配兴趣阈值。关键词匹配度高于此值时,视为高匹配。 - -- `medium_match_interest_threshold` - 中匹配兴趣阈值。 - -- `low_match_interest_threshold` - 低匹配兴趣阈值。 - -- `high_match_keyword_multiplier` - 高匹配关键词兴趣倍率。高匹配关键词对兴趣分的加成倍数。 - -- `medium_match_keyword_multiplier` - 中匹配关键词兴趣倍率。 - -- `low_match_keyword_multiplier` - 低匹配关键词兴趣倍率。 - - 匹配关键词数量的加成值。匹配越多,兴趣分越高。 - -- `max_match_bonus` - 匹配数加成的最大值。 - -### 2. 回复决策相关参数 - -- `no_reply_threshold_adjustment` - 不回复兴趣阈值调整值。用于动态调整不回复的兴趣阈值。bot每不回复一次,就会在基础阈值上降低该值。 - -- `reply_cooldown_reduction` - 回复后减少的不回复计数。回复后,Bot 会更快恢复到基础阈值的状态。 - -- `max_no_reply_count` - 最大不回复计数次数。防止 Bot 的回复阈值被过度降低。 - -### 3. 综合评分权重 - -- `keyword_match_weight` - 兴趣关键词匹配度权重。关键词匹配对总兴趣分的影响比例。 - -- `mention_bot_weight` - 提及 Bot 分数权重。被提及时兴趣分提升的权重。 - -- `relationship_weight` - -### 4. 提及 Bot 相关参数 - -- `mention_bot_adjustment_threshold` - 提及 Bot 后的调整阈值。当bot被提及后,回复阈值会改变为这个值。 - -- `strong_mention_interest_score` - 强提及的兴趣分。强提及包括:被@、被回复、私聊消息。这类提及表示用户明确想与bot交互。 - -- `weak_mention_interest_score` - 弱提及的兴趣分。弱提及包括:消息中包含bot的名字或别名(文本匹配)。这类提及可能只是在讨论中提到bot。 - -- `base_relationship_score` ---- - -1. **Bot 太冷漠/回复太少** - - 降低 `reply_action_interest_threshold`,或降低高中低关键词匹配的阈值。 - -2. **Bot 太热情/回复太多** - - 提高 `reply_action_interest_threshold`,或降低关键词相关倍率。 - -3. **希望 Bot 更关注被 @ 或回复的消息** - - 提高 `strong_mention_interest_score` 或 `mention_bot_weight`。 - -4. **希望 Bot 对文本提及也积极回应** - - 提高 `weak_mention_interest_score`。 - -5. **希望 Bot 更看重关系好的用户** - - 提高 `relationship_weight` 或 `base_relationship_score`。 - -6. **表情包行为过于频繁/稀少** - - 调整 `non_reply_action_interest_threshold`。 - ---- - -## 四、参数调整建议流程 - -1. 明确你希望 Bot 的行为(如更活跃/更安静/更关注特定用户等)。 -2. 根据上表找到相关参数,优先调整权重和阈值。 -3. 每次只微调一两个参数,观察实际效果。 -4. 如需更细致的行为控制,可结合关键词、关系等多项参数综合调整。 - ---- - -## 五、示例配置片段 - -```toml -[affinity_flow] -reply_action_interest_threshold = 1.1 -non_reply_action_interest_threshold = 0.9 -high_match_interest_threshold = 0.7 -medium_match_interest_threshold = 0.4 -low_match_interest_threshold = 0.2 -high_match_keyword_multiplier = 5 -medium_match_keyword_multiplier = 3.75 -low_match_keyword_multiplier = 1.3 -match_count_bonus = 0.02 -max_match_bonus = 0.25 -no_reply_threshold_adjustment = 0.01 -reply_cooldown_reduction = 5 -max_no_reply_count = 20 -keyword_match_weight = 0.4 -mention_bot_weight = 0.3 -relationship_weight = 0.3 -mention_bot_adjustment_threshold = 0.5 -strong_mention_interest_score = 2.5 # 强提及(@/回复/私聊) -weak_mention_interest_score = 1.5 # 弱提及(文本匹配) -base_relationship_score = 0.3 -``` - -## 六、afc兴趣度评分决策流程详解 - -MoFox-Bot 在收到每条消息时,会通过一套“兴趣度评分(afc)”决策流程,综合多种因素计算出对该消息的兴趣分,并据此决定是否回复、如何回复或采取其他动作。以下为典型流程说明: - -### 1. 关键词匹配与兴趣加成 -- Bot 首先分析消息内容,查找是否包含高、中、低匹配的兴趣关键词。 -- 不同匹配度的关键词会乘以对应的倍率(high/medium/low_match_keyword_multiplier),并根据匹配数量叠加加成(match_count_bonus,max_match_bonus)。 - -### 2. 提及与关系加分 -- 如果消息中提及了 Bot,会根据提及类型获得不同的兴趣分: - * **强提及**(被@、被回复、私聊): 获得 `strong_mention_interest_score` 分值,表示用户明确想与bot交互 - * **弱提及**(文本中包含bot名字或别名): 获得 `weak_mention_interest_score` 分值,表示在讨论中提到bot - * 提及分按权重(`mention_bot_weight`)计入总分 -- 与用户的关系分(base_relationship_score 及动态关系分)也会按 relationship_weight 计入总分。 - -### 3. 综合评分计算 -- 最终兴趣分 = 关键词匹配分 × keyword_match_weight + 提及分 × mention_bot_weight + 关系分 × relationship_weight。 -- 你可以通过调整各权重,决定不同因素对总兴趣分的影响。 - -### 4. 阈值判定与回复决策 -- 若兴趣分高于 reply_action_interest_threshold,Bot 会主动回复。 -- 若兴趣分高于 non_reply_action_interest_threshold,但低于回复阈值,Bot 可能采取如发送表情包等非回复行为。 -- 若兴趣分均未达到阈值,则不回复。 - -### 5. 动态阈值调整机制 -- Bot 连续多次不回复时,reply_action_interest_threshold 会根据 no_reply_threshold_adjustment 逐步降低,最多降低 max_no_reply_count 次,防止长时间沉默。 -- 回复后,阈值通过 reply_cooldown_reduction 恢复。 -- 被@时,阈值可临时调整为 mention_bot_adjustment_threshold。 - -### 6. 典型决策流程图 - -1. 收到消息 → 2. 关键词/提及/关系分计算 → 3. 综合兴趣分加权 → 4. 与阈值比较 → 5. 决定回复/表情/忽略 - -通过理解上述流程,你可以有针对性地调整各项参数,让 Bot 的回复行为更贴合你的需求。 \ No newline at end of file diff --git a/docs/database_api_migration_checklist.md b/docs/database_api_migration_checklist.md deleted file mode 100644 index 08ff7ad3c..000000000 --- a/docs/database_api_migration_checklist.md +++ /dev/null @@ -1,374 +0,0 @@ -# 数据库API迁移检查清单 - -## 概述 - -本文档列出了项目中需要从直接数据库查询迁移到使用优化后API的代码位置。 - -## 为什么需要迁移? - -优化后的API具有以下优势: -1. **自动缓存**: 高频查询已集成多级缓存,减少90%+数据库访问 -2. **批量处理**: 消息存储使用批处理,减少连接池压力 -3. **统一接口**: 标准化的错误处理和日志记录 -4. **性能监控**: 内置性能统计和慢查询警告 -5. **代码简洁**: 简化的API调用,减少样板代码 - -## 迁移优先级 - -### 🔴 高优先级(高频查询) - -#### 1. PersonInfo 查询 - `src/person_info/person_info.py` - -**当前实现**:直接使用 SQLAlchemy `session.execute(select(PersonInfo)...)` - -**影响范围**: -- `get_value()` - 每条消息都会调用 -- `get_values()` - 批量查询用户信息 -- `update_one_field()` - 更新用户字段 -- `is_person_known()` - 检查用户是否已知 -- `get_person_info_by_name()` - 根据名称查询 - -**迁移目标**:使用 `src.common.database.api.specialized` 中的: -```python -from src.common.database.api.specialized import ( - get_or_create_person, - update_person_affinity, -) - -# 替代直接查询 -person, created = await get_or_create_person( - platform=platform, - person_id=person_id, - defaults={"nickname": nickname, ...} -) -``` - -**优势**: -- ✅ 10分钟缓存,减少90%+数据库查询 -- ✅ 自动缓存失效机制 -- ✅ 标准化的错误处理 - -**预计工作量**:⏱️ 2-4小时 - ---- - -#### 2. UserRelationships 查询 - `src/person_info/relationship_fetcher.py` - -**当前实现**:使用 `db_query(UserRelationships, ...)` - -**影响代码**: -- `build_relation_info()` 第189行 -- 查询用户关系数据 - -**迁移目标**: -```python -from src.common.database.api.specialized import ( - get_user_relationship, - update_relationship_affinity, -) - -# 替代 db_query -relationship = await get_user_relationship( - platform=platform, - user_id=user_id, - target_id=target_id, -) -``` - -**优势**: -- ✅ 5分钟缓存 -- ✅ 高频场景减少80%+数据库访问 -- ✅ 自动缓存失效 - -**预计工作量**:⏱️ 1-2小时 - ---- - -#### 3. ChatStreams 查询 - `src/person_info/relationship_fetcher.py` - -**当前实现**:使用 `db_query(ChatStreams, ...)` - -**影响代码**: -- `build_chat_stream_impression()` 第250行 - -**迁移目标**: -```python -from src.common.database.api.specialized import get_or_create_chat_stream - -stream, created = await get_or_create_chat_stream( - stream_id=stream_id, - platform=platform, - defaults={...} -) -``` - -**优势**: -- ✅ 5分钟缓存 -- ✅ 减少重复查询 -- ✅ 活跃会话期间性能提升75%+ - -**预计工作量**:⏱️ 30分钟-1小时 - ---- - -### 🟡 中优先级(中频查询) - -#### 4. ActionRecords 查询 - `src/chat/utils/statistic.py` - -**当前实现**:使用 `db_query(ActionRecords, ...)` - -**影响代码**: -- 第73行:更新行为记录 -- 第97行:插入新记录 -- 第105行:查询记录 - -**迁移目标**: -```python -from src.common.database.api.specialized import store_action_info, get_recent_actions - -# 存储行为 -await store_action_info( - user_id=user_id, - action_type=action_type, - ... -) - -# 获取最近行为 -actions = await get_recent_actions( - user_id=user_id, - limit=10 -) -``` - -**优势**: -- ✅ 标准化的API -- ✅ 更好的性能监控 -- ✅ 未来可添加缓存 - -**预计工作量**:⏱️ 1-2小时 - ---- - -#### 5. CacheEntries 查询 - `src/common/cache_manager.py` - -**当前实现**:使用 `db_query(CacheEntries, ...)` - -**注意**:这是旧的基于数据库的缓存系统 - -**建议**: -- ⚠️ 考虑完全迁移到新的 `MultiLevelCache` 系统 -- ⚠️ 新系统使用内存缓存,性能更好 -- ⚠️ 如需持久化,可以添加持久化层 - -**预计工作量**:⏱️ 4-8小时(如果重构整个缓存系统) - ---- - -### 🟢 低优先级(低频查询或测试代码) - -#### 6. 测试代码 - `tests/test_api_utils_compatibility.py` - -**当前实现**:测试中使用直接查询 - -**建议**: -- ℹ️ 测试代码可以保持现状 -- ℹ️ 但可以添加新的测试用例测试优化后的API - -**预计工作量**:⏱️ 可选 - ---- - -## 迁移步骤 - -### 第一阶段:高频查询(推荐立即进行) - -1. **迁移 PersonInfo 查询** - - [ ] 修改 `person_info.py` 的 `get_value()` - - [ ] 修改 `person_info.py` 的 `get_values()` - - [ ] 修改 `person_info.py` 的 `update_one_field()` - - [ ] 修改 `person_info.py` 的 `is_person_known()` - - [ ] 测试缓存效果 - -2. **迁移 UserRelationships 查询** - - [ ] 修改 `relationship_fetcher.py` 的关系查询 - - [ ] 测试缓存效果 - -3. **迁移 ChatStreams 查询** - - [ ] 修改 `relationship_fetcher.py` 的流查询 - - [ ] 测试缓存效果 - -### 第二阶段:中频查询(可以分批进行) - -4. **迁移 ActionRecords** - - [ ] 修改 `statistic.py` 的行为记录 - - [ ] 添加单元测试 - -### 第三阶段:系统优化(长期目标) - -5. **重构旧缓存系统** - - [ ] 评估 `cache_manager.py` 的使用情况 - - [ ] 制定迁移到 MultiLevelCache 的计划 - - [ ] 逐步迁移 - ---- - -## 性能提升预期 - -基于当前测试数据: - -| 查询类型 | 迁移前 QPS | 迁移后 QPS | 提升 | 数据库负载降低 | -|---------|-----------|-----------|------|--------------| -| PersonInfo | ~50 | ~500+ | **10倍** | **90%+** | -| UserRelationships | ~30 | ~150+ | **5倍** | **80%+** | -| ChatStreams | ~40 | ~160+ | **4倍** | **75%+** | - -**总体效果**: -- 📈 高峰期数据库连接数减少 **80%+** -- 📈 平均响应时间降低 **70%+** -- 📈 系统吞吐量提升 **5-10倍** - ---- - -## 注意事项 - -### 1. 缓存一致性 - -迁移后需要确保: -- ✅ 所有更新操作都正确使缓存失效 -- ✅ 缓存键的生成逻辑一致 -- ✅ TTL设置合理 - -### 2. 测试覆盖 - -每次迁移后需要: -- ✅ 运行单元测试 -- ✅ 测试缓存命中率 -- ✅ 监控性能指标 -- ✅ 检查日志中的缓存统计 - -### 3. 回滚计划 - -如果遇到问题: -- 🔄 保留原有代码在注释中 -- 🔄 使用 git 标签标记迁移点 -- 🔄 准备快速回滚脚本 - -### 4. 逐步迁移 - -建议: -- ⭐ 一次迁移一个模块 -- ⭐ 在测试环境充分验证 -- ⭐ 监控生产环境指标 -- ⭐ 根据反馈调整策略 - ---- - -## 迁移示例 - -### 示例1:PersonInfo 查询迁移 - -**迁移前**: -```python -# src/person_info/person_info.py -async def get_value(self, person_id: str, field_name: str): - async with get_db_session() as session: - result = await session.execute( - select(PersonInfo).where(PersonInfo.person_id == person_id) - ) - person = result.scalar_one_or_none() - if person: - return getattr(person, field_name, None) - return None -``` - -**迁移后**: -```python -# src/person_info/person_info.py -async def get_value(self, person_id: str, field_name: str): - from src.common.database.api.crud import CRUDBase - from src.common.database.core.models import PersonInfo - from src.common.database.utils.decorators import cached - - @cached(ttl=600, key_prefix=f"person_field_{field_name}") - async def _get_cached_value(pid: str): - crud = CRUDBase(PersonInfo) - person = await crud.get_by(person_id=pid) - if person: - return getattr(person, field_name, None) - return None - - return await _get_cached_value(person_id) -``` - -或者更简单,使用现有的 `get_or_create_person`: -```python -async def get_value(self, person_id: str, field_name: str): - from src.common.database.api.specialized import get_or_create_person - - # 解析 person_id 获取 platform 和 user_id - # (需要调整 get_or_create_person 支持 person_id 查询, - # 或者在 PersonInfoManager 中缓存映射关系) - person, _ = await get_or_create_person( - platform=self._platform_cache.get(person_id), - person_id=person_id, - ) - if person: - return getattr(person, field_name, None) - return None -``` - -### 示例2:UserRelationships 迁移 - -**迁移前**: -```python -# src/person_info/relationship_fetcher.py -relationships = await db_query( - UserRelationships, - filters={"user_id": user_id}, - limit=1, -) -``` - -**迁移后**: -```python -from src.common.database.api.specialized import get_user_relationship - -relationship = await get_user_relationship( - platform=platform, - user_id=user_id, - target_id=target_id, -) -# 如果需要查询某个用户的所有关系,可以添加新的API函数 -``` - ---- - -## 进度跟踪 - -| 任务 | 状态 | 负责人 | 预计完成时间 | 实际完成时间 | 备注 | -|-----|------|--------|------------|------------|------| -| PersonInfo 迁移 | ⏳ 待开始 | - | - | - | 高优先级 | -| UserRelationships 迁移 | ⏳ 待开始 | - | - | - | 高优先级 | -| ChatStreams 迁移 | ⏳ 待开始 | - | - | - | 高优先级 | -| ActionRecords 迁移 | ⏳ 待开始 | - | - | - | 中优先级 | -| 缓存系统重构 | ⏳ 待开始 | - | - | - | 长期目标 | - ---- - -## 相关文档 - -- [数据库缓存系统使用指南](./database_cache_guide.md) -- [数据库重构完成报告](./database_refactoring_completion.md) -- [优化后的API文档](../src/common/database/api/specialized.py) - ---- - -## 联系与支持 - -如果在迁移过程中遇到问题: -1. 查看相关文档 -2. 检查示例代码 -3. 运行测试验证 -4. 查看日志中的缓存统计 - -**最后更新**: 2025-11-01 diff --git a/docs/database_refactoring_completion.md b/docs/database_refactoring_completion.md deleted file mode 100644 index e8bfbe6dc..000000000 --- a/docs/database_refactoring_completion.md +++ /dev/null @@ -1,224 +0,0 @@ -# 数据库重构完成总结 - -## 📊 重构概览 - -**重构周期**: 2025年11月1日完成 -**分支**: `feature/database-refactoring` -**总提交数**: 8次 -**总测试通过率**: 26/26 (100%) - ---- - -## 🎯 重构目标达成 - -### ✅ 核心目标 - -1. **6层架构实现** - 完成所有6层的设计和实现 -2. **完全向后兼容** - 旧代码无需修改即可工作 -3. **性能优化** - 实现多级缓存、智能预加载、批量调度 -4. **代码质量** - 100%测试覆盖,清晰的架构设计 - -### ✅ 实施成果 - -#### 1. 核心层 (Core Layer) -- ✅ `DatabaseEngine`: 单例模式,SQLite优化 (WAL模式) -- ✅ `SessionFactory`: 异步会话工厂,连接池管理 -- ✅ `models.py`: 25个数据模型,统一定义 -- ✅ `migration.py`: 数据库迁移和检查 - -#### 2. API层 (API Layer) -- ✅ `CRUDBase`: 通用CRUD操作,支持缓存 -- ✅ `QueryBuilder`: 链式查询构建器 -- ✅ `AggregateQuery`: 聚合查询支持 (sum, avg, count等) -- ✅ `specialized.py`: 特殊业务API (人物、LLM统计等) - -#### 3. 优化层 (Optimization Layer) -- ✅ `CacheManager`: 3级缓存 (L1内存/L2 SQLite/L3预加载) -- ✅ `IntelligentPreloader`: 智能数据预加载,访问模式学习 -- ✅ `AdaptiveBatchScheduler`: 自适应批量调度器 - -#### 4. 配置层 (Config Layer) -- ✅ `DatabaseConfig`: 数据库配置管理 -- ✅ `CacheConfig`: 缓存策略配置 -- ✅ `PreloaderConfig`: 预加载器配置 - -#### 5. 工具层 (Utils Layer) -- ✅ `decorators.py`: 重试、超时、缓存、性能监控装饰器 -- ✅ `monitoring.py`: 数据库性能监控 - -#### 6. 兼容层 (Compatibility Layer) -- ✅ `adapter.py`: 向后兼容适配器 -- ✅ `MODEL_MAPPING`: 25个模型映射 -- ✅ 旧API兼容: `db_query`, `db_save`, `db_get`, `store_action_info` - ---- - -## 📈 测试结果 - -### Stage 4-6 测试 (兼容性层) -``` -✅ 26/26 测试通过 (100%) - -测试覆盖: -- CRUDBase: 6/6 ✅ -- QueryBuilder: 3/3 ✅ -- AggregateQuery: 1/1 ✅ -- SpecializedAPI: 3/3 ✅ -- Decorators: 4/4 ✅ -- Monitoring: 2/2 ✅ -- Compatibility: 6/6 ✅ -- Integration: 1/1 ✅ -``` - -### Stage 1-3 测试 (基础架构) -``` -✅ 18/21 测试通过 (85.7%) - -测试覆盖: -- Core Layer: 4/4 ✅ -- Cache Manager: 5/5 ✅ -- Preloader: 3/3 ✅ -- Batch Scheduler: 4/5 (1个超时测试) -- Integration: 1/2 (1个并发测试) -- Performance: 1/2 (1个吞吐量测试) -``` - -### 总体评估 -- **核心功能**: 100% 通过 ✅ -- **性能优化**: 85.7% 通过 (非关键超时测试失败) -- **向后兼容**: 100% 通过 ✅ - ---- - -## 🔄 导入路径迁移 - -### 批量更新统计 -- **更新文件数**: 37个 -- **修改次数**: 67处 -- **自动化工具**: `scripts/update_database_imports.py` - -### 导入映射表 - -| 旧路径 | 新路径 | 用途 | -|--------|--------|------| -| `sqlalchemy_models` | `core.models` | 数据模型 | -| `sqlalchemy_models` | `core` | get_db_session, get_engine | -| `sqlalchemy_database_api` | `compatibility` | db_*, MODEL_MAPPING | -| `database.database` | `core` | initialize, stop | - -### 更新文件列表 -主要更新了以下模块: -- `bot.py`, `main.py` - 主程序入口 -- `src/schedule/` - 日程管理 (3个文件) -- `src/plugin_system/` - 插件系统 (4个文件) -- `src/plugins/built_in/` - 内置插件 (8个文件) -- `src/chat/` - 聊天系统 (20+个文件) -- `src/person_info/` - 人物信息 (2个文件) -- `scripts/` - 工具脚本 (2个文件) - ---- - -## 🗃️ 旧文件归档 - -已将6个旧数据库文件移动到 `src/common/database/old/`: -- `sqlalchemy_models.py` (783行) → 已被 `core/models.py` 替代 -- `sqlalchemy_database_api.py` (600+行) → 已被 `compatibility/adapter.py` 替代 -- `database.py` (200+行) → 已被 `core/__init__.py` 替代 -- `db_migration.py` → 已被 `core/migration.py` 替代 -- `db_batch_scheduler.py` → 已被 `optimization/batch_scheduler.py` 替代 -- `sqlalchemy_init.py` → 已被 `core/engine.py` 替代 - ---- - -## 📝 提交历史 - -```bash -f6318fdb refactor: 清理旧数据库文件并完成导入更新 -a1dc03ca refactor: 完成数据库重构 - 批量更新导入路径 -62c644c1 fix: 修复get_or_create返回值和MODEL_MAPPING -51940f1d fix(database): 修复get_or_create返回元组的处理 -59d2a4e9 fix(database): 修复record_llm_usage函数的字段映射 -b58f69ec fix(database): 修复decorators循环导入问题 -61de975d feat(database): 完成API层、Utils层和兼容层重构 (Stage 4-6) -aae84ec4 docs(database): 添加重构测试报告 -``` - ---- - -## 🎉 重构收益 - -### 1. 性能提升 -- **3级缓存系统**: 减少数据库查询 ~70% -- **智能预加载**: 访问模式学习,命中率 >80% -- **批量调度**: 自适应批处理,吞吐量提升 ~50% -- **WAL模式**: 并发性能提升 ~3x - -### 2. 代码质量 -- **架构清晰**: 6层分离,职责明确 -- **高度模块化**: 每层独立,易于维护 -- **完全测试**: 26个测试用例,100%通过 -- **向后兼容**: 旧代码0改动即可工作 - -### 3. 可维护性 -- **统一接口**: CRUDBase提供一致的API -- **装饰器模式**: 重试、缓存、监控统一管理 -- **配置驱动**: 所有策略可通过配置调整 -- **文档完善**: 每层都有详细文档 - -### 4. 扩展性 -- **插件化设计**: 易于添加新的数据模型 -- **策略可配**: 缓存、预加载策略可灵活调整 -- **监控完善**: 实时性能数据,便于优化 -- **未来支持**: 预留PostgreSQL/MySQL适配接口 - ---- - -## 🔮 后续优化建议 - -### 短期 (1-2周) -1. ✅ **完成导入迁移** - 已完成 -2. ✅ **清理旧文件** - 已完成 -3. 📝 **更新文档** - 进行中 -4. 🔄 **合并到主分支** - 待进行 - -### 中期 (1-2月) -1. **监控优化**: 收集生产环境数据,调优缓存策略 -2. **压力测试**: 模拟高并发场景,验证性能 -3. **错误处理**: 完善异常处理和降级策略 -4. **日志完善**: 增加更详细的性能日志 - -### 长期 (3-6月) -1. **PostgreSQL支持**: 添加PostgreSQL适配器 -2. **分布式缓存**: Redis集成,支持多实例 -3. **读写分离**: 主从复制支持 -4. **数据分析**: 实现复杂的分析查询优化 - ---- - -## 📚 参考文档 - -- [数据库重构计划](./database_refactoring_plan.md) - 原始计划文档 -- [统一调度器指南](./unified_scheduler_guide.md) - 批量调度器使用 -- [测试报告](./database_refactoring_test_report.md) - 详细测试结果 - ---- - -## 🙏 致谢 - -感谢项目组成员在重构过程中的支持和反馈! - -本次重构历时约2周,涉及: -- **新增代码**: ~3000行 -- **重构代码**: ~1500行 -- **测试代码**: ~800行 -- **文档**: ~2000字 - ---- - -**重构状态**: ✅ **已完成** -**下一步**: 合并到主分支并部署 - ---- - -*生成时间: 2025-11-01* -*文档版本: v1.0* diff --git a/docs/database_refactoring_plan.md b/docs/database_refactoring_plan.md deleted file mode 100644 index 68703ec07..000000000 --- a/docs/database_refactoring_plan.md +++ /dev/null @@ -1,1475 +0,0 @@ -# 数据库模块重构方案 - -## 📋 目录 -1. [重构目标](#重构目标) -2. [对外API保持兼容](#对外api保持兼容) -3. [新架构设计](#新架构设计) -4. [高频读写优化](#高频读写优化) -5. [实施计划](#实施计划) -6. [风险评估与回滚方案](#风险评估与回滚方案) - ---- - -## 🎯 重构目标 - -### 核心目标 -1. **架构清晰化** - 消除职责重叠,明确模块边界 -2. **性能优化** - 针对高频读写场景进行深度优化 -3. **向后兼容** - 保持所有对外API接口不变 -4. **可维护性** - 提高代码质量和可测试性 - -### 关键指标 -- ✅ 零破坏性变更 -- ✅ 高频读取性能提升 50%+ -- ✅ 写入批量化率提升至 80%+ -- ✅ 连接池利用率 > 90% - ---- - -## 🔒 对外API保持兼容 - -### 识别的关键API接口 - -#### 1. 数据库会话管理 -```python -# ✅ 必须保持 -from src.common.database.sqlalchemy_models import get_db_session - -async with get_db_session() as session: - # 使用session -``` - -#### 2. 数据操作API -```python -# ✅ 必须保持 -from src.common.database.sqlalchemy_database_api import ( - db_query, # 通用查询 - db_save, # 保存/更新 - db_get, # 快捷查询 - store_action_info, # 存储动作 -) -``` - -#### 3. 模型导入 -```python -# ✅ 必须保持 -from src.common.database.sqlalchemy_models import ( - ChatStreams, - Messages, - PersonInfo, - LLMUsage, - Emoji, - Images, - # ... 所有30+模型 -) -``` - -#### 4. 初始化接口 -```python -# ✅ 必须保持 -from src.common.database.database import ( - db, - initialize_sql_database, - stop_database, -) -``` - -#### 5. 模型映射 -```python -# ✅ 必须保持 -from src.common.database.sqlalchemy_database_api import MODEL_MAPPING -``` - -### 兼容性策略 -所有现有导入路径将通过 `__init__.py` 重新导出,确保零破坏性变更。 - ---- - -## 🏗️ 新架构设计 - -### 当前架构问题 -``` -❌ 当前结构 - 职责混乱 -database/ -├── database.py (入口+初始化+代理) -├── sqlalchemy_init.py (重复的初始化逻辑) -├── sqlalchemy_models.py (模型+引擎+会话+初始化) -├── sqlalchemy_database_api.py -├── connection_pool_manager.py -├── db_batch_scheduler.py -└── db_migration.py -``` - -### 新架构设计 -``` -✅ 新结构 - 职责清晰 -database/ -├── __init__.py 【统一入口】导出所有API -│ -├── core/ 【核心层】 -│ ├── __init__.py -│ ├── engine.py 数据库引擎管理(单一职责) -│ ├── session.py 会话管理(单一职责) -│ ├── models.py 模型定义(纯模型) -│ └── migration.py 迁移工具 -│ -├── api/ 【API层】 -│ ├── __init__.py -│ ├── crud.py CRUD操作(db_query/save/get) -│ ├── specialized.py 特殊操作(store_action_info等) -│ └── query_builder.py 查询构建器 -│ -├── optimization/ 【优化层】 -│ ├── __init__.py -│ ├── connection_pool.py 连接池管理 -│ ├── batch_scheduler.py 批量调度 -│ ├── cache_manager.py 智能缓存 -│ ├── read_write_splitter.py 读写分离 -│ └── preloader.py 预加载器 -│ -├── config/ 【配置层】 -│ ├── __init__.py -│ ├── database_config.py 数据库配置 -│ └── optimization_config.py 优化配置 -│ -└── utils/ 【工具层】 - ├── __init__.py - ├── exceptions.py 统一异常 - ├── decorators.py 装饰器(缓存、重试等) - └── monitoring.py 性能监控 -``` - -### 职责划分 - -#### Core 层(核心层) -| 文件 | 职责 | 依赖 | -|------|------|------| -| `engine.py` | 创建和管理数据库引擎,单例模式 | config | -| `session.py` | 提供会话工厂和上下文管理器 | engine, optimization | -| `models.py` | 定义所有SQLAlchemy模型 | engine | -| `migration.py` | 数据库结构自动迁移 | engine, models | - -#### API 层(接口层) -| 文件 | 职责 | 依赖 | -|------|------|------| -| `crud.py` | 实现db_query/db_save/db_get | session, models | -| `specialized.py` | 特殊业务操作 | crud | -| `query_builder.py` | 构建复杂查询条件 | - | - -#### Optimization 层(优化层) -| 文件 | 职责 | 依赖 | -|------|------|------| -| `connection_pool.py` | 透明连接复用 | session | -| `batch_scheduler.py` | 批量操作调度 | session | -| `cache_manager.py` | 多级缓存管理 | - | -| `read_write_splitter.py` | 读写分离路由 | engine | -| `preloader.py` | 数据预加载 | cache_manager | - ---- - -## ⚡ 高频读写优化 - -### 问题分析 - -通过代码分析,识别出以下高频操作场景: - -#### 高频读取场景 -1. **ChatStreams 查询** - 每条消息都要查询聊天流 -2. **Messages 历史查询** - 构建上下文时频繁查询 -3. **PersonInfo 查询** - 每次交互都要查用户信息 -4. **Emoji/Images 查询** - 发送表情时查询 -5. **UserRelationships 查询** - 关系系统频繁读取 - -#### 高频写入场景 -1. **Messages 插入** - 每条消息都要写入 -2. **LLMUsage 插入** - 每次LLM调用都记录 -3. **ActionRecords 插入** - 每个动作都记录 -4. **ChatStreams 更新** - 更新活跃时间和状态 - -### 优化策略设计 - -#### 1️⃣ 多级缓存系统 - -```python -# optimization/cache_manager.py - -from typing import Any, Optional, Callable -from dataclasses import dataclass -from datetime import timedelta -import asyncio -from collections import OrderedDict - -@dataclass -class CacheConfig: - """缓存配置""" - l1_size: int = 1000 # L1缓存大小(内存LRU) - l1_ttl: float = 60.0 # L1 TTL(秒) - l2_size: int = 10000 # L2缓存大小(内存LRU) - l2_ttl: float = 300.0 # L2 TTL(秒) - enable_write_through: bool = True # 写穿透 - enable_write_back: bool = False # 写回(风险较高) - - -class MultiLevelCache: - """多级缓存管理器 - - L1: 热数据缓存(1000条,60秒)- 极高频访问 - L2: 温数据缓存(10000条,300秒)- 高频访问 - L3: 数据库 - - 策略: - - 读取:L1 → L2 → DB,回填到上层 - - 写入:写穿透(同步更新所有层) - - 失效:TTL + LRU - """ - - def __init__(self, config: CacheConfig): - self.config = config - self.l1_cache: OrderedDict = OrderedDict() - self.l2_cache: OrderedDict = OrderedDict() - self.l1_timestamps: dict = {} - self.l2_timestamps: dict = {} - self.stats = { - "l1_hits": 0, - "l2_hits": 0, - "db_hits": 0, - "writes": 0, - } - self._lock = asyncio.Lock() - - async def get( - self, - key: str, - fetch_func: Callable, - ttl_override: Optional[float] = None - ) -> Any: - """获取数据,自动回填""" - # L1 查找 - if key in self.l1_cache: - if self._is_valid(key, self.l1_timestamps, self.config.l1_ttl): - self.stats["l1_hits"] += 1 - # LRU更新 - self.l1_cache.move_to_end(key) - return self.l1_cache[key] - - # L2 查找 - if key in self.l2_cache: - if self._is_valid(key, self.l2_timestamps, self.config.l2_ttl): - self.stats["l2_hits"] += 1 - value = self.l2_cache[key] - # 回填到L1 - await self._set_l1(key, value) - return value - - # 从数据库获取 - self.stats["db_hits"] += 1 - value = await fetch_func() - - # 回填到L2和L1 - await self._set_l2(key, value) - await self._set_l1(key, value) - - return value - - async def set(self, key: str, value: Any): - """写入数据(写穿透)""" - async with self._lock: - self.stats["writes"] += 1 - await self._set_l1(key, value) - await self._set_l2(key, value) - - async def invalidate(self, key: str): - """失效指定key""" - async with self._lock: - self.l1_cache.pop(key, None) - self.l2_cache.pop(key, None) - self.l1_timestamps.pop(key, None) - self.l2_timestamps.pop(key, None) - - async def invalidate_pattern(self, pattern: str): - """失效匹配模式的key""" - import re - regex = re.compile(pattern) - - async with self._lock: - for key in list(self.l1_cache.keys()): - if regex.match(key): - del self.l1_cache[key] - self.l1_timestamps.pop(key, None) - - for key in list(self.l2_cache.keys()): - if regex.match(key): - del self.l2_cache[key] - self.l2_timestamps.pop(key, None) - - def _is_valid(self, key: str, timestamps: dict, ttl: float) -> bool: - """检查缓存是否有效""" - import time - if key not in timestamps: - return False - return time.time() - timestamps[key] < ttl - - async def _set_l1(self, key: str, value: Any): - """设置L1缓存""" - import time - if len(self.l1_cache) >= self.config.l1_size: - # LRU淘汰 - oldest = next(iter(self.l1_cache)) - del self.l1_cache[oldest] - self.l1_timestamps.pop(oldest, None) - - self.l1_cache[key] = value - self.l1_timestamps[key] = time.time() - - async def _set_l2(self, key: str, value: Any): - """设置L2缓存""" - import time - if len(self.l2_cache) >= self.config.l2_size: - # LRU淘汰 - oldest = next(iter(self.l2_cache)) - del self.l2_cache[oldest] - self.l2_timestamps.pop(oldest, None) - - self.l2_cache[key] = value - self.l2_timestamps[key] = time.time() - - def get_stats(self) -> dict: - """获取缓存统计""" - total_hits = self.stats["l1_hits"] + self.stats["l2_hits"] + self.stats["db_hits"] - if total_hits == 0: - hit_rate = 0 - else: - hit_rate = (self.stats["l1_hits"] + self.stats["l2_hits"]) / total_hits * 100 - - return { - **self.stats, - "l1_size": len(self.l1_cache), - "l2_size": len(self.l2_cache), - "hit_rate": f"{hit_rate:.2f}%", - "total_requests": total_hits, - } - - -# 全局缓存实例 -_cache_manager: Optional[MultiLevelCache] = None - - -def get_cache_manager() -> MultiLevelCache: - """获取全局缓存管理器""" - global _cache_manager - if _cache_manager is None: - _cache_manager = MultiLevelCache(CacheConfig()) - return _cache_manager -``` - -#### 2️⃣ 智能预加载器 - -```python -# optimization/preloader.py - -import asyncio -from typing import List, Dict, Any -from collections import defaultdict - -class DataPreloader: - """数据预加载器 - - 策略: - 1. 会话启动时预加载该聊天流的最近消息 - 2. 定期预加载热门用户的PersonInfo - 3. 预加载常用表情和图片 - """ - - def __init__(self): - self.preload_tasks: Dict[str, asyncio.Task] = {} - self.access_patterns = defaultdict(int) # 访问模式统计 - - async def preload_chat_stream_context( - self, - stream_id: str, - message_limit: int = 50 - ): - """预加载聊天流上下文""" - from ..api.crud import db_get - from ..core.models import Messages, ChatStreams, PersonInfo - from .cache_manager import get_cache_manager - - cache = get_cache_manager() - - # 1. 预加载ChatStream - stream_key = f"chat_stream:{stream_id}" - if stream_key not in cache.l1_cache: - stream = await db_get( - ChatStreams, - filters={"stream_id": stream_id}, - single_result=True - ) - if stream: - await cache.set(stream_key, stream) - - # 2. 预加载最近消息 - messages = await db_get( - Messages, - filters={"chat_id": stream_id}, - order_by="-time", - limit=message_limit - ) - - # 批量缓存消息 - for msg in messages: - msg_key = f"message:{msg['message_id']}" - await cache.set(msg_key, msg) - - # 3. 预加载相关用户信息 - user_ids = set() - for msg in messages: - if msg.get("user_id"): - user_ids.add(msg["user_id"]) - - # 批量查询用户信息 - if user_ids: - users = await db_get( - PersonInfo, - filters={"user_id": {"$in": list(user_ids)}} - ) - for user in users: - user_key = f"person_info:{user['user_id']}" - await cache.set(user_key, user) - - async def preload_hot_emojis(self, limit: int = 100): - """预加载热门表情""" - from ..api.crud import db_get - from ..core.models import Emoji - from .cache_manager import get_cache_manager - - cache = get_cache_manager() - - # 按使用次数排序 - hot_emojis = await db_get( - Emoji, - order_by="-usage_count", - limit=limit - ) - - for emoji in hot_emojis: - emoji_key = f"emoji:{emoji['emoji_hash']}" - await cache.set(emoji_key, emoji) - - async def schedule_preload_task( - self, - task_name: str, - coro, - interval: float = 300.0 # 5分钟 - ): - """定期执行预加载任务""" - async def _task(): - while True: - try: - await coro - await asyncio.sleep(interval) - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"预加载任务 {task_name} 失败: {e}") - await asyncio.sleep(interval) - - task = asyncio.create_task(_task()) - self.preload_tasks[task_name] = task - - async def stop_all_tasks(self): - """停止所有预加载任务""" - for task in self.preload_tasks.values(): - task.cancel() - - await asyncio.gather(*self.preload_tasks.values(), return_exceptions=True) - self.preload_tasks.clear() - - -# 全局预加载器 -_preloader: Optional[DataPreloader] = None - - -def get_preloader() -> DataPreloader: - """获取全局预加载器""" - global _preloader - if _preloader is None: - _preloader = DataPreloader() - return _preloader -``` - -#### 3️⃣ 增强批量调度器 - -```python -# optimization/batch_scheduler.py - -from typing import List, Dict, Any, Callable -from dataclasses import dataclass -import asyncio -import time - -@dataclass -class SmartBatchConfig: - """智能批量配置""" - # 基础配置 - batch_size: int = 100 # 增加批量大小 - max_wait_time: float = 0.05 # 减少等待时间(50ms) - - # 智能调整 - enable_adaptive: bool = True # 启用自适应批量大小 - min_batch_size: int = 10 - max_batch_size: int = 500 - - # 优先级配置 - high_priority_models: List[str] = None # 高优先级模型 - - # 自动降级 - enable_auto_degradation: bool = True - degradation_threshold: float = 1.0 # 超过1秒降级为直接写入 - - -class EnhancedBatchScheduler: - """增强的批量调度器 - - 改进: - 1. 自适应批量大小 - 2. 优先级队列 - 3. 自动降级保护 - 4. 写入确认机制 - """ - - def __init__(self, config: SmartBatchConfig): - self.config = config - self.queues: Dict[str, asyncio.Queue] = {} - self.pending_operations: Dict[str, List] = {} - self.scheduler_tasks: Dict[str, asyncio.Task] = {} - - # 性能监控 - self.performance_stats = { - "avg_batch_size": 0, - "avg_latency": 0, - "total_batches": 0, - } - - self._lock = asyncio.Lock() - self._running = False - - async def schedule_write( - self, - model_class: Any, - operation_type: str, # 'insert', 'update', 'delete' - data: Dict[str, Any], - priority: int = 0, # 0=normal, 1=high, -1=low - ) -> asyncio.Future: - """调度写入操作 - - Returns: - Future对象,可await等待操作完成 - """ - queue_key = f"{model_class.__name__}_{operation_type}" - - # 确保队列存在 - if queue_key not in self.queues: - async with self._lock: - if queue_key not in self.queues: - self.queues[queue_key] = asyncio.Queue() - self.pending_operations[queue_key] = [] - # 启动调度器 - task = asyncio.create_task( - self._scheduler_loop(queue_key, model_class, operation_type) - ) - self.scheduler_tasks[queue_key] = task - - # 创建Future - future = asyncio.get_event_loop().create_future() - - # 加入队列 - operation = { - "data": data, - "priority": priority, - "future": future, - "timestamp": time.time(), - } - - await self.queues[queue_key].put(operation) - - return future - - async def _scheduler_loop( - self, - queue_key: str, - model_class: Any, - operation_type: str - ): - """调度器主循环""" - while self._running: - try: - # 收集一批操作 - batch = [] - deadline = time.time() + self.config.max_wait_time - - while len(batch) < self.config.batch_size: - timeout = deadline - time.time() - if timeout <= 0: - break - - try: - operation = await asyncio.wait_for( - self.queues[queue_key].get(), - timeout=timeout - ) - batch.append(operation) - except asyncio.TimeoutError: - break - - if batch: - # 执行批量操作 - await self._execute_batch( - model_class, - operation_type, - batch - ) - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"批量调度器错误 [{queue_key}]: {e}") - await asyncio.sleep(0.1) - - async def _execute_batch( - self, - model_class: Any, - operation_type: str, - batch: List[Dict] - ): - """执行批量操作""" - start_time = time.time() - - try: - from ..core.session import get_db_session - from sqlalchemy import insert, update, delete - - async with get_db_session() as session: - if operation_type == "insert": - # 批量插入 - data_list = [op["data"] for op in batch] - stmt = insert(model_class).values(data_list) - await session.execute(stmt) - await session.commit() - - # 标记所有Future为成功 - for op in batch: - if not op["future"].done(): - op["future"].set_result(True) - - elif operation_type == "update": - # 批量更新 - for op in batch: - stmt = update(model_class) - # 根据data中的条件更新 - # ... 实现细节 - await session.execute(stmt) - - await session.commit() - - for op in batch: - if not op["future"].done(): - op["future"].set_result(True) - - # 更新性能统计 - latency = time.time() - start_time - self._update_stats(len(batch), latency) - - except Exception as e: - # 标记所有Future为失败 - for op in batch: - if not op["future"].done(): - op["future"].set_exception(e) - - logger.error(f"批量操作失败: {e}") - - def _update_stats(self, batch_size: int, latency: float): - """更新性能统计""" - n = self.performance_stats["total_batches"] - - # 移动平均 - self.performance_stats["avg_batch_size"] = ( - (self.performance_stats["avg_batch_size"] * n + batch_size) / (n + 1) - ) - self.performance_stats["avg_latency"] = ( - (self.performance_stats["avg_latency"] * n + latency) / (n + 1) - ) - self.performance_stats["total_batches"] = n + 1 - - # 自适应调整批量大小 - if self.config.enable_adaptive: - if latency > 0.5: # 太慢,减小批量 - self.config.batch_size = max( - self.config.min_batch_size, - int(self.config.batch_size * 0.8) - ) - elif latency < 0.1: # 很快,增大批量 - self.config.batch_size = min( - self.config.max_batch_size, - int(self.config.batch_size * 1.2) - ) - - async def start(self): - """启动调度器""" - self._running = True - - async def stop(self): - """停止调度器""" - self._running = False - - # 取消所有任务 - for task in self.scheduler_tasks.values(): - task.cancel() - - await asyncio.gather( - *self.scheduler_tasks.values(), - return_exceptions=True - ) - - self.scheduler_tasks.clear() -``` - -#### 4️⃣ 装饰器工具 - -```python -# utils/decorators.py - -from functools import wraps -from typing import Callable, Optional -import asyncio -import time - -def cached( - key_func: Callable = None, - ttl: float = 60.0, - cache_none: bool = False -): - """缓存装饰器 - - Args: - key_func: 生成缓存键的函数 - ttl: 缓存时间 - cache_none: 是否缓存None值 - - Example: - @cached(key_func=lambda stream_id: f"stream:{stream_id}", ttl=300) - async def get_chat_stream(stream_id: str): - # ... - """ - def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): - from ..optimization.cache_manager import get_cache_manager - - cache = get_cache_manager() - - # 生成缓存键 - if key_func: - cache_key = key_func(*args, **kwargs) - else: - # 默认键:函数名+参数 - cache_key = f"{func.__name__}:{args}:{kwargs}" - - # 尝试从缓存获取 - async def fetch(): - return await func(*args, **kwargs) - - result = await cache.get(cache_key, fetch, ttl_override=ttl) - - # 检查是否缓存None - if result is None and not cache_none: - result = await func(*args, **kwargs) - - return result - - return wrapper - return decorator - - -def batch_write( - model_class, - operation_type: str = "insert", - priority: int = 0 -): - """批量写入装饰器 - - 自动将写入操作加入批量调度器 - - Example: - @batch_write(Messages, operation_type="insert") - async def save_message(data: dict): - return data - """ - def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): - from ..optimization.batch_scheduler import get_batch_scheduler - - # 执行原函数获取数据 - data = await func(*args, **kwargs) - - # 加入批量调度器 - scheduler = get_batch_scheduler() - future = await scheduler.schedule_write( - model_class, - operation_type, - data, - priority - ) - - # 等待完成 - result = await future - return result - - return wrapper - return decorator - - -def retry( - max_attempts: int = 3, - delay: float = 0.5, - backoff: float = 2.0, - exceptions: tuple = (Exception,) -): - """重试装饰器 - - Args: - max_attempts: 最大重试次数 - delay: 初始延迟 - backoff: 延迟倍数 - exceptions: 需要重试的异常类型 - """ - def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): - current_delay = delay - - for attempt in range(max_attempts): - try: - return await func(*args, **kwargs) - except exceptions as e: - if attempt == max_attempts - 1: - raise - - logger.warning( - f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {e}," - f"{current_delay}秒后重试" - ) - await asyncio.sleep(current_delay) - current_delay *= backoff - - return wrapper - return decorator - - -def monitor_performance(func: Callable): - """性能监控装饰器""" - @wraps(func) - async def wrapper(*args, **kwargs): - start_time = time.time() - - try: - result = await func(*args, **kwargs) - return result - finally: - elapsed = time.time() - start_time - - # 记录性能数据 - from ..utils.monitoring import record_metric - record_metric( - func.__name__, - "execution_time", - elapsed - ) - - # 慢查询警告 - if elapsed > 1.0: - logger.warning( - f"慢操作检测: {func.__name__} 耗时 {elapsed:.2f}秒" - ) - - return wrapper -``` - -#### 5️⃣ 高频API优化版本 - -```python -# api/optimized_crud.py - -from typing import Optional, List, Dict, Any -from ..utils.decorators import cached, batch_write, monitor_performance -from ..core.models import ChatStreams, Messages, PersonInfo, Emoji - -class OptimizedCRUD: - """优化的CRUD操作 - - 针对高频场景提供优化版本的API - """ - - @staticmethod - @cached( - key_func=lambda stream_id: f"chat_stream:{stream_id}", - ttl=300.0 - ) - @monitor_performance - async def get_chat_stream(stream_id: str) -> Optional[Dict]: - """获取聊天流(高频优化)""" - from .crud import db_get - return await db_get( - ChatStreams, - filters={"stream_id": stream_id}, - single_result=True - ) - - @staticmethod - @cached( - key_func=lambda user_id: f"person_info:{user_id}", - ttl=600.0 # 10分钟 - ) - @monitor_performance - async def get_person_info(user_id: str) -> Optional[Dict]: - """获取用户信息(高频优化)""" - from .crud import db_get - return await db_get( - PersonInfo, - filters={"user_id": user_id}, - single_result=True - ) - - @staticmethod - @cached( - key_func=lambda chat_id, limit: f"messages:{chat_id}:{limit}", - ttl=120.0 # 2分钟 - ) - @monitor_performance - async def get_recent_messages( - chat_id: str, - limit: int = 50 - ) -> List[Dict]: - """获取最近消息(高频优化)""" - from .crud import db_get - return await db_get( - Messages, - filters={"chat_id": chat_id}, - order_by="-time", - limit=limit - ) - - @staticmethod - @batch_write(Messages, operation_type="insert", priority=1) - @monitor_performance - async def save_message(data: Dict) -> Dict: - """保存消息(高频优化,批量写入)""" - return data - - @staticmethod - @cached( - key_func=lambda emoji_hash: f"emoji:{emoji_hash}", - ttl=3600.0 # 1小时 - ) - @monitor_performance - async def get_emoji(emoji_hash: str) -> Optional[Dict]: - """获取表情(高频优化)""" - from .crud import db_get - return await db_get( - Emoji, - filters={"emoji_hash": emoji_hash}, - single_result=True - ) - - @staticmethod - async def update_chat_stream_active_time( - stream_id: str, - active_time: float - ): - """更新聊天流活跃时间(高频优化,异步批量)""" - from ..optimization.batch_scheduler import get_batch_scheduler - from ..optimization.cache_manager import get_cache_manager - - scheduler = get_batch_scheduler() - - # 加入批量更新 - await scheduler.schedule_write( - ChatStreams, - "update", - { - "stream_id": stream_id, - "last_active_time": active_time - }, - priority=0 # 低优先级 - ) - - # 失效缓存 - cache = get_cache_manager() - await cache.invalidate(f"chat_stream:{stream_id}") -``` - ---- - -## 📅 实施计划 - -### 阶段一:准备阶段(1-2天) - -#### 任务清单 -- [x] 完成需求分析和架构设计 -- [ ] 创建新目录结构 -- [ ] 编写测试用例(覆盖所有API) -- [ ] 设置性能基准测试 - -### 阶段二:核心层重构(2-3天) - -#### 任务清单 -- [ ] 创建 `core/engine.py` - 迁移引擎管理逻辑 -- [ ] 创建 `core/session.py` - 迁移会话管理逻辑 -- [ ] 创建 `core/models.py` - 迁移并统一所有模型定义 -- [ ] 更新所有模型到 SQLAlchemy 2.0 类型注解 -- [ ] 创建 `core/migration.py` - 迁移工具 -- [ ] 运行测试,确保核心功能正常 - -### 阶段三:优化层实现(3-4天) - -#### 任务清单 -- [ ] 实现 `optimization/cache_manager.py` - 多级缓存 -- [ ] 实现 `optimization/preloader.py` - 智能预加载 -- [ ] 增强 `optimization/batch_scheduler.py` - 智能批量调度 -- [ ] 实现 `optimization/connection_pool.py` - 优化连接池 -- [ ] 添加性能监控和统计 - -### 阶段四:API层重构(2-3天) - -#### 任务清单 -- [ ] 创建 `api/crud.py` - 重构 CRUD 操作 -- [ ] 创建 `api/optimized_crud.py` - 高频优化API -- [ ] 创建 `api/specialized.py` - 特殊业务操作 -- [ ] 创建 `api/query_builder.py` - 查询构建器 -- [ ] 实现向后兼容的API包装 - -### 阶段五:工具层完善(1-2天) - -#### 任务清单 -- [ ] 创建 `utils/exceptions.py` - 统一异常体系 -- [ ] 创建 `utils/decorators.py` - 装饰器工具 -- [ ] 创建 `utils/monitoring.py` - 性能监控 -- [ ] 添加日志增强 - -### 阶段六:兼容层和迁移(2-3天) - -#### 任务清单 -- [ ] 完善 `__init__.py` - 导出所有API -- [ ] 创建兼容性适配器(如果需要) -- [ ] 逐步迁移现有代码使用新API -- [ ] 添加弃用警告(对于将来要移除的API) - -### 阶段七:测试和优化(2-3天) - -#### 任务清单 -- [ ] 运行完整测试套件 -- [ ] 性能基准测试对比 -- [ ] 压力测试 -- [ ] 修复发现的问题 -- [ ] 性能调优 - -### 阶段八:文档和清理(1-2天) - -#### 任务清单 -- [ ] 编写使用文档 -- [ ] 更新API文档 -- [ ] 删除旧文件(如 .bak) -- [ ] 代码审查 -- [ ] 准备发布 - -### 总时间估计:14-22天 - ---- - -## 🔧 具体实施步骤 - -### 步骤1:创建新目录结构 - -```bash -cd src/common/database - -# 创建新目录 -mkdir -p core api optimization config utils - -# 创建__init__.py -touch core/__init__.py -touch api/__init__.py -touch optimization/__init__.py -touch config/__init__.py -touch utils/__init__.py -``` - -### 步骤2:实现核心层 - -#### core/engine.py -```python -"""数据库引擎管理 -单一职责:创建和管理SQLAlchemy引擎 -""" - -from typing import Optional -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from ..config.database_config import get_database_config -from ..utils.exceptions import DatabaseInitializationError - -_engine: Optional[AsyncEngine] = None -_engine_lock = None - - -async def get_engine() -> AsyncEngine: - """获取全局数据库引擎(单例)""" - global _engine, _engine_lock - - if _engine is not None: - return _engine - - # 延迟导入避免循环依赖 - import asyncio - if _engine_lock is None: - _engine_lock = asyncio.Lock() - - async with _engine_lock: - # 双重检查 - if _engine is not None: - return _engine - - try: - config = get_database_config() - _engine = create_async_engine( - config.url, - **config.engine_kwargs - ) - - # SQLite优化 - if config.db_type == "sqlite": - await _enable_sqlite_optimizations(_engine) - - logger.info(f"数据库引擎初始化成功: {config.db_type}") - return _engine - - except Exception as e: - raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e - - -async def close_engine(): - """关闭数据库引擎""" - global _engine - - if _engine is not None: - await _engine.dispose() - _engine = None - logger.info("数据库引擎已关闭") - - -async def _enable_sqlite_optimizations(engine: AsyncEngine): - """启用SQLite性能优化""" - from sqlalchemy import text - - async with engine.begin() as conn: - await conn.execute(text("PRAGMA journal_mode = WAL")) - await conn.execute(text("PRAGMA synchronous = NORMAL")) - await conn.execute(text("PRAGMA foreign_keys = ON")) - await conn.execute(text("PRAGMA busy_timeout = 60000")) - - logger.info("SQLite性能优化已启用") -``` - -#### core/session.py -```python -"""会话管理 -单一职责:提供数据库会话上下文管理器 -""" - -from contextlib import asynccontextmanager -from typing import AsyncGenerator -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from .engine import get_engine - -_session_factory: Optional[async_sessionmaker] = None - - -async def get_session_factory() -> async_sessionmaker: - """获取会话工厂""" - global _session_factory - - if _session_factory is None: - engine = await get_engine() - _session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False - ) - - return _session_factory - - -@asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession, None]: - """ - 获取数据库会话上下文管理器 - - 使用连接池优化,透明复用连接 - - Example: - async with get_db_session() as session: - result = await session.execute(select(User)) - """ - from ..optimization.connection_pool import get_connection_pool_manager - - session_factory = await get_session_factory() - pool_manager = get_connection_pool_manager() - - async with pool_manager.get_session(session_factory) as session: - # SQLite特定配置 - from ..config.database_config import get_database_config - config = get_database_config() - - if config.db_type == "sqlite": - from sqlalchemy import text - try: - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) - except Exception: - pass # 复用连接时可能已设置 - - yield session -``` - -### 步骤3:完善 `__init__.py` 保持兼容 - -```python -# src/common/database/__init__.py - -""" -数据库模块统一入口 - -导出所有对外API,确保向后兼容 -""" - -# === 核心层导出 === -from .core.engine import get_engine, close_engine -from .core.session import get_db_session -from .core.models import ( - Base, - ChatStreams, - Messages, - ActionRecords, - PersonInfo, - LLMUsage, - Emoji, - Images, - Videos, - OnlineTime, - Memory, - Expression, - ThinkingLog, - GraphNodes, - GraphEdges, - Schedule, - MonthlyPlan, - BanUser, - PermissionNodes, - UserPermissions, - UserRelationships, - ImageDescriptions, - CacheEntries, - MaiZoneScheduleStatus, - AntiInjectionStats, - # ... 所有模型 -) - -# === API层导出 === -from .api.crud import ( - db_query, - db_save, - db_get, -) -from .api.specialized import ( - store_action_info, -) -from .api.optimized_crud import OptimizedCRUD - -# === 优化层导出(可选) === -from .optimization.cache_manager import get_cache_manager -from .optimization.batch_scheduler import get_batch_scheduler -from .optimization.preloader import get_preloader - -# === 旧接口兼容 === -from .database import ( - db, # DatabaseProxy - initialize_sql_database, - stop_database, -) - -# === 模型映射(向后兼容) === -MODEL_MAPPING = { - "Messages": Messages, - "ActionRecords": ActionRecords, - "PersonInfo": PersonInfo, - "ChatStreams": ChatStreams, - "LLMUsage": LLMUsage, - "Emoji": Emoji, - "Images": Images, - "Videos": Videos, - "OnlineTime": OnlineTime, - "Memory": Memory, - "Expression": Expression, - "ThinkingLog": ThinkingLog, - "GraphNodes": GraphNodes, - "GraphEdges": GraphEdges, - "Schedule": Schedule, - "MonthlyPlan": MonthlyPlan, - "UserRelationships": UserRelationships, - # ... 完整映射 -} - -__all__ = [ - # 会话管理 - "get_db_session", - "get_engine", - - # CRUD操作 - "db_query", - "db_save", - "db_get", - "store_action_info", - - # 优化API - "OptimizedCRUD", - - # 模型 - "Base", - "ChatStreams", - "Messages", - # ... 所有模型 - - # 模型映射 - "MODEL_MAPPING", - - # 初始化 - "db", - "initialize_sql_database", - "stop_database", - - # 优化工具 - "get_cache_manager", - "get_batch_scheduler", - "get_preloader", -] -``` - ---- - -## ⚠️ 风险评估与回滚方案 - -### 风险识别 - -| 风险 | 等级 | 影响 | 缓解措施 | -|------|------|------|---------| -| API接口变更 | 高 | 现有代码崩溃 | 完整的兼容层 + 测试覆盖 | -| 性能下降 | 中 | 响应变慢 | 性能基准测试 + 监控 | -| 数据不一致 | 高 | 数据损坏 | 批量操作事务保证 + 备份 | -| 内存泄漏 | 中 | 资源耗尽 | 压力测试 + 监控 | -| 缓存穿透 | 中 | 数据库压力增大 | 布隆过滤器 + 空值缓存 | - -### 回滚方案 - -#### 快速回滚 -```bash -# 如果发现重大问题,立即回滚到旧版本 -git checkout -# 或使用feature分支开发,随时可切换 -git checkout main -``` - -#### 渐进式回滚 -```python -# 在新代码中添加开关 -from src.config.config import global_config - -if global_config.database.use_legacy_mode: - # 使用旧实现 - from .legacy.database import db_query -else: - # 使用新实现 - from .api.crud import db_query -``` - -### 监控指标 - -重构后需要监控的关键指标: -- API响应时间(P50, P95, P99) -- 数据库连接数 -- 缓存命中率 -- 批量操作成功率 -- 错误率和异常 -- 内存使用量 - ---- - -## 📊 预期效果 - -### 性能提升目标 - -| 指标 | 当前 | 目标 | 提升 | -|------|------|------|------| -| 高频读取延迟 | ~50ms | ~10ms | 80% ↓ | -| 缓存命中率 | 0% | 85%+ | ∞ | -| 写入吞吐量 | ~100/s | ~1000/s | 10x ↑ | -| 连接池利用率 | ~60% | >90% | 50% ↑ | -| 数据库连接数 | 动态 | 稳定 | 更稳定 | - -### 代码质量提升 - -- ✅ 减少文件数量和代码行数 -- ✅ 职责更清晰,易于维护 -- ✅ 完整的类型注解 -- ✅ 统一的错误处理 -- ✅ 完善的文档和示例 - ---- - -## ✅ 验收标准 - -### 功能验收 -- [ ] 所有现有测试通过 -- [ ] 所有API接口保持兼容 -- [ ] 无数据丢失或不一致 -- [ ] 无性能回归 - -### 性能验收 -- [ ] 高频读取延迟 < 15ms(P95) -- [ ] 缓存命中率 > 80% -- [ ] 写入吞吐量 > 500/s -- [ ] 连接池利用率 > 85% - -### 代码质量验收 -- [ ] 类型检查无错误 -- [ ] 代码覆盖率 > 80% -- [ ] 无重大代码异味 -- [ ] 文档完整 - ---- - -## 📝 总结 - -本重构方案在保持完全向后兼容的前提下,通过以下措施优化数据库模块: - -1. **架构清晰化** - 分层设计,职责明确 -2. **多级缓存** - L1/L2缓存 + 智能失效 -3. **智能预加载** - 减少冷启动延迟 -4. **批量调度增强** - 自适应批量大小 + 优先级队列 -5. **装饰器工具** - 简化高频操作的优化 -6. **性能监控** - 实时监控和告警 - -预期可实现: -- 高频读取延迟降低 80% -- 写入吞吐量提升 10 倍 -- 连接池利用率提升至 90% 以上 - -风险可控,可随时回滚。 diff --git a/docs/database_refactoring_test_report.md b/docs/database_refactoring_test_report.md deleted file mode 100644 index 7906f93b4..000000000 --- a/docs/database_refactoring_test_report.md +++ /dev/null @@ -1,187 +0,0 @@ -# 数据库重构测试报告 - -**测试时间**: 2025-11-01 13:00 -**测试环境**: Python 3.13.2, pytest 8.4.2 -**测试范围**: 核心层 + 优化层 - -## 📊 测试结果总览 - -**总计**: 21个测试 -**通过**: 19个 ✅ (90.5%) -**失败**: 1个 ❌ (超时) -**跳过**: 1个 ⏭️ - -## ✅ 通过的测试 (19/21) - -### 核心层 (Core Layer) - 4/4 ✅ - -1. **test_engine_singleton** ✅ - - 引擎单例模式正常工作 - - 多次调用返回同一实例 - -2. **test_session_factory** ✅ - - 会话工厂创建会话正常 - - 连接池复用机制工作 - -3. **test_database_migration** ✅ - - 数据库迁移成功 - - 25个表结构全部一致 - - 自动检测和更新功能正常 - -4. **test_model_crud** ✅ - - 模型CRUD操作正常 - - ChatStreams创建、查询、删除成功 - -### 缓存管理器 (Cache Manager) - 5/5 ✅ - -5. **test_cache_basic_operations** ✅ - - set/get/delete基本操作正常 - -6. **test_cache_levels** ✅ - - L1和L2两级缓存同时工作 - - 数据正确写入两级缓存 - -7. **test_cache_expiration** ✅ - - TTL过期机制正常 - - 过期数据自动清理 - -8. **test_cache_lru_eviction** ✅ - - LRU淘汰策略正确 - - 最近使用的数据保留 - -9. **test_cache_stats** ✅ - - 统计信息准确 - - 命中率/未命中率正确记录 - -### 数据预加载器 (Preloader) - 3/3 ✅ - -10. **test_access_pattern_tracking** ✅ - - 访问模式追踪正常 - - 访问次数统计准确 - -11. **test_preload_data** ✅ - - 数据预加载功能正常 - - 预加载的数据正确写入缓存 - -12. **test_related_keys** ✅ - - 关联键识别正确 - - 关联关系记录准确 - -### 批量调度器 (Batch Scheduler) - 4/5 ✅ - -13. **test_scheduler_lifecycle** ✅ - - 启动/停止生命周期正常 - - 状态管理正确 - -14. **test_batch_priority** ✅ - - 优先级队列工作正常 - - LOW/NORMAL/HIGH/URGENT四级优先级 - -15. **test_adaptive_parameters** ✅ - - 自适应参数调整正常 - - 根据拥塞评分动态调整批次大小 - -16. **test_batch_stats** ✅ - - 统计信息准确 - - 拥塞评分、操作数等指标正常 - -17. **test_batch_operations** - 跳过(待优化) - - 批量操作功能基本正常 - - 需要优化等待时间 - -### 集成测试 (Integration) - 1/2 ✅ - -18. **test_cache_and_preloader_integration** ✅ - - 缓存与预加载器协同工作 - - 预加载数据正确进入缓存 - -19. **test_full_stack_query** ❌ 超时 - - 完整查询流程测试超时 - - 需要优化批处理响应时间 - -### 性能测试 (Performance) - 1/2 ✅ - -20. **test_cache_performance** ✅ - - **写入性能**: 196k ops/s (0.51ms/100项) - - **读取性能**: 1.6k ops/s (59.53ms/100项) - - 性能达标,读取可进一步优化 - -21. **test_batch_throughput** - 跳过 - - 需要优化测试用例 - -## 📈 性能指标 - -### 缓存性能 -- **写入吞吐**: 195,996 ops/s -- **读取吞吐**: 1,680 ops/s -- **L1命中率**: >80% (预期) -- **L2命中率**: >60% (预期) - -### 批处理性能 -- **批次大小**: 10-100 (自适应) -- **等待时间**: 50-200ms (自适应) -- **拥塞控制**: 实时调节 - -### 数据库连接 -- **连接池**: 最大10个连接 -- **连接复用**: 正常工作 -- **WAL模式**: SQLite优化启用 - -## 🐛 待解决问题 - -### 1. 批处理超时 (优先级: 中) -- **问题**: `test_full_stack_query` 超时 -- **原因**: 批处理调度器等待时间过长 -- **影响**: 某些场景下响应慢 -- **方案**: 调整等待时间和批次触发条件 - -### 2. 警告信息 (优先级: 低) -- **SQLAlchemy 2.0**: `declarative_base()` 已废弃 - - 建议: 迁移到 `sqlalchemy.orm.declarative_base()` -- **pytest-asyncio**: fixture警告 - - 建议: 使用 `@pytest_asyncio.fixture` - -## ✨ 测试亮点 - -### 1. 核心功能稳定 -- ✅ 引擎单例、会话管理、模型迁移全部正常 -- ✅ 25个数据库表结构完整 - -### 2. 缓存系统高效 -- ✅ L1/L2两级缓存正常工作 -- ✅ LRU淘汰和TTL过期机制正确 -- ✅ 写入性能达到196k ops/s - -### 3. 预加载智能 -- ✅ 访问模式追踪准确 -- ✅ 关联数据识别正常 -- ✅ 与缓存系统集成良好 - -### 4. 批处理自适应 -- ✅ 动态调整批次大小 -- ✅ 优先级队列工作正常 -- ✅ 拥塞控制有效 - -## 📋 下一步建议 - -### 立即行动 (P0) -1. ✅ 核心层和优化层功能完整,可以进入阶段四 -2. ⏭️ 优化批处理超时问题可以并行进行 - -### 短期优化 (P1) -1. 优化批处理调度器的等待策略 -2. 提升缓存读取性能(目前1.6k ops/s) -3. 修复SQLAlchemy 2.0警告 - -### 长期改进 (P2) -1. 增加更多边界情况测试 -2. 添加并发测试和压力测试 -3. 完善性能基准测试 - -## 🎯 结论 - -**重构成功率**: 90.5% ✅ - -核心层和优化层的重构基本完成,功能测试通过率高,性能指标达标。仅有1个超时问题不影响核心功能使用,可以进入下一阶段的API层重构工作。 - -**建议**: 继续推进阶段四(API层重构),同时并行优化批处理性能。 diff --git a/docs/development/json_parser_unification.md b/docs/development/json_parser_unification.md deleted file mode 100644 index 2e38383b6..000000000 --- a/docs/development/json_parser_unification.md +++ /dev/null @@ -1,216 +0,0 @@ -# JSON 解析统一化改进文档 - -## 改进目标 -统一项目中所有 LLM 响应的 JSON 解析逻辑,使用 `json_repair` 库和统一的解析工具,简化代码并提高解析成功率。 - -## 创建的新工具模块 - -### `src/utils/json_parser.py` -提供统一的 JSON 解析功能: - -#### 主要函数: -1. **`extract_and_parse_json(response, strict=False)`** - - 从 LLM 响应中提取并解析 JSON - - 自动处理 Markdown 代码块标记 - - 使用 json_repair 修复格式问题 - - 支持严格模式和容错模式 - -2. **`safe_parse_json(json_str, default=None)`** - - 安全解析 JSON,失败时返回默认值 - -3. **`extract_json_field(response, field_name, default=None)`** - - 从 LLM 响应中提取特定字段的值 - -#### 处理策略: -1. 清理 Markdown 代码块标记(```json 和 ```) -2. 提取 JSON 对象或数组(使用栈匹配算法) -3. 尝试直接解析 -4. 如果失败,使用 json_repair 修复后解析 -5. 容错模式下返回空字典或空列表 - -## 已修改的文件 - -### 1. `src/chat/memory_system/memory_query_planner.py` ✅ -- 移除了自定义的 `_extract_json_payload` 方法 -- 使用 `extract_and_parse_json` 替代原有的解析逻辑 -- 简化了代码,提高了可维护性 - -**修改前:** -```python -payload = self._extract_json_payload(response) -if not payload: - return self._default_plan(query_text) -try: - data = orjson.loads(payload) -except orjson.JSONDecodeError as exc: - ... -``` - -**修改后:** -```python -data = extract_and_parse_json(response, strict=False) -if not data or not isinstance(data, dict): - return self._default_plan(query_text) -``` - -### 2. `src/chat/memory_system/memory_system.py` ✅ -- 移除了自定义的 `_extract_json_payload` 方法 -- 在 `_evaluate_information_value` 方法中使用统一解析工具 -- 简化了错误处理逻辑 - -### 3. `src/chat/interest_system/bot_interest_manager.py` ✅ -- 移除了自定义的 `_clean_llm_response` 方法 -- 使用 `extract_and_parse_json` 解析兴趣标签数据 -- 改进了错误处理和日志输出 - -### 4. `src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py` ✅ -- 将 `_clean_llm_json_response` 标记为已废弃 -- 使用 `extract_and_parse_json` 解析聊天流印象数据 -- 添加了类型检查和错误处理 - -## 待修改的文件 - -### 需要类似修改的其他文件: -1. `src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py` - - 包含自定义的 JSON 清理逻辑 - -2. `src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py` - - 包含自定义的 JSON 清理逻辑 - -3. 其他包含自定义 JSON 解析逻辑的文件 - -## 改进效果 - -### 1. 代码简化 -- 消除了重复的 JSON 提取和清理代码 -- 减少了代码行数和维护成本 -- 统一了错误处理模式 - -### 2. 解析成功率提升 -- 使用 json_repair 自动修复常见的 JSON 格式问题 -- 支持多种 JSON 包装格式(代码块、纯文本等) -- 更好的容错处理 - -### 3. 可维护性提升 -- 集中管理 JSON 解析逻辑 -- 易于添加新的解析策略 -- 便于调试和日志记录 - -### 4. 一致性提升 -- 所有 LLM 响应使用相同的解析流程 -- 统一的日志输出格式 -- 一致的错误处理 - -## 使用示例 - -### 基本用法: -```python -from src.utils.json_parser import extract_and_parse_json - -# LLM 响应可能包含 Markdown 代码块或其他文本 -llm_response = '```json\\n{"key": "value"}\\n```' - -# 自动提取和解析 -data = extract_and_parse_json(llm_response, strict=False) -# 返回: {'key': 'value'} - -# 如果解析失败,返回空字典(非严格模式) -# 严格模式下返回 None -``` - -### 提取特定字段: -```python -from src.utils.json_parser import extract_json_field - -llm_response = '{"score": 0.85, "reason": "Good quality"}' -score = extract_json_field(llm_response, "score", default=0.0) -# 返回: 0.85 -``` - -## 测试建议 - -1. **单元测试**: - - 测试各种 JSON 格式(带/不带代码块标记) - - 测试格式错误的 JSON(验证 json_repair 的修复能力) - - 测试嵌套 JSON 结构 - - 测试空响应和无效响应 - -2. **集成测试**: - - 在实际 LLM 调用场景中测试 - - 验证不同模型的响应格式兼容性 - - 测试错误处理和日志输出 - -3. **性能测试**: - - 测试大型 JSON 的解析性能 - - 验证缓存和优化策略 - -## 迁移指南 - -### 旧代码模式: -```python -# 旧的自定义解析逻辑 -def _extract_json(response: str) -> str | None: - stripped = response.strip() - code_block_match = re.search(r"```(?:json)?\\s*(.*?)```", stripped, re.DOTALL) - if code_block_match: - return code_block_match.group(1) - # ... 更多自定义逻辑 - -# 使用 -payload = self._extract_json(response) -if payload: - data = orjson.loads(payload) -``` - -### 新代码模式: -```python -# 使用统一工具 -from src.utils.json_parser import extract_and_parse_json - -# 直接解析 -data = extract_and_parse_json(response, strict=False) -if data and isinstance(data, dict): - # 使用数据 - pass -``` - -## 注意事项 - -1. **导入语句**:确保添加正确的导入 - ```python - from src.utils.json_parser import extract_and_parse_json - ``` - -2. **错误处理**:统一工具已包含错误处理,无需额外 try-except - ```python - # 不需要 - try: - data = extract_and_parse_json(response) - except Exception: - ... - - # 应该 - data = extract_and_parse_json(response, strict=False) - if not data: - # 处理失败情况 - pass - ``` - -3. **类型检查**:始终验证返回值类型 - ```python - data = extract_and_parse_json(response) - if isinstance(data, dict): - # 处理字典 - elif isinstance(data, list): - # 处理列表 - ``` - -## 后续工作 - -1. 完成剩余文件的迁移 -2. 添加完整的单元测试 -3. 更新相关文档 -4. 考虑添加性能监控和统计 - -## 日期 -2025年11月2日 diff --git a/docs/guides/OBJECT_LEVEL_MEMORY_ANALYSIS.md b/docs/guides/OBJECT_LEVEL_MEMORY_ANALYSIS.md deleted file mode 100644 index c7aad3ffc..000000000 --- a/docs/guides/OBJECT_LEVEL_MEMORY_ANALYSIS.md +++ /dev/null @@ -1,267 +0,0 @@ -# 对象级内存分析指南 - -## 🎯 概述 - -对象级内存分析可以帮助你: -- 查看哪些 Python 对象类型占用最多内存 -- 追踪对象数量和大小的变化 -- 识别内存泄漏的具体对象 -- 监控垃圾回收效率 - -## 🚀 快速开始 - -### 1. 安装依赖 - -```powershell -pip install pympler -``` - -### 2. 启用对象级分析 - -```powershell -# 基本用法 - 启用对象分析 -python scripts/run_bot_with_tracking.py --objects - -# 自定义监控间隔(10 秒) -python scripts/run_bot_with_tracking.py --objects --interval 10 - -# 显示更多对象类型(前 20 个) -python scripts/run_bot_with_tracking.py --objects --object-limit 20 - -# 完整示例(简写参数) -python scripts/run_bot_with_tracking.py -o -i 10 -l 20 -``` - -## 📊 输出示例 - -### 进程级信息 - -``` -================================================================================ -检查点 #1 - 12:34:56 -Bot 进程 (PID: 12345) - RSS: 45.23 MB - VMS: 125.45 MB - 占比: 0.35% - 子进程: 1 个 - 子进程内存: 32.10 MB - 总内存: 77.33 MB - -变化: - RSS: +2.15 MB -``` - -### 对象级分析信息 - -``` -📦 对象级内存分析 (检查点 #1) --------------------------------------------------------------------------------- -类型 数量 总大小 --------------------------------------------------------------------------------- -dict 12,345 15.23 MB -str 45,678 8.92 MB -list 8,901 5.67 MB -tuple 23,456 4.32 MB -type 1,234 3.21 MB -code 2,345 2.10 MB -set 1,567 1.85 MB -function 3,456 1.23 MB -method 4,567 890.45 KB -weakref 2,345 678.12 KB - -🗑️ 垃圾回收统计: - - 代 0 回收: 125 次 - - 代 1 回收: 12 次 - - 代 2 回收: 2 次 - - 未回收对象: 0 - - 追踪对象数: 89,456 - -📊 总对象数: 123,456 --------------------------------------------------------------------------------- -``` - -## 🔍 如何解读输出 - -### 1. 对象类型统计 - -每一行显示: -- **类型名称**: Python 对象类型(dict、str、list 等) -- **数量**: 该类型的对象实例数量 -- **总大小**: 该类型所有对象占用的总内存 - -**关键指标**: -- `dict` 多是正常的(Python 大量使用字典) -- `str` 多也是正常的(字符串无处不在) -- 如果看到某个自定义类型数量异常增长 → 可能存在泄漏 -- 如果某个类型占用内存异常大 → 需要优化 - -### 2. 垃圾回收统计 - -**代 0/1/2 回收次数**: -- 代 0:最频繁,新创建的对象 -- 代 1:中等频率,存活一段时间的对象 -- 代 2:最少,长期存活的对象 - -**未回收对象**: -- 应该是 0 或很小的数字 -- 如果持续增长 → 可能存在循环引用导致的内存泄漏 - -**追踪对象数**: -- Python 垃圾回收器追踪的对象总数 -- 持续增长可能表示内存泄漏 - -### 3. 总对象数 - -当前进程中所有 Python 对象的数量。 - -## 🎯 常见使用场景 - -### 场景 1: 查找内存泄漏 - -```powershell -# 长时间运行,频繁检查 -python scripts/run_bot_with_tracking.py -o -i 5 -``` - -**观察**: -- 哪些对象类型数量持续增长? -- RSS 内存增长和对象数量增长是否一致? -- 垃圾回收是否正常工作? - -### 场景 2: 优化内存占用 - -```powershell -# 较长间隔,查看稳定状态 -python scripts/run_bot_with_tracking.py -o -i 30 -l 25 -``` - -**分析**: -- 前 25 个对象类型中,哪些是你的代码创建的? -- 是否有不必要的大对象缓存? -- 能否使用更轻量的数据结构? - -### 场景 3: 调试特定功能 - -```powershell -# 短间隔,快速反馈 -python scripts/run_bot_with_tracking.py -o -i 3 -``` - -**用途**: -- 触发某个功能后立即观察内存变化 -- 检查对象是否正确释放 -- 验证优化效果 - -## 📝 保存的历史文件 - -监控结束后,历史数据会自动保存到: -``` -data/memory_diagnostics/bot_memory_monitor_YYYYMMDD_HHMMSS_pidXXXXX.txt -``` - -文件内容包括: -- 每个检查点的进程内存信息 -- 每个检查点的对象统计(前 10 个类型) -- 总体统计信息(起始/结束/峰值/平均) - -## 🔧 高级技巧 - -### 1. 结合代码修改 - -在你的代码中添加检查点: - -```python -import gc -from pympler import muppy, summary - -def debug_memory(): - """在关键位置调用此函数""" - gc.collect() - all_objects = muppy.get_objects() - sum_data = summary.summarize(all_objects) - summary.print_(sum_data, limit=10) -``` - -### 2. 比较不同时间点 - -```powershell -# 运行 1 分钟 -python scripts/run_bot_with_tracking.py -o -i 10 -# Ctrl+C 停止,查看文件 - -# 等待 5 分钟后再运行 -python scripts/run_bot_with_tracking.py -o -i 10 -# 比较两次的对象统计 -``` - -### 3. 专注特定对象类型 - -修改 `run_bot_with_tracking.py` 中的 `get_object_stats()` 函数,添加过滤: - -```python -def get_object_stats(limit: int = 10) -> Dict: - # ...现有代码... - - # 只显示特定类型 - filtered_summary = [ - row for row in sum_data - if 'YourClassName' in row[0] - ] - - return { - "summary": filtered_summary[:limit], - # ... - } -``` - -## ⚠️ 注意事项 - -### 性能影响 - -对象级分析会影响性能: -- **pympler 分析**: ~10-20% 性能影响 -- **gc.collect()**: 每次检查点触发垃圾回收,可能导致短暂卡顿 - -**建议**: -- 开发/调试时使用对象分析 -- 生产环境使用普通监控(不加 `--objects`) - -### 内存开销 - -对象分析本身也会占用内存: -- `muppy.get_objects()` 会创建对象列表 -- 统计数据会保存在历史中 - -**建议**: -- 不要设置过小的 `--interval`(建议 >= 5 秒) -- 长时间运行时考虑关闭对象分析 - -### 准确性 - -- 对象统计是**快照**,不是实时的 -- `gc.collect()` 后才统计,确保垃圾已回收 -- 子进程的对象无法统计(只统计主进程) - -## 📚 相关工具 - -| 工具 | 用途 | 对象级分析 | -|------|------|----------| -| `run_bot_with_tracking.py` | 一键启动+监控 | ✅ 支持 | -| `memory_monitor.py` | 手动监控 | ✅ 支持 | -| `windows_memory_profiler.py` | 详细分析 | ✅ 支持 | -| `run_bot_with_pympler.py` | 专门的对象追踪 | ✅ 专注此功能 | - -## 🎓 学习资源 - -- [Pympler 文档](https://pympler.readthedocs.io/) -- [Python GC 模块](https://docs.python.org/3/library/gc.html) -- [内存泄漏调试技巧](https://docs.python.org/3/library/tracemalloc.html) - ---- - -**快速开始**: -```powershell -pip install pympler -python scripts/run_bot_with_tracking.py --objects -``` -🎉 diff --git a/docs/guides/memory_deduplication_guide.md b/docs/guides/memory_deduplication_guide.md deleted file mode 100644 index 77d346a0c..000000000 --- a/docs/guides/memory_deduplication_guide.md +++ /dev/null @@ -1,391 +0,0 @@ -# 记忆去重工具使用指南 - -## 📋 功能说明 - -`deduplicate_memories.py` 是一个用于清理重复记忆的工具。它会: - -1. 扫描所有标记为"相似"关系的记忆对 -2. 根据重要性、激活度和创建时间决定保留哪个 -3. 删除重复的记忆,保留最有价值的那个 -4. 提供详细的去重报告 - -## 🚀 快速开始 - -### 步骤1: 预览模式(推荐) - -**首次使用前,建议先运行预览模式,查看会删除哪些记忆:** - -```bash -python scripts/deduplicate_memories.py --dry-run -``` - -输出示例: -``` -============================================================ -记忆去重工具 -============================================================ -数据目录: data/memory_graph -相似度阈值: 0.85 -模式: 预览模式(不实际删除) -============================================================ - -✅ 记忆管理器初始化成功,共 156 条记忆 -找到 23 对相似记忆(阈值>=0.85) - -[预览] 去重相似记忆对 (相似度=0.904): - 保留: mem_20251106_202832_887727 - - 主题: 今天天气很好 - - 重要性: 0.60 - - 激活度: 0.55 - - 创建时间: 2024-11-06 20:28:32 - 删除: mem_20251106_202828_883440 - - 主题: 今天天气晴朗 - - 重要性: 0.50 - - 激活度: 0.50 - - 创建时间: 2024-11-06 20:28:28 - [预览模式] 不执行实际删除 - -============================================================ -去重报告 -============================================================ -总记忆数: 156 -相似记忆对: 23 -发现重复: 23 -预览通过: 23 -错误数: 0 -耗时: 2.35秒 - -⚠️ 这是预览模式,未实际删除任何记忆 -💡 要执行实际删除,请运行: python scripts/deduplicate_memories.py -============================================================ -``` - -### 步骤2: 执行去重 - -**确认预览结果无误后,执行实际去重:** - -```bash -python scripts/deduplicate_memories.py -``` - -输出示例: -``` -============================================================ -记忆去重工具 -============================================================ -数据目录: data/memory_graph -相似度阈值: 0.85 -模式: 执行模式(会实际删除) -============================================================ - -✅ 记忆管理器初始化成功,共 156 条记忆 -找到 23 对相似记忆(阈值>=0.85) - -[执行] 去重相似记忆对 (相似度=0.904): - 保留: mem_20251106_202832_887727 - ... - 删除: mem_20251106_202828_883440 - ... - ✅ 删除成功 - -正在保存数据... -✅ 数据已保存 - -============================================================ -去重报告 -============================================================ -总记忆数: 156 -相似记忆对: 23 -成功删除: 23 -错误数: 0 -耗时: 5.67秒 - -✅ 去重完成! -📊 最终记忆数: 133 (减少 23 条) -============================================================ -``` - -## 🎛️ 命令行参数 - -### `--dry-run`(推荐先使用) - -预览模式,不实际删除任何记忆。 - -```bash -python scripts/deduplicate_memories.py --dry-run -``` - -### `--threshold <相似度>` - -指定相似度阈值,只处理相似度大于等于此值的记忆对。 - -```bash -# 只处理高度相似(>=0.95)的记忆 -python scripts/deduplicate_memories.py --threshold 0.95 - -# 处理中等相似(>=0.8)的记忆 -python scripts/deduplicate_memories.py --threshold 0.8 -``` - -**阈值建议**: -- `0.95-1.0`: 极高相似度,几乎完全相同(最安全) -- `0.9-0.95`: 高度相似,内容基本一致(推荐) -- `0.85-0.9`: 中等相似,可能有细微差别(谨慎使用) -- `<0.85`: 低相似度,可能误删(不推荐) - -### `--data-dir <目录>` - -指定记忆数据目录。 - -```bash -# 对测试数据去重 -python scripts/deduplicate_memories.py --data-dir data/test_memory - -# 对备份数据去重 -python scripts/deduplicate_memories.py --data-dir data/memory_backup -``` - -## 📖 使用场景 - -### 场景1: 定期维护 - -**建议频率**: 每周或每月运行一次 - -```bash -# 1. 先预览 -python scripts/deduplicate_memories.py --dry-run --threshold 0.92 - -# 2. 确认后执行 -python scripts/deduplicate_memories.py --threshold 0.92 -``` - -### 场景2: 清理大量重复 - -**适用于**: 导入外部数据后,或发现大量重复记忆 - -```bash -# 使用较低阈值,清理更多重复 -python scripts/deduplicate_memories.py --threshold 0.85 -``` - -### 场景3: 保守清理 - -**适用于**: 担心误删,只想删除极度相似的记忆 - -```bash -# 使用高阈值,只删除几乎完全相同的记忆 -python scripts/deduplicate_memories.py --threshold 0.98 -``` - -### 场景4: 测试环境 - -**适用于**: 在测试数据上验证效果 - -```bash -# 对测试数据执行去重 -python scripts/deduplicate_memories.py --data-dir data/test_memory --dry-run -``` - -## 🔍 去重策略 - -### 保留原则(按优先级) - -脚本会按以下优先级决定保留哪个记忆: - -1. **重要性更高** (`importance` 值更大) -2. **激活度更高** (`activation` 值更大) -3. **创建时间更早** (更早创建的记忆) - -### 增强保留记忆 - -保留的记忆会获得以下增强: - -- **重要性** +0.05(最高1.0) -- **激活度** +0.05(最高1.0) -- **访问次数** 累加被删除记忆的访问次数 - -### 示例 - -``` -记忆A: 重要性0.8, 激活度0.6, 创建于 2024-11-01 -记忆B: 重要性0.7, 激活度0.9, 创建于 2024-11-05 - -结果: 保留记忆A(重要性更高) -增强: 重要性 0.8 → 0.85, 激活度 0.6 → 0.65 -``` - -## ⚠️ 注意事项 - -### 1. 备份数据 - -**在执行实际去重前,建议备份数据:** - -```bash -# Windows -xcopy data\memory_graph data\memory_graph_backup /E /I /Y - -# Linux/Mac -cp -r data/memory_graph data/memory_graph_backup -``` - -### 2. 先预览再执行 - -**务必先运行 `--dry-run` 预览:** - -```bash -# 错误示范 ❌ -python scripts/deduplicate_memories.py # 直接执行 - -# 正确示范 ✅ -python scripts/deduplicate_memories.py --dry-run # 先预览 -python scripts/deduplicate_memories.py # 再执行 -``` - -### 3. 阈值选择 - -**过低的阈值可能导致误删:** - -```bash -# 风险较高 ⚠️ -python scripts/deduplicate_memories.py --threshold 0.7 - -# 推荐范围 ✅ -python scripts/deduplicate_memories.py --threshold 0.92 -``` - -### 4. 不可恢复 - -**删除的记忆无法恢复!** 如果不确定,请: - -1. 先备份数据 -2. 使用 `--dry-run` 预览 -3. 使用较高的阈值(如 0.95) - -### 5. 中断恢复 - -如果执行过程中中断(Ctrl+C),已删除的记忆无法恢复。建议: - -- 在低负载时段运行 -- 确保足够的执行时间 -- 使用 `--threshold` 限制处理数量 - -## 🐛 故障排查 - -### 问题1: 找不到相似记忆对 - -``` -找到 0 对相似记忆(阈值>=0.85) -``` - -**原因**: -- 没有标记为"相似"的边 -- 阈值设置过高 - -**解决**: -1. 降低阈值:`--threshold 0.7` -2. 检查记忆系统是否正确创建了相似关系 -3. 先运行自动关联任务 - -### 问题2: 初始化失败 - -``` -❌ 记忆管理器初始化失败 -``` - -**原因**: -- 数据目录不存在 -- 配置文件错误 -- 数据文件损坏 - -**解决**: -1. 检查数据目录是否存在 -2. 验证配置文件:`config/bot_config.toml` -3. 查看详细日志定位问题 - -### 问题3: 删除失败 - -``` -❌ 删除失败: ... -``` - -**原因**: -- 权限不足 -- 数据库锁定 -- 文件损坏 - -**解决**: -1. 检查文件权限 -2. 确保没有其他进程占用数据 -3. 恢复备份后重试 - -## 📊 性能参考 - -| 记忆数量 | 相似对数 | 执行时间(预览) | 执行时间(实际) | -|---------|---------|----------------|----------------| -| 100 | 10 | ~1秒 | ~2秒 | -| 500 | 50 | ~3秒 | ~6秒 | -| 1000 | 100 | ~5秒 | ~12秒 | -| 5000 | 500 | ~15秒 | ~45秒 | - -**注**: 实际时间取决于服务器性能和数据复杂度 - -## 🔗 相关工具 - -- **记忆整理**: `src/memory_graph/manager.py::consolidate_memories()` -- **自动关联**: `src/memory_graph/manager.py::auto_link_memories()` -- **配置验证**: `scripts/verify_config_update.py` - -## 💡 最佳实践 - -### 1. 定期维护流程 - -```bash -# 每周执行 -cd /path/to/bot - -# 1. 备份 -cp -r data/memory_graph data/memory_graph_backup_$(date +%Y%m%d) - -# 2. 预览 -python scripts/deduplicate_memories.py --dry-run --threshold 0.92 - -# 3. 执行 -python scripts/deduplicate_memories.py --threshold 0.92 - -# 4. 验证 -python scripts/verify_config_update.py -``` - -### 2. 保守去重策略 - -```bash -# 只删除极度相似的记忆 -python scripts/deduplicate_memories.py --dry-run --threshold 0.98 -python scripts/deduplicate_memories.py --threshold 0.98 -``` - -### 3. 批量清理策略 - -```bash -# 先清理高相似度的 -python scripts/deduplicate_memories.py --threshold 0.95 - -# 再清理中相似度的(可选) -python scripts/deduplicate_memories.py --dry-run --threshold 0.9 -python scripts/deduplicate_memories.py --threshold 0.9 -``` - -## 📝 总结 - -- ✅ **务必先备份数据** -- ✅ **务必先运行 `--dry-run`** -- ✅ **建议使用阈值 >= 0.92** -- ✅ **定期运行,保持记忆库清洁** -- ❌ **避免过低阈值(< 0.85)** -- ❌ **避免跳过预览直接执行** - ---- - -**创建日期**: 2024-11-06 -**版本**: v1.0 -**维护者**: MoFox-Bot Team diff --git a/docs/memory_graph/OPTIMIZATION_ARCHITECTURE_VISUAL.md b/docs/memory_graph/OPTIMIZATION_ARCHITECTURE_VISUAL.md deleted file mode 100644 index 31d0e5f1c..000000000 --- a/docs/memory_graph/OPTIMIZATION_ARCHITECTURE_VISUAL.md +++ /dev/null @@ -1,451 +0,0 @@ -# 优化架构可视化 - -## 📐 优化前后架构对比 - -### ❌ 优化前:线性+串行架构 - -``` - 搜索记忆请求 - | - v - ┌─────────────┐ - │ 生成查询向量 │ - └──────┬──────┘ - | - v - ┌─────────────────────────────┐ - │ for each memory in list: │ - │ - 线性扫描 O(n) │ - │ - 计算相似度 await │ - │ - 串行等待 1500ms │ - │ - 每次都重复计算! │ - └──────┬──────────────────────┘ - | - v - ┌──────────────┐ - │ 排序结果 │ - │ Top-K 返回 │ - └──────────────┘ - -查询记忆流程: - ID 查找 → for 循环遍历 O(n) → 30 次比较 - -性能问题: - - ❌ 串行计算: 等待太久 - - ❌ 重复计算: 缓存为空 - - ❌ 线性查找: 列表遍历太多 -``` - ---- - -### ✅ 优化后:哈希+并发+缓存架构 - -``` - 搜索记忆请求 - | - v - ┌─────────────┐ - │ 生成查询向量 │ - └──────┬──────┘ - | - v - ┌──────────────────────┐ - │ 检查缓存存在? │ - │ cache[query_id]? │ - └────────┬────────┬───┘ - 命中 YES | | NO (首次查询) - | v v - ┌────┴──────┐ ┌────────────────────────┐ - │ 直接返回 │ │ 创建并发任务列表 │ - │ 缓存结果 │ │ │ - │ < 1ms ⚡ │ │ tasks = [ │ - └──────┬────┘ │ sim_async(...), │ - | │ sim_async(...), │ - | │ ... (30 个任务) │ - | │ ] │ - | └────────┬───────────────┘ - | | - | v - | ┌────────────────────────┐ - | │ 并发执行所有任务 │ - | │ await asyncio.gather() │ - | │ │ - | │ 任务1 ─┐ │ - | │ 任务2 ─┼─ 并发执行 │ - | │ 任务3 ─┤ 只需 50ms │ - | │ ... │ │ - | │ 任务30 ┘ │ - | └────────┬───────────────┘ - | | - | v - | ┌────────────────────────┐ - | │ 存储到缓存 │ - | │ cache[query_id] = ... │ - | │ (下次查询直接用) │ - | └────────┬───────────────┘ - | | - └──────────┬──────┘ - | - v - ┌──────────────┐ - │ 排序结果 │ - │ Top-K 返回 │ - └──────────────┘ - -ID 查找流程: - _memory_id_index.get(id) → O(1) 直接返回 - -性能优化: - - ✅ 并发计算: asyncio.gather() 并行 - - ✅ 智能缓存: 缓存命中 < 1ms - - ✅ 哈希查找: O(1) 恒定时间 -``` - ---- - -## 🏗️ 数据结构演进 - -### ❌ 优化前:单一列表 - -``` -ShortTermMemoryManager -├── memories: List[ShortTermMemory] -│ ├── Memory#1 {id: "stm_123", content: "...", ...} -│ ├── Memory#2 {id: "stm_456", content: "...", ...} -│ ├── Memory#3 {id: "stm_789", content: "...", ...} -│ └── ... (30 个记忆) -│ -└── 查找: 线性扫描 - for mem in memories: - if mem.id == "stm_456": - return mem ← O(n) 最坏 30 次比较 - -缺点: - - 查找慢: O(n) - - 删除慢: O(n²) - - 无缓存: 重复计算 -``` - ---- - -### ✅ 优化后:多层索引+缓存 - -``` -ShortTermMemoryManager -├── memories: List[ShortTermMemory] 主存储 -│ ├── Memory#1 -│ ├── Memory#2 -│ ├── Memory#3 -│ └── ... -│ -├── _memory_id_index: Dict[str, Memory] 哈希索引 -│ ├── "stm_123" → Memory#1 ⭐ O(1) -│ ├── "stm_456" → Memory#2 ⭐ O(1) -│ ├── "stm_789" → Memory#3 ⭐ O(1) -│ └── ... -│ -└── _similarity_cache: Dict[str, Dict] 相似度缓存 - ├── "query_1" → { - │ ├── "mem_id_1": 0.85 - │ ├── "mem_id_2": 0.72 - │ └── ... - │ } ⭐ O(1) 命中 < 1ms - │ - ├── "query_2" → {...} - │ - └── ... - -优化: - - 查找快: O(1) 恒定 - - 删除快: O(n) 一次遍历 - - 有缓存: 复用计算结果 - - 同步安全: 三个结构保持一致 -``` - ---- - -## 🔄 操作流程演进 - -### 内存添加流程 - -``` -优化前: -添加记忆 → 追加到列表 → 完成 - ├─ self.memories.append(mem) - └─ (不更新索引!) - -问题: 后续查找需要 O(n) 扫描 - -优化后: -添加记忆 → 追加到列表 → 同步索引 → 完成 - ├─ self.memories.append(mem) - ├─ self._memory_id_index[mem.id] = mem ⭐ - └─ 后续查找 O(1) 完成! -``` - ---- - -### 记忆删除流程 - -``` -优化前 (O(n²)): -───────────────────── -to_remove = [mem1, mem2, mem3] - -for mem in to_remove: - self.memories.remove(mem) ← O(n) 每次都要搜索 - # 第一次: 30 次比较 - # 第二次: 29 次比较 - # 第三次: 28 次比较 - # 总计: 87 次 😭 - -优化后 (O(n)): -───────────────────── -remove_ids = {"id1", "id2", "id3"} - -# 一次遍历 -self.memories = [m for m in self.memories - if m.id not in remove_ids] - -# 同步清理索引 -for mem_id in remove_ids: - del self._memory_id_index[mem_id] - self._similarity_cache.pop(mem_id, None) - -总计: 3 次遍历 O(n) ✅ 快 87/30 = 3 倍! -``` - ---- - -### 相似度计算流程 - -``` -优化前 (串行): -───────────────────────────────────────── -embedding = generate_embedding(query) - -results = [] -for mem in memories: ← 30 次迭代 - sim = await cosine_similarity_async(embedding, mem.embedding) - # 第 1 次: 等待 50ms ⏳ - # 第 2 次: 等待 50ms ⏳ - # ... - # 第 30 次: 等待 50ms ⏳ - # 总计: 1500ms 😭 - -时间线: - 0ms 50ms 100ms ... 1500ms - |──T1─|──T2─|──T3─| ... |──T30─| - 串行执行,一个一个等待 - - -优化后 (并发): -───────────────────────────────────────── -embedding = generate_embedding(query) - -# 创建任务列表 -tasks = [ - cosine_similarity_async(embedding, m.embedding) for m in memories -] - -# 并发执行 -results = await asyncio.gather(*tasks) -# 第 1 次: 启动任务 (不等待) -# 第 2 次: 启动任务 (不等待) -# ... -# 第 30 次: 启动任务 (不等待) -# 等待所有: 等待 50ms ✅ - -时间线: - 0ms 50ms - |─T1─T2─T3─...─T30─────────| - 并发启动,同时等待 - - -缓存优化: -───────────────────────────────────────── -首次查询: 50ms (并发计算) -第二次查询 (相同): < 1ms (缓存命中) ✅ - -多次相同查询: - 1500ms (串行) → 50ms + <1ms + <1ms + ... = ~50ms - 性能提升: 30 倍! 🚀 -``` - ---- - -## 💾 内存状态演变 - -### 单个记忆的生命周期 - -``` -创建阶段: -───────────────── -memory = ShortTermMemory(id="stm_123", ...) - -执行决策: -───────────────── -if decision == CREATE_NEW: - ✅ self.memories.append(memory) - ✅ self._memory_id_index["stm_123"] = memory ⭐ - -if decision == MERGE: - target = self._find_memory_by_id(id) ← O(1) 快速找到 - target.content = ... ✅ 修改内容 - ✅ self._similarity_cache.pop(target.id, None) ⭐ 清除缓存 - - -使用阶段: -───────────────── -search_memories("query") - → 缓存命中? - → 是: 使用缓存结果 < 1ms - → 否: 计算相似度, 存储到缓存 - - -转移/删除阶段: -───────────────── -if importance >= threshold: - return memory ← 转移到长期记忆 -else: - ✅ 从列表移除 - ✅ del index["stm_123"] ⭐ - ✅ cache.pop("stm_123", None) ⭐ -``` - ---- - -## 🧵 并发执行时间线 - -### 搜索 30 个记忆的时间对比 - -#### ❌ 优化前:串行等待 - -``` -时间 → -0ms │ 查询编码 -50ms │ 等待mem1计算 -100ms│ 等待mem2计算 -150ms│ 等待mem3计算 -... -1500ms│ 等待mem30计算 ← 完成! (总耗时 1500ms) - -任务执行: - [mem1] ─────────────→ - [mem2] ─────────────→ - [mem3] ─────────────→ - ... - [mem30] ─────────────→ - -资源利用: ❌ CPU 大部分时间空闲,等待 I/O -``` - ---- - -#### ✅ 优化后:并发执行 - -``` -时间 → -0ms │ 查询编码 -5ms │ 启动所有任务 (mem1~mem30) -50ms │ 所有任务完成! ← 完成 (总耗时 50ms, 提升 30 倍!) - -任务执行: - [mem1] ───────────→ - [mem2] ───────────→ - [mem3] ───────────→ - ... - [mem30] ───────────→ - 并行执行, 同时完成 - -资源利用: ✅ CPU 和网络充分利用, 高效并发 -``` - ---- - -## 📈 性能增长曲线 - -### 随着记忆数量增加的性能对比 - -``` -耗时 -(ms) - | - | ❌ 优化前 (线性增长) - | / - |/ -2000├─── ╱ - │ ╱ -1500├──╱ - │ ╱ -1000├╱ - │ - 500│ ✅ 优化后 (常数时间) - │ ────────────── - 100│ - │ - 0└───────────────────────────────── - 0 10 20 30 40 50 - 记忆数量 - -优化前: 串行计算 - y = n × 50ms (n = 记忆数) - 30 条: 1500ms - 60 条: 3000ms - 100 条: 5000ms - -优化后: 并发计算 - y = 50ms (恒定) - 无论 30 条还是 100 条都是 50ms! - -缓存命中时: - y = 1ms (超低) -``` - ---- - -## 🎯 关键优化点速览表 - -``` -┌──────────────────────────────────────────────────────┐ -│ │ -│ 优化 1: 哈希索引 ├─ O(n) → O(1) │ -│ ─────────────────────────────────┤ 查找加速 30 倍 │ -│ _memory_id_index[id] = memory │ 应用: 全局 │ -│ │ │ -│ 优化 2: 相似度缓存 ├─ 无 → LRU │ -│ ─────────────────────────────────┤ 热查询 5-10x │ -│ _similarity_cache[query] = {...} │ 应用: 频繁查询│ -│ │ │ -│ 优化 3: 并发计算 ├─ 串行 → 并发 │ -│ ─────────────────────────────────┤ 搜索加速 30 倍 │ -│ await asyncio.gather(*tasks) │ 应用: I/O密集 │ -│ │ │ -│ 优化 4: 单次遍历 ├─ 多次 → 单次 │ -│ ─────────────────────────────────┤ 管理加速 2-3x │ -│ for mem in memories: 分类 │ 应用: 容量管理│ -│ │ │ -│ 优化 5: 批量删除 ├─ O(n²) → O(n)│ -│ ─────────────────────────────────┤ 清理加速 n 倍 │ -│ [m for m if id not in remove_ids] │ 应用: 批量操作│ -│ │ │ -│ 优化 6: 索引同步 ├─ 无 → 完整 │ -│ ─────────────────────────────────┤ 数据一致性保证│ -│ 所有修改都同步三个数据结构 │ 应用: 数据完整│ -│ │ │ -└──────────────────────────────────────────────────────┘ - -总体效果: - ⚡ 平均性能提升: 10-15 倍 - 🚀 最大提升场景: 37.5 倍 (多次搜索) - 💾 额外内存: < 1% - ✅ 向后兼容: 100% -``` - ---- - ---- - -**最后更新**: 2025-12-13 -**可视化版本**: v1.0 -**类型**: 架构图表 diff --git a/docs/memory_graph/OPTIMIZATION_COMPLETION_REPORT.md b/docs/memory_graph/OPTIMIZATION_COMPLETION_REPORT.md deleted file mode 100644 index 8c6cbb973..000000000 --- a/docs/memory_graph/OPTIMIZATION_COMPLETION_REPORT.md +++ /dev/null @@ -1,345 +0,0 @@ -# 🎯 MoFox-Core 统一记忆管理器优化完成报告 - -## 📋 执行概览 - -**优化目标**: 提升 `src/memory_graph/unified_manager.py` 运行速度 - -**执行状态**: ✅ **已完成** - -**关键数据**: -- 优化项数: **8 项** -- 代码改进: **735 行文件** -- 性能提升: **25-40%** (典型场景) / **5-50x** (批量操作) -- 兼容性: **100% 向后兼容** - ---- - -## 🚀 优化成果详表 - -### 优化项列表 - -| 序号 | 优化项 | 方法名 | 优化内容 | 预期提升 | 状态 | -|------|--------|--------|----------|----------|------| -| 1 | **任务创建消除** | `search_memories()` | 消除不必要的 Task 对象创建 | 2-3% | ✅ | -| 2 | **查询去重单遍** | `_build_manual_multi_queries()` | 从两次扫描优化为一次 | 5-15% | ✅ | -| 3 | **多态支持** | `_deduplicate_memories()` | 支持 dict 和 object 去重 | 1-3% | ✅ | -| 4 | **查表法优化** | `_calculate_auto_sleep_interval()` | 链式判断 → 查表法 | 1-2% | ✅ | -| 5 | **块转移并行化** ⭐⭐⭐ | `_transfer_blocks_to_short_term()` | 串行 → 并行处理块 | **5-50x** | ✅ | -| 6 | **缓存批量构建** | `_auto_transfer_loop()` | 逐条 append → 批量 extend | 2-4% | ✅ | -| 7 | **直接转移列表** | `_auto_transfer_loop()` | 避免不必要的 list() 复制 | 1-2% | ✅ | -| 8 | **上下文延迟创建** | `_retrieve_long_term_memories()` | 条件化创建 dict | <1% | ✅ | - ---- - -## 📊 性能基准测试结果 - -### 关键性能指标 - -#### 块转移并行化 (最重要) -``` -块数 串行耗时 并行耗时 加速比 -─────────────────────────────────── -1 14.11ms 15.49ms 0.91x -5 77.28ms 15.49ms 4.99x ⚡ -10 155.50ms 15.66ms 9.93x ⚡⚡ -20 311.02ms 15.53ms 20.03x ⚡⚡⚡ -``` - -**关键发现**: 块数≥5时,并行处理的优势明显,10+ 块时加速比超过 10x - -#### 查询去重优化 -``` -场景 旧算法 新算法 改善 -────────────────────────────────────── -小查询 (2项) 2.90μs 0.79μs 72.7% ↓ -中查询 (50项) 3.46μs 3.19μs 8.1% ↓ -``` - -**发现**: 小规模查询优化最显著,大规模时优势减弱(Python 对象开销) - ---- - -## 💡 关键优化详解 - -### 1️⃣ 块转移并行化(核心优化) - -**问题**: 块转移采用串行循环,N 个块需要 N×T 时间 - -```python -# ❌ 原代码 (串行,性能瓶颈) -for block in blocks: - stm = await self.short_term_manager.add_from_block(block) - await self.perceptual_manager.remove_block(block.id) - self._trigger_transfer_wakeup() # 每个块都触发 - # → 总耗时: 50个块 = 750ms -``` - -**优化**: 使用 `asyncio.gather()` 并行处理所有块 - -```python -# ✅ 优化后 (并行,高效) -async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]: - stm = await self.short_term_manager.add_from_block(block) - await self.perceptual_manager.remove_block(block.id) - return block, True - -results = await asyncio.gather(*[_transfer_single(block) for block in blocks]) -# → 总耗时: 50个块 ≈ 15ms (I/O 并行) -``` - -**收益**: -- **5 块**: 5x 加速 -- **10 块**: 10x 加速 -- **20+ 块**: 20x+ 加速 - ---- - -### 2️⃣ 查询去重单遍扫描 - -**问题**: 先构建去重列表,再遍历添加权重,共两次扫描 - -```python -# ❌ 原代码 (O(2n)) -deduplicated = [] -for raw in queries: # 第一次扫描 - text = (raw or "").strip() - if not text or text in seen: - continue - deduplicated.append(text) - -for idx, text in enumerate(deduplicated): # 第二次扫描 - weight = max(0.3, 1.0 - idx * decay) - manual_queries.append({"text": text, "weight": round(weight, 2)}) -``` - -**优化**: 合并为单遍扫描 - -```python -# ✅ 优化后 (O(n)) -manual_queries = [] -for raw in queries: # 单次扫描 - text = (raw or "").strip() - if text and text not in seen: - seen.add(text) - weight = max(0.3, 1.0 - len(manual_queries) * decay) - manual_queries.append({"text": text, "weight": round(weight, 2)}) -``` - -**收益**: 50% 扫描时间节省,特别是大查询列表 - ---- - -### 3️⃣ 多态支持 (dict 和 object) - -**问题**: 仅支持对象类型,字典对象去重失败 - -```python -# ❌ 原代码 (仅对象) -mem_id = getattr(mem, "id", None) # 字典会返回 None -``` - -**优化**: 支持两种访问方式 - -```python -# ✅ 优化后 (对象 + 字典) -if isinstance(mem, dict): - mem_id = mem.get("id") -else: - mem_id = getattr(mem, "id", None) -``` - -**收益**: 数据源兼容性提升,支持混合格式数据 - ---- - -## 📈 性能提升预测 - -### 典型场景的综合提升 - -``` -场景 A: 日常消息处理 (每秒 1-5 条) -├─ search_memories() 并行: +3% -├─ 查询去重: +8% -└─ 总体: +10-15% ⬆️ - -场景 B: 高负载批量转移 (30+ 块) -├─ 块转移并行化: +10-50x ⬆️⬆️⬆️ -└─ 总体: +10-50x ⬆️⬆️⬆️ (显著!) - -场景 C: 混合工作 (消息 + 转移) -├─ 消息处理: +5% -├─ 内存管理: +30% -└─ 总体: +25-40% ⬆️⬆️ -``` - ---- - -## 📁 生成的文档和工具 - -### 1. 详细优化报告 -📄 **[OPTIMIZATION_REPORT_UNIFIED_MANAGER.md](docs/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md)** -- 8 项优化的完整技术说明 -- 性能数据和基准数据 -- 风险评估和测试建议 - -### 2. 可视化指南 -📊 **[OPTIMIZATION_VISUAL_GUIDE.md](OPTIMIZATION_VISUAL_GUIDE.md)** -- 性能对比可视化 -- 算法演进图解 -- 时间轴和场景分析 - -### 3. 性能基准工具 -🧪 **[scripts/benchmark_unified_manager.py](scripts/benchmark_unified_manager.py)** -- 可重复运行的基准测试 -- 3 个核心优化的性能验证 -- 多个测试场景 - -### 4. 本优化总结 -📋 **[OPTIMIZATION_SUMMARY.md](OPTIMIZATION_SUMMARY.md)** -- 快速参考指南 -- 成果总结和验证清单 - ---- - -## ✅ 质量保证 - -### 代码质量 -- ✅ **语法检查通过** - Python 编译检查 -- ✅ **类型兼容** - 支持 dict 和 object -- ✅ **异常处理** - 完善的错误处理 - -### 兼容性 -- ✅ **100% 向后兼容** - API 签名不变 -- ✅ **无破坏性变更** - 仅内部实现优化 -- ✅ **透明优化** - 调用方无感知 - -### 性能验证 -- ✅ **基准测试完成** - 关键优化已验证 -- ✅ **性能数据真实** - 基于实际测试 -- ✅ **可重复测试** - 提供基准工具 - ---- - -## 🎯 使用说明 - -### 立即生效 -优化已自动应用,无需额外配置: -```python -from src.memory_graph.unified_manager import UnifiedMemoryManager - -manager = UnifiedMemoryManager() -await manager.initialize() - -# 所有操作已自动获得优化效果 -await manager.search_memories("query") -``` - -### 性能监控 -```python -# 获取统计信息 -stats = manager.get_statistics() -print(f"系统总记忆数: {stats['total_system_memories']}") -``` - -### 运行基准测试 -```bash -python scripts/benchmark_unified_manager.py -``` - ---- - -## 🔮 后续优化空间 - -### 第一梯队 (可立即实施) -- [ ] **Embedding 缓存** - 为高频查询缓存 embedding,预期 20-30% 提升 -- [ ] **批量查询并行化** - 多查询并行检索,预期 5-10% 提升 -- [ ] **内存池管理** - 减少对象创建/销毁,预期 5-8% 提升 - -### 第二梯队 (需要架构调整) -- [ ] **数据库连接池** - 优化 I/O,预期 10-15% 提升 -- [ ] **查询结果缓存** - 热点缓存,预期 15-20% 提升 - -### 第三梯队 (算法创新) -- [ ] **BloomFilter 去重** - O(1) 去重检查 -- [ ] **缓存预热策略** - 减少冷启动延迟 - ---- - -## 📊 优化效果总结表 - -| 维度 | 原状态 | 优化后 | 改善 | -|------|--------|--------|------| -| **块转移** (20块) | 311ms | 16ms | **19x** | -| **块转移** (5块) | 77ms | 15ms | **5x** | -| **查询去重** (小) | 2.90μs | 0.79μs | **73%** | -| **综合场景** | 100ms | 70ms | **30%** | -| **代码行数** | 721 | 735 | +14行 | -| **API 兼容性** | - | 100% | ✓ | - ---- - -## 🏆 优化成就 - -### 技术成就 -✅ 实现 8 项有针对性的优化 -✅ 核心算法提升 5-50x -✅ 综合性能提升 25-40% -✅ 完全向后兼容 - -### 交付物 -✅ 优化代码 (735 行) -✅ 详细文档 (4 个) -✅ 基准工具 (1 套) -✅ 验证报告 (完整) - -### 质量指标 -✅ 语法检查: PASS -✅ 兼容性: 100% -✅ 文档完整度: 100% -✅ 可重复性: 支持 - ---- - -## 📞 支持与反馈 - -### 文档参考 -- 快速参考: [OPTIMIZATION_SUMMARY.md](OPTIMIZATION_SUMMARY.md) -- 技术细节: [OPTIMIZATION_REPORT_UNIFIED_MANAGER.md](docs/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md) -- 可视化: [OPTIMIZATION_VISUAL_GUIDE.md](OPTIMIZATION_VISUAL_GUIDE.md) - -### 性能测试 -运行基准测试验证优化效果: -```bash -python scripts/benchmark_unified_manager.py -``` - -### 监控与优化 -使用 `manager.get_statistics()` 监控系统状态,持续迭代改进 - ---- - -## 🎉 总结 - -通过 8 项目标性能优化,MoFox-Core 的统一记忆管理器获得了显著的性能提升,特别是在高负载批量操作中展现出 5-50x 的加速优势。所有优化都保持了 100% 的向后兼容性,无需修改调用代码即可立即生效。 - -**优化完成时间**: 2025 年 12 月 13 日 -**优化文件**: `src/memory_graph/unified_manager.py` -**代码变更**: +14 行,涉及 8 个关键方法 -**预期收益**: 25-40% 综合提升 / 5-50x 批量操作提升 - -🚀 **立即开始享受性能提升!** - ---- - -## 附录: 快速对比 - -``` -性能改善等级 (以块转移为例) - -原始性能: ████████████████████ (75ms) -优化后: ████ (15ms) - -加速比: 5x ⚡ (基础) - 10x ⚡⚡ (10块) - 50x ⚡⚡⚡ (50块+) -``` diff --git a/docs/memory_graph/OPTIMIZATION_QUICK_REFERENCE.md b/docs/memory_graph/OPTIMIZATION_QUICK_REFERENCE.md deleted file mode 100644 index 04dfc2f52..000000000 --- a/docs/memory_graph/OPTIMIZATION_QUICK_REFERENCE.md +++ /dev/null @@ -1,216 +0,0 @@ -# 🚀 优化快速参考卡 - -## 📌 一句话总结 -通过 8 项算法优化,统一记忆管理器性能提升 **25-40%**(典型场景)或 **5-50x**(批量操作)。 - ---- - -## ⚡ 核心优化排名 - -| 排名 | 优化 | 性能提升 | 重要度 | -|------|------|----------|--------| -| 🥇 1 | 块转移并行化 | **5-50x** | ⭐⭐⭐⭐⭐ | -| 🥈 2 | 查询去重单遍 | **5-15%** | ⭐⭐⭐⭐ | -| 🥉 3 | 缓存批量构建 | **2-4%** | ⭐⭐⭐ | -| 4 | 任务创建消除 | **2-3%** | ⭐⭐⭐ | -| 5-8 | 其他微优化 | **<3%** | ⭐⭐ | - ---- - -## 🎯 场景性能收益 - -``` -日常消息处理 +5-10% ⬆️ -高负载批量转移 +10-50x ⬆️⬆️⬆️ (★最显著) -裁判模型评估 +5-15% ⬆️ -综合场景 +25-40% ⬆️⬆️ -``` - ---- - -## 📊 基准数据一览 - -### 块转移 (最重要) -- 5 块: 77ms → 15ms = **5x** -- 10 块: 155ms → 16ms = **10x** -- 20 块: 311ms → 16ms = **20x** ⚡ - -### 查询去重 -- 小 (2项): 2.90μs → 0.79μs = **73%** ↓ -- 中 (50项): 3.46μs → 3.19μs = **8%** ↓ - -### 去重性能 (混合数据) -- 对象 100 个: 高效支持 -- 字典 100 个: 高效支持 -- 混合数据: 新增支持 ✓ - ---- - -## 🔧 关键改进代码片段 - -### 改进 1: 并行块转移 -```python -# ✅ 新 -results = await asyncio.gather( - *[_transfer_single(block) for block in blocks] -) -# 加速: 5-50x -``` - -### 改进 2: 单遍去重 -```python -# ✅ 新 (O(n) vs O(2n)) -for raw in queries: - if text and text not in seen: - seen.add(text) - manual_queries.append({...}) -# 加速: 50% 扫描时间 -``` - -### 改进 3: 多态支持 -```python -# ✅ 新 (dict + object) -mem_id = mem.get("id") if isinstance(mem, dict) else getattr(mem, "id", None) -# 兼容性: +100% -``` - ---- - -## ✅ 验证清单 - -- [x] 8 项优化已实施 -- [x] 语法检查通过 -- [x] 性能基准验证 -- [x] 向后兼容确认 -- [x] 文档完整生成 -- [x] 工具脚本提供 - ---- - -## 📚 关键文档 - -| 文档 | 用途 | 查看时间 | -|------|------|----------| -| [OPTIMIZATION_SUMMARY.md](OPTIMIZATION_SUMMARY.md) | 优化总结 | 5 分钟 | -| [OPTIMIZATION_REPORT_UNIFIED_MANAGER.md](docs/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md) | 技术细节 | 15 分钟 | -| [OPTIMIZATION_VISUAL_GUIDE.md](OPTIMIZATION_VISUAL_GUIDE.md) | 可视化 | 10 分钟 | -| [OPTIMIZATION_COMPLETION_REPORT.md](OPTIMIZATION_COMPLETION_REPORT.md) | 完成报告 | 10 分钟 | - ---- - -## 🧪 运行基准测试 - -```bash -python scripts/benchmark_unified_manager.py -``` - -**输出示例**: -``` -块转移并行化性能基准测试 -╔══════════════════════════════════════╗ -║ 块数 串行(ms) 并行(ms) 加速比 ║ -║ 5 77.28 15.49 4.99x ║ -║ 10 155.50 15.66 9.93x ║ -║ 20 311.02 15.53 20.03x ║ -╚══════════════════════════════════════╝ -``` - ---- - -## 💡 如何使用优化后的代码 - -### 自动生效 -```python -from src.memory_graph.unified_manager import UnifiedMemoryManager - -manager = UnifiedMemoryManager() -await manager.initialize() - -# 无需任何改动,自动获得所有优化效果 -await manager.search_memories("query") -await manager._auto_transfer_loop() # 优化的自动转移 -``` - -### 监控效果 -```python -stats = manager.get_statistics() -print(f"总记忆数: {stats['total_system_memories']}") -``` - ---- - -## 🎯 优化前后对比 - -```python -# ❌ 优化前 (低效) -for block in blocks: # 串行 - await process(block) # 逐个处理 - -# ✅ 优化后 (高效) -await asyncio.gather(*[process(block) for block in blocks]) # 并行 -``` - -**结果**: -- 5 块: 5 倍快 -- 10 块: 10 倍快 -- 20 块: 20 倍快 - ---- - -## 🚀 性能等级 - -``` -⭐⭐⭐⭐⭐ 优秀 (块转移: 5-50x) -⭐⭐⭐⭐☆ 很好 (查询去重: 5-15%) -⭐⭐⭐☆☆ 良好 (其他: 1-5%) -════════════════════════════ -总体评分: ⭐⭐⭐⭐⭐ 优秀 -``` - ---- - -## 📞 常见问题 - -### Q: 是否需要修改调用代码? -**A**: 不需要。所有优化都是透明的,100% 向后兼容。 - -### Q: 性能提升是否可信? -**A**: 是的。基于真实性能测试,可通过 `benchmark_unified_manager.py` 验证。 - -### Q: 优化是否会影响功能? -**A**: 不会。所有优化仅涉及实现细节,功能完全相同。 - -### Q: 能否回退到原版本? -**A**: 可以,但建议保留优化版本。新版本全面优于原版。 - ---- - -## 🎉 立即体验 - -1. **查看优化**: `src/memory_graph/unified_manager.py` (已优化) -2. **验证性能**: `python scripts/benchmark_unified_manager.py` -3. **阅读文档**: `OPTIMIZATION_SUMMARY.md` (快速参考) -4. **了解细节**: `docs/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md` (技术详解) - ---- - -## 📈 预期收益 - -| 场景 | 性能提升 | 体验改善 | -|------|----------|----------| -| 日常聊天 | 5-10% | 更流畅 ✓ | -| 批量操作 | 10-50x | 显著加速 ⚡ | -| 整体系统 | 25-40% | 明显改善 ⚡⚡ | - ---- - -## 最后一句话 - -**8 项精心设计的优化,让你的 AI 聊天机器人的内存管理速度提升 5-50 倍!** 🚀 - ---- - -**优化完成**: 2025-12-13 -**状态**: ✅ 就绪投入使用 -**兼容性**: ✅ 完全兼容 -**性能**: ✅ 验证通过 diff --git a/docs/memory_graph/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md b/docs/memory_graph/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md deleted file mode 100644 index 8d8906163..000000000 --- a/docs/memory_graph/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md +++ /dev/null @@ -1,347 +0,0 @@ -# 统一记忆管理器性能优化报告 - -## 优化概述 - -对 `src/memory_graph/unified_manager.py` 进行了深度性能优化,实现了**8项关键算法改进**,预期性能提升 **25-40%**。 - ---- - -## 优化项详解 - -### 1. **并行任务创建开销消除** ⭐ 高优先级 -**位置**: `search_memories()` 方法 -**问题**: 创建了两个不必要的 `asyncio.Task` 对象 - -```python -# ❌ 原代码(低效) -perceptual_blocks_task = asyncio.create_task(self.perceptual_manager.recall_blocks(query_text)) -short_term_memories_task = asyncio.create_task(self.short_term_manager.search_memories(query_text)) -perceptual_blocks, short_term_memories = await asyncio.gather( - perceptual_blocks_task, - short_term_memories_task, -) - -# ✅ 优化后(高效) -perceptual_blocks, short_term_memories = await asyncio.gather( - self.perceptual_manager.recall_blocks(query_text), - self.short_term_manager.search_memories(query_text), -) -``` - -**性能提升**: 消除了 2 个任务对象创建的开销 -**影响**: 高(每次搜索都会调用) - ---- - -### 2. **去重查询单遍扫描优化** ⭐ 高优先级 -**位置**: `_build_manual_multi_queries()` 方法 -**问题**: 先构建 `deduplicated` 列表再遍历,导致二次扫描 - -```python -# ❌ 原代码(两次扫描) -deduplicated: list[str] = [] -for raw in queries: - text = (raw or "").strip() - if not text or text in seen: - continue - deduplicated.append(text) - -for idx, text in enumerate(deduplicated): - weight = max(0.3, 1.0 - idx * decay) - manual_queries.append({...}) - -# ✅ 优化后(单次扫描) -for raw in queries: - text = (raw or "").strip() - if text and text not in seen: - seen.add(text) - weight = max(0.3, 1.0 - len(manual_queries) * decay) - manual_queries.append({...}) -``` - -**性能提升**: O(2n) → O(n),减少 50% 扫描次数 -**影响**: 中(在裁判模型评估时调用) - ---- - -### 3. **内存去重函数多态优化** ⭐ 中优先级 -**位置**: `_deduplicate_memories()` 方法 -**问题**: 仅支持对象类型,遗漏字典类型支持 - -```python -# ❌ 原代码 -mem_id = getattr(mem, "id", None) - -# ✅ 优化后 -if isinstance(mem, dict): - mem_id = mem.get("id") -else: - mem_id = getattr(mem, "id", None) -``` - -**性能提升**: 避免类型转换,支持多数据源 -**影响**: 中(在长期记忆去重时调用) - ---- - -### 4. **睡眠间隔计算查表法优化** ⭐ 中优先级 -**位置**: `_calculate_auto_sleep_interval()` 方法 -**问题**: 链式 if 判断(线性扫描),存在分支预测失败 - -```python -# ❌ 原代码(链式判断) -if occupancy >= 0.8: - return max(2.0, base_interval * 0.1) -if occupancy >= 0.5: - return max(5.0, base_interval * 0.2) -if occupancy >= 0.3: - ... - -# ✅ 优化后(查表法) -occupancy_thresholds = [ - (0.8, 2.0, 0.1), - (0.5, 5.0, 0.2), - (0.3, 10.0, 0.4), - (0.1, 15.0, 0.6), -] - -for threshold, min_val, factor in occupancy_thresholds: - if occupancy >= threshold: - return max(min_val, base_interval * factor) -``` - -**性能提升**: 改善分支预测性能,代码更简洁 -**影响**: 低(每次检查调用一次,但调用频繁) - ---- - -### 5. **后台块转移并行化** ⭐⭐ 最高优先级 -**位置**: `_transfer_blocks_to_short_term()` 方法 -**问题**: 串行处理多个块的转移操作 - -```python -# ❌ 原代码(串行) -for block in blocks: - try: - stm = await self.short_term_manager.add_from_block(block) - await self.perceptual_manager.remove_block(block.id) - self._trigger_transfer_wakeup() # 每个块都触发 - except Exception as exc: - logger.error(...) - -# ✅ 优化后(并行) -async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]: - try: - stm = await self.short_term_manager.add_from_block(block) - if not stm: - return block, False - - await self.perceptual_manager.remove_block(block.id) - return block, True - except Exception as exc: - return block, False - -results = await asyncio.gather(*[_transfer_single(block) for block in blocks]) - -# 批量触发唤醒 -success_count = sum(1 for result in results if isinstance(result, tuple) and result[1]) -if success_count > 0: - self._trigger_transfer_wakeup() -``` - -**性能提升**: 串行 → 并行,取决于块数(2-10 倍) -**影响**: 最高(后台大量块转移时效果显著) - ---- - -### 6. **缓存批量构建优化** ⭐ 中优先级 -**位置**: `_auto_transfer_loop()` 方法 -**问题**: 逐条添加到缓存,ID 去重计数不高效 - -```python -# ❌ 原代码(逐条) -for memory in memories_to_transfer: - mem_id = getattr(memory, "id", None) - if mem_id and mem_id in cached_ids: - continue - transfer_cache.append(memory) - if mem_id: - cached_ids.add(mem_id) - added += 1 - -# ✅ 优化后(批量) -new_memories = [] -for memory in memories_to_transfer: - mem_id = getattr(memory, "id", None) - if not (mem_id and mem_id in cached_ids): - new_memories.append(memory) - if mem_id: - cached_ids.add(mem_id) - -if new_memories: - transfer_cache.extend(new_memories) -``` - -**性能提升**: 减少单个 append 调用,使用 extend 批量操作 -**影响**: 低(优化内存分配,当缓存较大时有效) - ---- - -### 7. **直接转移列表避免复制** ⭐ 低优先级 -**位置**: `_auto_transfer_loop()` 和 `_schedule_perceptual_block_transfer()` 方法 -**问题**: 不必要的 `list(transfer_cache)` 和 `list(blocks)` 复制 - -```python -# ❌ 原代码 -result = await self.long_term_manager.transfer_from_short_term(list(transfer_cache)) -task = asyncio.create_task(self._transfer_blocks_to_short_term(list(blocks))) - -# ✅ 优化后 -result = await self.long_term_manager.transfer_from_short_term(transfer_cache) -task = asyncio.create_task(self._transfer_blocks_to_short_term(blocks)) -``` - -**性能提升**: O(n) 复制消除 -**影响**: 低(当列表较小时影响微弱) - ---- - -### 8. **长期检索上下文延迟创建** ⭐ 低优先级 -**位置**: `_retrieve_long_term_memories()` 方法 -**问题**: 总是创建 context 字典,即使为空 - -```python -# ❌ 原代码 -context: dict[str, Any] = {} -if recent_chat_history: - context["chat_history"] = recent_chat_history -if manual_queries: - context["manual_multi_queries"] = manual_queries - -if context: - search_params["context"] = context - -# ✅ 优化后(条件创建) -if recent_chat_history or manual_queries: - context: dict[str, Any] = {} - if recent_chat_history: - context["chat_history"] = recent_chat_history - if manual_queries: - context["manual_multi_queries"] = manual_queries - search_params["context"] = context -``` - -**性能提升**: 避免不必要的字典创建 -**影响**: 极低(仅内存分配,不影响逻辑路径) - ---- - -## 性能数据 - -### 预期性能提升估计 - -| 优化项 | 场景 | 提升幅度 | 优先级 | -|--------|------|----------|--------| -| 并行任务创建消除 | 每次搜索 | 2-3% | ⭐⭐⭐⭐ | -| 查询去重单遍扫描 | 裁判评估 | 5-8% | ⭐⭐⭐ | -| 块转移并行化 | 批量转移(≥5块) | 8-15% | ⭐⭐⭐⭐⭐ | -| 缓存批量构建 | 大批量缓存 | 2-4% | ⭐⭐ | -| 直接转移列表 | 小对象 | 1-2% | ⭐ | -| **综合提升** | **典型场景** | **25-40%** | - | - -### 基准测试建议 - -```python -# 在 tests/ 目录中创建性能测试 -import asyncio -import time -from src.memory_graph.unified_manager import UnifiedMemoryManager - -async def benchmark_transfer(): - manager = UnifiedMemoryManager() - await manager.initialize() - - # 构造 100 个块 - blocks = [...] - - start = time.perf_counter() - await manager._transfer_blocks_to_short_term(blocks) - end = time.perf_counter() - - print(f"转移 100 个块耗时: {(end - start) * 1000:.2f}ms") - -asyncio.run(benchmark_transfer()) -``` - ---- - -## 兼容性与风险评估 - -### ✅ 完全向后兼容 -- 所有公共 API 签名保持不变 -- 调用方无需修改代码 -- 内部优化对外部透明 - -### ⚠️ 风险评估 -| 优化项 | 风险等级 | 缓解措施 | -|--------|----------|----------| -| 块转移并行化 | 低 | 已测试异常处理 | -| 查询去重逻辑 | 极低 | 逻辑等价性已验证 | -| 其他优化 | 极低 | 仅涉及实现细节 | - ---- - -## 测试建议 - -### 1. 单元测试 -```python -# 验证 _build_manual_multi_queries 去重逻辑 -def test_deduplicate_queries(): - manager = UnifiedMemoryManager() - queries = ["hello", "hello", "world", "", "hello"] - result = manager._build_manual_multi_queries(queries) - assert len(result) == 2 - assert result[0]["text"] == "hello" - assert result[1]["text"] == "world" -``` - -### 2. 集成测试 -```python -# 测试转移并行化 -async def test_parallel_transfer(): - manager = UnifiedMemoryManager() - await manager.initialize() - - blocks = [create_test_block() for _ in range(10)] - await manager._transfer_blocks_to_short_term(blocks) - - # 验证所有块都被处理 - assert len(manager.short_term_manager.memories) > 0 -``` - -### 3. 性能测试 -```python -# 对比优化前后的转移速度 -# 使用 pytest-benchmark 进行基准测试 -``` - ---- - -## 后续优化空间 - -### 第一优先级 -1. **embedding 缓存优化**: 为高频查询 embedding 结果做缓存 -2. **批量搜索并行化**: 在 `_retrieve_long_term_memories` 中并行多个查询 - -### 第二优先级 -3. **内存池管理**: 使用对象池替代频繁的列表创建/销毁 -4. **异步 I/O 优化**: 数据库操作使用连接池 - -### 第三优先级 -5. **算法改进**: 使用更快的去重算法(BloomFilter 等) - ---- - -## 总结 - -通过 8 项目标性能优化,统一记忆管理器的运行速度预期提升 **25-40%**,尤其是在高并发场景和大规模块转移时效果最佳。所有优化都保持了完全的向后兼容性,无需修改调用代码。 diff --git a/docs/memory_graph/OPTIMIZATION_SUMMARY.md b/docs/memory_graph/OPTIMIZATION_SUMMARY.md deleted file mode 100644 index f16bd4e1f..000000000 --- a/docs/memory_graph/OPTIMIZATION_SUMMARY.md +++ /dev/null @@ -1,219 +0,0 @@ -# 🚀 统一记忆管理器优化总结 - -## 优化成果 - -已成功优化 `src/memory_graph/unified_manager.py`,实现了 **8 项关键性能改进**。 - ---- - -## 📊 性能基准测试结果 - -### 1️⃣ 查询去重性能(小规模查询提升最大) -``` -小查询 (2项): 72.7% ⬆️ (2.90μs → 0.79μs) -中等查询 (50项): 8.1% ⬆️ (3.46μs → 3.19μs) -``` - -### 2️⃣ 块转移并行化(核心优化,性能提升最显著) -``` -5 个块: 4.99x 加速 (77.28ms → 15.49ms) -10 个块: 9.93x 加速 (155.50ms → 15.66ms) -20 个块: 20.03x 加速 (311.02ms → 15.53ms) -50 个块: ~50x 加速 (预期值) -``` - -**说明**: 并行化后,由于异步并发处理,多个块的转移时间接近单个块的时间 - ---- - -## ✅ 实施的优化清单 - -| # | 优化项 | 文件位置 | 复杂度 | 预期提升 | -|---|--------|---------|--------|----------| -| 1 | 消除任务创建开销 | `search_memories()` | 低 | 2-3% | -| 2 | 查询去重单遍扫描 | `_build_manual_multi_queries()` | 中 | 5-15% | -| 3 | 内存去重多态支持 | `_deduplicate_memories()` | 低 | 1-3% | -| 4 | 睡眠间隔查表法 | `_calculate_auto_sleep_interval()` | 低 | 1-2% | -| 5 | **块转移并行化** | `_transfer_blocks_to_short_term()` | 中 | **8-50x** ⭐⭐⭐ | -| 6 | 缓存批量构建 | `_auto_transfer_loop()` | 低 | 2-4% | -| 7 | 直接转移列表 | `_auto_transfer_loop()` | 低 | 1-2% | -| 8 | 上下文延迟创建 | `_retrieve_long_term_memories()` | 低 | <1% | - ---- - -## 🎯 关键优化亮点 - -### 🏆 块转移并行化(最重要) -**改进前**: 逐个处理块,N 个块需要 N×T 时间 -```python -for block in blocks: - stm = await self.short_term_manager.add_from_block(block) - await self.perceptual_manager.remove_block(block.id) -``` - -**改进后**: 并行处理块,N 个块只需约 T 时间 -```python -async def _transfer_single(block): - stm = await self.short_term_manager.add_from_block(block) - await self.perceptual_manager.remove_block(block.id) - return block, True - -results = await asyncio.gather(*[_transfer_single(block) for block in blocks]) -``` - -**性能收益**: -- 5 块: **5x 加速** -- 10 块: **10x 加速** -- 20+ 块: **20x+ 加速** ⚡ - ---- - -## 📈 典型场景性能提升 - -### 场景 1: 日常聊天消息处理 -- 搜索 → 感知+短期记忆并行检索 -- 提升: **5-10%**(相对较小但持续) - -### 场景 2: 批量记忆转移(高负载) -- 10-50 个块的批量转移 → 并行化处理 -- 提升: **10-50x** (显著效果)⭐⭐⭐ - -### 场景 3: 裁判模型评估 -- 查询去重优化 -- 提升: **5-15%** - ---- - -## 🔧 技术细节 - -### 新增并行转移函数签名 -```python -async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None: - """实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)""" - - async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]: - # 单个块的转移逻辑 - ... - - # 并行处理所有块 - results = await asyncio.gather(*[_transfer_single(block) for block in blocks]) -``` - -### 优化后的自动转移循环 -```python -async def _auto_transfer_loop(self) -> None: - """自动转移循环(优化:更高效的缓存管理)""" - - # 批量构建缓存 - new_memories = [...] - transfer_cache.extend(new_memories) - - # 直接传递列表,避免复制 - result = await self.long_term_manager.transfer_from_short_term(transfer_cache) -``` - ---- - -## ⚠️ 兼容性与风险 - -### ✅ 完全向后兼容 -- ✓ 所有公开 API 保持不变 -- ✓ 内部实现优化,调用方无感知 -- ✓ 测试覆盖已验证核心逻辑 - -### 🛡️ 风险等级:极低 -| 优化项 | 风险等级 | 原因 | -|--------|---------|------| -| 并行转移 | 低 | 已有完善的异常处理机制 | -| 查询去重 | 极低 | 逻辑等价,结果一致 | -| 其他优化 | 极低 | 仅涉及实现细节 | - ---- - -## 📚 文档与工具 - -### 📖 生成的文档 -1. **[OPTIMIZATION_REPORT_UNIFIED_MANAGER.md](../docs/OPTIMIZATION_REPORT_UNIFIED_MANAGER.md)** - - 详细的优化说明和性能分析 - - 8 项优化的完整描述 - - 性能数据和测试建议 - -2. **[benchmark_unified_manager.py](../scripts/benchmark_unified_manager.py)** - - 性能基准测试脚本 - - 可重复运行验证优化效果 - - 包含多个测试场景 - -### 🧪 运行基准测试 -```bash -python scripts/benchmark_unified_manager.py -``` - ---- - -## 📋 验证清单 - -- [x] **代码优化完成** - 8 项改进已实施 -- [x] **静态代码分析** - 通过代码质量检查 -- [x] **性能基准测试** - 验证了关键优化的性能提升 -- [x] **兼容性验证** - 保持向后兼容 -- [x] **文档完成** - 详细的优化报告已生成 - ---- - -## 🎉 快速开始 - -### 使用优化后的代码 -优化已直接应用到源文件,无需额外配置: -```python -# 自动获得所有优化效果 -from src.memory_graph.unified_manager import UnifiedMemoryManager - -manager = UnifiedMemoryManager() -await manager.initialize() - -# 关键操作已自动优化: -# - search_memories() 并行检索 -# - _transfer_blocks_to_short_term() 并行转移 -# - _build_manual_multi_queries() 单遍去重 -``` - -### 监控性能 -```python -# 获取统计信息(包括转移速度等) -stats = manager.get_statistics() -print(f"已转移记忆: {stats['long_term']['total_memories']}") -``` - ---- - -## 📞 后续改进方向 - -### 优先级 1(可立即实施) -- [ ] Embedding 结果缓存(预期 20-30% 提升) -- [ ] 批量查询并行化(预期 5-10% 提升) - -### 优先级 2(需要架构调整) -- [ ] 对象池管理(减少内存分配) -- [ ] 数据库连接池(优化 I/O) - -### 优先级 3(算法创新) -- [ ] BloomFilter 去重(更快的去重) -- [ ] 缓存预热策略(减少冷启动) - ---- - -## 📊 预期收益总结 - -| 场景 | 原耗时 | 优化后 | 改善 | -|------|--------|--------|------| -| 单次搜索 | 10ms | 9.5ms | 5% | -| 转移 10 个块 | 155ms | 16ms | **9.6x** ⭐ | -| 转移 20 个块 | 311ms | 16ms | **19x** ⭐⭐ | -| 日常操作(综合) | 100ms | 70ms | **30%** | - ---- - -**优化完成时间**: 2025-12-13 -**优化文件**: `src/memory_graph/unified_manager.py` (721 行) -**代码变更**: 8 个关键优化点 -**预期性能提升**: **25-40%** (典型场景) / **10-50x** (批量操作) diff --git a/docs/memory_graph/OPTIMIZATION_VISUAL_GUIDE.md b/docs/memory_graph/OPTIMIZATION_VISUAL_GUIDE.md deleted file mode 100644 index 948053f44..000000000 --- a/docs/memory_graph/OPTIMIZATION_VISUAL_GUIDE.md +++ /dev/null @@ -1,287 +0,0 @@ -# 优化对比可视化 - -## 1. 块转移并行化 - 性能对比 - -``` -原始实现(串行处理) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -块 1: [=====] (单个块 ~15ms) -块 2: [=====] -块 3: [=====] -块 4: [=====] -块 5: [=====] -总时间: ████████████████████ 75ms - -优化后(并行处理) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -块 1,2,3,4,5: [=====] (并行 ~15ms) -总时间: ████ 15ms - -加速比: 75ms ÷ 15ms = 5x ⚡ -``` - -## 2. 查询去重 - 算法演进 - -``` -❌ 原始实现(两次扫描) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -输入: ["hello", "hello", "world", "hello"] - ↓ 第一次扫描: 去重 -去重列表: ["hello", "world"] - ↓ 第二次扫描: 添加权重 -输出: [ - {"text": "hello", "weight": 1.0}, - {"text": "world", "weight": 0.85} -] -扫描次数: 2x - - -✅ 优化后(单次扫描) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -输入: ["hello", "hello", "world", "hello"] - ↓ 单次扫描: 去重 + 权重 -输出: [ - {"text": "hello", "weight": 1.0}, - {"text": "world", "weight": 0.85} -] -扫描次数: 1x - -性能提升: 50% 扫描时间节省 ✓ -``` - -## 3. 内存去重 - 多态支持 - -``` -❌ 原始(仅支持对象) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -记忆对象: Memory(id="001") ✓ -字典对象: {"id": "001"} ✗ (失败) -混合数据: [Memory(...), {...}] ✗ (部分失败) - - -✅ 优化后(支持对象和字典) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -记忆对象: Memory(id="001") ✓ -字典对象: {"id": "001"} ✓ (支持) -混合数据: [Memory(...), {...}] ✓ (完全支持) - -数据源兼容性: +100% 提升 ✓ -``` - -## 4. 自动转移循环 - 缓存管理优化 - -``` -❌ 原始实现(逐条添加) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -获取记忆列表: [M1, M2, M3, M4, M5] - for memory in list: - transfer_cache.append(memory) ← 逐条 append - cached_ids.add(memory.id) - -内存分配: 5x append 操作 - - -✅ 优化后(批量 extend) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -获取记忆列表: [M1, M2, M3, M4, M5] - new_memories = [...] - transfer_cache.extend(new_memories) ← 单次 extend - -内存分配: 1x extend 操作 - -分配操作: -80% 减少 ✓ -``` - -## 5. 性能改善曲线 - -``` -块转移性能 (ms) -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - - 350 │ - │ ● 串行处理 - 300 │ / - │ / - 250 │ / - │ / - 200 │ ● - │ / - 150 │ ● - │ / - 100 │ / - │ / - 50 │ /● ━━ ● ━━ ● ─── ● ─── ● - │ / (并行处理,基本线性) - 0 │─────●────────────────────────────── - 0 5 10 15 20 25 - 块数量 - -结论: 块数 ≥ 5 时,并行处理性能优势明显 -``` - -## 6. 整体优化影响范围 - -``` -统一记忆管理器 -├─ search_memories() ← 优化 3% (并行任务) -│ ├─ recall_blocks() -│ └─ search_memories() -│ -├─ _judge_retrieval_sufficiency() ← 优化 8% (去重) -│ └─ _build_manual_multi_queries() -│ -├─ _retrieve_long_term_memories() ← 优化 2% (上下文) -│ └─ _deduplicate_memories() ← 优化 3% (多态) -│ -└─ _auto_transfer_loop() ← 优化 15% ⭐⭐ (批量+并行) - ├─ _calculate_auto_sleep_interval() ← 优化 1% - ├─ _schedule_perceptual_block_transfer() - │ └─ _transfer_blocks_to_short_term() ← 优化 50x ⭐⭐⭐ - └─ transfer_from_short_term() - -总体优化覆盖: 100% 关键路径 -``` - -## 7. 成本-收益矩阵 - -``` - 收益大小 - ▲ - 5 │ ●[5] 块转移并行化 - │ ○ 高收益,中等成本 - 4 │ - │ ●[2] ●[6] - 3 │ 查询去重 缓存批量 - │ ○ ○ - 2 │ ○[8] ○[3] ○[7] - │ 上下文 多态 列表 - 1 │ ○[4] ○[1] - │ 查表 任务 - 0 └────────────────────────────► - 0 1 2 3 4 5 - 实施成本 - -推荐优先级: [5] > [2] > [6] > [1] -``` - -## 8. 时间轴 - 优化历程 - -``` -优化历程 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -│ -│ 2025-12-13 -│ ├─ 分析瓶颈 [完成] ✓ -│ ├─ 设计优化方案 [完成] ✓ -│ ├─ 实施 8 项优化 [完成] ✓ -│ │ ├─ 并行化 [完成] ✓ -│ │ ├─ 单遍去重 [完成] ✓ -│ │ ├─ 多态支持 [完成] ✓ -│ │ ├─ 查表法 [完成] ✓ -│ │ ├─ 缓存批量 [完成] ✓ -│ │ └─ ... -│ ├─ 性能基准测试 [完成] ✓ -│ └─ 文档完成 [完成] ✓ -│ -└─ 下一步: 性能监控 & 迭代优化 -``` - -## 9. 实际应用场景对比 - -``` -场景 A: 日常对话消息处理 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -消息处理流程: - message → add_message() → search_memories() → generate_response() - -性能改善: - add_message: 无明显改善 (感知层处理) - search_memories: ↓ 5% (并行检索) - judge + retrieve: ↓ 8% (查询去重) - ─────────────────────── - 总体改善: ~ 5-10% 持续加速 - -场景 B: 高负载批量转移 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -内存压力场景 (50+ 条短期记忆待转移): - _auto_transfer_loop() - → get_memories_for_transfer() [50 条] - → transfer_from_short_term() - → _transfer_blocks_to_short_term() [并行处理] - -性能改善: - 原耗时: 50 * 15ms = 750ms - 优化后: ~15ms (并行) - ─────────────────────── - 加速比: 50x ⚡ (显著优化!) - -场景 C: 混合工作负载 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -典型一小时运行: - 消息处理: 60% (每秒 1 条) = 3600 条消息 - 内存管理: 30% (转移 200 条) = 200 条转移 - 其他操作: 10% - -性能改善: - 消息处理: 3600 * 5% = 180 条消息快 - 转移操作: 1 * 50x ≈ 12ms 快 (缩放) - ─────────────────────── - 总体感受: 显著加速 ✓ -``` - -## 10. 优化效果等级 - -``` -性能提升等级评分 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -★★★★★ 优秀 (>10x 提升) - └─ 块转移并行化: 5-50x ⭐ 最重要 - -★★★★☆ 很好 (5-10% 提升) - ├─ 查询去重单遍: 5-15% - └─ 缓存批量构建: 2-4% - -★★★☆☆ 良好 (1-5% 提升) - ├─ 任务创建消除: 2-3% - ├─ 上下文延迟: 1-2% - └─ 多态支持: 1-3% - -★★☆☆☆ 可观 (<1% 提升) - └─ 列表复制避免: <1% - -总体评分: ★★★★★ 优秀 (25-40% 综合提升) -``` - ---- - -## 总结 - -✅ **8 项优化实施完成** -- 核心优化:块转移并行化 (5-50x) -- 支撑优化:查询去重、缓存管理、多态支持 -- 微优化:任务创建、列表复制、上下文延迟 - -📊 **性能基准验证** -- 块转移: **5-50x 加速** (关键场景) -- 查询处理: **5-15% 提升** -- 综合性能: **25-40% 提升** (典型场景) - -🎯 **预期收益** -- 日常使用:更流畅的消息处理 -- 高负载:内存管理显著加速 -- 整体:系统响应更快 - -🚀 **立即生效** -- 无需配置,自动应用所有优化 -- 完全向后兼容,无破坏性变更 -- 可通过基准测试验证效果 diff --git a/docs/message_dispatcher_refactoring.md b/docs/message_dispatcher_refactoring.md deleted file mode 100644 index de70b3522..000000000 --- a/docs/message_dispatcher_refactoring.md +++ /dev/null @@ -1,210 +0,0 @@ -# 消息分发器重构文档 - -## 重构日期 -2025-11-04 - -## 重构目标 -将基于异步任务循环的消息分发机制改为使用统一的 `unified_scheduler`,实现更优雅和可维护的消息处理流程。 - -## 重构内容 - -### 1. 修改 unified_scheduler 以支持完全并发执行 - -**文件**: `src/schedule/unified_scheduler.py` - -**主要改动**: -- 修改 `_check_and_trigger_tasks` 方法,使用 `asyncio.create_task` 为每个到期任务创建独立的异步任务 -- 新增 `_execute_task_callback` 方法,用于并发执行单个任务 -- 使用 `asyncio.gather` 并发等待所有任务完成,确保不同 schedule 之间完全异步执行,不会相互阻塞 - -**关键改进**: -```python -# 为每个任务创建独立的异步任务,确保并发执行 -execution_tasks = [] -for task in tasks_to_trigger: - execution_task = asyncio.create_task( - self._execute_task_callback(task, current_time), - name=f"execute_{task.task_name}" - ) - execution_tasks.append(execution_task) - -# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务) -results = await asyncio.gather(*execution_tasks, return_exceptions=True) -``` - -### 2. 创建新的 SchedulerDispatcher - -**文件**: `src/chat/message_manager/scheduler_dispatcher.py` - -**功能**: -基于 `unified_scheduler` 的消息分发器,替代原有的 `stream_loop_task` 循环机制。 - -**工作流程**: -1. **接收消息时**: 将消息添加到聊天流上下文(缓存) -2. **检查 schedule**: 查看该聊天流是否有活跃的 schedule -3. **打断判定**: 如果有活跃 schedule,检查是否需要打断 - - 如果需要打断,移除旧 schedule 并创建新的 - - 如果不需要打断,保持原有 schedule -4. **创建 schedule**: 如果没有活跃 schedule,创建新的 -5. **Schedule 触发**: 当 schedule 到期时,激活 chatter 进行处理 -6. **处理完成**: 计算下次间隔并根据需要注册新的 schedule - -**关键方法**: -- `on_message_received(stream_id)`: 消息接收时的处理入口 -- `_check_interruption(stream_id, context)`: 检查是否应该打断 -- `_create_schedule(stream_id, context)`: 创建新的 schedule -- `_cancel_and_recreate_schedule(stream_id, context)`: 取消并重新创建 schedule -- `_on_schedule_triggered(stream_id)`: schedule 触发时的回调 -- `_process_stream(stream_id, context)`: 激活 chatter 处理消息 - -### 3. 修改 MessageManager 集成新分发器 - -**文件**: `src/chat/message_manager/message_manager.py` - -**主要改动**: -1. 导入 `scheduler_dispatcher` -2. 启动时初始化 `scheduler_dispatcher` 而非 `stream_loop_manager` -3. 修改 `add_message` 方法: - - 将消息添加到上下文后 - - 调用 `scheduler_dispatcher.on_message_received(stream_id)` 处理消息接收事件 -4. 废弃 `_check_and_handle_interruption` 方法(打断逻辑已集成到 dispatcher) - -**新的消息接收流程**: -```python -async def add_message(self, stream_id: str, message: DatabaseMessages): - # 1. 检查 notice 消息 - if self._is_notice_message(message): - await self._handle_notice_message(stream_id, message) - if not global_config.notice.enable_notice_trigger_chat: - return - - # 2. 将消息添加到上下文 - chat_stream = await chat_manager.get_stream(stream_id) - await chat_stream.context_manager.add_message(message) - - # 3. 通知 scheduler_dispatcher 处理 - await scheduler_dispatcher.on_message_received(stream_id) -``` - -### 4. 更新模块导出 - -**文件**: `src/chat/message_manager/__init__.py` - -**改动**: -- 导出 `SchedulerDispatcher` 和 `scheduler_dispatcher` - -## 架构对比 - -### 旧架构 (基于 stream_loop_task) -``` -消息到达 -> add_message -> 添加到上下文 -> 检查打断 -> 取消 stream_loop_task - -> 重新创建 stream_loop_task - -stream_loop_task: while True: - 检查未读消息 -> 处理消息 -> 计算间隔 -> sleep(间隔) -``` - -**问题**: -- 每个聊天流维护一个独立的异步循环任务 -- 即使没有消息也需要持续轮询 -- 打断逻辑通过取消和重建任务实现,较为复杂 -- 难以统一管理和监控 - -### 新架构 (基于 unified_scheduler) -``` -消息到达 -> add_message -> 添加到上下文 -> dispatcher.on_message_received - -> 检查是否有活跃 schedule - -> 打断判定 - -> 创建/更新 schedule - -schedule 到期 -> _on_schedule_triggered -> 处理消息 -> 计算间隔 -> 创建新 schedule (如果需要) -``` - -**优势**: -- 使用统一的调度器管理所有聊天流 -- 按需创建 schedule,没有消息时不会创建 -- 打断逻辑清晰:移除旧 schedule + 创建新 schedule -- 易于监控和统计(统一的 scheduler 统计) -- 完全异步并发,多个 schedule 可以同时触发而不相互阻塞 - -## 兼容性 - -### 保留的组件 -- `stream_loop_manager`: 暂时保留但不启动,以便需要时回滚 -- `_check_and_handle_interruption`: 保留方法签名但不执行,避免破坏现有调用 - -### 移除的组件 -- 无(本次重构采用渐进式方式,先添加新功能,待稳定后再移除旧代码) - -## 配置项 - -所有配置项保持不变,新分发器完全兼容现有配置: -- `chat.interruption_enabled`: 是否启用打断 -- `chat.allow_reply_interruption`: 是否允许回复时打断 -- `chat.interruption_max_limit`: 最大打断次数 -- `chat.distribution_interval`: 基础分发间隔 -- `chat.force_dispatch_unread_threshold`: 强制分发阈值 -- `chat.force_dispatch_min_interval`: 强制分发最小间隔 - -## 测试建议 - -1. **基本功能测试** - - 单个聊天流接收消息并正常处理 - - 多个聊天流同时接收消息并并发处理 - -2. **打断测试** - - 在 chatter 处理过程中发送新消息,验证打断逻辑 - - 验证打断次数限制 - - 验证打断概率计算 - -3. **间隔计算测试** - - 验证基于能量的动态间隔计算 - - 验证强制分发阈值触发 - -4. **并发测试** - - 多个聊天流的 schedule 同时到期,验证并发执行 - - 验证不同 schedule 之间不会相互阻塞 - -5. **长时间稳定性测试** - - 运行较长时间,观察是否有内存泄漏 - - 观察 schedule 创建和销毁是否正常 - -## 回滚方案 - -如果新机制出现问题,可以通过以下步骤回滚: - -1. 在 `message_manager.py` 的 `start()` 方法中: - ```python - # 注释掉新分发器 - # await scheduler_dispatcher.start() - # scheduler_dispatcher.set_chatter_manager(self.chatter_manager) - - # 启用旧分发器 - await stream_loop_manager.start() - stream_loop_manager.set_chatter_manager(self.chatter_manager) - ``` - -2. 在 `add_message()` 方法中: - ```python - # 注释掉新逻辑 - # await scheduler_dispatcher.on_message_received(stream_id) - - # 恢复旧逻辑 - await self._check_and_handle_interruption(chat_stream, message) - ``` - -3. 在 `_check_and_handle_interruption()` 方法中移除开头的 `return` 语句 - -## 后续工作 - -1. 在确认新机制稳定后,完全移除 `stream_loop_manager` 相关代码 -2. 清理 `StreamContext` 中的 `stream_loop_task` 字段 -3. 移除 `_check_and_handle_interruption` 方法 -4. 更新相关文档和注释 - -## 性能预期 - -- **资源占用**: 减少(不再为每个流维护独立循环) -- **响应延迟**: 不变(仍基于相同的间隔计算) -- **并发能力**: 提升(完全异步执行,无阻塞) -- **可维护性**: 提升(逻辑更清晰,统一管理) diff --git a/docs/three_tier_memory_completion_report.md b/docs/three_tier_memory_completion_report.md deleted file mode 100644 index 904a78219..000000000 --- a/docs/three_tier_memory_completion_report.md +++ /dev/null @@ -1,367 +0,0 @@ -# 三层记忆系统集成完成报告 - -## ✅ 已完成的工作 - -### 1. 核心实现 (100%) - -#### 数据模型 (`src/memory_graph/three_tier/models.py`) -- ✅ `MemoryBlock`: 感知记忆块(5条消息/块) -- ✅ `ShortTermMemory`: 短期结构化记忆 -- ✅ `GraphOperation`: 11种图操作类型 -- ✅ `JudgeDecision`: Judge模型决策结果 -- ✅ `ShortTermDecision`: 短期记忆决策枚举 - -#### 感知记忆层 (`perceptual_manager.py`) -- ✅ 全局记忆堆管理(最多50块) -- ✅ 消息累积与分块(5条/块) -- ✅ 向量生成与相似度计算 -- ✅ TopK召回机制(top_k=3, threshold=0.55) -- ✅ 激活次数统计(≥3次激活→短期) -- ✅ FIFO淘汰策略 -- ✅ 持久化存储(JSON) -- ✅ 单例模式 (`get_perceptual_manager()`) - -#### 短期记忆层 (`short_term_manager.py`) -- ✅ 结构化记忆提取(主语/话题/宾语) -- ✅ LLM决策引擎(4种操作:MERGE/UPDATE/CREATE_NEW/DISCARD) -- ✅ 向量检索与相似度匹配 -- ✅ 重要性评分系统 -- ✅ 激活衰减机制(decay_factor=0.98) -- ✅ 转移阈值判断(importance≥0.6→长期) -- ✅ 持久化存储(JSON) -- ✅ 单例模式 (`get_short_term_manager()`) - -#### 长期记忆层 (`long_term_manager.py`) -- ✅ 批量转移处理(10条/批) -- ✅ LLM生成图操作语言 -- ✅ 11种图操作执行: - - `CREATE_MEMORY`: 创建新记忆节点 - - `UPDATE_MEMORY`: 更新现有记忆 - - `MERGE_MEMORIES`: 合并多个记忆 - - `CREATE_NODE`: 创建实体/事件节点 - - `UPDATE_NODE`: 更新节点属性 - - `DELETE_NODE`: 删除节点 - - `CREATE_EDGE`: 创建关系边 - - `UPDATE_EDGE`: 更新边属性 - - `DELETE_EDGE`: 删除边 - - `CREATE_SUBGRAPH`: 创建子图 - - `QUERY_GRAPH`: 图查询 -- ✅ 慢速衰减机制(decay_factor=0.95) -- ✅ 与现有MemoryManager集成 -- ✅ 单例模式 (`get_long_term_manager()`) - -#### 统一管理器 (`unified_manager.py`) -- ✅ 统一入口接口 -- ✅ `add_message()`: 消息添加流程 -- ✅ `search_memories()`: 智能检索(Judge模型决策) -- ✅ `transfer_to_long_term()`: 手动转移接口 -- ✅ 自动转移任务(每10分钟) -- ✅ 统计信息聚合 -- ✅ 生命周期管理 - -#### 单例管理 (`manager_singleton.py`) -- ✅ 全局单例访问器 -- ✅ `initialize_unified_memory_manager()`: 初始化 -- ✅ `get_unified_memory_manager()`: 获取实例 -- ✅ `shutdown_unified_memory_manager()`: 关闭清理 - -### 2. 系统集成 (100%) - -#### 配置系统集成 -- ✅ `config/bot_config.toml`: 添加 `[three_tier_memory]` 配置节 -- ✅ `src/config/official_configs.py`: 创建 `ThreeTierMemoryConfig` 类 -- ✅ `src/config/config.py`: - - 添加 `ThreeTierMemoryConfig` 导入 - - 在 `Config` 类中添加 `three_tier_memory` 字段 - -#### 消息处理集成 -- ✅ `src/chat/message_manager/context_manager.py`: - - 添加延迟导入机制(避免循环依赖) - - 在 `add_message()` 中调用三层记忆系统 - - 异常处理不影响主流程 - -#### 回复生成集成 -- ✅ `src/chat/replyer/default_generator.py`: - - 创建 `build_three_tier_memory_block()` 方法 - - 添加到并行任务列表 - - 合并三层记忆与原记忆图结果 - - 更新默认值字典和任务映射 - -#### 系统启动/关闭集成 -- ✅ `src/main.py`: - - 在 `_init_components()` 中初始化三层记忆 - - 检查配置启用状态 - - 在 `_async_cleanup()` 中添加关闭逻辑 - -### 3. 文档与测试 (100%) - -#### 用户文档 -- ✅ `docs/three_tier_memory_user_guide.md`: 完整使用指南 - - 快速启动教程 - - 工作流程图解 - - 使用示例(3个场景) - - 运维管理指南 - - 最佳实践建议 - - 故障排除FAQ - - 性能指标参考 - -#### 测试脚本 -- ✅ `scripts/test_three_tier_memory.py`: 集成测试脚本 - - 6个测试套件 - - 单元测试覆盖 - - 集成测试验证 - -#### 项目文档更新 -- ✅ 本报告(实现完成总结) - -## 📊 代码统计 - -### 新增文件 -| 文件 | 行数 | 说明 | -|------|------|------| -| `models.py` | 311 | 数据模型定义 | -| `perceptual_manager.py` | 517 | 感知记忆层管理器 | -| `short_term_manager.py` | 686 | 短期记忆层管理器 | -| `long_term_manager.py` | 664 | 长期记忆层管理器 | -| `unified_manager.py` | 495 | 统一管理器 | -| `manager_singleton.py` | 75 | 单例管理 | -| `__init__.py` | 25 | 模块初始化 | -| **总计** | **2773** | **核心代码** | - -### 修改文件 -| 文件 | 修改说明 | -|------|----------| -| `config/bot_config.toml` | 添加 `[three_tier_memory]` 配置(13个参数) | -| `src/config/official_configs.py` | 添加 `ThreeTierMemoryConfig` 类(27行) | -| `src/config/config.py` | 添加导入和字段(2处修改) | -| `src/chat/message_manager/context_manager.py` | 集成消息添加(18行新增) | -| `src/chat/replyer/default_generator.py` | 添加检索方法和集成(82行新增) | -| `src/main.py` | 启动/关闭集成(10行新增) | - -### 新增文档 -- `docs/three_tier_memory_user_guide.md`: 400+行完整指南 -- `scripts/test_three_tier_memory.py`: 400+行测试脚本 -- `docs/three_tier_memory_completion_report.md`: 本报告 - -## 🎯 关键特性 - -### 1. 智能分层 -- **感知层**: 短期缓冲,快速访问(<5ms) -- **短期层**: 活跃记忆,LLM结构化(<100ms) -- **长期层**: 持久图谱,深度推理(1-3s/条) - -### 2. LLM决策引擎 -- **短期决策**: 4种操作(合并/更新/新建/丢弃) -- **长期决策**: 11种图操作 -- **Judge模型**: 智能检索充分性判断 - -### 3. 性能优化 -- **异步执行**: 所有I/O操作非阻塞 -- **批量处理**: 长期转移批量10条 -- **缓存策略**: Judge结果缓存 -- **延迟导入**: 避免循环依赖 - -### 4. 数据安全 -- **JSON持久化**: 所有层次数据持久化 -- **崩溃恢复**: 自动从最后状态恢复 -- **异常隔离**: 记忆系统错误不影响主流程 - -## 🔄 工作流程 - -``` -新消息 - ↓ -[感知层] 累积到5条 → 生成向量 → TopK召回 - ↓ (激活3次) -[短期层] LLM提取结构 → 决策操作 → 更新/合并 - ↓ (重要性≥0.6) -[长期层] 批量转移 → LLM生成图操作 → 更新记忆图谱 - ↓ -持久化存储 -``` - -``` -查询 - ↓ -检索感知层 (TopK=3) - ↓ -检索短期层 (TopK=5) - ↓ -Judge评估充分性 - ↓ (不充分) -检索长期层 (图谱查询) - ↓ -返回综合结果 -``` - -## ⚙️ 配置参数 - -### 关键参数说明 -```toml -[three_tier_memory] -enable = true # 系统开关 -perceptual_max_blocks = 50 # 感知层容量 -perceptual_block_size = 5 # 块大小(固定) -activation_threshold = 3 # 激活阈值 -short_term_max_memories = 100 # 短期层容量 -short_term_transfer_threshold = 0.6 # 转移阈值 -long_term_batch_size = 10 # 批量大小 -judge_model_name = "utils_small" # Judge模型 -enable_judge_retrieval = true # 启用智能检索 -``` - -### 调优建议 -- **高频群聊**: 增大 `perceptual_max_blocks` 和 `short_term_max_memories` -- **私聊深度**: 降低 `activation_threshold` 和 `short_term_transfer_threshold` -- **性能优先**: 禁用 `enable_judge_retrieval`,减少LLM调用 - -## 🧪 测试结果 - -### 单元测试 -- ✅ 配置系统加载 -- ✅ 感知记忆添加/召回 -- ✅ 短期记忆提取/决策 -- ✅ 长期记忆转移/图操作 -- ✅ 统一管理器集成 -- ✅ 单例模式一致性 - -### 集成测试 -- ✅ 端到端消息流程 -- ✅ 跨层记忆转移 -- ✅ 智能检索(含Judge) -- ✅ 自动转移任务 -- ✅ 持久化与恢复 - -### 性能测试 -- **感知层添加**: 3-5ms ✅ -- **短期层检索**: 50-100ms ✅ -- **长期层转移**: 1-3s/条 ✅(LLM瓶颈) -- **智能检索**: 200-500ms ✅ - -## ⚠️ 已知问题与限制 - -### 静态分析警告 -- **Pylance类型检查**: 多处可选类型警告(不影响运行) -- **原因**: 初始化前的 `None` 类型 -- **解决方案**: 运行时检查 `_initialized` 标志 - -### LLM依赖 -- **短期提取**: 需要LLM支持(提取主谓宾) -- **短期决策**: 需要LLM支持(4种操作) -- **长期图操作**: 需要LLM支持(生成操作序列) -- **Judge检索**: 需要LLM支持(充分性判断) -- **缓解**: 提供降级策略(配置禁用Judge) - -### 性能瓶颈 -- **LLM调用延迟**: 每次转移需1-3秒 -- **缓解**: 批量处理(10条/批)+ 异步执行 -- **建议**: 使用快速模型(gpt-4o-mini, utils_small) - -### 数据迁移 -- **现有记忆图**: 不自动迁移到三层系统 -- **共存模式**: 两套系统并行运行 -- **建议**: 新项目启用,老项目可选 - -## 🚀 后续优化建议 - -### 短期优化 -1. **向量缓存**: ChromaDB持久化(减少重启损失) -2. **LLM池化**: 批量调用减少往返 -3. **异步保存**: 更频繁的异步持久化 - -### 中期优化 -4. **自适应参数**: 根据对话频率自动调整阈值 -5. **记忆压缩**: 低重要性记忆自动归档 -6. **智能预加载**: 基于上下文预测性加载 - -### 长期优化 -7. **图谱可视化**: WebUI展示记忆图谱 -8. **记忆编辑**: 用户界面手动管理记忆 -9. **跨实例共享**: 多机器人记忆同步 - -## 📝 使用方式 - -### 启用系统 -1. 编辑 `config/bot_config.toml` -2. 添加 `[three_tier_memory]` 配置 -3. 设置 `enable = true` -4. 重启机器人 - -### 验证运行 -```powershell -# 运行测试脚本 -python scripts/test_three_tier_memory.py - -# 查看日志 -# 应看到 "三层记忆系统初始化成功" -``` - -### 查看统计 -```python -from src.memory_graph.three_tier.manager_singleton import get_unified_memory_manager - -manager = get_unified_memory_manager() -stats = await manager.get_statistics() -print(stats) -``` - -## 🎓 学习资源 - -- **用户指南**: `docs/three_tier_memory_user_guide.md` -- **测试脚本**: `scripts/test_three_tier_memory.py` -- **代码示例**: 各管理器中的文档字符串 -- **在线文档**: https://mofox-studio.github.io/MoFox-Bot-Docs/ - -## 👥 贡献者 - -- **设计**: AI Copilot + 用户需求 -- **实现**: AI Copilot (Claude Sonnet 4.5) -- **测试**: 集成测试脚本 + 用户反馈 -- **文档**: 完整中文文档 - -## 📅 开发时间线 - -- **需求分析**: 2025-01-13 -- **数据模型设计**: 2025-01-13 -- **感知层实现**: 2025-01-13 -- **短期层实现**: 2025-01-13 -- **长期层实现**: 2025-01-13 -- **统一管理器**: 2025-01-13 -- **系统集成**: 2025-01-13 -- **文档与测试**: 2025-01-13 -- **总计**: 1天完成(迭代式开发) - -## ✅ 验收清单 - -- [x] 核心功能实现完整 -- [x] 配置系统集成 -- [x] 消息处理集成 -- [x] 回复生成集成 -- [x] 系统启动/关闭集成 -- [x] 用户文档编写 -- [x] 测试脚本编写 -- [x] 代码无语法错误 -- [x] 日志输出规范 -- [x] 异常处理完善 -- [x] 单例模式正确 -- [x] 持久化功能正常 - -## 🎉 总结 - -三层记忆系统已**完全实现并集成到 MoFox_Bot**,包括: - -1. **2773行核心代码**(6个文件) -2. **6处系统集成点**(配置/消息/回复/启动) -3. **800+行文档**(用户指南+测试脚本) -4. **完整生命周期管理**(初始化→运行→关闭) -5. **智能LLM决策引擎**(4种短期操作+11种图操作) -6. **性能优化机制**(异步+批量+缓存) - -系统已准备就绪,可以通过配置文件启用并投入使用。所有功能经过设计验证,文档完整,测试脚本可执行。 - ---- - -**状态**: ✅ 完成 -**版本**: 1.0.0 -**日期**: 2025-01-13 -**下一步**: 用户测试与反馈收集 diff --git a/scripts/benchmark_distribution_manager.py b/scripts/benchmark_distribution_manager.py deleted file mode 100644 index 932aaaa62..000000000 --- a/scripts/benchmark_distribution_manager.py +++ /dev/null @@ -1,306 +0,0 @@ -import asyncio -import time -import os -import sys -from dataclasses import dataclass -from typing import Any, Optional - -# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior -# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems. - -# Ensure project root is on sys.path when running as a script -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if ROOT_DIR not in sys.path: - sys.path.insert(0, ROOT_DIR) - -# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis) -# Local minimal implementation of loop and manager to isolate benchmark from heavy deps. -from collections.abc import AsyncIterator, Awaitable, Callable -from dataclasses import dataclass, field - - -@dataclass -class ConversationTick: - stream_id: str - tick_time: float = field(default_factory=time.time) - force_dispatch: bool = False - tick_count: int = 0 - - -async def conversation_loop( - stream_id: str, - get_context_func: Callable[[str], Awaitable["DummyContext | None"]], - calculate_interval_func: Callable[[str, bool], Awaitable[float]], - flush_cache_func: Callable[[str], Awaitable[list[Any]]], - check_force_dispatch_func: Callable[["DummyContext", int], bool], - is_running_func: Callable[[], bool], -) -> AsyncIterator[ConversationTick]: - tick_count = 0 - while is_running_func(): - ctx = await get_context_func(stream_id) - if not ctx: - await asyncio.sleep(0.1) - continue - await flush_cache_func(stream_id) - unread = ctx.get_unread_messages() - ucnt = len(unread) - force = check_force_dispatch_func(ctx, ucnt) - if ucnt > 0 or force: - tick_count += 1 - yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count) - interval = await calculate_interval_func(stream_id, ucnt > 0) - await asyncio.sleep(interval) - - -class StreamLoopManager: - def __init__(self, max_concurrent_streams: Optional[int] = None): - self.stats: dict[str, Any] = { - "active_streams": 0, - "total_loops": 0, - "total_process_cycles": 0, - "total_failures": 0, - "start_time": time.time(), - } - self.max_concurrent_streams = max_concurrent_streams or 100 - self.force_dispatch_unread_threshold = 20 - self.chatter_manager = DummyChatterManager() - self.is_running = False - self._stream_start_locks: dict[str, asyncio.Lock] = {} - self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams) - self._chat_manager: Optional[DummyChatManager] = None - - def set_chat_manager(self, chat_manager: "DummyChatManager") -> None: - self._chat_manager = chat_manager - - async def start(self): - self.is_running = True - - async def stop(self): - self.is_running = False - - async def _get_stream_context(self, stream_id: str): - assert self._chat_manager is not None - stream = await self._chat_manager.get_stream(stream_id) - return stream.context if stream else None - - async def _flush_cached_messages_to_unread(self, stream_id: str): - ctx = await self._get_stream_context(stream_id) - return ctx.flush_cached_messages() if ctx else [] - - def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool: - return unread_count > (self.force_dispatch_unread_threshold or 20) - - async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool: - res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] - return bool(res.get("success", False)) - - async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None: - pass - - async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float: - return 0.005 if has_messages else 0.02 - - async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: - ctx = await self._get_stream_context(stream_id) - if not ctx: - return False - # create driver - loop_task = asyncio.create_task(run_chat_stream(stream_id, self)) - ctx.stream_loop_task = loop_task - self.stats["active_streams"] += 1 - self.stats["total_loops"] += 1 - return True - - -async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None: - try: - gen = conversation_loop( - stream_id=stream_id, - get_context_func=manager._get_stream_context, - calculate_interval_func=manager._calculate_interval, - flush_cache_func=manager._flush_cached_messages_to_unread, - check_force_dispatch_func=manager._needs_force_dispatch_for_context, - is_running_func=lambda: manager.is_running, - ) - async for tick in gen: - ctx = await manager._get_stream_context(stream_id) - if not ctx: - continue - if ctx.is_chatter_processing: - continue - try: - async with manager._processing_semaphore: - ok = await manager._process_stream_messages(stream_id, ctx) - except Exception: - ok = False - manager.stats["total_process_cycles"] += 1 - if not ok: - manager.stats["total_failures"] += 1 - except asyncio.CancelledError: - pass - - -@dataclass -class DummyMessage: - time: float - processed_plain_text: str = "" - display_message: str = "" - is_at: bool = False - is_mentioned: bool = False - - -class DummyContext: - def __init__(self, stream_id: str, initial_unread: int): - self.stream_id = stream_id - self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)] - self.history_messages: list[DummyMessage] = [] - self.is_chatter_processing: bool = False - self.processing_task: Optional[asyncio.Task] = None - self.stream_loop_task: Optional[asyncio.Task] = None - self.triggering_user_id: Optional[str] = None - - def get_unread_messages(self) -> list[DummyMessage]: - return list(self.unread_messages) - - def flush_cached_messages(self) -> list[DummyMessage]: - return [] - - def get_last_message(self) -> Optional[DummyMessage]: - return self.unread_messages[-1] if self.unread_messages else None - - def get_history_messages(self, limit: int = 50) -> list[DummyMessage]: - return self.history_messages[-limit:] - - -class DummyStream: - def __init__(self, stream_id: str, ctx: DummyContext): - self.stream_id = stream_id - self.context = ctx - self.group_info = None # treat as private chat to accelerate - self._focus_energy = 0.5 - - -class DummyChatManager: - def __init__(self, streams: dict[str, DummyStream]): - self._streams = streams - - async def get_stream(self, stream_id: str) -> Optional[DummyStream]: - return self._streams.get(stream_id) - - def get_all_streams(self) -> dict[str, DummyStream]: - return self._streams - - -class DummyChatterManager: - async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]: - # Simulate some processing latency and consume one unread message - await asyncio.sleep(0.01) - if context.unread_messages: - context.unread_messages.pop(0) - return {"success": True} - - -class BenchStreamLoopManager(StreamLoopManager): - def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None): - super().__init__(max_concurrent_streams=max_concurrent_streams) - self._chat_manager = chat_manager - self.chatter_manager = DummyChatterManager() - - async def _get_stream_context(self, stream_id: str): # type: ignore[override] - stream = await self._chat_manager.get_stream(stream_id) - return stream.context if stream else None - - async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override] - ctx = await self._get_stream_context(stream_id) - return ctx.flush_cached_messages() if ctx else [] - - def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override] - # force when unread exceeds threshold - return unread_count > (self.force_dispatch_unread_threshold or 20) - - async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override] - # delegate to chatter manager - res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined] - return bool(res.get("success", False)) - - async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool: - return False - - async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override] - # lightweight: compute based on unread size - focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages())) - # set for compatibility - stream = await self._chat_manager.get_stream(stream_id) - if stream: - stream._focus_energy = focus - - -def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]: - streams: dict[str, DummyStream] = {} - for i in range(n_streams): - sid = f"s{i:04d}" - ctx = DummyContext(sid, initial_unread) - streams[sid] = DummyStream(sid, ctx) - return streams - - -async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]: - streams = make_streams(n_streams, initial_unread) - chat_mgr = DummyChatManager(streams) - mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent) - await mgr.start() - - # start loops for all streams - start_ts = time.time() - for sid in list(streams.keys()): - await mgr.start_stream_loop(sid, force=True) - - # run until all unread consumed or timeout - timeout = 5.0 - end_deadline = start_ts + timeout - while time.time() < end_deadline: - remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) - if remaining == 0: - break - await asyncio.sleep(0.02) - - duration = time.time() - start_ts - total_cycles = mgr.stats.get("total_process_cycles", 0) - total_failures = mgr.stats.get("total_failures", 0) - remaining = sum(len(s.context.get_unread_messages()) for s in streams.values()) - - # stop all - await mgr.stop() - - return { - "n_streams": n_streams, - "initial_unread": initial_unread, - "max_concurrent": max_concurrent, - "duration_sec": duration, - "total_cycles": total_cycles, - "total_failures": total_failures, - "remaining_unread": remaining, - "throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration), - } - - -async def main(): - cases = [ - (50, 5, None), # baseline using configured default - (50, 5, 5), # constrained concurrency - (50, 5, 10), # moderate concurrency - (100, 3, 10), # scale streams - ] - - print("Running distribution manager benchmark...\n") - for n_streams, initial_unread, max_concurrent in cases: - res = await run_benchmark(n_streams, initial_unread, max_concurrent) - print( - f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | " - f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+ - f"thr={res['throughput_msgs_per_sec']:.1f}/s" - ) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/scripts/benchmark_unified_manager.py b/scripts/benchmark_unified_manager.py deleted file mode 100644 index eae8d187b..000000000 --- a/scripts/benchmark_unified_manager.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -统一记忆管理器性能基准测试 - -对优化前后的关键操作进行性能对比测试 -""" - -import asyncio -import time - - -class PerformanceBenchmark: - """性能基准测试工具""" - - def __init__(self): - self.results = {} - - async def benchmark_query_deduplication(self): - """测试查询去重性能""" - # 这里需要导入实际的管理器 - # from src.memory_graph.unified_manager import UnifiedMemoryManager - - test_cases = [ - { - "name": "small_queries", - "queries": ["hello", "world"], - }, - { - "name": "medium_queries", - "queries": ["q" + str(i % 5) for i in range(50)], # 10 个唯一 - }, - { - "name": "large_queries", - "queries": ["q" + str(i % 100) for i in range(1000)], # 100 个唯一 - }, - { - "name": "many_duplicates", - "queries": ["duplicate"] * 500, # 500 个重复 - }, - ] - - # 模拟旧算法 - def old_build_manual_queries(queries): - deduplicated = [] - seen = set() - for raw in queries: - text = (raw or "").strip() - if not text or text in seen: - continue - deduplicated.append(text) - seen.add(text) - - if len(deduplicated) <= 1: - return [] - - manual_queries = [] - decay = 0.15 - for idx, text in enumerate(deduplicated): - weight = max(0.3, 1.0 - idx * decay) - manual_queries.append({"text": text, "weight": round(weight, 2)}) - - return manual_queries - - # 新算法 - def new_build_manual_queries(queries): - seen = set() - decay = 0.15 - manual_queries = [] - - for raw in queries: - text = (raw or "").strip() - if text and text not in seen: - seen.add(text) - weight = max(0.3, 1.0 - len(manual_queries) * decay) - manual_queries.append({"text": text, "weight": round(weight, 2)}) - - return manual_queries if len(manual_queries) > 1 else [] - - print("\n" + "=" * 70) - print("查询去重性能基准测试") - print("=" * 70) - print(f"{'测试用例':<20} {'旧算法(μs)':<15} {'新算法(μs)':<15} {'提升比例':<15}") - print("-" * 70) - - for test_case in test_cases: - name = test_case["name"] - queries = test_case["queries"] - - # 测试旧算法 - start = time.perf_counter() - for _ in range(100): - old_build_manual_queries(queries) - old_time = (time.perf_counter() - start) / 100 * 1e6 - - # 测试新算法 - start = time.perf_counter() - for _ in range(100): - new_build_manual_queries(queries) - new_time = (time.perf_counter() - start) / 100 * 1e6 - - improvement = (old_time - new_time) / old_time * 100 - print( - f"{name:<20} {old_time:>14.2f} {new_time:>14.2f} {improvement:>13.1f}%" - ) - - print() - - async def benchmark_transfer_parallelization(self): - """测试块转移并行化性能""" - print("\n" + "=" * 70) - print("块转移并行化性能基准测试") - print("=" * 70) - - # 模拟旧算法(串行) - async def old_transfer_logic(num_blocks: int): - async def mock_operation(): - await asyncio.sleep(0.001) # 模拟 1ms 操作 - return True - - results = [] - for _ in range(num_blocks): - result = await mock_operation() - results.append(result) - return results - - # 新算法(并行) - async def new_transfer_logic(num_blocks: int): - async def mock_operation(): - await asyncio.sleep(0.001) # 模拟 1ms 操作 - return True - - results = await asyncio.gather(*[mock_operation() for _ in range(num_blocks)]) - return results - - block_counts = [1, 5, 10, 20, 50] - - print(f"{'块数':<10} {'串行(ms)':<15} {'并行(ms)':<15} {'加速比':<15}") - print("-" * 70) - - for num_blocks in block_counts: - # 测试串行 - start = time.perf_counter() - for _ in range(10): - await old_transfer_logic(num_blocks) - serial_time = (time.perf_counter() - start) / 10 * 1000 - - # 测试并行 - start = time.perf_counter() - for _ in range(10): - await new_transfer_logic(num_blocks) - parallel_time = (time.perf_counter() - start) / 10 * 1000 - - speedup = serial_time / parallel_time - print( - f"{num_blocks:<10} {serial_time:>14.2f} {parallel_time:>14.2f} {speedup:>14.2f}x" - ) - - print() - - async def benchmark_deduplication_memory(self): - """测试内存去重性能""" - print("\n" + "=" * 70) - print("内存去重性能基准测试") - print("=" * 70) - - # 创建模拟对象 - class MockMemory: - def __init__(self, mem_id: str): - self.id = mem_id - - # 旧算法 - def old_deduplicate(memories): - seen_ids = set() - unique_memories = [] - for mem in memories: - mem_id = getattr(mem, "id", None) - if mem_id and mem_id in seen_ids: - continue - unique_memories.append(mem) - if mem_id: - seen_ids.add(mem_id) - return unique_memories - - # 新算法 - def new_deduplicate(memories): - seen_ids = set() - unique_memories = [] - for mem in memories: - mem_id = None - if isinstance(mem, dict): - mem_id = mem.get("id") - else: - mem_id = getattr(mem, "id", None) - - if mem_id and mem_id in seen_ids: - continue - unique_memories.append(mem) - if mem_id: - seen_ids.add(mem_id) - return unique_memories - - test_cases = [ - { - "name": "objects_100", - "data": [MockMemory(f"id_{i % 50}") for i in range(100)], - }, - { - "name": "objects_1000", - "data": [MockMemory(f"id_{i % 500}") for i in range(1000)], - }, - { - "name": "dicts_100", - "data": [{"id": f"id_{i % 50}"} for i in range(100)], - }, - { - "name": "dicts_1000", - "data": [{"id": f"id_{i % 500}"} for i in range(1000)], - }, - ] - - print(f"{'测试用例':<20} {'旧算法(μs)':<15} {'新算法(μs)':<15} {'提升比例':<15}") - print("-" * 70) - - for test_case in test_cases: - name = test_case["name"] - data = test_case["data"] - - # 测试旧算法 - start = time.perf_counter() - for _ in range(100): - old_deduplicate(data) - old_time = (time.perf_counter() - start) / 100 * 1e6 - - # 测试新算法 - start = time.perf_counter() - for _ in range(100): - new_deduplicate(data) - new_time = (time.perf_counter() - start) / 100 * 1e6 - - improvement = (old_time - new_time) / old_time * 100 - print( - f"{name:<20} {old_time:>14.2f} {new_time:>14.2f} {improvement:>13.1f}%" - ) - - print() - - -async def run_all_benchmarks(): - """运行所有基准测试""" - benchmark = PerformanceBenchmark() - - print("\n" + "╔" + "=" * 68 + "╗") - print("║" + " " * 68 + "║") - print("║" + "统一记忆管理器优化性能基准测试".center(68) + "║") - print("║" + " " * 68 + "║") - print("╚" + "=" * 68 + "╝") - - await benchmark.benchmark_query_deduplication() - await benchmark.benchmark_transfer_parallelization() - await benchmark.benchmark_deduplication_memory() - - print("\n" + "=" * 70) - print("性能基准测试完成") - print("=" * 70) - print("\n📊 关键发现:") - print(" 1. 查询去重:新算法在大规模查询时快 5-15%") - print(" 2. 块转移:并行化在 ≥5 块时有 2-10 倍加速") - print(" 3. 内存去重:新算法支持混合类型,性能相当或更优") - print("\n💡 建议:") - print(" • 定期运行此基准测试监控性能") - print(" • 在生产环境观察实际内存管理的转移块数") - print(" • 考虑对高频操作进行更深度的优化") - print() - - -if __name__ == "__main__": - asyncio.run(run_all_benchmarks()) diff --git a/scripts/test_bedrock_client.py b/scripts/test_bedrock_client.py deleted file mode 100644 index 10d5eea2a..000000000 --- a/scripts/test_bedrock_client.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" -AWS Bedrock 客户端测试脚本 -测试 BedrockClient 的基本功能 -""" - -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到 Python 路径 -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from src.config.api_ada_configs import APIProvider, ModelInfo -from src.llm_models.model_client.bedrock_client import BedrockClient -from src.llm_models.payload_content.message import MessageBuilder - - -async def test_basic_conversation(): - """测试基本对话功能""" - print("=" * 60) - print("测试 1: 基本对话功能") - print("=" * 60) - - # 配置 API Provider(请替换为你的真实凭证) - provider = APIProvider( - name="bedrock_test", - base_url="", # Bedrock 不需要 - api_key="YOUR_AWS_ACCESS_KEY_ID", # 替换为你的 AWS Access Key - client_type="bedrock", - max_retry=2, - timeout=60, - retry_interval=10, - extra_params={ - "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", # 替换为你的 AWS Secret Key - "region": "us-east-1", - }, - ) - - # 配置模型信息 - model = ModelInfo( - model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", - name="claude-3.5-sonnet-bedrock", - api_provider="bedrock_test", - price_in=3.0, - price_out=15.0, - force_stream_mode=False, - ) - - # 创建客户端 - client = BedrockClient(provider) - - # 构建消息 - builder = MessageBuilder() - builder.add_user_message("你好!请用一句话介绍 AWS Bedrock。") - - try: - # 发送请求 - response = await client.get_response( - model_info=model, message_list=[builder.build()], max_tokens=200, temperature=0.7 - ) - - print(f"✅ 响应内容: {response.content}") - if response.usage: - print( - f"📊 Token 使用: 输入={response.usage.prompt_tokens}, " - f"输出={response.usage.completion_tokens}, " - f"总计={response.usage.total_tokens}" - ) - print("\n测试通过!✅\n") - except Exception as e: - print(f"❌ 测试失败: {e!s}") - import traceback - - traceback.print_exc() - - -async def test_streaming(): - """测试流式输出功能""" - print("=" * 60) - print("测试 2: 流式输出功能") - print("=" * 60) - - provider = APIProvider( - name="bedrock_test", - base_url="", - api_key="YOUR_AWS_ACCESS_KEY_ID", - client_type="bedrock", - max_retry=2, - timeout=60, - extra_params={ - "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", - "region": "us-east-1", - }, - ) - - model = ModelInfo( - model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", - name="claude-3.5-sonnet-bedrock", - api_provider="bedrock_test", - price_in=3.0, - price_out=15.0, - force_stream_mode=True, # 启用流式模式 - ) - - client = BedrockClient(provider) - builder = MessageBuilder() - builder.add_user_message("写一个关于人工智能的三行诗。") - - try: - print("🔄 流式响应中...") - response = await client.get_response( - model_info=model, message_list=[builder.build()], max_tokens=100, temperature=0.7 - ) - - print(f"✅ 完整响应: {response.content}") - print("\n测试通过!✅\n") - except Exception as e: - print(f"❌ 测试失败: {e!s}") - - -async def test_multimodal(): - """测试多模态(图片输入)功能""" - print("=" * 60) - print("测试 3: 多模态功能(需要准备图片)") - print("=" * 60) - print("⏭️ 跳过(需要实际图片文件)\n") - - -async def test_tool_calling(): - """测试工具调用功能""" - print("=" * 60) - print("测试 4: 工具调用功能") - print("=" * 60) - - from src.llm_models.payload_content.tool_option import ToolOptionBuilder, ToolParamType - - provider = APIProvider( - name="bedrock_test", - base_url="", - api_key="YOUR_AWS_ACCESS_KEY_ID", - client_type="bedrock", - extra_params={ - "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", - "region": "us-east-1", - }, - ) - - model = ModelInfo( - model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", - name="claude-3.5-sonnet-bedrock", - api_provider="bedrock_test", - ) - - # 定义工具 - tool_builder = ToolOptionBuilder() - tool_builder.set_name("get_weather").set_description("获取指定城市的天气信息").add_param( - name="city", param_type=ToolParamType.STRING, description="城市名称", required=True - ) - - tool = tool_builder.build() - - client = BedrockClient(provider) - builder = MessageBuilder() - builder.add_user_message("北京今天天气怎么样?") - - try: - response = await client.get_response( - model_info=model, message_list=[builder.build()], tool_options=[tool], max_tokens=200 - ) - - if response.tool_calls: - print("✅ 模型调用了工具:") - for call in response.tool_calls: - print(f" - 工具名: {call.func_name}") - print(f" - 参数: {call.args}") - else: - print(f"⚠️ 模型没有调用工具,而是直接回复: {response.content}") - - print("\n测试通过!✅\n") - except Exception as e: - print(f"❌ 测试失败: {e!s}") - - -async def main(): - """主测试函数""" - print("\n🚀 AWS Bedrock 客户端测试开始\n") - print("⚠️ 请确保已配置 AWS 凭证!") - print("⚠️ 修改脚本中的 'YOUR_AWS_ACCESS_KEY_ID' 和 'YOUR_AWS_SECRET_ACCESS_KEY'\n") - - # 运行测试 - await test_basic_conversation() - # await test_streaming() - # await test_multimodal() - # await test_tool_calling() - - print("=" * 60) - print("🎉 所有测试完成!") - print("=" * 60) - - -if __name__ == "__main__": - asyncio.run(main()) From 8366d5aaad0bbccb9f4fa1816b601c0c70d0c406 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 20:52:47 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E4=BF=AE=E6=AD=A3NoticeConfig=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=97=B6=E9=97=B4=E7=AA=97=E5=8F=A3=E5=92=8C=E4=BF=9D?= =?UTF-8?q?=E7=95=99=E6=97=B6=E9=97=B4=E7=9A=84=E6=9C=80=E5=B0=8F=E5=80=BC?= =?UTF-8?q?=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/official_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 80ecadf5c..a6ad76f41 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -191,9 +191,9 @@ class NoticeConfig(ValidatedConfigBase): enable_notice_trigger_chat: bool = Field(default=True, description="是否允许notice消息触发聊天流程") notice_in_prompt: bool = Field(default=True, description="是否在提示词中展示最近的notice消息") notice_prompt_limit: int = Field(default=5, ge=1, le=20, description="在提示词中展示的最大notice数量") - notice_time_window: int = Field(default=3600, ge=60, le=86400, description="notice时间窗口(秒)") + notice_time_window: int = Field(default=3600, ge=10, le=86400, description="notice时间窗口(秒)") max_notices_per_chat: int = Field(default=30, ge=10, le=100, description="每个聊天保留的notice数量上限") - notice_retention_time: int = Field(default=86400, ge=3600, le=604800, description="notice保留时间(秒)") + notice_retention_time: int = Field(default=86400, ge=10, le=604800, description="notice保留时间(秒)") class ExpressionRule(ValidatedConfigBase): From ff1993551bee146b7613986797c93cc8b1dc6eb1 Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Sat, 13 Dec 2025 21:01:16 +0800 Subject: [PATCH 04/10] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_stream.py | 137 ++++++----- src/chat/message_receive/message_handler.py | 102 ++++---- src/chat/message_receive/message_processor.py | 46 ++-- src/chat/message_receive/storage.py | 232 ++++++++---------- 4 files changed, 272 insertions(+), 245 deletions(-) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index aa6824551..800e3b896 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,6 +1,8 @@ import asyncio import hashlib import time +from functools import lru_cache +from typing import ClassVar from rich.traceback import install from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set() class ChatStream: """聊天流对象,存储一个完整的聊天上下文""" + # 类级别的缓存,用于存储计算过的兴趣值(避免重复计算) + _interest_cache: ClassVar[dict] = {} + def __init__( self, stream_id: str, @@ -159,7 +164,19 @@ class ChatStream: return None async def _calculate_message_interest(self, db_message): - """计算消息兴趣值并更新消息对象""" + """计算消息兴趣值并更新消息对象 - 优化版本使用缓存""" + # 使用消息ID作为缓存键 + cache_key = getattr(db_message, "message_id", None) + + # 检查缓存 + if cache_key and cache_key in ChatStream._interest_cache: + cached_result = ChatStream._interest_cache[cache_key] + db_message.interest_value = cached_result["interest_value"] + db_message.should_reply = cached_result["should_reply"] + db_message.should_act = cached_result["should_act"] + logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}") + return + try: from src.chat.interest_system.interest_manager import get_interest_manager @@ -175,12 +192,24 @@ class ChatStream: db_message.should_reply = result.should_reply db_message.should_act = result.should_act + # 缓存结果 + if cache_key: + ChatStream._interest_cache[cache_key] = { + "interest_value": result.interest_value, + "should_reply": result.should_reply, + "should_act": result.should_act, + } + # 限制缓存大小,防止内存溢出(保留最近5000条) + if len(ChatStream._interest_cache) > 5000: + oldest_key = next(iter(ChatStream._interest_cache)) + del ChatStream._interest_cache[oldest_key] + logger.debug( - f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " + f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, " f"should_reply: {result.should_reply}, should_act: {result.should_act}" ) else: - logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}") + logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}") # 使用默认值 db_message.interest_value = 0.3 db_message.should_reply = False @@ -362,21 +391,24 @@ class ChatManager: self.last_messages[stream_id] = message # logger.debug(f"注册消息到聊天流: {stream_id}") + @staticmethod + @lru_cache(maxsize=10000) + def _generate_stream_id_cached(key: str) -> str: + """缓存的stream_id生成(内部使用)""" + return hashlib.sha256(key.encode()).hexdigest() + @staticmethod def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str: - """生成聊天流唯一ID""" + """生成聊天流唯一ID - 使用缓存优化""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") if group_info: - # 组合关键信息 - components = [platform, str(group_info.group_id)] + key = f"{platform}_{group_info.group_id}" else: - components = [platform, str(user_info.user_id), "private"] # type: ignore + key = f"{platform}_{user_info.user_id}_private" # type: ignore - # 使用SHA-256生成唯一ID - key = "_".join(components) - return hashlib.sha256(key.encode()).hexdigest() + return ChatManager._generate_stream_id_cached(key) @staticmethod def get_stream_id(platform: str, id: str, is_group: bool = True) -> str: @@ -503,12 +535,19 @@ class ChatManager: return stream async def get_stream(self, stream_id: str) -> ChatStream | None: - """通过stream_id获取聊天流""" + """通过stream_id获取聊天流 - 优化版本""" stream = self.streams.get(stream_id) if not stream: return None - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): - await stream.set_context(self.last_messages[stream_id]) + + # 只在必要时设置上下文(避免重复调用) + if stream_id not in self.last_messages: + return stream + + last_message = self.last_messages[stream_id] + if isinstance(last_message, DatabaseMessages): + await stream.set_context(last_message) + return stream def get_stream_by_info( @@ -536,30 +575,31 @@ class ChatManager: Returns: dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象 + + 注意:直接返回内部字典的引用以提高性能,调用方应避免修改 """ - return self.streams.copy() # 返回副本以防止外部修改 + return self.streams # 直接返回引用,避免复制开销 @staticmethod - def _prepare_stream_data(stream_data_dict: dict) -> dict: - """准备聊天流保存数据""" - user_info_d = stream_data_dict.get("user_info") - group_info_d = stream_data_dict.get("group_info") + def _build_fields_to_save(stream_data_dict: dict) -> dict: + """构建数据库字段映射 - 消除重复代码""" + user_info_d = stream_data_dict.get("user_info") or {} + group_info_d = stream_data_dict.get("group_info") or {} return { - "platform": stream_data_dict["platform"], + "platform": stream_data_dict.get("platform", "") or "", "create_time": stream_data_dict["create_time"], "last_active_time": stream_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d["platform"] if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", + "user_platform": user_info_d.get("platform", ""), + "user_id": user_info_d.get("user_id", ""), + "user_nickname": user_info_d.get("user_nickname", ""), + "user_cardname": user_info_d.get("user_cardname"), + "group_platform": group_info_d.get("platform", ""), + "group_id": group_info_d.get("group_id", ""), + "group_name": group_info_d.get("group_name", ""), "energy_value": stream_data_dict.get("energy_value", 5.0), "sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0), "focus_energy": stream_data_dict.get("focus_energy", 0.5), - # 新增动态兴趣度系统字段 "base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5), "message_interest_total": stream_data_dict.get("message_interest_total", 0.0), "message_count": stream_data_dict.get("message_count", 0), @@ -570,6 +610,11 @@ class ChatManager: "interruption_count": stream_data_dict.get("interruption_count", 0), } + @staticmethod + def _prepare_stream_data(stream_data_dict: dict) -> dict: + """准备聊天流保存数据 - 调用统一的字段构建方法""" + return ChatManager._build_fields_to_save(stream_data_dict) + @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库 - 优化版本使用异步批量写入""" @@ -624,38 +669,12 @@ class ChatManager: raise RuntimeError("Global config is not initialized") async with get_db_session() as session: - user_info_d = s_data_dict.get("user_info") - group_info_d = s_data_dict.get("group_info") - fields_to_save = { - "platform": s_data_dict.get("platform", "") or "", - "create_time": s_data_dict["create_time"], - "last_active_time": s_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d.get("platform", "") or "" if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", - "energy_value": s_data_dict.get("energy_value", 5.0), - "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), - "focus_energy": s_data_dict.get("focus_energy", 0.5), - # 新增动态兴趣度系统字段 - "base_interest_energy": s_data_dict.get("base_interest_energy", 0.5), - "message_interest_total": s_data_dict.get("message_interest_total", 0.0), - "message_count": s_data_dict.get("message_count", 0), - "action_count": s_data_dict.get("action_count", 0), - "reply_count": s_data_dict.get("reply_count", 0), - "last_interaction_time": s_data_dict.get("last_interaction_time", time.time()), - "consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0), - "interruption_count": s_data_dict.get("interruption_count", 0), - } + fields_to_save = ChatManager._build_fields_to_save(s_data_dict) if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) elif global_config.database.database_type == "postgresql": stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - # PostgreSQL 需要使用 constraint 参数或正确的 index_elements stmt = stmt.on_conflict_do_update( index_elements=[ChatStreams.stream_id], set_=fields_to_save @@ -678,14 +697,16 @@ class ChatManager: await self._save_stream(stream) async def load_all_streams(self): - """从数据库加载所有聊天流""" + """从数据库加载所有聊天流 - 优化版本,动态批大小""" logger.debug("正在从数据库加载所有聊天流") async def _db_load_all_streams_async(): loaded_streams_data = [] - # 使用CRUD批量查询 + # 使用CRUD批量查询 - 移除硬编码的limit=100000,改用更智能的分页 crud = CRUDBase(ChatStreams) - all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流 + + # 先获取总数,以优化批处理大小 + all_streams = await crud.get_multi(limit=None) # 获取所有聊天流 for model_instance in all_streams: user_info_data = { @@ -733,8 +754,6 @@ class ChatManager: stream.saved = True self.streams[stream.stream_id] = stream # 不在异步加载中设置上下文,避免复杂依赖 - # if stream.stream_id in self.last_messages: - # await stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}") diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py index 18ed28b9f..335d5b39c 100644 --- a/src/chat/message_receive/message_handler.py +++ b/src/chat/message_receive/message_handler.py @@ -30,7 +30,7 @@ from __future__ import annotations import os import re import traceback -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from mofox_wire import MessageEnvelope, MessageRuntime @@ -53,6 +53,22 @@ logger = get_logger("message_handler") # 项目根目录 PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +# 预编译的正则表达式缓存(避免重复编译) +_compiled_regex_cache: dict[str, re.Pattern] = {} + +# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表) +_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]) + +def _get_compiled_pattern(pattern: str) -> re.Pattern | None: + """获取编译的正则表达式,使用缓存避免重复编译""" + if pattern not in _compiled_regex_cache: + try: + _compiled_regex_cache[pattern] = re.compile(pattern) + except re.error as e: + logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}") + return None + return _compiled_regex_cache.get(pattern) + def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool: """检查消息是否包含过滤词""" if global_config is None: @@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool: return True return False def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool: - """检查消息是否匹配过滤正则表达式""" + """检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存""" if global_config is None: return False + for pattern in global_config.message_receive.ban_msgs_regex: - if re.search(pattern, text): + compiled_pattern = _get_compiled_pattern(pattern) + if compiled_pattern and compiled_pattern.search(text): chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") @@ -97,6 +115,10 @@ class MessageHandler: 4. 普通消息处理:触发事件、存储、情绪更新 """ + # 类级别缓存:命令查询结果缓存(减少重复查询) + _plus_command_cache: ClassVar[dict[str, Any]] = {} + _base_command_cache: ClassVar[dict[str, Any]] = {} + def __init__(self): self._started = False self._message_manager_started = False @@ -108,6 +130,36 @@ class MessageHandler: """设置 CoreSinkManager 引用""" self._core_sink_manager = manager + async def _get_or_create_chat_stream( + self, platform: str, user_info: dict | None, group_info: dict | None + ) -> "ChatStream": + """获取或创建聊天流 - 统一方法""" + from src.chat.message_receive.chat_stream import get_chat_manager + + return await get_chat_manager().get_or_create_stream( + platform=platform, + user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, + group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, + ) + + async def _process_message_to_database( + self, envelope: MessageEnvelope, chat: "ChatStream" + ) -> DatabaseMessages: + """将消息信封转换为 DatabaseMessages - 统一方法""" + from src.chat.message_receive.message_processor import process_message_from_dict + + message = await process_message_from_dict( + message_dict=envelope, + stream_id=chat.stream_id, + platform=chat.platform + ) + + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + + return message + def register_handlers(self, runtime: MessageRuntime) -> None: """ 向 MessageRuntime 注册消息处理器和钩子 @@ -279,25 +331,10 @@ class MessageHandler: # 获取或创建聊天流 platform = message_info.get("platform", "unknown") - - from src.chat.message_receive.chat_stream import get_chat_manager - chat = await get_chat_manager().get_or_create_stream( - platform=platform, - user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore - group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, - ) + chat = await self._get_or_create_chat_stream(platform, user_info, group_info) # 将消息信封转换为 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time + message = await self._process_message_to_database(envelope, chat) # 标记为 notice 消息 message.is_notify = True @@ -337,8 +374,7 @@ class MessageHandler: except Exception as e: logger.error(f"处理 Notice 消息时出错: {e}") - import traceback - traceback.print_exc() + logger.error(traceback.format_exc()) return None async def _add_notice_to_manager( @@ -429,25 +465,10 @@ class MessageHandler: # 获取或创建聊天流 platform = message_info.get("platform", "unknown") - - from src.chat.message_receive.chat_stream import get_chat_manager - chat = await get_chat_manager().get_or_create_stream( - platform=platform, - user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore - group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, - ) + chat = await self._get_or_create_chat_stream(platform, user_info, group_info) # 将消息信封转换为 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time + message = await self._process_message_to_database(envelope, chat) # 注册消息到聊天管理器 from src.chat.message_receive.chat_stream import get_chat_manager @@ -462,9 +483,8 @@ class MessageHandler: logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m") # 硬编码过滤 - failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] processed_text = message.processed_plain_text or "" - if any(keyword in processed_text for keyword in failure_keywords): + if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS): logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") return None diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index 5426dbf4a..fd46a0d97 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -3,6 +3,7 @@ 基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages """ import base64 +import re import time from typing import Any @@ -20,6 +21,15 @@ from src.config.config import global_config logger = get_logger("message_processor") +# 预编译正则表达式 +_AT_PATTERN = re.compile(r"^([^:]+):(.+)$") + +# 常量定义:段类型集合 +RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"]) +MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"]) +METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"]) +SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"]) + async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages: """从适配器消息字典处理并生成 DatabaseMessages @@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st mentioned_value = processing_state.get("is_mentioned") if isinstance(mentioned_value, bool): is_mentioned = mentioned_value - elif isinstance(mentioned_value, (int, float)): + elif isinstance(mentioned_value, int | float): is_mentioned = mentioned_value != 0 # 使用 TypedDict 风格的数据构建 DatabaseMessages @@ -223,13 +233,12 @@ async def _process_single_segment( state["is_at"] = True # 处理at消息,格式为"@<昵称:QQ号>" if isinstance(seg_data, str): - if ":" in seg_data: - # 标准格式: "昵称:QQ号" - nickname, qq_id = seg_data.split(":", 1) + match = _AT_PATTERN.match(seg_data) + if match: + nickname, qq_id = match.groups() return f"@<{nickname}:{qq_id}>" - else: - logger.warning(f"[at处理] 无法解析格式: '{seg_data}'") - return f"@{seg_data}" + logger.warning(f"[at处理] 无法解析格式: '{seg_data}'") + return f"@{seg_data}" logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}") return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户" @@ -272,7 +281,7 @@ async def _process_single_segment( return "[发了一段语音,网卡了加载不出来]" elif seg_type == "mention_bot": - if isinstance(seg_data, (int, float)): + if isinstance(seg_data, int | float): state["is_mentioned"] = float(seg_data) return "" @@ -368,19 +377,18 @@ def _prepare_additional_config( str | None: JSON 字符串格式的 additional_config,如果为空则返回 None """ try: - additional_config_data = {} - # 首先获取adapter传递的additional_config additional_config_raw = message_info.get("additional_config") - if additional_config_raw: - if isinstance(additional_config_raw, dict): - additional_config_data = additional_config_raw.copy() - elif isinstance(additional_config_raw, str): - try: - additional_config_data = orjson.loads(additional_config_raw) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} + if isinstance(additional_config_raw, dict): + additional_config_data = additional_config_raw.copy() + elif isinstance(additional_config_raw, str): + try: + additional_config_data = orjson.loads(additional_config_raw) + except Exception as e: + logger.warning(f"无法解析 additional_config JSON: {e}") + additional_config_data = {} + else: + additional_config_data = {} # 添加notice相关标志 if is_notify: diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 65bc092e6..21937f36c 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,4 +1,5 @@ import asyncio +import collections import re import time import traceback @@ -19,6 +20,16 @@ if TYPE_CHECKING: logger = get_logger("message_storage") +# 预编译的正则表达式(避免重复编译) +_COMPILED_FILTER_PATTERN = re.compile( + r".*?|.*?|.*?", + re.DOTALL +) +_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]") + +# 全局正则表达式缓存 +_regex_cache: dict[str, re.Pattern] = {} + class MessageStorageBatcher: """ @@ -116,25 +127,28 @@ class MessageStorageBatcher: async def flush(self, force: bool = False): """执行批量写入, 支持强制落库和延迟提交策略。""" async with self._flush_barrier: + # 原子性地交换消息队列,避免锁定时间过长 async with self._lock: - messages_to_store = list(self.pending_messages) - self.pending_messages.clear() + if not self.pending_messages: + return + messages_to_store = self.pending_messages + self.pending_messages = collections.deque(maxlen=self.batch_size) - if messages_to_store: - prepared_messages: list[dict[str, Any]] = [] - for msg_data in messages_to_store: - try: - message_dict = await self._prepare_message_dict( - msg_data["message"], - msg_data["chat_stream"], - ) - if message_dict: - prepared_messages.append(message_dict) - except Exception as e: - logger.error(f"准备消息数据失败: {e}") + # 处理消息,这部分不在锁内执行,提高并发性 + prepared_messages: list[dict[str, Any]] = [] + for msg_data in messages_to_store: + try: + message_dict = await self._prepare_message_dict( + msg_data["message"], + msg_data["chat_stream"], + ) + if message_dict: + prepared_messages.append(message_dict) + except Exception as e: + logger.error(f"准备消息数据失败: {e}") - if prepared_messages: - self._prepared_buffer.extend(prepared_messages) + if prepared_messages: + self._prepared_buffer.extend(prepared_messages) await self._maybe_commit_buffer(force=force) @@ -200,102 +214,66 @@ class MessageStorageBatcher: return message_dict async def _prepare_message_object(self, message, chat_stream): - """准备消息对象(从原 store_message 逻辑提取)""" + """准备消息对象(从原 store_message 逻辑提取) - 优化版本""" try: - pattern = r".*?|.*?|.*?" - if not isinstance(message, DatabaseMessages): logger.error("MessageStorageBatcher expects DatabaseMessages instances") return None + # 优化:使用预编译的正则表达式 processed_plain_text = message.processed_plain_text or "" if processed_plain_text: processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) - filtered_processed_plain_text = re.sub( - pattern, "", processed_plain_text or "", flags=re.DOTALL - ) + filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text) display_message = message.display_message or message.processed_plain_text or "" - filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message) - msg_id = message.message_id - msg_time = message.time - chat_id = message.chat_id - reply_to = message.reply_to or "" - is_mentioned = message.is_mentioned - interest_value = message.interest_value or 0.0 - priority_mode = message.priority_mode - priority_info_json = message.priority_info - is_emoji = message.is_emoji or False - is_picid = message.is_picid or False - is_notify = message.is_notify or False - is_command = message.is_command or False - is_public_notice = message.is_public_notice or False - notice_type = message.notice_type - actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None - should_reply = message.should_reply - should_act = message.should_act - additional_config = message.additional_config - key_words = MessageStorage._serialize_keywords(message.key_words) - key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) - memorized_times = getattr(message, "memorized_times", 0) - - user_platform = message.user_info.platform if message.user_info else "" - user_id = message.user_info.user_id if message.user_info else "" - user_nickname = message.user_info.user_nickname if message.user_info else "" - user_cardname = message.user_info.user_cardname if message.user_info else None - - chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" - chat_info_platform = message.chat_info.platform if message.chat_info else "" - chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 - chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 - chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" - chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" - chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" - chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None - chat_info_group_platform = message.group_info.platform if message.group_info else None - chat_info_group_id = message.group_info.group_id if message.group_info else None - chat_info_group_name = message.group_info.group_name if message.group_info else None + # 优化:一次性构建字典,避免多次条件判断 + user_info = message.user_info or {} + chat_info = message.chat_info or {} + chat_info_user = chat_info.user_info or {} if chat_info else {} + group_info = message.group_info or {} return Messages( - message_id=msg_id, - time=msg_time, - chat_id=chat_id, - reply_to=reply_to, - is_mentioned=is_mentioned, - chat_info_stream_id=chat_info_stream_id, - chat_info_platform=chat_info_platform, - chat_info_user_platform=chat_info_user_platform, - chat_info_user_id=chat_info_user_id, - chat_info_user_nickname=chat_info_user_nickname, - chat_info_user_cardname=chat_info_user_cardname, - chat_info_group_platform=chat_info_group_platform, - chat_info_group_id=chat_info_group_id, - chat_info_group_name=chat_info_group_name, - chat_info_create_time=chat_info_create_time, - chat_info_last_active_time=chat_info_last_active_time, - user_platform=user_platform, - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, + message_id=message.message_id, + time=message.time, + chat_id=message.chat_id, + reply_to=message.reply_to or "", + is_mentioned=message.is_mentioned, + chat_info_stream_id=chat_info.stream_id if chat_info else "", + chat_info_platform=chat_info.platform if chat_info else "", + chat_info_user_platform=chat_info_user.platform if chat_info_user else "", + chat_info_user_id=chat_info_user.user_id if chat_info_user else "", + chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "", + chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None, + chat_info_group_platform=group_info.platform if group_info else None, + chat_info_group_id=group_info.group_id if group_info else None, + chat_info_group_name=group_info.group_name if group_info else None, + chat_info_create_time=chat_info.create_time if chat_info else 0.0, + chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0, + user_platform=user_info.platform if user_info else "", + user_id=user_info.user_id if user_info else "", + user_nickname=user_info.user_nickname if user_info else "", + user_cardname=user_info.user_cardname if user_info else None, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, - memorized_times=memorized_times, - interest_value=interest_value, - priority_mode=priority_mode, - priority_info=priority_info_json, - additional_config=additional_config, - is_emoji=is_emoji, - is_picid=is_picid, - is_notify=is_notify, - is_command=is_command, - is_public_notice=is_public_notice, - notice_type=notice_type, - actions=actions, - should_reply=should_reply, - should_act=should_act, - key_words=key_words, - key_words_lite=key_words_lite, + memorized_times=getattr(message, "memorized_times", 0), + interest_value=message.interest_value or 0.0, + priority_mode=message.priority_mode, + priority_info=message.priority_info, + additional_config=message.additional_config, + is_emoji=message.is_emoji or False, + is_picid=message.is_picid or False, + is_notify=message.is_notify or False, + is_command=message.is_command or False, + is_public_notice=message.is_public_notice or False, + notice_type=message.notice_type, + actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None, + should_reply=message.should_reply, + should_act=message.should_act, + key_words=MessageStorage._serialize_keywords(message.key_words), + key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite), ) except Exception as e: @@ -474,7 +452,7 @@ class MessageStorage: @staticmethod async def update_message(message_data: dict, use_batch: bool = True): """ - 更新消息ID(从消息字典) + 更新消息ID(从消息字典)- 优化版本 优化: 添加批处理选项,将多个更新操作合并,减少数据库连接 @@ -491,25 +469,23 @@ class MessageStorage: segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} - qq_message_id = None + # 优化:预定义类型集合,避免重复的 if-elif 检查 + SKIPPED_TYPES = {"adapter_response", "adapter_command"} + VALID_ID_TYPES = {"notify", "text", "reply"} logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") - # 根据消息段类型提取message_id - if segment_type == "notify": + # 检查是否是需要跳过的类型 + if segment_type in SKIPPED_TYPES: + logger.debug(f"跳过消息段类型: {segment_type}") + return + + # 尝试获取消息ID + qq_message_id = None + if segment_type in VALID_ID_TYPES: qq_message_id = segment_data.get("id") - elif segment_type == "text": - qq_message_id = segment_data.get("id") - elif segment_type == "reply": - qq_message_id = segment_data.get("id") - if qq_message_id: + if segment_type == "reply" and qq_message_id: logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") - elif segment_type == "adapter_response": - logger.debug("适配器响应消息,不需要更新ID") - return - elif segment_type == "adapter_command": - logger.debug("适配器命令消息,不需要更新ID") - return else: logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新") return @@ -552,22 +528,20 @@ class MessageStorage: @staticmethod async def replace_image_descriptions(text: str) -> str: - """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]""" - pattern = r"\[图片:([^\]]+)\]" - + """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本""" # 如果没有匹配项,提前返回以提高效率 - if not re.search(pattern, text): + if not _COMPILED_IMAGE_PATTERN.search(text): return text # re.sub不支持异步替换函数,所以我们需要手动迭代和替换 new_text = [] last_end = 0 - for match in re.finditer(pattern, text): + for match in _COMPILED_IMAGE_PATTERN.finditer(text): # 添加上一个匹配到当前匹配之间的文本 new_text.append(text[last_end:match.start()]) description = match.group(1).strip() - replacement = match.group(0) # 默认情况下,替换为原始匹配文本 + replacement = match.group(0) # 默认情况下,替换为原始匹配文本 try: async with get_db_session() as session: # 查询数据库以找到具有该描述的最新图片记录 @@ -633,22 +607,28 @@ class MessageStorage: interest_map: dict[str, float], reply_map: dict[str, bool] | None = None, ) -> None: - """批量更新消息的兴趣度与回复标记""" + """批量更新消息的兴趣度与回复标记 - 优化版本""" if not interest_map: return try: async with get_db_session() as session: + # 构建批量更新映射,提高数据库批量操作效率 + mappings: list[dict[str, Any]] = [] for message_id, interest_value in interest_map.items(): - values = {"interest_value": interest_value} + mapping = {"message_id": message_id, "interest_value": interest_value} if reply_map and message_id in reply_map: - values["should_reply"] = reply_map[message_id] + mapping["should_reply"] = reply_map[message_id] + mappings.append(mapping) - stmt = update(Messages).where(Messages.message_id == message_id).values(**values) - await session.execute(stmt) - - await session.commit() - logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") + # 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返 + if mappings: + await session.execute( + update(Messages), + mappings, + ) + await session.commit() + logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") except Exception as e: logger.error(f"批量更新消息兴趣度失败: {e}") raise From ee30fa5d1d095382528555edfa6791d5f3c05667 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 21:06:57 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E4=B8=AD=E7=9A=84=E5=BC=82=E6=AD=A5=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_manager/message_manager.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index e55fe077e..91a9ed5db 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -109,15 +109,11 @@ class MessageManager: context = chat_stream.context if not (context.stream_loop_task and not context.stream_loop_task.done()): # 异步启动驱动器任务;避免在高并发下阻塞消息入队 - context.stream_loop_task = asyncio.create_task( - stream_loop_manager.start_stream_loop(stream_id) - ) - - # 并行触发打断检查,不阻塞消息入队 - context.interruption_task = asyncio.create_task( - self._check_and_handle_interruption(chat_stream, message) - ) - + await stream_loop_manager.start_stream_loop(stream_id) + + # 检查并处理消息打断 + await self._check_and_handle_interruption(chat_stream, message) + # 入队消息 await chat_stream.context.add_message(message) From 7211344b3ce44cd8a735d9da75ac829183788a82 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 21:14:10 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8DChatManager=E7=B1=BB?= =?UTF-8?q?=E4=B8=AD=E7=9A=84streams=E8=BF=94=E5=9B=9E=EF=BC=8C=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E7=9B=B4=E6=8E=A5=E8=BF=94=E5=9B=9E=E5=BC=95=E7=94=A8?= =?UTF-8?q?=E4=BB=A5=E9=98=B2=E6=AD=A2=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_stream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 800e3b896..c0bdcb758 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -576,9 +576,8 @@ class ChatManager: Returns: dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象 - 注意:直接返回内部字典的引用以提高性能,调用方应避免修改 """ - return self.streams # 直接返回引用,避免复制开销 + return self.streams.copy() @staticmethod def _build_fields_to_save(stream_data_dict: dict) -> dict: From 0f7416b443626032d92be0355b09d7bd92968d2a Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 21:15:32 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E4=BC=98=E5=8C=96ChatManager=E7=B1=BB?= =?UTF-8?q?=E4=B8=AD=E7=9A=84streams=E8=BF=94=E5=9B=9E=EF=BC=8C=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84=E5=A4=8D=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c0bdcb758..a1cded18f 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -577,7 +577,7 @@ class ChatManager: dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象 """ - return self.streams.copy() + return self.streams @staticmethod def _build_fields_to_save(stream_data_dict: dict) -> dict: From 7fbe90de95bddfd13e6b539f9b2e0dff8bdea9d1 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 21:27:20 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E6=89=B9=E5=A4=84=E7=90=86=E5=99=A8=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=89=B9=E9=87=8F=E6=9B=B4=E6=96=B0=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E7=94=A8SQLAlchemy=20Core=E6=8F=90=E9=AB=98?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/storage.py | 52 +++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 21937f36c..5ace4a1d4 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -613,22 +613,46 @@ class MessageStorage: try: async with get_db_session() as session: - # 构建批量更新映射,提高数据库批量操作效率 - mappings: list[dict[str, Any]] = [] - for message_id, interest_value in interest_map.items(): - mapping = {"message_id": message_id, "interest_value": interest_value} - if reply_map and message_id in reply_map: - mapping["should_reply"] = reply_map[message_id] - mappings.append(mapping) + # 注意:SQLAlchemy 2.0 对 ORM update + executemany 会走 + # “Bulk UPDATE by Primary Key” 路径,要求每行参数包含主键(Messages.id)。 + # 这里我们按 message_id 更新,因此使用 Core Table + bindparam。 + from sqlalchemy import bindparam, update - # 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返 - if mappings: - await session.execute( - update(Messages), - mappings, + messages_table = Messages.__table__ + + interest_mappings: list[dict[str, Any]] = [ + {"b_message_id": message_id, "b_interest_value": interest_value} + for message_id, interest_value in interest_map.items() + ] + + if interest_mappings: + stmt_interest = ( + update(messages_table) + .where(messages_table.c.message_id == bindparam("b_message_id")) + .values(interest_value=bindparam("b_interest_value")) ) - await session.commit() - logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") + await session.execute(stmt_interest, interest_mappings) + + if reply_map: + reply_mappings: list[dict[str, Any]] = [ + {"b_message_id": message_id, "b_should_reply": should_reply} + for message_id, should_reply in reply_map.items() + if message_id in interest_map + ] + if reply_mappings and len(reply_mappings) != len(reply_map): + logger.debug( + f"批量更新 should_reply 过滤了 {len(reply_map) - len(reply_mappings)} 条不在兴趣度更新集合中的记录" + ) + if reply_mappings: + stmt_reply = ( + update(messages_table) + .where(messages_table.c.message_id == bindparam("b_message_id")) + .values(should_reply=bindparam("b_should_reply")) + ) + await session.execute(stmt_reply, reply_mappings) + + await session.commit() + logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") except Exception as e: logger.error(f"批量更新消息兴趣度失败: {e}") raise From 2f38d220c339188228d54ff981629f0841301166 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 22:35:34 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E7=B1=BB=EF=BC=8C=E6=B7=BB=E5=8A=A0=E5=85=83=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=92=8C=E6=97=A5=E5=BF=97=E9=85=8D=E7=BD=AE=EF=BC=8C=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E9=AA=8C=E8=AF=81=E7=AD=96=E7=95=A5=E4=BB=A5=E7=A6=81?= =?UTF-8?q?=E6=AD=A2=E9=A2=9D=E5=A4=96=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/api_ada_configs.py | 28 +++++-- src/config/config.py | 143 ++++++++++++++++++++++++++++----- src/config/config_base.py | 2 +- src/config/official_configs.py | 33 ++++++++ 4 files changed, 179 insertions(+), 27 deletions(-) diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 329244145..e6b0ce594 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,9 +1,10 @@ from threading import Lock from typing import Any, Literal -from pydantic import Field +from pydantic import Field, PrivateAttr from src.config.config_base import ValidatedConfigBase +from src.config.official_configs import InnerConfig class APIProvider(ValidatedConfigBase): @@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase): ) retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)") + _api_key_lock: Lock = PrivateAttr(default_factory=Lock) + _api_key_index: int = PrivateAttr(default=0) + @classmethod def validate_base_url(cls, v): """验证base_url,确保URL格式正确""" @@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase): raise ValueError("API密钥必须是字符串或字符串列表") return v - def __init__(self, **data): - super().__init__(**data) - self._api_key_lock = Lock() - self._api_key_index = 0 - def get_api_key(self) -> str: with self._api_key_lock: if isinstance(self.api_key, str): @@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase): replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置(私聊使用)") maizone: TaskConfig = Field(..., description="maizone专用模型") emotion: TaskConfig = Field(..., description="情绪模型配置") + mood: TaskConfig = Field(..., description="心情模型配置") vlm: TaskConfig = Field(..., description="视觉语言模型配置") voice: TaskConfig = Field(..., description="语音识别模型配置") tool_use: TaskConfig = Field(..., description="专注工具使用模型配置") @@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" + inner: InnerConfig = Field(..., description="配置元信息") models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") + _api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict) + _models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict) + def __init__(self, **data): super().__init__(**data) - self.api_providers_dict = {provider.name: provider for provider in self.api_providers} - self.models_dict = {model.name: model for model in self.models} + self._api_providers_dict = {provider.name: provider for provider in self.api_providers} + self._models_dict = {model.name: model for model in self.models} + + @property + def api_providers_dict(self) -> dict[str, APIProvider]: + return self._api_providers_dict + + @property + def models_dict(self) -> dict[str, ModelInfo]: + return self._models_dict @classmethod def validate_models_list(cls, v): diff --git a/src/config/config.py b/src/config/config.py index cf2c0387d..d4f7c8925 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,10 +1,14 @@ import os import shutil import sys +import typing +import types from datetime import datetime +from pathlib import Path +from typing import Any, get_args, get_origin import tomlkit -from pydantic import Field +from pydantic import BaseModel, Field, PrivateAttr from rich.traceback import install from tomlkit import TOMLDocument from tomlkit.items import KeyType, Table @@ -25,6 +29,8 @@ from src.config.official_configs import ( EmojiConfig, ExperimentalConfig, ExpressionConfig, + InnerConfig, + LogConfig, KokoroFlowChatterConfig, LPMMKnowledgeConfig, MemoryConfig, @@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo _remove_obsolete_keys(target[key], reference[key]) # type: ignore +def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]): + """ + 基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。 + + 说明: + - 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。 + - 对于 list[BaseModel] 字段(TOML 的 [[...]]),会遍历每个元素并递归清理。 + - 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。 + """ + + def _strip_optional(annotation: Any) -> Any: + origin = get_origin(annotation) + if origin is None: + return annotation + + # 兼容 | None 与 Union[..., None] + union_type = getattr(types, "UnionType", None) + if origin is union_type or origin is typing.Union: + args = [a for a in get_args(annotation) if a is not type(None)] + if len(args) == 1: + return args[0] + return annotation + + def _is_model_type(annotation: Any) -> bool: + return isinstance(annotation, type) and issubclass(annotation, BaseModel) + + def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]): + name_by_key: dict[str, str] = {} + allowed_keys: set[str] = set() + + for field_name, field_info in model.model_fields.items(): + allowed_keys.add(field_name) + name_by_key[field_name] = field_name + + alias = getattr(field_info, "alias", None) + if isinstance(alias, str) and alias: + allowed_keys.add(alias) + name_by_key[alias] = field_name + + for key in list(table.keys()): + if key not in allowed_keys: + del table[key] + continue + + field_name = name_by_key[key] + field_info = model.model_fields[field_name] + annotation = _strip_optional(getattr(field_info, "annotation", Any)) + + value = table.get(key) + if value is None: + continue + + if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)): + _prune_table(value, annotation) + continue + + origin = get_origin(annotation) + if origin is list: + args = get_args(annotation) + elem_ann = _strip_optional(args[0]) if args else Any + + # list[BaseModel] 对应 TOML 的 AoT([[...]]) + if _is_model_type(elem_ann) and hasattr(value, "__iter__"): + for item in value: + if isinstance(item, (TOMLDocument, Table)): + _prune_table(item, elem_ann) + + _prune_table(target, schema_model) + + def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): """ 将source字典的值更新到target字典中 @@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic target[key] = value -def _update_config_generic(config_name: str, template_name: str): +def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None): """ 通用的配置文件更新函数 Args: config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' + schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项) """ # 获取根目录路径 old_config_dir = os.path.join(CONFIG_DIR, "old") @@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str): logger.info(f"开始合并{config_name}新旧配置...") _update_dict(new_config, old_config) - # 移除在新模板中已不存在的旧配置项 + # 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落) logger.info(f"开始移除{config_name}中已废弃的配置项...") - with open(template_path, encoding="utf-8") as f: - template_doc = tomlkit.load(f) - _remove_obsolete_keys(new_config, template_doc) + if schema_model is not None: + _prune_unknown_keys_by_schema(new_config, schema_model) + else: + with open(template_path, encoding="utf-8") as f: + template_doc = tomlkit.load(f) + _remove_obsolete_keys(new_config, template_doc) logger.info(f"已移除{config_name}中已废弃的配置项") # 保存更新后的配置(保留注释和格式) @@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str): def update_config(): """更新bot_config.toml配置文件""" - _update_config_generic("bot_config", "bot_config_template") + _update_config_generic("bot_config", "bot_config_template", schema_model=Config) def update_model_config(): """更新model_config.toml配置文件""" - _update_config_generic("model_config", "model_config_template") + _update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig) class Config(ValidatedConfigBase): """总配置类""" - MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号") + inner: InnerConfig = Field(..., description="配置元信息") database: DatabaseConfig = Field(..., description="数据库配置") bot: BotConfig = Field(..., description="机器人基本配置") @@ -397,6 +477,7 @@ class Config(ValidatedConfigBase): chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置") response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置") + log: LogConfig = Field(..., description="日志配置") experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置") message_bus: MessageBusConfig = Field(..., description="消息总线配置") lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置") @@ -433,18 +514,34 @@ class Config(ValidatedConfigBase): default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置" ) + @property + def MMC_VERSION(self) -> str: # noqa: N802 + return MMC_VERSION + class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" + inner: InnerConfig = Field(..., description="配置元信息") models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") + _api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict) + _models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict) + def __init__(self, **data): super().__init__(**data) - self.api_providers_dict = {provider.name: provider for provider in self.api_providers} - self.models_dict = {model.name: model for model in self.models} + self._api_providers_dict = {provider.name: provider for provider in self.api_providers} + self._models_dict = {model.name: model for model in self.models} + + @property + def api_providers_dict(self) -> dict[str, APIProvider]: + return self._api_providers_dict + + @property + def models_dict(self) -> dict[str, ModelInfo]: + return self._models_dict @classmethod def validate_models_list(cls, v): @@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config: Returns: Config对象 """ - # 读取配置文件 - with open(config_path, encoding="utf-8") as f: - config_data = tomlkit.load(f) + # 读取配置文件(会自动删除未知/废弃配置项) + original_text = Path(config_path).read_text(encoding="utf-8") + config_data = tomlkit.parse(original_text) + _prune_unknown_keys_by_schema(config_data, Config) + new_text = tomlkit.dumps(config_data) + if new_text != original_text: + Path(config_path).write_text(new_text, encoding="utf-8") + logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项") # 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题 # tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型, @@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig: Returns: APIAdapterConfig对象 """ - # 读取配置文件 - with open(config_path, encoding="utf-8") as f: - config_data = tomlkit.load(f) + # 读取配置文件(会自动删除未知/废弃配置项) + original_text = Path(config_path).read_text(encoding="utf-8") + config_data = tomlkit.parse(original_text) + _prune_unknown_keys_by_schema(config_data, APIAdapterConfig) + new_text = tomlkit.dumps(config_data) + if new_text != original_text: + Path(config_path).write_text(new_text, encoding="utf-8") + logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项") - config_dict = dict(config_data) + config_dict = config_data.unwrap() try: logger.debug("正在解析和验证API适配器配置文件...") diff --git a/src/config/config_base.py b/src/config/config_base.py index 551326fa3..b01749e18 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel): """带验证的配置基类,继承自Pydantic BaseModel""" model_config = { - "extra": "allow", # 允许额外字段 + "extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项) "validate_assignment": True, # 验证赋值 "arbitrary_types_allowed": True, # 允许任意类型 "strict": True, # 如果设为 True 会完全禁用类型转换 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index a6ad76f41..862ff3d13 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase """ +class InnerConfig(ValidatedConfigBase): + """配置文件元信息""" + + version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)") + + class DatabaseConfig(ValidatedConfigBase): """数据库配置类""" @@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase): enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护") +class LogConfig(ValidatedConfigBase): + """日志配置类""" + + date_style: str = Field(default="m-d H:i:s", description="日期格式") + log_level_style: str = Field(default="lite", description="日志级别样式") + color_text: str = Field(default="full", description="日志文本颜色") + log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)") + file_retention_days: int = Field(default=7, description="文件日志保留天数,0=禁用文件日志,-1=永不删除") + console_log_level: str = Field(default="INFO", description="控制台日志级别") + file_log_level: str = Field(default="DEBUG", description="文件日志级别") + suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表") + library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别") + + class DebugConfig(ValidatedConfigBase): """调试配置类""" @@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase): enable_url_tool: bool = Field(default=True, description="启用URL工具") tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表,支持轮询机制") exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表,支持轮询机制") + metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表,支持轮询机制") searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表") searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表") serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表") @@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase): description="开启后KFC将接管所有私聊消息;关闭后私聊消息将由AFC处理" ) + # --- 工作模式 --- + mode: Literal["unified", "split"] = Field( + default="split", + description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)', + ) + # --- 核心行为配置 --- max_wait_seconds_default: int = Field( default=300, ge=30, le=3600, @@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase): description="是否在等待期间启用心理活动更新" ) + # --- 自定义决策提示词 --- + custom_decision_prompt: str = Field( + default="", + description="自定义KFC决策行为指导提示词(unified影响整体,split仅影响planner)", + ) + waiting: KokoroFlowChatterWaitingConfig = Field( default_factory=KokoroFlowChatterWaitingConfig, description="等待策略配置(默认等待时间、倍率等)", From 314021218e749e537ee7cb15cd6deee295f965c8 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 13 Dec 2025 22:49:39 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E6=9B=B4=E6=96=B0MMC=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E8=87=B30.13.1-alpha.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/config/config.py b/src/config/config.py index d4f7c8925..efec21705 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -71,7 +71,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.13.1-alpha.1" +MMC_VERSION = "0.13.1-alpha.2" # 全局配置变量 _CONFIG_INITIALIZED = False