typing and plugins
This commit is contained in:
@@ -41,6 +41,13 @@
|
|||||||
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
|
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
|
||||||
7. 现在不支持单文件插件了,加载方式已经完全删除。
|
7. 现在不支持单文件插件了,加载方式已经完全删除。
|
||||||
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
|
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
|
||||||
|
9. `BaseEventHandler`现在有了`get_config`方法了。
|
||||||
|
10. 修正了`main.py`中的错误输出。
|
||||||
|
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
||||||
|
12. `events_manager`有了task相关逻辑了。
|
||||||
|
|
||||||
|
### TODO
|
||||||
|
把这个看起来就很别扭的config获取方式改一下
|
||||||
|
|
||||||
|
|
||||||
# 吐槽
|
# 吐槽
|
||||||
|
|||||||
@@ -102,11 +102,12 @@ class PrintMessage(BaseEventHandler):
|
|||||||
handler_name = "print_message_handler"
|
handler_name = "print_message_handler"
|
||||||
handler_description = "打印接收到的消息"
|
handler_description = "打印接收到的消息"
|
||||||
|
|
||||||
async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]:
|
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
||||||
"""执行打印消息事件处理"""
|
"""执行打印消息事件处理"""
|
||||||
# 打印接收到的消息
|
# 打印接收到的消息
|
||||||
print(f"接收到消息: {message.raw_message}")
|
if self.get_config("print_message.enabled", False):
|
||||||
return True, "消息已打印"
|
print(f"接收到消息: {message.raw_message}")
|
||||||
|
return True, True, "消息已打印"
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
# ===== 插件注册 =====
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
|||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
|
|
||||||
|
|
||||||
@@ -140,24 +140,22 @@ class ChatBot:
|
|||||||
message = MessageRecvS4U(message_data)
|
message = MessageRecvS4U(message_data)
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=message.message_info.platform, # type: ignore
|
platform=message.message_info.platform, # type: ignore
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
message.update_chat_stream(chat)
|
message.update_chat_stream(chat)
|
||||||
|
|
||||||
# 处理消息内容
|
# 处理消息内容
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
await self.s4u_message_processor.process_message(message)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
await self.s4u_message_processor.process_message(message)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||||
"""处理转化后的统一格式消息
|
"""处理转化后的统一格式消息
|
||||||
@@ -176,9 +174,9 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
# 确保所有任务已启动
|
# 确保所有任务已启动
|
||||||
await self._ensure_started()
|
await self._ensure_started()
|
||||||
|
|
||||||
platform = message_data["message_info"].get("platform")
|
platform = message_data["message_info"].get("platform")
|
||||||
|
|
||||||
if platform == "amaidesu_default":
|
if platform == "amaidesu_default":
|
||||||
await self.do_s4u(message_data)
|
await self.do_s4u(message_data)
|
||||||
return
|
return
|
||||||
@@ -202,6 +200,9 @@ class ChatBot:
|
|||||||
await MessageStorage.update_message(message)
|
await MessageStorage.update_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||||
|
return
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
|
|||||||
@@ -508,7 +508,7 @@ class DefaultReplyer:
|
|||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
if msg_user_id in [bot_id, target_user_id]:
|
||||||
# bot 和目标用户的对话
|
# bot 和目标用户的对话
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
else:
|
else:
|
||||||
@@ -553,7 +553,7 @@ class DefaultReplyer:
|
|||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_timeout: bool = False,
|
enable_timeout: bool = False,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
|
|
||||||
@@ -724,47 +724,7 @@ class DefaultReplyer:
|
|||||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
|
||||||
# 根据配置选择使用哪种 prompt 构建模式
|
if not global_config.chat.use_s4u_prompt_mode or not person_id:
|
||||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
|
||||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
|
||||||
try:
|
|
||||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
|
||||||
if user_id_value:
|
|
||||||
target_user_id = str(user_id_value)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
|
||||||
target_user_id = ""
|
|
||||||
|
|
||||||
# 构建分离的对话 prompt
|
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
|
||||||
message_list_before_now_long, target_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 s4u 风格的模板
|
|
||||||
template_name = "s4u_style_prompt"
|
|
||||||
|
|
||||||
return await global_prompt_manager.format_prompt(
|
|
||||||
template_name,
|
|
||||||
expression_habits_block=expression_habits_block,
|
|
||||||
tool_info_block=tool_info_block,
|
|
||||||
knowledge_prompt=prompt_info,
|
|
||||||
memory_block=memory_block,
|
|
||||||
relation_info_block=relation_info,
|
|
||||||
extra_info_block=extra_info_block,
|
|
||||||
identity=identity_block,
|
|
||||||
action_descriptions=action_descriptions,
|
|
||||||
sender_name=sender,
|
|
||||||
mood_state=mood_prompt,
|
|
||||||
background_dialogue_prompt=background_dialogue_prompt,
|
|
||||||
time_block=time_block,
|
|
||||||
core_dialogue_prompt=core_dialogue_prompt,
|
|
||||||
reply_target_block=reply_target_block,
|
|
||||||
message_txt=target,
|
|
||||||
config_expression_style=global_config.expression.expression_style,
|
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
|
||||||
moderation_prompt=moderation_prompt_block,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 使用原有的模式
|
# 使用原有的模式
|
||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
@@ -788,6 +748,44 @@ class DefaultReplyer:
|
|||||||
chat_target_2=chat_target_2,
|
chat_target_2=chat_target_2,
|
||||||
mood_state=mood_prompt,
|
mood_state=mood_prompt,
|
||||||
)
|
)
|
||||||
|
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||||
|
try:
|
||||||
|
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||||
|
if user_id_value:
|
||||||
|
target_user_id = str(user_id_value)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||||
|
target_user_id = ""
|
||||||
|
|
||||||
|
# 构建分离的对话 prompt
|
||||||
|
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||||
|
message_list_before_now_long, target_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 s4u 风格的模板
|
||||||
|
template_name = "s4u_style_prompt"
|
||||||
|
|
||||||
|
return await global_prompt_manager.format_prompt(
|
||||||
|
template_name,
|
||||||
|
expression_habits_block=expression_habits_block,
|
||||||
|
tool_info_block=tool_info_block,
|
||||||
|
knowledge_prompt=prompt_info,
|
||||||
|
memory_block=memory_block,
|
||||||
|
relation_info_block=relation_info,
|
||||||
|
extra_info_block=extra_info_block,
|
||||||
|
identity=identity_block,
|
||||||
|
action_descriptions=action_descriptions,
|
||||||
|
sender_name=sender,
|
||||||
|
mood_state=mood_prompt,
|
||||||
|
background_dialogue_prompt=background_dialogue_prompt,
|
||||||
|
time_block=time_block,
|
||||||
|
core_dialogue_prompt=core_dialogue_prompt,
|
||||||
|
reply_target_block=reply_target_block,
|
||||||
|
message_txt=target,
|
||||||
|
config_expression_style=global_config.expression.expression_style,
|
||||||
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
|
moderation_prompt=moderation_prompt_block,
|
||||||
|
)
|
||||||
|
|
||||||
async def build_prompt_rewrite_context(
|
async def build_prompt_rewrite_context(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -78,8 +78,7 @@ class MainSystem:
|
|||||||
# logger.info("API服务器启动成功")
|
# logger.info("API服务器启动成功")
|
||||||
|
|
||||||
# 加载所有actions,包括默认的和插件的
|
# 加载所有actions,包括默认的和插件的
|
||||||
plugin_count, component_count = plugin_manager.load_all_plugins()
|
plugin_manager.load_all_plugins()
|
||||||
logger.info(f"插件系统加载成功: {plugin_count} 个插件,{component_count} 个组件")
|
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, Dict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||||
@@ -21,15 +21,17 @@ class BaseEventHandler(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_prefix = "[EventHandler]"
|
self.log_prefix = "[EventHandler]"
|
||||||
|
self.plugin_name = "" # 对应插件名
|
||||||
|
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
||||||
if self.event_type == EventType.UNKNOWN:
|
if self.event_type == EventType.UNKNOWN:
|
||||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
|
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
|
||||||
"""执行事件处理的抽象方法,子类必须实现
|
"""执行事件处理的抽象方法,子类必须实现
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
|
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("子类必须实现 execute 方法")
|
raise NotImplementedError("子类必须实现 execute 方法")
|
||||||
|
|
||||||
@@ -49,3 +51,44 @@ class BaseEventHandler(ABC):
|
|||||||
weight=cls.weight,
|
weight=cls.weight,
|
||||||
intercept_message=cls.intercept_message,
|
intercept_message=cls.intercept_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_plugin_config(self, plugin_config: Dict) -> None:
|
||||||
|
"""设置插件配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_config (dict): 插件配置字典
|
||||||
|
"""
|
||||||
|
self.plugin_config = plugin_config
|
||||||
|
|
||||||
|
def set_plugin_name(self, plugin_name: str) -> None:
|
||||||
|
"""设置插件名称
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 插件名称
|
||||||
|
"""
|
||||||
|
self.plugin_name = plugin_name
|
||||||
|
|
||||||
|
def get_config(self, key: str, default=None):
|
||||||
|
"""获取插件配置值,支持嵌套键访问
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 配置值或默认值
|
||||||
|
"""
|
||||||
|
if not self.plugin_config:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# 支持嵌套键访问
|
||||||
|
keys = key.split(".")
|
||||||
|
current = self.plugin_config
|
||||||
|
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(current, dict) and k in current:
|
||||||
|
current = current[k]
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return current
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ class ComponentRegistry:
|
|||||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||||
if pattern not in self._command_patterns:
|
if pattern not in self._command_patterns:
|
||||||
self._command_patterns[pattern] = command_name
|
self._command_patterns[pattern] = command_name
|
||||||
|
else:
|
||||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Optional, Type
|
import contextlib
|
||||||
|
from typing import List, Dict, Optional, Type, Tuple
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -12,8 +13,9 @@ logger = get_logger("events_manager")
|
|||||||
class EventsManager:
|
class EventsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 有权重的 events 订阅者注册表
|
# 有权重的 events 订阅者注册表
|
||||||
self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||||
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||||
|
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||||
|
|
||||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||||
"""注册事件处理器
|
"""注册事件处理器
|
||||||
@@ -29,7 +31,7 @@ class EventsManager:
|
|||||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
||||||
|
|
||||||
namespace_name = f"{plugin_name}.{handler_name}"
|
namespace_name = f"{plugin_name}.{handler_name}"
|
||||||
if namespace_name in self.handler_mapping:
|
if namespace_name in self._handler_mapping:
|
||||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -37,50 +39,73 @@ class EventsManager:
|
|||||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.handler_mapping[namespace_name] = handler_class
|
self._handler_mapping[namespace_name] = handler_class
|
||||||
|
return self._insert_event_handler(handler_class, handler_info)
|
||||||
|
|
||||||
return self._insert_event_handler(handler_class)
|
async def handle_mai_events(
|
||||||
|
|
||||||
async def handler_mai_events(
|
|
||||||
self,
|
self,
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
message: MessageRecv,
|
message: MessageRecv,
|
||||||
llm_prompt: Optional[str] = None,
|
llm_prompt: Optional[str] = None,
|
||||||
llm_response: Optional[str] = None,
|
llm_response: Optional[str] = None,
|
||||||
) -> None:
|
) -> bool:
|
||||||
"""处理 events"""
|
"""处理 events"""
|
||||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
from src.plugin_system.core import component_registry
|
||||||
for handler in self.events_subscribers.get(event_type, []):
|
|
||||||
if handler.intercept_message:
|
|
||||||
await handler.execute(transformed_message)
|
|
||||||
else:
|
|
||||||
asyncio.create_task(handler.execute(transformed_message))
|
|
||||||
|
|
||||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
continue_flag = True
|
||||||
"""插入事件处理器到对应的事件类型列表中"""
|
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||||
|
for handler in self._events_subscribers.get(event_type, []):
|
||||||
|
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||||
|
if handler.intercept_message:
|
||||||
|
try:
|
||||||
|
success, continue_processing, result = await handler.execute(transformed_message)
|
||||||
|
if not success:
|
||||||
|
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
|
||||||
|
continue_flag = continue_flag and continue_processing
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||||
|
handler_task.add_done_callback(self._task_done_callback)
|
||||||
|
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
|
||||||
|
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||||
|
continue
|
||||||
|
return continue_flag
|
||||||
|
|
||||||
|
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
||||||
|
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.events_subscribers[handler_class.event_type].append(handler_class())
|
handler_instance = handler_class()
|
||||||
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
||||||
|
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
||||||
|
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||||
"""从事件类型列表中移除事件处理器"""
|
"""从事件类型列表中移除事件处理器"""
|
||||||
|
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
|
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
handlers = self.events_subscribers[handler_class.event_type]
|
handlers = self._events_subscribers[handler_class.event_type]
|
||||||
for i, handler in enumerate(handlers):
|
for i, handler in enumerate(handlers):
|
||||||
if isinstance(handler, handler_class):
|
if isinstance(handler, handler_class):
|
||||||
del handlers[i]
|
del handlers[i]
|
||||||
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
|
logger.debug(f"事件处理器 {display_handler_name} 已移除")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
|
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _transform_event_message(
|
def _transform_event_message(
|
||||||
@@ -102,35 +127,68 @@ class EventsManager:
|
|||||||
transformed_message.message_segments = [message.message_segment]
|
transformed_message.message_segments = [message.message_segment]
|
||||||
|
|
||||||
# stream_id 处理
|
# stream_id 处理
|
||||||
if hasattr(message, "chat_stream"):
|
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||||
transformed_message.stream_id = message.chat_stream.stream_id
|
transformed_message.stream_id = message.chat_stream.stream_id
|
||||||
|
|
||||||
# 处理后文本
|
# 处理后文本
|
||||||
transformed_message.plain_text = message.processed_plain_text
|
transformed_message.plain_text = message.processed_plain_text
|
||||||
|
|
||||||
# 基本信息
|
# 基本信息
|
||||||
if message.message_info.platform:
|
if hasattr(message, "message_info") and message.message_info:
|
||||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
if message.message_info.platform:
|
||||||
if message.message_info.group_info:
|
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||||
transformed_message.is_group_message = True
|
if message.message_info.group_info:
|
||||||
transformed_message.message_base_info.update(
|
transformed_message.is_group_message = True
|
||||||
{
|
transformed_message.message_base_info.update(
|
||||||
"group_id": message.message_info.group_info.group_id,
|
{
|
||||||
"group_name": message.message_info.group_info.group_name,
|
"group_id": message.message_info.group_info.group_id,
|
||||||
}
|
"group_name": message.message_info.group_info.group_name,
|
||||||
)
|
}
|
||||||
if message.message_info.user_info:
|
)
|
||||||
if not transformed_message.is_group_message:
|
if message.message_info.user_info:
|
||||||
transformed_message.is_private_message = True
|
if not transformed_message.is_group_message:
|
||||||
transformed_message.message_base_info.update(
|
transformed_message.is_private_message = True
|
||||||
{
|
transformed_message.message_base_info.update(
|
||||||
"user_id": message.message_info.user_info.user_id,
|
{
|
||||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
"user_id": message.message_info.user_info.user_id,
|
||||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||||
}
|
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return transformed_message
|
return transformed_message
|
||||||
|
|
||||||
|
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
|
||||||
|
"""任务完成回调"""
|
||||||
|
task_name = task.get_name() or "Unknown Task"
|
||||||
|
try:
|
||||||
|
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||||
|
if success:
|
||||||
|
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||||
|
else:
|
||||||
|
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
||||||
|
finally:
|
||||||
|
with contextlib.suppress(ValueError, KeyError):
|
||||||
|
self._handler_tasks[task_name].remove(task)
|
||||||
|
|
||||||
|
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||||
|
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||||
|
remaining_tasks = [task for task in tasks_to_be_cancelled if not task.done()]
|
||||||
|
for task in remaining_tasks:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
||||||
|
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
||||||
|
finally:
|
||||||
|
del self._handler_tasks[handler_name]
|
||||||
|
|
||||||
|
|
||||||
events_manager = EventsManager()
|
events_manager = EventsManager()
|
||||||
|
|||||||
Reference in New Issue
Block a user