优化消息管理

This commit is contained in:
LuiKlee
2025-12-13 20:19:11 +08:00
parent 6af9780ff6
commit 9a0163d06b
6 changed files with 405 additions and 74 deletions

View File

@@ -0,0 +1,306 @@
import asyncio
import time
import os
import sys
from dataclasses import dataclass
from typing import Any, Optional
# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior
# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems.
# Ensure project root is on sys.path when running as a script
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis)
# Local minimal implementation of loop and manager to isolate benchmark from heavy deps.
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field
@dataclass
class ConversationTick:
stream_id: str
tick_time: float = field(default_factory=time.time)
force_dispatch: bool = False
tick_count: int = 0
async def conversation_loop(
stream_id: str,
get_context_func: Callable[[str], Awaitable["DummyContext | None"]],
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
check_force_dispatch_func: Callable[["DummyContext", int], bool],
is_running_func: Callable[[], bool],
) -> AsyncIterator[ConversationTick]:
tick_count = 0
while is_running_func():
ctx = await get_context_func(stream_id)
if not ctx:
await asyncio.sleep(0.1)
continue
await flush_cache_func(stream_id)
unread = ctx.get_unread_messages()
ucnt = len(unread)
force = check_force_dispatch_func(ctx, ucnt)
if ucnt > 0 or force:
tick_count += 1
yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count)
interval = await calculate_interval_func(stream_id, ucnt > 0)
await asyncio.sleep(interval)
class StreamLoopManager:
def __init__(self, max_concurrent_streams: Optional[int] = None):
self.stats: dict[str, Any] = {
"active_streams": 0,
"total_loops": 0,
"total_process_cycles": 0,
"total_failures": 0,
"start_time": time.time(),
}
self.max_concurrent_streams = max_concurrent_streams or 100
self.force_dispatch_unread_threshold = 20
self.chatter_manager = DummyChatterManager()
self.is_running = False
self._stream_start_locks: dict[str, asyncio.Lock] = {}
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
self._chat_manager: Optional[DummyChatManager] = None
def set_chat_manager(self, chat_manager: "DummyChatManager") -> None:
self._chat_manager = chat_manager
async def start(self):
self.is_running = True
async def stop(self):
self.is_running = False
async def _get_stream_context(self, stream_id: str):
assert self._chat_manager is not None
stream = await self._chat_manager.get_stream(stream_id)
return stream.context if stream else None
async def _flush_cached_messages_to_unread(self, stream_id: str):
ctx = await self._get_stream_context(stream_id)
return ctx.flush_cached_messages() if ctx else []
def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool:
return unread_count > (self.force_dispatch_unread_threshold or 20)
async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool:
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
return bool(res.get("success", False))
async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None:
pass
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
return 0.005 if has_messages else 0.02
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
ctx = await self._get_stream_context(stream_id)
if not ctx:
return False
# create driver
loop_task = asyncio.create_task(run_chat_stream(stream_id, self))
ctx.stream_loop_task = loop_task
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
return True
async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None:
try:
gen = conversation_loop(
stream_id=stream_id,
get_context_func=manager._get_stream_context,
calculate_interval_func=manager._calculate_interval,
flush_cache_func=manager._flush_cached_messages_to_unread,
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
is_running_func=lambda: manager.is_running,
)
async for tick in gen:
ctx = await manager._get_stream_context(stream_id)
if not ctx:
continue
if ctx.is_chatter_processing:
continue
try:
async with manager._processing_semaphore:
ok = await manager._process_stream_messages(stream_id, ctx)
except Exception:
ok = False
manager.stats["total_process_cycles"] += 1
if not ok:
manager.stats["total_failures"] += 1
except asyncio.CancelledError:
pass
@dataclass
class DummyMessage:
time: float
processed_plain_text: str = ""
display_message: str = ""
is_at: bool = False
is_mentioned: bool = False
class DummyContext:
def __init__(self, stream_id: str, initial_unread: int):
self.stream_id = stream_id
self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)]
self.history_messages: list[DummyMessage] = []
self.is_chatter_processing: bool = False
self.processing_task: Optional[asyncio.Task] = None
self.stream_loop_task: Optional[asyncio.Task] = None
self.triggering_user_id: Optional[str] = None
def get_unread_messages(self) -> list[DummyMessage]:
return list(self.unread_messages)
def flush_cached_messages(self) -> list[DummyMessage]:
return []
def get_last_message(self) -> Optional[DummyMessage]:
return self.unread_messages[-1] if self.unread_messages else None
def get_history_messages(self, limit: int = 50) -> list[DummyMessage]:
return self.history_messages[-limit:]
class DummyStream:
def __init__(self, stream_id: str, ctx: DummyContext):
self.stream_id = stream_id
self.context = ctx
self.group_info = None # treat as private chat to accelerate
self._focus_energy = 0.5
class DummyChatManager:
def __init__(self, streams: dict[str, DummyStream]):
self._streams = streams
async def get_stream(self, stream_id: str) -> Optional[DummyStream]:
return self._streams.get(stream_id)
def get_all_streams(self) -> dict[str, DummyStream]:
return self._streams
class DummyChatterManager:
async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]:
# Simulate some processing latency and consume one unread message
await asyncio.sleep(0.01)
if context.unread_messages:
context.unread_messages.pop(0)
return {"success": True}
class BenchStreamLoopManager(StreamLoopManager):
def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None):
super().__init__(max_concurrent_streams=max_concurrent_streams)
self._chat_manager = chat_manager
self.chatter_manager = DummyChatterManager()
async def _get_stream_context(self, stream_id: str): # type: ignore[override]
stream = await self._chat_manager.get_stream(stream_id)
return stream.context if stream else None
async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override]
ctx = await self._get_stream_context(stream_id)
return ctx.flush_cached_messages() if ctx else []
def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override]
# force when unread exceeds threshold
return unread_count > (self.force_dispatch_unread_threshold or 20)
async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override]
# delegate to chatter manager
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
return bool(res.get("success", False))
async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool:
return False
async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override]
# lightweight: compute based on unread size
focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages()))
# set for compatibility
stream = await self._chat_manager.get_stream(stream_id)
if stream:
stream._focus_energy = focus
def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]:
streams: dict[str, DummyStream] = {}
for i in range(n_streams):
sid = f"s{i:04d}"
ctx = DummyContext(sid, initial_unread)
streams[sid] = DummyStream(sid, ctx)
return streams
async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]:
streams = make_streams(n_streams, initial_unread)
chat_mgr = DummyChatManager(streams)
mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent)
await mgr.start()
# start loops for all streams
start_ts = time.time()
for sid in list(streams.keys()):
await mgr.start_stream_loop(sid, force=True)
# run until all unread consumed or timeout
timeout = 5.0
end_deadline = start_ts + timeout
while time.time() < end_deadline:
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
if remaining == 0:
break
await asyncio.sleep(0.02)
duration = time.time() - start_ts
total_cycles = mgr.stats.get("total_process_cycles", 0)
total_failures = mgr.stats.get("total_failures", 0)
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
# stop all
await mgr.stop()
return {
"n_streams": n_streams,
"initial_unread": initial_unread,
"max_concurrent": max_concurrent,
"duration_sec": duration,
"total_cycles": total_cycles,
"total_failures": total_failures,
"remaining_unread": remaining,
"throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration),
}
async def main():
cases = [
(50, 5, None), # baseline using configured default
(50, 5, 5), # constrained concurrency
(50, 5, 10), # moderate concurrency
(100, 3, 10), # scale streams
]
print("Running distribution manager benchmark...\n")
for n_streams, initial_unread, max_concurrent in cases:
res = await run_benchmark(n_streams, initial_unread, max_concurrent)
print(
f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | "
f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+
f"thr={res['throughput_msgs_per_sec']:.1f}/s"
)
if __name__ == "__main__":
asyncio.run(main())