typing and plugins

This commit is contained in:
UnCLAS-Prommer
2025-07-21 01:23:23 +08:00
parent f2c901bc98
commit 484fc20983
8 changed files with 215 additions and 108 deletions

View File

@@ -41,6 +41,13 @@
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况) - 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
7. 现在不支持单文件插件了,加载方式已经完全删除。 7. 现在不支持单文件插件了,加载方式已经完全删除。
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。 8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
9. `BaseEventHandler`现在有了`get_config`方法了。
10. 修正了`main.py`中的错误输出。
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
12. `events_manager`有了task相关逻辑了。
### TODO
把这个看起来就很别扭的config获取方式改一下
# 吐槽 # 吐槽

View File

@@ -102,11 +102,12 @@ class PrintMessage(BaseEventHandler):
handler_name = "print_message_handler" handler_name = "print_message_handler"
handler_description = "打印接收到的消息" handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]: async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
"""执行打印消息事件处理""" """执行打印消息事件处理"""
# 打印接收到的消息 # 打印接收到的消息
if self.get_config("print_message.enabled", False):
print(f"接收到消息: {message.raw_message}") print(f"接收到消息: {message.raw_message}")
return True, "消息已打印" return True, True, "消息已打印"
# ===== 插件注册 ===== # ===== 插件注册 =====

View File

@@ -13,8 +13,8 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
@@ -141,7 +141,6 @@ class ChatBot:
group_info = message.message_info.group_info group_info = message.message_info.group_info
user_info = message.message_info.user_info user_info = message.message_info.user_info
get_chat_manager().register_message(message) get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore platform=message.message_info.platform, # type: ignore
@@ -158,7 +157,6 @@ class ChatBot:
return return
async def message_process(self, message_data: Dict[str, Any]) -> None: async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息 """处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中 这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
@@ -202,6 +200,9 @@ class ChatBot:
await MessageStorage.update_message(message) await MessageStorage.update_message(message)
return return
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
return
get_chat_manager().register_message(message) get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(

View File

@@ -508,7 +508,7 @@ class DefaultReplyer:
for msg_dict in message_list_before_now: for msg_dict in message_list_before_now:
try: try:
msg_user_id = str(msg_dict.get("user_id")) msg_user_id = str(msg_dict.get("user_id"))
if msg_user_id == bot_id or msg_user_id == target_user_id: if msg_user_id in [bot_id, target_user_id]:
# bot 和目标用户的对话 # bot 和目标用户的对话
core_dialogue_list.append(msg_dict) core_dialogue_list.append(msg_dict)
else: else:
@@ -724,8 +724,30 @@ class DefaultReplyer:
# 根据sender通过person_info_manager反向查找person_id再获取user_id # 根据sender通过person_info_manager反向查找person_id再获取user_id
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = person_info_manager.get_person_id_by_person_name(sender)
# 根据配置选择使用哪种 prompt 构建模式 if not global_config.chat.use_s4u_prompt_mode or not person_id:
if global_config.chat.use_s4u_prompt_mode and person_id: # 使用原有的模式
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
memory_block=memory_block,
tool_info_block=tool_info_block,
knowledge_prompt=prompt_info,
extra_info_block=extra_info_block,
relation_info_block=relation_info,
time_block=time_block,
reply_target_block=reply_target_block,
moderation_prompt=moderation_prompt_block,
keywords_reaction_prompt=keywords_reaction_prompt,
identity=identity_block,
target_message=target,
sender_name=sender,
config_expression_style=global_config.expression.expression_style,
action_descriptions=action_descriptions,
chat_target_2=chat_target_2,
mood_state=mood_prompt,
)
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话 # 使用 s4u 对话构建模式:分离当前对话对象和其他对话
try: try:
user_id_value = await person_info_manager.get_value(person_id, "user_id") user_id_value = await person_info_manager.get_value(person_id, "user_id")
@@ -764,30 +786,6 @@ class DefaultReplyer:
keywords_reaction_prompt=keywords_reaction_prompt, keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
) )
else:
# 使用原有的模式
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
memory_block=memory_block,
tool_info_block=tool_info_block,
knowledge_prompt=prompt_info,
extra_info_block=extra_info_block,
relation_info_block=relation_info,
time_block=time_block,
reply_target_block=reply_target_block,
moderation_prompt=moderation_prompt_block,
keywords_reaction_prompt=keywords_reaction_prompt,
identity=identity_block,
target_message=target,
sender_name=sender,
config_expression_style=global_config.expression.expression_style,
action_descriptions=action_descriptions,
chat_target_2=chat_target_2,
mood_state=mood_prompt,
)
async def build_prompt_rewrite_context( async def build_prompt_rewrite_context(
self, self,

View File

@@ -78,8 +78,7 @@ class MainSystem:
# logger.info("API服务器启动成功") # logger.info("API服务器启动成功")
# 加载所有actions包括默认的和插件的 # 加载所有actions包括默认的和插件的
plugin_count, component_count = plugin_manager.load_all_plugins() plugin_manager.load_all_plugins()
logger.info(f"插件系统加载成功: {plugin_count} 个插件,{component_count} 个组件")
# 初始化表情管理器 # 初始化表情管理器
get_emoji_manager().initialize() get_emoji_manager().initialize()

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple, Optional from typing import Tuple, Optional, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
@@ -21,15 +21,17 @@ class BaseEventHandler(ABC):
def __init__(self): def __init__(self):
self.log_prefix = "[EventHandler]" self.log_prefix = "[EventHandler]"
self.plugin_name = "" # 对应插件名
self.plugin_config: Optional[Dict] = None # 插件配置字典
if self.event_type == EventType.UNKNOWN: if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type") raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod @abstractmethod
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]: async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现 """执行事件处理的抽象方法,子类必须实现
Returns: Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息) Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
""" """
raise NotImplementedError("子类必须实现 execute 方法") raise NotImplementedError("子类必须实现 execute 方法")
@@ -49,3 +51,44 @@ class BaseEventHandler(ABC):
weight=cls.weight, weight=cls.weight,
intercept_message=cls.intercept_message, intercept_message=cls.intercept_message,
) )
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称
Args:
plugin_name (str): 插件名称
"""
self.plugin_name = plugin_name
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -159,7 +159,7 @@ class ComponentRegistry:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns: if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name self._command_patterns[pattern] = command_name
else:
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
return True return True

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import List, Dict, Optional, Type import contextlib
from typing import List, Dict, Optional, Type, Tuple
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -12,8 +13,9 @@ logger = get_logger("events_manager")
class EventsManager: class EventsManager:
def __init__(self): def __init__(self):
# 有权重的 events 订阅者注册表 # 有权重的 events 订阅者注册表
self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType} self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool: def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器 """注册事件处理器
@@ -29,7 +31,7 @@ class EventsManager:
plugin_name = getattr(handler_info, "plugin_name", "unknown") plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}" namespace_name = f"{plugin_name}.{handler_name}"
if namespace_name in self.handler_mapping: if namespace_name in self._handler_mapping:
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册") logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
return False return False
@@ -37,50 +39,73 @@ class EventsManager:
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类") logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False return False
self.handler_mapping[namespace_name] = handler_class self._handler_mapping[namespace_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
return self._insert_event_handler(handler_class) async def handle_mai_events(
async def handler_mai_events(
self, self,
event_type: EventType, event_type: EventType,
message: MessageRecv, message: MessageRecv,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional[str] = None, llm_response: Optional[str] = None,
) -> None: ) -> bool:
"""处理 events""" """处理 events"""
transformed_message = self._transform_event_message(message, llm_prompt, llm_response) from src.plugin_system.core import component_registry
for handler in self.events_subscribers.get(event_type, []):
if handler.intercept_message:
await handler.execute(transformed_message)
else:
asyncio.create_task(handler.execute(transformed_message))
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: continue_flag = True
"""插入事件处理器到对应的事件类型列表中""" transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self._events_subscribers.get(event_type, []):
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
if handler.intercept_message:
try:
success, continue_processing, result = await handler.execute(transformed_message)
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
else:
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
continue_flag = continue_flag and continue_processing
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
continue
else:
try:
handler_task = asyncio.create_task(handler.execute(transformed_message))
handler_task.add_done_callback(self._task_done_callback)
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
self._handler_tasks[handler.handler_name].append(handler_task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
continue
return continue_flag
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
if handler_class.event_type == EventType.UNKNOWN: if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册") logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False return False
self.events_subscribers[handler_class.event_type].append(handler_class()) handler_instance = handler_class()
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
self._events_subscribers[handler_class.event_type].append(handler_instance)
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器""" """从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN: if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中") logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
return False return False
handlers = self.events_subscribers[handler_class.event_type] handlers = self._events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers): for i, handler in enumerate(handlers):
if isinstance(handler, handler_class): if isinstance(handler, handler_class):
del handlers[i] del handlers[i]
logger.debug(f"事件处理器 {handler_class.__name__} 已移除") logger.debug(f"事件处理器 {display_handler_name} 已移除")
return True return True
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除") logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
return False return False
def _transform_event_message( def _transform_event_message(
@@ -102,13 +127,14 @@ class EventsManager:
transformed_message.message_segments = [message.message_segment] transformed_message.message_segments = [message.message_segment]
# stream_id 处理 # stream_id 处理
if hasattr(message, "chat_stream"): if hasattr(message, "chat_stream") and message.chat_stream:
transformed_message.stream_id = message.chat_stream.stream_id transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本 # 处理后文本
transformed_message.plain_text = message.processed_plain_text transformed_message.plain_text = message.processed_plain_text
# 基本信息 # 基本信息
if hasattr(message, "message_info") and message.message_info:
if message.message_info.platform: if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info: if message.message_info.group_info:
@@ -132,5 +158,37 @@ class EventsManager:
return transformed_message return transformed_message
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
"""任务完成回调"""
task_name = task.get_name() or "Unknown Task"
try:
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else:
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
finally:
with contextlib.suppress(ValueError, KeyError):
self._handler_tasks[task_name].remove(task)
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
remaining_tasks = [task for task in tasks_to_be_cancelled if not task.done()]
for task in remaining_tasks:
task.cancel()
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
except asyncio.TimeoutError:
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
except Exception as e:
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
finally:
del self._handler_tasks[handler_name]
events_manager = EventsManager() events_manager = EventsManager()