diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 7dd7df940..90459545b 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -30,7 +30,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool: from src.plugin_system.core.event_manager import event_manager if message.chat_stream: - await event_manager.trigger_event( + event_manager.emit_event( EventType.AFTER_SEND, permission_group="SYSTEM", stream_id=message.chat_stream.stream_id, diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e686f0702..dc16d1005 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -269,6 +269,16 @@ class ToolConfig(ValidatedConfigBase): """工具配置类""" enable_tool: bool = Field(default=False, description="启用工具") + force_parallel_execution: bool = Field( + default=True, + description="����LLM����ͬʱ������Ҫʹ�ù�����ʱǿ��ʹ�ò���ģʽ��ֹ���������Ϣ", + ) + max_parallel_invocations: int = Field( + default=5, ge=1, le=50, description="��ͬһ�������п��Խ������ܹ��ߵ�������" + ) + tool_timeout: float = Field( + default=60.0, ge=1.0, le=600.0, description="�������ߵ��õij�ʱʱ�䣨�룩" + ) class VoiceConfig(ValidatedConfigBase): @@ -779,7 +789,13 @@ class PluginHttpSystemConfig(ValidatedConfigBase): default="100/minute", description="插件API的默认速率限制策略" ) plugin_api_valid_keys: list[str] = Field( - default_factory=list, description="有效的API密钥列表,用于插件认证" + default_factory=list, description="��Ч��API��Կ�б������ڲ����֤" + ) + event_handler_timeout: float = Field( + default=30.0, ge=1.0, le=300.0, description="�¼����������ִ�г�ʱʱ�䣨�룩" + ) + event_handler_max_concurrency: int = Field( + default=20, ge=1, le=200, description="����ÿ���¼�ͬʱִ�е�������߸���0��ʾ����������" ) diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index 47d410c60..110903c39 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -101,7 +101,9 @@ class BaseEvent: def __name__(self): return self.name - async def activate(self, params: dict) -> HandlerResultsCollection: + async def activate( + self, params: dict, handler_timeout: float | None = None, max_concurrency: int | None = None + ) -> HandlerResultsCollection: """激活事件,执行所有订阅的处理器 Args: @@ -115,40 +117,71 @@ class BaseEvent: # 使用锁确保同一个事件不能同时激活多次 async with self.event_handle_lock: - # 按权重从高到低排序订阅者 - # 使用直接属性访问,-1代表自动权重 sorted_subscribers = sorted( self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True ) - # 并行执行所有订阅者 - tasks = [] - for subscriber in sorted_subscribers: - # 为每个订阅者创建执行任务 - task = self._execute_subscriber(subscriber, params) - tasks.append(task) + if not sorted_subscribers: + return HandlerResultsCollection([]) - # 等待所有任务完成 - results = await asyncio.gather(*tasks, return_exceptions=True) + concurrency_limit = None + if max_concurrency is not None: + concurrency_limit = max_concurrency if max_concurrency > 0 else None + if concurrency_limit: + concurrency_limit = min(concurrency_limit, len(sorted_subscribers)) - # 处理执行结果 - processed_results = [] - for i, result in enumerate(results): - subscriber = sorted_subscribers[i] + semaphore = ( + asyncio.Semaphore(concurrency_limit) + if concurrency_limit and concurrency_limit < len(sorted_subscribers) + else None + ) + + async def _run_handler(subscriber): handler_name = ( subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__ ) - if result: - if isinstance(result, Exception): - # 处理执行异常 - logger.error(f"事件处理器 {handler_name} 执行失败: {result}") - processed_results.append(HandlerResult(False, True, str(result), handler_name)) + + async def _invoke(): + return await self._execute_subscriber(subscriber, params) + + try: + if handler_timeout and handler_timeout > 0: + result = await asyncio.wait_for(_invoke(), timeout=handler_timeout) else: - # 正常执行结果 - if not result.handler_name: - # 补充handler_name - result.handler_name = handler_name - processed_results.append(result) + result = await _invoke() + except asyncio.TimeoutError: + logger.warning(f"事件处理器 {handler_name} 执行超时 ({handler_timeout}s)") + return HandlerResult(False, True, f"timeout after {handler_timeout}s", handler_name) + except Exception as exc: + logger.error(f"事件处理器 {handler_name} 执行失败: {exc}") + return HandlerResult(False, True, str(exc), handler_name) + + if not isinstance(result, HandlerResult): + return HandlerResult(True, True, result, handler_name) + + if not result.handler_name: + result.handler_name = handler_name + return result + + async def _guarded_run(subscriber): + if semaphore: + async with semaphore: + return await _run_handler(subscriber) + return await _run_handler(subscriber) + + tasks = [asyncio.create_task(_guarded_run(subscriber)) for subscriber in sorted_subscribers] + results = await asyncio.gather(*tasks, return_exceptions=True) + + processed_results: list[HandlerResult] = [] + for subscriber, result in zip(sorted_subscribers, results): + handler_name = ( + subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__ + ) + if isinstance(result, Exception): + logger.error(f"事件处理器 {handler_name} 执行失败: {result}") + processed_results.append(HandlerResult(False, True, str(result), handler_name)) + else: + processed_results.append(result) return HandlerResultsCollection(processed_results) diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index ed773e31b..cdb3fdb19 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -7,6 +7,7 @@ from threading import Lock from typing import Any, Optional from src.common.logger import get_logger +from src.config.config import global_config from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.component_types import EventType @@ -40,6 +41,15 @@ class EventManager: self._event_handlers: dict[str, BaseEventHandler] = {} self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 self._scheduler_callback: Any | None = None # scheduler 回调函数 + plugin_cfg = getattr(global_config, "plugin_http_system", None) + self._default_handler_timeout: float | None = ( + getattr(plugin_cfg, "event_handler_timeout", 30.0) if plugin_cfg else 30.0 + ) + default_concurrency = getattr(plugin_cfg, "event_handler_max_concurrency", None) if plugin_cfg else None + self._default_handler_concurrency: int | None = ( + default_concurrency if default_concurrency and default_concurrency > 0 else None + ) + self._background_tasks: set[asyncio.Task[Any]] = set() self._initialized = True logger.info("EventManager 单例初始化完成") @@ -293,7 +303,13 @@ class EventManager: return {handler.handler_name: handler for handler in event.subscribers} async def trigger_event( - self, event_name: EventType | str, permission_group: str | None = "", **kwargs + self, + event_name: EventType | str, + permission_group: str | None = "", + *, + handler_timeout: float | None = None, + max_concurrency: int | None = None, + **kwargs, ) -> HandlerResultsCollection | None: """触发指定事件 @@ -328,7 +344,10 @@ class EventManager: except Exception as e: logger.error(f"调用 scheduler 回调时出错: {e}", exc_info=True) - return await event.activate(params) + timeout = handler_timeout if handler_timeout is not None else self._default_handler_timeout + concurrency = max_concurrency if max_concurrency is not None else self._default_handler_concurrency + + return await event.activate(params, handler_timeout=timeout, max_concurrency=concurrency) def register_scheduler_callback(self, callback) -> None: """注册 scheduler 回调函数 @@ -344,6 +363,35 @@ class EventManager: self._scheduler_callback = None logger.info("Scheduler 回调已取消注册") + def emit_event( + self, + event_name: EventType | str, + permission_group: str | None = "", + *, + handler_timeout: float | None = None, + max_concurrency: int | None = None, + **kwargs, + ) -> asyncio.Task[Any] | None: + """调度事件但不等待结果,返回后台任务对象""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"调度事件 {event_name} 失败:当前没有运行中的事件循环") + return None + + task = loop.create_task( + self.trigger_event( + event_name, + permission_group=permission_group, + handler_timeout=handler_timeout, + max_concurrency=max_concurrency, + **kwargs, + ), + name=f"event::{event_name}", + ) + self._track_background_task(task) + return task + def init_default_events(self) -> None: """初始化默认事件""" default_events = [ @@ -437,5 +485,18 @@ class EventManager: return processed_count + def _track_background_task(self, task: asyncio.Task[Any]) -> None: + """跟踪后台事件任务,避免被 GC 清理""" + self._background_tasks.add(task) + + def _cleanup(fut: asyncio.Task[Any]) -> None: + self._background_tasks.discard(fut) + + task.add_done_callback(_cleanup) + + def get_background_task_count(self) -> int: + """返回当前仍在运行的后台事件任务数量""" + return len(self._background_tasks) + # 创建全局事件管理器实例 event_manager = EventManager() diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 3f321236c..2614ab2bf 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -108,6 +108,8 @@ class ToolExecutor: """ self.chat_id = chat_id self.execution_config = execution_config or ToolExecutionConfig() + if execution_config is None: + self._apply_config_defaults() # chat_stream 和 log_prefix 将在异步方法中初始化 self.chat_stream = None # type: ignore @@ -115,6 +117,20 @@ class ToolExecutor: self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") + def _apply_config_defaults(self) -> None: + tool_cfg = getattr(global_config, "tool", None) + if not tool_cfg: + return + if hasattr(tool_cfg, "force_parallel_execution"): + self.execution_config.enable_parallel = bool(tool_cfg.force_parallel_execution) + max_invocations = getattr(tool_cfg, "max_parallel_invocations", None) + if max_invocations: + self.execution_config.max_concurrent_tools = max(1, max_invocations) + timeout = getattr(tool_cfg, "tool_timeout", None) + if timeout: + self.execution_config.tool_timeout = max(1.0, float(timeout)) + + # 二步工具调用状态管理 self._pending_step_two_tools: dict[str, dict[str, Any]] = {} """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index 415d2ed13..b7e7b2c25 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -356,7 +356,7 @@ class MessageHandler: case RealMessageType.text: ret_seg = await self.handle_text_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.TEXT, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -365,7 +365,7 @@ class MessageHandler: case RealMessageType.face: ret_seg = await self.handle_face_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.FACE, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -375,7 +375,7 @@ class MessageHandler: if not in_reply: ret_seg = await self.handle_reply_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.REPLY, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message += ret_seg @@ -385,7 +385,7 @@ class MessageHandler: logger.debug("开始处理图片消息段") ret_seg = await self.handle_image_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.IMAGE, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -396,7 +396,7 @@ class MessageHandler: case RealMessageType.record: ret_seg = await self.handle_record_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.RECORD, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.clear() @@ -408,7 +408,7 @@ class MessageHandler: logger.debug(f"开始处理VIDEO消息段: {sub_message}") ret_seg = await self.handle_video_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.VIDEO, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -422,7 +422,7 @@ class MessageHandler: raw_message.get("group_id"), ) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.AT, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -431,7 +431,7 @@ class MessageHandler: case RealMessageType.rps: ret_seg = await self.handle_rps_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.RPS, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -440,7 +440,7 @@ class MessageHandler: case RealMessageType.dice: ret_seg = await self.handle_dice_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.DICE, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -449,7 +449,7 @@ class MessageHandler: case RealMessageType.shake: ret_seg = await self.handle_shake_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.SHAKE, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) @@ -478,7 +478,7 @@ class MessageHandler: case RealMessageType.json: ret_seg = await self.handle_json_message(sub_message) if ret_seg: - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.JSON, permission_group=PLUGIN_NAME, message_seg=ret_seg ) seg_message.append(ret_seg) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 866028472..1f6bf104e 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -133,7 +133,7 @@ class NoticeHandler: from ...event_types import NapcatEvent - await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) + event_manager.emit_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) case _: logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") case NoticeType.group_msg_emoji_like: @@ -376,7 +376,7 @@ class NoticeHandler: ) like_emoji_id = raw_message.get("likes")[0].get("emoji_id") - await event_manager.trigger_event( + event_manager.emit_event( NapcatEvent.ON_RECEIVED.EMOJI_LIEK, permission_group=PLUGIN_NAME, group_id=group_id, @@ -702,4 +702,4 @@ class NoticeHandler: await asyncio.sleep(1) -notice_handler = NoticeHandler() \ No newline at end of file +notice_handler = NoticeHandler()