feat: 优化事件管理,添加事件处理超时和并发限制功能

This commit is contained in:
Windpicker-owo
2025-11-19 01:26:23 +08:00
parent 22d43ede8c
commit 8c6242026d
7 changed files with 169 additions and 43 deletions

View File

@@ -101,7 +101,9 @@ class BaseEvent:
def __name__(self):
return self.name
async def activate(self, params: dict) -> HandlerResultsCollection:
async def activate(
self, params: dict, handler_timeout: float | None = None, max_concurrency: int | None = None
) -> HandlerResultsCollection:
"""激活事件,执行所有订阅的处理器
Args:
@@ -115,40 +117,71 @@ class BaseEvent:
# 使用锁确保同一个事件不能同时激活多次
async with self.event_handle_lock:
# 按权重从高到低排序订阅者
# 使用直接属性访问,-1代表自动权重
sorted_subscribers = sorted(
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
)
# 并行执行所有订阅者
tasks = []
for subscriber in sorted_subscribers:
# 为每个订阅者创建执行任务
task = self._execute_subscriber(subscriber, params)
tasks.append(task)
if not sorted_subscribers:
return HandlerResultsCollection([])
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
concurrency_limit = None
if max_concurrency is not None:
concurrency_limit = max_concurrency if max_concurrency > 0 else None
if concurrency_limit:
concurrency_limit = min(concurrency_limit, len(sorted_subscribers))
# 处理执行结果
processed_results = []
for i, result in enumerate(results):
subscriber = sorted_subscribers[i]
semaphore = (
asyncio.Semaphore(concurrency_limit)
if concurrency_limit and concurrency_limit < len(sorted_subscribers)
else None
)
async def _run_handler(subscriber):
handler_name = (
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
)
if result:
if isinstance(result, Exception):
# 处理执行异常
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
processed_results.append(HandlerResult(False, True, str(result), handler_name))
async def _invoke():
return await self._execute_subscriber(subscriber, params)
try:
if handler_timeout and handler_timeout > 0:
result = await asyncio.wait_for(_invoke(), timeout=handler_timeout)
else:
# 正常执行结果
if not result.handler_name:
# 补充handler_name
result.handler_name = handler_name
processed_results.append(result)
result = await _invoke()
except asyncio.TimeoutError:
logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)")
return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name)
except Exception as exc:
logger.error(f"事件处理器 {handler_name} 执行失败: {exc}")
return HandlerResult(False, True, str(exc), handler_name)
if not isinstance(result, HandlerResult):
return HandlerResult(True, True, result, handler_name)
if not result.handler_name:
result.handler_name = handler_name
return result
async def _guarded_run(subscriber):
if semaphore:
async with semaphore:
return await _run_handler(subscriber)
return await _run_handler(subscriber)
tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers]
results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results: list[HandlerResult] = []
for subscriber, result in zip(sorted_subscribers, results):
handler_name = (
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
)
if isinstance(result, Exception):
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
processed_results.append(HandlerResult(False, True, str(result), handler_name))
else:
processed_results.append(result)
return HandlerResultsCollection(processed_results)

View File

@@ -7,6 +7,7 @@ from threading import Lock
from typing import Any, Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.component_types import EventType
@@ -40,6 +41,15 @@ class EventManager:
self._event_handlers: dict[str, BaseEventHandler] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._scheduler_callback: Any | None = None # scheduler 回调函数
plugin_cfg = getattr(global_config, "plugin_http_system", None)
self._default_handler_timeout: float | None = (
getattr(plugin_cfg, "event_handler_timeout", 30.0) if plugin_cfg else 30.0
)
default_concurrency = getattr(plugin_cfg, "event_handler_max_concurrency", None) if plugin_cfg else None
self._default_handler_concurrency: int | None = (
default_concurrency if default_concurrency and default_concurrency > 0 else None
)
self._background_tasks: set[asyncio.Task[Any]] = set()
self._initialized = True
logger.info("EventManager 单例初始化完成")
@@ -293,7 +303,13 @@ class EventManager:
return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event(
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
self,
event_name: EventType | str,
permission_group: str | None = "",
*,
handler_timeout: float | None = None,
max_concurrency: int | None = None,
**kwargs,
) -> HandlerResultsCollection | None:
"""触发指定事件
@@ -328,7 +344,10 @@ class EventManager:
except Exception as e:
logger.error(f"调用 scheduler 回调时出错: {e}", exc_info=True)
return await event.activate(params)
timeout = handler_timeout if handler_timeout is not None else self._default_handler_timeout
concurrency = max_concurrency if max_concurrency is not None else self._default_handler_concurrency
return await event.activate(params, handler_timeout=timeout, max_concurrency=concurrency)
def register_scheduler_callback(self, callback) -> None:
"""注册 scheduler 回调函数
@@ -344,6 +363,35 @@ class EventManager:
self._scheduler_callback = None
logger.info("Scheduler 回调已取消注册")
def emit_event(
self,
event_name: EventType | str,
permission_group: str | None = "",
*,
handler_timeout: float | None = None,
max_concurrency: int | None = None,
**kwargs,
) -> asyncio.Task[Any] | None:
"""调度事件但不等待结果,返回后台任务对象"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
logger.warning(f"调度事件 {event_name} 失败:当前没有运行中的事件循环")
return None
task = loop.create_task(
self.trigger_event(
event_name,
permission_group=permission_group,
handler_timeout=handler_timeout,
max_concurrency=max_concurrency,
**kwargs,
),
name=f"event::{event_name}",
)
self._track_background_task(task)
return task
def init_default_events(self) -> None:
"""初始化默认事件"""
default_events = [
@@ -437,5 +485,18 @@ class EventManager:
return processed_count
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
"""跟踪后台事件任务,避免被 GC 清理"""
self._background_tasks.add(task)
def _cleanup(fut: asyncio.Task[Any]) -> None:
self._background_tasks.discard(fut)
task.add_done_callback(_cleanup)
def get_background_task_count(self) -> int:
"""返回当前仍在运行的后台事件任务数量"""
return len(self._background_tasks)
# 创建全局事件管理器实例
event_manager = EventManager()

View File

@@ -108,6 +108,8 @@ class ToolExecutor:
"""
self.chat_id = chat_id
self.execution_config = execution_config or ToolExecutionConfig()
if execution_config is None:
self._apply_config_defaults()
# chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream = None # type: ignore
@@ -115,6 +117,20 @@ class ToolExecutor:
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
def _apply_config_defaults(self) -> None:
tool_cfg = getattr(global_config, "tool", None)
if not tool_cfg:
return
if hasattr(tool_cfg, "force_parallel_execution"):
self.execution_config.enable_parallel = bool(tool_cfg.force_parallel_execution)
max_invocations = getattr(tool_cfg, "max_parallel_invocations", None)
if max_invocations:
self.execution_config.max_concurrent_tools = max(1, max_invocations)
timeout = getattr(tool_cfg, "tool_timeout", None)
if timeout:
self.execution_config.tool_timeout = max(1.0, float(timeout))
# 二步工具调用状态管理
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""