优化消息管理
This commit is contained in:
306
scripts/benchmark_distribution_manager.py
Normal file
306
scripts/benchmark_distribution_manager.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior
|
||||||
|
# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems.
|
||||||
|
|
||||||
|
# Ensure project root is on sys.path when running as a script
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
if ROOT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, ROOT_DIR)
|
||||||
|
|
||||||
|
# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis)
|
||||||
|
# Local minimal implementation of loop and manager to isolate benchmark from heavy deps.
|
||||||
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConversationTick:
|
||||||
|
stream_id: str
|
||||||
|
tick_time: float = field(default_factory=time.time)
|
||||||
|
force_dispatch: bool = False
|
||||||
|
tick_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def conversation_loop(
|
||||||
|
stream_id: str,
|
||||||
|
get_context_func: Callable[[str], Awaitable["DummyContext | None"]],
|
||||||
|
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||||
|
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||||
|
check_force_dispatch_func: Callable[["DummyContext", int], bool],
|
||||||
|
is_running_func: Callable[[], bool],
|
||||||
|
) -> AsyncIterator[ConversationTick]:
|
||||||
|
tick_count = 0
|
||||||
|
while is_running_func():
|
||||||
|
ctx = await get_context_func(stream_id)
|
||||||
|
if not ctx:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
await flush_cache_func(stream_id)
|
||||||
|
unread = ctx.get_unread_messages()
|
||||||
|
ucnt = len(unread)
|
||||||
|
force = check_force_dispatch_func(ctx, ucnt)
|
||||||
|
if ucnt > 0 or force:
|
||||||
|
tick_count += 1
|
||||||
|
yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count)
|
||||||
|
interval = await calculate_interval_func(stream_id, ucnt > 0)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamLoopManager:
|
||||||
|
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||||
|
self.stats: dict[str, Any] = {
|
||||||
|
"active_streams": 0,
|
||||||
|
"total_loops": 0,
|
||||||
|
"total_process_cycles": 0,
|
||||||
|
"total_failures": 0,
|
||||||
|
"start_time": time.time(),
|
||||||
|
}
|
||||||
|
self.max_concurrent_streams = max_concurrent_streams or 100
|
||||||
|
self.force_dispatch_unread_threshold = 20
|
||||||
|
self.chatter_manager = DummyChatterManager()
|
||||||
|
self.is_running = False
|
||||||
|
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||||
|
self._chat_manager: Optional[DummyChatManager] = None
|
||||||
|
|
||||||
|
def set_chat_manager(self, chat_manager: "DummyChatManager") -> None:
|
||||||
|
self._chat_manager = chat_manager
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.is_running = True
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
async def _get_stream_context(self, stream_id: str):
|
||||||
|
assert self._chat_manager is not None
|
||||||
|
stream = await self._chat_manager.get_stream(stream_id)
|
||||||
|
return stream.context if stream else None
|
||||||
|
|
||||||
|
async def _flush_cached_messages_to_unread(self, stream_id: str):
|
||||||
|
ctx = await self._get_stream_context(stream_id)
|
||||||
|
return ctx.flush_cached_messages() if ctx else []
|
||||||
|
|
||||||
|
def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool:
|
||||||
|
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||||
|
|
||||||
|
async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool:
|
||||||
|
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||||
|
return bool(res.get("success", False))
|
||||||
|
|
||||||
|
async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
|
||||||
|
return 0.005 if has_messages else 0.02
|
||||||
|
|
||||||
|
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||||
|
ctx = await self._get_stream_context(stream_id)
|
||||||
|
if not ctx:
|
||||||
|
return False
|
||||||
|
# create driver
|
||||||
|
loop_task = asyncio.create_task(run_chat_stream(stream_id, self))
|
||||||
|
ctx.stream_loop_task = loop_task
|
||||||
|
self.stats["active_streams"] += 1
|
||||||
|
self.stats["total_loops"] += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None:
|
||||||
|
try:
|
||||||
|
gen = conversation_loop(
|
||||||
|
stream_id=stream_id,
|
||||||
|
get_context_func=manager._get_stream_context,
|
||||||
|
calculate_interval_func=manager._calculate_interval,
|
||||||
|
flush_cache_func=manager._flush_cached_messages_to_unread,
|
||||||
|
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
||||||
|
is_running_func=lambda: manager.is_running,
|
||||||
|
)
|
||||||
|
async for tick in gen:
|
||||||
|
ctx = await manager._get_stream_context(stream_id)
|
||||||
|
if not ctx:
|
||||||
|
continue
|
||||||
|
if ctx.is_chatter_processing:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
async with manager._processing_semaphore:
|
||||||
|
ok = await manager._process_stream_messages(stream_id, ctx)
|
||||||
|
except Exception:
|
||||||
|
ok = False
|
||||||
|
manager.stats["total_process_cycles"] += 1
|
||||||
|
if not ok:
|
||||||
|
manager.stats["total_failures"] += 1
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyMessage:
|
||||||
|
time: float
|
||||||
|
processed_plain_text: str = ""
|
||||||
|
display_message: str = ""
|
||||||
|
is_at: bool = False
|
||||||
|
is_mentioned: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class DummyContext:
|
||||||
|
def __init__(self, stream_id: str, initial_unread: int):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)]
|
||||||
|
self.history_messages: list[DummyMessage] = []
|
||||||
|
self.is_chatter_processing: bool = False
|
||||||
|
self.processing_task: Optional[asyncio.Task] = None
|
||||||
|
self.stream_loop_task: Optional[asyncio.Task] = None
|
||||||
|
self.triggering_user_id: Optional[str] = None
|
||||||
|
|
||||||
|
def get_unread_messages(self) -> list[DummyMessage]:
|
||||||
|
return list(self.unread_messages)
|
||||||
|
|
||||||
|
def flush_cached_messages(self) -> list[DummyMessage]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_last_message(self) -> Optional[DummyMessage]:
|
||||||
|
return self.unread_messages[-1] if self.unread_messages else None
|
||||||
|
|
||||||
|
def get_history_messages(self, limit: int = 50) -> list[DummyMessage]:
|
||||||
|
return self.history_messages[-limit:]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyStream:
|
||||||
|
def __init__(self, stream_id: str, ctx: DummyContext):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.context = ctx
|
||||||
|
self.group_info = None # treat as private chat to accelerate
|
||||||
|
self._focus_energy = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChatManager:
|
||||||
|
def __init__(self, streams: dict[str, DummyStream]):
|
||||||
|
self._streams = streams
|
||||||
|
|
||||||
|
async def get_stream(self, stream_id: str) -> Optional[DummyStream]:
|
||||||
|
return self._streams.get(stream_id)
|
||||||
|
|
||||||
|
def get_all_streams(self) -> dict[str, DummyStream]:
|
||||||
|
return self._streams
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChatterManager:
|
||||||
|
async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]:
|
||||||
|
# Simulate some processing latency and consume one unread message
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
if context.unread_messages:
|
||||||
|
context.unread_messages.pop(0)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
|
class BenchStreamLoopManager(StreamLoopManager):
|
||||||
|
def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None):
|
||||||
|
super().__init__(max_concurrent_streams=max_concurrent_streams)
|
||||||
|
self._chat_manager = chat_manager
|
||||||
|
self.chatter_manager = DummyChatterManager()
|
||||||
|
|
||||||
|
async def _get_stream_context(self, stream_id: str): # type: ignore[override]
|
||||||
|
stream = await self._chat_manager.get_stream(stream_id)
|
||||||
|
return stream.context if stream else None
|
||||||
|
|
||||||
|
async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override]
|
||||||
|
ctx = await self._get_stream_context(stream_id)
|
||||||
|
return ctx.flush_cached_messages() if ctx else []
|
||||||
|
|
||||||
|
def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override]
|
||||||
|
# force when unread exceeds threshold
|
||||||
|
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||||
|
|
||||||
|
async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override]
|
||||||
|
# delegate to chatter manager
|
||||||
|
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||||
|
return bool(res.get("success", False))
|
||||||
|
|
||||||
|
async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override]
|
||||||
|
# lightweight: compute based on unread size
|
||||||
|
focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages()))
|
||||||
|
# set for compatibility
|
||||||
|
stream = await self._chat_manager.get_stream(stream_id)
|
||||||
|
if stream:
|
||||||
|
stream._focus_energy = focus
|
||||||
|
|
||||||
|
|
||||||
|
def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]:
|
||||||
|
streams: dict[str, DummyStream] = {}
|
||||||
|
for i in range(n_streams):
|
||||||
|
sid = f"s{i:04d}"
|
||||||
|
ctx = DummyContext(sid, initial_unread)
|
||||||
|
streams[sid] = DummyStream(sid, ctx)
|
||||||
|
return streams
|
||||||
|
|
||||||
|
|
||||||
|
async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]:
|
||||||
|
streams = make_streams(n_streams, initial_unread)
|
||||||
|
chat_mgr = DummyChatManager(streams)
|
||||||
|
mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent)
|
||||||
|
await mgr.start()
|
||||||
|
|
||||||
|
# start loops for all streams
|
||||||
|
start_ts = time.time()
|
||||||
|
for sid in list(streams.keys()):
|
||||||
|
await mgr.start_stream_loop(sid, force=True)
|
||||||
|
|
||||||
|
# run until all unread consumed or timeout
|
||||||
|
timeout = 5.0
|
||||||
|
end_deadline = start_ts + timeout
|
||||||
|
while time.time() < end_deadline:
|
||||||
|
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||||
|
if remaining == 0:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
duration = time.time() - start_ts
|
||||||
|
total_cycles = mgr.stats.get("total_process_cycles", 0)
|
||||||
|
total_failures = mgr.stats.get("total_failures", 0)
|
||||||
|
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||||
|
|
||||||
|
# stop all
|
||||||
|
await mgr.stop()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"n_streams": n_streams,
|
||||||
|
"initial_unread": initial_unread,
|
||||||
|
"max_concurrent": max_concurrent,
|
||||||
|
"duration_sec": duration,
|
||||||
|
"total_cycles": total_cycles,
|
||||||
|
"total_failures": total_failures,
|
||||||
|
"remaining_unread": remaining,
|
||||||
|
"throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
cases = [
|
||||||
|
(50, 5, None), # baseline using configured default
|
||||||
|
(50, 5, 5), # constrained concurrency
|
||||||
|
(50, 5, 10), # moderate concurrency
|
||||||
|
(100, 3, 10), # scale streams
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Running distribution manager benchmark...\n")
|
||||||
|
for n_streams, initial_unread, max_concurrent in cases:
|
||||||
|
res = await run_benchmark(n_streams, initial_unread, max_concurrent)
|
||||||
|
print(
|
||||||
|
f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | "
|
||||||
|
f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+
|
||||||
|
f"thr={res['throughput_msgs_per_sec']:.1f}/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
|
|
||||||
class PerformanceBenchmark:
|
class PerformanceBenchmark:
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams
|
from src.common.database.core.models import ChatStreams
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -159,20 +161,27 @@ class BatchDatabaseWriter:
|
|||||||
logger.info("批量写入循环结束")
|
logger.info("批量写入循环结束")
|
||||||
|
|
||||||
async def _collect_batch(self) -> list[StreamUpdatePayload]:
|
async def _collect_batch(self) -> list[StreamUpdatePayload]:
|
||||||
"""收集一个批次的数据"""
|
"""收集一个批次的数据
|
||||||
batch = []
|
- 自适应刷新:队列增长加快时缩短等待时间
|
||||||
deadline = time.time() + self.flush_interval
|
- 避免长时间空转:添加轻微抖动以分散竞争
|
||||||
|
"""
|
||||||
|
batch: list[StreamUpdatePayload] = []
|
||||||
|
# 根据当前队列长度调整刷新时间(最多缩短到 40%)
|
||||||
|
qsize = self.write_queue.qsize()
|
||||||
|
adapt_factor = 1.0
|
||||||
|
if qsize > 0:
|
||||||
|
adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize)))
|
||||||
|
deadline = time.time() + (self.flush_interval * adapt_factor)
|
||||||
|
|
||||||
while len(batch) < self.batch_size and time.time() < deadline:
|
while len(batch) < self.batch_size and time.time() < deadline:
|
||||||
try:
|
try:
|
||||||
# 计算剩余等待时间
|
remaining_time = max(0.0, deadline - time.time())
|
||||||
remaining_time = max(0, deadline - time.time())
|
|
||||||
if remaining_time == 0:
|
if remaining_time == 0:
|
||||||
break
|
break
|
||||||
|
# 轻微抖动,避免多个协程同时争抢队列
|
||||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
jitter = 0.002
|
||||||
|
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter)
|
||||||
batch.append(payload)
|
batch.append(payload)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -208,48 +217,52 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||||
|
|
||||||
except Exception as e:
|
except SQLAlchemyError as e:
|
||||||
self.stats["failed_writes"] += 1
|
self.stats["failed_writes"] += 1
|
||||||
logger.error(f"批量写入失败: {e}")
|
logger.error(f"批量写入失败: {e}")
|
||||||
# 降级到单个写入
|
# 降级到单个写入
|
||||||
for payload in batch:
|
for payload in batch:
|
||||||
try:
|
try:
|
||||||
await self._direct_write(payload.stream_id, payload.update_data)
|
await self._direct_write(payload.stream_id, payload.update_data)
|
||||||
except Exception as single_e:
|
except SQLAlchemyError as single_e:
|
||||||
logger.error(f"单个写入也失败: {single_e}")
|
logger.error(f"单个写入也失败: {single_e}")
|
||||||
|
|
||||||
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
||||||
"""批量写入数据库"""
|
"""批量写入数据库(单事务、多值 UPSERT)"""
|
||||||
if global_config is None:
|
if global_config is None:
|
||||||
raise RuntimeError("Global config is not initialized")
|
raise RuntimeError("Global config is not initialized")
|
||||||
|
|
||||||
|
if not payloads:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 预组装行数据,确保每行包含 stream_id
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for p in payloads:
|
||||||
|
row = {"stream_id": p.stream_id}
|
||||||
|
row.update(p.update_data)
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for payload in payloads:
|
# 使用单次事务提交,显著减少 I/O
|
||||||
stream_id = payload.stream_id
|
if global_config.database.database_type == "postgresql":
|
||||||
update_data = payload.update_data
|
|
||||||
|
|
||||||
# 根据数据库类型选择不同的插入/更新策略
|
|
||||||
if global_config.database.database_type == "sqlite":
|
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
||||||
|
|
||||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
|
||||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
|
||||||
elif global_config.database.database_type == "postgresql":
|
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
stmt = pg_insert(ChatStreams).values(rows)
|
||||||
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
|
||||||
stmt = stmt.on_conflict_do_update(
|
stmt = stmt.on_conflict_do_update(
|
||||||
index_elements=[ChatStreams.stream_id],
|
index_elements=[ChatStreams.stream_id],
|
||||||
set_=update_data
|
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# 默认使用SQLite语法
|
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
||||||
|
|
||||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
|
||||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
# 默认(sqlite)
|
||||||
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
stmt = sqlite_insert(ChatStreams).values(rows)
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=["stream_id"],
|
||||||
|
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
||||||
"""直接写入数据库(降级方案)"""
|
"""直接写入数据库(降级方案)"""
|
||||||
if global_config is None:
|
if global_config is None:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ async def conversation_loop(
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
|
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
|
||||||
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||||
flush_cache_func: Callable[[str], Awaitable[None]],
|
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||||
check_force_dispatch_func: Callable[["StreamContext", int], bool],
|
check_force_dispatch_func: Callable[["StreamContext", int], bool],
|
||||||
is_running_func: Callable[[], bool],
|
is_running_func: Callable[[], bool],
|
||||||
) -> AsyncIterator[ConversationTick]:
|
) -> AsyncIterator[ConversationTick]:
|
||||||
@@ -121,7 +121,7 @@ async def conversation_loop(
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
|
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
|
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
|
||||||
await asyncio.sleep(5.0)
|
await asyncio.sleep(5.0)
|
||||||
|
|
||||||
@@ -151,10 +151,10 @@ async def run_chat_stream(
|
|||||||
# 创建生成器
|
# 创建生成器
|
||||||
tick_generator = conversation_loop(
|
tick_generator = conversation_loop(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
get_context_func=manager._get_stream_context,
|
get_context_func=manager._get_stream_context, # noqa: SLF001
|
||||||
calculate_interval_func=manager._calculate_interval,
|
calculate_interval_func=manager._calculate_interval, # noqa: SLF001
|
||||||
flush_cache_func=manager._flush_cached_messages_to_unread,
|
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001
|
||||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001
|
||||||
is_running_func=lambda: manager.is_running,
|
is_running_func=lambda: manager.is_running,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,13 +162,13 @@ async def run_chat_stream(
|
|||||||
async for tick in tick_generator:
|
async for tick in tick_generator:
|
||||||
try:
|
try:
|
||||||
# 获取上下文
|
# 获取上下文
|
||||||
context = await manager._get_stream_context(stream_id)
|
context = await manager._get_stream_context(stream_id) # noqa: SLF001
|
||||||
if not context:
|
if not context:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 并发保护:检查是否正在处理
|
# 并发保护:检查是否正在处理
|
||||||
if context.is_chatter_processing:
|
if context.is_chatter_processing:
|
||||||
if manager._recover_stale_chatter_state(stream_id, context):
|
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001
|
||||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
|
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
|
||||||
else:
|
else:
|
||||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick")
|
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick")
|
||||||
@@ -182,16 +182,17 @@ async def run_chat_stream(
|
|||||||
|
|
||||||
# 更新能量值
|
# 更新能量值
|
||||||
try:
|
try:
|
||||||
await manager._update_stream_energy(stream_id, context)
|
await manager._update_stream_energy(stream_id, context) # noqa: SLF001
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"更新能量失败: {e}")
|
logger.debug(f"更新能量失败: {e}")
|
||||||
|
|
||||||
# 处理消息
|
# 处理消息
|
||||||
assert global_config is not None
|
assert global_config is not None
|
||||||
try:
|
try:
|
||||||
|
async with manager._processing_semaphore:
|
||||||
success = await asyncio.wait_for(
|
success = await asyncio.wait_for(
|
||||||
manager._process_stream_messages(stream_id, context),
|
manager._process_stream_messages(stream_id, context), # noqa: SLF001
|
||||||
global_config.chat.thinking_timeout
|
global_config.chat.thinking_timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
|
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
|
||||||
@@ -208,7 +209,7 @@ async def run_chat_stream(
|
|||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
|
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
|
||||||
manager.stats["total_failures"] += 1
|
manager.stats["total_failures"] += 1
|
||||||
|
|
||||||
@@ -221,7 +222,7 @@ async def run_chat_stream(
|
|||||||
if context and context.stream_loop_task:
|
if context and context.stream_loop_task:
|
||||||
context.stream_loop_task = None
|
context.stream_loop_task = None
|
||||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
|
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
logger.debug(f"清理任务记录失败: {e}")
|
logger.debug(f"清理任务记录失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
@@ -268,6 +269,9 @@ class StreamLoopManager:
|
|||||||
# 流启动锁:防止并发启动同一个流的多个任务
|
# 流启动锁:防止并发启动同一个流的多个任务
|
||||||
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
# 并发控制:限制同时进行的 Chatter 处理任务数
|
||||||
|
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||||
|
|
||||||
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|||||||
@@ -104,9 +104,21 @@ class MessageManager:
|
|||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||||
return
|
return
|
||||||
# 启动 stream loop 任务(如果尚未启动)
|
|
||||||
await stream_loop_manager.start_stream_loop(stream_id)
|
# 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await
|
||||||
await self._check_and_handle_interruption(chat_stream, message)
|
context = chat_stream.context
|
||||||
|
if not (context.stream_loop_task and not context.stream_loop_task.done()):
|
||||||
|
# 异步启动驱动器任务;避免在高并发下阻塞消息入队
|
||||||
|
context.stream_loop_task = asyncio.create_task(
|
||||||
|
stream_loop_manager.start_stream_loop(stream_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 并行触发打断检查,不阻塞消息入队
|
||||||
|
context.interruption_task = asyncio.create_task(
|
||||||
|
self._check_and_handle_interruption(chat_stream, message)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 入队消息
|
||||||
await chat_stream.context.add_message(message)
|
await chat_stream.context.add_message(message)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -476,8 +488,7 @@ class MessageManager:
|
|||||||
is_processing: 是否正在处理
|
is_processing: 是否正在处理
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 尝试更新StreamContext的处理状态
|
# 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入)
|
||||||
import asyncio
|
|
||||||
async def _update_context():
|
async def _update_context():
|
||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
@@ -492,7 +503,7 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
asyncio.create_task(_update_context())
|
self._update_context_task = asyncio.create_task(_update_context())
|
||||||
else:
|
else:
|
||||||
# 如果事件循环未运行,则跳过
|
# 如果事件循环未运行,则跳过
|
||||||
logger.debug("事件循环未运行,跳过StreamContext状态更新")
|
logger.debug("事件循环未运行,跳过StreamContext状态更新")
|
||||||
@@ -512,8 +523,7 @@ class MessageManager:
|
|||||||
bool: 是否正在处理
|
bool: 是否正在处理
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 尝试从StreamContext获取处理状态
|
# 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入)
|
||||||
import asyncio
|
|
||||||
async def _get_context_status():
|
async def _get_context_status():
|
||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
|
|||||||
Reference in New Issue
Block a user