重构消息处理和信封转换

- 从代码库中移除了EnvelopeConverter类及其相关方法,因为它们已不再需要。
- 更新了主系统,使其能够直接处理MessageEnvelope对象,而无需将其转换为旧格式。
- 增强了MessageRuntime类,以支持多种消息类型并防止重复注册处理程序。
引入了一个新的MessageHandler类来管理消息处理,包括预处理和数据库存储。
- 改进了整个消息处理工作流程中的错误处理和日志记录。
- 更新了类型提示和数据模型,以确保消息结构的一致性和清晰度。
This commit is contained in:
Windpicker-owo
2025-11-24 22:36:33 +08:00
parent 81a209ed87
commit d30b0544b5
8 changed files with 685 additions and 724 deletions

View File

@@ -3,8 +3,8 @@ import re
import traceback import traceback
from typing import Any from typing import Any
from mofox_bus import UserInfo from mofox_bus.runtime import MessageRuntime
from mofox_bus import MessageEnvelope
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
@@ -63,13 +63,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
return True return True
return False return False
runtime = MessageRuntime() # 获取mofox-bus运行时环境
class ChatBot: class ChatBot:
def __init__(self): def __init__(self):
self.bot = None # bot 实例引用 self.bot = None # bot 实例引用
self._started = False self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例 self.mood_manager = mood_manager # 获取情绪管理器单例
# 启动消息管理器 # 启动消息管理器
self._message_manager_started = False self._message_manager_started = False
@@ -304,48 +304,29 @@ class ChatBot:
except Exception as e: except Exception as e:
logger.error(f"处理适配器响应时出错: {e}") logger.error(f"处理适配器响应时出错: {e}")
async def message_process(self, message_data: dict[str, Any]) -> None: @runtime.on_message
async def message_process(self, envelope: MessageEnvelope) -> None:
"""处理转化后的统一格式消息""" """处理转化后的统一格式消息"""
try:
# 首先处理可能的切片消息重组
from src.utils.message_chunker import reassembler
# 尝试重组切片消息
reassembled_message = await reassembler.process_chunk(message_data)
if reassembled_message is None:
# 这是一个切片,但还未完整,等待更多切片
logger.debug("等待更多切片,跳过此次处理")
return
elif reassembled_message != message_data:
# 消息已被重组,使用重组后的消息
logger.info("使用重组后的完整消息进行处理")
message_data = reassembled_message
# 确保所有任务已启动
await self._ensure_started()
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError # 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
message_info = message_data.get("message_info") message_info = envelope.get("message_info")
if not isinstance(message_info, dict): if not isinstance(message_info, dict):
logger.debug( logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s", "收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(message_data.keys()), ", ".join(envelope.keys()),
) )
return return
if message_info.get("group_info") is not None: if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( message_info["group_info"]["group_id"] = str( # type: ignore
message_info["group_info"]["group_id"] message_info["group_info"]["group_id"] # type: ignore
) )
if message_info.get("user_info") is not None: if message_info.get("user_info") is not None:
message_info["user_info"]["user_id"] = str( message_info["user_info"]["user_id"] = str( # type: ignore
message_info["user_info"]["user_id"] message_info["user_info"]["user_id"] # type: ignore
) )
# print(message_data)
# logger.debug(str(message_data))
# 优先处理adapter_response消息在echo检查之前 # 优先处理adapter_response消息在echo检查之前
message_segment = message_data.get("message_segment") message_segment = envelope.get("message_segment")
if message_segment and isinstance(message_segment, dict): if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response": if message_segment.get("type") == "adapter_response":
logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理") logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
@@ -362,7 +343,7 @@ class ChatBot:
await MessageStorage.update_message(message_data) await MessageStorage.update_message(message_data)
return return
message_segment = message_data.get("message_segment") message_segment = envelope.get("message_segment")
group_info = temp_message_info.group_info group_info = temp_message_info.group_info
user_info = temp_message_info.user_info user_info = temp_message_info.user_info
@@ -376,7 +357,7 @@ class ChatBot:
# 使用新的消息处理器直接生成 DatabaseMessages # 使用新的消息处理器直接生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict( message = await process_message_from_dict(
message_dict=message_data, message_dict=envelope,
stream_id=chat.stream_id, stream_id=chat.stream_id,
platform=chat.platform platform=chat.platform
) )

View File

@@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import urllib3 import urllib3
from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo
from rich.traceback import install from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream

View File

@@ -0,0 +1,110 @@
import os
import traceback
from mofox_bus.runtime import MessageRuntime
from mofox_bus import MessageEnvelope
from src.chat.message_manager import message_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages
runtime = MessageRuntime()
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
# 配置主程序日志格式
logger = get_logger("chat")
class MessageHandler:
def __init__(self):
self._started = False
async def preprocess(self, chat: ChatStream, message: DatabaseMessages):
# message 已经是 DatabaseMessages直接使用
group_info = chat.group_info
# 先交给消息管理器处理
try:
# 在将消息添加到管理器之前进行最终的静默检查
should_process_in_manager = True
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
# 检查消息是否为图片或表情包
is_image_or_emoji = message.is_picid or message.is_emoji
if not message.is_mentioned and not is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
should_process_in_manager = False
elif is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
should_process_in_manager = False
if should_process_in_manager:
await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
# 存储消息到数据库,只进行一次写入
try:
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 情绪系统更新 - 在消息存储后触发情绪更新
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:
logger.error(f"更新情绪状态失败: {e}")
traceback.print_exc()
async def handle_message(self, envelope: MessageEnvelope):
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
message_info = envelope.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(envelope.keys()),
)
return
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( # type: ignore
message_info["group_info"]["group_id"] # type: ignore
)
if message_info.get("user_info") is not None:
message_info["user_info"]["user_id"] = str( # type: ignore
message_info["user_info"]["user_id"] # type: ignore
)
group_info = message_info.get("group_info")
user_info = message_info.get("user_info")
chat_stream = await get_chat_manager().get_or_create_stream(
platform=envelope["platform"], # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
# 生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat_stream.stream_id,
platform=chat_stream.platform
)

View File

@@ -1,13 +1,14 @@
"""消息处理工具模块 """消息处理工具模块
将原 MessageRecv 的消息处理逻辑提取为独立函数, 将原 MessageRecv 的消息处理逻辑提取为独立函数,
直接从适配器消息字典生成 DatabaseMessages 基于 mofox-bus 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
""" """
import base64 import base64
import time import time
from typing import Any from typing import Any
import orjson import orjson
from mofox_bus import BaseMessageInfo, Seg from mofox_bus import MessageEnvelope
from mofox_bus.types import MessageInfoPayload, SegPayload, UserInfoPayload, GroupInfoPayload
from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_image import get_image_manager
@@ -20,25 +21,25 @@ from src.config.config import global_config
logger = get_logger("message_processor") logger = get_logger("message_processor")
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages: async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages """从适配器消息字典处理并生成 DatabaseMessages
这个函数整合了原 MessageRecv 的所有处理逻辑: 这个函数整合了原 MessageRecv 的所有处理逻辑:
1. 解析 message_segment 并异步处理内容(图片、语音、视频等) 1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
2. 提取所有消息元数据 2. 提取所有消息元数据
3. 直接构造 DatabaseMessages 对象 3. 基于 TypedDict 形式构建数据,然后转换为 DatabaseMessages
Args: Args:
message_dict: MessageCQ序列化后的字典 message_dict: MessageEnvelope 格式的消息字典
stream_id: 聊天流ID stream_id: 聊天流ID
platform: 平台标识 platform: 平台标识
Returns: Returns:
DatabaseMessages: 处理完成的数据库消息对象 DatabaseMessages: 处理完成的数据库消息对象
""" """
# 解析基础信息 # 提取核心数据(使用 TypedDict 类型)
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) message_info: MessageInfoPayload = message_dict.get("message_info", {}) # type: ignore
message_segment = Seg.from_dict(message_dict.get("message_segment", {})) message_segment: SegPayload | list[SegPayload] = message_dict.get("message_segment", {"type": "text", "data": ""}) # type: ignore
# 初始化处理状态 # 初始化处理状态
processing_state = { processing_state = {
@@ -61,26 +62,26 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
is_notify = False is_notify = False
is_public_notice = False is_public_notice = False
notice_type = None notice_type = None
if message_info.additional_config and isinstance(message_info.additional_config, dict): additional_config_dict = message_info.get("additional_config", {})
is_notify = message_info.additional_config.get("is_notice", False) if isinstance(additional_config_dict, dict):
is_public_notice = message_info.additional_config.get("is_public_notice", False) is_notify = additional_config_dict.get("is_notice", False)
notice_type = message_info.additional_config.get("notice_type") is_public_notice = additional_config_dict.get("is_public_notice", False)
notice_type = additional_config_dict.get("notice_type")
# 提取用户信息 # 提取用户信息
user_info = message_info.user_info user_info_payload: UserInfoPayload = message_info.get("user_info", {}) # type: ignore
user_id = str(user_info.user_id) if user_info and user_info.user_id else "" user_id = str(user_info_payload.get("user_id", ""))
user_nickname = (user_info.user_nickname or "") if user_info else "" user_nickname = user_info_payload.get("user_nickname", "")
user_cardname = user_info.user_cardname if user_info else None user_cardname = user_info_payload.get("user_cardname")
user_platform = (user_info.platform or "") if user_info else "" user_platform = user_info_payload.get("platform", "")
# 提取群组信息 # 提取群组信息
group_info = message_info.group_info group_info_payload: GroupInfoPayload | None = message_info.get("group_info") # type: ignore
group_id = group_info.group_id if group_info else None group_id = group_info_payload.get("group_id") if group_info_payload else None
group_name = group_info.group_name if group_info else None group_name = group_info_payload.get("group_name") if group_info_payload else None
group_platform = group_info.platform if group_info else None group_platform = group_info_payload.get("platform") if group_info_payload else None
# chat_id 应该直接使用 stream_id与数据库存储格式一致 # chat_id 应该直接使用 stream_id与数据库存储格式一致
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
chat_id = stream_id chat_id = stream_id
# 准备 additional_config # 准备 additional_config
@@ -89,18 +90,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
# 提取 reply_to # 提取 reply_to
reply_to = _extract_reply_from_segment(message_segment) reply_to = _extract_reply_from_segment(message_segment)
# 构造 DatabaseMessages # 构造消息数据字典(基于 TypedDict 风格)
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() message_time = message_info.get("time", time.time())
message_id = message_info.message_id or "" message_id = message_info.get("message_id", "")
# 处理 is_mentioned # 处理 is_mentioned
is_mentioned = None is_mentioned = None
mentioned_value = processing_state.get("is_mentioned") mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool): if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value is_mentioned = mentioned_value
elif isinstance(mentioned_value, int | float): elif isinstance(mentioned_value, (int, float)):
is_mentioned = mentioned_value != 0 is_mentioned = mentioned_value != 0
# 使用 TypedDict 风格的数据构建 DatabaseMessages
db_message = DatabaseMessages( db_message = DatabaseMessages(
message_id=message_id, message_id=message_id,
time=float(message_time), time=float(message_time),
@@ -134,7 +136,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
chat_info_group_platform=group_platform, chat_info_group_platform=group_platform,
) )
# 设置优先级信息 # 设置优先级信息(运行时属性)
if processing_state.get("priority_mode"): if processing_state.get("priority_mode"):
setattr(db_message, "priority_mode", processing_state["priority_mode"]) setattr(db_message, "priority_mode", processing_state["priority_mode"])
if processing_state.get("priority_info"): if processing_state.get("priority_info"):
@@ -149,99 +151,127 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
return db_message return db_message
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: async def _process_message_segments(
segment: SegPayload | list[SegPayload],
state: dict,
message_info: MessageInfoPayload
) -> str:
"""递归处理消息段,转换为文字描述 """递归处理消息段,转换为文字描述
Args: Args:
segment: 要处理的消息段 segment: 要处理的消息段TypedDict 或列表)
state: 处理状态字典(用于记录消息类型标记) state: 处理状态字典(用于记录消息类型标记)
message_info: 消息基础信息(用于某些处理逻辑 message_info: 消息基础信息(TypedDict 格式
Returns: Returns:
str: 处理后的文本 str: 处理后的文本
""" """
if segment.type == "seglist": # 如果是列表,遍历处理
# 处理消息段列表 if isinstance(segment, list):
segments_text = [] segments_text = []
for seg in segment.data: for seg in segment:
processed = await _process_message_segments(seg, state, message_info) processed = await _process_message_segments(seg, state, message_info)
if processed: if processed:
segments_text.append(processed) segments_text.append(processed)
return " ".join(segments_text) return " ".join(segments_text)
else:
# 处理单个消息 # 如果是单个
if isinstance(segment, dict):
seg_type = segment.get("type", "")
seg_data = segment.get("data")
# 处理 seglist 类型
if seg_type == "seglist" and isinstance(seg_data, list):
segments_text = []
for sub_seg in seg_data:
processed = await _process_message_segments(sub_seg, state, message_info)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
# 处理其他类型
return await _process_single_segment(segment, state, message_info) return await _process_single_segment(segment, state, message_info)
return ""
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
async def _process_single_segment(
segment: SegPayload,
state: dict,
message_info: MessageInfoPayload
) -> str:
"""处理单个消息段 """处理单个消息段
Args: Args:
segment: 消息段 segment: 消息段TypedDict 格式)
state: 处理状态字典 state: 处理状态字典
message_info: 消息基础信息 message_info: 消息基础信息TypedDict 格式)
Returns: Returns:
str: 处理后的文本 str: 处理后的文本
""" """
seg_type = segment.get("type", "")
seg_data = segment.get("data")
try: try:
if segment.type == "text": if seg_type == "text":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_video"] = False state["is_video"] = False
return segment.data return str(seg_data) if seg_data else ""
elif segment.type == "at": elif seg_type == "at":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_video"] = False state["is_video"] = False
state["is_at"] = True state["is_at"] = True
# 处理at消息格式为"@<昵称:QQ号>" # 处理at消息格式为"@<昵称:QQ号>"
if isinstance(segment.data, str): if isinstance(seg_data, str):
if ":" in segment.data: if ":" in seg_data:
# 标准格式: "昵称:QQ号" # 标准格式: "昵称:QQ号"
nickname, qq_id = segment.data.split(":", 1) nickname, qq_id = seg_data.split(":", 1)
result = f"@<{nickname}:{qq_id}>" return f"@<{nickname}:{qq_id}>"
return result
else: else:
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'") logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
return f"@{segment.data}" return f"@{seg_data}"
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}") logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
elif segment.type == "image": elif seg_type == "image":
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(segment.data, str): if isinstance(seg_data, str):
state["has_picid"] = True state["has_picid"] = True
state["is_picid"] = True state["is_picid"] = True
state["is_emoji"] = False state["is_emoji"] = False
state["is_video"] = False state["is_video"] = False
image_manager = get_image_manager() image_manager = get_image_manager()
_, processed_text = await image_manager.process_image(segment.data) _, processed_text = await image_manager.process_image(seg_data)
return processed_text return processed_text
return "[发了一张图片,网卡了加载不出来]" return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji": elif seg_type == "emoji":
state["has_emoji"] = True state["has_emoji"] = True
state["is_emoji"] = True state["is_emoji"] = True
state["is_picid"] = False state["is_picid"] = False
state["is_voice"] = False state["is_voice"] = False
state["is_video"] = False state["is_video"] = False
if isinstance(segment.data, str): if isinstance(seg_data, str):
return await get_image_manager().get_emoji_description(segment.data) return await get_image_manager().get_emoji_description(seg_data)
return "[发了一个表情包,网卡了加载不出来]" return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice": elif seg_type == "voice":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_voice"] = True state["is_voice"] = True
state["is_video"] = False state["is_video"] = False
# 检查消息是否由机器人自己发送 # 检查消息是否由机器人自己发送
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account): user_info = message_info.get("user_info", {})
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。") user_id_str = str(user_info.get("user_id", ""))
if isinstance(segment.data, str): if user_id_str == str(global_config.bot.qq_account):
cached_text = consume_self_voice_text(segment.data) logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。")
if isinstance(seg_data, str):
cached_text = consume_self_voice_text(seg_data)
if cached_text: if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]" return f"[语音:{cached_text}]"
@@ -249,41 +279,42 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 # 标准语音识别流程
if isinstance(segment.data, str): if isinstance(seg_data, str):
return await get_voice_text(segment.data) return await get_voice_text(seg_data)
return "[发了一段语音,网卡了加载不出来]" return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot": elif seg_type == "mention_bot":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_voice"] = False state["is_voice"] = False
state["is_video"] = False state["is_video"] = False
state["is_mentioned"] = float(segment.data) if isinstance(seg_data, (int, float)):
state["is_mentioned"] = float(seg_data)
return "" return ""
elif segment.type == "priority_info": elif seg_type == "priority_info":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_voice"] = False state["is_voice"] = False
if isinstance(segment.data, dict): if isinstance(seg_data, dict):
# 处理优先级信息 # 处理优先级信息
state["priority_mode"] = "priority" state["priority_mode"] = "priority"
state["priority_info"] = segment.data state["priority_info"] = seg_data
return "" return ""
elif segment.type == "file": elif seg_type == "file":
if isinstance(segment.data, dict): if isinstance(seg_data, dict):
file_name = segment.data.get("name", "未知文件") file_name = seg_data.get("name", "未知文件")
file_size = segment.data.get("size", "未知大小") file_size = seg_data.get("size", "未知大小")
return f"[文件:{file_name} ({file_size}字节)]" return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]" return "[收到一个文件]"
elif segment.type == "video": elif seg_type == "video":
state["is_picid"] = False state["is_picid"] = False
state["is_emoji"] = False state["is_emoji"] = False
state["is_voice"] = False state["is_voice"] = False
state["is_video"] = True state["is_video"] = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") logger.info(f"接收到视频消息,数据类型: {type(seg_data)}")
# 检查视频分析功能是否可用 # 检查视频分析功能是否可用
if not is_video_analysis_available(): if not is_video_analysis_available():
@@ -292,11 +323,11 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
if global_config.video_analysis.enable: if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别") logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict): if isinstance(seg_data, dict):
try: try:
# 从Adapter接收的视频数据 # 从Adapter接收的视频数据
video_base64 = segment.data.get("base64") video_base64 = seg_data.get("base64")
filename = segment.data.get("filename", "video.mp4") filename = seg_data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}") logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
@@ -329,24 +360,29 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
logger.error(f"错误详情: {traceback.format_exc()}") logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]" return "[收到视频,但处理时出现错误]"
else: else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") logger.warning(f"视频消息数据不是字典格式: {type(seg_data)}")
return "[发了一个视频,但格式不支持]" return "[发了一个视频,但格式不支持]"
else: else:
return "" return ""
else: else:
logger.warning(f"未知的消息段类型: {segment.type}") logger.warning(f"未知的消息段类型: {seg_type}")
return f"[{segment.type} 消息]" return f"[{seg_type} 消息]"
except Exception as e: except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") logger.error(f"处理消息段失败: {e!s}, 类型: {seg_type}, 数据: {seg_data}")
return f"[处理失败的{segment.type}消息]" return f"[处理失败的{seg_type}消息]"
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None: def _prepare_additional_config(
message_info: MessageInfoPayload,
is_notify: bool,
is_public_notice: bool,
notice_type: str | None
) -> str | None:
"""准备 additional_config包含 format_info 和 notice 信息 """准备 additional_config包含 format_info 和 notice 信息
Args: Args:
message_info: 消息基础信息 message_info: 消息基础信息TypedDict 格式)
is_notify: 是否为notice消息 is_notify: 是否为notice消息
is_public_notice: 是否为公共notice is_public_notice: 是否为公共notice
notice_type: notice类型 notice_type: notice类型
@@ -358,12 +394,13 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
additional_config_data = {} additional_config_data = {}
# 首先获取adapter传递的additional_config # 首先获取adapter传递的additional_config
if hasattr(message_info, "additional_config") and message_info.additional_config: additional_config_raw = message_info.get("additional_config")
if isinstance(message_info.additional_config, dict): if additional_config_raw:
additional_config_data = message_info.additional_config.copy() if isinstance(additional_config_raw, dict):
elif isinstance(message_info.additional_config, str): additional_config_data = additional_config_raw.copy()
elif isinstance(additional_config_raw, str):
try: try:
additional_config_data = orjson.loads(message_info.additional_config) additional_config_data = orjson.loads(additional_config_raw)
except Exception as e: except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}") logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {} additional_config_data = {}
@@ -375,11 +412,11 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
additional_config_data["is_public_notice"] = bool(is_public_notice) additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中 # 添加format_info到additional_config中
if hasattr(message_info, "format_info") and message_info.format_info: format_info = message_info.get("format_info")
if format_info:
try: try:
format_info_dict = message_info.format_info.to_dict() additional_config_data["format_info"] = format_info
additional_config_data["format_info"] = format_info_dict logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info}")
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e: except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}") logger.warning(f"将 format_info 转换为字典失败: {e}")
@@ -392,28 +429,43 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
return None return None
def _extract_reply_from_segment(segment: Seg) -> str | None: def _extract_reply_from_segment(segment: SegPayload | list[SegPayload]) -> str | None:
"""从消息段中提取reply_to信息 """从消息段中提取reply_to信息
Args: Args:
segment: 消息段 segment: 消息段TypedDict 格式或列表)
Returns: Returns:
str | None: 回复的消息ID如果没有则返回None str | None: 回复的消息ID如果没有则返回None
""" """
try: try:
if hasattr(segment, "type") and segment.type == "seglist": # 如果是列表,遍历查找
# 递归搜索seglist中的reply段 if isinstance(segment, list):
if hasattr(segment, "data") and segment.data: for seg in segment:
for seg in segment.data:
reply_id = _extract_reply_from_segment(seg) reply_id = _extract_reply_from_segment(seg)
if reply_id: if reply_id:
return reply_id return reply_id
elif hasattr(segment, "type") and segment.type == "reply": return None
# 找到reply段返回message_id
return str(segment.data) if segment.data else None # 如果是字典
if isinstance(segment, dict):
seg_type = segment.get("type", "")
seg_data = segment.get("data")
# 如果是 seglist递归搜索
if seg_type == "seglist" and isinstance(seg_data, list):
for sub_seg in seg_data:
reply_id = _extract_reply_from_segment(sub_seg)
if reply_id:
return reply_id
# 如果是 reply 段,返回 message_id
elif seg_type == "reply":
return str(seg_data) if seg_data else None
except Exception as e: except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}") logger.warning(f"提取reply_to信息失败: {e}")
return None return None
@@ -421,33 +473,31 @@ def _extract_reply_from_segment(segment: Seg) -> str | None:
# DatabaseMessages 扩展工具函数 # DatabaseMessages 扩展工具函数
# ============================================================================= # =============================================================================
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo: def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInfoPayload:
"""从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码 """从 DatabaseMessages 重建 MessageInfoPayloadTypedDict 格式
Args: Args:
db_message: DatabaseMessages 对象 db_message: DatabaseMessages 对象
Returns: Returns:
BaseMessageInfo: 重建的消息信息对象 MessageInfoPayload: 重建的消息信息对象TypedDict 格式)
""" """
from mofox_bus import GroupInfo, UserInfo # 构建用户信息
user_info: UserInfoPayload = {
"platform": db_message.user_info.platform,
"user_id": db_message.user_info.user_id,
"user_nickname": db_message.user_info.user_nickname,
"user_cardname": db_message.user_info.user_cardname or "",
}
# 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo # 构建群组信息(如果存在)
user_info = UserInfo( group_info: GroupInfoPayload | None = None
platform=db_message.user_info.platform,
user_id=db_message.user_info.user_id,
user_nickname=db_message.user_info.user_nickname,
user_cardname=db_message.user_info.user_cardname or ""
)
# 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo如果存在
group_info = None
if db_message.group_info: if db_message.group_info:
group_info = GroupInfo( group_info = {
platform=db_message.group_info.group_platform or "", "platform": db_message.group_info.group_platform or "",
group_id=db_message.group_info.group_id, "group_id": db_message.group_info.group_id,
group_name=db_message.group_info.group_name "group_name": db_message.group_info.group_name,
) }
# 解析 additional_config从 JSON 字符串到字典) # 解析 additional_config从 JSON 字符串到字典)
additional_config = None additional_config = None
@@ -458,15 +508,19 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
# 如果解析失败,保持为字符串 # 如果解析失败,保持为字符串
pass pass
# 创建 BaseMessageInfo # 创建 MessageInfoPayload
message_info = BaseMessageInfo( message_info: MessageInfoPayload = {
platform=db_message.chat_info.platform, "platform": db_message.chat_info.platform,
message_id=db_message.message_id, "message_id": db_message.message_id,
time=db_message.time, "time": db_message.time,
user_info=user_info, "user_info": user_info,
group_info=group_info, }
additional_config=additional_config # type: ignore
) if group_info:
message_info["group_info"] = group_info
if additional_config:
message_info["additional_config"] = additional_config
return message_info return message_info

View File

@@ -1,341 +0,0 @@
"""
MessageEnvelope converter between mofox_bus schema and internal message structures.
- 优先处理 maim_message 风格的 message_info + message_segment。
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from mofox_bus import (
BaseMessageInfo,
MessageBase,
MessageEnvelope,
Seg,
UserInfo,
GroupInfo,
)
from src.common.logger import get_logger
logger = get_logger("envelope_converter")
class EnvelopeConverter:
"""MessageEnvelope <-> MessageBase converter."""
@staticmethod
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
"""
Convert MessageEnvelope to MessageBase.
"""
try:
# 优先使用 maim_message 样式字段
info_payload = envelope.get("message_info") or {}
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
if info_payload:
message_info = BaseMessageInfo.from_dict(info_payload)
else:
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
if seg_payload is None:
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
seg_payload = seg_list
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
return MessageBase(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message,
)
except Exception as e:
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
raise
@staticmethod
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
"""将 legacy 字段映射为 BaseMessageInfo。"""
platform = envelope.get("platform")
channel = envelope.get("channel") or {}
sender = envelope.get("sender") or {}
message_id = envelope.get("id") or envelope.get("message_id")
timestamp_ms = envelope.get("timestamp_ms")
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
group_info: Optional[GroupInfo] = None
channel_type = channel.get("channel_type")
if channel_type in ("group", "supergroup", "room"):
group_info = GroupInfo(
platform=platform,
group_id=channel.get("channel_id"),
group_name=channel.get("title"),
)
user_info: Optional[UserInfo] = None
if sender:
user_info = UserInfo(
platform=platform,
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
user_avatar=sender.get("avatar_url"),
)
return BaseMessageInfo(
platform=platform,
message_id=message_id,
time=time_value,
group_info=group_info,
user_info=user_info,
additional_config=envelope.get("metadata"),
)
@staticmethod
def _ensure_seg(payload: Any) -> Seg:
"""将任意 payload 转为 Seg dataclass。"""
if isinstance(payload, Seg):
return payload
if isinstance(payload, list):
# 直接传入 Seg 列表或 seglist data
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
if isinstance(payload, dict):
seg_type = payload.get("type") or "text"
data = payload.get("data")
if seg_type == "seglist" and isinstance(data, list):
data = [EnvelopeConverter._ensure_seg(item) for item in data]
return Seg(type=seg_type, data=data)
# 兜底:转成文本片段
return Seg(type="text", data="" if payload is None else str(payload))
@staticmethod
def _flatten_segments(seg: Seg) -> List[Seg]:
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
if seg.type == "seglist" and isinstance(seg.data, list):
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
return [seg]
@staticmethod
def _content_to_segments(content: Any) -> List[Seg]:
"""
Convert legacy Content (type/data/metadata) to a flat list of Seg.
"""
segments: List[Seg] = []
def _walk(node: Any) -> None:
if node is None:
return
if isinstance(node, list):
for item in node:
_walk(item)
return
if not isinstance(node, dict):
logger.warning("未知的 content 节点类型: %s", type(node))
return
content_type = node.get("type")
data = node.get("data")
metadata = node.get("metadata") or {}
if content_type == "collection":
items = data if isinstance(data, list) else node.get("items", [])
for item in items:
_walk(item)
return
if content_type in ("text", "at"):
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
text = "" if data is None else str(data)
if subtype in ("at", "mention"):
user_info = metadata.get("user") or {}
seg_data: Dict[str, Any] = {
"user_id": user_info.get("id") or user_info.get("user_id"),
"user_name": user_info.get("name") or user_info.get("display_name"),
"text": text,
"raw": user_info.get("raw") or user_info if user_info else None,
}
if any(v is not None for v in seg_data.values()):
segments.append(Seg(type="at", data=seg_data))
else:
segments.append(Seg(type="at", data=text))
else:
segments.append(Seg(type="text", data=text))
return
if content_type == "image":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="image", data=url))
return
if content_type == "audio":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="record", data=url))
return
if content_type == "video":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="video", data=url))
return
if content_type == "file":
file_name = ""
if isinstance(data, dict):
file_name = data.get("file_name") or data.get("name") or ""
text = file_name or "[file]"
segments.append(Seg(type="text", data=text))
return
if content_type == "command":
name = ""
args: Dict[str, Any] = {}
if isinstance(data, dict):
name = data.get("name", "")
args = data.get("args", {}) or {}
else:
name = str(data or "")
cmd_text = f"/{name}" if name else "/command"
if args:
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
segments.append(Seg(type="text", data=cmd_text))
return
if content_type == "event":
event_type = ""
if isinstance(data, dict):
event_type = data.get("event_type", "")
else:
event_type = str(data or "")
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
return
if content_type == "system":
text = "" if data is None else str(data)
segments.append(Seg(type="text", data=f"[系统] {text}"))
return
logger.warning(f"未知的消息类型: {content_type}")
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
_walk(content)
return segments
@staticmethod
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
"""
Convert MessageEnvelope to legacy dict for backward compatibility.
"""
message_base = EnvelopeConverter.to_message_base(envelope)
return message_base.to_dict()
@staticmethod
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
"""
Convert MessageBase to MessageEnvelope (maim_message style preferred).
"""
try:
info_dict = message.message_info.to_dict()
seg_dict = message.message_segment.to_dict()
envelope: MessageEnvelope = {
"direction": direction,
"message_info": info_dict,
"message_segment": seg_dict,
"platform": info_dict.get("platform"),
"message_id": info_dict.get("message_id"),
"schema_version": 1,
}
if message.message_info.time is not None:
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
if message.raw_message is not None:
envelope["raw_message"] = message.raw_message
# legacy 补充,方便老代码继续工作
segments = EnvelopeConverter._flatten_segments(message.message_segment)
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
if message.message_info.user_info:
envelope["sender"] = {
"user_id": message.message_info.user_info.user_id,
"role": "assistant" if direction == "outgoing" else "user",
"display_name": message.message_info.user_info.user_nickname,
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
}
if message.message_info.group_info:
envelope["channel"] = {
"channel_id": message.message_info.group_info.group_id,
"channel_type": "group",
"title": message.message_info.group_info.group_name,
}
return envelope
except Exception as e:
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
raise
@staticmethod
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
"""
Convert Seg list to legacy Content (type/data/metadata).
"""
if not segments:
return {"type": "text", "data": ""}
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
data = seg.data
if seg.type == "text":
return {"type": "text", "data": data}
if seg.type == "at":
content: Dict[str, Any] = {"type": "text", "data": ""}
metadata: Dict[str, Any] = {"subtype": "at"}
if isinstance(data, dict):
content["data"] = data.get("text", "")
user = {
"id": data.get("user_id"),
"name": data.get("user_name"),
"raw": data.get("raw"),
}
if any(v is not None for v in user.values()):
metadata["user"] = user
else:
content["data"] = data
if metadata:
content["metadata"] = metadata
return content
if seg.type == "image":
return {"type": "image", "data": data}
if seg.type in ("record", "voice", "audio"):
return {"type": "audio", "data": data}
if seg.type == "video":
return {"type": "video", "data": data}
return {"type": seg.type, "data": data}
if len(segments) == 1:
return _seg_to_content(segments[0])
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
__all__ = ["EnvelopeConverter"]

View File

@@ -18,7 +18,6 @@ from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message.envelope_converter import EnvelopeConverter
# 全局背景任务集合 # 全局背景任务集合
_background_tasks = set() _background_tasks = set()
@@ -77,7 +76,7 @@ class MainSystem:
self.individuality: Individuality = get_individuality() self.individuality: Individuality = get_individuality()
# 创建核心消息接收器 # 创建核心消息接收器
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._handle_message_envelope) self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper)
# 使用服务器 # 使用服务器
self.server: Server = get_global_server() self.server: Server = get_global_server()
@@ -353,19 +352,18 @@ class MainSystem:
except Exception as e: except Exception as e:
logger.error(f"同步清理资源时出错: {e}") logger.error(f"同步清理资源时出错: {e}")
async def _message_process_wrapper(self, message_data: dict[str, Any]) -> None: async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None:
"""并行处理消息的包装器""" """并行处理消息的包装器"""
try: try:
start_time = time.time() start_time = time.time()
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN") message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
# 检查系统是否正在关闭 # 检查系统是否正在关闭
if self._shutting_down: if self._shutting_down:
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}") logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
return return
# 创建后台任务 # 创建后台任务
task = asyncio.create_task(chat_bot.message_process(message_data)) task = asyncio.create_task(chat_bot.message_process(envelope))
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})") logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
# 添加一个回调函数,当任务完成时,它会被调用 # 添加一个回调函数,当任务完成时,它会被调用
@@ -374,22 +372,6 @@ class MainSystem:
logger.error("在创建消息处理任务时发生严重错误:") logger.error("在创建消息处理任务时发生严重错误:")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def _handle_message_envelope(self, envelope: MessageEnvelope) -> None:
"""
处理来自适配器的 MessageEnvelope
Args:
envelope: 统一的消息信封
"""
try:
# 转换为旧版格式
message_data = EnvelopeConverter.to_legacy_dict(envelope)
# 使用现有的消息处理流程
await self._message_process_wrapper(message_data)
except Exception as e:
logger.error(f"处理 MessageEnvelope 时出错: {e}", exc_info=True)
async def initialize(self) -> None: async def initialize(self) -> None:
"""初始化系统组件""" """初始化系统组件"""

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import inspect import inspect
import threading import threading
import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
@@ -37,6 +39,7 @@ class MessageRoute:
handler: MessageHandler handler: MessageHandler
name: str | None = None name: str | None = None
message_type: str | None = None message_type: str | None = None
message_types: set[str] | None = None # 支持多个消息类型
event_types: set[str] | None = None event_types: set[str] | None = None
@@ -55,6 +58,8 @@ class MessageRuntime:
self._middlewares: list[Middleware] = [] self._middlewares: list[Middleware] = []
self._type_routes: Dict[str, list[MessageRoute]] = {} self._type_routes: Dict[str, list[MessageRoute]] = {}
self._event_routes: Dict[str, list[MessageRoute]] = {} self._event_routes: Dict[str, list[MessageRoute]] = {}
# 用于检测同一类型的重复注册
self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name
def add_route( def add_route(
self, self,
@@ -62,7 +67,7 @@ class MessageRuntime:
handler: MessageHandler, handler: MessageHandler,
name: str | None = None, name: str | None = None,
*, *,
message_type: str | None = None, message_type: str | list[str] | None = None,
event_types: Iterable[str] | None = None, event_types: Iterable[str] | None = None,
) -> None: ) -> None:
""" """
@@ -72,28 +77,72 @@ class MessageRuntime:
predicate: 路由匹配条件 predicate: 路由匹配条件
handler: 消息处理函数 handler: 消息处理函数
name: 路由名称(可选) name: 路由名称(可选)
message_type: 消息类型(可选) message_type: 消息类型,可以是字符串或字符串列表(可选)
event_types: 事件类型列表(可选) event_types: 事件类型列表(可选)
""" """
with self._lock: with self._lock:
# 处理 message_type 参数,支持字符串或列表
message_types_set: set[str] | None = None
single_message_type: str | None = None
if message_type is not None:
if isinstance(message_type, str):
message_types_set = {message_type}
single_message_type = message_type
elif isinstance(message_type, list):
message_types_set = set(message_type)
if len(message_types_set) == 1:
single_message_type = next(iter(message_types_set))
else:
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
# 检测重复注册:如果明确指定了某个类型,不允许重复
handler_name = name or getattr(handler, "__name__", str(handler))
for msg_type in message_types_set:
if msg_type in self._explicit_type_handlers:
existing_handler = self._explicit_type_handlers[msg_type]
raise ValueError(
f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册,"
f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。"
)
self._explicit_type_handlers[msg_type] = handler_name
route = MessageRoute( route = MessageRoute(
predicate=predicate, predicate=predicate,
handler=handler, handler=handler,
name=name, name=name,
message_type=message_type, message_type=single_message_type,
message_types=message_types_set,
event_types=set(event_types) if event_types is not None else None, event_types=set(event_types) if event_types is not None else None,
) )
self._routes.append(route) self._routes.append(route)
if message_type:
self._type_routes.setdefault(message_type, []).append(route) # 为每个消息类型建立索引
if message_types_set:
for msg_type in message_types_set:
self._type_routes.setdefault(msg_type, []).append(route)
if route.event_types: if route.event_types:
for et in route.event_types: for et in route.event_types:
self._event_routes.setdefault(et, []).append(route) self._event_routes.setdefault(et, []).append(route)
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]: def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
"""装饰器写法,便于在核心逻辑中声明式注册。""" """装饰器写法,便于在核心逻辑中声明式注册。
支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。
"""
def decorator(func: MessageHandler) -> MessageHandler: def decorator(func: MessageHandler) -> MessageHandler:
# Support decorating instance methods: defer binding until the object is created.
if _looks_like_method(func):
return _InstanceMethodRoute(
runtime=self,
func=func,
predicate=predicate,
name=name,
message_type=None,
)
self.add_route(predicate, func, name=name) self.add_route(predicate, func, name=name)
return func return func
@@ -101,16 +150,45 @@ class MessageRuntime:
def on_message( def on_message(
self, self,
func: MessageHandler | None = None,
*, *,
message_type: str | None = None, message_type: str | list[str] | None = None,
platform: str | None = None, platform: str | None = None,
predicate: Predicate | None = None, predicate: Predicate | None = None,
name: str | None = None, name: str | None = None,
) -> Callable[[MessageHandler], MessageHandler]: ) -> Callable[[MessageHandler], MessageHandler] | MessageHandler:
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。""" """Sugar decorator with optional Seg.type/platform predicate matching.
Args:
func: 被装饰的函数
message_type: 消息类型,可以是单个字符串或字符串列表
platform: 平台名称
predicate: 自定义匹配条件
name: 路由名称
Usages:
- @runtime.on_message(...)
- @runtime.on_message
- @runtime.on_message(message_type="text")
- @runtime.on_message(message_type=["text", "image"])
If the target looks like an instance method (first arg is self), it will be
auto-bound to the instance and registered when the object is constructed.
"""
# 将 message_type 转换为集合以便统一处理
message_types_set: set[str] | None = None
if message_type is not None:
if isinstance(message_type, str):
message_types_set = {message_type}
elif isinstance(message_type, list):
message_types_set = set(message_type)
else:
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
async def combined_predicate(message: MessageEnvelope) -> bool: async def combined_predicate(message: MessageEnvelope) -> bool:
if message_type is not None and _extract_segment_type(message) != message_type: if message_types_set is not None:
extracted_type = _extract_segment_type(message)
if extracted_type not in message_types_set:
return False return False
if platform is not None: if platform is not None:
info_platform = message.get("message_info", {}).get("platform") info_platform = message.get("message_info", {}).get("platform")
@@ -123,35 +201,24 @@ class MessageRuntime:
return await _invoke_callable(predicate, message, prefer_thread=False) return await _invoke_callable(predicate, message, prefer_thread=False)
def decorator(func: MessageHandler) -> MessageHandler: def decorator(func: MessageHandler) -> MessageHandler:
# Support decorating instance methods: defer binding until the object is created.
if _looks_like_method(func):
return _InstanceMethodRoute(
runtime=self,
func=func,
predicate=combined_predicate,
name=name,
message_type=message_type,
)
self.add_route(combined_predicate, func, name=name, message_type=message_type) self.add_route(combined_predicate, func, name=name, message_type=message_type)
return func return func
if func is not None:
return decorator(func)
return decorator return decorator
def on_event(
self,
event_type: str | Iterable[str],
*,
name: str | None = None,
) -> Callable[[MessageHandler], MessageHandler]:
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
async def predicate(message: MessageEnvelope) -> bool:
current = (
message.get("event_type")
or message.get("message_info", {})
.get("additional_config", {})
.get("event_type")
)
return current in allowed
def decorator(func: MessageHandler) -> MessageHandler:
self.add_route(predicate, func, name=name, event_types=allowed)
return func
return decorator
def set_batch_handler(self, handler: BatchHandler) -> None: def set_batch_handler(self, handler: BatchHandler) -> None:
self._batch_handler = handler self._batch_handler = handler
@@ -199,7 +266,7 @@ class MessageRuntime:
return responses return responses
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None: async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
candidates: list[MessageRoute] = [] """匹配消息路由,优先匹配明确指定了消息类型的处理器"""
message_type = _extract_segment_type(message) message_type = _extract_segment_type(message)
event_type = ( event_type = (
message.get("event_type") message.get("event_type")
@@ -207,15 +274,29 @@ class MessageRuntime:
.get("additional_config", {}) .get("additional_config", {})
.get("event_type") .get("event_type")
) )
with self._lock:
if event_type and event_type in self._event_routes:
candidates.extend(self._event_routes[event_type])
if message_type and message_type in self._type_routes:
candidates.extend(self._type_routes[message_type])
candidates.extend(self._routes)
# 分为两层候选:优先级和普通
priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的
normal_candidates: list[MessageRoute] = [] # 没有指定或通配的
with self._lock:
# 事件路由(优先级最高)
if event_type and event_type in self._event_routes:
priority_candidates.extend(self._event_routes[event_type])
# 消息类型路由(明确指定的有优先级)
if message_type and message_type in self._type_routes:
priority_candidates.extend(self._type_routes[message_type])
# 通用路由(没有明确指定类型的)
for route in self._routes:
# 如果路由没有指定 message_types则是通用路由
if route.message_types is None and route.event_types is None:
normal_candidates.append(route)
# 先尝试优先级候选
seen: set[int] = set() seen: set[int] = set()
for route in candidates: for route in priority_candidates:
rid = id(route) rid = id(route)
if rid in seen: if rid in seen:
continue continue
@@ -223,6 +304,17 @@ class MessageRuntime:
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
if should_handle: if should_handle:
return route return route
# 如果没有匹配到优先级候选,再尝试普通候选
for route in normal_candidates:
rid = id(route)
if rid in seen:
continue
seen.add(rid)
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
if should_handle:
return route
return None return None
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None: async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
@@ -257,7 +349,27 @@ class MessageRuntime:
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False): async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
"""支持 sync/async 调用,并可选择在线程中执行。""" """支持 sync/async 调用,并可选择在线程中执行。
自动处理普通函数、类方法和绑定方法。
"""
# 如果是绑定方法bound method直接使用不需要额外处理
# 因为绑定方法已经包含了 self 参数
if inspect.ismethod(func):
# 绑定方法可以直接调用args 中不应包含 self
if inspect.iscoroutinefunction(func):
return await func(*args)
if prefer_thread:
result = await asyncio.to_thread(func, *args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
result = func(*args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
# 对于普通函数(未绑定的),按原有逻辑处理
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
return await func(*args) return await func(*args)
if prefer_thread: if prefer_thread:
@@ -282,6 +394,70 @@ def _extract_segment_type(message: MessageEnvelope) -> str | None:
return None return None
def _looks_like_method(func: Callable[..., object]) -> bool:
"""Return True if callable signature suggests an instance method (first arg named self)."""
if inspect.ismethod(func):
return True
if not inspect.isfunction(func):
return False
params = inspect.signature(func).parameters
if not params:
return False
first = next(iter(params.values()))
return first.name == "self"
class _InstanceMethodRoute:
"""Descriptor that binds decorated instance methods and registers routes per-instance."""
def __init__(
self,
runtime: MessageRuntime,
func: MessageHandler,
predicate: Predicate,
name: str | None,
message_type: str | None,
) -> None:
self._runtime = runtime
self._func = func
self._predicate = predicate
self._name = name
self._message_type = message_type
self._owner: type | None = None
self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet()
def __set_name__(self, owner: type, name: str) -> None:
self._owner = owner
registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None)
if registry is None:
registry = []
setattr(owner, "_mofox_instance_routes", registry)
original_init = owner.__init__
@functools.wraps(original_init)
def wrapped_init(inst, *args, **kwargs):
original_init(inst, *args, **kwargs)
for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []):
descriptor._register_instance(inst)
owner.__init__ = wrapped_init # type: ignore[assignment]
registry.append(self)
def _register_instance(self, instance: object) -> None:
if instance in self._registered_instances:
return
owner = self._owner or instance.__class__
bound = self._func.__get__(instance, owner) # type: ignore[arg-type]
self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type)
self._registered_instances.add(instance)
def __get__(self, instance: object | None, owner: type | None = None):
if instance is None:
return self._func
self._register_instance(instance)
return self._func.__get__(instance, owner) # type: ignore[arg-type]
__all__ = [ __all__ = [
"BatchHandler", "BatchHandler",
"Hook", "Hook",

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Literal, NotRequired, TypedDict from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required
MessageDirection = Literal["incoming", "outgoing"] MessageDirection = Literal["incoming", "outgoing"]
@@ -14,14 +14,14 @@ class SegPayload(TypedDict, total=False):
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。 对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
""" """
type: str type: Required[str]
data: str | List["SegPayload"] data: Required[str | List["SegPayload"]]
translated_data: NotRequired[str | List["SegPayload"]] translated_data: NotRequired[str | List["SegPayload"]]
class UserInfoPayload(TypedDict, total=False): class UserInfoPayload(TypedDict, total=False):
platform: NotRequired[str] platform: NotRequired[str]
user_id: NotRequired[str] user_id: Required[str]
user_nickname: NotRequired[str] user_nickname: NotRequired[str]
user_cardname: NotRequired[str] user_cardname: NotRequired[str]
user_avatar: NotRequired[str] user_avatar: NotRequired[str]
@@ -29,7 +29,7 @@ class UserInfoPayload(TypedDict, total=False):
class GroupInfoPayload(TypedDict, total=False): class GroupInfoPayload(TypedDict, total=False):
platform: NotRequired[str] platform: NotRequired[str]
group_id: NotRequired[str] group_id: Required[str]
group_name: NotRequired[str] group_name: NotRequired[str]
@@ -45,8 +45,8 @@ class TemplateInfoPayload(TypedDict, total=False):
class MessageInfoPayload(TypedDict, total=False): class MessageInfoPayload(TypedDict, total=False):
platform: NotRequired[str] platform: Required[str]
message_id: NotRequired[str] message_id: Required[str]
time: NotRequired[float] time: NotRequired[float]
group_info: NotRequired[GroupInfoPayload] group_info: NotRequired[GroupInfoPayload]
user_info: NotRequired[UserInfoPayload] user_info: NotRequired[UserInfoPayload]
@@ -67,8 +67,8 @@ class MessageEnvelope(TypedDict, total=False):
""" """
direction: MessageDirection direction: MessageDirection
message_info: MessageInfoPayload message_info: Required[MessageInfoPayload]
message_segment: SegPayload | List[SegPayload] message_segment: Required[SegPayload] | List[SegPayload]
raw_message: NotRequired[Any] raw_message: NotRequired[Any]
raw_bytes: NotRequired[bytes] raw_bytes: NotRequired[bytes]
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名 message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名