typing and plugins
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||
@@ -21,15 +21,17 @@ class BaseEventHandler(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = "" # 对应插件名
|
||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
|
||||
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@@ -49,3 +51,44 @@ class BaseEventHandler(ABC):
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
def set_plugin_config(self, plugin_config: Dict) -> None:
|
||||
"""设置插件配置
|
||||
|
||||
Args:
|
||||
plugin_config (dict): 插件配置字典
|
||||
"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
@@ -159,8 +159,8 @@ class ComponentRegistry:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
if pattern not in self._command_patterns:
|
||||
self._command_patterns[pattern] = command_name
|
||||
|
||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
||||
else:
|
||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Type
|
||||
import contextlib
|
||||
from typing import List, Dict, Optional, Type, Tuple
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
@@ -12,8 +13,9 @@ logger = get_logger("events_manager")
|
||||
class EventsManager:
|
||||
def __init__(self):
|
||||
# 有权重的 events 订阅者注册表
|
||||
self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||
|
||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
@@ -29,7 +31,7 @@ class EventsManager:
|
||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
||||
|
||||
namespace_name = f"{plugin_name}.{handler_name}"
|
||||
if namespace_name in self.handler_mapping:
|
||||
if namespace_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
@@ -37,50 +39,73 @@ class EventsManager:
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
return False
|
||||
|
||||
self.handler_mapping[namespace_name] = handler_class
|
||||
self._handler_mapping[namespace_name] = handler_class
|
||||
return self._insert_event_handler(handler_class, handler_info)
|
||||
|
||||
return self._insert_event_handler(handler_class)
|
||||
|
||||
async def handler_mai_events(
|
||||
async def handle_mai_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: MessageRecv,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""处理 events"""
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self.events_subscribers.get(event_type, []):
|
||||
if handler.intercept_message:
|
||||
await handler.execute(transformed_message)
|
||||
else:
|
||||
asyncio.create_task(handler.execute(transformed_message))
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""插入事件处理器到对应的事件类型列表中"""
|
||||
continue_flag = True
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
success, continue_processing, result = await handler.execute(transformed_message)
|
||||
if not success:
|
||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
|
||||
else:
|
||||
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
|
||||
continue_flag = continue_flag and continue_processing
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||
handler_task.add_done_callback(self._task_done_callback)
|
||||
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
|
||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||
continue
|
||||
return continue_flag
|
||||
|
||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
||||
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||
return False
|
||||
|
||||
self.events_subscribers[handler_class.event_type].append(handler_class())
|
||||
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||
handler_instance = handler_class()
|
||||
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
||||
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
||||
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||
|
||||
return True
|
||||
|
||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""从事件类型列表中移除事件处理器"""
|
||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
|
||||
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
|
||||
return False
|
||||
|
||||
handlers = self.events_subscribers[handler_class.event_type]
|
||||
handlers = self._events_subscribers[handler_class.event_type]
|
||||
for i, handler in enumerate(handlers):
|
||||
if isinstance(handler, handler_class):
|
||||
del handlers[i]
|
||||
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
|
||||
logger.debug(f"事件处理器 {display_handler_name} 已移除")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
|
||||
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
|
||||
return False
|
||||
|
||||
def _transform_event_message(
|
||||
@@ -102,35 +127,68 @@ class EventsManager:
|
||||
transformed_message.message_segments = [message.message_segment]
|
||||
|
||||
# stream_id 处理
|
||||
if hasattr(message, "chat_stream"):
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
transformed_message.stream_id = message.chat_stream.stream_id
|
||||
|
||||
# 处理后文本
|
||||
transformed_message.plain_text = message.processed_plain_text
|
||||
|
||||
# 基本信息
|
||||
if message.message_info.platform:
|
||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||
if message.message_info.group_info:
|
||||
transformed_message.is_group_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed_message.is_group_message:
|
||||
transformed_message.is_private_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||
}
|
||||
)
|
||||
if hasattr(message, "message_info") and message.message_info:
|
||||
if message.message_info.platform:
|
||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||
if message.message_info.group_info:
|
||||
transformed_message.is_group_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed_message.is_group_message:
|
||||
transformed_message.is_private_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||
}
|
||||
)
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
|
||||
"""任务完成回调"""
|
||||
task_name = task.get_name() or "Unknown Task"
|
||||
try:
|
||||
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||
if success:
|
||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||
else:
|
||||
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
||||
finally:
|
||||
with contextlib.suppress(ValueError, KeyError):
|
||||
self._handler_tasks[task_name].remove(task)
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
remaining_tasks = [task for task in tasks_to_be_cancelled if not task.done()]
|
||||
for task in remaining_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
||||
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
||||
except Exception as e:
|
||||
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
||||
finally:
|
||||
del self._handler_tasks[handler_name]
|
||||
|
||||
|
||||
events_manager = EventsManager()
|
||||
|
||||
Reference in New Issue
Block a user