feat: 优化事件管理,添加事件处理超时和并发限制功能
This commit is contained in:
@@ -30,7 +30,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
if message.chat_stream:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
EventType.AFTER_SEND,
|
||||
permission_group="SYSTEM",
|
||||
stream_id=message.chat_stream.stream_id,
|
||||
|
||||
@@ -269,6 +269,16 @@ class ToolConfig(ValidatedConfigBase):
|
||||
"""工具配置类"""
|
||||
|
||||
enable_tool: bool = Field(default=False, description="启用工具")
|
||||
force_parallel_execution: bool = Field(
|
||||
default=True,
|
||||
description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD>LLM<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ͬʱ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ҫʹ<EFBFBD>ù<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱǿ<EFBFBD><EFBFBD>ʹ<EFBFBD>ò<EFBFBD><EFBFBD><EFBFBD>ģʽ<EFBFBD><EFBFBD>ֹ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ",
|
||||
)
|
||||
max_parallel_invocations: int = Field(
|
||||
default=5, ge=1, le=50, description="<EFBFBD><EFBFBD>ͬһ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>п<EFBFBD><EFBFBD>Խ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ܹ<EFBFBD><EFBFBD>ߵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"
|
||||
)
|
||||
tool_timeout: float = Field(
|
||||
default=60.0, ge=1.0, le=600.0, description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ߵ<EFBFBD><EFBFBD>õij<EFBFBD>ʱʱ<EFBFBD>䣨<EFBFBD>룩"
|
||||
)
|
||||
|
||||
|
||||
class VoiceConfig(ValidatedConfigBase):
|
||||
@@ -779,7 +789,13 @@ class PluginHttpSystemConfig(ValidatedConfigBase):
|
||||
default="100/minute", description="插件API的默认速率限制策略"
|
||||
)
|
||||
plugin_api_valid_keys: list[str] = Field(
|
||||
default_factory=list, description="有效的API密钥列表,用于插件认证"
|
||||
default_factory=list, description="<EFBFBD><EFBFBD>Ч<EFBFBD><EFBFBD>API<EFBFBD><EFBFBD>Կ<EFBFBD>б<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڲ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤"
|
||||
)
|
||||
event_handler_timeout: float = Field(
|
||||
default=30.0, ge=1.0, le=300.0, description="<EFBFBD>¼<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ִ<EFBFBD>г<EFBFBD>ʱʱ<EFBFBD>䣨<EFBFBD>룩"
|
||||
)
|
||||
event_handler_max_concurrency: int = Field(
|
||||
default=20, ge=1, le=200, description="<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ÿ<EFBFBD><EFBFBD><EFBFBD>¼<EFBFBD>ͬʱִ<EFBFBD>е<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>߸<EFBFBD><EFBFBD><EFBFBD>0<EFBFBD><EFBFBD>ʾ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,9 @@ class BaseEvent:
|
||||
def __name__(self):
|
||||
return self.name
|
||||
|
||||
async def activate(self, params: dict) -> HandlerResultsCollection:
|
||||
async def activate(
|
||||
self, params: dict, handler_timeout: float | None = None, max_concurrency: int | None = None
|
||||
) -> HandlerResultsCollection:
|
||||
"""激活事件,执行所有订阅的处理器
|
||||
|
||||
Args:
|
||||
@@ -115,40 +117,71 @@ class BaseEvent:
|
||||
|
||||
# 使用锁确保同一个事件不能同时激活多次
|
||||
async with self.event_handle_lock:
|
||||
# 按权重从高到低排序订阅者
|
||||
# 使用直接属性访问,-1代表自动权重
|
||||
sorted_subscribers = sorted(
|
||||
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
|
||||
)
|
||||
|
||||
# 并行执行所有订阅者
|
||||
tasks = []
|
||||
for subscriber in sorted_subscribers:
|
||||
# 为每个订阅者创建执行任务
|
||||
task = self._execute_subscriber(subscriber, params)
|
||||
tasks.append(task)
|
||||
if not sorted_subscribers:
|
||||
return HandlerResultsCollection([])
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
concurrency_limit = None
|
||||
if max_concurrency is not None:
|
||||
concurrency_limit = max_concurrency if max_concurrency > 0 else None
|
||||
if concurrency_limit:
|
||||
concurrency_limit = min(concurrency_limit, len(sorted_subscribers))
|
||||
|
||||
# 处理执行结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
subscriber = sorted_subscribers[i]
|
||||
semaphore = (
|
||||
asyncio.Semaphore(concurrency_limit)
|
||||
if concurrency_limit and concurrency_limit < len(sorted_subscribers)
|
||||
else None
|
||||
)
|
||||
|
||||
async def _run_handler(subscriber):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if result:
|
||||
if isinstance(result, Exception):
|
||||
# 处理执行异常
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
|
||||
processed_results.append(HandlerResult(False, True, str(result), handler_name))
|
||||
|
||||
async def _invoke():
|
||||
return await self._execute_subscriber(subscriber, params)
|
||||
|
||||
try:
|
||||
if handler_timeout and handler_timeout > 0:
|
||||
result = await asyncio.wait_for(_invoke(), timeout=handler_timeout)
|
||||
else:
|
||||
# 正常执行结果
|
||||
if not result.handler_name:
|
||||
# 补充handler_name
|
||||
result.handler_name = handler_name
|
||||
processed_results.append(result)
|
||||
result = await _invoke()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)")
|
||||
return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name)
|
||||
except Exception as exc:
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {exc}")
|
||||
return HandlerResult(False, True, str(exc), handler_name)
|
||||
|
||||
if not isinstance(result, HandlerResult):
|
||||
return HandlerResult(True, True, result, handler_name)
|
||||
|
||||
if not result.handler_name:
|
||||
result.handler_name = handler_name
|
||||
return result
|
||||
|
||||
async def _guarded_run(subscriber):
|
||||
if semaphore:
|
||||
async with semaphore:
|
||||
return await _run_handler(subscriber)
|
||||
return await _run_handler(subscriber)
|
||||
|
||||
tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
processed_results: list[HandlerResult] = []
|
||||
for subscriber, result in zip(sorted_subscribers, results):
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
|
||||
processed_results.append(HandlerResult(False, True, str(result), handler_name))
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return HandlerResultsCollection(processed_results)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
@@ -40,6 +41,15 @@ class EventManager:
|
||||
self._event_handlers: dict[str, BaseEventHandler] = {}
|
||||
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
||||
self._scheduler_callback: Any | None = None # scheduler 回调函数
|
||||
plugin_cfg = getattr(global_config, "plugin_http_system", None)
|
||||
self._default_handler_timeout: float | None = (
|
||||
getattr(plugin_cfg, "event_handler_timeout", 30.0) if plugin_cfg else 30.0
|
||||
)
|
||||
default_concurrency = getattr(plugin_cfg, "event_handler_max_concurrency", None) if plugin_cfg else None
|
||||
self._default_handler_concurrency: int | None = (
|
||||
default_concurrency if default_concurrency and default_concurrency > 0 else None
|
||||
)
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
@@ -293,7 +303,13 @@ class EventManager:
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(
|
||||
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
|
||||
self,
|
||||
event_name: EventType | str,
|
||||
permission_group: str | None = "",
|
||||
*,
|
||||
handler_timeout: float | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
**kwargs,
|
||||
) -> HandlerResultsCollection | None:
|
||||
"""触发指定事件
|
||||
|
||||
@@ -328,7 +344,10 @@ class EventManager:
|
||||
except Exception as e:
|
||||
logger.error(f"调用 scheduler 回调时出错: {e}", exc_info=True)
|
||||
|
||||
return await event.activate(params)
|
||||
timeout = handler_timeout if handler_timeout is not None else self._default_handler_timeout
|
||||
concurrency = max_concurrency if max_concurrency is not None else self._default_handler_concurrency
|
||||
|
||||
return await event.activate(params, handler_timeout=timeout, max_concurrency=concurrency)
|
||||
|
||||
def register_scheduler_callback(self, callback) -> None:
|
||||
"""注册 scheduler 回调函数
|
||||
@@ -344,6 +363,35 @@ class EventManager:
|
||||
self._scheduler_callback = None
|
||||
logger.info("Scheduler 回调已取消注册")
|
||||
|
||||
def emit_event(
|
||||
self,
|
||||
event_name: EventType | str,
|
||||
permission_group: str | None = "",
|
||||
*,
|
||||
handler_timeout: float | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
**kwargs,
|
||||
) -> asyncio.Task[Any] | None:
|
||||
"""调度事件但不等待结果,返回后台任务对象"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.warning(f"调度事件 {event_name} 失败:当前没有运行中的事件循环")
|
||||
return None
|
||||
|
||||
task = loop.create_task(
|
||||
self.trigger_event(
|
||||
event_name,
|
||||
permission_group=permission_group,
|
||||
handler_timeout=handler_timeout,
|
||||
max_concurrency=max_concurrency,
|
||||
**kwargs,
|
||||
),
|
||||
name=f"event::{event_name}",
|
||||
)
|
||||
self._track_background_task(task)
|
||||
return task
|
||||
|
||||
def init_default_events(self) -> None:
|
||||
"""初始化默认事件"""
|
||||
default_events = [
|
||||
@@ -437,5 +485,18 @@ class EventManager:
|
||||
return processed_count
|
||||
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
|
||||
"""跟踪后台事件任务,避免被 GC 清理"""
|
||||
self._background_tasks.add(task)
|
||||
|
||||
def _cleanup(fut: asyncio.Task[Any]) -> None:
|
||||
self._background_tasks.discard(fut)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
def get_background_task_count(self) -> int:
|
||||
"""返回当前仍在运行的后台事件任务数量"""
|
||||
return len(self._background_tasks)
|
||||
|
||||
# 创建全局事件管理器实例
|
||||
event_manager = EventManager()
|
||||
|
||||
@@ -108,6 +108,8 @@ class ToolExecutor:
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.execution_config = execution_config or ToolExecutionConfig()
|
||||
if execution_config is None:
|
||||
self._apply_config_defaults()
|
||||
|
||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||
self.chat_stream = None # type: ignore
|
||||
@@ -115,6 +117,20 @@ class ToolExecutor:
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
def _apply_config_defaults(self) -> None:
|
||||
tool_cfg = getattr(global_config, "tool", None)
|
||||
if not tool_cfg:
|
||||
return
|
||||
if hasattr(tool_cfg, "force_parallel_execution"):
|
||||
self.execution_config.enable_parallel = bool(tool_cfg.force_parallel_execution)
|
||||
max_invocations = getattr(tool_cfg, "max_parallel_invocations", None)
|
||||
if max_invocations:
|
||||
self.execution_config.max_concurrent_tools = max(1, max_invocations)
|
||||
timeout = getattr(tool_cfg, "tool_timeout", None)
|
||||
if timeout:
|
||||
self.execution_config.tool_timeout = max(1.0, float(timeout))
|
||||
|
||||
|
||||
# 二步工具调用状态管理
|
||||
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||
|
||||
@@ -356,7 +356,7 @@ class MessageHandler:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.TEXT, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -365,7 +365,7 @@ class MessageHandler:
|
||||
case RealMessageType.face:
|
||||
ret_seg = await self.handle_face_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.FACE, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -375,7 +375,7 @@ class MessageHandler:
|
||||
if not in_reply:
|
||||
ret_seg = await self.handle_reply_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.REPLY, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message += ret_seg
|
||||
@@ -385,7 +385,7 @@ class MessageHandler:
|
||||
logger.debug("开始处理图片消息段")
|
||||
ret_seg = await self.handle_image_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.IMAGE, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -396,7 +396,7 @@ class MessageHandler:
|
||||
case RealMessageType.record:
|
||||
ret_seg = await self.handle_record_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.RECORD, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.clear()
|
||||
@@ -408,7 +408,7 @@ class MessageHandler:
|
||||
logger.debug(f"开始处理VIDEO消息段: {sub_message}")
|
||||
ret_seg = await self.handle_video_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.VIDEO, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -422,7 +422,7 @@ class MessageHandler:
|
||||
raw_message.get("group_id"),
|
||||
)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.AT, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -431,7 +431,7 @@ class MessageHandler:
|
||||
case RealMessageType.rps:
|
||||
ret_seg = await self.handle_rps_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.RPS, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -440,7 +440,7 @@ class MessageHandler:
|
||||
case RealMessageType.dice:
|
||||
ret_seg = await self.handle_dice_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.DICE, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -449,7 +449,7 @@ class MessageHandler:
|
||||
case RealMessageType.shake:
|
||||
ret_seg = await self.handle_shake_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.SHAKE, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
@@ -478,7 +478,7 @@ class MessageHandler:
|
||||
case RealMessageType.json:
|
||||
ret_seg = await self.handle_json_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.JSON, permission_group=PLUGIN_NAME, message_seg=ret_seg
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
|
||||
@@ -133,7 +133,7 @@ class NoticeHandler:
|
||||
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
event_manager.emit_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
@@ -376,7 +376,7 @@ class NoticeHandler:
|
||||
)
|
||||
|
||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||
await event_manager.trigger_event(
|
||||
event_manager.emit_event(
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
@@ -702,4 +702,4 @@ class NoticeHandler:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
notice_handler = NoticeHandler()
|
||||
notice_handler = NoticeHandler()
|
||||
|
||||
Reference in New Issue
Block a user