diff --git a/changes.md b/changes.md index 0d6b507b9..407537d28 100644 --- a/changes.md +++ b/changes.md @@ -41,6 +41,13 @@ - 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况) 7. 现在不支持单文件插件了,加载方式已经完全删除。 8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。 +9. `BaseEventHandler`现在有了`get_config`方法了。 +10. 修正了`main.py`中的错误输出。 +11. 修正了`command`所编译的`Pattern`注册时的错误输出。 +12. `events_manager`有了task相关逻辑了。 + +### TODO +把这个看起来就很别扭的config获取方式改一下 # 吐槽 diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 2d645616b..14a9d16c5 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -102,11 +102,12 @@ class PrintMessage(BaseEventHandler): handler_name = "print_message_handler" handler_description = "打印接收到的消息" - async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]: + async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]: """执行打印消息事件处理""" # 打印接收到的消息 - print(f"接收到消息: {message.raw_message}") - return True, "消息已打印" + if self.get_config("print_message.enabled", False): + print(f"接收到消息: {message.raw_message}") + return True, True, "消息已打印" # ===== 插件注册 ===== diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 5d761e5cb..4f211ede5 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -13,8 +13,8 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 -from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.core import component_registry, events_manager # 导入新插件系统 +from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.llm_models.utils_model import LLMRequest @@ -140,57 +140,22 @@ class ChatBot: message = MessageRecvS4U(message_data) group_info = message.message_info.group_info user_info = message.message_info.user_info - - + get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( platform=message.message_info.platform, # type: ignore user_info=user_info, # type: ignore group_info=group_info, ) - + message.update_chat_stream(chat) # 处理消息内容 await message.process() - + await self.s4u_message_processor.process_message(message) - + return - - async def hanle_notice_message(self, message: MessageRecv): - if message.message_info.message_id == "notice": - logger.info(f"收到notice消息,暂时不支持处理") - return True - - async def check_ban_content(self,message:MessageRecv): - content = message.processed_plain_text - model = LLMRequest( - model=global_config.model.utils_small, - temperature=0.3, - request_type="text_ban", - ) - - if content.startswith("[picid:"): - return False - - prompt = f""" - 你是一个直播内容审核者,请判断以下内容是否符合直播内容规范,如果符合,请输出"符合",如果不符合,请输出"不符合"以及理由 - 如果存在反动,敏感内容,请输出"不符合" - - 内容:{content} - - 现在请输出: - """ - - response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt) - logger.info(f"response: {response}") - logger.info(f"reasoning_content: {reasoning_content}") - - if "不符合" in response: - return True - else: - return False async def message_process(self, message_data: Dict[str, Any]) -> None: @@ -212,9 +177,9 @@ class ChatBot: # 确保所有任务已启动 await self._ensure_started() - + platform = message_data["message_info"].get("platform") - + if platform == "amaidesu_default": await self.do_s4u(message_data) return @@ -243,6 +208,9 @@ class ChatBot: await MessageStorage.update_message(message) return + if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message): + return + get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( diff --git a/src/chat/willing/mode_custom.py b/src/chat/willing/mode_custom.py index 36334df43..9987ba942 100644 --- a/src/chat/willing/mode_custom.py +++ b/src/chat/willing/mode_custom.py @@ -1,21 +1,23 @@ from .willing_manager import BaseWillingManager +NOT_IMPLEMENTED_MESSAGE = "\ncustom模式你实现了吗?没自行实现不要选custom。给你退了快点给你麦爹配置\n注:以上内容由gemini生成,如有不满请投诉gemini" class CustomWillingManager(BaseWillingManager): async def async_task_starter(self) -> None: - pass + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) async def before_generate_reply_handle(self, message_id: str): - pass + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) async def after_generate_reply_handle(self, message_id: str): - pass + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) async def not_reply_handle(self, message_id: str): - pass + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) async def get_reply_probability(self, message_id: str): - pass + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) def __init__(self): super().__init__() + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) diff --git a/src/main.py b/src/main.py index dbd12f1a4..3cd2107d1 100644 --- a/src/main.py +++ b/src/main.py @@ -78,8 +78,7 @@ class MainSystem: # logger.info("API服务器启动成功") # 加载所有actions,包括默认的和插件的 - plugin_count, component_count = plugin_manager.load_all_plugins() - logger.info(f"插件系统加载成功: {plugin_count} 个插件,{component_count} 个组件") + plugin_manager.load_all_plugins() # 初始化表情管理器 get_emoji_manager().initialize() diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index db6c20b62..b6c9e965d 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict from src.common.logger import get_logger from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType @@ -21,15 +21,17 @@ class BaseEventHandler(ABC): def __init__(self): self.log_prefix = "[EventHandler]" + self.plugin_name = "" # 对应插件名 + self.plugin_config: Optional[Dict] = None # 插件配置字典 if self.event_type == EventType.UNKNOWN: raise NotImplementedError("事件处理器必须指定 event_type") @abstractmethod - async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]: + async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]: """执行事件处理的抽象方法,子类必须实现 Returns: - Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息) + Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息) """ raise NotImplementedError("子类必须实现 execute 方法") @@ -49,3 +51,44 @@ class BaseEventHandler(ABC): weight=cls.weight, intercept_message=cls.intercept_message, ) + + def set_plugin_config(self, plugin_config: Dict) -> None: + """设置插件配置 + + Args: + plugin_config (dict): 插件配置字典 + """ + self.plugin_config = plugin_config + + def set_plugin_name(self, plugin_name: str) -> None: + """设置插件名称 + + Args: + plugin_name (str): 插件名称 + """ + self.plugin_name = plugin_name + + def get_config(self, key: str, default=None): + """获取插件配置值,支持嵌套键访问 + + Args: + key: 配置键名,支持嵌套访问如 "section.subsection.key" + default: 默认值 + + Returns: + Any: 配置值或默认值 + """ + if not self.plugin_config: + return default + + # 支持嵌套键访问 + keys = key.split(".") + current = self.plugin_config + + for k in keys: + if isinstance(current, dict) and k in current: + current = current[k] + else: + return default + + return current diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 29a45c604..7283cf9eb 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -159,8 +159,8 @@ class ComponentRegistry: pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) if pattern not in self._command_patterns: self._command_patterns[pattern] = command_name - - logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") + else: + logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") return True diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 2c48f9d6d..6352c4a09 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -1,5 +1,6 @@ import asyncio -from typing import List, Dict, Optional, Type +import contextlib +from typing import List, Dict, Optional, Type, Tuple from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger @@ -12,8 +13,9 @@ logger = get_logger("events_manager") class EventsManager: def __init__(self): # 有权重的 events 订阅者注册表 - self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType} - self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 + self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType} + self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 + self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务 def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool: """注册事件处理器 @@ -29,7 +31,7 @@ class EventsManager: plugin_name = getattr(handler_info, "plugin_name", "unknown") namespace_name = f"{plugin_name}.{handler_name}" - if namespace_name in self.handler_mapping: + if namespace_name in self._handler_mapping: logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册") return False @@ -37,50 +39,73 @@ class EventsManager: logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类") return False - self.handler_mapping[namespace_name] = handler_class + self._handler_mapping[namespace_name] = handler_class + return self._insert_event_handler(handler_class, handler_info) - return self._insert_event_handler(handler_class) - - async def handler_mai_events( + async def handle_mai_events( self, event_type: EventType, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None, - ) -> None: + ) -> bool: """处理 events""" - transformed_message = self._transform_event_message(message, llm_prompt, llm_response) - for handler in self.events_subscribers.get(event_type, []): - if handler.intercept_message: - await handler.execute(transformed_message) - else: - asyncio.create_task(handler.execute(transformed_message)) + from src.plugin_system.core import component_registry - def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: - """插入事件处理器到对应的事件类型列表中""" + continue_flag = True + transformed_message = self._transform_event_message(message, llm_prompt, llm_response) + for handler in self._events_subscribers.get(event_type, []): + handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {}) + if handler.intercept_message: + try: + success, continue_processing, result = await handler.execute(transformed_message) + if not success: + logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}") + else: + logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}") + continue_flag = continue_flag and continue_processing + except Exception as e: + logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}") + continue + else: + try: + handler_task = asyncio.create_task(handler.execute(transformed_message)) + handler_task.add_done_callback(self._task_done_callback) + handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}") + self._handler_tasks[handler.handler_name].append(handler_task) + except Exception as e: + logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}") + continue + return continue_flag + + def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool: + """插入事件处理器到对应的事件类型列表中并设置其插件配置""" if handler_class.event_type == EventType.UNKNOWN: logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册") return False - self.events_subscribers[handler_class.event_type].append(handler_class()) - self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) + handler_instance = handler_class() + handler_instance.set_plugin_name(handler_info.plugin_name or "unknown") + self._events_subscribers[handler_class.event_type].append(handler_instance) + self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) return True def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: """从事件类型列表中移除事件处理器""" + display_handler_name = handler_class.handler_name or handler_class.__name__ if handler_class.event_type == EventType.UNKNOWN: - logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中") + logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中") return False - handlers = self.events_subscribers[handler_class.event_type] + handlers = self._events_subscribers[handler_class.event_type] for i, handler in enumerate(handlers): if isinstance(handler, handler_class): del handlers[i] - logger.debug(f"事件处理器 {handler_class.__name__} 已移除") + logger.debug(f"事件处理器 {display_handler_name} 已移除") return True - logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除") + logger.warning(f"未找到事件处理器 {display_handler_name},无法移除") return False def _transform_event_message( @@ -102,35 +127,68 @@ class EventsManager: transformed_message.message_segments = [message.message_segment] # stream_id 处理 - if hasattr(message, "chat_stream"): + if hasattr(message, "chat_stream") and message.chat_stream: transformed_message.stream_id = message.chat_stream.stream_id # 处理后文本 transformed_message.plain_text = message.processed_plain_text # 基本信息 - if message.message_info.platform: - transformed_message.message_base_info["platform"] = message.message_info.platform - if message.message_info.group_info: - transformed_message.is_group_message = True - transformed_message.message_base_info.update( - { - "group_id": message.message_info.group_info.group_id, - "group_name": message.message_info.group_info.group_name, - } - ) - if message.message_info.user_info: - if not transformed_message.is_group_message: - transformed_message.is_private_message = True - transformed_message.message_base_info.update( - { - "user_id": message.message_info.user_info.user_id, - "user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 - "user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) - } - ) + if hasattr(message, "message_info") and message.message_info: + if message.message_info.platform: + transformed_message.message_base_info["platform"] = message.message_info.platform + if message.message_info.group_info: + transformed_message.is_group_message = True + transformed_message.message_base_info.update( + { + "group_id": message.message_info.group_info.group_id, + "group_name": message.message_info.group_info.group_name, + } + ) + if message.message_info.user_info: + if not transformed_message.is_group_message: + transformed_message.is_private_message = True + transformed_message.message_base_info.update( + { + "user_id": message.message_info.user_info.user_id, + "user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 + "user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) + } + ) return transformed_message + def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]): + """任务完成回调""" + task_name = task.get_name() or "Unknown Task" + try: + success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截 + if success: + logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}") + else: + logger.error(f"事件处理任务 {task_name} 执行失败: {result}") + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"事件处理任务 {task_name} 发生异常: {e}") + finally: + with contextlib.suppress(ValueError, KeyError): + self._handler_tasks[task_name].remove(task) + + async def cancel_handler_tasks(self, handler_name: str) -> None: + tasks_to_be_cancelled = self._handler_tasks.get(handler_name, []) + remaining_tasks = [task for task in tasks_to_be_cancelled if not task.done()] + for task in remaining_tasks: + task.cancel() + try: + await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5) + logger.info(f"已取消事件处理器 {handler_name} 的所有任务") + except asyncio.TimeoutError: + logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消") + except Exception as e: + logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}") + finally: + del self._handler_tasks[handler_name] + events_manager = EventsManager()