diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index 781d9dde4..4a946bb72 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -100,7 +100,7 @@ class CycleProcessor: from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType # 触发 ON_PLAN 事件 - result = await event_manager.trigger_event(EventType.ON_PLAN, stream_id=self.context.stream_id) + result = await event_manager.trigger_event(EventType.ON_PLAN, plugin_name="SYSTEM", stream_id=self.context.stream_id) if result and not result.all_continue_process(): return diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 760d69062..9a29998d8 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -437,7 +437,7 @@ class ChatBot: logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") return - result = await event_manager.trigger_event(EventType.ON_MESSAGE,message=message) + result = await event_manager.trigger_event(EventType.ON_MESSAGE,plugin_name="SYSTEM",message=message) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers','')}于消息到达时取消了消息处理") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 9c9dc334f..77033472d 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -370,7 +370,7 @@ class DefaultReplyer: from src.plugin_system.core.event_manager import event_manager if not from_plugin: - result = await event_manager.trigger_event(EventType.POST_LLM,prompt=prompt,stream_id=stream_id) + result = await event_manager.trigger_event(EventType.POST_LLM,plugin_name="SYSTEM",prompt=prompt,stream_id=stream_id) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于请求前中断了内容生成") @@ -390,7 +390,7 @@ class DefaultReplyer: } # 触发 AFTER_LLM 事件 if not from_plugin: - result = await event_manager.trigger_event(EventType.AFTER_LLM,prompt=prompt,llm_response=llm_response,stream_id=stream_id) + result = await event_manager.trigger_event(EventType.AFTER_LLM,plugin_name="SYSTEM",prompt=prompt,llm_response=llm_response,stream_id=stream_id) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers','')}于请求后取消了内容生成") except UserWarning as e: diff --git a/src/main.py b/src/main.py index 7ac8cb76d..6a5f989b0 100644 --- a/src/main.py +++ b/src/main.py @@ -254,7 +254,7 @@ MoFox_Bot(第三方修改版) try: - await event_manager.trigger_event(EventType.ON_START) + await event_manager.trigger_event(EventType.ON_START,plugin_name="SYSTEM") init_time = int(1000 * (time.time() - init_start_time)) logger.info(f"初始化完成,神经元放电{init_time}次") except Exception as e: diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index 9e85dc34d..c527752d5 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -3,18 +3,18 @@ from typing import List, Dict, Any, Optional from src.common.logger import get_logger logger = get_logger("base_event") - + class HandlerResult: """事件处理器执行结果 所有事件处理器必须返回此类的实例 """ - def __init__(self, success: bool, continue_process: bool, message: str = "", handler_name: str = ""): + def __init__(self, success: bool, continue_process: bool, message: Any = {}, handler_name: str = ""): self.success = success self.continue_process = continue_process self.message = message self.handler_name = handler_name - + def __repr__(self): return f"HandlerResult(success={self.success}, continue_process={self.continue_process}, message='{self.message}', handler_name='{self.handler_name}')" @@ -67,9 +67,16 @@ class HandlerResultsCollection: } class BaseEvent: - def __init__(self, name: str): + def __init__( + self, + name: str, + allowed_subscribers: List[str]=[], + allowed_triggers: List[str]=[] + ): self.name = name self.enabled = True + self.allowed_subscribers = allowed_subscribers # 记录事件处理器名 + self.allowed_triggers = allowed_triggers # 记录插件名 from src.plugin_system.base.base_events_handler import BaseEventHandler self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index b6845c3ef..38d5775da 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -40,12 +40,18 @@ class EventManager: self._initialized = True logger.info("EventManager 单例初始化完成") - def register_event(self, event_name: Union[EventType, str]) -> bool: + def register_event( + self, + event_name: Union[EventType, str], + allowed_subscribers: List[str]=[], + allowed_triggers: List[str]=[] + ) -> bool: """注册一个新的事件 Args: event_name Union[EventType, str]: 事件名称 - + allowed_subscribers: List[str]: 事件订阅者白名单, + allowed_triggers: List[str]: 事件触发插件白名单 Returns: bool: 注册成功返回True,已存在返回False """ @@ -53,7 +59,7 @@ class EventManager: logger.warning(f"事件 {event_name} 已存在,跳过注册") return False - event = BaseEvent(event_name) + event = BaseEvent(event_name,allowed_subscribers,allowed_triggers) self._events[event_name] = event logger.info(f"事件 {event_name} 注册成功") @@ -210,7 +216,12 @@ class EventManager: if handler_instance in event.subscribers: logger.warning(f"事件处理器 {handler_name} 已经订阅了事件 {event_name},跳过重复订阅") return True - + + # 白名单检查 + if event.allowed_subscribers and handler_name not in event.allowed_subscribers: + logger.warning(f"事件处理器 {handler_name} 不在事件 {event_name} 的订阅者白名单中,无法订阅") + return False + event.subscribers.append(handler_instance) # 按权重从高到低排序订阅者 @@ -264,11 +275,12 @@ class EventManager: return {handler.handler_name: handler for handler in event.subscribers} - async def trigger_event(self, event_name: Union[EventType, str], **kwargs) -> Optional[HandlerResultsCollection]: + async def trigger_event(self, event_name: Union[EventType, str], plugin_name: Optional[str]="", **kwargs) -> Optional[HandlerResultsCollection]: """触发指定事件 Args: event_name Union[EventType, str]: 事件名称 + plugin_name str: 触发事件的插件名 **kwargs: 传递给处理器的参数 Returns: @@ -280,7 +292,15 @@ class EventManager: if event is None: logger.error(f"事件 {event_name} 不存在,无法触发") return None - + + # 插件白名单检查 + if event.allowed_triggers and not plugin_name: + logger.warning(f"事件 {event_name} 存在触发者白名单,缺少plugin_name无法验证权限,已拒绝触发!") + return None + elif event.allowed_triggers and plugin_name not in event.allowed_triggers: + logger.warning(f"插件 {plugin_name} 没有权限触发事件 {event_name},已拒绝触发!") + return None + return await event.activate(params) def init_default_events(self) -> None: @@ -297,7 +317,7 @@ class EventManager: ] for event_name in default_events: - self.register_event(event_name) + self.register_event(event_name,allowed_triggers=["SYSTEM"]) logger.info("默认事件初始化完成")