重构消息处理和信封转换
- 从代码库中移除了EnvelopeConverter类及其相关方法,因为它们已不再需要。 - 更新了主系统,使其能够直接处理MessageEnvelope对象,而无需将其转换为旧格式。 - 增强了MessageRuntime类,以支持多种消息类型并防止重复注册处理程序。 引入了一个新的MessageHandler类来管理消息处理,包括预处理和数据库存储。 - 改进了整个消息处理工作流程中的错误处理和日志记录。 - 更新了类型提示和数据模型,以确保消息结构的一致性和清晰度。
This commit is contained in:
@@ -3,8 +3,8 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mofox_bus import UserInfo
|
from mofox_bus.runtime import MessageRuntime
|
||||||
|
from mofox_bus import MessageEnvelope
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
@@ -63,13 +63,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
runtime = MessageRuntime() # 获取mofox-bus运行时环境
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||||
|
|
||||||
# 启动消息管理器
|
# 启动消息管理器
|
||||||
self._message_manager_started = False
|
self._message_manager_started = False
|
||||||
|
|
||||||
@@ -304,48 +304,29 @@ class ChatBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理适配器响应时出错: {e}")
|
logger.error(f"处理适配器响应时出错: {e}")
|
||||||
|
|
||||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
@runtime.on_message
|
||||||
|
async def message_process(self, envelope: MessageEnvelope) -> None:
|
||||||
"""处理转化后的统一格式消息"""
|
"""处理转化后的统一格式消息"""
|
||||||
try:
|
|
||||||
# 首先处理可能的切片消息重组
|
|
||||||
from src.utils.message_chunker import reassembler
|
|
||||||
|
|
||||||
# 尝试重组切片消息
|
|
||||||
reassembled_message = await reassembler.process_chunk(message_data)
|
|
||||||
if reassembled_message is None:
|
|
||||||
# 这是一个切片,但还未完整,等待更多切片
|
|
||||||
logger.debug("等待更多切片,跳过此次处理")
|
|
||||||
return
|
|
||||||
elif reassembled_message != message_data:
|
|
||||||
# 消息已被重组,使用重组后的消息
|
|
||||||
logger.info("使用重组后的完整消息进行处理")
|
|
||||||
message_data = reassembled_message
|
|
||||||
|
|
||||||
# 确保所有任务已启动
|
|
||||||
await self._ensure_started()
|
|
||||||
|
|
||||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||||
message_info = message_data.get("message_info")
|
message_info = envelope.get("message_info")
|
||||||
if not isinstance(message_info, dict):
|
if not isinstance(message_info, dict):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||||
", ".join(message_data.keys()),
|
", ".join(envelope.keys()),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if message_info.get("group_info") is not None:
|
if message_info.get("group_info") is not None:
|
||||||
message_info["group_info"]["group_id"] = str(
|
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||||
message_info["group_info"]["group_id"]
|
message_info["group_info"]["group_id"] # type: ignore
|
||||||
)
|
)
|
||||||
if message_info.get("user_info") is not None:
|
if message_info.get("user_info") is not None:
|
||||||
message_info["user_info"]["user_id"] = str(
|
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||||
message_info["user_info"]["user_id"]
|
message_info["user_info"]["user_id"] # type: ignore
|
||||||
)
|
)
|
||||||
# print(message_data)
|
|
||||||
# logger.debug(str(message_data))
|
|
||||||
|
|
||||||
# 优先处理adapter_response消息(在echo检查之前!)
|
# 优先处理adapter_response消息(在echo检查之前!)
|
||||||
message_segment = message_data.get("message_segment")
|
message_segment = envelope.get("message_segment")
|
||||||
if message_segment and isinstance(message_segment, dict):
|
if message_segment and isinstance(message_segment, dict):
|
||||||
if message_segment.get("type") == "adapter_response":
|
if message_segment.get("type") == "adapter_response":
|
||||||
logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理")
|
logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理")
|
||||||
@@ -362,7 +343,7 @@ class ChatBot:
|
|||||||
await MessageStorage.update_message(message_data)
|
await MessageStorage.update_message(message_data)
|
||||||
return
|
return
|
||||||
|
|
||||||
message_segment = message_data.get("message_segment")
|
message_segment = envelope.get("message_segment")
|
||||||
group_info = temp_message_info.group_info
|
group_info = temp_message_info.group_info
|
||||||
user_info = temp_message_info.user_info
|
user_info = temp_message_info.user_info
|
||||||
|
|
||||||
@@ -376,7 +357,7 @@ class ChatBot:
|
|||||||
# 使用新的消息处理器直接生成 DatabaseMessages
|
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
message = await process_message_from_dict(
|
message = await process_message_from_dict(
|
||||||
message_dict=message_data,
|
message_dict=envelope,
|
||||||
stream_id=chat.stream_id,
|
stream_id=chat.stream_id,
|
||||||
platform=chat.platform
|
platform=chat.platform
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|||||||
110
src/chat/message_receive/message_handler.py
Normal file
110
src/chat/message_receive/message_handler.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from mofox_bus.runtime import MessageRuntime
|
||||||
|
from mofox_bus import MessageEnvelope
|
||||||
|
from src.chat.message_manager import message_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
|
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages
|
||||||
|
|
||||||
|
runtime = MessageRuntime()
|
||||||
|
|
||||||
|
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||||
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
|
||||||
|
# 配置主程序日志格式
|
||||||
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
class MessageHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
async def preprocess(self, chat: ChatStream, message: DatabaseMessages):
|
||||||
|
# message 已经是 DatabaseMessages,直接使用
|
||||||
|
group_info = chat.group_info
|
||||||
|
|
||||||
|
# 先交给消息管理器处理
|
||||||
|
try:
|
||||||
|
# 在将消息添加到管理器之前进行最终的静默检查
|
||||||
|
should_process_in_manager = True
|
||||||
|
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||||
|
# 检查消息是否为图片或表情包
|
||||||
|
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||||
|
if not message.is_mentioned and not is_image_or_emoji:
|
||||||
|
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||||
|
should_process_in_manager = False
|
||||||
|
elif is_image_or_emoji:
|
||||||
|
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||||
|
should_process_in_manager = False
|
||||||
|
|
||||||
|
if should_process_in_manager:
|
||||||
|
await message_manager.add_message(chat.stream_id, message)
|
||||||
|
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||||
|
|
||||||
|
# 存储消息到数据库,只进行一次写入
|
||||||
|
try:
|
||||||
|
await MessageStorage.store_message(message, chat)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储消息到数据库失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||||
|
try:
|
||||||
|
if global_config.mood.enable_mood:
|
||||||
|
# 获取兴趣度用于情绪更新
|
||||||
|
interest_rate = message.interest_value
|
||||||
|
if interest_rate is None:
|
||||||
|
interest_rate = 0.0
|
||||||
|
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||||
|
|
||||||
|
# 获取当前聊天的情绪对象并更新情绪状态
|
||||||
|
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||||
|
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||||
|
logger.debug("情绪状态更新完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新情绪状态失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_message(self, envelope: MessageEnvelope):
|
||||||
|
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||||
|
message_info = envelope.get("message_info")
|
||||||
|
if not isinstance(message_info, dict):
|
||||||
|
logger.debug(
|
||||||
|
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||||
|
", ".join(envelope.keys()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message_info.get("group_info") is not None:
|
||||||
|
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||||
|
message_info["group_info"]["group_id"] # type: ignore
|
||||||
|
)
|
||||||
|
if message_info.get("user_info") is not None:
|
||||||
|
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||||
|
message_info["user_info"]["user_id"] # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
group_info = message_info.get("group_info")
|
||||||
|
user_info = message_info.get("user_info")
|
||||||
|
|
||||||
|
chat_stream = await get_chat_manager().get_or_create_stream(
|
||||||
|
platform=envelope["platform"], # type: ignore
|
||||||
|
user_info=user_info, # type: ignore
|
||||||
|
group_info=group_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成 DatabaseMessages
|
||||||
|
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
|
message = await process_message_from_dict(
|
||||||
|
message_dict=envelope,
|
||||||
|
stream_id=chat_stream.stream_id,
|
||||||
|
platform=chat_stream.platform
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
"""消息处理工具模块
|
"""消息处理工具模块
|
||||||
将原 MessageRecv 的消息处理逻辑提取为独立函数,
|
将原 MessageRecv 的消息处理逻辑提取为独立函数,
|
||||||
直接从适配器消息字典生成 DatabaseMessages
|
基于 mofox-bus 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from mofox_bus import BaseMessageInfo, Seg
|
from mofox_bus import MessageEnvelope
|
||||||
|
from mofox_bus.types import MessageInfoPayload, SegPayload, UserInfoPayload, GroupInfoPayload
|
||||||
|
|
||||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||||
from src.chat.utils.utils_image import get_image_manager
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
@@ -20,25 +21,25 @@ from src.config.config import global_config
|
|||||||
logger = get_logger("message_processor")
|
logger = get_logger("message_processor")
|
||||||
|
|
||||||
|
|
||||||
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
|
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
|
||||||
"""从适配器消息字典处理并生成 DatabaseMessages
|
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||||
|
|
||||||
这个函数整合了原 MessageRecv 的所有处理逻辑:
|
这个函数整合了原 MessageRecv 的所有处理逻辑:
|
||||||
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
|
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
|
||||||
2. 提取所有消息元数据
|
2. 提取所有消息元数据
|
||||||
3. 直接构造 DatabaseMessages 对象
|
3. 基于 TypedDict 形式构建数据,然后转换为 DatabaseMessages
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_dict: MessageCQ序列化后的字典
|
message_dict: MessageEnvelope 格式的消息字典
|
||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
platform: 平台标识
|
platform: 平台标识
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DatabaseMessages: 处理完成的数据库消息对象
|
DatabaseMessages: 处理完成的数据库消息对象
|
||||||
"""
|
"""
|
||||||
# 解析基础信息
|
# 提取核心数据(使用 TypedDict 类型)
|
||||||
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
message_info: MessageInfoPayload = message_dict.get("message_info", {}) # type: ignore
|
||||||
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
message_segment: SegPayload | list[SegPayload] = message_dict.get("message_segment", {"type": "text", "data": ""}) # type: ignore
|
||||||
|
|
||||||
# 初始化处理状态
|
# 初始化处理状态
|
||||||
processing_state = {
|
processing_state = {
|
||||||
@@ -61,26 +62,26 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
|||||||
is_notify = False
|
is_notify = False
|
||||||
is_public_notice = False
|
is_public_notice = False
|
||||||
notice_type = None
|
notice_type = None
|
||||||
if message_info.additional_config and isinstance(message_info.additional_config, dict):
|
additional_config_dict = message_info.get("additional_config", {})
|
||||||
is_notify = message_info.additional_config.get("is_notice", False)
|
if isinstance(additional_config_dict, dict):
|
||||||
is_public_notice = message_info.additional_config.get("is_public_notice", False)
|
is_notify = additional_config_dict.get("is_notice", False)
|
||||||
notice_type = message_info.additional_config.get("notice_type")
|
is_public_notice = additional_config_dict.get("is_public_notice", False)
|
||||||
|
notice_type = additional_config_dict.get("notice_type")
|
||||||
|
|
||||||
# 提取用户信息
|
# 提取用户信息
|
||||||
user_info = message_info.user_info
|
user_info_payload: UserInfoPayload = message_info.get("user_info", {}) # type: ignore
|
||||||
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
|
user_id = str(user_info_payload.get("user_id", ""))
|
||||||
user_nickname = (user_info.user_nickname or "") if user_info else ""
|
user_nickname = user_info_payload.get("user_nickname", "")
|
||||||
user_cardname = user_info.user_cardname if user_info else None
|
user_cardname = user_info_payload.get("user_cardname")
|
||||||
user_platform = (user_info.platform or "") if user_info else ""
|
user_platform = user_info_payload.get("platform", "")
|
||||||
|
|
||||||
# 提取群组信息
|
# 提取群组信息
|
||||||
group_info = message_info.group_info
|
group_info_payload: GroupInfoPayload | None = message_info.get("group_info") # type: ignore
|
||||||
group_id = group_info.group_id if group_info else None
|
group_id = group_info_payload.get("group_id") if group_info_payload else None
|
||||||
group_name = group_info.group_name if group_info else None
|
group_name = group_info_payload.get("group_name") if group_info_payload else None
|
||||||
group_platform = group_info.platform if group_info else None
|
group_platform = group_info_payload.get("platform") if group_info_payload else None
|
||||||
|
|
||||||
# chat_id 应该直接使用 stream_id(与数据库存储格式一致)
|
# chat_id 应该直接使用 stream_id(与数据库存储格式一致)
|
||||||
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
|
|
||||||
chat_id = stream_id
|
chat_id = stream_id
|
||||||
|
|
||||||
# 准备 additional_config
|
# 准备 additional_config
|
||||||
@@ -89,18 +90,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
|||||||
# 提取 reply_to
|
# 提取 reply_to
|
||||||
reply_to = _extract_reply_from_segment(message_segment)
|
reply_to = _extract_reply_from_segment(message_segment)
|
||||||
|
|
||||||
# 构造 DatabaseMessages
|
# 构造消息数据字典(基于 TypedDict 风格)
|
||||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
message_time = message_info.get("time", time.time())
|
||||||
message_id = message_info.message_id or ""
|
message_id = message_info.get("message_id", "")
|
||||||
|
|
||||||
# 处理 is_mentioned
|
# 处理 is_mentioned
|
||||||
is_mentioned = None
|
is_mentioned = None
|
||||||
mentioned_value = processing_state.get("is_mentioned")
|
mentioned_value = processing_state.get("is_mentioned")
|
||||||
if isinstance(mentioned_value, bool):
|
if isinstance(mentioned_value, bool):
|
||||||
is_mentioned = mentioned_value
|
is_mentioned = mentioned_value
|
||||||
elif isinstance(mentioned_value, int | float):
|
elif isinstance(mentioned_value, (int, float)):
|
||||||
is_mentioned = mentioned_value != 0
|
is_mentioned = mentioned_value != 0
|
||||||
|
|
||||||
|
# 使用 TypedDict 风格的数据构建 DatabaseMessages
|
||||||
db_message = DatabaseMessages(
|
db_message = DatabaseMessages(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
time=float(message_time),
|
time=float(message_time),
|
||||||
@@ -134,7 +136,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
|||||||
chat_info_group_platform=group_platform,
|
chat_info_group_platform=group_platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置优先级信息
|
# 设置优先级信息(运行时属性)
|
||||||
if processing_state.get("priority_mode"):
|
if processing_state.get("priority_mode"):
|
||||||
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
||||||
if processing_state.get("priority_info"):
|
if processing_state.get("priority_info"):
|
||||||
@@ -149,99 +151,127 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
|||||||
return db_message
|
return db_message
|
||||||
|
|
||||||
|
|
||||||
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
async def _process_message_segments(
|
||||||
|
segment: SegPayload | list[SegPayload],
|
||||||
|
state: dict,
|
||||||
|
message_info: MessageInfoPayload
|
||||||
|
) -> str:
|
||||||
"""递归处理消息段,转换为文字描述
|
"""递归处理消息段,转换为文字描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segment: 要处理的消息段
|
segment: 要处理的消息段(TypedDict 或列表)
|
||||||
state: 处理状态字典(用于记录消息类型标记)
|
state: 处理状态字典(用于记录消息类型标记)
|
||||||
message_info: 消息基础信息(用于某些处理逻辑)
|
message_info: 消息基础信息(TypedDict 格式)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 处理后的文本
|
str: 处理后的文本
|
||||||
"""
|
"""
|
||||||
if segment.type == "seglist":
|
# 如果是列表,遍历处理
|
||||||
# 处理消息段列表
|
if isinstance(segment, list):
|
||||||
segments_text = []
|
segments_text = []
|
||||||
for seg in segment.data:
|
for seg in segment:
|
||||||
processed = await _process_message_segments(seg, state, message_info)
|
processed = await _process_message_segments(seg, state, message_info)
|
||||||
if processed:
|
if processed:
|
||||||
segments_text.append(processed)
|
segments_text.append(processed)
|
||||||
return " ".join(segments_text)
|
return " ".join(segments_text)
|
||||||
else:
|
|
||||||
# 处理单个消息段
|
# 如果是单个段
|
||||||
|
if isinstance(segment, dict):
|
||||||
|
seg_type = segment.get("type", "")
|
||||||
|
seg_data = segment.get("data")
|
||||||
|
|
||||||
|
# 处理 seglist 类型
|
||||||
|
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||||
|
segments_text = []
|
||||||
|
for sub_seg in seg_data:
|
||||||
|
processed = await _process_message_segments(sub_seg, state, message_info)
|
||||||
|
if processed:
|
||||||
|
segments_text.append(processed)
|
||||||
|
return " ".join(segments_text)
|
||||||
|
|
||||||
|
# 处理其他类型
|
||||||
return await _process_single_segment(segment, state, message_info)
|
return await _process_single_segment(segment, state, message_info)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
|
||||||
|
async def _process_single_segment(
|
||||||
|
segment: SegPayload,
|
||||||
|
state: dict,
|
||||||
|
message_info: MessageInfoPayload
|
||||||
|
) -> str:
|
||||||
"""处理单个消息段
|
"""处理单个消息段
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segment: 消息段
|
segment: 消息段(TypedDict 格式)
|
||||||
state: 处理状态字典
|
state: 处理状态字典
|
||||||
message_info: 消息基础信息
|
message_info: 消息基础信息(TypedDict 格式)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 处理后的文本
|
str: 处理后的文本
|
||||||
"""
|
"""
|
||||||
|
seg_type = segment.get("type", "")
|
||||||
|
seg_data = segment.get("data")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if segment.type == "text":
|
if seg_type == "text":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
return segment.data
|
return str(seg_data) if seg_data else ""
|
||||||
|
|
||||||
elif segment.type == "at":
|
elif seg_type == "at":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
state["is_at"] = True
|
state["is_at"] = True
|
||||||
# 处理at消息,格式为"@<昵称:QQ号>"
|
# 处理at消息,格式为"@<昵称:QQ号>"
|
||||||
if isinstance(segment.data, str):
|
if isinstance(seg_data, str):
|
||||||
if ":" in segment.data:
|
if ":" in seg_data:
|
||||||
# 标准格式: "昵称:QQ号"
|
# 标准格式: "昵称:QQ号"
|
||||||
nickname, qq_id = segment.data.split(":", 1)
|
nickname, qq_id = seg_data.split(":", 1)
|
||||||
result = f"@<{nickname}:{qq_id}>"
|
return f"@<{nickname}:{qq_id}>"
|
||||||
return result
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
|
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||||
return f"@{segment.data}"
|
return f"@{seg_data}"
|
||||||
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
|
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
|
||||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
|
||||||
|
|
||||||
elif segment.type == "image":
|
elif seg_type == "image":
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if isinstance(segment.data, str):
|
if isinstance(seg_data, str):
|
||||||
state["has_picid"] = True
|
state["has_picid"] = True
|
||||||
state["is_picid"] = True
|
state["is_picid"] = True
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
image_manager = get_image_manager()
|
image_manager = get_image_manager()
|
||||||
_, processed_text = await image_manager.process_image(segment.data)
|
_, processed_text = await image_manager.process_image(seg_data)
|
||||||
return processed_text
|
return processed_text
|
||||||
return "[发了一张图片,网卡了加载不出来]"
|
return "[发了一张图片,网卡了加载不出来]"
|
||||||
|
|
||||||
elif segment.type == "emoji":
|
elif seg_type == "emoji":
|
||||||
state["has_emoji"] = True
|
state["has_emoji"] = True
|
||||||
state["is_emoji"] = True
|
state["is_emoji"] = True
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_voice"] = False
|
state["is_voice"] = False
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
if isinstance(segment.data, str):
|
if isinstance(seg_data, str):
|
||||||
return await get_image_manager().get_emoji_description(segment.data)
|
return await get_image_manager().get_emoji_description(seg_data)
|
||||||
return "[发了一个表情包,网卡了加载不出来]"
|
return "[发了一个表情包,网卡了加载不出来]"
|
||||||
|
|
||||||
elif segment.type == "voice":
|
elif seg_type == "voice":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_voice"] = True
|
state["is_voice"] = True
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
|
|
||||||
# 检查消息是否由机器人自己发送
|
# 检查消息是否由机器人自己发送
|
||||||
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
user_info = message_info.get("user_info", {})
|
||||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
|
user_id_str = str(user_info.get("user_id", ""))
|
||||||
if isinstance(segment.data, str):
|
if user_id_str == str(global_config.bot.qq_account):
|
||||||
cached_text = consume_self_voice_text(segment.data)
|
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。")
|
||||||
|
if isinstance(seg_data, str):
|
||||||
|
cached_text = consume_self_voice_text(seg_data)
|
||||||
if cached_text:
|
if cached_text:
|
||||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||||
return f"[语音:{cached_text}]"
|
return f"[语音:{cached_text}]"
|
||||||
@@ -249,41 +279,42 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
|||||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||||
|
|
||||||
# 标准语音识别流程
|
# 标准语音识别流程
|
||||||
if isinstance(segment.data, str):
|
if isinstance(seg_data, str):
|
||||||
return await get_voice_text(segment.data)
|
return await get_voice_text(seg_data)
|
||||||
return "[发了一段语音,网卡了加载不出来]"
|
return "[发了一段语音,网卡了加载不出来]"
|
||||||
|
|
||||||
elif segment.type == "mention_bot":
|
elif seg_type == "mention_bot":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_voice"] = False
|
state["is_voice"] = False
|
||||||
state["is_video"] = False
|
state["is_video"] = False
|
||||||
state["is_mentioned"] = float(segment.data)
|
if isinstance(seg_data, (int, float)):
|
||||||
|
state["is_mentioned"] = float(seg_data)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
elif segment.type == "priority_info":
|
elif seg_type == "priority_info":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_voice"] = False
|
state["is_voice"] = False
|
||||||
if isinstance(segment.data, dict):
|
if isinstance(seg_data, dict):
|
||||||
# 处理优先级信息
|
# 处理优先级信息
|
||||||
state["priority_mode"] = "priority"
|
state["priority_mode"] = "priority"
|
||||||
state["priority_info"] = segment.data
|
state["priority_info"] = seg_data
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
elif segment.type == "file":
|
elif seg_type == "file":
|
||||||
if isinstance(segment.data, dict):
|
if isinstance(seg_data, dict):
|
||||||
file_name = segment.data.get("name", "未知文件")
|
file_name = seg_data.get("name", "未知文件")
|
||||||
file_size = segment.data.get("size", "未知大小")
|
file_size = seg_data.get("size", "未知大小")
|
||||||
return f"[文件:{file_name} ({file_size}字节)]"
|
return f"[文件:{file_name} ({file_size}字节)]"
|
||||||
return "[收到一个文件]"
|
return "[收到一个文件]"
|
||||||
|
|
||||||
elif segment.type == "video":
|
elif seg_type == "video":
|
||||||
state["is_picid"] = False
|
state["is_picid"] = False
|
||||||
state["is_emoji"] = False
|
state["is_emoji"] = False
|
||||||
state["is_voice"] = False
|
state["is_voice"] = False
|
||||||
state["is_video"] = True
|
state["is_video"] = True
|
||||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
logger.info(f"接收到视频消息,数据类型: {type(seg_data)}")
|
||||||
|
|
||||||
# 检查视频分析功能是否可用
|
# 检查视频分析功能是否可用
|
||||||
if not is_video_analysis_available():
|
if not is_video_analysis_available():
|
||||||
@@ -292,11 +323,11 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
|||||||
|
|
||||||
if global_config.video_analysis.enable:
|
if global_config.video_analysis.enable:
|
||||||
logger.info("已启用视频识别,开始识别")
|
logger.info("已启用视频识别,开始识别")
|
||||||
if isinstance(segment.data, dict):
|
if isinstance(seg_data, dict):
|
||||||
try:
|
try:
|
||||||
# 从Adapter接收的视频数据
|
# 从Adapter接收的视频数据
|
||||||
video_base64 = segment.data.get("base64")
|
video_base64 = seg_data.get("base64")
|
||||||
filename = segment.data.get("filename", "video.mp4")
|
filename = seg_data.get("filename", "video.mp4")
|
||||||
|
|
||||||
logger.info(f"视频文件名: {filename}")
|
logger.info(f"视频文件名: {filename}")
|
||||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||||
@@ -329,24 +360,29 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
|||||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
return "[收到视频,但处理时出现错误]"
|
return "[收到视频,但处理时出现错误]"
|
||||||
else:
|
else:
|
||||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
logger.warning(f"视频消息数据不是字典格式: {type(seg_data)}")
|
||||||
return "[发了一个视频,但格式不支持]"
|
return "[发了一个视频,但格式不支持]"
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
logger.warning(f"未知的消息段类型: {seg_type}")
|
||||||
return f"[{segment.type} 消息]"
|
return f"[{seg_type} 消息]"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
logger.error(f"处理消息段失败: {e!s}, 类型: {seg_type}, 数据: {seg_data}")
|
||||||
return f"[处理失败的{segment.type}消息]"
|
return f"[处理失败的{seg_type}消息]"
|
||||||
|
|
||||||
|
|
||||||
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
|
def _prepare_additional_config(
|
||||||
|
message_info: MessageInfoPayload,
|
||||||
|
is_notify: bool,
|
||||||
|
is_public_notice: bool,
|
||||||
|
notice_type: str | None
|
||||||
|
) -> str | None:
|
||||||
"""准备 additional_config,包含 format_info 和 notice 信息
|
"""准备 additional_config,包含 format_info 和 notice 信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_info: 消息基础信息
|
message_info: 消息基础信息(TypedDict 格式)
|
||||||
is_notify: 是否为notice消息
|
is_notify: 是否为notice消息
|
||||||
is_public_notice: 是否为公共notice
|
is_public_notice: 是否为公共notice
|
||||||
notice_type: notice类型
|
notice_type: notice类型
|
||||||
@@ -358,12 +394,13 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
|||||||
additional_config_data = {}
|
additional_config_data = {}
|
||||||
|
|
||||||
# 首先获取adapter传递的additional_config
|
# 首先获取adapter传递的additional_config
|
||||||
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
additional_config_raw = message_info.get("additional_config")
|
||||||
if isinstance(message_info.additional_config, dict):
|
if additional_config_raw:
|
||||||
additional_config_data = message_info.additional_config.copy()
|
if isinstance(additional_config_raw, dict):
|
||||||
elif isinstance(message_info.additional_config, str):
|
additional_config_data = additional_config_raw.copy()
|
||||||
|
elif isinstance(additional_config_raw, str):
|
||||||
try:
|
try:
|
||||||
additional_config_data = orjson.loads(message_info.additional_config)
|
additional_config_data = orjson.loads(additional_config_raw)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||||
additional_config_data = {}
|
additional_config_data = {}
|
||||||
@@ -375,11 +412,11 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
|||||||
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||||
|
|
||||||
# 添加format_info到additional_config中
|
# 添加format_info到additional_config中
|
||||||
if hasattr(message_info, "format_info") and message_info.format_info:
|
format_info = message_info.get("format_info")
|
||||||
|
if format_info:
|
||||||
try:
|
try:
|
||||||
format_info_dict = message_info.format_info.to_dict()
|
additional_config_data["format_info"] = format_info
|
||||||
additional_config_data["format_info"] = format_info_dict
|
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info}")
|
||||||
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||||
|
|
||||||
@@ -392,28 +429,43 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_reply_from_segment(segment: Seg) -> str | None:
|
def _extract_reply_from_segment(segment: SegPayload | list[SegPayload]) -> str | None:
|
||||||
"""从消息段中提取reply_to信息
|
"""从消息段中提取reply_to信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segment: 消息段
|
segment: 消息段(TypedDict 格式或列表)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str | None: 回复的消息ID,如果没有则返回None
|
str | None: 回复的消息ID,如果没有则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if hasattr(segment, "type") and segment.type == "seglist":
|
# 如果是列表,遍历查找
|
||||||
# 递归搜索seglist中的reply段
|
if isinstance(segment, list):
|
||||||
if hasattr(segment, "data") and segment.data:
|
for seg in segment:
|
||||||
for seg in segment.data:
|
|
||||||
reply_id = _extract_reply_from_segment(seg)
|
reply_id = _extract_reply_from_segment(seg)
|
||||||
if reply_id:
|
if reply_id:
|
||||||
return reply_id
|
return reply_id
|
||||||
elif hasattr(segment, "type") and segment.type == "reply":
|
return None
|
||||||
# 找到reply段,返回message_id
|
|
||||||
return str(segment.data) if segment.data else None
|
# 如果是字典
|
||||||
|
if isinstance(segment, dict):
|
||||||
|
seg_type = segment.get("type", "")
|
||||||
|
seg_data = segment.get("data")
|
||||||
|
|
||||||
|
# 如果是 seglist,递归搜索
|
||||||
|
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||||
|
for sub_seg in seg_data:
|
||||||
|
reply_id = _extract_reply_from_segment(sub_seg)
|
||||||
|
if reply_id:
|
||||||
|
return reply_id
|
||||||
|
|
||||||
|
# 如果是 reply 段,返回 message_id
|
||||||
|
elif seg_type == "reply":
|
||||||
|
return str(seg_data) if seg_data else None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"提取reply_to信息失败: {e}")
|
logger.warning(f"提取reply_to信息失败: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -421,33 +473,31 @@ def _extract_reply_from_segment(segment: Seg) -> str | None:
|
|||||||
# DatabaseMessages 扩展工具函数
|
# DatabaseMessages 扩展工具函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
|
def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInfoPayload:
|
||||||
"""从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码)
|
"""从 DatabaseMessages 重建 MessageInfoPayload(TypedDict 格式)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_message: DatabaseMessages 对象
|
db_message: DatabaseMessages 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseMessageInfo: 重建的消息信息对象
|
MessageInfoPayload: 重建的消息信息对象(TypedDict 格式)
|
||||||
"""
|
"""
|
||||||
from mofox_bus import GroupInfo, UserInfo
|
# 构建用户信息
|
||||||
|
user_info: UserInfoPayload = {
|
||||||
|
"platform": db_message.user_info.platform,
|
||||||
|
"user_id": db_message.user_info.user_id,
|
||||||
|
"user_nickname": db_message.user_info.user_nickname,
|
||||||
|
"user_cardname": db_message.user_info.user_cardname or "",
|
||||||
|
}
|
||||||
|
|
||||||
# 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo
|
# 构建群组信息(如果存在)
|
||||||
user_info = UserInfo(
|
group_info: GroupInfoPayload | None = None
|
||||||
platform=db_message.user_info.platform,
|
|
||||||
user_id=db_message.user_info.user_id,
|
|
||||||
user_nickname=db_message.user_info.user_nickname,
|
|
||||||
user_cardname=db_message.user_info.user_cardname or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo(如果存在)
|
|
||||||
group_info = None
|
|
||||||
if db_message.group_info:
|
if db_message.group_info:
|
||||||
group_info = GroupInfo(
|
group_info = {
|
||||||
platform=db_message.group_info.group_platform or "",
|
"platform": db_message.group_info.group_platform or "",
|
||||||
group_id=db_message.group_info.group_id,
|
"group_id": db_message.group_info.group_id,
|
||||||
group_name=db_message.group_info.group_name
|
"group_name": db_message.group_info.group_name,
|
||||||
)
|
}
|
||||||
|
|
||||||
# 解析 additional_config(从 JSON 字符串到字典)
|
# 解析 additional_config(从 JSON 字符串到字典)
|
||||||
additional_config = None
|
additional_config = None
|
||||||
@@ -458,15 +508,19 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
|||||||
# 如果解析失败,保持为字符串
|
# 如果解析失败,保持为字符串
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 创建 BaseMessageInfo
|
# 创建 MessageInfoPayload
|
||||||
message_info = BaseMessageInfo(
|
message_info: MessageInfoPayload = {
|
||||||
platform=db_message.chat_info.platform,
|
"platform": db_message.chat_info.platform,
|
||||||
message_id=db_message.message_id,
|
"message_id": db_message.message_id,
|
||||||
time=db_message.time,
|
"time": db_message.time,
|
||||||
user_info=user_info,
|
"user_info": user_info,
|
||||||
group_info=group_info,
|
}
|
||||||
additional_config=additional_config # type: ignore
|
|
||||||
)
|
if group_info:
|
||||||
|
message_info["group_info"] = group_info
|
||||||
|
|
||||||
|
if additional_config:
|
||||||
|
message_info["additional_config"] = additional_config
|
||||||
|
|
||||||
return message_info
|
return message_info
|
||||||
|
|
||||||
|
|||||||
@@ -1,341 +0,0 @@
|
|||||||
"""
|
|
||||||
MessageEnvelope converter between mofox_bus schema and internal message structures.
|
|
||||||
|
|
||||||
- 优先处理 maim_message 风格的 message_info + message_segment。
|
|
||||||
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from mofox_bus import (
|
|
||||||
BaseMessageInfo,
|
|
||||||
MessageBase,
|
|
||||||
MessageEnvelope,
|
|
||||||
Seg,
|
|
||||||
UserInfo,
|
|
||||||
GroupInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("envelope_converter")
|
|
||||||
|
|
||||||
|
|
||||||
class EnvelopeConverter:
|
|
||||||
"""MessageEnvelope <-> MessageBase converter."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
|
|
||||||
"""
|
|
||||||
Convert MessageEnvelope to MessageBase.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 优先使用 maim_message 样式字段
|
|
||||||
info_payload = envelope.get("message_info") or {}
|
|
||||||
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
|
|
||||||
|
|
||||||
if info_payload:
|
|
||||||
message_info = BaseMessageInfo.from_dict(info_payload)
|
|
||||||
else:
|
|
||||||
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
|
|
||||||
|
|
||||||
if seg_payload is None:
|
|
||||||
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
|
|
||||||
seg_payload = seg_list
|
|
||||||
|
|
||||||
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
|
|
||||||
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
|
|
||||||
|
|
||||||
return MessageBase(
|
|
||||||
message_info=message_info,
|
|
||||||
message_segment=message_segment,
|
|
||||||
raw_message=raw_message,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
|
|
||||||
"""将 legacy 字段映射为 BaseMessageInfo。"""
|
|
||||||
platform = envelope.get("platform")
|
|
||||||
channel = envelope.get("channel") or {}
|
|
||||||
sender = envelope.get("sender") or {}
|
|
||||||
|
|
||||||
message_id = envelope.get("id") or envelope.get("message_id")
|
|
||||||
timestamp_ms = envelope.get("timestamp_ms")
|
|
||||||
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
|
|
||||||
|
|
||||||
group_info: Optional[GroupInfo] = None
|
|
||||||
channel_type = channel.get("channel_type")
|
|
||||||
if channel_type in ("group", "supergroup", "room"):
|
|
||||||
group_info = GroupInfo(
|
|
||||||
platform=platform,
|
|
||||||
group_id=channel.get("channel_id"),
|
|
||||||
group_name=channel.get("title"),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_info: Optional[UserInfo] = None
|
|
||||||
if sender:
|
|
||||||
user_info = UserInfo(
|
|
||||||
platform=platform,
|
|
||||||
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
|
|
||||||
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
|
|
||||||
user_avatar=sender.get("avatar_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return BaseMessageInfo(
|
|
||||||
platform=platform,
|
|
||||||
message_id=message_id,
|
|
||||||
time=time_value,
|
|
||||||
group_info=group_info,
|
|
||||||
user_info=user_info,
|
|
||||||
additional_config=envelope.get("metadata"),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ensure_seg(payload: Any) -> Seg:
|
|
||||||
"""将任意 payload 转为 Seg dataclass。"""
|
|
||||||
if isinstance(payload, Seg):
|
|
||||||
return payload
|
|
||||||
if isinstance(payload, list):
|
|
||||||
# 直接传入 Seg 列表或 seglist data
|
|
||||||
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
seg_type = payload.get("type") or "text"
|
|
||||||
data = payload.get("data")
|
|
||||||
if seg_type == "seglist" and isinstance(data, list):
|
|
||||||
data = [EnvelopeConverter._ensure_seg(item) for item in data]
|
|
||||||
return Seg(type=seg_type, data=data)
|
|
||||||
# 兜底:转成文本片段
|
|
||||||
return Seg(type="text", data="" if payload is None else str(payload))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _flatten_segments(seg: Seg) -> List[Seg]:
|
|
||||||
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
|
|
||||||
if seg.type == "seglist" and isinstance(seg.data, list):
|
|
||||||
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
|
|
||||||
return [seg]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _content_to_segments(content: Any) -> List[Seg]:
|
|
||||||
"""
|
|
||||||
Convert legacy Content (type/data/metadata) to a flat list of Seg.
|
|
||||||
"""
|
|
||||||
segments: List[Seg] = []
|
|
||||||
|
|
||||||
def _walk(node: Any) -> None:
|
|
||||||
if node is None:
|
|
||||||
return
|
|
||||||
if isinstance(node, list):
|
|
||||||
for item in node:
|
|
||||||
_walk(item)
|
|
||||||
return
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
logger.warning("未知的 content 节点类型: %s", type(node))
|
|
||||||
return
|
|
||||||
|
|
||||||
content_type = node.get("type")
|
|
||||||
data = node.get("data")
|
|
||||||
metadata = node.get("metadata") or {}
|
|
||||||
|
|
||||||
if content_type == "collection":
|
|
||||||
items = data if isinstance(data, list) else node.get("items", [])
|
|
||||||
for item in items:
|
|
||||||
_walk(item)
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type in ("text", "at"):
|
|
||||||
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
|
|
||||||
text = "" if data is None else str(data)
|
|
||||||
if subtype in ("at", "mention"):
|
|
||||||
user_info = metadata.get("user") or {}
|
|
||||||
seg_data: Dict[str, Any] = {
|
|
||||||
"user_id": user_info.get("id") or user_info.get("user_id"),
|
|
||||||
"user_name": user_info.get("name") or user_info.get("display_name"),
|
|
||||||
"text": text,
|
|
||||||
"raw": user_info.get("raw") or user_info if user_info else None,
|
|
||||||
}
|
|
||||||
if any(v is not None for v in seg_data.values()):
|
|
||||||
segments.append(Seg(type="at", data=seg_data))
|
|
||||||
else:
|
|
||||||
segments.append(Seg(type="at", data=text))
|
|
||||||
else:
|
|
||||||
segments.append(Seg(type="text", data=text))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "image":
|
|
||||||
url = ""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
|
||||||
elif data is not None:
|
|
||||||
url = str(data)
|
|
||||||
segments.append(Seg(type="image", data=url))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "audio":
|
|
||||||
url = ""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
|
||||||
elif data is not None:
|
|
||||||
url = str(data)
|
|
||||||
segments.append(Seg(type="record", data=url))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "video":
|
|
||||||
url = ""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
|
||||||
elif data is not None:
|
|
||||||
url = str(data)
|
|
||||||
segments.append(Seg(type="video", data=url))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "file":
|
|
||||||
file_name = ""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
file_name = data.get("file_name") or data.get("name") or ""
|
|
||||||
text = file_name or "[file]"
|
|
||||||
segments.append(Seg(type="text", data=text))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "command":
|
|
||||||
name = ""
|
|
||||||
args: Dict[str, Any] = {}
|
|
||||||
if isinstance(data, dict):
|
|
||||||
name = data.get("name", "")
|
|
||||||
args = data.get("args", {}) or {}
|
|
||||||
else:
|
|
||||||
name = str(data or "")
|
|
||||||
cmd_text = f"/{name}" if name else "/command"
|
|
||||||
if args:
|
|
||||||
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
|
||||||
segments.append(Seg(type="text", data=cmd_text))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "event":
|
|
||||||
event_type = ""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
event_type = data.get("event_type", "")
|
|
||||||
else:
|
|
||||||
event_type = str(data or "")
|
|
||||||
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
|
|
||||||
return
|
|
||||||
|
|
||||||
if content_type == "system":
|
|
||||||
text = "" if data is None else str(data)
|
|
||||||
segments.append(Seg(type="text", data=f"[系统] {text}"))
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(f"未知的消息类型: {content_type}")
|
|
||||||
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
|
|
||||||
|
|
||||||
_walk(content)
|
|
||||||
return segments
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Convert MessageEnvelope to legacy dict for backward compatibility.
|
|
||||||
"""
|
|
||||||
message_base = EnvelopeConverter.to_message_base(envelope)
|
|
||||||
return message_base.to_dict()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
|
|
||||||
"""
|
|
||||||
Convert MessageBase to MessageEnvelope (maim_message style preferred).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
info_dict = message.message_info.to_dict()
|
|
||||||
seg_dict = message.message_segment.to_dict()
|
|
||||||
|
|
||||||
envelope: MessageEnvelope = {
|
|
||||||
"direction": direction,
|
|
||||||
"message_info": info_dict,
|
|
||||||
"message_segment": seg_dict,
|
|
||||||
"platform": info_dict.get("platform"),
|
|
||||||
"message_id": info_dict.get("message_id"),
|
|
||||||
"schema_version": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.message_info.time is not None:
|
|
||||||
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
|
|
||||||
if message.raw_message is not None:
|
|
||||||
envelope["raw_message"] = message.raw_message
|
|
||||||
|
|
||||||
# legacy 补充,方便老代码继续工作
|
|
||||||
segments = EnvelopeConverter._flatten_segments(message.message_segment)
|
|
||||||
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
|
|
||||||
if message.message_info.user_info:
|
|
||||||
envelope["sender"] = {
|
|
||||||
"user_id": message.message_info.user_info.user_id,
|
|
||||||
"role": "assistant" if direction == "outgoing" else "user",
|
|
||||||
"display_name": message.message_info.user_info.user_nickname,
|
|
||||||
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
|
|
||||||
}
|
|
||||||
if message.message_info.group_info:
|
|
||||||
envelope["channel"] = {
|
|
||||||
"channel_id": message.message_info.group_info.group_id,
|
|
||||||
"channel_type": "group",
|
|
||||||
"title": message.message_info.group_info.group_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
return envelope
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Convert Seg list to legacy Content (type/data/metadata).
|
|
||||||
"""
|
|
||||||
if not segments:
|
|
||||||
return {"type": "text", "data": ""}
|
|
||||||
|
|
||||||
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
|
|
||||||
data = seg.data
|
|
||||||
|
|
||||||
if seg.type == "text":
|
|
||||||
return {"type": "text", "data": data}
|
|
||||||
|
|
||||||
if seg.type == "at":
|
|
||||||
content: Dict[str, Any] = {"type": "text", "data": ""}
|
|
||||||
metadata: Dict[str, Any] = {"subtype": "at"}
|
|
||||||
if isinstance(data, dict):
|
|
||||||
content["data"] = data.get("text", "")
|
|
||||||
user = {
|
|
||||||
"id": data.get("user_id"),
|
|
||||||
"name": data.get("user_name"),
|
|
||||||
"raw": data.get("raw"),
|
|
||||||
}
|
|
||||||
if any(v is not None for v in user.values()):
|
|
||||||
metadata["user"] = user
|
|
||||||
else:
|
|
||||||
content["data"] = data
|
|
||||||
if metadata:
|
|
||||||
content["metadata"] = metadata
|
|
||||||
return content
|
|
||||||
|
|
||||||
if seg.type == "image":
|
|
||||||
return {"type": "image", "data": data}
|
|
||||||
|
|
||||||
if seg.type in ("record", "voice", "audio"):
|
|
||||||
return {"type": "audio", "data": data}
|
|
||||||
|
|
||||||
if seg.type == "video":
|
|
||||||
return {"type": "video", "data": data}
|
|
||||||
|
|
||||||
return {"type": seg.type, "data": data}
|
|
||||||
|
|
||||||
if len(segments) == 1:
|
|
||||||
return _seg_to_content(segments[0])
|
|
||||||
|
|
||||||
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["EnvelopeConverter"]
|
|
||||||
26
src/main.py
26
src/main.py
@@ -18,7 +18,6 @@ from src.chat.message_receive.bot import chat_bot
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message.envelope_converter import EnvelopeConverter
|
|
||||||
|
|
||||||
# 全局背景任务集合
|
# 全局背景任务集合
|
||||||
_background_tasks = set()
|
_background_tasks = set()
|
||||||
@@ -77,7 +76,7 @@ class MainSystem:
|
|||||||
self.individuality: Individuality = get_individuality()
|
self.individuality: Individuality = get_individuality()
|
||||||
|
|
||||||
# 创建核心消息接收器
|
# 创建核心消息接收器
|
||||||
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._handle_message_envelope)
|
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper)
|
||||||
|
|
||||||
# 使用服务器
|
# 使用服务器
|
||||||
self.server: Server = get_global_server()
|
self.server: Server = get_global_server()
|
||||||
@@ -353,19 +352,18 @@ class MainSystem:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"同步清理资源时出错: {e}")
|
logger.error(f"同步清理资源时出错: {e}")
|
||||||
|
|
||||||
async def _message_process_wrapper(self, message_data: dict[str, Any]) -> None:
|
async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None:
|
||||||
"""并行处理消息的包装器"""
|
"""并行处理消息的包装器"""
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||||
|
|
||||||
# 检查系统是否正在关闭
|
# 检查系统是否正在关闭
|
||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 创建后台任务
|
# 创建后台任务
|
||||||
task = asyncio.create_task(chat_bot.message_process(message_data))
|
task = asyncio.create_task(chat_bot.message_process(envelope))
|
||||||
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||||
|
|
||||||
# 添加一个回调函数,当任务完成时,它会被调用
|
# 添加一个回调函数,当任务完成时,它会被调用
|
||||||
@@ -374,22 +372,6 @@ class MainSystem:
|
|||||||
logger.error("在创建消息处理任务时发生严重错误:")
|
logger.error("在创建消息处理任务时发生严重错误:")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def _handle_message_envelope(self, envelope: MessageEnvelope) -> None:
|
|
||||||
"""
|
|
||||||
处理来自适配器的 MessageEnvelope
|
|
||||||
|
|
||||||
Args:
|
|
||||||
envelope: 统一的消息信封
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 转换为旧版格式
|
|
||||||
message_data = EnvelopeConverter.to_legacy_dict(envelope)
|
|
||||||
|
|
||||||
# 使用现有的消息处理流程
|
|
||||||
await self._message_process_wrapper(message_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理 MessageEnvelope 时出错: {e}", exc_info=True)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
||||||
|
|
||||||
@@ -37,6 +39,7 @@ class MessageRoute:
|
|||||||
handler: MessageHandler
|
handler: MessageHandler
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
message_type: str | None = None
|
message_type: str | None = None
|
||||||
|
message_types: set[str] | None = None # 支持多个消息类型
|
||||||
event_types: set[str] | None = None
|
event_types: set[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -55,6 +58,8 @@ class MessageRuntime:
|
|||||||
self._middlewares: list[Middleware] = []
|
self._middlewares: list[Middleware] = []
|
||||||
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
||||||
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
||||||
|
# 用于检测同一类型的重复注册
|
||||||
|
self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name
|
||||||
|
|
||||||
def add_route(
|
def add_route(
|
||||||
self,
|
self,
|
||||||
@@ -62,7 +67,7 @@ class MessageRuntime:
|
|||||||
handler: MessageHandler,
|
handler: MessageHandler,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
*,
|
*,
|
||||||
message_type: str | None = None,
|
message_type: str | list[str] | None = None,
|
||||||
event_types: Iterable[str] | None = None,
|
event_types: Iterable[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -72,28 +77,72 @@ class MessageRuntime:
|
|||||||
predicate: 路由匹配条件
|
predicate: 路由匹配条件
|
||||||
handler: 消息处理函数
|
handler: 消息处理函数
|
||||||
name: 路由名称(可选)
|
name: 路由名称(可选)
|
||||||
message_type: 消息类型(可选)
|
message_type: 消息类型,可以是字符串或字符串列表(可选)
|
||||||
event_types: 事件类型列表(可选)
|
event_types: 事件类型列表(可选)
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
# 处理 message_type 参数,支持字符串或列表
|
||||||
|
message_types_set: set[str] | None = None
|
||||||
|
single_message_type: str | None = None
|
||||||
|
|
||||||
|
if message_type is not None:
|
||||||
|
if isinstance(message_type, str):
|
||||||
|
message_types_set = {message_type}
|
||||||
|
single_message_type = message_type
|
||||||
|
elif isinstance(message_type, list):
|
||||||
|
message_types_set = set(message_type)
|
||||||
|
if len(message_types_set) == 1:
|
||||||
|
single_message_type = next(iter(message_types_set))
|
||||||
|
else:
|
||||||
|
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||||
|
|
||||||
|
# 检测重复注册:如果明确指定了某个类型,不允许重复
|
||||||
|
handler_name = name or getattr(handler, "__name__", str(handler))
|
||||||
|
for msg_type in message_types_set:
|
||||||
|
if msg_type in self._explicit_type_handlers:
|
||||||
|
existing_handler = self._explicit_type_handlers[msg_type]
|
||||||
|
raise ValueError(
|
||||||
|
f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册,"
|
||||||
|
f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。"
|
||||||
|
)
|
||||||
|
self._explicit_type_handlers[msg_type] = handler_name
|
||||||
|
|
||||||
route = MessageRoute(
|
route = MessageRoute(
|
||||||
predicate=predicate,
|
predicate=predicate,
|
||||||
handler=handler,
|
handler=handler,
|
||||||
name=name,
|
name=name,
|
||||||
message_type=message_type,
|
message_type=single_message_type,
|
||||||
|
message_types=message_types_set,
|
||||||
event_types=set(event_types) if event_types is not None else None,
|
event_types=set(event_types) if event_types is not None else None,
|
||||||
)
|
)
|
||||||
self._routes.append(route)
|
self._routes.append(route)
|
||||||
if message_type:
|
|
||||||
self._type_routes.setdefault(message_type, []).append(route)
|
# 为每个消息类型建立索引
|
||||||
|
if message_types_set:
|
||||||
|
for msg_type in message_types_set:
|
||||||
|
self._type_routes.setdefault(msg_type, []).append(route)
|
||||||
|
|
||||||
if route.event_types:
|
if route.event_types:
|
||||||
for et in route.event_types:
|
for et in route.event_types:
|
||||||
self._event_routes.setdefault(et, []).append(route)
|
self._event_routes.setdefault(et, []).append(route)
|
||||||
|
|
||||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||||
"""装饰器写法,便于在核心逻辑中声明式注册。"""
|
"""装饰器写法,便于在核心逻辑中声明式注册。
|
||||||
|
|
||||||
|
支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。
|
||||||
|
"""
|
||||||
|
|
||||||
def decorator(func: MessageHandler) -> MessageHandler:
|
def decorator(func: MessageHandler) -> MessageHandler:
|
||||||
|
# Support decorating instance methods: defer binding until the object is created.
|
||||||
|
if _looks_like_method(func):
|
||||||
|
return _InstanceMethodRoute(
|
||||||
|
runtime=self,
|
||||||
|
func=func,
|
||||||
|
predicate=predicate,
|
||||||
|
name=name,
|
||||||
|
message_type=None,
|
||||||
|
)
|
||||||
|
|
||||||
self.add_route(predicate, func, name=name)
|
self.add_route(predicate, func, name=name)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@@ -101,16 +150,45 @@ class MessageRuntime:
|
|||||||
|
|
||||||
def on_message(
|
def on_message(
|
||||||
self,
|
self,
|
||||||
|
func: MessageHandler | None = None,
|
||||||
*,
|
*,
|
||||||
message_type: str | None = None,
|
message_type: str | list[str] | None = None,
|
||||||
platform: str | None = None,
|
platform: str | None = None,
|
||||||
predicate: Predicate | None = None,
|
predicate: Predicate | None = None,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> Callable[[MessageHandler], MessageHandler]:
|
) -> Callable[[MessageHandler], MessageHandler] | MessageHandler:
|
||||||
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。"""
|
"""Sugar decorator with optional Seg.type/platform predicate matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 被装饰的函数
|
||||||
|
message_type: 消息类型,可以是单个字符串或字符串列表
|
||||||
|
platform: 平台名称
|
||||||
|
predicate: 自定义匹配条件
|
||||||
|
name: 路由名称
|
||||||
|
|
||||||
|
Usages:
|
||||||
|
- @runtime.on_message(...)
|
||||||
|
- @runtime.on_message
|
||||||
|
- @runtime.on_message(message_type="text")
|
||||||
|
- @runtime.on_message(message_type=["text", "image"])
|
||||||
|
|
||||||
|
If the target looks like an instance method (first arg is self), it will be
|
||||||
|
auto-bound to the instance and registered when the object is constructed.
|
||||||
|
"""
|
||||||
|
# 将 message_type 转换为集合以便统一处理
|
||||||
|
message_types_set: set[str] | None = None
|
||||||
|
if message_type is not None:
|
||||||
|
if isinstance(message_type, str):
|
||||||
|
message_types_set = {message_type}
|
||||||
|
elif isinstance(message_type, list):
|
||||||
|
message_types_set = set(message_type)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||||
|
|
||||||
async def combined_predicate(message: MessageEnvelope) -> bool:
|
async def combined_predicate(message: MessageEnvelope) -> bool:
|
||||||
if message_type is not None and _extract_segment_type(message) != message_type:
|
if message_types_set is not None:
|
||||||
|
extracted_type = _extract_segment_type(message)
|
||||||
|
if extracted_type not in message_types_set:
|
||||||
return False
|
return False
|
||||||
if platform is not None:
|
if platform is not None:
|
||||||
info_platform = message.get("message_info", {}).get("platform")
|
info_platform = message.get("message_info", {}).get("platform")
|
||||||
@@ -123,35 +201,24 @@ class MessageRuntime:
|
|||||||
return await _invoke_callable(predicate, message, prefer_thread=False)
|
return await _invoke_callable(predicate, message, prefer_thread=False)
|
||||||
|
|
||||||
def decorator(func: MessageHandler) -> MessageHandler:
|
def decorator(func: MessageHandler) -> MessageHandler:
|
||||||
|
# Support decorating instance methods: defer binding until the object is created.
|
||||||
|
if _looks_like_method(func):
|
||||||
|
return _InstanceMethodRoute(
|
||||||
|
runtime=self,
|
||||||
|
func=func,
|
||||||
|
predicate=combined_predicate,
|
||||||
|
name=name,
|
||||||
|
message_type=message_type,
|
||||||
|
)
|
||||||
|
|
||||||
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def on_event(
|
|
||||||
self,
|
|
||||||
event_type: str | Iterable[str],
|
|
||||||
*,
|
|
||||||
name: str | None = None,
|
|
||||||
) -> Callable[[MessageHandler], MessageHandler]:
|
|
||||||
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
|
|
||||||
|
|
||||||
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
|
|
||||||
|
|
||||||
async def predicate(message: MessageEnvelope) -> bool:
|
|
||||||
current = (
|
|
||||||
message.get("event_type")
|
|
||||||
or message.get("message_info", {})
|
|
||||||
.get("additional_config", {})
|
|
||||||
.get("event_type")
|
|
||||||
)
|
|
||||||
return current in allowed
|
|
||||||
|
|
||||||
def decorator(func: MessageHandler) -> MessageHandler:
|
|
||||||
self.add_route(predicate, func, name=name, event_types=allowed)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||||
self._batch_handler = handler
|
self._batch_handler = handler
|
||||||
@@ -199,7 +266,7 @@ class MessageRuntime:
|
|||||||
return responses
|
return responses
|
||||||
|
|
||||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||||
candidates: list[MessageRoute] = []
|
"""匹配消息路由,优先匹配明确指定了消息类型的处理器"""
|
||||||
message_type = _extract_segment_type(message)
|
message_type = _extract_segment_type(message)
|
||||||
event_type = (
|
event_type = (
|
||||||
message.get("event_type")
|
message.get("event_type")
|
||||||
@@ -207,15 +274,29 @@ class MessageRuntime:
|
|||||||
.get("additional_config", {})
|
.get("additional_config", {})
|
||||||
.get("event_type")
|
.get("event_type")
|
||||||
)
|
)
|
||||||
with self._lock:
|
|
||||||
if event_type and event_type in self._event_routes:
|
|
||||||
candidates.extend(self._event_routes[event_type])
|
|
||||||
if message_type and message_type in self._type_routes:
|
|
||||||
candidates.extend(self._type_routes[message_type])
|
|
||||||
candidates.extend(self._routes)
|
|
||||||
|
|
||||||
|
# 分为两层候选:优先级和普通
|
||||||
|
priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的
|
||||||
|
normal_candidates: list[MessageRoute] = [] # 没有指定或通配的
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# 事件路由(优先级最高)
|
||||||
|
if event_type and event_type in self._event_routes:
|
||||||
|
priority_candidates.extend(self._event_routes[event_type])
|
||||||
|
|
||||||
|
# 消息类型路由(明确指定的有优先级)
|
||||||
|
if message_type and message_type in self._type_routes:
|
||||||
|
priority_candidates.extend(self._type_routes[message_type])
|
||||||
|
|
||||||
|
# 通用路由(没有明确指定类型的)
|
||||||
|
for route in self._routes:
|
||||||
|
# 如果路由没有指定 message_types,则是通用路由
|
||||||
|
if route.message_types is None and route.event_types is None:
|
||||||
|
normal_candidates.append(route)
|
||||||
|
|
||||||
|
# 先尝试优先级候选
|
||||||
seen: set[int] = set()
|
seen: set[int] = set()
|
||||||
for route in candidates:
|
for route in priority_candidates:
|
||||||
rid = id(route)
|
rid = id(route)
|
||||||
if rid in seen:
|
if rid in seen:
|
||||||
continue
|
continue
|
||||||
@@ -223,6 +304,17 @@ class MessageRuntime:
|
|||||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||||
if should_handle:
|
if should_handle:
|
||||||
return route
|
return route
|
||||||
|
|
||||||
|
# 如果没有匹配到优先级候选,再尝试普通候选
|
||||||
|
for route in normal_candidates:
|
||||||
|
rid = id(route)
|
||||||
|
if rid in seen:
|
||||||
|
continue
|
||||||
|
seen.add(rid)
|
||||||
|
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||||
|
if should_handle:
|
||||||
|
return route
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||||
@@ -257,7 +349,27 @@ class MessageRuntime:
|
|||||||
|
|
||||||
|
|
||||||
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||||
"""支持 sync/async 调用,并可选择在线程中执行。"""
|
"""支持 sync/async 调用,并可选择在线程中执行。
|
||||||
|
|
||||||
|
自动处理普通函数、类方法和绑定方法。
|
||||||
|
"""
|
||||||
|
# 如果是绑定方法(bound method),直接使用,不需要额外处理
|
||||||
|
# 因为绑定方法已经包含了 self 参数
|
||||||
|
if inspect.ismethod(func):
|
||||||
|
# 绑定方法可以直接调用,args 中不应包含 self
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
return await func(*args)
|
||||||
|
if prefer_thread:
|
||||||
|
result = await asyncio.to_thread(func, *args)
|
||||||
|
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
result = func(*args)
|
||||||
|
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 对于普通函数(未绑定的),按原有逻辑处理
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
return await func(*args)
|
return await func(*args)
|
||||||
if prefer_thread:
|
if prefer_thread:
|
||||||
@@ -282,6 +394,70 @@ def _extract_segment_type(message: MessageEnvelope) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_method(func: Callable[..., object]) -> bool:
|
||||||
|
"""Return True if callable signature suggests an instance method (first arg named self)."""
|
||||||
|
if inspect.ismethod(func):
|
||||||
|
return True
|
||||||
|
if not inspect.isfunction(func):
|
||||||
|
return False
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
if not params:
|
||||||
|
return False
|
||||||
|
first = next(iter(params.values()))
|
||||||
|
return first.name == "self"
|
||||||
|
|
||||||
|
|
||||||
|
class _InstanceMethodRoute:
|
||||||
|
"""Descriptor that binds decorated instance methods and registers routes per-instance."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runtime: MessageRuntime,
|
||||||
|
func: MessageHandler,
|
||||||
|
predicate: Predicate,
|
||||||
|
name: str | None,
|
||||||
|
message_type: str | None,
|
||||||
|
) -> None:
|
||||||
|
self._runtime = runtime
|
||||||
|
self._func = func
|
||||||
|
self._predicate = predicate
|
||||||
|
self._name = name
|
||||||
|
self._message_type = message_type
|
||||||
|
self._owner: type | None = None
|
||||||
|
self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet()
|
||||||
|
|
||||||
|
def __set_name__(self, owner: type, name: str) -> None:
|
||||||
|
self._owner = owner
|
||||||
|
registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None)
|
||||||
|
if registry is None:
|
||||||
|
registry = []
|
||||||
|
setattr(owner, "_mofox_instance_routes", registry)
|
||||||
|
original_init = owner.__init__
|
||||||
|
|
||||||
|
@functools.wraps(original_init)
|
||||||
|
def wrapped_init(inst, *args, **kwargs):
|
||||||
|
original_init(inst, *args, **kwargs)
|
||||||
|
for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []):
|
||||||
|
descriptor._register_instance(inst)
|
||||||
|
|
||||||
|
owner.__init__ = wrapped_init # type: ignore[assignment]
|
||||||
|
registry.append(self)
|
||||||
|
|
||||||
|
def _register_instance(self, instance: object) -> None:
|
||||||
|
if instance in self._registered_instances:
|
||||||
|
return
|
||||||
|
owner = self._owner or instance.__class__
|
||||||
|
bound = self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||||
|
self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type)
|
||||||
|
self._registered_instances.add(instance)
|
||||||
|
|
||||||
|
def __get__(self, instance: object | None, owner: type | None = None):
|
||||||
|
if instance is None:
|
||||||
|
return self._func
|
||||||
|
self._register_instance(instance)
|
||||||
|
return self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchHandler",
|
"BatchHandler",
|
||||||
"Hook",
|
"Hook",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required
|
||||||
|
|
||||||
MessageDirection = Literal["incoming", "outgoing"]
|
MessageDirection = Literal["incoming", "outgoing"]
|
||||||
|
|
||||||
@@ -14,14 +14,14 @@ class SegPayload(TypedDict, total=False):
|
|||||||
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: Required[str]
|
||||||
data: str | List["SegPayload"]
|
data: Required[str | List["SegPayload"]]
|
||||||
translated_data: NotRequired[str | List["SegPayload"]]
|
translated_data: NotRequired[str | List["SegPayload"]]
|
||||||
|
|
||||||
|
|
||||||
class UserInfoPayload(TypedDict, total=False):
|
class UserInfoPayload(TypedDict, total=False):
|
||||||
platform: NotRequired[str]
|
platform: NotRequired[str]
|
||||||
user_id: NotRequired[str]
|
user_id: Required[str]
|
||||||
user_nickname: NotRequired[str]
|
user_nickname: NotRequired[str]
|
||||||
user_cardname: NotRequired[str]
|
user_cardname: NotRequired[str]
|
||||||
user_avatar: NotRequired[str]
|
user_avatar: NotRequired[str]
|
||||||
@@ -29,7 +29,7 @@ class UserInfoPayload(TypedDict, total=False):
|
|||||||
|
|
||||||
class GroupInfoPayload(TypedDict, total=False):
|
class GroupInfoPayload(TypedDict, total=False):
|
||||||
platform: NotRequired[str]
|
platform: NotRequired[str]
|
||||||
group_id: NotRequired[str]
|
group_id: Required[str]
|
||||||
group_name: NotRequired[str]
|
group_name: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
@@ -45,8 +45,8 @@ class TemplateInfoPayload(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class MessageInfoPayload(TypedDict, total=False):
|
class MessageInfoPayload(TypedDict, total=False):
|
||||||
platform: NotRequired[str]
|
platform: Required[str]
|
||||||
message_id: NotRequired[str]
|
message_id: Required[str]
|
||||||
time: NotRequired[float]
|
time: NotRequired[float]
|
||||||
group_info: NotRequired[GroupInfoPayload]
|
group_info: NotRequired[GroupInfoPayload]
|
||||||
user_info: NotRequired[UserInfoPayload]
|
user_info: NotRequired[UserInfoPayload]
|
||||||
@@ -67,8 +67,8 @@ class MessageEnvelope(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
direction: MessageDirection
|
direction: MessageDirection
|
||||||
message_info: MessageInfoPayload
|
message_info: Required[MessageInfoPayload]
|
||||||
message_segment: SegPayload | List[SegPayload]
|
message_segment: Required[SegPayload] | List[SegPayload]
|
||||||
raw_message: NotRequired[Any]
|
raw_message: NotRequired[Any]
|
||||||
raw_bytes: NotRequired[bytes]
|
raw_bytes: NotRequired[bytes]
|
||||||
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
||||||
|
|||||||
Reference in New Issue
Block a user