From d30b0544b557780f69b05c7fadd0d18e0527cc51 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Mon, 24 Nov 2025 22:36:33 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=B6=88=E6=81=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=92=8C=E4=BF=A1=E5=B0=81=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 从代码库中移除了EnvelopeConverter类及其相关方法,因为它们已不再需要。 - 更新了主系统,使其能够直接处理MessageEnvelope对象,而无需将其转换为旧格式。 - 增强了MessageRuntime类,以支持多种消息类型并防止重复注册处理程序。 引入了一个新的MessageHandler类来管理消息处理,包括预处理和数据库存储。 - 改进了整个消息处理工作流程中的错误处理和日志记录。 - 更新了类型提示和数据模型,以确保消息结构的一致性和清晰度。 --- src/chat/message_receive/bot.py | 339 ++++++++--------- src/chat/message_receive/message.py | 1 - src/chat/message_receive/message_handler.py | 110 ++++++ src/chat/message_receive/message_processor.py | 318 +++++++++------- src/common/message/envelope_converter.py | 341 ------------------ src/main.py | 26 +- src/mofox_bus/runtime.py | 256 +++++++++++-- src/mofox_bus/types.py | 18 +- 8 files changed, 685 insertions(+), 724 deletions(-) create mode 100644 src/chat/message_receive/message_handler.py delete mode 100644 src/common/message/envelope_converter.py diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 31b2b5e71..6524ea8d3 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -3,8 +3,8 @@ import re import traceback from typing import Any -from mofox_bus import UserInfo - +from mofox_bus.runtime import MessageRuntime +from mofox_bus import MessageEnvelope from src.chat.message_manager import message_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.storage import MessageStorage @@ -63,13 +63,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: return True return False +runtime = MessageRuntime() # 获取mofox-bus运行时环境 class ChatBot: def __init__(self): self.bot = None # bot 实例引用 self._started = False - self.mood_manager = mood_manager # 获取情绪管理器单例 - + self.mood_manager = mood_manager # 获取情绪管理器单例 # 启动消息管理器 self._message_manager_started = False @@ -303,204 +303,185 @@ class ChatBot: except Exception as e: logger.error(f"处理适配器响应时出错: {e}") - - async def message_process(self, message_data: dict[str, Any]) -> None: + + @runtime.on_message + async def message_process(self, envelope: MessageEnvelope) -> None: """处理转化后的统一格式消息""" - try: - # 首先处理可能的切片消息重组 - from src.utils.message_chunker import reassembler + # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError + message_info = envelope.get("message_info") + if not isinstance(message_info, dict): + logger.debug( + "收到缺少 message_info 的消息,已跳过。可用字段: %s", + ", ".join(envelope.keys()), + ) + return - # 尝试重组切片消息 - reassembled_message = await reassembler.process_chunk(message_data) - if reassembled_message is None: - # 这是一个切片,但还未完整,等待更多切片 - logger.debug("等待更多切片,跳过此次处理") - return - elif reassembled_message != message_data: - # 消息已被重组,使用重组后的消息 - logger.info("使用重组后的完整消息进行处理") - message_data = reassembled_message - - # 确保所有任务已启动 - await self._ensure_started() - - # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError - message_info = message_data.get("message_info") - if not isinstance(message_info, dict): - logger.debug( - "收到缺少 message_info 的消息,已跳过。可用字段: %s", - ", ".join(message_data.keys()), - ) - return - - if message_info.get("group_info") is not None: - message_info["group_info"]["group_id"] = str( - message_info["group_info"]["group_id"] - ) - if message_info.get("user_info") is not None: - message_info["user_info"]["user_id"] = str( - message_info["user_info"]["user_id"] - ) - # print(message_data) - # logger.debug(str(message_data)) - - # 优先处理adapter_response消息(在echo检查之前!) - message_segment = message_data.get("message_segment") - if message_segment and isinstance(message_segment, dict): - if message_segment.get("type") == "adapter_response": - logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理") - await self._handle_adapter_response_from_dict(message_segment.get("data")) - return - - # 先提取基础信息检查是否是自身消息上报 - from mofox_bus import BaseMessageInfo - temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) - if temp_message_info.additional_config: - sent_message = temp_message_info.additional_config.get("echo", False) - if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题 - # 直接使用消息字典更新,不再需要创建 MessageRecv - await MessageStorage.update_message(message_data) - return - - message_segment = message_data.get("message_segment") - group_info = temp_message_info.group_info - user_info = temp_message_info.user_info - - # 获取或创建聊天流 - chat = await get_chat_manager().get_or_create_stream( - platform=temp_message_info.platform, # type: ignore - user_info=user_info, # type: ignore - group_info=group_info, + if message_info.get("group_info") is not None: + message_info["group_info"]["group_id"] = str( # type: ignore + message_info["group_info"]["group_id"] # type: ignore + ) + if message_info.get("user_info") is not None: + message_info["user_info"]["user_id"] = str( # type: ignore + message_info["user_info"]["user_id"] # type: ignore ) - # 使用新的消息处理器直接生成 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=message_data, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time - - # 注册消息到聊天管理器 - get_chat_manager().register_message(message) - - # 检测是否提及机器人 - message.is_mentioned, _ = is_mentioned_bot_in_message(message) - - # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 - chat_name = chat.group_info.group_name if chat.group_info else "私聊" - user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" - logger.info( - f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m" - ) - - # 在此添加硬编码过滤,防止回复图片处理失败的消息 - failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] - processed_text = message.processed_plain_text or "" - if any(keyword in processed_text for keyword in failure_keywords): - logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") + # 优先处理adapter_response消息(在echo检查之前!) + message_segment = envelope.get("message_segment") + if message_segment and isinstance(message_segment, dict): + if message_segment.get("type") == "adapter_response": + logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理") + await self._handle_adapter_response_from_dict(message_segment.get("data")) return - # 过滤检查 - # DatabaseMessages 使用 display_message 作为原始消息表示 - raw_text = message.display_message or message.processed_plain_text or "" - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - raw_text, - chat, - user_info, # type: ignore - ): + # 先提取基础信息检查是否是自身消息上报 + from mofox_bus import BaseMessageInfo + temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) + if temp_message_info.additional_config: + sent_message = temp_message_info.additional_config.get("echo", False) + if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题 + # 直接使用消息字典更新,不再需要创建 MessageRecv + await MessageStorage.update_message(message_data) return - # 命令处理 - 首先尝试PlusCommand独立处理 - is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) + message_segment = envelope.get("message_segment") + group_info = temp_message_info.group_info + user_info = temp_message_info.user_info - # 如果是PlusCommand且不需要继续处理,则直接返回 - if is_plus_command and not plus_continue_process: + # 获取或创建聊天流 + chat = await get_chat_manager().get_or_create_stream( + platform=temp_message_info.platform, # type: ignore + user_info=user_info, # type: ignore + group_info=group_info, + ) + + # 使用新的消息处理器直接生成 DatabaseMessages + from src.chat.message_receive.message_processor import process_message_from_dict + message = await process_message_from_dict( + message_dict=envelope, + stream_id=chat.stream_id, + platform=chat.platform + ) + + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + + # 注册消息到聊天管理器 + get_chat_manager().register_message(message) + + # 检测是否提及机器人 + message.is_mentioned, _ = is_mentioned_bot_in_message(message) + + # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 + chat_name = chat.group_info.group_name if chat.group_info else "私聊" + user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" + logger.info( + f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m" + ) + + # 在此添加硬编码过滤,防止回复图片处理失败的消息 + failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] + processed_text = message.processed_plain_text or "" + if any(keyword in processed_text for keyword in failure_keywords): + logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") + return + + # 过滤检查 + # DatabaseMessages 使用 display_message 作为原始消息表示 + raw_text = message.display_message or message.processed_plain_text or "" + if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore + raw_text, + chat, + user_info, # type: ignore + ): + return + + # 命令处理 - 首先尝试PlusCommand独立处理 + is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) + + # 如果是PlusCommand且不需要继续处理,则直接返回 + if is_plus_command and not plus_continue_process: + await MessageStorage.store_message(message, chat) + logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}") + return + + # 如果不是PlusCommand,尝试传统的BaseCommand处理 + if not is_plus_command: + is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat) + + # 如果是命令且不需要继续处理,则直接返回 + if is_command and not continue_process: await MessageStorage.store_message(message, chat) - logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}") + logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") return - # 如果不是PlusCommand,尝试传统的BaseCommand处理 - if not is_plus_command: - is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat) + result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) + if result and not result.all_continue_process(): + raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - # 如果是命令且不需要继续处理,则直接返回 - if is_command and not continue_process: - await MessageStorage.store_message(message, chat) - logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") - return + # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info + # 确认从接口发来的message是否有自定义的prompt模板信息 + # 这个功能需要在 adapter 层通过 additional_config 传递 + template_group_name = None - result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) - if result and not result.all_continue_process(): - raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") + async def preprocess(): + # message 已经是 DatabaseMessages,直接使用 + group_info = chat.group_info - # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info - # 确认从接口发来的message是否有自定义的prompt模板信息 - # 这个功能需要在 adapter 层通过 additional_config 传递 - template_group_name = None + # 先交给消息管理器处理,计算兴趣度等衍生数据 + try: + # 在将消息添加到管理器之前进行最终的静默检查 + should_process_in_manager = True + if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list: + # 检查消息是否为图片或表情包 + is_image_or_emoji = message.is_picid or message.is_emoji + if not message.is_mentioned and not is_image_or_emoji: + logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理") + should_process_in_manager = False + elif is_image_or_emoji: + logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理") + should_process_in_manager = False - async def preprocess(): - # message 已经是 DatabaseMessages,直接使用 - group_info = chat.group_info + if should_process_in_manager: + await message_manager.add_message(chat.stream_id, message) + logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") - # 先交给消息管理器处理,计算兴趣度等衍生数据 - try: - # 在将消息添加到管理器之前进行最终的静默检查 - should_process_in_manager = True - if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list: - # 检查消息是否为图片或表情包 - is_image_or_emoji = message.is_picid or message.is_emoji - if not message.is_mentioned and not is_image_or_emoji: - logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理") - should_process_in_manager = False - elif is_image_or_emoji: - logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理") - should_process_in_manager = False + except Exception as e: + logger.error(f"消息添加到消息管理器失败: {e}") - if should_process_in_manager: - await message_manager.add_message(chat.stream_id, message) - logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") + # 存储消息到数据库,只进行一次写入 + try: + await MessageStorage.store_message(message, chat) + except Exception as e: + logger.error(f"存储消息到数据库失败: {e}") + traceback.print_exc() - except Exception as e: - logger.error(f"消息添加到消息管理器失败: {e}") + # 情绪系统更新 - 在消息存储后触发情绪更新 + try: + if global_config.mood.enable_mood: + # 获取兴趣度用于情绪更新 + interest_rate = message.interest_value + if interest_rate is None: + interest_rate = 0.0 + logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") - # 存储消息到数据库,只进行一次写入 - try: - await MessageStorage.store_message(message, chat) - except Exception as e: - logger.error(f"存储消息到数据库失败: {e}") - traceback.print_exc() + # 获取当前聊天的情绪对象并更新情绪状态 + chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) + await chat_mood.update_mood_by_message(message, interest_rate) + logger.debug("情绪状态更新完成") + except Exception as e: + logger.error(f"更新情绪状态失败: {e}") + traceback.print_exc() - # 情绪系统更新 - 在消息存储后触发情绪更新 - try: - if global_config.mood.enable_mood: - # 获取兴趣度用于情绪更新 - interest_rate = message.interest_value - if interest_rate is None: - interest_rate = 0.0 - logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") - - # 获取当前聊天的情绪对象并更新情绪状态 - chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) - await chat_mood.update_mood_by_message(message, interest_rate) - logger.debug("情绪状态更新完成") - except Exception as e: - logger.error(f"更新情绪状态失败: {e}") - traceback.print_exc() - - if template_group_name: - async with global_prompt_manager.async_message_scope(template_group_name): - await preprocess() - else: + if template_group_name: + async with global_prompt_manager.async_message_scope(template_group_name): await preprocess() + else: + await preprocess() - except Exception as e: - logger.error(f"预处理消息失败: {e}") - traceback.print_exc() + except Exception as e: + logger.error(f"预处理消息失败: {e}") + traceback.print_exc() # 创建全局ChatBot实例 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fc2dcacc5..305e73de2 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Optional import urllib3 -from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo from rich.traceback import install from src.chat.message_receive.chat_stream import ChatStream diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py new file mode 100644 index 000000000..be81ab0eb --- /dev/null +++ b/src/chat/message_receive/message_handler.py @@ -0,0 +1,110 @@ +import os +import traceback + +from mofox_bus.runtime import MessageRuntime +from mofox_bus import MessageEnvelope +from src.chat.message_manager import message_manager +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages + +runtime = MessageRuntime() + +# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) + +# 配置主程序日志格式 +logger = get_logger("chat") + +class MessageHandler: + def __init__(self): + self._started = False + + async def preprocess(self, chat: ChatStream, message: DatabaseMessages): + # message 已经是 DatabaseMessages,直接使用 + group_info = chat.group_info + + # 先交给消息管理器处理 + try: + # 在将消息添加到管理器之前进行最终的静默检查 + should_process_in_manager = True + if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list: + # 检查消息是否为图片或表情包 + is_image_or_emoji = message.is_picid or message.is_emoji + if not message.is_mentioned and not is_image_or_emoji: + logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理") + should_process_in_manager = False + elif is_image_or_emoji: + logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理") + should_process_in_manager = False + + if should_process_in_manager: + await message_manager.add_message(chat.stream_id, message) + logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") + + except Exception as e: + logger.error(f"消息添加到消息管理器失败: {e}") + + # 存储消息到数据库,只进行一次写入 + try: + await MessageStorage.store_message(message, chat) + except Exception as e: + logger.error(f"存储消息到数据库失败: {e}") + traceback.print_exc() + + # 情绪系统更新 - 在消息存储后触发情绪更新 + try: + if global_config.mood.enable_mood: + # 获取兴趣度用于情绪更新 + interest_rate = message.interest_value + if interest_rate is None: + interest_rate = 0.0 + logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") + + # 获取当前聊天的情绪对象并更新情绪状态 + chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) + await chat_mood.update_mood_by_message(message, interest_rate) + logger.debug("情绪状态更新完成") + except Exception as e: + logger.error(f"更新情绪状态失败: {e}") + traceback.print_exc() + + + async def handle_message(self, envelope: MessageEnvelope): + # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError + message_info = envelope.get("message_info") + if not isinstance(message_info, dict): + logger.debug( + "收到缺少 message_info 的消息,已跳过。可用字段: %s", + ", ".join(envelope.keys()), + ) + return + + if message_info.get("group_info") is not None: + message_info["group_info"]["group_id"] = str( # type: ignore + message_info["group_info"]["group_id"] # type: ignore + ) + if message_info.get("user_info") is not None: + message_info["user_info"]["user_id"] = str( # type: ignore + message_info["user_info"]["user_id"] # type: ignore + ) + + group_info = message_info.get("group_info") + user_info = message_info.get("user_info") + + chat_stream = await get_chat_manager().get_or_create_stream( + platform=envelope["platform"], # type: ignore + user_info=user_info, # type: ignore + group_info=group_info, + ) + + # 生成 DatabaseMessages + from src.chat.message_receive.message_processor import process_message_from_dict + message = await process_message_from_dict( + message_dict=envelope, + stream_id=chat_stream.stream_id, + platform=chat_stream.platform + ) + + \ No newline at end of file diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index 9f848f819..e278c545a 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -1,13 +1,14 @@ """消息处理工具模块 将原 MessageRecv 的消息处理逻辑提取为独立函数, -直接从适配器消息字典生成 DatabaseMessages +基于 mofox-bus 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages """ import base64 import time from typing import Any import orjson -from mofox_bus import BaseMessageInfo, Seg +from mofox_bus import MessageEnvelope +from mofox_bus.types import MessageInfoPayload, SegPayload, UserInfoPayload, GroupInfoPayload from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager @@ -20,26 +21,26 @@ from src.config.config import global_config logger = get_logger("message_processor") -async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages: +async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages: """从适配器消息字典处理并生成 DatabaseMessages 这个函数整合了原 MessageRecv 的所有处理逻辑: 1. 解析 message_segment 并异步处理内容(图片、语音、视频等) 2. 提取所有消息元数据 - 3. 直接构造 DatabaseMessages 对象 + 3. 基于 TypedDict 形式构建数据,然后转换为 DatabaseMessages Args: - message_dict: MessageCQ序列化后的字典 + message_dict: MessageEnvelope 格式的消息字典 stream_id: 聊天流ID platform: 平台标识 Returns: DatabaseMessages: 处理完成的数据库消息对象 """ - # 解析基础信息 - message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - + # 提取核心数据(使用 TypedDict 类型) + message_info: MessageInfoPayload = message_dict.get("message_info", {}) # type: ignore + message_segment: SegPayload | list[SegPayload] = message_dict.get("message_segment", {"type": "text", "data": ""}) # type: ignore + # 初始化处理状态 processing_state = { "is_emoji": False, @@ -61,26 +62,26 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str is_notify = False is_public_notice = False notice_type = None - if message_info.additional_config and isinstance(message_info.additional_config, dict): - is_notify = message_info.additional_config.get("is_notice", False) - is_public_notice = message_info.additional_config.get("is_public_notice", False) - notice_type = message_info.additional_config.get("notice_type") + additional_config_dict = message_info.get("additional_config", {}) + if isinstance(additional_config_dict, dict): + is_notify = additional_config_dict.get("is_notice", False) + is_public_notice = additional_config_dict.get("is_public_notice", False) + notice_type = additional_config_dict.get("notice_type") # 提取用户信息 - user_info = message_info.user_info - user_id = str(user_info.user_id) if user_info and user_info.user_id else "" - user_nickname = (user_info.user_nickname or "") if user_info else "" - user_cardname = user_info.user_cardname if user_info else None - user_platform = (user_info.platform or "") if user_info else "" + user_info_payload: UserInfoPayload = message_info.get("user_info", {}) # type: ignore + user_id = str(user_info_payload.get("user_id", "")) + user_nickname = user_info_payload.get("user_nickname", "") + user_cardname = user_info_payload.get("user_cardname") + user_platform = user_info_payload.get("platform", "") # 提取群组信息 - group_info = message_info.group_info - group_id = group_info.group_id if group_info else None - group_name = group_info.group_name if group_info else None - group_platform = group_info.platform if group_info else None + group_info_payload: GroupInfoPayload | None = message_info.get("group_info") # type: ignore + group_id = group_info_payload.get("group_id") if group_info_payload else None + group_name = group_info_payload.get("group_name") if group_info_payload else None + group_platform = group_info_payload.get("platform") if group_info_payload else None # chat_id 应该直接使用 stream_id(与数据库存储格式一致) - # stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的 chat_id = stream_id # 准备 additional_config @@ -89,18 +90,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str # 提取 reply_to reply_to = _extract_reply_from_segment(message_segment) - # 构造 DatabaseMessages - message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() - message_id = message_info.message_id or "" + # 构造消息数据字典(基于 TypedDict 风格) + message_time = message_info.get("time", time.time()) + message_id = message_info.get("message_id", "") # 处理 is_mentioned is_mentioned = None mentioned_value = processing_state.get("is_mentioned") if isinstance(mentioned_value, bool): is_mentioned = mentioned_value - elif isinstance(mentioned_value, int | float): + elif isinstance(mentioned_value, (int, float)): is_mentioned = mentioned_value != 0 + # 使用 TypedDict 风格的数据构建 DatabaseMessages db_message = DatabaseMessages( message_id=message_id, time=float(message_time), @@ -134,7 +136,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str chat_info_group_platform=group_platform, ) - # 设置优先级信息 + # 设置优先级信息(运行时属性) if processing_state.get("priority_mode"): setattr(db_message, "priority_mode", processing_state["priority_mode"]) if processing_state.get("priority_info"): @@ -149,99 +151,127 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str return db_message -async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: +async def _process_message_segments( + segment: SegPayload | list[SegPayload], + state: dict, + message_info: MessageInfoPayload +) -> str: """递归处理消息段,转换为文字描述 Args: - segment: 要处理的消息段 + segment: 要处理的消息段(TypedDict 或列表) state: 处理状态字典(用于记录消息类型标记) - message_info: 消息基础信息(用于某些处理逻辑) + message_info: 消息基础信息(TypedDict 格式) Returns: str: 处理后的文本 """ - if segment.type == "seglist": - # 处理消息段列表 + # 如果是列表,遍历处理 + if isinstance(segment, list): segments_text = [] - for seg in segment.data: + for seg in segment: processed = await _process_message_segments(seg, state, message_info) if processed: segments_text.append(processed) return " ".join(segments_text) - else: - # 处理单个消息段 + + # 如果是单个段 + if isinstance(segment, dict): + seg_type = segment.get("type", "") + seg_data = segment.get("data") + + # 处理 seglist 类型 + if seg_type == "seglist" and isinstance(seg_data, list): + segments_text = [] + for sub_seg in seg_data: + processed = await _process_message_segments(sub_seg, state, message_info) + if processed: + segments_text.append(processed) + return " ".join(segments_text) + + # 处理其他类型 return await _process_single_segment(segment, state, message_info) + + return "" -async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: +async def _process_single_segment( + segment: SegPayload, + state: dict, + message_info: MessageInfoPayload +) -> str: """处理单个消息段 Args: - segment: 消息段 + segment: 消息段(TypedDict 格式) state: 处理状态字典 - message_info: 消息基础信息 + message_info: 消息基础信息(TypedDict 格式) Returns: str: 处理后的文本 """ + seg_type = segment.get("type", "") + seg_data = segment.get("data") + try: - if segment.type == "text": + if seg_type == "text": state["is_picid"] = False state["is_emoji"] = False state["is_video"] = False - return segment.data + return str(seg_data) if seg_data else "" - elif segment.type == "at": + elif seg_type == "at": state["is_picid"] = False state["is_emoji"] = False state["is_video"] = False state["is_at"] = True # 处理at消息,格式为"@<昵称:QQ号>" - if isinstance(segment.data, str): - if ":" in segment.data: + if isinstance(seg_data, str): + if ":" in seg_data: # 标准格式: "昵称:QQ号" - nickname, qq_id = segment.data.split(":", 1) - result = f"@<{nickname}:{qq_id}>" - return result + nickname, qq_id = seg_data.split(":", 1) + return f"@<{nickname}:{qq_id}>" else: - logger.warning(f"[at处理] 无法解析格式: '{segment.data}'") - return f"@{segment.data}" - logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}") - return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" + logger.warning(f"[at处理] 无法解析格式: '{seg_data}'") + return f"@{seg_data}" + logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}") + return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户" - elif segment.type == "image": + elif seg_type == "image": # 如果是base64图片数据 - if isinstance(segment.data, str): + if isinstance(seg_data, str): state["has_picid"] = True state["is_picid"] = True state["is_emoji"] = False state["is_video"] = False image_manager = get_image_manager() - _, processed_text = await image_manager.process_image(segment.data) + _, processed_text = await image_manager.process_image(seg_data) return processed_text return "[发了一张图片,网卡了加载不出来]" - elif segment.type == "emoji": + elif seg_type == "emoji": state["has_emoji"] = True state["is_emoji"] = True state["is_picid"] = False state["is_voice"] = False state["is_video"] = False - if isinstance(segment.data, str): - return await get_image_manager().get_emoji_description(segment.data) + if isinstance(seg_data, str): + return await get_image_manager().get_emoji_description(seg_data) return "[发了一个表情包,网卡了加载不出来]" - elif segment.type == "voice": + elif seg_type == "voice": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = True state["is_video"] = False # 检查消息是否由机器人自己发送 - if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account): - logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。") - if isinstance(segment.data, str): - cached_text = consume_self_voice_text(segment.data) + user_info = message_info.get("user_info", {}) + user_id_str = str(user_info.get("user_id", "")) + if user_id_str == str(global_config.bot.qq_account): + logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。") + if isinstance(seg_data, str): + cached_text = consume_self_voice_text(seg_data) if cached_text: logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") return f"[语音:{cached_text}]" @@ -249,41 +279,42 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") # 标准语音识别流程 - if isinstance(segment.data, str): - return await get_voice_text(segment.data) + if isinstance(seg_data, str): + return await get_voice_text(seg_data) return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "mention_bot": + elif seg_type == "mention_bot": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = False state["is_video"] = False - state["is_mentioned"] = float(segment.data) + if isinstance(seg_data, (int, float)): + state["is_mentioned"] = float(seg_data) return "" - elif segment.type == "priority_info": + elif seg_type == "priority_info": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = False - if isinstance(segment.data, dict): + if isinstance(seg_data, dict): # 处理优先级信息 state["priority_mode"] = "priority" - state["priority_info"] = segment.data + state["priority_info"] = seg_data return "" - elif segment.type == "file": - if isinstance(segment.data, dict): - file_name = segment.data.get("name", "未知文件") - file_size = segment.data.get("size", "未知大小") + elif seg_type == "file": + if isinstance(seg_data, dict): + file_name = seg_data.get("name", "未知文件") + file_size = seg_data.get("size", "未知大小") return f"[文件:{file_name} ({file_size}字节)]" return "[收到一个文件]" - elif segment.type == "video": + elif seg_type == "video": state["is_picid"] = False state["is_emoji"] = False state["is_voice"] = False state["is_video"] = True - logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") + logger.info(f"接收到视频消息,数据类型: {type(seg_data)}") # 检查视频分析功能是否可用 if not is_video_analysis_available(): @@ -292,11 +323,11 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM if global_config.video_analysis.enable: logger.info("已启用视频识别,开始识别") - if isinstance(segment.data, dict): + if isinstance(seg_data, dict): try: # 从Adapter接收的视频数据 - video_base64 = segment.data.get("base64") - filename = segment.data.get("filename", "video.mp4") + video_base64 = seg_data.get("base64") + filename = seg_data.get("filename", "video.mp4") logger.info(f"视频文件名: {filename}") logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") @@ -329,24 +360,29 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM logger.error(f"错误详情: {traceback.format_exc()}") return "[收到视频,但处理时出现错误]" else: - logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") + logger.warning(f"视频消息数据不是字典格式: {type(seg_data)}") return "[发了一个视频,但格式不支持]" else: return "" else: - logger.warning(f"未知的消息段类型: {segment.type}") - return f"[{segment.type} 消息]" + logger.warning(f"未知的消息段类型: {seg_type}") + return f"[{seg_type} 消息]" except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" + logger.error(f"处理消息段失败: {e!s}, 类型: {seg_type}, 数据: {seg_data}") + return f"[处理失败的{seg_type}消息]" -def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None: +def _prepare_additional_config( + message_info: MessageInfoPayload, + is_notify: bool, + is_public_notice: bool, + notice_type: str | None +) -> str | None: """准备 additional_config,包含 format_info 和 notice 信息 Args: - message_info: 消息基础信息 + message_info: 消息基础信息(TypedDict 格式) is_notify: 是否为notice消息 is_public_notice: 是否为公共notice notice_type: notice类型 @@ -358,12 +394,13 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i additional_config_data = {} # 首先获取adapter传递的additional_config - if hasattr(message_info, "additional_config") and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_data = message_info.additional_config.copy() - elif isinstance(message_info.additional_config, str): + additional_config_raw = message_info.get("additional_config") + if additional_config_raw: + if isinstance(additional_config_raw, dict): + additional_config_data = additional_config_raw.copy() + elif isinstance(additional_config_raw, str): try: - additional_config_data = orjson.loads(message_info.additional_config) + additional_config_data = orjson.loads(additional_config_raw) except Exception as e: logger.warning(f"无法解析 additional_config JSON: {e}") additional_config_data = {} @@ -375,11 +412,11 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i additional_config_data["is_public_notice"] = bool(is_public_notice) # 添加format_info到additional_config中 - if hasattr(message_info, "format_info") and message_info.format_info: + format_info = message_info.get("format_info") + if format_info: try: - format_info_dict = message_info.format_info.to_dict() - additional_config_data["format_info"] = format_info_dict - logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}") + additional_config_data["format_info"] = format_info + logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info}") except Exception as e: logger.warning(f"将 format_info 转换为字典失败: {e}") @@ -392,28 +429,43 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i return None -def _extract_reply_from_segment(segment: Seg) -> str | None: +def _extract_reply_from_segment(segment: SegPayload | list[SegPayload]) -> str | None: """从消息段中提取reply_to信息 Args: - segment: 消息段 + segment: 消息段(TypedDict 格式或列表) Returns: str | None: 回复的消息ID,如果没有则返回None """ try: - if hasattr(segment, "type") and segment.type == "seglist": - # 递归搜索seglist中的reply段 - if hasattr(segment, "data") and segment.data: - for seg in segment.data: - reply_id = _extract_reply_from_segment(seg) + # 如果是列表,遍历查找 + if isinstance(segment, list): + for seg in segment: + reply_id = _extract_reply_from_segment(seg) + if reply_id: + return reply_id + return None + + # 如果是字典 + if isinstance(segment, dict): + seg_type = segment.get("type", "") + seg_data = segment.get("data") + + # 如果是 seglist,递归搜索 + if seg_type == "seglist" and isinstance(seg_data, list): + for sub_seg in seg_data: + reply_id = _extract_reply_from_segment(sub_seg) if reply_id: return reply_id - elif hasattr(segment, "type") and segment.type == "reply": - # 找到reply段,返回message_id - return str(segment.data) if segment.data else None + + # 如果是 reply 段,返回 message_id + elif seg_type == "reply": + return str(seg_data) if seg_data else None + except Exception as e: logger.warning(f"提取reply_to信息失败: {e}") + return None @@ -421,33 +473,31 @@ def _extract_reply_from_segment(segment: Seg) -> str | None: # DatabaseMessages 扩展工具函数 # ============================================================================= -def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo: - """从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码) +def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInfoPayload: + """从 DatabaseMessages 重建 MessageInfoPayload(TypedDict 格式) Args: db_message: DatabaseMessages 对象 Returns: - BaseMessageInfo: 重建的消息信息对象 + MessageInfoPayload: 重建的消息信息对象(TypedDict 格式) """ - from mofox_bus import GroupInfo, UserInfo + # 构建用户信息 + user_info: UserInfoPayload = { + "platform": db_message.user_info.platform, + "user_id": db_message.user_info.user_id, + "user_nickname": db_message.user_info.user_nickname, + "user_cardname": db_message.user_info.user_cardname or "", + } - # 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo - user_info = UserInfo( - platform=db_message.user_info.platform, - user_id=db_message.user_info.user_id, - user_nickname=db_message.user_info.user_nickname, - user_cardname=db_message.user_info.user_cardname or "" - ) - - # 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo(如果存在) - group_info = None + # 构建群组信息(如果存在) + group_info: GroupInfoPayload | None = None if db_message.group_info: - group_info = GroupInfo( - platform=db_message.group_info.group_platform or "", - group_id=db_message.group_info.group_id, - group_name=db_message.group_info.group_name - ) + group_info = { + "platform": db_message.group_info.group_platform or "", + "group_id": db_message.group_info.group_id, + "group_name": db_message.group_info.group_name, + } # 解析 additional_config(从 JSON 字符串到字典) additional_config = None @@ -458,15 +508,19 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag # 如果解析失败,保持为字符串 pass - # 创建 BaseMessageInfo - message_info = BaseMessageInfo( - platform=db_message.chat_info.platform, - message_id=db_message.message_id, - time=db_message.time, - user_info=user_info, - group_info=group_info, - additional_config=additional_config # type: ignore - ) + # 创建 MessageInfoPayload + message_info: MessageInfoPayload = { + "platform": db_message.chat_info.platform, + "message_id": db_message.message_id, + "time": db_message.time, + "user_info": user_info, + } + + if group_info: + message_info["group_info"] = group_info + + if additional_config: + message_info["additional_config"] = additional_config return message_info diff --git a/src/common/message/envelope_converter.py b/src/common/message/envelope_converter.py deleted file mode 100644 index 8434565f7..000000000 --- a/src/common/message/envelope_converter.py +++ /dev/null @@ -1,341 +0,0 @@ -""" -MessageEnvelope converter between mofox_bus schema and internal message structures. - -- 优先处理 maim_message 风格的 message_info + message_segment。 -- 兼容旧版 content/sender/channel 结构,方便逐步迁移。 -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from mofox_bus import ( - BaseMessageInfo, - MessageBase, - MessageEnvelope, - Seg, - UserInfo, - GroupInfo, -) - -from src.common.logger import get_logger - -logger = get_logger("envelope_converter") - - -class EnvelopeConverter: - """MessageEnvelope <-> MessageBase converter.""" - - @staticmethod - def to_message_base(envelope: MessageEnvelope) -> MessageBase: - """ - Convert MessageEnvelope to MessageBase. - """ - try: - # 优先使用 maim_message 样式字段 - info_payload = envelope.get("message_info") or {} - seg_payload = envelope.get("message_segment") or envelope.get("message_chain") - - if info_payload: - message_info = BaseMessageInfo.from_dict(info_payload) - else: - message_info = EnvelopeConverter._build_info_from_legacy(envelope) - - if seg_payload is None: - seg_list = EnvelopeConverter._content_to_segments(envelope.get("content")) - seg_payload = seg_list - - message_segment = EnvelopeConverter._ensure_seg(seg_payload) - raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message") - - return MessageBase( - message_info=message_info, - message_segment=message_segment, - raw_message=raw_message, - ) - except Exception as e: - logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True) - raise - - @staticmethod - def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo: - """将 legacy 字段映射为 BaseMessageInfo。""" - platform = envelope.get("platform") - channel = envelope.get("channel") or {} - sender = envelope.get("sender") or {} - - message_id = envelope.get("id") or envelope.get("message_id") - timestamp_ms = envelope.get("timestamp_ms") - time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None - - group_info: Optional[GroupInfo] = None - channel_type = channel.get("channel_type") - if channel_type in ("group", "supergroup", "room"): - group_info = GroupInfo( - platform=platform, - group_id=channel.get("channel_id"), - group_name=channel.get("title"), - ) - - user_info: Optional[UserInfo] = None - if sender: - user_info = UserInfo( - platform=platform, - user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None, - user_nickname=sender.get("display_name") or sender.get("user_nickname"), - user_avatar=sender.get("avatar_url"), - ) - - return BaseMessageInfo( - platform=platform, - message_id=message_id, - time=time_value, - group_info=group_info, - user_info=user_info, - additional_config=envelope.get("metadata"), - ) - - @staticmethod - def _ensure_seg(payload: Any) -> Seg: - """将任意 payload 转为 Seg dataclass。""" - if isinstance(payload, Seg): - return payload - if isinstance(payload, list): - # 直接传入 Seg 列表或 seglist data - return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload]) - if isinstance(payload, dict): - seg_type = payload.get("type") or "text" - data = payload.get("data") - if seg_type == "seglist" and isinstance(data, list): - data = [EnvelopeConverter._ensure_seg(item) for item in data] - return Seg(type=seg_type, data=data) - # 兜底:转成文本片段 - return Seg(type="text", data="" if payload is None else str(payload)) - - @staticmethod - def _flatten_segments(seg: Seg) -> List[Seg]: - """将 Seg/seglist 打平成列表,便于旧 content 转换。""" - if seg.type == "seglist" and isinstance(seg.data, list): - return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data] - return [seg] - - @staticmethod - def _content_to_segments(content: Any) -> List[Seg]: - """ - Convert legacy Content (type/data/metadata) to a flat list of Seg. - """ - segments: List[Seg] = [] - - def _walk(node: Any) -> None: - if node is None: - return - if isinstance(node, list): - for item in node: - _walk(item) - return - if not isinstance(node, dict): - logger.warning("未知的 content 节点类型: %s", type(node)) - return - - content_type = node.get("type") - data = node.get("data") - metadata = node.get("metadata") or {} - - if content_type == "collection": - items = data if isinstance(data, list) else node.get("items", []) - for item in items: - _walk(item) - return - - if content_type in ("text", "at"): - subtype = metadata.get("subtype") or ("at" if content_type == "at" else None) - text = "" if data is None else str(data) - if subtype in ("at", "mention"): - user_info = metadata.get("user") or {} - seg_data: Dict[str, Any] = { - "user_id": user_info.get("id") or user_info.get("user_id"), - "user_name": user_info.get("name") or user_info.get("display_name"), - "text": text, - "raw": user_info.get("raw") or user_info if user_info else None, - } - if any(v is not None for v in seg_data.values()): - segments.append(Seg(type="at", data=seg_data)) - else: - segments.append(Seg(type="at", data=text)) - else: - segments.append(Seg(type="text", data=text)) - return - - if content_type == "image": - url = "" - if isinstance(data, dict): - url = data.get("url") or data.get("file") or data.get("file_id") or "" - elif data is not None: - url = str(data) - segments.append(Seg(type="image", data=url)) - return - - if content_type == "audio": - url = "" - if isinstance(data, dict): - url = data.get("url") or data.get("file") or data.get("file_id") or "" - elif data is not None: - url = str(data) - segments.append(Seg(type="record", data=url)) - return - - if content_type == "video": - url = "" - if isinstance(data, dict): - url = data.get("url") or data.get("file") or data.get("file_id") or "" - elif data is not None: - url = str(data) - segments.append(Seg(type="video", data=url)) - return - - if content_type == "file": - file_name = "" - if isinstance(data, dict): - file_name = data.get("file_name") or data.get("name") or "" - text = file_name or "[file]" - segments.append(Seg(type="text", data=text)) - return - - if content_type == "command": - name = "" - args: Dict[str, Any] = {} - if isinstance(data, dict): - name = data.get("name", "") - args = data.get("args", {}) or {} - else: - name = str(data or "") - cmd_text = f"/{name}" if name else "/command" - if args: - cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items()) - segments.append(Seg(type="text", data=cmd_text)) - return - - if content_type == "event": - event_type = "" - if isinstance(data, dict): - event_type = data.get("event_type", "") - else: - event_type = str(data or "") - segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]")) - return - - if content_type == "system": - text = "" if data is None else str(data) - segments.append(Seg(type="text", data=f"[系统] {text}")) - return - - logger.warning(f"未知的消息类型: {content_type}") - segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]")) - - _walk(content) - return segments - - @staticmethod - def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]: - """ - Convert MessageEnvelope to legacy dict for backward compatibility. - """ - message_base = EnvelopeConverter.to_message_base(envelope) - return message_base.to_dict() - - @staticmethod - def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope: - """ - Convert MessageBase to MessageEnvelope (maim_message style preferred). - """ - try: - info_dict = message.message_info.to_dict() - seg_dict = message.message_segment.to_dict() - - envelope: MessageEnvelope = { - "direction": direction, - "message_info": info_dict, - "message_segment": seg_dict, - "platform": info_dict.get("platform"), - "message_id": info_dict.get("message_id"), - "schema_version": 1, - } - - if message.message_info.time is not None: - envelope["timestamp_ms"] = int(message.message_info.time * 1000) - if message.raw_message is not None: - envelope["raw_message"] = message.raw_message - - # legacy 补充,方便老代码继续工作 - segments = EnvelopeConverter._flatten_segments(message.message_segment) - envelope["content"] = EnvelopeConverter._segments_to_content(segments) - if message.message_info.user_info: - envelope["sender"] = { - "user_id": message.message_info.user_info.user_id, - "role": "assistant" if direction == "outgoing" else "user", - "display_name": message.message_info.user_info.user_nickname, - "avatar_url": getattr(message.message_info.user_info, "user_avatar", None), - } - if message.message_info.group_info: - envelope["channel"] = { - "channel_id": message.message_info.group_info.group_id, - "channel_type": "group", - "title": message.message_info.group_info.group_name, - } - - return envelope - - except Exception as e: - logger.error(f"转换 MessageBase 失败: {e}", exc_info=True) - raise - - @staticmethod - def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]: - """ - Convert Seg list to legacy Content (type/data/metadata). - """ - if not segments: - return {"type": "text", "data": ""} - - def _seg_to_content(seg: Seg) -> Dict[str, Any]: - data = seg.data - - if seg.type == "text": - return {"type": "text", "data": data} - - if seg.type == "at": - content: Dict[str, Any] = {"type": "text", "data": ""} - metadata: Dict[str, Any] = {"subtype": "at"} - if isinstance(data, dict): - content["data"] = data.get("text", "") - user = { - "id": data.get("user_id"), - "name": data.get("user_name"), - "raw": data.get("raw"), - } - if any(v is not None for v in user.values()): - metadata["user"] = user - else: - content["data"] = data - if metadata: - content["metadata"] = metadata - return content - - if seg.type == "image": - return {"type": "image", "data": data} - - if seg.type in ("record", "voice", "audio"): - return {"type": "audio", "data": data} - - if seg.type == "video": - return {"type": "video", "data": data} - - return {"type": seg.type, "data": data} - - if len(segments) == 1: - return _seg_to_content(segments[0]) - - return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]} - - -__all__ = ["EnvelopeConverter"] diff --git a/src/main.py b/src/main.py index 69ea51630..5070f6696 100644 --- a/src/main.py +++ b/src/main.py @@ -18,7 +18,6 @@ from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.common.logger import get_logger -from src.common.message.envelope_converter import EnvelopeConverter # 全局背景任务集合 _background_tasks = set() @@ -77,7 +76,7 @@ class MainSystem: self.individuality: Individuality = get_individuality() # 创建核心消息接收器 - self.core_sink: InProcessCoreSink = InProcessCoreSink(self._handle_message_envelope) + self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper) # 使用服务器 self.server: Server = get_global_server() @@ -353,19 +352,18 @@ class MainSystem: except Exception as e: logger.error(f"同步清理资源时出错: {e}") - async def _message_process_wrapper(self, message_data: dict[str, Any]) -> None: + async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None: """并行处理消息的包装器""" try: start_time = time.time() - message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN") - + message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN") # 检查系统是否正在关闭 if self._shutting_down: logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}") return # 创建后台任务 - task = asyncio.create_task(chat_bot.message_process(message_data)) + task = asyncio.create_task(chat_bot.message_process(envelope)) logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})") # 添加一个回调函数,当任务完成时,它会被调用 @@ -374,22 +372,6 @@ class MainSystem: logger.error("在创建消息处理任务时发生严重错误:") logger.error(traceback.format_exc()) - async def _handle_message_envelope(self, envelope: MessageEnvelope) -> None: - """ - 处理来自适配器的 MessageEnvelope - - Args: - envelope: 统一的消息信封 - """ - try: - # 转换为旧版格式 - message_data = EnvelopeConverter.to_legacy_dict(envelope) - - # 使用现有的消息处理流程 - await self._message_process_wrapper(message_data) - - except Exception as e: - logger.error(f"处理 MessageEnvelope 时出错: {e}", exc_info=True) async def initialize(self) -> None: """初始化系统组件""" diff --git a/src/mofox_bus/runtime.py b/src/mofox_bus/runtime.py index b4db14053..225f6607a 100644 --- a/src/mofox_bus/runtime.py +++ b/src/mofox_bus/runtime.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import functools import inspect import threading +import weakref from dataclasses import dataclass from typing import Awaitable, Callable, Dict, Iterable, List, Protocol @@ -37,6 +39,7 @@ class MessageRoute: handler: MessageHandler name: str | None = None message_type: str | None = None + message_types: set[str] | None = None # 支持多个消息类型 event_types: set[str] | None = None @@ -55,6 +58,8 @@ class MessageRuntime: self._middlewares: list[Middleware] = [] self._type_routes: Dict[str, list[MessageRoute]] = {} self._event_routes: Dict[str, list[MessageRoute]] = {} + # 用于检测同一类型的重复注册 + self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name def add_route( self, @@ -62,7 +67,7 @@ class MessageRuntime: handler: MessageHandler, name: str | None = None, *, - message_type: str | None = None, + message_type: str | list[str] | None = None, event_types: Iterable[str] | None = None, ) -> None: """ @@ -72,28 +77,72 @@ class MessageRuntime: predicate: 路由匹配条件 handler: 消息处理函数 name: 路由名称(可选) - message_type: 消息类型(可选) + message_type: 消息类型,可以是字符串或字符串列表(可选) event_types: 事件类型列表(可选) """ with self._lock: + # 处理 message_type 参数,支持字符串或列表 + message_types_set: set[str] | None = None + single_message_type: str | None = None + + if message_type is not None: + if isinstance(message_type, str): + message_types_set = {message_type} + single_message_type = message_type + elif isinstance(message_type, list): + message_types_set = set(message_type) + if len(message_types_set) == 1: + single_message_type = next(iter(message_types_set)) + else: + raise TypeError(f"message_type must be str or list[str], got {type(message_type)}") + + # 检测重复注册:如果明确指定了某个类型,不允许重复 + handler_name = name or getattr(handler, "__name__", str(handler)) + for msg_type in message_types_set: + if msg_type in self._explicit_type_handlers: + existing_handler = self._explicit_type_handlers[msg_type] + raise ValueError( + f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册," + f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。" + ) + self._explicit_type_handlers[msg_type] = handler_name + route = MessageRoute( predicate=predicate, handler=handler, name=name, - message_type=message_type, + message_type=single_message_type, + message_types=message_types_set, event_types=set(event_types) if event_types is not None else None, ) self._routes.append(route) - if message_type: - self._type_routes.setdefault(message_type, []).append(route) + + # 为每个消息类型建立索引 + if message_types_set: + for msg_type in message_types_set: + self._type_routes.setdefault(msg_type, []).append(route) + if route.event_types: for et in route.event_types: self._event_routes.setdefault(et, []).append(route) def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]: - """装饰器写法,便于在核心逻辑中声明式注册。""" + """装饰器写法,便于在核心逻辑中声明式注册。 + + 支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。 + """ def decorator(func: MessageHandler) -> MessageHandler: + # Support decorating instance methods: defer binding until the object is created. + if _looks_like_method(func): + return _InstanceMethodRoute( + runtime=self, + func=func, + predicate=predicate, + name=name, + message_type=None, + ) + self.add_route(predicate, func, name=name) return func @@ -101,17 +150,46 @@ class MessageRuntime: def on_message( self, + func: MessageHandler | None = None, *, - message_type: str | None = None, + message_type: str | list[str] | None = None, platform: str | None = None, predicate: Predicate | None = None, name: str | None = None, - ) -> Callable[[MessageHandler], MessageHandler]: - """Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。""" + ) -> Callable[[MessageHandler], MessageHandler] | MessageHandler: + """Sugar decorator with optional Seg.type/platform predicate matching. + + Args: + func: 被装饰的函数 + message_type: 消息类型,可以是单个字符串或字符串列表 + platform: 平台名称 + predicate: 自定义匹配条件 + name: 路由名称 + + Usages: + - @runtime.on_message(...) + - @runtime.on_message + - @runtime.on_message(message_type="text") + - @runtime.on_message(message_type=["text", "image"]) + + If the target looks like an instance method (first arg is self), it will be + auto-bound to the instance and registered when the object is constructed. + """ + # 将 message_type 转换为集合以便统一处理 + message_types_set: set[str] | None = None + if message_type is not None: + if isinstance(message_type, str): + message_types_set = {message_type} + elif isinstance(message_type, list): + message_types_set = set(message_type) + else: + raise TypeError(f"message_type must be str or list[str], got {type(message_type)}") async def combined_predicate(message: MessageEnvelope) -> bool: - if message_type is not None and _extract_segment_type(message) != message_type: - return False + if message_types_set is not None: + extracted_type = _extract_segment_type(message) + if extracted_type not in message_types_set: + return False if platform is not None: info_platform = message.get("message_info", {}).get("platform") if message.get("platform") not in (None, platform) and info_platform is None: @@ -123,35 +201,24 @@ class MessageRuntime: return await _invoke_callable(predicate, message, prefer_thread=False) def decorator(func: MessageHandler) -> MessageHandler: + # Support decorating instance methods: defer binding until the object is created. + if _looks_like_method(func): + return _InstanceMethodRoute( + runtime=self, + func=func, + predicate=combined_predicate, + name=name, + message_type=message_type, + ) + self.add_route(combined_predicate, func, name=name, message_type=message_type) return func + if func is not None: + return decorator(func) return decorator - def on_event( - self, - event_type: str | Iterable[str], - *, - name: str | None = None, - ) -> Callable[[MessageHandler], MessageHandler]: - """装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。""" - allowed = {event_type} if isinstance(event_type, str) else set(event_type) - - async def predicate(message: MessageEnvelope) -> bool: - current = ( - message.get("event_type") - or message.get("message_info", {}) - .get("additional_config", {}) - .get("event_type") - ) - return current in allowed - - def decorator(func: MessageHandler) -> MessageHandler: - self.add_route(predicate, func, name=name, event_types=allowed) - return func - - return decorator def set_batch_handler(self, handler: BatchHandler) -> None: self._batch_handler = handler @@ -199,7 +266,7 @@ class MessageRuntime: return responses async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None: - candidates: list[MessageRoute] = [] + """匹配消息路由,优先匹配明确指定了消息类型的处理器""" message_type = _extract_segment_type(message) event_type = ( message.get("event_type") @@ -207,15 +274,29 @@ class MessageRuntime: .get("additional_config", {}) .get("event_type") ) + + # 分为两层候选:优先级和普通 + priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的 + normal_candidates: list[MessageRoute] = [] # 没有指定或通配的 + with self._lock: + # 事件路由(优先级最高) if event_type and event_type in self._event_routes: - candidates.extend(self._event_routes[event_type]) + priority_candidates.extend(self._event_routes[event_type]) + + # 消息类型路由(明确指定的有优先级) if message_type and message_type in self._type_routes: - candidates.extend(self._type_routes[message_type]) - candidates.extend(self._routes) + priority_candidates.extend(self._type_routes[message_type]) + + # 通用路由(没有明确指定类型的) + for route in self._routes: + # 如果路由没有指定 message_types,则是通用路由 + if route.message_types is None and route.event_types is None: + normal_candidates.append(route) + # 先尝试优先级候选 seen: set[int] = set() - for route in candidates: + for route in priority_candidates: rid = id(route) if rid in seen: continue @@ -223,6 +304,17 @@ class MessageRuntime: should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) if should_handle: return route + + # 如果没有匹配到优先级候选,再尝试普通候选 + for route in normal_candidates: + rid = id(route) + if rid in seen: + continue + seen.add(rid) + should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) + if should_handle: + return route + return None async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None: @@ -257,7 +349,27 @@ class MessageRuntime: async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False): - """支持 sync/async 调用,并可选择在线程中执行。""" + """支持 sync/async 调用,并可选择在线程中执行。 + + 自动处理普通函数、类方法和绑定方法。 + """ + # 如果是绑定方法(bound method),直接使用,不需要额外处理 + # 因为绑定方法已经包含了 self 参数 + if inspect.ismethod(func): + # 绑定方法可以直接调用,args 中不应包含 self + if inspect.iscoroutinefunction(func): + return await func(*args) + if prefer_thread: + result = await asyncio.to_thread(func, *args) + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + result = func(*args) + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + + # 对于普通函数(未绑定的),按原有逻辑处理 if inspect.iscoroutinefunction(func): return await func(*args) if prefer_thread: @@ -282,6 +394,70 @@ def _extract_segment_type(message: MessageEnvelope) -> str | None: return None +def _looks_like_method(func: Callable[..., object]) -> bool: + """Return True if callable signature suggests an instance method (first arg named self).""" + if inspect.ismethod(func): + return True + if not inspect.isfunction(func): + return False + params = inspect.signature(func).parameters + if not params: + return False + first = next(iter(params.values())) + return first.name == "self" + + +class _InstanceMethodRoute: + """Descriptor that binds decorated instance methods and registers routes per-instance.""" + + def __init__( + self, + runtime: MessageRuntime, + func: MessageHandler, + predicate: Predicate, + name: str | None, + message_type: str | None, + ) -> None: + self._runtime = runtime + self._func = func + self._predicate = predicate + self._name = name + self._message_type = message_type + self._owner: type | None = None + self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet() + + def __set_name__(self, owner: type, name: str) -> None: + self._owner = owner + registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None) + if registry is None: + registry = [] + setattr(owner, "_mofox_instance_routes", registry) + original_init = owner.__init__ + + @functools.wraps(original_init) + def wrapped_init(inst, *args, **kwargs): + original_init(inst, *args, **kwargs) + for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []): + descriptor._register_instance(inst) + + owner.__init__ = wrapped_init # type: ignore[assignment] + registry.append(self) + + def _register_instance(self, instance: object) -> None: + if instance in self._registered_instances: + return + owner = self._owner or instance.__class__ + bound = self._func.__get__(instance, owner) # type: ignore[arg-type] + self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type) + self._registered_instances.add(instance) + + def __get__(self, instance: object | None, owner: type | None = None): + if instance is None: + return self._func + self._register_instance(instance) + return self._func.__get__(instance, owner) # type: ignore[arg-type] + + __all__ = [ "BatchHandler", "Hook", diff --git a/src/mofox_bus/types.py b/src/mofox_bus/types.py index 9ff517629..fb69e3bd5 100644 --- a/src/mofox_bus/types.py +++ b/src/mofox_bus/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, NotRequired, TypedDict +from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required MessageDirection = Literal["incoming", "outgoing"] @@ -14,14 +14,14 @@ class SegPayload(TypedDict, total=False): 对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。 """ - type: str - data: str | List["SegPayload"] + type: Required[str] + data: Required[str | List["SegPayload"]] translated_data: NotRequired[str | List["SegPayload"]] class UserInfoPayload(TypedDict, total=False): platform: NotRequired[str] - user_id: NotRequired[str] + user_id: Required[str] user_nickname: NotRequired[str] user_cardname: NotRequired[str] user_avatar: NotRequired[str] @@ -29,7 +29,7 @@ class UserInfoPayload(TypedDict, total=False): class GroupInfoPayload(TypedDict, total=False): platform: NotRequired[str] - group_id: NotRequired[str] + group_id: Required[str] group_name: NotRequired[str] @@ -45,8 +45,8 @@ class TemplateInfoPayload(TypedDict, total=False): class MessageInfoPayload(TypedDict, total=False): - platform: NotRequired[str] - message_id: NotRequired[str] + platform: Required[str] + message_id: Required[str] time: NotRequired[float] group_info: NotRequired[GroupInfoPayload] user_info: NotRequired[UserInfoPayload] @@ -67,8 +67,8 @@ class MessageEnvelope(TypedDict, total=False): """ direction: MessageDirection - message_info: MessageInfoPayload - message_segment: SegPayload | List[SegPayload] + message_info: Required[MessageInfoPayload] + message_segment: Required[SegPayload] | List[SegPayload] raw_message: NotRequired[Any] raw_bytes: NotRequired[bytes] message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名