feat: 优化事件管理,添加事件处理超时和并发限制功能

This commit is contained in:
Windpicker-owo
2025-11-19 01:26:23 +08:00
parent 22d43ede8c
commit 8c6242026d
7 changed files with 169 additions and 43 deletions

View File

@@ -30,7 +30,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
if message.chat_stream: if message.chat_stream:
await event_manager.trigger_event( event_manager.emit_event(
EventType.AFTER_SEND, EventType.AFTER_SEND,
permission_group="SYSTEM", permission_group="SYSTEM",
stream_id=message.chat_stream.stream_id, stream_id=message.chat_stream.stream_id,

View File

@@ -269,6 +269,16 @@ class ToolConfig(ValidatedConfigBase):
"""工具配置类""" """工具配置类"""
enable_tool: bool = Field(default=False, description="启用工具") enable_tool: bool = Field(default=False, description="启用工具")
force_parallel_execution: bool = Field(
default=True,
description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD>LLM<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ͬʱ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ҫʹ<EFBFBD>ù<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱǿ<EFBFBD><EFBFBD>ʹ<EFBFBD>ò<EFBFBD><EFBFBD><EFBFBD>ģʽ<EFBFBD><EFBFBD>ֹ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ",
)
max_parallel_invocations: int = Field(
default=5, ge=1, le=50, description="<EFBFBD><EFBFBD>ͬһ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>п<EFBFBD><EFBFBD>Խ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ܹ<EFBFBD><EFBFBD>ߵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"
)
tool_timeout: float = Field(
default=60.0, ge=1.0, le=600.0, description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ߵ<EFBFBD><EFBFBD>õij<EFBFBD>ʱʱ<EFBFBD><EFBFBD>"
)
class VoiceConfig(ValidatedConfigBase): class VoiceConfig(ValidatedConfigBase):
@@ -779,7 +789,13 @@ class PluginHttpSystemConfig(ValidatedConfigBase):
default="100/minute", description="插件API的默认速率限制策略" default="100/minute", description="插件API的默认速率限制策略"
) )
plugin_api_valid_keys: list[str] = Field( plugin_api_valid_keys: list[str] = Field(
default_factory=list, description="有效的API密钥列表用于插件认证" default_factory=list, description="<EFBFBD><EFBFBD>Ч<EFBFBD><EFBFBD>API<EFBFBD><EFBFBD>Կ<EFBFBD>б<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڲ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤"
)
event_handler_timeout: float = Field(
default=30.0, ge=1.0, le=300.0, description="<EFBFBD>¼<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ִ<EFBFBD>г<EFBFBD>ʱʱ<EFBFBD><EFBFBD>"
)
event_handler_max_concurrency: int = Field(
default=20, ge=1, le=200, description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ÿ<EFBFBD><EFBFBD><EFBFBD>¼<EFBFBD>ͬʱִ<EFBFBD>е<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>߸<EFBFBD><EFBFBD><EFBFBD>0<EFBFBD><EFBFBD>ʾ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"
) )

View File

@@ -101,7 +101,9 @@ class BaseEvent:
def __name__(self): def __name__(self):
return self.name return self.name
async def activate(self, params: dict) -> HandlerResultsCollection: async def activate(
self, params: dict, handler_timeout: float | None = None, max_concurrency: int | None = None
) -> HandlerResultsCollection:
"""激活事件,执行所有订阅的处理器 """激活事件,执行所有订阅的处理器
Args: Args:
@@ -115,40 +117,71 @@ class BaseEvent:
# 使用锁确保同一个事件不能同时激活多次 # 使用锁确保同一个事件不能同时激活多次
async with self.event_handle_lock: async with self.event_handle_lock:
# 按权重从高到低排序订阅者
# 使用直接属性访问,-1代表自动权重
sorted_subscribers = sorted( sorted_subscribers = sorted(
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
) )
# 并行执行所有订阅者 if not sorted_subscribers:
tasks = [] return HandlerResultsCollection([])
for subscriber in sorted_subscribers:
# 为每个订阅者创建执行任务
task = self._execute_subscriber(subscriber, params)
tasks.append(task)
# 等待所有任务完成 concurrency_limit = None
results = await asyncio.gather(*tasks, return_exceptions=True) if max_concurrency is not None:
concurrency_limit = max_concurrency if max_concurrency > 0 else None
if concurrency_limit:
concurrency_limit = min(concurrency_limit, len(sorted_subscribers))
# 处理执行结果 semaphore = (
processed_results = [] asyncio.Semaphore(concurrency_limit)
for i, result in enumerate(results): if concurrency_limit and concurrency_limit < len(sorted_subscribers)
subscriber = sorted_subscribers[i] else None
)
async def _run_handler(subscriber):
handler_name = ( handler_name = (
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__ subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
) )
if result:
if isinstance(result, Exception): async def _invoke():
# 处理执行异常 return await self._execute_subscriber(subscriber, params)
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
processed_results.append(HandlerResult(False, True, str(result), handler_name)) try:
if handler_timeout and handler_timeout > 0:
result = await asyncio.wait_for(_invoke(), timeout=handler_timeout)
else: else:
# 正常执行结果 result = await _invoke()
if not result.handler_name: except asyncio.TimeoutError:
# 补充handler_name logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)")
result.handler_name = handler_name return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name)
processed_results.append(result) except Exception as exc:
logger.error(f"事件处理器 {handler_name} 执行失败: {exc}")
return HandlerResult(False, True, str(exc), handler_name)
if not isinstance(result, HandlerResult):
return HandlerResult(True, True, result, handler_name)
if not result.handler_name:
result.handler_name = handler_name
return result
async def _guarded_run(subscriber):
if semaphore:
async with semaphore:
return await _run_handler(subscriber)
return await _run_handler(subscriber)
tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers]
results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results: list[HandlerResult] = []
for subscriber, result in zip(sorted_subscribers, results):
handler_name = (
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
)
if isinstance(result, Exception):
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
processed_results.append(HandlerResult(False, True, str(result), handler_name))
else:
processed_results.append(result)
return HandlerResultsCollection(processed_results) return HandlerResultsCollection(processed_results)

View File

@@ -7,6 +7,7 @@ from threading import Lock
from typing import Any, Optional from typing import Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType
@@ -40,6 +41,15 @@ class EventManager:
self._event_handlers: dict[str, BaseEventHandler] = {} self._event_handlers: dict[str, BaseEventHandler] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._scheduler_callback: Any | None = None # scheduler 回调函数 self._scheduler_callback: Any | None = None # scheduler 回调函数
plugin_cfg = getattr(global_config, "plugin_http_system", None)
self._default_handler_timeout: float | None = (
getattr(plugin_cfg, "event_handler_timeout", 30.0) if plugin_cfg else 30.0
)
default_concurrency = getattr(plugin_cfg, "event_handler_max_concurrency", None) if plugin_cfg else None
self._default_handler_concurrency: int | None = (
default_concurrency if default_concurrency and default_concurrency > 0 else None
)
self._background_tasks: set[asyncio.Task[Any]] = set()
self._initialized = True self._initialized = True
logger.info("EventManager 单例初始化完成") logger.info("EventManager 单例初始化完成")
@@ -293,7 +303,13 @@ class EventManager:
return {handler.handler_name: handler for handler in event.subscribers} return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event( async def trigger_event(
self, event_name: EventType | str, permission_group: str | None = "", **kwargs self,
event_name: EventType | str,
permission_group: str | None = "",
*,
handler_timeout: float | None = None,
max_concurrency: int | None = None,
**kwargs,
) -> HandlerResultsCollection | None: ) -> HandlerResultsCollection | None:
"""触发指定事件 """触发指定事件
@@ -328,7 +344,10 @@ class EventManager:
except Exception as e: except Exception as e:
logger.error(f"调用 scheduler 回调时出错: {e}", exc_info=True) logger.error(f"调用 scheduler 回调时出错: {e}", exc_info=True)
return await event.activate(params) timeout = handler_timeout if handler_timeout is not None else self._default_handler_timeout
concurrency = max_concurrency if max_concurrency is not None else self._default_handler_concurrency
return await event.activate(params, handler_timeout=timeout, max_concurrency=concurrency)
def register_scheduler_callback(self, callback) -> None: def register_scheduler_callback(self, callback) -> None:
"""注册 scheduler 回调函数 """注册 scheduler 回调函数
@@ -344,6 +363,35 @@ class EventManager:
self._scheduler_callback = None self._scheduler_callback = None
logger.info("Scheduler 回调已取消注册") logger.info("Scheduler 回调已取消注册")
def emit_event(
self,
event_name: EventType | str,
permission_group: str | None = "",
*,
handler_timeout: float | None = None,
max_concurrency: int | None = None,
**kwargs,
) -> asyncio.Task[Any] | None:
"""调度事件但不等待结果,返回后台任务对象"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
logger.warning(f"调度事件 {event_name} 失败:当前没有运行中的事件循环")
return None
task = loop.create_task(
self.trigger_event(
event_name,
permission_group=permission_group,
handler_timeout=handler_timeout,
max_concurrency=max_concurrency,
**kwargs,
),
name=f"event::{event_name}",
)
self._track_background_task(task)
return task
def init_default_events(self) -> None: def init_default_events(self) -> None:
"""初始化默认事件""" """初始化默认事件"""
default_events = [ default_events = [
@@ -437,5 +485,18 @@ class EventManager:
return processed_count return processed_count
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
"""跟踪后台事件任务,避免被 GC 清理"""
self._background_tasks.add(task)
def _cleanup(fut: asyncio.Task[Any]) -> None:
self._background_tasks.discard(fut)
task.add_done_callback(_cleanup)
def get_background_task_count(self) -> int:
"""返回当前仍在运行的后台事件任务数量"""
return len(self._background_tasks)
# 创建全局事件管理器实例 # 创建全局事件管理器实例
event_manager = EventManager() event_manager = EventManager()

View File

@@ -108,6 +108,8 @@ class ToolExecutor:
""" """
self.chat_id = chat_id self.chat_id = chat_id
self.execution_config = execution_config or ToolExecutionConfig() self.execution_config = execution_config or ToolExecutionConfig()
if execution_config is None:
self._apply_config_defaults()
# chat_stream 和 log_prefix 将在异步方法中初始化 # chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream = None # type: ignore self.chat_stream = None # type: ignore
@@ -115,6 +117,20 @@ class ToolExecutor:
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
def _apply_config_defaults(self) -> None:
tool_cfg = getattr(global_config, "tool", None)
if not tool_cfg:
return
if hasattr(tool_cfg, "force_parallel_execution"):
self.execution_config.enable_parallel = bool(tool_cfg.force_parallel_execution)
max_invocations = getattr(tool_cfg, "max_parallel_invocations", None)
if max_invocations:
self.execution_config.max_concurrent_tools = max(1, max_invocations)
timeout = getattr(tool_cfg, "tool_timeout", None)
if timeout:
self.execution_config.tool_timeout = max(1.0, float(timeout))
# 二步工具调用状态管理 # 二步工具调用状态管理
self._pending_step_two_tools: dict[str, dict[str, Any]] = {} self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""

View File

@@ -356,7 +356,7 @@ class MessageHandler:
case RealMessageType.text: case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message) ret_seg = await self.handle_text_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.TEXT, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.TEXT, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -365,7 +365,7 @@ class MessageHandler:
case RealMessageType.face: case RealMessageType.face:
ret_seg = await self.handle_face_message(sub_message) ret_seg = await self.handle_face_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.FACE, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.FACE, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -375,7 +375,7 @@ class MessageHandler:
if not in_reply: if not in_reply:
ret_seg = await self.handle_reply_message(sub_message) ret_seg = await self.handle_reply_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.REPLY, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.REPLY, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message += ret_seg seg_message += ret_seg
@@ -385,7 +385,7 @@ class MessageHandler:
logger.debug("开始处理图片消息段") logger.debug("开始处理图片消息段")
ret_seg = await self.handle_image_message(sub_message) ret_seg = await self.handle_image_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.IMAGE, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.IMAGE, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -396,7 +396,7 @@ class MessageHandler:
case RealMessageType.record: case RealMessageType.record:
ret_seg = await self.handle_record_message(sub_message) ret_seg = await self.handle_record_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.RECORD, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.RECORD, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.clear() seg_message.clear()
@@ -408,7 +408,7 @@ class MessageHandler:
logger.debug(f"开始处理VIDEO消息段: {sub_message}") logger.debug(f"开始处理VIDEO消息段: {sub_message}")
ret_seg = await self.handle_video_message(sub_message) ret_seg = await self.handle_video_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.VIDEO, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.VIDEO, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -422,7 +422,7 @@ class MessageHandler:
raw_message.get("group_id"), raw_message.get("group_id"),
) )
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.AT, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.AT, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -431,7 +431,7 @@ class MessageHandler:
case RealMessageType.rps: case RealMessageType.rps:
ret_seg = await self.handle_rps_message(sub_message) ret_seg = await self.handle_rps_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.RPS, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.RPS, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -440,7 +440,7 @@ class MessageHandler:
case RealMessageType.dice: case RealMessageType.dice:
ret_seg = await self.handle_dice_message(sub_message) ret_seg = await self.handle_dice_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.DICE, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.DICE, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -449,7 +449,7 @@ class MessageHandler:
case RealMessageType.shake: case RealMessageType.shake:
ret_seg = await self.handle_shake_message(sub_message) ret_seg = await self.handle_shake_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.SHAKE, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.SHAKE, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)
@@ -478,7 +478,7 @@ class MessageHandler:
case RealMessageType.json: case RealMessageType.json:
ret_seg = await self.handle_json_message(sub_message) ret_seg = await self.handle_json_message(sub_message)
if ret_seg: if ret_seg:
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.JSON, permission_group=PLUGIN_NAME, message_seg=ret_seg NapcatEvent.ON_RECEIVED.JSON, permission_group=PLUGIN_NAME, message_seg=ret_seg
) )
seg_message.append(ret_seg) seg_message.append(ret_seg)

View File

@@ -133,7 +133,7 @@ class NoticeHandler:
from ...event_types import NapcatEvent from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) event_manager.emit_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
case _: case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_msg_emoji_like: case NoticeType.group_msg_emoji_like:
@@ -376,7 +376,7 @@ class NoticeHandler:
) )
like_emoji_id = raw_message.get("likes")[0].get("emoji_id") like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event( event_manager.emit_event(
NapcatEvent.ON_RECEIVED.EMOJI_LIEK, NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME, permission_group=PLUGIN_NAME,
group_id=group_id, group_id=group_id,
@@ -702,4 +702,4 @@ class NoticeHandler:
await asyncio.sleep(1) await asyncio.sleep(1)
notice_handler = NoticeHandler() notice_handler = NoticeHandler()