重构消息处理和信封转换
- 从代码库中移除了EnvelopeConverter类及其相关方法,因为它们已不再需要。 - 更新了主系统,使其能够直接处理MessageEnvelope对象,而无需将其转换为旧格式。 - 增强了MessageRuntime类,以支持多种消息类型并防止重复注册处理程序。 引入了一个新的MessageHandler类来管理消息处理,包括预处理和数据库存储。 - 改进了整个消息处理工作流程中的错误处理和日志记录。 - 更新了类型提示和数据模型,以确保消息结构的一致性和清晰度。
This commit is contained in:
@@ -3,8 +3,8 @@ import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from mofox_bus import UserInfo
|
||||
|
||||
from mofox_bus.runtime import MessageRuntime
|
||||
from mofox_bus import MessageEnvelope
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -63,13 +63,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
runtime = MessageRuntime() # 获取mofox-bus运行时环境
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
# 启动消息管理器
|
||||
self._message_manager_started = False
|
||||
|
||||
@@ -303,204 +303,185 @@ class ChatBot:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
|
||||
@runtime.on_message
|
||||
async def message_process(self, envelope: MessageEnvelope) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
# 首先处理可能的切片消息重组
|
||||
from src.utils.message_chunker import reassembler
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(envelope.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
# 尝试重组切片消息
|
||||
reassembled_message = await reassembler.process_chunk(message_data)
|
||||
if reassembled_message is None:
|
||||
# 这是一个切片,但还未完整,等待更多切片
|
||||
logger.debug("等待更多切片,跳过此次处理")
|
||||
return
|
||||
elif reassembled_message != message_data:
|
||||
# 消息已被重组,使用重组后的消息
|
||||
logger.info("使用重组后的完整消息进行处理")
|
||||
message_data = reassembled_message
|
||||
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
message_info = message_data.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(message_data.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str(
|
||||
message_info["group_info"]["group_id"]
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str(
|
||||
message_info["user_info"]["user_id"]
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
|
||||
# 优先处理adapter_response消息(在echo检查之前!)
|
||||
message_segment = message_data.get("message_segment")
|
||||
if message_segment and isinstance(message_segment, dict):
|
||||
if message_segment.get("type") == "adapter_response":
|
||||
logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理")
|
||||
await self._handle_adapter_response_from_dict(message_segment.get("data"))
|
||||
return
|
||||
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from mofox_bus import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
if temp_message_info.additional_config:
|
||||
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||
await MessageStorage.update_message(message_data)
|
||||
return
|
||||
|
||||
message_segment = message_data.get("message_segment")
|
||||
group_info = temp_message_info.group_info
|
||||
user_info = temp_message_info.user_info
|
||||
|
||||
# 获取或创建聊天流
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=temp_message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||
message_info["group_info"]["group_id"] # type: ignore
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||
message_info["user_info"]["user_id"] # type: ignore
|
||||
)
|
||||
|
||||
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=message_data,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||
logger.info(
|
||||
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
|
||||
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
# 优先处理adapter_response消息(在echo检查之前!)
|
||||
message_segment = envelope.get("message_segment")
|
||||
if message_segment and isinstance(message_segment, dict):
|
||||
if message_segment.get("type") == "adapter_response":
|
||||
logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理")
|
||||
await self._handle_adapter_response_from_dict(message_segment.get("data"))
|
||||
return
|
||||
|
||||
# 过滤检查
|
||||
# DatabaseMessages 使用 display_message 作为原始消息表示
|
||||
raw_text = message.display_message or message.processed_plain_text or ""
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
raw_text,
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from mofox_bus import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
if temp_message_info.additional_config:
|
||||
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||
await MessageStorage.update_message(message_data)
|
||||
return
|
||||
|
||||
# 命令处理 - 首先尝试PlusCommand独立处理
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||
message_segment = envelope.get("message_segment")
|
||||
group_info = temp_message_info.group_info
|
||||
user_info = temp_message_info.user_info
|
||||
|
||||
# 如果是PlusCommand且不需要继续处理,则直接返回
|
||||
if is_plus_command and not plus_continue_process:
|
||||
# 获取或创建聊天流
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=temp_message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||
logger.info(
|
||||
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
|
||||
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return
|
||||
|
||||
# 过滤检查
|
||||
# DatabaseMessages 使用 display_message 作为原始消息表示
|
||||
raw_text = message.display_message or message.processed_plain_text or ""
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
raw_text,
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 首先尝试PlusCommand独立处理
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||
|
||||
# 如果是PlusCommand且不需要继续处理,则直接返回
|
||||
if is_plus_command and not plus_continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}")
|
||||
return
|
||||
|
||||
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
||||
if not is_plus_command:
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}")
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
||||
if not is_plus_command:
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
# 这个功能需要在 adapter 层通过 additional_config 传递
|
||||
template_group_name = None
|
||||
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
async def preprocess():
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
|
||||
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
# 这个功能需要在 adapter 层通过 additional_config 传递
|
||||
template_group_name = None
|
||||
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
||||
try:
|
||||
# 在将消息添加到管理器之前进行最终的静默检查
|
||||
should_process_in_manager = True
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
# 检查消息是否为图片或表情包
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||
should_process_in_manager = False
|
||||
elif is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||
should_process_in_manager = False
|
||||
|
||||
async def preprocess():
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
if should_process_in_manager:
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
|
||||
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
||||
try:
|
||||
# 在将消息添加到管理器之前进行最终的静默检查
|
||||
should_process_in_manager = True
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
# 检查消息是否为图片或表情包
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||
should_process_in_manager = False
|
||||
elif is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||
should_process_in_manager = False
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
|
||||
if should_process_in_manager:
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
await preprocess()
|
||||
else:
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
await preprocess()
|
||||
else:
|
||||
await preprocess()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
|
||||
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import urllib3
|
||||
from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
110
src/chat/message_receive/message_handler.py
Normal file
110
src/chat/message_receive/message_handler.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from mofox_bus.runtime import MessageRuntime
|
||||
from mofox_bus import MessageEnvelope
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages
|
||||
|
||||
runtime = MessageRuntime()
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("chat")
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
|
||||
async def preprocess(self, chat: ChatStream, message: DatabaseMessages):
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
|
||||
# 先交给消息管理器处理
|
||||
try:
|
||||
# 在将消息添加到管理器之前进行最终的静默检查
|
||||
should_process_in_manager = True
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
# 检查消息是否为图片或表情包
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||
should_process_in_manager = False
|
||||
elif is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||
should_process_in_manager = False
|
||||
|
||||
if should_process_in_manager:
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def handle_message(self, envelope: MessageEnvelope):
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(envelope.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||
message_info["group_info"]["group_id"] # type: ignore
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||
message_info["user_info"]["user_id"] # type: ignore
|
||||
)
|
||||
|
||||
group_info = message_info.get("group_info")
|
||||
user_info = message_info.get("user_info")
|
||||
|
||||
chat_stream = await get_chat_manager().get_or_create_stream(
|
||||
platform=envelope["platform"], # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat_stream.stream_id,
|
||||
platform=chat_stream.platform
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""消息处理工具模块
|
||||
将原 MessageRecv 的消息处理逻辑提取为独立函数,
|
||||
直接从适配器消息字典生成 DatabaseMessages
|
||||
基于 mofox-bus 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
|
||||
"""
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from mofox_bus import BaseMessageInfo, Seg
|
||||
from mofox_bus import MessageEnvelope
|
||||
from mofox_bus.types import MessageInfoPayload, SegPayload, UserInfoPayload, GroupInfoPayload
|
||||
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
@@ -20,26 +21,26 @@ from src.config.config import global_config
|
||||
logger = get_logger("message_processor")
|
||||
|
||||
|
||||
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
|
||||
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
|
||||
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||
|
||||
这个函数整合了原 MessageRecv 的所有处理逻辑:
|
||||
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
|
||||
2. 提取所有消息元数据
|
||||
3. 直接构造 DatabaseMessages 对象
|
||||
3. 基于 TypedDict 形式构建数据,然后转换为 DatabaseMessages
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
message_dict: MessageEnvelope 格式的消息字典
|
||||
stream_id: 聊天流ID
|
||||
platform: 平台标识
|
||||
|
||||
Returns:
|
||||
DatabaseMessages: 处理完成的数据库消息对象
|
||||
"""
|
||||
# 解析基础信息
|
||||
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
|
||||
# 提取核心数据(使用 TypedDict 类型)
|
||||
message_info: MessageInfoPayload = message_dict.get("message_info", {}) # type: ignore
|
||||
message_segment: SegPayload | list[SegPayload] = message_dict.get("message_segment", {"type": "text", "data": ""}) # type: ignore
|
||||
|
||||
# 初始化处理状态
|
||||
processing_state = {
|
||||
"is_emoji": False,
|
||||
@@ -61,26 +62,26 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
is_notify = False
|
||||
is_public_notice = False
|
||||
notice_type = None
|
||||
if message_info.additional_config and isinstance(message_info.additional_config, dict):
|
||||
is_notify = message_info.additional_config.get("is_notice", False)
|
||||
is_public_notice = message_info.additional_config.get("is_public_notice", False)
|
||||
notice_type = message_info.additional_config.get("notice_type")
|
||||
additional_config_dict = message_info.get("additional_config", {})
|
||||
if isinstance(additional_config_dict, dict):
|
||||
is_notify = additional_config_dict.get("is_notice", False)
|
||||
is_public_notice = additional_config_dict.get("is_public_notice", False)
|
||||
notice_type = additional_config_dict.get("notice_type")
|
||||
|
||||
# 提取用户信息
|
||||
user_info = message_info.user_info
|
||||
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
|
||||
user_nickname = (user_info.user_nickname or "") if user_info else ""
|
||||
user_cardname = user_info.user_cardname if user_info else None
|
||||
user_platform = (user_info.platform or "") if user_info else ""
|
||||
user_info_payload: UserInfoPayload = message_info.get("user_info", {}) # type: ignore
|
||||
user_id = str(user_info_payload.get("user_id", ""))
|
||||
user_nickname = user_info_payload.get("user_nickname", "")
|
||||
user_cardname = user_info_payload.get("user_cardname")
|
||||
user_platform = user_info_payload.get("platform", "")
|
||||
|
||||
# 提取群组信息
|
||||
group_info = message_info.group_info
|
||||
group_id = group_info.group_id if group_info else None
|
||||
group_name = group_info.group_name if group_info else None
|
||||
group_platform = group_info.platform if group_info else None
|
||||
group_info_payload: GroupInfoPayload | None = message_info.get("group_info") # type: ignore
|
||||
group_id = group_info_payload.get("group_id") if group_info_payload else None
|
||||
group_name = group_info_payload.get("group_name") if group_info_payload else None
|
||||
group_platform = group_info_payload.get("platform") if group_info_payload else None
|
||||
|
||||
# chat_id 应该直接使用 stream_id(与数据库存储格式一致)
|
||||
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
|
||||
chat_id = stream_id
|
||||
|
||||
# 准备 additional_config
|
||||
@@ -89,18 +90,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
# 提取 reply_to
|
||||
reply_to = _extract_reply_from_segment(message_segment)
|
||||
|
||||
# 构造 DatabaseMessages
|
||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||
message_id = message_info.message_id or ""
|
||||
# 构造消息数据字典(基于 TypedDict 风格)
|
||||
message_time = message_info.get("time", time.time())
|
||||
message_id = message_info.get("message_id", "")
|
||||
|
||||
# 处理 is_mentioned
|
||||
is_mentioned = None
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
if isinstance(mentioned_value, bool):
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, int | float):
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
# 使用 TypedDict 风格的数据构建 DatabaseMessages
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
@@ -134,7 +136,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
# 设置优先级信息
|
||||
# 设置优先级信息(运行时属性)
|
||||
if processing_state.get("priority_mode"):
|
||||
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
||||
if processing_state.get("priority_info"):
|
||||
@@ -149,99 +151,127 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
return db_message
|
||||
|
||||
|
||||
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||
async def _process_message_segments(
|
||||
segment: SegPayload | list[SegPayload],
|
||||
state: dict,
|
||||
message_info: MessageInfoPayload
|
||||
) -> str:
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
segment: 要处理的消息段(TypedDict 或列表)
|
||||
state: 处理状态字典(用于记录消息类型标记)
|
||||
message_info: 消息基础信息(用于某些处理逻辑)
|
||||
message_info: 消息基础信息(TypedDict 格式)
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
# 如果是列表,遍历处理
|
||||
if isinstance(segment, list):
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
for seg in segment:
|
||||
processed = await _process_message_segments(seg, state, message_info)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
|
||||
# 如果是单个段
|
||||
if isinstance(segment, dict):
|
||||
seg_type = segment.get("type", "")
|
||||
seg_data = segment.get("data")
|
||||
|
||||
# 处理 seglist 类型
|
||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||
segments_text = []
|
||||
for sub_seg in seg_data:
|
||||
processed = await _process_message_segments(sub_seg, state, message_info)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
|
||||
# 处理其他类型
|
||||
return await _process_single_segment(segment, state, message_info)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||
async def _process_single_segment(
|
||||
segment: SegPayload,
|
||||
state: dict,
|
||||
message_info: MessageInfoPayload
|
||||
) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
segment: 消息段(TypedDict 格式)
|
||||
state: 处理状态字典
|
||||
message_info: 消息基础信息
|
||||
message_info: 消息基础信息(TypedDict 格式)
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
seg_type = segment.get("type", "")
|
||||
seg_data = segment.get("data")
|
||||
|
||||
try:
|
||||
if segment.type == "text":
|
||||
if seg_type == "text":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
return segment.data
|
||||
return str(seg_data) if seg_data else ""
|
||||
|
||||
elif segment.type == "at":
|
||||
elif seg_type == "at":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
state["is_at"] = True
|
||||
# 处理at消息,格式为"@<昵称:QQ号>"
|
||||
if isinstance(segment.data, str):
|
||||
if ":" in segment.data:
|
||||
if isinstance(seg_data, str):
|
||||
if ":" in seg_data:
|
||||
# 标准格式: "昵称:QQ号"
|
||||
nickname, qq_id = segment.data.split(":", 1)
|
||||
result = f"@<{nickname}:{qq_id}>"
|
||||
return result
|
||||
nickname, qq_id = seg_data.split(":", 1)
|
||||
return f"@<{nickname}:{qq_id}>"
|
||||
else:
|
||||
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
|
||||
return f"@{segment.data}"
|
||||
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
|
||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||
return f"@{seg_data}"
|
||||
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
|
||||
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
|
||||
|
||||
elif segment.type == "image":
|
||||
elif seg_type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
if isinstance(seg_data, str):
|
||||
state["has_picid"] = True
|
||||
state["is_picid"] = True
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
image_manager = get_image_manager()
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
_, processed_text = await image_manager.process_image(seg_data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "emoji":
|
||||
elif seg_type == "emoji":
|
||||
state["has_emoji"] = True
|
||||
state["is_emoji"] = True
|
||||
state["is_picid"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
if isinstance(seg_data, str):
|
||||
return await get_image_manager().get_emoji_description(seg_data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "voice":
|
||||
elif seg_type == "voice":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = True
|
||||
state["is_video"] = False
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(segment.data, str):
|
||||
cached_text = consume_self_voice_text(segment.data)
|
||||
user_info = message_info.get("user_info", {})
|
||||
user_id_str = str(user_info.get("user_id", ""))
|
||||
if user_id_str == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。")
|
||||
if isinstance(seg_data, str):
|
||||
cached_text = consume_self_voice_text(seg_data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
@@ -249,41 +279,42 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
if isinstance(seg_data, str):
|
||||
return await get_voice_text(seg_data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "mention_bot":
|
||||
elif seg_type == "mention_bot":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
state["is_mentioned"] = float(segment.data)
|
||||
if isinstance(seg_data, (int, float)):
|
||||
state["is_mentioned"] = float(seg_data)
|
||||
return ""
|
||||
|
||||
elif segment.type == "priority_info":
|
||||
elif seg_type == "priority_info":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
if isinstance(segment.data, dict):
|
||||
if isinstance(seg_data, dict):
|
||||
# 处理优先级信息
|
||||
state["priority_mode"] = "priority"
|
||||
state["priority_info"] = segment.data
|
||||
state["priority_info"] = seg_data
|
||||
return ""
|
||||
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get("name", "未知文件")
|
||||
file_size = segment.data.get("size", "未知大小")
|
||||
elif seg_type == "file":
|
||||
if isinstance(seg_data, dict):
|
||||
file_name = seg_data.get("name", "未知文件")
|
||||
file_size = seg_data.get("size", "未知大小")
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
|
||||
elif segment.type == "video":
|
||||
elif seg_type == "video":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = True
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
logger.info(f"接收到视频消息,数据类型: {type(seg_data)}")
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
@@ -292,11 +323,11 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
if isinstance(seg_data, dict):
|
||||
try:
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
video_base64 = seg_data.get("base64")
|
||||
filename = seg_data.get("filename", "video.mp4")
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
@@ -329,24 +360,29 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
return "[收到视频,但处理时出现错误]"
|
||||
else:
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(seg_data)}")
|
||||
return "[发了一个视频,但格式不支持]"
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
logger.warning(f"未知的消息段类型: {seg_type}")
|
||||
return f"[{seg_type} 消息]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {seg_type}, 数据: {seg_data}")
|
||||
return f"[处理失败的{seg_type}消息]"
|
||||
|
||||
|
||||
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
|
||||
def _prepare_additional_config(
|
||||
message_info: MessageInfoPayload,
|
||||
is_notify: bool,
|
||||
is_public_notice: bool,
|
||||
notice_type: str | None
|
||||
) -> str | None:
|
||||
"""准备 additional_config,包含 format_info 和 notice 信息
|
||||
|
||||
Args:
|
||||
message_info: 消息基础信息
|
||||
message_info: 消息基础信息(TypedDict 格式)
|
||||
is_notify: 是否为notice消息
|
||||
is_public_notice: 是否为公共notice
|
||||
notice_type: notice类型
|
||||
@@ -358,12 +394,13 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
additional_config_raw = message_info.get("additional_config")
|
||||
if additional_config_raw:
|
||||
if isinstance(additional_config_raw, dict):
|
||||
additional_config_data = additional_config_raw.copy()
|
||||
elif isinstance(additional_config_raw, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
additional_config_data = orjson.loads(additional_config_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
@@ -375,11 +412,11 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
||||
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||
|
||||
# 添加format_info到additional_config中
|
||||
if hasattr(message_info, "format_info") and message_info.format_info:
|
||||
format_info = message_info.get("format_info")
|
||||
if format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
additional_config_data["format_info"] = format_info
|
||||
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
|
||||
@@ -392,28 +429,43 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
||||
return None
|
||||
|
||||
|
||||
def _extract_reply_from_segment(segment: Seg) -> str | None:
|
||||
def _extract_reply_from_segment(segment: SegPayload | list[SegPayload]) -> str | None:
|
||||
"""从消息段中提取reply_to信息
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
segment: 消息段(TypedDict 格式或列表)
|
||||
|
||||
Returns:
|
||||
str | None: 回复的消息ID,如果没有则返回None
|
||||
"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = _extract_reply_from_segment(seg)
|
||||
# 如果是列表,遍历查找
|
||||
if isinstance(segment, list):
|
||||
for seg in segment:
|
||||
reply_id = _extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
return None
|
||||
|
||||
# 如果是字典
|
||||
if isinstance(segment, dict):
|
||||
seg_type = segment.get("type", "")
|
||||
seg_data = segment.get("data")
|
||||
|
||||
# 如果是 seglist,递归搜索
|
||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||
for sub_seg in seg_data:
|
||||
reply_id = _extract_reply_from_segment(sub_seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
|
||||
# 如果是 reply 段,返回 message_id
|
||||
elif seg_type == "reply":
|
||||
return str(seg_data) if seg_data else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -421,33 +473,31 @@ def _extract_reply_from_segment(segment: Seg) -> str | None:
|
||||
# DatabaseMessages 扩展工具函数
|
||||
# =============================================================================
|
||||
|
||||
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
|
||||
"""从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码)
|
||||
def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInfoPayload:
|
||||
"""从 DatabaseMessages 重建 MessageInfoPayload(TypedDict 格式)
|
||||
|
||||
Args:
|
||||
db_message: DatabaseMessages 对象
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 重建的消息信息对象
|
||||
MessageInfoPayload: 重建的消息信息对象(TypedDict 格式)
|
||||
"""
|
||||
from mofox_bus import GroupInfo, UserInfo
|
||||
# 构建用户信息
|
||||
user_info: UserInfoPayload = {
|
||||
"platform": db_message.user_info.platform,
|
||||
"user_id": db_message.user_info.user_id,
|
||||
"user_nickname": db_message.user_info.user_nickname,
|
||||
"user_cardname": db_message.user_info.user_cardname or "",
|
||||
}
|
||||
|
||||
# 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo
|
||||
user_info = UserInfo(
|
||||
platform=db_message.user_info.platform,
|
||||
user_id=db_message.user_info.user_id,
|
||||
user_nickname=db_message.user_info.user_nickname,
|
||||
user_cardname=db_message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
# 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo(如果存在)
|
||||
group_info = None
|
||||
# 构建群组信息(如果存在)
|
||||
group_info: GroupInfoPayload | None = None
|
||||
if db_message.group_info:
|
||||
group_info = GroupInfo(
|
||||
platform=db_message.group_info.group_platform or "",
|
||||
group_id=db_message.group_info.group_id,
|
||||
group_name=db_message.group_info.group_name
|
||||
)
|
||||
group_info = {
|
||||
"platform": db_message.group_info.group_platform or "",
|
||||
"group_id": db_message.group_info.group_id,
|
||||
"group_name": db_message.group_info.group_name,
|
||||
}
|
||||
|
||||
# 解析 additional_config(从 JSON 字符串到字典)
|
||||
additional_config = None
|
||||
@@ -458,15 +508,19 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
# 如果解析失败,保持为字符串
|
||||
pass
|
||||
|
||||
# 创建 BaseMessageInfo
|
||||
message_info = BaseMessageInfo(
|
||||
platform=db_message.chat_info.platform,
|
||||
message_id=db_message.message_id,
|
||||
time=db_message.time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
additional_config=additional_config # type: ignore
|
||||
)
|
||||
# 创建 MessageInfoPayload
|
||||
message_info: MessageInfoPayload = {
|
||||
"platform": db_message.chat_info.platform,
|
||||
"message_id": db_message.message_id,
|
||||
"time": db_message.time,
|
||||
"user_info": user_info,
|
||||
}
|
||||
|
||||
if group_info:
|
||||
message_info["group_info"] = group_info
|
||||
|
||||
if additional_config:
|
||||
message_info["additional_config"] = additional_config
|
||||
|
||||
return message_info
|
||||
|
||||
|
||||
@@ -1,341 +0,0 @@
|
||||
"""
|
||||
MessageEnvelope converter between mofox_bus schema and internal message structures.
|
||||
|
||||
- 优先处理 maim_message 风格的 message_info + message_segment。
|
||||
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mofox_bus import (
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
MessageEnvelope,
|
||||
Seg,
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("envelope_converter")
|
||||
|
||||
|
||||
class EnvelopeConverter:
|
||||
"""MessageEnvelope <-> MessageBase converter."""
|
||||
|
||||
@staticmethod
|
||||
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
|
||||
"""
|
||||
Convert MessageEnvelope to MessageBase.
|
||||
"""
|
||||
try:
|
||||
# 优先使用 maim_message 样式字段
|
||||
info_payload = envelope.get("message_info") or {}
|
||||
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
|
||||
|
||||
if info_payload:
|
||||
message_info = BaseMessageInfo.from_dict(info_payload)
|
||||
else:
|
||||
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
|
||||
|
||||
if seg_payload is None:
|
||||
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
|
||||
seg_payload = seg_list
|
||||
|
||||
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
|
||||
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
|
||||
|
||||
return MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=message_segment,
|
||||
raw_message=raw_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
|
||||
"""将 legacy 字段映射为 BaseMessageInfo。"""
|
||||
platform = envelope.get("platform")
|
||||
channel = envelope.get("channel") or {}
|
||||
sender = envelope.get("sender") or {}
|
||||
|
||||
message_id = envelope.get("id") or envelope.get("message_id")
|
||||
timestamp_ms = envelope.get("timestamp_ms")
|
||||
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
|
||||
|
||||
group_info: Optional[GroupInfo] = None
|
||||
channel_type = channel.get("channel_type")
|
||||
if channel_type in ("group", "supergroup", "room"):
|
||||
group_info = GroupInfo(
|
||||
platform=platform,
|
||||
group_id=channel.get("channel_id"),
|
||||
group_name=channel.get("title"),
|
||||
)
|
||||
|
||||
user_info: Optional[UserInfo] = None
|
||||
if sender:
|
||||
user_info = UserInfo(
|
||||
platform=platform,
|
||||
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
|
||||
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
|
||||
user_avatar=sender.get("avatar_url"),
|
||||
)
|
||||
|
||||
return BaseMessageInfo(
|
||||
platform=platform,
|
||||
message_id=message_id,
|
||||
time=time_value,
|
||||
group_info=group_info,
|
||||
user_info=user_info,
|
||||
additional_config=envelope.get("metadata"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_seg(payload: Any) -> Seg:
|
||||
"""将任意 payload 转为 Seg dataclass。"""
|
||||
if isinstance(payload, Seg):
|
||||
return payload
|
||||
if isinstance(payload, list):
|
||||
# 直接传入 Seg 列表或 seglist data
|
||||
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
|
||||
if isinstance(payload, dict):
|
||||
seg_type = payload.get("type") or "text"
|
||||
data = payload.get("data")
|
||||
if seg_type == "seglist" and isinstance(data, list):
|
||||
data = [EnvelopeConverter._ensure_seg(item) for item in data]
|
||||
return Seg(type=seg_type, data=data)
|
||||
# 兜底:转成文本片段
|
||||
return Seg(type="text", data="" if payload is None else str(payload))
|
||||
|
||||
@staticmethod
|
||||
def _flatten_segments(seg: Seg) -> List[Seg]:
|
||||
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
|
||||
if seg.type == "seglist" and isinstance(seg.data, list):
|
||||
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
|
||||
return [seg]
|
||||
|
||||
@staticmethod
|
||||
def _content_to_segments(content: Any) -> List[Seg]:
|
||||
"""
|
||||
Convert legacy Content (type/data/metadata) to a flat list of Seg.
|
||||
"""
|
||||
segments: List[Seg] = []
|
||||
|
||||
def _walk(node: Any) -> None:
|
||||
if node is None:
|
||||
return
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
_walk(item)
|
||||
return
|
||||
if not isinstance(node, dict):
|
||||
logger.warning("未知的 content 节点类型: %s", type(node))
|
||||
return
|
||||
|
||||
content_type = node.get("type")
|
||||
data = node.get("data")
|
||||
metadata = node.get("metadata") or {}
|
||||
|
||||
if content_type == "collection":
|
||||
items = data if isinstance(data, list) else node.get("items", [])
|
||||
for item in items:
|
||||
_walk(item)
|
||||
return
|
||||
|
||||
if content_type in ("text", "at"):
|
||||
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
|
||||
text = "" if data is None else str(data)
|
||||
if subtype in ("at", "mention"):
|
||||
user_info = metadata.get("user") or {}
|
||||
seg_data: Dict[str, Any] = {
|
||||
"user_id": user_info.get("id") or user_info.get("user_id"),
|
||||
"user_name": user_info.get("name") or user_info.get("display_name"),
|
||||
"text": text,
|
||||
"raw": user_info.get("raw") or user_info if user_info else None,
|
||||
}
|
||||
if any(v is not None for v in seg_data.values()):
|
||||
segments.append(Seg(type="at", data=seg_data))
|
||||
else:
|
||||
segments.append(Seg(type="at", data=text))
|
||||
else:
|
||||
segments.append(Seg(type="text", data=text))
|
||||
return
|
||||
|
||||
if content_type == "image":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="image", data=url))
|
||||
return
|
||||
|
||||
if content_type == "audio":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="record", data=url))
|
||||
return
|
||||
|
||||
if content_type == "video":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="video", data=url))
|
||||
return
|
||||
|
||||
if content_type == "file":
|
||||
file_name = ""
|
||||
if isinstance(data, dict):
|
||||
file_name = data.get("file_name") or data.get("name") or ""
|
||||
text = file_name or "[file]"
|
||||
segments.append(Seg(type="text", data=text))
|
||||
return
|
||||
|
||||
if content_type == "command":
|
||||
name = ""
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(data, dict):
|
||||
name = data.get("name", "")
|
||||
args = data.get("args", {}) or {}
|
||||
else:
|
||||
name = str(data or "")
|
||||
cmd_text = f"/{name}" if name else "/command"
|
||||
if args:
|
||||
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
||||
segments.append(Seg(type="text", data=cmd_text))
|
||||
return
|
||||
|
||||
if content_type == "event":
|
||||
event_type = ""
|
||||
if isinstance(data, dict):
|
||||
event_type = data.get("event_type", "")
|
||||
else:
|
||||
event_type = str(data or "")
|
||||
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
|
||||
return
|
||||
|
||||
if content_type == "system":
|
||||
text = "" if data is None else str(data)
|
||||
segments.append(Seg(type="text", data=f"[系统] {text}"))
|
||||
return
|
||||
|
||||
logger.warning(f"未知的消息类型: {content_type}")
|
||||
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
|
||||
|
||||
_walk(content)
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert MessageEnvelope to legacy dict for backward compatibility.
|
||||
"""
|
||||
message_base = EnvelopeConverter.to_message_base(envelope)
|
||||
return message_base.to_dict()
|
||||
|
||||
@staticmethod
|
||||
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
|
||||
"""
|
||||
Convert MessageBase to MessageEnvelope (maim_message style preferred).
|
||||
"""
|
||||
try:
|
||||
info_dict = message.message_info.to_dict()
|
||||
seg_dict = message.message_segment.to_dict()
|
||||
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": direction,
|
||||
"message_info": info_dict,
|
||||
"message_segment": seg_dict,
|
||||
"platform": info_dict.get("platform"),
|
||||
"message_id": info_dict.get("message_id"),
|
||||
"schema_version": 1,
|
||||
}
|
||||
|
||||
if message.message_info.time is not None:
|
||||
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
|
||||
if message.raw_message is not None:
|
||||
envelope["raw_message"] = message.raw_message
|
||||
|
||||
# legacy 补充,方便老代码继续工作
|
||||
segments = EnvelopeConverter._flatten_segments(message.message_segment)
|
||||
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
|
||||
if message.message_info.user_info:
|
||||
envelope["sender"] = {
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"role": "assistant" if direction == "outgoing" else "user",
|
||||
"display_name": message.message_info.user_info.user_nickname,
|
||||
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
|
||||
}
|
||||
if message.message_info.group_info:
|
||||
envelope["channel"] = {
|
||||
"channel_id": message.message_info.group_info.group_id,
|
||||
"channel_type": "group",
|
||||
"title": message.message_info.group_info.group_name,
|
||||
}
|
||||
|
||||
return envelope
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert Seg list to legacy Content (type/data/metadata).
|
||||
"""
|
||||
if not segments:
|
||||
return {"type": "text", "data": ""}
|
||||
|
||||
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
|
||||
data = seg.data
|
||||
|
||||
if seg.type == "text":
|
||||
return {"type": "text", "data": data}
|
||||
|
||||
if seg.type == "at":
|
||||
content: Dict[str, Any] = {"type": "text", "data": ""}
|
||||
metadata: Dict[str, Any] = {"subtype": "at"}
|
||||
if isinstance(data, dict):
|
||||
content["data"] = data.get("text", "")
|
||||
user = {
|
||||
"id": data.get("user_id"),
|
||||
"name": data.get("user_name"),
|
||||
"raw": data.get("raw"),
|
||||
}
|
||||
if any(v is not None for v in user.values()):
|
||||
metadata["user"] = user
|
||||
else:
|
||||
content["data"] = data
|
||||
if metadata:
|
||||
content["metadata"] = metadata
|
||||
return content
|
||||
|
||||
if seg.type == "image":
|
||||
return {"type": "image", "data": data}
|
||||
|
||||
if seg.type in ("record", "voice", "audio"):
|
||||
return {"type": "audio", "data": data}
|
||||
|
||||
if seg.type == "video":
|
||||
return {"type": "video", "data": data}
|
||||
|
||||
return {"type": seg.type, "data": data}
|
||||
|
||||
if len(segments) == 1:
|
||||
return _seg_to_content(segments[0])
|
||||
|
||||
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
|
||||
|
||||
|
||||
__all__ = ["EnvelopeConverter"]
|
||||
26
src/main.py
26
src/main.py
@@ -18,7 +18,6 @@ from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.envelope_converter import EnvelopeConverter
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
@@ -77,7 +76,7 @@ class MainSystem:
|
||||
self.individuality: Individuality = get_individuality()
|
||||
|
||||
# 创建核心消息接收器
|
||||
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._handle_message_envelope)
|
||||
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper)
|
||||
|
||||
# 使用服务器
|
||||
self.server: Server = get_global_server()
|
||||
@@ -353,19 +352,18 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"同步清理资源时出错: {e}")
|
||||
|
||||
async def _message_process_wrapper(self, message_data: dict[str, Any]) -> None:
|
||||
async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None:
|
||||
"""并行处理消息的包装器"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
|
||||
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
# 检查系统是否正在关闭
|
||||
if self._shutting_down:
|
||||
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
||||
return
|
||||
|
||||
# 创建后台任务
|
||||
task = asyncio.create_task(chat_bot.message_process(message_data))
|
||||
task = asyncio.create_task(chat_bot.message_process(envelope))
|
||||
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||
|
||||
# 添加一个回调函数,当任务完成时,它会被调用
|
||||
@@ -374,22 +372,6 @@ class MainSystem:
|
||||
logger.error("在创建消息处理任务时发生严重错误:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_message_envelope(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理来自适配器的 MessageEnvelope
|
||||
|
||||
Args:
|
||||
envelope: 统一的消息信封
|
||||
"""
|
||||
try:
|
||||
# 转换为旧版格式
|
||||
message_data = EnvelopeConverter.to_legacy_dict(envelope)
|
||||
|
||||
# 使用现有的消息处理流程
|
||||
await self._message_process_wrapper(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MessageEnvelope 时出错: {e}", exc_info=True)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化系统组件"""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
||||
|
||||
@@ -37,6 +39,7 @@ class MessageRoute:
|
||||
handler: MessageHandler
|
||||
name: str | None = None
|
||||
message_type: str | None = None
|
||||
message_types: set[str] | None = None # 支持多个消息类型
|
||||
event_types: set[str] | None = None
|
||||
|
||||
|
||||
@@ -55,6 +58,8 @@ class MessageRuntime:
|
||||
self._middlewares: list[Middleware] = []
|
||||
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
||||
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
||||
# 用于检测同一类型的重复注册
|
||||
self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
@@ -62,7 +67,7 @@ class MessageRuntime:
|
||||
handler: MessageHandler,
|
||||
name: str | None = None,
|
||||
*,
|
||||
message_type: str | None = None,
|
||||
message_type: str | list[str] | None = None,
|
||||
event_types: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -72,28 +77,72 @@ class MessageRuntime:
|
||||
predicate: 路由匹配条件
|
||||
handler: 消息处理函数
|
||||
name: 路由名称(可选)
|
||||
message_type: 消息类型(可选)
|
||||
message_type: 消息类型,可以是字符串或字符串列表(可选)
|
||||
event_types: 事件类型列表(可选)
|
||||
"""
|
||||
with self._lock:
|
||||
# 处理 message_type 参数,支持字符串或列表
|
||||
message_types_set: set[str] | None = None
|
||||
single_message_type: str | None = None
|
||||
|
||||
if message_type is not None:
|
||||
if isinstance(message_type, str):
|
||||
message_types_set = {message_type}
|
||||
single_message_type = message_type
|
||||
elif isinstance(message_type, list):
|
||||
message_types_set = set(message_type)
|
||||
if len(message_types_set) == 1:
|
||||
single_message_type = next(iter(message_types_set))
|
||||
else:
|
||||
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||
|
||||
# 检测重复注册:如果明确指定了某个类型,不允许重复
|
||||
handler_name = name or getattr(handler, "__name__", str(handler))
|
||||
for msg_type in message_types_set:
|
||||
if msg_type in self._explicit_type_handlers:
|
||||
existing_handler = self._explicit_type_handlers[msg_type]
|
||||
raise ValueError(
|
||||
f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册,"
|
||||
f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。"
|
||||
)
|
||||
self._explicit_type_handlers[msg_type] = handler_name
|
||||
|
||||
route = MessageRoute(
|
||||
predicate=predicate,
|
||||
handler=handler,
|
||||
name=name,
|
||||
message_type=message_type,
|
||||
message_type=single_message_type,
|
||||
message_types=message_types_set,
|
||||
event_types=set(event_types) if event_types is not None else None,
|
||||
)
|
||||
self._routes.append(route)
|
||||
if message_type:
|
||||
self._type_routes.setdefault(message_type, []).append(route)
|
||||
|
||||
# 为每个消息类型建立索引
|
||||
if message_types_set:
|
||||
for msg_type in message_types_set:
|
||||
self._type_routes.setdefault(msg_type, []).append(route)
|
||||
|
||||
if route.event_types:
|
||||
for et in route.event_types:
|
||||
self._event_routes.setdefault(et, []).append(route)
|
||||
|
||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""装饰器写法,便于在核心逻辑中声明式注册。"""
|
||||
"""装饰器写法,便于在核心逻辑中声明式注册。
|
||||
|
||||
支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。
|
||||
"""
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
# Support decorating instance methods: defer binding until the object is created.
|
||||
if _looks_like_method(func):
|
||||
return _InstanceMethodRoute(
|
||||
runtime=self,
|
||||
func=func,
|
||||
predicate=predicate,
|
||||
name=name,
|
||||
message_type=None,
|
||||
)
|
||||
|
||||
self.add_route(predicate, func, name=name)
|
||||
return func
|
||||
|
||||
@@ -101,17 +150,46 @@ class MessageRuntime:
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
func: MessageHandler | None = None,
|
||||
*,
|
||||
message_type: str | None = None,
|
||||
message_type: str | list[str] | None = None,
|
||||
platform: str | None = None,
|
||||
predicate: Predicate | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。"""
|
||||
) -> Callable[[MessageHandler], MessageHandler] | MessageHandler:
|
||||
"""Sugar decorator with optional Seg.type/platform predicate matching.
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
message_type: 消息类型,可以是单个字符串或字符串列表
|
||||
platform: 平台名称
|
||||
predicate: 自定义匹配条件
|
||||
name: 路由名称
|
||||
|
||||
Usages:
|
||||
- @runtime.on_message(...)
|
||||
- @runtime.on_message
|
||||
- @runtime.on_message(message_type="text")
|
||||
- @runtime.on_message(message_type=["text", "image"])
|
||||
|
||||
If the target looks like an instance method (first arg is self), it will be
|
||||
auto-bound to the instance and registered when the object is constructed.
|
||||
"""
|
||||
# 将 message_type 转换为集合以便统一处理
|
||||
message_types_set: set[str] | None = None
|
||||
if message_type is not None:
|
||||
if isinstance(message_type, str):
|
||||
message_types_set = {message_type}
|
||||
elif isinstance(message_type, list):
|
||||
message_types_set = set(message_type)
|
||||
else:
|
||||
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||
|
||||
async def combined_predicate(message: MessageEnvelope) -> bool:
|
||||
if message_type is not None and _extract_segment_type(message) != message_type:
|
||||
return False
|
||||
if message_types_set is not None:
|
||||
extracted_type = _extract_segment_type(message)
|
||||
if extracted_type not in message_types_set:
|
||||
return False
|
||||
if platform is not None:
|
||||
info_platform = message.get("message_info", {}).get("platform")
|
||||
if message.get("platform") not in (None, platform) and info_platform is None:
|
||||
@@ -123,35 +201,24 @@ class MessageRuntime:
|
||||
return await _invoke_callable(predicate, message, prefer_thread=False)
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
# Support decorating instance methods: defer binding until the object is created.
|
||||
if _looks_like_method(func):
|
||||
return _InstanceMethodRoute(
|
||||
runtime=self,
|
||||
func=func,
|
||||
predicate=combined_predicate,
|
||||
name=name,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
||||
return func
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
event_type: str | Iterable[str],
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
|
||||
|
||||
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
|
||||
|
||||
async def predicate(message: MessageEnvelope) -> bool:
|
||||
current = (
|
||||
message.get("event_type")
|
||||
or message.get("message_info", {})
|
||||
.get("additional_config", {})
|
||||
.get("event_type")
|
||||
)
|
||||
return current in allowed
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
self.add_route(predicate, func, name=name, event_types=allowed)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||
self._batch_handler = handler
|
||||
@@ -199,7 +266,7 @@ class MessageRuntime:
|
||||
return responses
|
||||
|
||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||
candidates: list[MessageRoute] = []
|
||||
"""匹配消息路由,优先匹配明确指定了消息类型的处理器"""
|
||||
message_type = _extract_segment_type(message)
|
||||
event_type = (
|
||||
message.get("event_type")
|
||||
@@ -207,15 +274,29 @@ class MessageRuntime:
|
||||
.get("additional_config", {})
|
||||
.get("event_type")
|
||||
)
|
||||
|
||||
# 分为两层候选:优先级和普通
|
||||
priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的
|
||||
normal_candidates: list[MessageRoute] = [] # 没有指定或通配的
|
||||
|
||||
with self._lock:
|
||||
# 事件路由(优先级最高)
|
||||
if event_type and event_type in self._event_routes:
|
||||
candidates.extend(self._event_routes[event_type])
|
||||
priority_candidates.extend(self._event_routes[event_type])
|
||||
|
||||
# 消息类型路由(明确指定的有优先级)
|
||||
if message_type and message_type in self._type_routes:
|
||||
candidates.extend(self._type_routes[message_type])
|
||||
candidates.extend(self._routes)
|
||||
priority_candidates.extend(self._type_routes[message_type])
|
||||
|
||||
# 通用路由(没有明确指定类型的)
|
||||
for route in self._routes:
|
||||
# 如果路由没有指定 message_types,则是通用路由
|
||||
if route.message_types is None and route.event_types is None:
|
||||
normal_candidates.append(route)
|
||||
|
||||
# 先尝试优先级候选
|
||||
seen: set[int] = set()
|
||||
for route in candidates:
|
||||
for route in priority_candidates:
|
||||
rid = id(route)
|
||||
if rid in seen:
|
||||
continue
|
||||
@@ -223,6 +304,17 @@ class MessageRuntime:
|
||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||
if should_handle:
|
||||
return route
|
||||
|
||||
# 如果没有匹配到优先级候选,再尝试普通候选
|
||||
for route in normal_candidates:
|
||||
rid = id(route)
|
||||
if rid in seen:
|
||||
continue
|
||||
seen.add(rid)
|
||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||
if should_handle:
|
||||
return route
|
||||
|
||||
return None
|
||||
|
||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||
@@ -257,7 +349,27 @@ class MessageRuntime:
|
||||
|
||||
|
||||
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||
"""支持 sync/async 调用,并可选择在线程中执行。"""
|
||||
"""支持 sync/async 调用,并可选择在线程中执行。
|
||||
|
||||
自动处理普通函数、类方法和绑定方法。
|
||||
"""
|
||||
# 如果是绑定方法(bound method),直接使用,不需要额外处理
|
||||
# 因为绑定方法已经包含了 self 参数
|
||||
if inspect.ismethod(func):
|
||||
# 绑定方法可以直接调用,args 中不应包含 self
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
result = await asyncio.to_thread(func, *args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
result = func(*args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
# 对于普通函数(未绑定的),按原有逻辑处理
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
@@ -282,6 +394,70 @@ def _extract_segment_type(message: MessageEnvelope) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _looks_like_method(func: Callable[..., object]) -> bool:
|
||||
"""Return True if callable signature suggests an instance method (first arg named self)."""
|
||||
if inspect.ismethod(func):
|
||||
return True
|
||||
if not inspect.isfunction(func):
|
||||
return False
|
||||
params = inspect.signature(func).parameters
|
||||
if not params:
|
||||
return False
|
||||
first = next(iter(params.values()))
|
||||
return first.name == "self"
|
||||
|
||||
|
||||
class _InstanceMethodRoute:
|
||||
"""Descriptor that binds decorated instance methods and registers routes per-instance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runtime: MessageRuntime,
|
||||
func: MessageHandler,
|
||||
predicate: Predicate,
|
||||
name: str | None,
|
||||
message_type: str | None,
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._func = func
|
||||
self._predicate = predicate
|
||||
self._name = name
|
||||
self._message_type = message_type
|
||||
self._owner: type | None = None
|
||||
self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet()
|
||||
|
||||
def __set_name__(self, owner: type, name: str) -> None:
|
||||
self._owner = owner
|
||||
registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None)
|
||||
if registry is None:
|
||||
registry = []
|
||||
setattr(owner, "_mofox_instance_routes", registry)
|
||||
original_init = owner.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def wrapped_init(inst, *args, **kwargs):
|
||||
original_init(inst, *args, **kwargs)
|
||||
for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []):
|
||||
descriptor._register_instance(inst)
|
||||
|
||||
owner.__init__ = wrapped_init # type: ignore[assignment]
|
||||
registry.append(self)
|
||||
|
||||
def _register_instance(self, instance: object) -> None:
|
||||
if instance in self._registered_instances:
|
||||
return
|
||||
owner = self._owner or instance.__class__
|
||||
bound = self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||
self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type)
|
||||
self._registered_instances.add(instance)
|
||||
|
||||
def __get__(self, instance: object | None, owner: type | None = None):
|
||||
if instance is None:
|
||||
return self._func
|
||||
self._register_instance(instance)
|
||||
return self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchHandler",
|
||||
"Hook",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required
|
||||
|
||||
MessageDirection = Literal["incoming", "outgoing"]
|
||||
|
||||
@@ -14,14 +14,14 @@ class SegPayload(TypedDict, total=False):
|
||||
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: str | List["SegPayload"]
|
||||
type: Required[str]
|
||||
data: Required[str | List["SegPayload"]]
|
||||
translated_data: NotRequired[str | List["SegPayload"]]
|
||||
|
||||
|
||||
class UserInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
user_id: NotRequired[str]
|
||||
user_id: Required[str]
|
||||
user_nickname: NotRequired[str]
|
||||
user_cardname: NotRequired[str]
|
||||
user_avatar: NotRequired[str]
|
||||
@@ -29,7 +29,7 @@ class UserInfoPayload(TypedDict, total=False):
|
||||
|
||||
class GroupInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
group_id: NotRequired[str]
|
||||
group_id: Required[str]
|
||||
group_name: NotRequired[str]
|
||||
|
||||
|
||||
@@ -45,8 +45,8 @@ class TemplateInfoPayload(TypedDict, total=False):
|
||||
|
||||
|
||||
class MessageInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
message_id: NotRequired[str]
|
||||
platform: Required[str]
|
||||
message_id: Required[str]
|
||||
time: NotRequired[float]
|
||||
group_info: NotRequired[GroupInfoPayload]
|
||||
user_info: NotRequired[UserInfoPayload]
|
||||
@@ -67,8 +67,8 @@ class MessageEnvelope(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
direction: MessageDirection
|
||||
message_info: MessageInfoPayload
|
||||
message_segment: SegPayload | List[SegPayload]
|
||||
message_info: Required[MessageInfoPayload]
|
||||
message_segment: Required[SegPayload] | List[SegPayload]
|
||||
raw_message: NotRequired[Any]
|
||||
raw_bytes: NotRequired[bytes]
|
||||
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
||||
|
||||
Reference in New Issue
Block a user