优化消息管理

This commit is contained in:
LuiKlee
2025-12-13 20:19:11 +08:00
parent 6af9780ff6
commit 9a0163d06b
6 changed files with 405 additions and 74 deletions

View File

@@ -0,0 +1,306 @@
import asyncio
import time
import os
import sys
from dataclasses import dataclass
from typing import Any, Optional
# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior
# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems.
# Ensure project root is on sys.path when running as a script
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis)
# Local minimal implementation of loop and manager to isolate benchmark from heavy deps.
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field
@dataclass
class ConversationTick:
stream_id: str
tick_time: float = field(default_factory=time.time)
force_dispatch: bool = False
tick_count: int = 0
async def conversation_loop(
stream_id: str,
get_context_func: Callable[[str], Awaitable["DummyContext | None"]],
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
check_force_dispatch_func: Callable[["DummyContext", int], bool],
is_running_func: Callable[[], bool],
) -> AsyncIterator[ConversationTick]:
tick_count = 0
while is_running_func():
ctx = await get_context_func(stream_id)
if not ctx:
await asyncio.sleep(0.1)
continue
await flush_cache_func(stream_id)
unread = ctx.get_unread_messages()
ucnt = len(unread)
force = check_force_dispatch_func(ctx, ucnt)
if ucnt > 0 or force:
tick_count += 1
yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count)
interval = await calculate_interval_func(stream_id, ucnt > 0)
await asyncio.sleep(interval)
class StreamLoopManager:
def __init__(self, max_concurrent_streams: Optional[int] = None):
self.stats: dict[str, Any] = {
"active_streams": 0,
"total_loops": 0,
"total_process_cycles": 0,
"total_failures": 0,
"start_time": time.time(),
}
self.max_concurrent_streams = max_concurrent_streams or 100
self.force_dispatch_unread_threshold = 20
self.chatter_manager = DummyChatterManager()
self.is_running = False
self._stream_start_locks: dict[str, asyncio.Lock] = {}
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
self._chat_manager: Optional[DummyChatManager] = None
def set_chat_manager(self, chat_manager: "DummyChatManager") -> None:
self._chat_manager = chat_manager
async def start(self):
self.is_running = True
async def stop(self):
self.is_running = False
async def _get_stream_context(self, stream_id: str):
assert self._chat_manager is not None
stream = await self._chat_manager.get_stream(stream_id)
return stream.context if stream else None
async def _flush_cached_messages_to_unread(self, stream_id: str):
ctx = await self._get_stream_context(stream_id)
return ctx.flush_cached_messages() if ctx else []
def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool:
return unread_count > (self.force_dispatch_unread_threshold or 20)
async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool:
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
return bool(res.get("success", False))
async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None:
pass
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
return 0.005 if has_messages else 0.02
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
ctx = await self._get_stream_context(stream_id)
if not ctx:
return False
# create driver
loop_task = asyncio.create_task(run_chat_stream(stream_id, self))
ctx.stream_loop_task = loop_task
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
return True
async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None:
try:
gen = conversation_loop(
stream_id=stream_id,
get_context_func=manager._get_stream_context,
calculate_interval_func=manager._calculate_interval,
flush_cache_func=manager._flush_cached_messages_to_unread,
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
is_running_func=lambda: manager.is_running,
)
async for tick in gen:
ctx = await manager._get_stream_context(stream_id)
if not ctx:
continue
if ctx.is_chatter_processing:
continue
try:
async with manager._processing_semaphore:
ok = await manager._process_stream_messages(stream_id, ctx)
except Exception:
ok = False
manager.stats["total_process_cycles"] += 1
if not ok:
manager.stats["total_failures"] += 1
except asyncio.CancelledError:
pass
@dataclass
class DummyMessage:
time: float
processed_plain_text: str = ""
display_message: str = ""
is_at: bool = False
is_mentioned: bool = False
class DummyContext:
def __init__(self, stream_id: str, initial_unread: int):
self.stream_id = stream_id
self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)]
self.history_messages: list[DummyMessage] = []
self.is_chatter_processing: bool = False
self.processing_task: Optional[asyncio.Task] = None
self.stream_loop_task: Optional[asyncio.Task] = None
self.triggering_user_id: Optional[str] = None
def get_unread_messages(self) -> list[DummyMessage]:
return list(self.unread_messages)
def flush_cached_messages(self) -> list[DummyMessage]:
return []
def get_last_message(self) -> Optional[DummyMessage]:
return self.unread_messages[-1] if self.unread_messages else None
def get_history_messages(self, limit: int = 50) -> list[DummyMessage]:
return self.history_messages[-limit:]
class DummyStream:
def __init__(self, stream_id: str, ctx: DummyContext):
self.stream_id = stream_id
self.context = ctx
self.group_info = None # treat as private chat to accelerate
self._focus_energy = 0.5
class DummyChatManager:
def __init__(self, streams: dict[str, DummyStream]):
self._streams = streams
async def get_stream(self, stream_id: str) -> Optional[DummyStream]:
return self._streams.get(stream_id)
def get_all_streams(self) -> dict[str, DummyStream]:
return self._streams
class DummyChatterManager:
async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]:
# Simulate some processing latency and consume one unread message
await asyncio.sleep(0.01)
if context.unread_messages:
context.unread_messages.pop(0)
return {"success": True}
class BenchStreamLoopManager(StreamLoopManager):
def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None):
super().__init__(max_concurrent_streams=max_concurrent_streams)
self._chat_manager = chat_manager
self.chatter_manager = DummyChatterManager()
async def _get_stream_context(self, stream_id: str): # type: ignore[override]
stream = await self._chat_manager.get_stream(stream_id)
return stream.context if stream else None
async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override]
ctx = await self._get_stream_context(stream_id)
return ctx.flush_cached_messages() if ctx else []
def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override]
# force when unread exceeds threshold
return unread_count > (self.force_dispatch_unread_threshold or 20)
async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override]
# delegate to chatter manager
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
return bool(res.get("success", False))
async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool:
return False
async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override]
# lightweight: compute based on unread size
focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages()))
# set for compatibility
stream = await self._chat_manager.get_stream(stream_id)
if stream:
stream._focus_energy = focus
def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]:
streams: dict[str, DummyStream] = {}
for i in range(n_streams):
sid = f"s{i:04d}"
ctx = DummyContext(sid, initial_unread)
streams[sid] = DummyStream(sid, ctx)
return streams
async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]:
streams = make_streams(n_streams, initial_unread)
chat_mgr = DummyChatManager(streams)
mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent)
await mgr.start()
# start loops for all streams
start_ts = time.time()
for sid in list(streams.keys()):
await mgr.start_stream_loop(sid, force=True)
# run until all unread consumed or timeout
timeout = 5.0
end_deadline = start_ts + timeout
while time.time() < end_deadline:
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
if remaining == 0:
break
await asyncio.sleep(0.02)
duration = time.time() - start_ts
total_cycles = mgr.stats.get("total_process_cycles", 0)
total_failures = mgr.stats.get("total_failures", 0)
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
# stop all
await mgr.stop()
return {
"n_streams": n_streams,
"initial_unread": initial_unread,
"max_concurrent": max_concurrent,
"duration_sec": duration,
"total_cycles": total_cycles,
"total_failures": total_failures,
"remaining_unread": remaining,
"throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration),
}
async def main():
cases = [
(50, 5, None), # baseline using configured default
(50, 5, 5), # constrained concurrency
(50, 5, 10), # moderate concurrency
(100, 3, 10), # scale streams
]
print("Running distribution manager benchmark...\n")
for n_streams, initial_unread, max_concurrent in cases:
res = await run_benchmark(n_streams, initial_unread, max_concurrent)
print(
f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | "
f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+
f"thr={res['throughput_msgs_per_sec']:.1f}/s"
)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -6,8 +6,6 @@
import asyncio
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
class PerformanceBenchmark:
@@ -20,7 +18,7 @@ class PerformanceBenchmark:
"""测试查询去重性能"""
# 这里需要导入实际的管理器
# from src.memory_graph.unified_manager import UnifiedMemoryManager
test_cases = [
{
"name": "small_queries",
@@ -67,7 +65,7 @@ class PerformanceBenchmark:
seen = set()
decay = 0.15
manual_queries = []
for raw in queries:
text = (raw or "").strip()
if text and text not in seen:
@@ -192,7 +190,7 @@ class PerformanceBenchmark:
mem_id = mem.get("id")
else:
mem_id = getattr(mem, "id", None)
if mem_id and mem_id in seen_ids:
continue
unique_memories.append(mem)

View File

@@ -9,6 +9,8 @@ from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.exc import SQLAlchemyError
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.logger import get_logger
@@ -159,20 +161,27 @@ class BatchDatabaseWriter:
logger.info("批量写入循环结束")
async def _collect_batch(self) -> list[StreamUpdatePayload]:
"""收集一个批次的数据"""
batch = []
deadline = time.time() + self.flush_interval
"""收集一个批次的数据
- 自适应刷新:队列增长加快时缩短等待时间
- 避免长时间空转:添加轻微抖动以分散竞争
"""
batch: list[StreamUpdatePayload] = []
# 根据当前队列长度调整刷新时间(最多缩短到 40%
qsize = self.write_queue.qsize()
adapt_factor = 1.0
if qsize > 0:
adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize)))
deadline = time.time() + (self.flush_interval * adapt_factor)
while len(batch) < self.batch_size and time.time() < deadline:
try:
# 计算剩余等待时间
remaining_time = max(0, deadline - time.time())
remaining_time = max(0.0, deadline - time.time())
if remaining_time == 0:
break
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
# 轻微抖动,避免多个协程同时争抢队列
jitter = 0.002
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter)
batch.append(payload)
except asyncio.TimeoutError:
break
@@ -208,48 +217,52 @@ class BatchDatabaseWriter:
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
except Exception as e:
except SQLAlchemyError as e:
self.stats["failed_writes"] += 1
logger.error(f"批量写入失败: {e}")
# 降级到单个写入
for payload in batch:
try:
await self._direct_write(payload.stream_id, payload.update_data)
except Exception as single_e:
except SQLAlchemyError as single_e:
logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
"""批量写入数据库"""
"""批量写入数据库(单事务、多值 UPSERT"""
if global_config is None:
raise RuntimeError("Global config is not initialized")
if not payloads:
return
# 预组装行数据,确保每行包含 stream_id
rows: list[dict[str, Any]] = []
for p in payloads:
row = {"stream_id": p.stream_id}
row.update(p.update_data)
rows.append(row)
async with get_db_session() as session:
for payload in payloads:
stream_id = payload.stream_id
update_data = payload.update_data
# 根据数据库类型选择不同的插入/更新策略
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=update_data
)
else:
# 默认使用SQLite语法
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
# 使用单次事务提交,显著减少 I/O
if global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(rows)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
)
await session.execute(stmt)
await session.commit()
else:
# 默认sqlite
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(rows)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
)
await session.execute(stmt)
await session.commit()
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
"""直接写入数据库(降级方案)"""
if global_config is None:

View File

@@ -55,7 +55,7 @@ async def conversation_loop(
stream_id: str,
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
flush_cache_func: Callable[[str], Awaitable[None]],
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
check_force_dispatch_func: Callable[["StreamContext", int], bool],
is_running_func: Callable[[], bool],
) -> AsyncIterator[ConversationTick]:
@@ -121,7 +121,7 @@ async def conversation_loop(
except asyncio.CancelledError:
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
break
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
await asyncio.sleep(5.0)
@@ -151,10 +151,10 @@ async def run_chat_stream(
# 创建生成器
tick_generator = conversation_loop(
stream_id=stream_id,
get_context_func=manager._get_stream_context,
calculate_interval_func=manager._calculate_interval,
flush_cache_func=manager._flush_cached_messages_to_unread,
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
get_context_func=manager._get_stream_context, # noqa: SLF001
calculate_interval_func=manager._calculate_interval, # noqa: SLF001
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001
is_running_func=lambda: manager.is_running,
)
@@ -162,13 +162,13 @@ async def run_chat_stream(
async for tick in tick_generator:
try:
# 获取上下文
context = await manager._get_stream_context(stream_id)
context = await manager._get_stream_context(stream_id) # noqa: SLF001
if not context:
continue
# 并发保护:检查是否正在处理
if context.is_chatter_processing:
if manager._recover_stale_chatter_state(stream_id, context):
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
else:
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理跳过此Tick")
@@ -182,17 +182,18 @@ async def run_chat_stream(
# 更新能量值
try:
await manager._update_stream_energy(stream_id, context)
await manager._update_stream_energy(stream_id, context) # noqa: SLF001
except Exception as e:
logger.debug(f"更新能量失败: {e}")
# 处理消息
assert global_config is not None
try:
success = await asyncio.wait_for(
manager._process_stream_messages(stream_id, context),
global_config.chat.thinking_timeout
)
async with manager._processing_semaphore:
success = await asyncio.wait_for(
manager._process_stream_messages(stream_id, context), # noqa: SLF001
global_config.chat.thinking_timeout,
)
except asyncio.TimeoutError:
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
success = False
@@ -208,7 +209,7 @@ async def run_chat_stream(
except asyncio.CancelledError:
raise
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
manager.stats["total_failures"] += 1
@@ -221,7 +222,7 @@ async def run_chat_stream(
if context and context.stream_loop_task:
context.stream_loop_task = None
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
except Exception as e:
except Exception as e: # noqa: BLE001
logger.debug(f"清理任务记录失败: {e}")
@@ -268,6 +269,9 @@ class StreamLoopManager:
# 流启动锁:防止并发启动同一个流的多个任务
self._stream_start_locks: dict[str, asyncio.Lock] = {}
# 并发控制:限制同时进行的 Chatter 处理任务数
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
# ========================================================================

View File

@@ -104,9 +104,21 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 启动 stream loop 任务(如果尚未启动)
await stream_loop_manager.start_stream_loop(stream_id)
await self._check_and_handle_interruption(chat_stream, message)
# 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await
context = chat_stream.context
if not (context.stream_loop_task and not context.stream_loop_task.done()):
# 异步启动驱动器任务;避免在高并发下阻塞消息入队
context.stream_loop_task = asyncio.create_task(
stream_loop_manager.start_stream_loop(stream_id)
)
# 并行触发打断检查,不阻塞消息入队
context.interruption_task = asyncio.create_task(
self._check_and_handle_interruption(chat_stream, message)
)
# 入队消息
await chat_stream.context.add_message(message)
except Exception as e:
@@ -476,8 +488,7 @@ class MessageManager:
is_processing: 是否正在处理
"""
try:
# 尝试更新StreamContext的处理状态
import asyncio
# 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入)
async def _update_context():
try:
chat_manager = get_chat_manager()
@@ -492,7 +503,7 @@ class MessageManager:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(_update_context())
self._update_context_task = asyncio.create_task(_update_context())
else:
# 如果事件循环未运行,则跳过
logger.debug("事件循环未运行跳过StreamContext状态更新")
@@ -512,8 +523,7 @@ class MessageManager:
bool: 是否正在处理
"""
try:
# 尝试从StreamContext获取处理状态
import asyncio
# 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入)
async def _get_context_status():
try:
chat_manager = get_chat_manager()

View File

@@ -451,7 +451,7 @@ class UnifiedMemoryManager:
(0.3, 10.0, 0.4),
(0.1, 15.0, 0.6),
]
for threshold, min_val, factor in occupancy_thresholds:
if occupancy >= threshold:
return max(min_val, base_interval * factor)
@@ -461,24 +461,24 @@ class UnifiedMemoryManager:
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
"""实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)"""
logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块")
# 优化:使用 asyncio.gather 并行处理转移
async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]:
try:
stm = await self.short_term_manager.add_from_block(block)
if not stm:
return block, False
await self.perceptual_manager.remove_block(block.id)
logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
return block, True
except Exception as exc:
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}")
return block, False
# 并行处理所有块
results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True)
# 统计成功的转移
success_count = sum(1 for result in results if isinstance(result, tuple) and result[1])
if success_count > 0:
@@ -491,7 +491,7 @@ class UnifiedMemoryManager:
seen = set()
decay = 0.15
manual_queries: list[dict[str, Any]] = []
for raw in queries:
text = (raw or "").strip()
if text and text not in seen:
@@ -517,7 +517,7 @@ class UnifiedMemoryManager:
"top_k": self._config["long_term"]["search_top_k"],
"use_multi_query": bool(manual_queries),
}
if recent_chat_history or manual_queries:
context: dict[str, Any] = {}
if recent_chat_history:
@@ -541,7 +541,7 @@ class UnifiedMemoryManager:
mem_id = mem.get("id")
else:
mem_id = getattr(mem, "id", None)
# 检查去重
if mem_id and mem_id in seen_ids:
continue
@@ -600,7 +600,7 @@ class UnifiedMemoryManager:
new_memories.append(memory)
if mem_id:
cached_ids.add(mem_id)
if new_memories:
transfer_cache.extend(new_memories)
logger.debug(
@@ -632,7 +632,7 @@ class UnifiedMemoryManager:
await self.short_term_manager.clear_transferred_memories(
result["transferred_memory_ids"]
)
# 优化:使用生成器表达式保留未转移的记忆
transfer_cache = [
m