优化消息管理
This commit is contained in:
306
scripts/benchmark_distribution_manager.py
Normal file
306
scripts/benchmark_distribution_manager.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior
|
||||
# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems.
|
||||
|
||||
# Ensure project root is on sys.path when running as a script
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if ROOT_DIR not in sys.path:
|
||||
sys.path.insert(0, ROOT_DIR)
|
||||
|
||||
# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis)
|
||||
# Local minimal implementation of loop and manager to isolate benchmark from heavy deps.
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTick:
|
||||
stream_id: str
|
||||
tick_time: float = field(default_factory=time.time)
|
||||
force_dispatch: bool = False
|
||||
tick_count: int = 0
|
||||
|
||||
|
||||
async def conversation_loop(
|
||||
stream_id: str,
|
||||
get_context_func: Callable[[str], Awaitable["DummyContext | None"]],
|
||||
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||
check_force_dispatch_func: Callable[["DummyContext", int], bool],
|
||||
is_running_func: Callable[[], bool],
|
||||
) -> AsyncIterator[ConversationTick]:
|
||||
tick_count = 0
|
||||
while is_running_func():
|
||||
ctx = await get_context_func(stream_id)
|
||||
if not ctx:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
await flush_cache_func(stream_id)
|
||||
unread = ctx.get_unread_messages()
|
||||
ucnt = len(unread)
|
||||
force = check_force_dispatch_func(ctx, ucnt)
|
||||
if ucnt > 0 or force:
|
||||
tick_count += 1
|
||||
yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count)
|
||||
interval = await calculate_interval_func(stream_id, ucnt > 0)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
class StreamLoopManager:
|
||||
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
"total_loops": 0,
|
||||
"total_process_cycles": 0,
|
||||
"total_failures": 0,
|
||||
"start_time": time.time(),
|
||||
}
|
||||
self.max_concurrent_streams = max_concurrent_streams or 100
|
||||
self.force_dispatch_unread_threshold = 20
|
||||
self.chatter_manager = DummyChatterManager()
|
||||
self.is_running = False
|
||||
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||
self._chat_manager: Optional[DummyChatManager] = None
|
||||
|
||||
def set_chat_manager(self, chat_manager: "DummyChatManager") -> None:
|
||||
self._chat_manager = chat_manager
|
||||
|
||||
async def start(self):
|
||||
self.is_running = True
|
||||
|
||||
async def stop(self):
|
||||
self.is_running = False
|
||||
|
||||
async def _get_stream_context(self, stream_id: str):
|
||||
assert self._chat_manager is not None
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
return stream.context if stream else None
|
||||
|
||||
async def _flush_cached_messages_to_unread(self, stream_id: str):
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
return ctx.flush_cached_messages() if ctx else []
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool:
|
||||
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool:
|
||||
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||
return bool(res.get("success", False))
|
||||
|
||||
async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None:
|
||||
pass
|
||||
|
||||
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
|
||||
return 0.005 if has_messages else 0.02
|
||||
|
||||
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
if not ctx:
|
||||
return False
|
||||
# create driver
|
||||
loop_task = asyncio.create_task(run_chat_stream(stream_id, self))
|
||||
ctx.stream_loop_task = loop_task
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
return True
|
||||
|
||||
|
||||
async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None:
|
||||
try:
|
||||
gen = conversation_loop(
|
||||
stream_id=stream_id,
|
||||
get_context_func=manager._get_stream_context,
|
||||
calculate_interval_func=manager._calculate_interval,
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread,
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
||||
is_running_func=lambda: manager.is_running,
|
||||
)
|
||||
async for tick in gen:
|
||||
ctx = await manager._get_stream_context(stream_id)
|
||||
if not ctx:
|
||||
continue
|
||||
if ctx.is_chatter_processing:
|
||||
continue
|
||||
try:
|
||||
async with manager._processing_semaphore:
|
||||
ok = await manager._process_stream_messages(stream_id, ctx)
|
||||
except Exception:
|
||||
ok = False
|
||||
manager.stats["total_process_cycles"] += 1
|
||||
if not ok:
|
||||
manager.stats["total_failures"] += 1
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyMessage:
|
||||
time: float
|
||||
processed_plain_text: str = ""
|
||||
display_message: str = ""
|
||||
is_at: bool = False
|
||||
is_mentioned: bool = False
|
||||
|
||||
|
||||
class DummyContext:
|
||||
def __init__(self, stream_id: str, initial_unread: int):
|
||||
self.stream_id = stream_id
|
||||
self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)]
|
||||
self.history_messages: list[DummyMessage] = []
|
||||
self.is_chatter_processing: bool = False
|
||||
self.processing_task: Optional[asyncio.Task] = None
|
||||
self.stream_loop_task: Optional[asyncio.Task] = None
|
||||
self.triggering_user_id: Optional[str] = None
|
||||
|
||||
def get_unread_messages(self) -> list[DummyMessage]:
|
||||
return list(self.unread_messages)
|
||||
|
||||
def flush_cached_messages(self) -> list[DummyMessage]:
|
||||
return []
|
||||
|
||||
def get_last_message(self) -> Optional[DummyMessage]:
|
||||
return self.unread_messages[-1] if self.unread_messages else None
|
||||
|
||||
def get_history_messages(self, limit: int = 50) -> list[DummyMessage]:
|
||||
return self.history_messages[-limit:]
|
||||
|
||||
|
||||
class DummyStream:
|
||||
def __init__(self, stream_id: str, ctx: DummyContext):
|
||||
self.stream_id = stream_id
|
||||
self.context = ctx
|
||||
self.group_info = None # treat as private chat to accelerate
|
||||
self._focus_energy = 0.5
|
||||
|
||||
|
||||
class DummyChatManager:
|
||||
def __init__(self, streams: dict[str, DummyStream]):
|
||||
self._streams = streams
|
||||
|
||||
async def get_stream(self, stream_id: str) -> Optional[DummyStream]:
|
||||
return self._streams.get(stream_id)
|
||||
|
||||
def get_all_streams(self) -> dict[str, DummyStream]:
|
||||
return self._streams
|
||||
|
||||
|
||||
class DummyChatterManager:
|
||||
async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]:
|
||||
# Simulate some processing latency and consume one unread message
|
||||
await asyncio.sleep(0.01)
|
||||
if context.unread_messages:
|
||||
context.unread_messages.pop(0)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
class BenchStreamLoopManager(StreamLoopManager):
|
||||
def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None):
|
||||
super().__init__(max_concurrent_streams=max_concurrent_streams)
|
||||
self._chat_manager = chat_manager
|
||||
self.chatter_manager = DummyChatterManager()
|
||||
|
||||
async def _get_stream_context(self, stream_id: str): # type: ignore[override]
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
return stream.context if stream else None
|
||||
|
||||
async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override]
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
return ctx.flush_cached_messages() if ctx else []
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override]
|
||||
# force when unread exceeds threshold
|
||||
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override]
|
||||
# delegate to chatter manager
|
||||
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||
return bool(res.get("success", False))
|
||||
|
||||
async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool:
|
||||
return False
|
||||
|
||||
async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override]
|
||||
# lightweight: compute based on unread size
|
||||
focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages()))
|
||||
# set for compatibility
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
if stream:
|
||||
stream._focus_energy = focus
|
||||
|
||||
|
||||
def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]:
|
||||
streams: dict[str, DummyStream] = {}
|
||||
for i in range(n_streams):
|
||||
sid = f"s{i:04d}"
|
||||
ctx = DummyContext(sid, initial_unread)
|
||||
streams[sid] = DummyStream(sid, ctx)
|
||||
return streams
|
||||
|
||||
|
||||
async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]:
|
||||
streams = make_streams(n_streams, initial_unread)
|
||||
chat_mgr = DummyChatManager(streams)
|
||||
mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent)
|
||||
await mgr.start()
|
||||
|
||||
# start loops for all streams
|
||||
start_ts = time.time()
|
||||
for sid in list(streams.keys()):
|
||||
await mgr.start_stream_loop(sid, force=True)
|
||||
|
||||
# run until all unread consumed or timeout
|
||||
timeout = 5.0
|
||||
end_deadline = start_ts + timeout
|
||||
while time.time() < end_deadline:
|
||||
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||
if remaining == 0:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
duration = time.time() - start_ts
|
||||
total_cycles = mgr.stats.get("total_process_cycles", 0)
|
||||
total_failures = mgr.stats.get("total_failures", 0)
|
||||
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||
|
||||
# stop all
|
||||
await mgr.stop()
|
||||
|
||||
return {
|
||||
"n_streams": n_streams,
|
||||
"initial_unread": initial_unread,
|
||||
"max_concurrent": max_concurrent,
|
||||
"duration_sec": duration,
|
||||
"total_cycles": total_cycles,
|
||||
"total_failures": total_failures,
|
||||
"remaining_unread": remaining,
|
||||
"throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration),
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
cases = [
|
||||
(50, 5, None), # baseline using configured default
|
||||
(50, 5, 5), # constrained concurrency
|
||||
(50, 5, 10), # moderate concurrency
|
||||
(100, 3, 10), # scale streams
|
||||
]
|
||||
|
||||
print("Running distribution manager benchmark...\n")
|
||||
for n_streams, initial_unread, max_concurrent in cases:
|
||||
res = await run_benchmark(n_streams, initial_unread, max_concurrent)
|
||||
print(
|
||||
f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | "
|
||||
f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+
|
||||
f"thr={res['throughput_msgs_per_sec']:.1f}/s"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -6,8 +6,6 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class PerformanceBenchmark:
|
||||
@@ -20,7 +18,7 @@ class PerformanceBenchmark:
|
||||
"""测试查询去重性能"""
|
||||
# 这里需要导入实际的管理器
|
||||
# from src.memory_graph.unified_manager import UnifiedMemoryManager
|
||||
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "small_queries",
|
||||
@@ -67,7 +65,7 @@ class PerformanceBenchmark:
|
||||
seen = set()
|
||||
decay = 0.15
|
||||
manual_queries = []
|
||||
|
||||
|
||||
for raw in queries:
|
||||
text = (raw or "").strip()
|
||||
if text and text not in seen:
|
||||
@@ -192,7 +190,7 @@ class PerformanceBenchmark:
|
||||
mem_id = mem.get("id")
|
||||
else:
|
||||
mem_id = getattr(mem, "id", None)
|
||||
|
||||
|
||||
if mem_id and mem_id in seen_ids:
|
||||
continue
|
||||
unique_memories.append(mem)
|
||||
|
||||
@@ -9,6 +9,8 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
@@ -159,20 +161,27 @@ class BatchDatabaseWriter:
|
||||
logger.info("批量写入循环结束")
|
||||
|
||||
async def _collect_batch(self) -> list[StreamUpdatePayload]:
|
||||
"""收集一个批次的数据"""
|
||||
batch = []
|
||||
deadline = time.time() + self.flush_interval
|
||||
"""收集一个批次的数据
|
||||
- 自适应刷新:队列增长加快时缩短等待时间
|
||||
- 避免长时间空转:添加轻微抖动以分散竞争
|
||||
"""
|
||||
batch: list[StreamUpdatePayload] = []
|
||||
# 根据当前队列长度调整刷新时间(最多缩短到 40%)
|
||||
qsize = self.write_queue.qsize()
|
||||
adapt_factor = 1.0
|
||||
if qsize > 0:
|
||||
adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize)))
|
||||
deadline = time.time() + (self.flush_interval * adapt_factor)
|
||||
|
||||
while len(batch) < self.batch_size and time.time() < deadline:
|
||||
try:
|
||||
# 计算剩余等待时间
|
||||
remaining_time = max(0, deadline - time.time())
|
||||
remaining_time = max(0.0, deadline - time.time())
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
||||
# 轻微抖动,避免多个协程同时争抢队列
|
||||
jitter = 0.002
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
@@ -208,48 +217,52 @@ class BatchDatabaseWriter:
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.stats["failed_writes"] += 1
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
# 降级到单个写入
|
||||
for payload in batch:
|
||||
try:
|
||||
await self._direct_write(payload.stream_id, payload.update_data)
|
||||
except Exception as single_e:
|
||||
except SQLAlchemyError as single_e:
|
||||
logger.error(f"单个写入也失败: {single_e}")
|
||||
|
||||
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
"""批量写入数据库(单事务、多值 UPSERT)"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
if not payloads:
|
||||
return
|
||||
|
||||
# 预组装行数据,确保每行包含 stream_id
|
||||
rows: list[dict[str, Any]] = []
|
||||
for p in payloads:
|
||||
row = {"stream_id": p.stream_id}
|
||||
row.update(p.update_data)
|
||||
rows.append(row)
|
||||
|
||||
async with get_db_session() as session:
|
||||
for payload in payloads:
|
||||
stream_id = payload.stream_id
|
||||
update_data = payload.update_data
|
||||
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=update_data
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
# 使用单次事务提交,显著减少 I/O
|
||||
if global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
stmt = pg_insert(ChatStreams).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
else:
|
||||
# 默认(sqlite)
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
||||
"""直接写入数据库(降级方案)"""
|
||||
if global_config is None:
|
||||
|
||||
@@ -55,7 +55,7 @@ async def conversation_loop(
|
||||
stream_id: str,
|
||||
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
|
||||
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||
flush_cache_func: Callable[[str], Awaitable[None]],
|
||||
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||
check_force_dispatch_func: Callable[["StreamContext", int], bool],
|
||||
is_running_func: Callable[[], bool],
|
||||
) -> AsyncIterator[ConversationTick]:
|
||||
@@ -121,7 +121,7 @@ async def conversation_loop(
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
@@ -151,10 +151,10 @@ async def run_chat_stream(
|
||||
# 创建生成器
|
||||
tick_generator = conversation_loop(
|
||||
stream_id=stream_id,
|
||||
get_context_func=manager._get_stream_context,
|
||||
calculate_interval_func=manager._calculate_interval,
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread,
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
||||
get_context_func=manager._get_stream_context, # noqa: SLF001
|
||||
calculate_interval_func=manager._calculate_interval, # noqa: SLF001
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001
|
||||
is_running_func=lambda: manager.is_running,
|
||||
)
|
||||
|
||||
@@ -162,13 +162,13 @@ async def run_chat_stream(
|
||||
async for tick in tick_generator:
|
||||
try:
|
||||
# 获取上下文
|
||||
context = await manager._get_stream_context(stream_id)
|
||||
context = await manager._get_stream_context(stream_id) # noqa: SLF001
|
||||
if not context:
|
||||
continue
|
||||
|
||||
# 并发保护:检查是否正在处理
|
||||
if context.is_chatter_processing:
|
||||
if manager._recover_stale_chatter_state(stream_id, context):
|
||||
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001
|
||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
|
||||
else:
|
||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick")
|
||||
@@ -182,17 +182,18 @@ async def run_chat_stream(
|
||||
|
||||
# 更新能量值
|
||||
try:
|
||||
await manager._update_stream_energy(stream_id, context)
|
||||
await manager._update_stream_energy(stream_id, context) # noqa: SLF001
|
||||
except Exception as e:
|
||||
logger.debug(f"更新能量失败: {e}")
|
||||
|
||||
# 处理消息
|
||||
assert global_config is not None
|
||||
try:
|
||||
success = await asyncio.wait_for(
|
||||
manager._process_stream_messages(stream_id, context),
|
||||
global_config.chat.thinking_timeout
|
||||
)
|
||||
async with manager._processing_semaphore:
|
||||
success = await asyncio.wait_for(
|
||||
manager._process_stream_messages(stream_id, context), # noqa: SLF001
|
||||
global_config.chat.thinking_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
|
||||
success = False
|
||||
@@ -208,7 +209,7 @@ async def run_chat_stream(
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
|
||||
manager.stats["total_failures"] += 1
|
||||
|
||||
@@ -221,7 +222,7 @@ async def run_chat_stream(
|
||||
if context and context.stream_loop_task:
|
||||
context.stream_loop_task = None
|
||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"清理任务记录失败: {e}")
|
||||
|
||||
|
||||
@@ -268,6 +269,9 @@ class StreamLoopManager:
|
||||
# 流启动锁:防止并发启动同一个流的多个任务
|
||||
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# 并发控制:限制同时进行的 Chatter 处理任务数
|
||||
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||
|
||||
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
||||
|
||||
# ========================================================================
|
||||
|
||||
@@ -104,9 +104,21 @@ class MessageManager:
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
# 启动 stream loop 任务(如果尚未启动)
|
||||
await stream_loop_manager.start_stream_loop(stream_id)
|
||||
await self._check_and_handle_interruption(chat_stream, message)
|
||||
|
||||
# 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await
|
||||
context = chat_stream.context
|
||||
if not (context.stream_loop_task and not context.stream_loop_task.done()):
|
||||
# 异步启动驱动器任务;避免在高并发下阻塞消息入队
|
||||
context.stream_loop_task = asyncio.create_task(
|
||||
stream_loop_manager.start_stream_loop(stream_id)
|
||||
)
|
||||
|
||||
# 并行触发打断检查,不阻塞消息入队
|
||||
context.interruption_task = asyncio.create_task(
|
||||
self._check_and_handle_interruption(chat_stream, message)
|
||||
)
|
||||
|
||||
# 入队消息
|
||||
await chat_stream.context.add_message(message)
|
||||
|
||||
except Exception as e:
|
||||
@@ -476,8 +488,7 @@ class MessageManager:
|
||||
is_processing: 是否正在处理
|
||||
"""
|
||||
try:
|
||||
# 尝试更新StreamContext的处理状态
|
||||
import asyncio
|
||||
# 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入)
|
||||
async def _update_context():
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
@@ -492,7 +503,7 @@ class MessageManager:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(_update_context())
|
||||
self._update_context_task = asyncio.create_task(_update_context())
|
||||
else:
|
||||
# 如果事件循环未运行,则跳过
|
||||
logger.debug("事件循环未运行,跳过StreamContext状态更新")
|
||||
@@ -512,8 +523,7 @@ class MessageManager:
|
||||
bool: 是否正在处理
|
||||
"""
|
||||
try:
|
||||
# 尝试从StreamContext获取处理状态
|
||||
import asyncio
|
||||
# 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入)
|
||||
async def _get_context_status():
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
@@ -451,7 +451,7 @@ class UnifiedMemoryManager:
|
||||
(0.3, 10.0, 0.4),
|
||||
(0.1, 15.0, 0.6),
|
||||
]
|
||||
|
||||
|
||||
for threshold, min_val, factor in occupancy_thresholds:
|
||||
if occupancy >= threshold:
|
||||
return max(min_val, base_interval * factor)
|
||||
@@ -461,24 +461,24 @@ class UnifiedMemoryManager:
|
||||
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
|
||||
"""实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)"""
|
||||
logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块")
|
||||
|
||||
|
||||
# 优化:使用 asyncio.gather 并行处理转移
|
||||
async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]:
|
||||
try:
|
||||
stm = await self.short_term_manager.add_from_block(block)
|
||||
if not stm:
|
||||
return block, False
|
||||
|
||||
|
||||
await self.perceptual_manager.remove_block(block.id)
|
||||
logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
|
||||
return block, True
|
||||
except Exception as exc:
|
||||
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}")
|
||||
return block, False
|
||||
|
||||
|
||||
# 并行处理所有块
|
||||
results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True)
|
||||
|
||||
|
||||
# 统计成功的转移
|
||||
success_count = sum(1 for result in results if isinstance(result, tuple) and result[1])
|
||||
if success_count > 0:
|
||||
@@ -491,7 +491,7 @@ class UnifiedMemoryManager:
|
||||
seen = set()
|
||||
decay = 0.15
|
||||
manual_queries: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
for raw in queries:
|
||||
text = (raw or "").strip()
|
||||
if text and text not in seen:
|
||||
@@ -517,7 +517,7 @@ class UnifiedMemoryManager:
|
||||
"top_k": self._config["long_term"]["search_top_k"],
|
||||
"use_multi_query": bool(manual_queries),
|
||||
}
|
||||
|
||||
|
||||
if recent_chat_history or manual_queries:
|
||||
context: dict[str, Any] = {}
|
||||
if recent_chat_history:
|
||||
@@ -541,7 +541,7 @@ class UnifiedMemoryManager:
|
||||
mem_id = mem.get("id")
|
||||
else:
|
||||
mem_id = getattr(mem, "id", None)
|
||||
|
||||
|
||||
# 检查去重
|
||||
if mem_id and mem_id in seen_ids:
|
||||
continue
|
||||
@@ -600,7 +600,7 @@ class UnifiedMemoryManager:
|
||||
new_memories.append(memory)
|
||||
if mem_id:
|
||||
cached_ids.add(mem_id)
|
||||
|
||||
|
||||
if new_memories:
|
||||
transfer_cache.extend(new_memories)
|
||||
logger.debug(
|
||||
@@ -632,7 +632,7 @@ class UnifiedMemoryManager:
|
||||
await self.short_term_manager.clear_transferred_memories(
|
||||
result["transferred_memory_ids"]
|
||||
)
|
||||
|
||||
|
||||
# 优化:使用生成器表达式保留未转移的记忆
|
||||
transfer_cache = [
|
||||
m
|
||||
|
||||
Reference in New Issue
Block a user