events manager and some typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-18 14:50:15 +08:00
parent d2b5019c24
commit ffa88b5462
6 changed files with 318 additions and 33 deletions

View File

@@ -1,8 +1,14 @@
from abc import abstractmethod
from typing import List, Tuple, Type, TYPE_CHECKING
from .plugin_base import PluginBase
from src.common.logger import get_logger
from .plugin_base import PluginBase
from .component_types import EventHandlerInfo
if TYPE_CHECKING:
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("base_event_plugin")
class BaseEventPlugin(PluginBase):
"""基于事件的插件基类
@@ -12,3 +18,42 @@ class BaseEventPlugin(PluginBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]:
"""获取插件包含的事件组件
子类必须实现此方法,返回事件组件
Returns:
List[Tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("子类必须实现 get_plugin_components 方法")
def register_plugin(self) -> bool:
"""注册事件插件"""
from src.plugin_system.core.events_manager import events_manager
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
registered_components = []
for handler_info, handler_class in components:
handler_info.plugin_name = self.plugin_name
if events_manager.register_event_subscriber(handler_info, handler_class):
registered_components.append(handler_info)
else:
logger.error(f"{self.log_prefix} 事件处理器 {handler_info.name} 注册失败")
self.plugin_info.components = registered_components
if events_manager.register_plugins(self.plugin_info):
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个事件处理器")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False

View File

@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from src.common.logger import get_logger
from .component_types import MaiMessages, EventType
logger = get_logger("base_event_handler")
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
handler_name: str = ""
handler_description: str = ""
weight: int = 0 # 权重,数值越大优先级越高
intercept_message: bool = False # 是否拦截消息,默认为否
def __init__(self):
self.log_prefix = "[EventHandler]"
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")

View File

@@ -1,6 +1,7 @@
from enum import Enum
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from maim_message import Seg
# 组件类型枚举
@@ -12,6 +13,9 @@ class ComponentType(Enum):
SCHEDULER = "scheduler" # 定时任务组件(预留)
LISTENER = "listener" # 事件监听组件(预留)
def __str__(self) -> str:
return self.value
# 动作激活类型枚举
class ActionActivationType(Enum):
@@ -46,12 +50,17 @@ class EventType(Enum):
事件类型枚举类
"""
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
AFTER_LLM = "after_llm"
POST_SEND = "post_send"
AFTER_SEND = "after_send"
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
return self.value
@dataclass
@@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo):
self.component_type = ComponentType.COMMAND
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
weight: int = 0 # 事件处理器权重,决定执行顺序
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.LISTENER
@dataclass
class PluginInfo:
"""插件信息"""
@@ -198,3 +220,42 @@ class PluginInfo:
def get_pip_requirements(self) -> List[str]:
"""获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class MaiMessages:
"""MaiM插件消息"""
message_segments: List[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
message_base_info: Dict[str, Any] = field(default_factory=dict)
"""消息基本信息,包含平台,用户信息等数据"""
plain_text: str = ""
"""纯文本消息内容"""
raw_message: Optional[str] = None
"""原始消息内容"""
is_group_message: bool = False
"""是否为群组消息"""
is_private_message: bool = False
"""是否为私聊消息"""
stream_id: Optional[str] = None
"""流ID用于标识消息流"""
llm_prompt: Optional[str] = None
"""LLM提示词"""
llm_response: Optional[str] = None
"""LLM响应内容"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []

View File

@@ -1,11 +1,154 @@
from typing import List, Dict, Type
import asyncio
from typing import List, Dict, Optional, Type
from src.plugin_system.base.component_types import EventType
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, PluginInfo
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("events_manager")
class EventsManager:
def __init__(self):
# 有权重的 events 订阅者注册表
self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType}
self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._plugins: Dict[str, PluginInfo] = {} # 插件注册表
events_manager = EventsManager()
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
handler_name = handler_info.name
plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}"
if namespace_name in self.handler_mapping:
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
return False
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self.handler_mapping[namespace_name] = handler_class
return self._insert_event_handler(handler_class)
def register_plugins(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info (PluginInfo): 插件信息
Returns:
bool: 是否注册成功
"""
if plugin_info.name in self._plugins:
logger.warning(f"插件 {plugin_info.name} 已存在,跳过注册")
return False
self._plugins[plugin_info.name] = plugin_info
logger.debug(f"插件 {plugin_info.name} 注册成功")
return True
async def handler_mai_events(
self,
event_type: EventType,
message: MessageRecv,
llm_prompt: Optional[str] = None,
llm_response: Optional[str] = None,
) -> None:
"""处理 events"""
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self.events_subscribers.get(event_type, []):
if handler.intercept_message:
await handler.execute(transformed_message)
else:
asyncio.create_task(handler.execute(transformed_message))
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""插入事件处理器到对应的事件类型列表中"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
self.events_subscribers[handler_class.event_type].append(handler_class())
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self.events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
return True
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
return False
def _transform_event_message(
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None
) -> MaiMessages:
"""转换事件消息格式"""
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response=llm_response,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
# stream_id 处理
if hasattr(message, "chat_stream"):
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
return transformed_message
events_manager = EventsManager()