diff --git a/docs/message_runtime_architecture.md b/docs/message_runtime_architecture.md new file mode 100644 index 000000000..30a7741e7 --- /dev/null +++ b/docs/message_runtime_architecture.md @@ -0,0 +1,222 @@ +# MoFox Bot 消息运行时架构 (MessageRuntime) + +本文档描述了 MoFox Bot 使用 `mofox_bus.MessageRuntime` 简化消息处理链条的架构设计。 + +## 架构概述 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ CoreSinkManager │ +│ ┌─────────────────────────────────────────────────────────────────────┐│ +│ │ MessageRuntime ││ +│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ││ +│ │ │ before_hook │→ │ Routes │→ │ after_hook │ ││ +│ │ │ (预处理/过滤) │ │ (消息路由) │ │ (后处理) │ ││ +│ │ └──────────────┘ └──────────────┘ └──────────────┘ ││ +│ │ ↓ ↓ ↓ ││ +│ │ ┌──────────────────────────────────────────────────────────────┐ ││ +│ │ │ error_hook (错误处理) │ ││ +│ │ └──────────────────────────────────────────────────────────────┘ ││ +│ └─────────────────────────────────────────────────────────────────────┘│ +│ │ +│ ┌──────────────────────┐ ┌──────────────────────────────────────┐ │ +│ │ InProcessCoreSink │ │ ProcessCoreSinkServer (子进程适配器) │ │ +│ │ (同进程适配器) │ │ │ │ +│ └──────────────────────┘ └──────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ + ↑ ↑ + │ │ + ┌─────────────────────┐ ┌─────────────────────┐ + │ 同进程适配器 │ │ 子进程适配器 │ + │ (run_in_subprocess │ │ (run_in_subprocess │ + │ = False) │ │ = True) │ + └─────────────────────┘ └─────────────────────┘ +``` + +## 核心组件 + +### 1. CoreSinkManager (`src/common/core_sink_manager.py`) + +统一管理 CoreSink 双实例和 MessageRuntime: + +```python +from src.common.core_sink_manager import get_core_sink_manager, get_message_runtime + +# 获取管理器 +manager = get_core_sink_manager() + +# 获取 MessageRuntime +runtime = get_message_runtime() + +# 发送消息到适配器 +await manager.send_outgoing(envelope) +``` + +### 2. MessageRuntime + +`MessageRuntime` 是 mofox_bus 提供的消息路由核心,支持: + +- **消息路由**:通过 `add_route()` 或 `@on_message` 装饰器按消息类型路由 +- **钩子机制**:`before_hook`(前置处理)、`after_hook`(后置处理)、`error_hook`(错误处理) +- **中间件**:洋葱模型的中间件机制 +- **批量处理**:支持 `handle_batch()` 批量处理消息 + +### 3. MessageHandler (`src/chat/message_receive/message_handler.py`) + +将消息处理逻辑注册为 MessageRuntime 的路由和钩子: + +```python +class MessageHandler: + def register_handlers(self, runtime: MessageRuntime) -> None: + # 注册前置钩子 + runtime.register_before_hook(self._before_hook) + + # 注册后置钩子 + runtime.register_after_hook(self._after_hook) + + # 注册错误钩子 + runtime.register_error_hook(self._error_hook) + + # 注册适配器响应处理器 + runtime.add_route( + predicate=_is_adapter_response, + handler=self._handle_adapter_response_route, + name="adapter_response_handler", + message_type="adapter_response", + ) + + # 注册默认消息处理器 + runtime.add_route( + predicate=lambda _: True, + handler=self._handle_normal_message, + name="default_message_handler", + ) +``` + +## 消息流向 + +### 接收消息 + +``` +适配器 → InProcessCoreSink/ProcessCoreSinkServer → CoreSinkManager._dispatch_to_runtime() + → MessageRuntime.handle_message() + → before_hook (预处理、过滤) + → 匹配路由 (adapter_response / normal_message) + → 执行处理器 + → after_hook (后处理) +``` + +### 发送消息 + +``` +消息发送请求 → CoreSinkManager.send_outgoing() + → InProcessCoreSink.push_outgoing() + → ProcessCoreSinkServer.push_outgoing() + → 适配器 +``` + +## 钩子功能 + +### before_hook + +在消息路由之前执行,用于: +- 标准化 ID 为字符串 +- 检查 echo 消息(自身消息上报) +- 通过抛出 `UserWarning` 跳过消息处理 + +### after_hook + +在消息处理完成后执行,用于: +- 清理工作 +- 日志记录 + +### error_hook + +在处理过程中出现异常时执行,用于: +- 区分预期的流程控制(UserWarning)和真正的错误 +- 统一异常日志记录 + +## 路由优先级 + +1. **明确指定 message_type 的路由**(优先级最高) +2. **事件路由**(基于 event_type) +3. **通用路由**(无 message_type 限制) + +## 扩展消息处理 + +### 注册自定义处理器 + +```python +from src.common.core_sink_manager import get_message_runtime +from mofox_bus import MessageEnvelope + +runtime = get_message_runtime() + +# 使用装饰器 +@runtime.on_message(message_type="image") +async def handle_image(envelope: MessageEnvelope): + # 处理图片消息 + pass + +# 或使用 add_route +runtime.add_route( + predicate=lambda env: env.get("platform") == "qq", + handler=my_handler, + name="qq_handler", +) +``` + +### 注册钩子 + +```python +runtime = get_message_runtime() + +# 前置钩子 +async def my_before_hook(envelope: MessageEnvelope) -> None: + # 预处理逻辑 + pass + +runtime.register_before_hook(my_before_hook) + +# 错误钩子 +async def my_error_hook(envelope: MessageEnvelope, exc: BaseException) -> None: + # 错误处理逻辑 + pass + +runtime.register_error_hook(my_error_hook) +``` + +## 初始化流程 + +在 `MainSystem.initialize()` 中: + +1. 初始化 `CoreSinkManager`(包含 `MessageRuntime`) +2. 获取 `MessageHandler` 并设置 `CoreSinkManager` 引用 +3. 调用 `MessageHandler.register_handlers()` 向 `MessageRuntime` 注册处理器和钩子 +4. 初始化其他组件 + +```python +async def initialize(self) -> None: + # 初始化 CoreSinkManager(包含 MessageRuntime) + self.core_sink_manager = await initialize_core_sink_manager() + + # 获取 MessageHandler 并向 MessageRuntime 注册处理器 + self.message_handler = get_message_handler() + self.message_handler.set_core_sink_manager(self.core_sink_manager) + self.message_handler.register_handlers(self.core_sink_manager.runtime) +``` + +## 优势 + +1. **简化消息处理链**:不再需要手动管理处理流程,使用声明式路由 +2. **更好的可扩展性**:通过 `add_route()` 或装饰器轻松添加新的处理器 +3. **统一的错误处理**:通过 `error_hook` 集中处理异常 +4. **支持中间件**:可以添加洋葱模型的中间件 +5. **更清晰的代码结构**:处理逻辑按类型分离 + +## 参考 + +- `packages/mofox-bus/src/mofox_bus/runtime.py` - MessageRuntime 实现 +- `src/common/core_sink_manager.py` - CoreSinkManager 实现 +- `src/chat/message_receive/message_handler.py` - MessageHandler 实现 +- `docs/mofox_bus.md` - MoFox Bus 消息库说明 diff --git a/docs/mofox_bus.md b/docs/mofox_bus.md index d7dc68427..febe0c401 100644 --- a/docs/mofox_bus.md +++ b/docs/mofox_bus.md @@ -2,6 +2,8 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_message`,将核心与各平台适配器之间的通信抽象成可拓展、可热插拔的组件。该库完全异步、面向高吞吐,覆盖消息建模、序列化、传输层、运行时路由、适配器工具等多个层面。 +> 现在已拆分为独立 pip 包:在项目根目录执行 `pip install -e ./packages/mofox-bus` 即可安装到当前 Python 环境。 + --- ## 1. 设计目标 @@ -14,7 +16,7 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess --- -## 2. 包结构概览(`src/mofox_bus/`) +## 2. 包结构概览(`packages/mofox-bus/src/mofox_bus/`) | 模块 | 主要职责 | | --- | --- | diff --git a/examples/mofox_bus_demo_adapter.py b/examples/mofox_bus_demo_adapter.py index c5b75b58e..8ca040dc0 100644 --- a/examples/mofox_bus_demo_adapter.py +++ b/examples/mofox_bus_demo_adapter.py @@ -8,7 +8,6 @@ from __future__ import annotations import asyncio -import sys import time import uuid from pathlib import Path @@ -17,8 +16,6 @@ from typing import Any, Dict, Optional import orjson import websockets -# 追加 src 目录,便于直接运行示例 -sys.path.append(str(Path(__file__).resolve().parents[1] / "src")) from mofox_bus import ( AdapterBase, diff --git a/pyproject.toml b/pyproject.toml index d9d250a9f..f965ce1d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "inkfox>=0.1.1", "rjieba>=0.1.13", "fastmcp>=2.13.0", + "mofox-bus", ] [[tool.uv.index]] diff --git a/requirements.txt b/requirements.txt index 2aacffef1..d69d3876d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -sqlalchemy aiosqlite aiofiles aiomysql diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index a4405358b..7a2e76d68 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -121,8 +121,6 @@ class ChatterManager: chatter_class = self.get_chatter_class(chat_type) if not chatter_class: - from src.plugin_system.base.component_types import ChatType - all_chatter_class = self.get_chatter_class(ChatType.ALL) if all_chatter_class: chatter_class = all_chatter_class diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 516c56456..0bd809d4a 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -8,9 +8,10 @@ import random import time from typing import TYPE_CHECKING, Any -from src.chat.chatter_manager import ChatterManager -from src.chat.message_receive.chat_stream import ChatStream from src.chat.planner_actions.action_manager import ChatterActionManager + +if TYPE_CHECKING: + from src.chat.chatter_manager import ChatterManager from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.logger import get_logger @@ -21,7 +22,7 @@ from .distribution_manager import stream_loop_manager from .global_notice_manager import NoticeScope, global_notice_manager if TYPE_CHECKING: - pass + from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("message_manager") @@ -39,6 +40,8 @@ class MessageManager: # 初始化chatter manager self.action_manager = ChatterActionManager() + # 延迟导入ChatterManager以避免循环导入 + from src.chat.chatter_manager import ChatterManager self.chatter_manager = ChatterManager(self.action_manager) # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py deleted file mode 100644 index 6524ea8d3..000000000 --- a/src/chat/message_receive/bot.py +++ /dev/null @@ -1,488 +0,0 @@ -import os -import re -import traceback -from typing import Any - -from mofox_bus.runtime import MessageRuntime -from mofox_bus import MessageEnvelope -from src.chat.message_manager import message_manager -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.prompt import global_prompt_manager -from src.chat.utils.utils import is_mentioned_bot_in_message -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.logger import get_logger -from src.config.config import global_config -from src.mood.mood_manager import mood_manager # 导入情绪管理器 -from src.plugin_system.base import BaseCommand, EventType -from src.plugin_system.core import component_registry, event_manager, global_announcement_manager - -# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) - -# 配置主程序日志格式 -logger = get_logger("chat") - - -def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: - """检查消息是否包含过滤词 - - Args: - text: 待检查的文本 - chat: 聊天对象 - userinfo: 用户信息 - - Returns: - bool: 是否包含过滤词 - """ - for word in global_config.message_receive.ban_words: - if word in text: - chat_name = chat.group_info.group_name if chat.group_info else "私聊" - logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") - logger.info(f"[过滤词识别]消息中含有{word},filtered") - return True - return False - - -def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: - """检查消息是否匹配过滤正则表达式 - - Args: - text: 待检查的文本 - chat: 聊天对象 - userinfo: 用户信息 - - Returns: - bool: 是否匹配过滤正则 - """ - for pattern in global_config.message_receive.ban_msgs_regex: - if re.search(pattern, text): - chat_name = chat.group_info.group_name if chat.group_info else "私聊" - logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") - logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") - return True - return False - -runtime = MessageRuntime() # 获取mofox-bus运行时环境 - -class ChatBot: - def __init__(self): - self.bot = None # bot 实例引用 - self._started = False - self.mood_manager = mood_manager # 获取情绪管理器单例 - # 启动消息管理器 - self._message_manager_started = False - - async def _ensure_started(self): - """确保所有任务已启动""" - if not self._started: - logger.debug("确保ChatBot所有任务已启动") - - # 启动消息管理器 - if not self._message_manager_started: - await message_manager.start() - self._message_manager_started = True - logger.info("消息管理器已启动") - - self._started = True - - async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream): - """独立处理PlusCommand系统""" - try: - text = message.processed_plain_text or "" - - # 获取配置的命令前缀 - from src.config.config import global_config - - prefixes = global_config.command.command_prefixes - - # 检查是否以任何前缀开头 - matched_prefix = None - for prefix in prefixes: - if text.startswith(prefix): - matched_prefix = prefix - break - - if not matched_prefix: - return False, None, True # 不是命令,继续处理 - - # 移除前缀 - command_part = text[len(matched_prefix) :].strip() - - # 分离命令名和参数 - parts = command_part.split(None, 1) - if not parts: - return False, None, True # 没有命令名,继续处理 - - command_word = parts[0].lower() - args_text = parts[1] if len(parts) > 1 else "" - - # 查找匹配的PlusCommand - plus_command_registry = component_registry.get_plus_command_registry() - matching_commands = [] - - for plus_command_name, plus_command_class in plus_command_registry.items(): - plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name) - if not plus_command_info: - continue - - # 检查命令名是否匹配(命令名和别名) - all_commands = [plus_command_name.lower()] + [ - alias.lower() for alias in plus_command_info.command_aliases - ] - if command_word in all_commands: - matching_commands.append((plus_command_class, plus_command_info, plus_command_name)) - - if not matching_commands: - return False, None, True # 没有找到匹配的PlusCommand,继续处理 - - # 如果有多个匹配,按优先级排序 - if len(matching_commands) > 1: - matching_commands.sort(key=lambda x: x[1].priority, reverse=True) - logger.warning( - f"文本 '{text}' 匹配到多个PlusCommand: {[cmd[2] for cmd in matching_commands]},使用优先级最高的" - ) - - plus_command_class, plus_command_info, plus_command_name = matching_commands[0] - - # 检查命令是否被禁用 - if ( - chat - and chat.stream_id - and plus_command_name - in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) - ): - logger.info("用户禁用的PlusCommand,跳过处理") - return False, None, True - - message.is_command = True - - # 获取插件配置 - plugin_config = component_registry.get_plugin_config(plus_command_name) - - # 创建PlusCommand实例 - plus_command_instance = plus_command_class(message, plugin_config) - - # 为插件实例设置 chat_stream 运行时属性 - setattr(plus_command_instance, "chat_stream", chat) - - try: - # 检查聊天类型限制 - if not plus_command_instance.is_chat_type_allowed(): - is_group = chat.group_info is not None - logger.info( - f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" - ) - return False, None, True # 跳过此命令,继续处理其他消息 - - # 设置参数 - from src.plugin_system.base.command_args import CommandArgs - - command_args = CommandArgs(args_text) - plus_command_instance.args = command_args - - # 执行命令 - success, response, intercept_message = await plus_command_instance.execute(command_args) - - # 记录命令执行结果 - if success: - logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})") - else: - logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}") - - # 根据命令的拦截设置决定是否继续处理消息 - return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续 - - except Exception as e: - logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}") - logger.error(traceback.format_exc()) - - try: - await plus_command_instance.send_text(f"命令执行出错: {e!s}") - except Exception as send_error: - logger.error(f"发送错误消息失败: {send_error}") - - # 命令出错时,根据命令的拦截设置决定是否继续处理消息 - return True, str(e), False # 出错时继续处理消息 - - except Exception as e: - logger.error(f"处理PlusCommand时出错: {e}") - return False, None, True # 出错时继续处理消息 - - async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream): - # sourcery skip: use-named-expression - """使用新插件系统处理命令""" - try: - text = message.processed_plain_text or "" - - # 使用新的组件注册中心查找命令 - command_result = component_registry.find_command_by_text(text) - if command_result: - command_class, matched_groups, command_info = command_result - plugin_name = command_info.plugin_name - command_name = command_info.name - if ( - chat - and chat.stream_id - and command_name - in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) - ): - logger.info("用户禁用的命令,跳过处理") - return False, None, True - - message.is_command = True - - # 获取插件配置 - plugin_config = component_registry.get_plugin_config(plugin_name) - - # 创建命令实例 - command_instance: BaseCommand = command_class(message, plugin_config) - command_instance.set_matched_groups(matched_groups) - - # 为插件实例设置 chat_stream 运行时属性 - setattr(command_instance, "chat_stream", chat) - - try: - # 检查聊天类型限制 - if not command_instance.is_chat_type_allowed(): - is_group = chat.group_info is not None - logger.info( - f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" - ) - return False, None, True # 跳过此命令,继续处理其他消息 - - # 执行命令 - success, response, intercept_message = await command_instance.execute() - - # 记录命令执行结果 - if success: - logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})") - else: - logger.warning(f"命令执行失败: {command_class.__name__} - {response}") - - # 根据命令的拦截设置决定是否继续处理消息 - return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续 - - except Exception as e: - logger.error(f"执行命令时出错: {command_class.__name__} - {e}") - logger.error(traceback.format_exc()) - - try: - await command_instance.send_text(f"命令执行出错: {e!s}") - except Exception as send_error: - logger.error(f"发送错误消息失败: {send_error}") - - # 命令出错时,根据命令的拦截设置决定是否继续处理消息 - return True, str(e), False # 出错时继续处理消息 - - # 没有找到命令,继续处理消息 - return False, None, True - - except Exception as e: - logger.error(f"处理命令时出错: {e}") - return False, None, True # 出错时继续处理消息 - - - async def _handle_adapter_response_from_dict(self, seg_data: dict | None): - """处理适配器命令响应(从字典数据)""" - try: - from src.plugin_system.apis.send_api import put_adapter_response - - if isinstance(seg_data, dict): - request_id = seg_data.get("request_id") - response_data = seg_data.get("response") - else: - request_id = None - response_data = None - - if request_id and response_data: - logger.info(f"[DEBUG bot.py] 收到适配器响应,request_id={request_id}") - put_adapter_response(request_id, response_data) - else: - logger.warning(f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}") - - except Exception as e: - logger.error(f"处理适配器响应时出错: {e}") - - @runtime.on_message - async def message_process(self, envelope: MessageEnvelope) -> None: - """处理转化后的统一格式消息""" - # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError - message_info = envelope.get("message_info") - if not isinstance(message_info, dict): - logger.debug( - "收到缺少 message_info 的消息,已跳过。可用字段: %s", - ", ".join(envelope.keys()), - ) - return - - if message_info.get("group_info") is not None: - message_info["group_info"]["group_id"] = str( # type: ignore - message_info["group_info"]["group_id"] # type: ignore - ) - if message_info.get("user_info") is not None: - message_info["user_info"]["user_id"] = str( # type: ignore - message_info["user_info"]["user_id"] # type: ignore - ) - - # 优先处理adapter_response消息(在echo检查之前!) - message_segment = envelope.get("message_segment") - if message_segment and isinstance(message_segment, dict): - if message_segment.get("type") == "adapter_response": - logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理") - await self._handle_adapter_response_from_dict(message_segment.get("data")) - return - - # 先提取基础信息检查是否是自身消息上报 - from mofox_bus import BaseMessageInfo - temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) - if temp_message_info.additional_config: - sent_message = temp_message_info.additional_config.get("echo", False) - if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题 - # 直接使用消息字典更新,不再需要创建 MessageRecv - await MessageStorage.update_message(message_data) - return - - message_segment = envelope.get("message_segment") - group_info = temp_message_info.group_info - user_info = temp_message_info.user_info - - # 获取或创建聊天流 - chat = await get_chat_manager().get_or_create_stream( - platform=temp_message_info.platform, # type: ignore - user_info=user_info, # type: ignore - group_info=group_info, - ) - - # 使用新的消息处理器直接生成 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time - - # 注册消息到聊天管理器 - get_chat_manager().register_message(message) - - # 检测是否提及机器人 - message.is_mentioned, _ = is_mentioned_bot_in_message(message) - - # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 - chat_name = chat.group_info.group_name if chat.group_info else "私聊" - user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" - logger.info( - f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m" - ) - - # 在此添加硬编码过滤,防止回复图片处理失败的消息 - failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] - processed_text = message.processed_plain_text or "" - if any(keyword in processed_text for keyword in failure_keywords): - logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") - return - - # 过滤检查 - # DatabaseMessages 使用 display_message 作为原始消息表示 - raw_text = message.display_message or message.processed_plain_text or "" - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - raw_text, - chat, - user_info, # type: ignore - ): - return - - # 命令处理 - 首先尝试PlusCommand独立处理 - is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) - - # 如果是PlusCommand且不需要继续处理,则直接返回 - if is_plus_command and not plus_continue_process: - await MessageStorage.store_message(message, chat) - logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}") - return - - # 如果不是PlusCommand,尝试传统的BaseCommand处理 - if not is_plus_command: - is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat) - - # 如果是命令且不需要继续处理,则直接返回 - if is_command and not continue_process: - await MessageStorage.store_message(message, chat) - logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") - return - - result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) - if result and not result.all_continue_process(): - raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - - # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info - # 确认从接口发来的message是否有自定义的prompt模板信息 - # 这个功能需要在 adapter 层通过 additional_config 传递 - template_group_name = None - - async def preprocess(): - # message 已经是 DatabaseMessages,直接使用 - group_info = chat.group_info - - # 先交给消息管理器处理,计算兴趣度等衍生数据 - try: - # 在将消息添加到管理器之前进行最终的静默检查 - should_process_in_manager = True - if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list: - # 检查消息是否为图片或表情包 - is_image_or_emoji = message.is_picid or message.is_emoji - if not message.is_mentioned and not is_image_or_emoji: - logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理") - should_process_in_manager = False - elif is_image_or_emoji: - logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理") - should_process_in_manager = False - - if should_process_in_manager: - await message_manager.add_message(chat.stream_id, message) - logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") - - except Exception as e: - logger.error(f"消息添加到消息管理器失败: {e}") - - # 存储消息到数据库,只进行一次写入 - try: - await MessageStorage.store_message(message, chat) - except Exception as e: - logger.error(f"存储消息到数据库失败: {e}") - traceback.print_exc() - - # 情绪系统更新 - 在消息存储后触发情绪更新 - try: - if global_config.mood.enable_mood: - # 获取兴趣度用于情绪更新 - interest_rate = message.interest_value - if interest_rate is None: - interest_rate = 0.0 - logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") - - # 获取当前聊天的情绪对象并更新情绪状态 - chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) - await chat_mood.update_mood_by_message(message, interest_rate) - logger.debug("情绪状态更新完成") - except Exception as e: - logger.error(f"更新情绪状态失败: {e}") - traceback.print_exc() - - if template_group_name: - async with global_prompt_manager.async_message_scope(template_group_name): - await preprocess() - else: - await preprocess() - - except Exception as e: - logger.error(f"预处理消息失败: {e}") - traceback.print_exc() - - -# 创建全局ChatBot实例 -chat_bot = ChatBot() diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 8777e852f..82f0ac659 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -2,11 +2,11 @@ import asyncio import hashlib import time -from mofox_bus import GroupInfo, UserInfo from rich.traceback import install from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import StreamContext from src.plugin_system.base.component_types import ChatMode, ChatType @@ -30,8 +30,8 @@ class ChatStream: self, stream_id: str, platform: str, - user_info: UserInfo | None = None, - group_info: GroupInfo | None = None, + user_info: DatabaseUserInfo | None = None, + group_info: DatabaseGroupInfo | None = None, data: dict | None = None, ): self.stream_id = stream_id @@ -77,8 +77,8 @@ class ChatStream: @classmethod def from_dict(cls, data: dict) -> "ChatStream": """从字典创建实例""" - user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None - group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None + user_info = DatabaseUserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None + group_info = DatabaseGroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None instance = cls( stream_id=data["stream_id"], @@ -369,7 +369,7 @@ class ChatManager: # logger.debug(f"注册消息到聊天流: {stream_id}") @staticmethod - def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str: + def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str: """生成聊天流唯一ID""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") @@ -392,7 +392,7 @@ class ChatManager: return hashlib.sha256(key.encode()).hexdigest() async def get_or_create_stream( - self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None + self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None ) -> ChatStream: """获取或创建聊天流 - 优化版本使用缓存机制""" try: @@ -483,7 +483,7 @@ class ChatManager: return stream def get_stream_by_info( - self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None + self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None ) -> ChatStream | None: """通过信息获取聊天流""" stream_id = self._generate_stream_id(platform, user_info, group_info) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 305e73de2..8baf1705c 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,12 +1,11 @@ import time from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Optional, TYPE_CHECKING import urllib3 from rich.traceback import install -from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_voice import get_voice_text @@ -14,6 +13,9 @@ from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + install(extra_lines=3) diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py index be81ab0eb..3b922660b 100644 --- a/src/chat/message_receive/message_handler.py +++ b/src/chat/message_receive/message_handler.py @@ -1,39 +1,566 @@ -import os -import traceback +""" +统一消息处理器 (Message Handler) + +利用 mofox_bus.MessageRuntime 的路由功能,简化消息处理链条: + +1. 使用 @runtime.on_message() 装饰器注册按消息类型路由的处理器 +2. 使用 before_hook 进行消息预处理(ID标准化、过滤等) +3. 使用 after_hook 进行消息后处理(存储、情绪更新等) +4. 使用 error_hook 统一处理异常 + +消息流向: + 适配器 → CoreSinkManager → MessageRuntime + ↓ + [before_hook] 消息预处理、过滤 + ↓ + [on_message] 按类型路由处理(命令、普通消息等) + ↓ + [after_hook] 存储、情绪更新等 + ↓ + 回复生成 → CoreSinkManager.send_outgoing() → 适配器 + +重构说明(2025-11): +- 移除手动的消息处理链,改用 MessageRuntime 路由 +- MessageHandler 变成处理器注册器,在初始化时注册各种处理器 +- 利用 runtime 的钩子机制简化前置/后置处理 +""" + +from __future__ import annotations + +import asyncio +import os +import re +import time +import traceback +from functools import partial +from typing import TYPE_CHECKING, Any + +from mofox_bus import MessageEnvelope, MessageRuntime -from mofox_bus.runtime import MessageRuntime -from mofox_bus import MessageEnvelope from src.chat.message_manager import message_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.prompt import global_prompt_manager +from src.chat.utils.utils import is_mentioned_bot_in_message +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages +from src.mood.mood_manager import mood_manager +from src.plugin_system.base import BaseCommand, EventType +from src.plugin_system.core import component_registry, event_manager, global_announcement_manager -runtime = MessageRuntime() +if TYPE_CHECKING: + from src.common.core_sink_manager import CoreSinkManager + from src.chat.message_receive.chat_stream import ChatStream -# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +logger = get_logger("message_handler") + +# 项目根目录 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + +def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool: + """检查消息是否包含过滤词""" + for word in global_config.message_receive.ban_words: + if word in text: + chat_name = chat.group_info.group_name if chat.group_info else "私聊" + logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") + logger.info(f"[过滤词识别]消息中含有{word},filtered") + return True + return False + + +def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool: + """检查消息是否匹配过滤正则表达式""" + for pattern in global_config.message_receive.ban_msgs_regex: + if re.search(pattern, text): + chat_name = chat.group_info.group_name if chat.group_info else "私聊" + logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") + logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") + return True + return False -# 配置主程序日志格式 -logger = get_logger("chat") class MessageHandler: + """ + 统一消息处理器 + + 利用 MessageRuntime 的路由功能,将消息处理逻辑注册为路由和钩子。 + + 架构说明: + - 在 register_handlers() 中向 MessageRuntime 注册各种处理器 + - 使用 @runtime.on_message(message_type=...) 按消息类型路由 + - 使用 before_hook 进行消息预处理 + - 使用 after_hook 进行消息后处理 + - 使用 error_hook 统一处理异常 + + 主要功能: + 1. 消息预处理:ID标准化、过滤检查 + 2. 适配器响应处理:处理 adapter_response 类型消息 + 3. 命令处理:PlusCommand 和 BaseCommand + 4. 普通消息处理:触发事件、存储、情绪更新 + """ + def __init__(self): self._started = False + self._message_manager_started = False + self._core_sink_manager: CoreSinkManager | None = None + self._shutting_down = False + self._runtime: MessageRuntime | None = None - async def preprocess(self, chat: ChatStream, message: DatabaseMessages): - # message 已经是 DatabaseMessages,直接使用 - group_info = chat.group_info + def set_core_sink_manager(self, manager: "CoreSinkManager") -> None: + """设置 CoreSinkManager 引用""" + self._core_sink_manager = manager - # 先交给消息管理器处理 + def register_handlers(self, runtime: MessageRuntime) -> None: + """ + 向 MessageRuntime 注册消息处理器和钩子 + + 这是核心方法,在系统初始化时调用,将所有处理逻辑注册到 runtime。 + + Args: + runtime: MessageRuntime 实例 + """ + self._runtime = runtime + + # 注册前置钩子:消息预处理和过滤 + runtime.register_before_hook(self._before_hook) + + # 注册后置钩子:存储、情绪更新等 + runtime.register_after_hook(self._after_hook) + + # 注册错误钩子:统一异常处理 + runtime.register_error_hook(self._error_hook) + + # 注册适配器响应处理器(最高优先级) + def _is_adapter_response(env: MessageEnvelope) -> bool: + segment = env.get("message_segment") + if isinstance(segment, dict): + return segment.get("type") == "adapter_response" + return False + + runtime.add_route( + predicate=_is_adapter_response, + handler=self._handle_adapter_response_route, + name="adapter_response_handler", + message_type="adapter_response", + ) + + # 注册默认消息处理器(处理所有其他消息) + runtime.add_route( + predicate=lambda _: True, # 匹配所有消息 + handler=self._handle_normal_message, + name="default_message_handler", + ) + + logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子") + + async def ensure_started(self) -> None: + """确保所有依赖任务已启动""" + if not self._started: + logger.debug("确保 MessageHandler 所有任务已启动") + + # 启动消息管理器 + if not self._message_manager_started: + await message_manager.start() + self._message_manager_started = True + logger.info("消息管理器已启动") + + self._started = True + + async def _before_hook(self, envelope: MessageEnvelope) -> None: + """ + 前置钩子:消息预处理 + + 1. 标准化 ID 为字符串 + 2. 检查是否为 echo 消息(自身发送的消息上报) + 3. 附加预处理数据到 envelope(chat_stream, message 等) + """ + if self._shutting_down: + raise UserWarning("系统正在关闭,拒绝处理消息") + + # 确保依赖服务已启动 + await self.ensure_started() + + # 提取消息信息 + message_info = envelope.get("message_info") + if not isinstance(message_info, dict): + logger.debug( + "收到缺少 message_info 的消息,已跳过。可用字段: %s", + ", ".join(envelope.keys()), + ) + raise UserWarning("消息缺少 message_info") + + # 标准化 ID 为字符串 + if message_info.get("group_info") is not None: + message_info["group_info"]["group_id"] = str( # type: ignore + message_info["group_info"]["group_id"] # type: ignore + ) + if message_info.get("user_info") is not None: + message_info["user_info"]["user_id"] = str( # type: ignore + message_info["user_info"]["user_id"] # type: ignore + ) + + # 处理自身消息上报(echo) + additional_config = message_info.get("additional_config", {}) + if additional_config and isinstance(additional_config, dict): + sent_message = additional_config.get("echo", False) + if sent_message: + # 更新消息ID + await MessageStorage.update_message(dict(envelope)) + raise UserWarning("Echo 消息已处理") + + async def _after_hook(self, envelope: MessageEnvelope) -> None: + """ + 后置钩子:消息后处理 + + 在消息处理完成后执行的清理工作 + """ + # 后置处理逻辑(如有需要) + pass + + async def _error_hook(self, envelope: MessageEnvelope, exc: BaseException) -> None: + """ + 错误钩子:统一异常处理 + """ + if isinstance(exc, UserWarning): + # UserWarning 是预期的流程控制,只记录 debug 日志 + logger.debug(f"消息处理流程控制: {exc}") + else: + message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN") + logger.error(f"处理消息 {message_id} 时出错: {exc}", exc_info=True) + + async def _handle_adapter_response_route(self, envelope: MessageEnvelope) -> MessageEnvelope | None: + """ + 处理适配器响应消息的路由处理器 + """ + message_segment = envelope.get("message_segment") + if message_segment and isinstance(message_segment, dict): + seg_data = message_segment.get("data") + if isinstance(seg_data, dict): + await self._handle_adapter_response(seg_data) + return None + + async def _handle_normal_message(self, envelope: MessageEnvelope) -> MessageEnvelope | None: + """ + 默认消息处理器:处理普通消息 + + 1. 获取或创建聊天流 + 2. 转换为 DatabaseMessages + 3. 过滤检查 + 4. 命令处理 + 5. 触发事件、存储、情绪更新 + """ try: - # 在将消息添加到管理器之前进行最终的静默检查 + message_info = envelope.get("message_info") + if not isinstance(message_info, dict): + return None + + # 获取用户和群组信息 + group_info = message_info.get("group_info") + user_info = message_info.get("user_info") + + # 获取或创建聊天流 + platform = message_info.get("platform", "unknown") + + chat = await get_chat_manager().get_or_create_stream( + platform=platform, + user_info=user_info, # type: ignore + group_info=group_info, + ) + + # 将消息信封转换为 DatabaseMessages + from src.chat.message_receive.message_processor import process_message_from_dict + message = await process_message_from_dict( + message_dict=envelope, + stream_id=chat.stream_id, + platform=chat.platform + ) + + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + + # 注册消息到聊天管理器 + get_chat_manager().register_message(message) + + # 检测是否提及机器人 + message.is_mentioned, _ = is_mentioned_bot_in_message(message) + + # 打印接收日志 + chat_name = chat.group_info.group_name if chat.group_info else "私聊" + user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" + logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m") + + # 硬编码过滤 + failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] + processed_text = message.processed_plain_text or "" + if any(keyword in processed_text for keyword in failure_keywords): + logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") + return None + + # 过滤检查 + raw_text = message.display_message or message.processed_plain_text or "" + if _check_ban_words(processed_text, chat, user_info) or _check_ban_regex( + raw_text, chat, user_info + ): + return None + + # 处理命令和后续流程 + await self._process_commands(message, chat) + + except UserWarning as uw: + logger.info(str(uw)) + except Exception as e: + logger.error(f"处理消息时出错: {e}") + logger.error(traceback.format_exc()) + + return None + + # 保留旧的 process_message 方法用于向后兼容 + async def process_message(self, envelope: MessageEnvelope) -> None: + """ + 处理接收到的消息信封(向后兼容) + + 注意:此方法已被 MessageRuntime 路由取代。 + 如果直接调用此方法,它会委托给 runtime.handle_message()。 + + Args: + envelope: 消息信封(来自适配器) + """ + if self._runtime: + await self._runtime.handle_message(envelope) + else: + # 如果 runtime 未设置,使用旧的处理流程 + await self._handle_normal_message(envelope) + + async def _process_commands(self, message: DatabaseMessages, chat: "ChatStream") -> None: + """处理命令和继续消息流程""" + try: + # 首先尝试 PlusCommand + is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) + + if is_plus_command and not plus_continue_process: + await MessageStorage.store_message(message, chat) + logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}") + return + + # 如果不是 PlusCommand,尝试传统 BaseCommand + if not is_plus_command: + is_command, cmd_result, continue_process = await self._process_base_commands(message, chat) + + if is_command and not continue_process: + await MessageStorage.store_message(message, chat) + logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") + return + + # 触发消息事件 + result = await event_manager.trigger_event( + EventType.ON_MESSAGE, + permission_group="SYSTEM", + message=message + ) + if result and not result.all_continue_process(): + raise UserWarning( + f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理" + ) + + # 预处理消息 + await self._preprocess_message(message, chat) + + except UserWarning as uw: + logger.info(str(uw)) + except Exception as e: + logger.error(f"处理命令时出错: {e}") + logger.error(traceback.format_exc()) + + async def _process_plus_commands( + self, + message: DatabaseMessages, + chat: "ChatStream" + ) -> tuple[bool, Any, bool]: + """处理 PlusCommand 系统""" + try: + text = message.processed_plain_text or "" + + # 获取配置的命令前缀 + prefixes = global_config.command.command_prefixes + + # 检查是否以任何前缀开头 + matched_prefix = None + for prefix in prefixes: + if text.startswith(prefix): + matched_prefix = prefix + break + + if not matched_prefix: + return False, None, True + + # 移除前缀 + command_part = text[len(matched_prefix):].strip() + + # 分离命令名和参数 + parts = command_part.split(None, 1) + if not parts: + return False, None, True + + command_word = parts[0].lower() + args_text = parts[1] if len(parts) > 1 else "" + + # 查找匹配的 PlusCommand + plus_command_registry = component_registry.get_plus_command_registry() + matching_commands = [] + + for plus_command_name, plus_command_class in plus_command_registry.items(): + plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name) + if not plus_command_info: + continue + + all_commands = [plus_command_name.lower()] + [ + alias.lower() for alias in plus_command_info.command_aliases + ] + if command_word in all_commands: + matching_commands.append((plus_command_class, plus_command_info, plus_command_name)) + + if not matching_commands: + return False, None, True + + # 按优先级排序 + if len(matching_commands) > 1: + matching_commands.sort(key=lambda x: x[1].priority, reverse=True) + + plus_command_class, plus_command_info, plus_command_name = matching_commands[0] + + # 检查是否被禁用 + if ( + chat + and chat.stream_id + and plus_command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) + ): + logger.info("用户禁用的PlusCommand,跳过处理") + return False, None, True + + message.is_command = True + + # 获取插件配置 + plugin_config = component_registry.get_plugin_config(plus_command_name) + + # 创建实例 + plus_command_instance = plus_command_class(message, plugin_config) + setattr(plus_command_instance, "chat_stream", chat) + + try: + if not plus_command_instance.is_chat_type_allowed(): + is_group = chat.group_info is not None + logger.info( + f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" + ) + return False, None, True + + from src.plugin_system.base.command_args import CommandArgs + command_args = CommandArgs(args_text) + plus_command_instance.args = command_args + + success, response, intercept_message = await plus_command_instance.execute(command_args) + + if success: + logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})") + else: + logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}") + + return True, response, not intercept_message + + except Exception as e: + logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}") + logger.error(traceback.format_exc()) + + try: + await plus_command_instance.send_text(f"命令执行出错: {e!s}") + except Exception: + pass + + return True, str(e), False + + except Exception as e: + logger.error(f"处理PlusCommand时出错: {e}") + return False, None, True + + async def _process_base_commands( + self, + message: DatabaseMessages, + chat: "ChatStream" + ) -> tuple[bool, Any, bool]: + """处理传统 BaseCommand 系统""" + try: + text = message.processed_plain_text or "" + + command_result = component_registry.find_command_by_text(text) + if command_result: + command_class, matched_groups, command_info = command_result + plugin_name = command_info.plugin_name + command_name = command_info.name + + if ( + chat + and chat.stream_id + and command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) + ): + logger.info("用户禁用的命令,跳过处理") + return False, None, True + + message.is_command = True + + plugin_config = component_registry.get_plugin_config(plugin_name) + command_instance: BaseCommand = command_class(message, plugin_config) + command_instance.set_matched_groups(matched_groups) + setattr(command_instance, "chat_stream", chat) + + try: + if not command_instance.is_chat_type_allowed(): + is_group = chat.group_info is not None + logger.info( + f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" + ) + return False, None, True + + success, response, intercept_message = await command_instance.execute() + + if success: + logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})") + else: + logger.warning(f"命令执行失败: {command_class.__name__} - {response}") + + return True, response, not intercept_message + + except Exception as e: + logger.error(f"执行命令时出错: {command_class.__name__} - {e}") + logger.error(traceback.format_exc()) + + try: + await command_instance.send_text(f"命令执行出错: {e!s}") + except Exception: + pass + + return True, str(e), False + + return False, None, True + + except Exception as e: + logger.error(f"处理命令时出错: {e}") + return False, None, True + + async def _preprocess_message(self, message: DatabaseMessages, chat: "ChatStream") -> None: + """预处理消息:存储、情绪更新等""" + try: + group_info = chat.group_info + + # 检查是否需要处理消息 should_process_in_manager = True if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list: - # 检查消息是否为图片或表情包 is_image_or_emoji = message.is_picid or message.is_emoji if not message.is_mentioned and not is_image_or_emoji: - logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理") + logger.debug( + f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理" + ) should_process_in_manager = False elif is_image_or_emoji: logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理") @@ -43,68 +570,81 @@ class MessageHandler: await message_manager.add_message(chat.stream_id, message) logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") - except Exception as e: - logger.error(f"消息添加到消息管理器失败: {e}") + # 存储消息 + try: + await MessageStorage.store_message(message, chat) + except Exception as e: + logger.error(f"存储消息到数据库失败: {e}") + traceback.print_exc() + + # 情绪系统更新 + try: + if global_config.mood.enable_mood: + interest_rate = message.interest_value or 0.0 + logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") + + chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) + await chat_mood.update_mood_by_message(message, interest_rate) + logger.debug("情绪状态更新完成") + except Exception as e: + logger.error(f"更新情绪状态失败: {e}") + traceback.print_exc() - # 存储消息到数据库,只进行一次写入 - try: - await MessageStorage.store_message(message, chat) except Exception as e: - logger.error(f"存储消息到数据库失败: {e}") + logger.error(f"预处理消息失败: {e}") traceback.print_exc() - # 情绪系统更新 - 在消息存储后触发情绪更新 + async def _handle_adapter_response(self, seg_data: dict | None) -> None: + """处理适配器命令响应""" try: - if global_config.mood.enable_mood: - # 获取兴趣度用于情绪更新 - interest_rate = message.interest_value - if interest_rate is None: - interest_rate = 0.0 - logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") + from src.plugin_system.apis.send_api import put_adapter_response + + if isinstance(seg_data, dict): + request_id = seg_data.get("request_id") + response_data = seg_data.get("response") + else: + request_id = None + response_data = None + + if request_id and response_data: + logger.debug(f"收到适配器响应,request_id={request_id}") + put_adapter_response(request_id, response_data) + else: + logger.warning( + f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}" + ) - # 获取当前聊天的情绪对象并更新情绪状态 - chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) - await chat_mood.update_mood_by_message(message, interest_rate) - logger.debug("情绪状态更新完成") except Exception as e: - logger.error(f"更新情绪状态失败: {e}") - traceback.print_exc() + logger.error(f"处理适配器响应时出错: {e}") + + async def shutdown(self) -> None: + """关闭消息处理器""" + self._shutting_down = True + logger.info("MessageHandler 正在关闭...") - async def handle_message(self, envelope: MessageEnvelope): - # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError - message_info = envelope.get("message_info") - if not isinstance(message_info, dict): - logger.debug( - "收到缺少 message_info 的消息,已跳过。可用字段: %s", - ", ".join(envelope.keys()), - ) - return +# 全局单例 +_message_handler: MessageHandler | None = None - if message_info.get("group_info") is not None: - message_info["group_info"]["group_id"] = str( # type: ignore - message_info["group_info"]["group_id"] # type: ignore - ) - if message_info.get("user_info") is not None: - message_info["user_info"]["user_id"] = str( # type: ignore - message_info["user_info"]["user_id"] # type: ignore - ) - group_info = message_info.get("group_info") - user_info = message_info.get("user_info") +def get_message_handler() -> MessageHandler: + """获取 MessageHandler 单例""" + global _message_handler + if _message_handler is None: + _message_handler = MessageHandler() + return _message_handler - chat_stream = await get_chat_manager().get_or_create_stream( - platform=envelope["platform"], # type: ignore - user_info=user_info, # type: ignore - group_info=group_info, - ) - # 生成 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat_stream.stream_id, - platform=chat_stream.platform - ) +async def shutdown_message_handler() -> None: + """关闭 MessageHandler""" + global _message_handler + if _message_handler: + await _message_handler.shutdown() + _message_handler = None - \ No newline at end of file + +__all__ = [ + "MessageHandler", + "get_message_handler", + "shutdown_message_handler", +] diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 073356d44..9d72504c4 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,13 +1,26 @@ +""" +统一消息发送器 + +重构说明(2025-11): +- 使用 CoreSinkManager 发送消息,而不是直接通过 WS 连接 +- MessageServer 仅作为与旧适配器的兼容层 +- 所有发送的消息都通过 CoreSinkManager.send_outgoing() 路由到适配器 +""" + import asyncio import traceback +import time +import uuid +from typing import Any, cast from rich.traceback import install +from mofox_bus import MessageEnvelope + from src.chat.message_receive.message import MessageSending from src.chat.message_receive.storage import MessageStorage from src.chat.utils.utils import calculate_typing_time, truncate_message from src.common.logger import get_logger -from src.common.message.api import get_global_api install(extra_lines=3) @@ -15,12 +28,30 @@ logger = get_logger("sender") async def send_message(message: MessageSending, show_log=True) -> bool: - """合并后的消息发送函数,包含WS发送和日志记录""" + """ + 合并后的消息发送函数 + + 重构后使用 CoreSinkManager 发送消息,而不是直接调用 MessageServer + + Args: + message: 要发送的消息 + show_log: 是否显示日志 + + Returns: + bool: 是否发送成功 + """ message_preview = truncate_message(message.processed_plain_text, max_length=120) try: - # 直接调用API发送消息 - await get_global_api().send_message(message) + # 将 MessageSending 转换为 MessageEnvelope + envelope = _message_sending_to_envelope(message) + + # 通过 CoreSinkManager 发送 + from src.common.core_sink_manager import get_core_sink_manager + + manager = get_core_sink_manager() + await manager.send_outgoing(envelope) + if show_log: logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'") @@ -44,7 +75,67 @@ async def send_message(message: MessageSending, show_log=True) -> bool: except Exception as e: logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}") traceback.print_exc() - raise e # 重新抛出其他异常 + raise e + + +def _message_sending_to_envelope(message: MessageSending) -> MessageEnvelope: + """ + 将 MessageSending 转换为 MessageEnvelope + + Args: + message: MessageSending 对象 + + Returns: + MessageEnvelope: 消息信封 + """ + # 构建消息信息 + message_info: dict[str, Any] = { + "message_id": message.message_info.message_id, + "time": message.message_info.time or time.time(), + "platform": message.message_info.platform, + "user_info": { + "user_id": message.message_info.user_info.user_id, + "user_nickname": message.message_info.user_info.user_nickname, + "platform": message.message_info.user_info.platform, + } if message.message_info.user_info else None, + } + + # 添加群组信息(如果有) + if message.chat_stream and message.chat_stream.group_info: + message_info["group_info"] = { + "group_id": message.chat_stream.group_info.group_id, + "group_name": message.chat_stream.group_info.group_name, + "platform": message.chat_stream.group_info.group_platform, + } + + # 构建消息段 + message_segment: dict[str, Any] + if message.message_segment: + message_segment = { + "type": message.message_segment.type, + "data": message.message_segment.data, + } + else: + # 默认为文本消息 + message_segment = { + "type": "text", + "data": message.processed_plain_text or "", + } + + # 添加回复信息(如果有) + if message.reply_to: + message_segment["reply_to"] = message.reply_to + + # 构建消息信封 + envelope = cast(MessageEnvelope, { + "id": str(uuid.uuid4()), + "direction": "outgoing", + "platform": message.message_info.platform, + "message_info": message_info, + "message_segment": message_segment, + }) + + return envelope class HeartFCSender: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 92f94ef64..2e4c717fb 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,9 +1,9 @@ import asyncio import time import traceback -from typing import Any +from typing import Any, TYPE_CHECKING -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.timer_calculator import Timer from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger @@ -14,6 +14,9 @@ from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.component_types import ActionInfo, ComponentType from src.plugin_system.core.component_registry import component_registry +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("action_manager") @@ -48,7 +51,7 @@ class ChatterActionManager: reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream: ChatStream, + chat_stream: "ChatStream", log_prefix: str, shutting_down: bool = False, action_message: DatabaseMessages | None = None, @@ -476,7 +479,7 @@ class ChatterActionManager: async def _send_and_store_reply( self, - chat_stream: ChatStream, + chat_stream: "ChatStream", response_set, loop_start_time, action_message, diff --git a/src/common/core_sink.py b/src/common/core_sink.py deleted file mode 100644 index 261adf397..000000000 --- a/src/common/core_sink.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -从 src.main 导出 core_sink 的辅助函数 - -由于 src.main 中实际使用的是 InProcessCoreSink, -我们需要创建一个全局访问点 -""" - -from mofox_bus import CoreSink, InProcessCoreSink - -_global_core_sink: CoreSink | None = None - - -def set_core_sink(sink: CoreSink) -> None: - """设置全局 core sink""" - global _global_core_sink - _global_core_sink = sink - - -def get_core_sink() -> CoreSink: - """获取全局 core sink""" - global _global_core_sink - if _global_core_sink is None: - raise RuntimeError("Core sink 尚未初始化") - return _global_core_sink - - -async def push_outgoing(envelope) -> None: - """将消息推送到 core sink 的 outgoing 通道""" - sink = get_core_sink() - push = getattr(sink, "push_outgoing", None) - if push is None: - raise RuntimeError("当前 core sink 不支持 push_outgoing 方法") - await push(envelope) - -__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"] diff --git a/src/common/core_sink_manager.py b/src/common/core_sink_manager.py new file mode 100644 index 000000000..2ebcfd3eb --- /dev/null +++ b/src/common/core_sink_manager.py @@ -0,0 +1,401 @@ +""" +CoreSink 统一管理器 + +负责管理 InProcessCoreSink 和 ProcessCoreSink 双实例, +提供统一的消息收发接口,自动维护与适配器子进程的通信管道。 + +核心职责: +1. 创建和管理 InProcessCoreSink(进程内消息)和 ProcessCoreSink(跨进程消息) +2. 自动维护 ProcessCoreSink 与子进程的通信管道 +3. 使用 MessageRuntime 进行消息路由和处理 +4. 提供统一的消息发送接口 + +架构说明(2025-11 重构): +- 集成 mofox_bus.MessageRuntime 作为消息路由中心 +- 使用 @runtime.on_message() 装饰器注册消息处理器 +- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑 +- 简化消息处理链条,提高可扩展性 +""" + +from __future__ import annotations + +import asyncio +import contextlib +import multiprocessing as mp +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional + +from mofox_bus import ( + InProcessCoreSink, + MessageEnvelope, + MessageRuntime, + ProcessCoreSinkServer, +) + +from src.common.logger import get_logger + +if TYPE_CHECKING: + from chat.message_receive.message_handler import MessageHandler + +logger = get_logger("core_sink_manager") + + +# 消息处理器类型 +MessageHandlerCallback = Callable[[MessageEnvelope], Awaitable[None]] + + +class CoreSinkManager: + """ + CoreSink 统一管理器 + + 管理 InProcessCoreSink 和 ProcessCoreSinkServer 双实例, + 集成 MessageRuntime 提供统一的消息路由和收发接口。 + + 架构说明: + - InProcessCoreSink: 用于同进程内的适配器(run_in_subprocess=False) + - ProcessCoreSinkServer: 用于管理与子进程适配器的通信 + - MessageRuntime: 统一消息路由,支持 @on_message 装饰器和钩子机制 + + 消息流向: + 1. 适配器(同进程)→ InProcessCoreSink → MessageRuntime.handle_message() → 注册的处理器 + 2. 适配器(子进程)→ ProcessCoreSinkServer → MessageRuntime.handle_message() → 注册的处理器 + 3. 核心回复 → CoreSinkManager.send_outgoing → 适配器 + + 使用 MessageRuntime 的优势: + - 支持 @runtime.on_message(message_type="xxx") 按消息类型路由 + - 支持 before_hook/after_hook/error_hook 统一处理流程 + - 支持中间件机制(洋葱模型) + - 自动处理同步/异步处理器 + """ + + def __init__(self): + # MessageRuntime 实例 + self._runtime: MessageRuntime = MessageRuntime() + + # InProcessCoreSink 实例(用于同进程适配器) + self._in_process_sink: InProcessCoreSink | None = None + + # 子进程通信管理 + # key: adapter_name, value: (ProcessCoreSinkServer, incoming_queue, outgoing_queue) + self._process_sinks: Dict[str, tuple[ProcessCoreSinkServer, mp.Queue, mp.Queue]] = {} + + # multiprocessing context + self._mp_ctx = mp.get_context("spawn") + + # 运行状态 + self._running = False + self._initialized = False + + @property + def runtime(self) -> MessageRuntime: + """ + 获取 MessageRuntime 实例 + + 外部模块可以通过此属性注册消息处理器、钩子等: + + ```python + manager = get_core_sink_manager() + + # 注册消息处理器 + @manager.runtime.on_message(message_type="text") + async def handle_text(envelope: MessageEnvelope): + ... + + # 注册前置钩子 + manager.runtime.register_before_hook(my_before_hook) + ``` + + Returns: + MessageRuntime 实例 + """ + return self._runtime + + async def initialize(self) -> None: + """ + 初始化 CoreSink 管理器 + + 创建 InProcessCoreSink,将收到的消息交给 MessageRuntime 处理。 + """ + if self._initialized: + logger.warning("CoreSinkManager 已经初始化,跳过重复初始化") + return + + logger.info("正在初始化 CoreSink 管理器...") + + # 创建 InProcessCoreSink,使用 MessageRuntime 作为消息处理入口 + self._in_process_sink = InProcessCoreSink(self._dispatch_to_runtime) + + self._running = True + self._initialized = True + + logger.info("CoreSink 管理器初始化完成(已集成 MessageRuntime)") + + async def shutdown(self) -> None: + """关闭 CoreSink 管理器""" + if not self._running: + return + + logger.info("正在关闭 CoreSink 管理器...") + self._running = False + + # 关闭所有 ProcessCoreSinkServer + for adapter_name, (server, _, _) in list(self._process_sinks.items()): + try: + await server.close() + logger.info(f"已关闭适配器 {adapter_name} 的 ProcessCoreSinkServer") + except Exception as e: + logger.error(f"关闭适配器 {adapter_name} 的 ProcessCoreSinkServer 时出错: {e}") + + self._process_sinks.clear() + + # 关闭 InProcessCoreSink + if self._in_process_sink: + await self._in_process_sink.close() + self._in_process_sink = None + + self._initialized = False + logger.info("CoreSink 管理器已关闭") + + def get_in_process_sink(self) -> InProcessCoreSink: + """ + 获取 InProcessCoreSink 实例 + + 用于同进程运行的适配器 + + Returns: + InProcessCoreSink 实例 + + Raises: + RuntimeError: 如果管理器未初始化 + """ + if self._in_process_sink is None: + raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()") + return self._in_process_sink + + def create_process_sink_queues(self, adapter_name: str) -> tuple[mp.Queue, mp.Queue]: + """ + 为子进程适配器创建通信队列 + + 创建 incoming 和 outgoing 队列对,用于与子进程适配器通信。 + 同时创建 ProcessCoreSinkServer 来处理消息转发。 + + Args: + adapter_name: 适配器名称 + + Returns: + (to_core_queue, from_core_queue) 元组 + - to_core_queue: 子进程发送到核心的队列 + - from_core_queue: 核心发送到子进程的队列 + + Raises: + RuntimeError: 如果管理器未初始化 + """ + if not self._initialized: + raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()") + + if adapter_name in self._process_sinks: + logger.warning(f"适配器 {adapter_name} 的队列已存在,将被覆盖") + # 先关闭旧的 + old_server, _, _ = self._process_sinks[adapter_name] + asyncio.create_task(old_server.close()) + + # 创建通信队列 + incoming_queue = self._mp_ctx.Queue() # 子进程 → 核心 + outgoing_queue = self._mp_ctx.Queue() # 核心 → 子进程 + + # 创建 ProcessCoreSinkServer,使用 MessageRuntime 处理消息 + server = ProcessCoreSinkServer( + incoming_queue=incoming_queue, + outgoing_queue=outgoing_queue, + core_handler=self._dispatch_to_runtime, + name=adapter_name, + ) + + # 启动服务器 + server.start() + + # 存储引用 + self._process_sinks[adapter_name] = (server, incoming_queue, outgoing_queue) + + logger.info(f"为适配器 {adapter_name} 创建了 ProcessCoreSink 通信队列") + + return incoming_queue, outgoing_queue + + def remove_process_sink(self, adapter_name: str) -> None: + """ + 移除子进程适配器的通信队列 + + Args: + adapter_name: 适配器名称 + """ + if adapter_name not in self._process_sinks: + logger.warning(f"适配器 {adapter_name} 的队列不存在") + return + + server, _, _ = self._process_sinks.pop(adapter_name) + asyncio.create_task(server.close()) + logger.info(f"已移除适配器 {adapter_name} 的 ProcessCoreSink 通信队列") + + async def send_outgoing( + self, + envelope: MessageEnvelope, + platform: str | None = None, + adapter_name: str | None = None + ) -> None: + """ + 发送消息到适配器 + + 根据 platform 或 adapter_name 路由到正确的适配器。 + + Args: + envelope: 消息信封 + platform: 目标平台(可选) + adapter_name: 目标适配器名称(可选) + + 路由规则: + 1. 如果指定了 adapter_name,直接发送到该适配器 + 2. 如果指定了 platform,发送到所有匹配平台的适配器 + 3. 如果都没指定,从 envelope 中提取 platform 并广播 + """ + # 从 envelope 中获取 platform + if platform is None: + platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform") + + # 发送到 InProcessCoreSink(会自动广播到所有注册的 outgoing handler) + if self._in_process_sink: + await self._in_process_sink.push_outgoing(envelope) + + # 发送到所有 ProcessCoreSinkServer + for name, (server, _, _) in self._process_sinks.items(): + if adapter_name and name != adapter_name: + continue + try: + await server.push_outgoing(envelope) + except Exception as e: + logger.error(f"发送消息到适配器 {name} 失败: {e}") + + async def _dispatch_to_runtime(self, envelope: MessageEnvelope) -> None: + """ + 将消息分发给 MessageRuntime 处理 + + 这是内部方法,由 InProcessCoreSink 和 ProcessCoreSinkServer 调用。 + 所有从适配器接收到的消息都会经过这里,然后交给 MessageRuntime 路由。 + + Args: + envelope: 消息信封 + """ + if not self._running: + logger.warning("CoreSinkManager 未运行,忽略接收到的消息") + return + + try: + # 使用 MessageRuntime 处理消息 + await self._runtime.handle_message(envelope) + except Exception as e: + logger.error(f"MessageRuntime 处理消息时出错: {e}", exc_info=True) + + +# 全局单例 +_core_sink_manager: CoreSinkManager | None = None + + +def get_core_sink_manager() -> CoreSinkManager: + """获取 CoreSinkManager 单例""" + global _core_sink_manager + if _core_sink_manager is None: + _core_sink_manager = CoreSinkManager() + return _core_sink_manager + + +def get_message_runtime() -> MessageRuntime: + """ + 获取全局 MessageRuntime 实例 + + 这是获取 MessageRuntime 的推荐方式,用于注册消息处理器、钩子等: + + ```python + from src.common.core_sink_manager import get_message_runtime + + runtime = get_message_runtime() + + @runtime.on_message(message_type="text") + async def handle_text(envelope: MessageEnvelope): + ... + ``` + + Returns: + MessageRuntime 实例 + """ + return get_core_sink_manager().runtime + + +async def initialize_core_sink_manager() -> CoreSinkManager: + """ + 初始化 CoreSinkManager 单例 + + Returns: + 初始化后的 CoreSinkManager 实例 + """ + manager = get_core_sink_manager() + await manager.initialize() + return manager + + +async def shutdown_core_sink_manager() -> None: + """关闭 CoreSinkManager 单例""" + global _core_sink_manager + if _core_sink_manager: + await _core_sink_manager.shutdown() + _core_sink_manager = None + + +# ============================================================================ +# 向后兼容的 API +# ============================================================================ + +def get_core_sink() -> InProcessCoreSink: + """ + 获取 InProcessCoreSink 实例(向后兼容) + + 这是旧版 API,推荐使用 get_core_sink_manager().get_in_process_sink() + + Returns: + InProcessCoreSink 实例 + """ + return get_core_sink_manager().get_in_process_sink() + + +def set_core_sink(sink: Any) -> None: + """ + 设置 CoreSink(向后兼容,现已弃用) + + 新架构中 CoreSink 由 CoreSinkManager 统一管理,不再支持外部设置。 + 此函数保留仅为兼容旧代码,调用会记录警告日志。 + """ + logger.warning( + "set_core_sink() 已弃用,CoreSink 现由 CoreSinkManager 统一管理。" + "请使用 initialize_core_sink_manager() 初始化。" + ) + + +async def push_outgoing(envelope: MessageEnvelope) -> None: + """ + 将消息推送到所有适配器(向后兼容) + + Args: + envelope: 消息信封 + """ + manager = get_core_sink_manager() + await manager.send_outgoing(envelope) + + +__all__ = [ + "CoreSinkManager", + "get_core_sink_manager", + "get_message_runtime", + "initialize_core_sink_manager", + "shutdown_core_sink_manager", + # 向后兼容 + "get_core_sink", + "set_core_sink", + "push_outgoing", +] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index af06eb7b5..d3bdaf892 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -16,6 +16,24 @@ class DatabaseUserInfo(BaseDataModel): user_nickname: str = field(default_factory=str) # 用户昵称 user_cardname: str | None = None # 用户备注名或群名片,可为空 + @classmethod + def from_dict(cls, data: dict) -> "DatabaseUserInfo": + """从字典创建实例""" + return cls( + platform=data.get("platform", ""), + user_id=data.get("user_id", ""), + user_nickname=data.get("user_nickname", ""), + user_cardname=data.get("user_cardname"), + ) + + def to_dict(self) -> dict: + """将实例转换为字典""" + return { + "platform": self.platform, + "user_id": self.user_id, + "user_nickname": self.user_nickname, + "user_cardname": self.user_cardname, + } @dataclass class DatabaseGroupInfo(BaseDataModel): @@ -26,7 +44,23 @@ class DatabaseGroupInfo(BaseDataModel): group_name: str = field(default_factory=str) # 群组名称 group_platform: str | None = None # 群组所在平台,可为空 - + @classmethod + def from_dict(cls, data: dict) -> "DatabaseGroupInfo": + """从字典创建实例""" + return cls( + group_id=data.get("group_id", ""), + group_name=data.get("group_name", ""), + group_platform=data.get("group_platform"), + ) + + def to_dict(self) -> dict: + """将实例转换为字典""" + return { + "group_id": self.group_id, + "group_name": self.group_name, + "group_platform": self.group_platform, + } + @dataclass class DatabaseChatInfo(BaseDataModel): """ diff --git a/src/main.py b/src/main.py index 5070f6696..ae62b18a5 100644 --- a/src/main.py +++ b/src/main.py @@ -6,17 +6,21 @@ import sys import time import traceback from collections.abc import Callable, Coroutine -from functools import partial from random import choices from typing import Any -from mofox_bus import InProcessCoreSink, MessageEnvelope from rich.traceback import install from src.chat.emoji_system.emoji_manager import get_emoji_manager -from src.chat.message_receive.bot import chat_bot +from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask +from src.common.core_sink_manager import ( + CoreSinkManager, + get_core_sink_manager, + initialize_core_sink_manager, + shutdown_core_sink_manager, +) from src.common.logger import get_logger # 全局背景任务集合 @@ -55,28 +59,15 @@ EGG_PHRASES: list[tuple[str, int]] = [ ] -def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float) -> None: - """后台任务完成时的回调函数""" - end_time = time.time() - duration = end_time - start_time - try: - task.result() # 如果任务有异常,这里会重新抛出 - logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s") - except asyncio.CancelledError: - logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s") - except Exception: - logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:") - logger.error(traceback.format_exc()) - - class MainSystem: """主系统类,负责协调所有组件""" def __init__(self) -> None: self.individuality: Individuality = get_individuality() - # 创建核心消息接收器 - self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper) + # CoreSinkManager 和 MessageHandler 将在 initialize() 中创建 + self.core_sink_manager: CoreSinkManager | None = None + self.message_handler = None # 使用服务器 self.server: Server = get_global_server() @@ -163,10 +154,11 @@ class MainSystem: continue try: + from src.plugin_system.base.component_types import ComponentType as CT from src.plugin_system.core.component_registry import component_registry component_class = component_registry.get_component_class( - calc_name, ComponentType.INTEREST_CALCULATOR + calc_name, CT.INTEREST_CALCULATOR ) if not component_class: @@ -299,6 +291,18 @@ class MainSystem: except Exception as e: logger.error(f"准备停止适配器管理器时出错: {e}") + # 停止 CoreSinkManager + try: + cleanup_tasks.append(("CoreSinkManager", shutdown_core_sink_manager())) + except Exception as e: + logger.error(f"准备停止 CoreSinkManager 时出错: {e}") + + # 停止 MessageHandler + try: + cleanup_tasks.append(("MessageHandler", shutdown_message_handler())) + except Exception as e: + logger.error(f"准备停止 MessageHandler 时出错: {e}") + # 并行执行所有清理任务 if cleanup_tasks: logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...") @@ -352,27 +356,6 @@ class MainSystem: except Exception as e: logger.error(f"同步清理资源时出错: {e}") - async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None: - """并行处理消息的包装器""" - try: - start_time = time.time() - message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN") - # 检查系统是否正在关闭 - if self._shutting_down: - logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}") - return - - # 创建后台任务 - task = asyncio.create_task(chat_bot.message_process(envelope)) - logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})") - - # 添加一个回调函数,当任务完成时,它会被调用 - task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time)) - except Exception: - logger.error("在创建消息处理任务时发生严重错误:") - logger.error(traceback.format_exc()) - - async def initialize(self) -> None: """初始化系统组件""" # 检查必要的配置 @@ -382,6 +365,18 @@ class MainSystem: logger.info(f"正在唤醒{global_config.bot.nickname}......") + # 初始化 CoreSinkManager(包含 MessageRuntime) + logger.info("正在初始化 CoreSinkManager...") + self.core_sink_manager = await initialize_core_sink_manager() + + # 获取 MessageHandler 并向 MessageRuntime 注册处理器 + self.message_handler = get_message_handler() + self.message_handler.set_core_sink_manager(self.core_sink_manager) + + # 向 MessageRuntime 注册消息处理器和钩子 + self.message_handler.register_handlers(self.core_sink_manager.runtime) + logger.info("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)") + # 初始化组件 await self._init_components() @@ -453,7 +448,11 @@ MoFox_Bot(第三方修改版) logger.error(f"统一调度器初始化失败: {e}") # 设置核心消息接收器到插件管理器 - plugin_manager.set_core_sink(self.core_sink) + # 使用 CoreSinkManager 的 InProcessCoreSink + if self.core_sink_manager: + plugin_manager.set_core_sink(self.core_sink_manager.get_in_process_sink()) + else: + logger.error("CoreSinkManager 未初始化,无法设置核心消息接收器") # 加载所有插件 plugin_manager.load_all_plugins() @@ -505,8 +504,8 @@ MoFox_Bot(第三方修改版) except Exception as e: logger.error(f"LPMM知识库初始化失败: {e}") - # 消息接收器已经在 __init__ 中创建,无需再次注册 - logger.info("核心消息接收器已就绪") + # 消息接收器已在 initialize() 中通过 CoreSinkManager 创建 + logger.info("核心消息接收器已就绪(通过 CoreSinkManager)") # 启动消息重组器 try: diff --git a/src/mofox_bus/__init__.py b/src/mofox_bus/__init__.py deleted file mode 100644 index b0868c14b..000000000 --- a/src/mofox_bus/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -MoFox 内部通用消息总线实现。 - -该模块导出 TypedDict 消息模型、序列化工具、传输层封装以及适配器辅助工具, -供核心进程与各类平台适配器共享。 -""" - -from . import codec, types -from .adapter_utils import ( - AdapterTransportOptions, - AdapterBase, - BatchDispatcher, - CoreSink, - CoreMessageSink, - HttpAdapterOptions, - InProcessCoreSink, - ProcessCoreSink, - ProcessCoreSinkServer, - WebSocketLike, - WebSocketAdapterOptions, -) -from .api import MessageClient, MessageServer -from .codec import dumps_message, dumps_messages, loads_message, loads_messages -from .builder import MessageBuilder -from .router import RouteConfig, Router, TargetConfig -from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware -from .types import ( - FormatInfoPayload, - GroupInfoPayload, - MessageDirection, - MessageEnvelope, - MessageInfoPayload, - SegPayload, - - TemplateInfoPayload, - UserInfoPayload, -) - -__all__ = [ - # TypedDict model - "MessageDirection", - "MessageEnvelope", - "SegPayload", - "UserInfoPayload", - "GroupInfoPayload", - "FormatInfoPayload", - "TemplateInfoPayload", - "MessageInfoPayload", - # Codec helpers - "codec", - "dumps_message", - "dumps_messages", - "loads_message", - "loads_messages", - "MessageBuilder", - # Runtime / routing - "MessageRoute", - "MessageRuntime", - "MessageProcessingError", - "Middleware", - # Server/client/router - "MessageServer", - "MessageClient", - "Router", - "RouteConfig", - "TargetConfig", - # Adapter helpers - "AdapterTransportOptions", - "AdapterBase", - "BatchDispatcher", - "CoreSink", - "CoreMessageSink", - "InProcessCoreSink", - "ProcessCoreSink", - "ProcessCoreSinkServer", - "WebSocketLike", - "WebSocketAdapterOptions", - "HttpAdapterOptions", -] diff --git a/src/mofox_bus/adapter_utils.py b/src/mofox_bus/adapter_utils.py deleted file mode 100644 index 275c35d2e..000000000 --- a/src/mofox_bus/adapter_utils.py +++ /dev/null @@ -1,485 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import logging -import multiprocessing as mp -from dataclasses import dataclass -from typing import Any, AsyncIterator, Awaitable, Callable, Protocol - -import orjson -from aiohttp import web as aiohttp_web -import websockets - -from .types import MessageEnvelope - -logger = logging.getLogger("mofox_bus.adapter") - - -OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]] - - -class CoreMessageSink(Protocol): - async def send(self, message: MessageEnvelope) -> None: ... - - async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional - - -class CoreSink(CoreMessageSink, Protocol): - """ - 双向 CoreSink 协议: - - send/send_many: 适配器 → 核心(incoming) - - push_outgoing: 核心 → 适配器(outgoing) - """ - - def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ... - - def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ... - - async def push_outgoing(self, envelope: MessageEnvelope) -> None: ... - - async def close(self) -> None: ... # pragma: no cover - lifecycle hook - - -class WebSocketLike(Protocol): - def __aiter__(self) -> AsyncIterator[str | bytes]: ... - - @property - def closed(self) -> bool: ... - - async def send(self, data: str | bytes) -> None: ... - - async def close(self) -> None: ... - - -@dataclass -class WebSocketAdapterOptions: - url: str - headers: dict[str, str] | None = None - incoming_parser: Callable[[str | bytes], Any] | None = None - outgoing_encoder: Callable[[MessageEnvelope], str | bytes] | None = None - - -@dataclass -class HttpAdapterOptions: - host: str = "0.0.0.0" - port: int = 8089 - path: str = "/adapter/messages" - app: aiohttp_web.Application | None = None - - -AdapterTransportOptions = WebSocketAdapterOptions | HttpAdapterOptions | None - - -class AdapterBase: - """ - 适配器基类:负责平台原始消息与 MessageEnvelope 之间的互转。 - 子类需要实现平台入站解析与出站发送逻辑。 - """ - - platform: str = "unknown" - - def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None): - """ - Args: - core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。 - transport: 传入 WebSocketAdapterOptions / HttpAdapterOptions 即可自动管理监听逻辑。 - """ - self.core_sink = core_sink - self._transport_config = transport - self._ws: WebSocketLike | None = None - self._ws_task: asyncio.Task | None = None - self._http_runner: aiohttp_web.AppRunner | None = None - self._http_site: aiohttp_web.BaseSite | None = None - - async def start(self) -> None: - """启动适配器的传输层监听(如果配置了传输选项)。""" - if hasattr(self.core_sink, "set_outgoing_handler"): - try: - self.core_sink.set_outgoing_handler(self._on_outgoing_from_core) - except Exception: - logger.exception("注册 outgoing 处理程序到核心接收器失败") - if isinstance(self._transport_config, WebSocketAdapterOptions): - await self._start_ws_transport(self._transport_config) - elif isinstance(self._transport_config, HttpAdapterOptions): - await self._start_http_transport(self._transport_config) - - - async def stop(self) -> None: - """停止适配器的传输层监听(如果配置了传输选项)。""" - remove = getattr(self.core_sink, "remove_outgoing_handler", None) - if callable(remove): - try: - remove(self._on_outgoing_from_core) - except Exception: - logger.exception("从核心接收器分离 outgoing 处理程序失败") - elif hasattr(self.core_sink, "set_outgoing_handler"): - try: - self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type] - except Exception: - logger.exception("从核心接收器分离 outgoing 处理程序失败") - if self._ws_task: - self._ws_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._ws_task - self._ws_task = None - if self._ws: - await self._ws.close() - self._ws = None - if self._http_site: - await self._http_site.stop() - self._http_site = None - if self._http_runner: - await self._http_runner.cleanup() - self._http_runner = None - - async def on_platform_message(self, raw: Any) -> None: - """处理平台下发的单条消息并交给核心。""" - envelope = await self.from_platform_message(raw) - await self.core_sink.send(envelope) - - async def on_platform_messages(self, raw_messages: list[Any]) -> None: - """批量推送入口,内部自动批量或逐条送入核心。""" - envelopes = [await self.from_platform_message(raw) for raw in raw_messages] - await _send_many(self.core_sink, envelopes) - - async def send_to_platform(self, envelope: MessageEnvelope) -> None: - """核心生成单条消息时调用,由子类或自动传输层发送。""" - await self._send_platform_message(envelope) - - async def send_batch_to_platform(self, envelopes: list[MessageEnvelope]) -> None: - """默认串行发送整批消息,子类可根据平台特性重写。""" - for env in envelopes: - await self._send_platform_message(env) - - async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None: - """核心生成 outgoing envelope 时的内部处理逻辑""" - platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform") - if platform and platform != getattr(self, "platform", None): - return - await self._send_platform_message(envelope) - - async def from_platform_message(self, raw: Any) -> MessageEnvelope: - """子类必须实现:将平台原始结构转换为统一 MessageEnvelope。""" - raise NotImplementedError - - async def _send_platform_message(self, envelope: MessageEnvelope) -> None: - """子类必须实现:把 MessageEnvelope 转为平台格式并发送出去。""" - if isinstance(self._transport_config, WebSocketAdapterOptions): - await self._send_via_ws(envelope) - return - raise NotImplementedError - - async def _start_ws_transport(self, options: WebSocketAdapterOptions) -> None: - self._ws = await websockets.connect(options.url, extra_headers=options.headers) - self._ws_task = asyncio.create_task(self._ws_listen_loop(options)) - - async def _ws_listen_loop(self, options: WebSocketAdapterOptions) -> None: - assert self._ws is not None - parser = options.incoming_parser or self._default_ws_parser - try: - async for raw in self._ws: - payload = parser(raw) - await self.on_platform_message(payload) - finally: - pass - - async def _send_via_ws(self, envelope: MessageEnvelope) -> None: - if self._ws is None or self._ws.closed: - raise RuntimeError("WebSocket transport is not active") - encoder = None - if isinstance(self._transport_config, WebSocketAdapterOptions): - encoder = self._transport_config.outgoing_encoder - data = encoder(envelope) if encoder else self._default_ws_encoder(envelope) - await self._ws.send(data) - - async def _start_http_transport(self, options: HttpAdapterOptions) -> None: - app = options.app or aiohttp_web.Application() - app.add_routes([aiohttp_web.post(options.path, self._handle_http_request)]) - self._http_runner = aiohttp_web.AppRunner(app) - await self._http_runner.setup() - self._http_site = aiohttp_web.TCPSite(self._http_runner, options.host, options.port) - await self._http_site.start() - - async def _handle_http_request(self, request: aiohttp_web.Request) -> aiohttp_web.Response: - raw = await request.read() - data = orjson.loads(raw) if raw else {} - if isinstance(data, list): - await self.on_platform_messages(data) - else: - await self.on_platform_message(data) - return aiohttp_web.json_response({"status": "ok"}) - - @staticmethod - def _default_ws_parser(raw: str | bytes) -> Any: - data = orjson.loads(raw) - if isinstance(data, dict) and data.get("type") == "message" and "payload" in data: - return data["payload"] - return data - - @staticmethod - def _default_ws_encoder(envelope: MessageEnvelope) -> bytes: - return orjson.dumps({"type": "send", "payload": envelope}) - - -class InProcessCoreSink(CoreSink): - """ - 进程内核心消息 sink,实现 CoreSink 协议。 - """ - - def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]): - self._handler = handler - self._outgoing_handlers: set[OutgoingHandler] = set() - - def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: - if handler is None: - return - self._outgoing_handlers.add(handler) - - def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: - self._outgoing_handlers.discard(handler) - - async def send(self, message: MessageEnvelope) -> None: - await self._handler(message) - - async def send_many(self, messages: list[MessageEnvelope]) -> None: - for message in messages: - await self._handler(message) - - async def push_outgoing(self, envelope: MessageEnvelope) -> None: - if not self._outgoing_handlers: - logger.debug("Outgoing envelope dropped: no handler registered") - return - for callback in list(self._outgoing_handlers): - await callback(envelope) - - async def close(self) -> None: # pragma: no cover - symmetry - self._outgoing_handlers.clear() - - -class ProcessCoreSink(CoreSink): - """ - 进程间核心消息 sink,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化 - """ - - _CONTROL_STOP = {"__core_sink_control__": "stop"} - - def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None: - self._to_core_queue = to_core_queue - self._from_core_queue = from_core_queue - self._outgoing_handler: OutgoingHandler | None = None - self._closed = False - self._listener_task: asyncio.Task | None = None - self._loop = asyncio.get_event_loop() - - def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: - self._outgoing_handler = handler - if handler is not None and (self._listener_task is None or self._listener_task.done()): - self._listener_task = self._loop.create_task(self._listen_from_core()) - - def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: - if self._outgoing_handler is handler: - self._outgoing_handler = None - if self._listener_task and not self._listener_task.done(): - self._listener_task.cancel() - - async def send(self, message: MessageEnvelope) -> None: - await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message}) - - async def send_many(self, messages: list[MessageEnvelope]) -> None: - for message in messages: - await self.send(message) - - async def push_outgoing(self, envelope: MessageEnvelope) -> None: - logger.debug("ProcessCoreSink.push_outgoing 在子进程中调用; 被忽略") - - async def close(self) -> None: - if self._closed: - return - self._closed = True - await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP) - if self._listener_task: - self._listener_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._listener_task - self._listener_task = None - - async def _listen_from_core(self) -> None: - while not self._closed: - try: - item = await asyncio.to_thread(self._from_core_queue.get) - except asyncio.CancelledError: - break - if item == self._CONTROL_STOP: - break - if isinstance(item, dict) and item.get("kind") == "outgoing": - envelope = item.get("payload") - if self._outgoing_handler: - try: - await self._outgoing_handler(envelope) - except Exception: # pragma: no cover - logger.exception("处理 ProcessCoreSink 中的 outgoing 信封失败") - else: - logger.debug(f"ProcessCoreSink 接受到未知负载: {item}") - - -class ProcessCoreSinkServer: - """ - 进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。 - - 将传入的 incoming 消息转发给指定的 handler - - 将接收到的 outgoing 消息放入 outgoing 队列 - """ - - def __init__( - self, - *, - incoming_queue: mp.Queue, - outgoing_queue: mp.Queue, - core_handler: Callable[[MessageEnvelope], Awaitable[None]], - name: str | None = None, - ) -> None: - self._incoming_queue = incoming_queue - self._outgoing_queue = outgoing_queue - self._core_handler = core_handler - self._task: asyncio.Task | None = None - self._closed = False - self._name = name or "adapter" - - def start(self) -> None: - if self._task is None or self._task.done(): - self._task = asyncio.create_task(self._consume_incoming()) - - async def _consume_incoming(self) -> None: - while not self._closed: - try: - item = await asyncio.to_thread(self._incoming_queue.get) - except asyncio.CancelledError: - break - if isinstance(item, dict) and item.get("__core_sink_control__") == "stop": - break - if isinstance(item, dict) and item.get("kind") == "incoming": - envelope = item.get("payload") - try: - await self._core_handler(envelope) - except Exception: # pragma: no cover - logger.exception(f"处理来自 {self._name} 的 incoming 信封时失败") - else: - logger.debug(f"ProcessCoreSinkServer 忽略来自 {self._name} 的未知负载: {item}") - - async def push_outgoing(self, envelope: MessageEnvelope) -> None: - await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope}) - - async def close(self) -> None: - if self._closed: - return - self._closed = True - await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"}) - await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP) - if self._task: - self._task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._task - self._task = None - -async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None: - send_many = getattr(sink, "send_many", None) - if callable(send_many): - await send_many(envelopes) - return - for env in envelopes: - await sink.send(env) - - -class BatchDispatcher: - """ - 批量消息分发器,负责将消息批量发送到核心 sink。 - """ - - _STOP = object() - - def __init__( - self, - sink: CoreMessageSink, - *, - max_batch_size: int = 50, - flush_interval: float = 0.2, - ) -> None: - self._sink = sink - self._max_batch_size = max_batch_size - self._flush_interval = flush_interval - self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue() - self._worker: asyncio.Task | None = None - self._closed = False - - async def add(self, message: MessageEnvelope) -> None: - if self._closed: - raise RuntimeError("Dispatcher closed") - self._ensure_worker() - await self._queue.put(message) - - async def close(self) -> None: - if self._closed: - return - self._closed = True - self._ensure_worker() - await self._queue.put(self._STOP) - if self._worker: - await self._worker - self._worker = None - - def _ensure_worker(self) -> None: - if self._worker is not None and not self._worker.done(): - return - self._worker = asyncio.create_task(self._worker_loop()) - - async def _worker_loop(self) -> None: - buffer: list[MessageEnvelope] = [] - try: - while True: - try: - item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval) - except asyncio.TimeoutError: - item = None - - if item is self._STOP: - await self._flush_buffer(buffer) - return - if item is not None: - buffer.append(item) # type: ignore[arg-type] - - while len(buffer) < self._max_batch_size: - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - break - if item is self._STOP: - await self._flush_buffer(buffer) - return - buffer.append(item) # type: ignore[arg-type] - - if buffer and (len(buffer) >= self._max_batch_size or item is None): - await self._flush_buffer(buffer) - except asyncio.CancelledError: # pragma: no cover - worker cancellation - if buffer: - await self._flush_buffer(buffer) - - async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None: - if not buffer: - return - payload = list(buffer) - buffer.clear() - await _send_many(self._sink, payload) - -__all__ = [ - "AdapterTransportOptions", - "AdapterBase", - "BatchDispatcher", - "CoreSink", - "CoreMessageSink", - "HttpAdapterOptions", - "InProcessCoreSink", - "ProcessCoreSink", - "ProcessCoreSinkServer", - "WebSocketLike", - "WebSocketAdapterOptions", -] diff --git a/src/mofox_bus/api.py b/src/mofox_bus/api.py deleted file mode 100644 index 0140052bc..000000000 --- a/src/mofox_bus/api.py +++ /dev/null @@ -1,515 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import logging -import ssl -from typing import Any, Awaitable, Callable, Dict, Literal, Optional - -import aiohttp -import orjson -import uvicorn -from fastapi import FastAPI, WebSocket, WebSocketDisconnect - - -MessagePayload = Dict[str, Any] -MessageHandler = Callable[[MessagePayload], Awaitable[None] | None] -DisconnectCallback = Callable[[str, str], Awaitable[None] | None] - - -def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any: - """ - 将原始字节数据附加到消息负载中 - - Args: - payload: 消息负载 - raw_bytes: 原始字节数据 - - Returns: - 附加了原始数据的消息负载 - """ - if isinstance(payload, dict): - payload.setdefault("raw_bytes", raw_bytes) - elif isinstance(payload, list): - for item in payload: - if isinstance(item, dict): - item.setdefault("raw_bytes", raw_bytes) - return payload - - -def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]: - """ - 编码消息用于 WebSocket 发送 - - Args: - message: 要发送的消息 - use_raw_bytes: 是否使用原始字节数据 - - Returns: - (编码后的数据, 是否为二进制格式) - """ - if isinstance(message, (bytes, bytearray)): - return bytes(message), True - if use_raw_bytes and isinstance(message, dict): - raw = message.get("raw_bytes") - if isinstance(raw, (bytes, bytearray)): - return bytes(raw), True - payload = message - if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes: - payload = {k: v for k, v in payload.items() if k != "raw_bytes"} - data = orjson.dumps(payload) - if use_raw_bytes: - return data, True - return data.decode("utf-8"), False - - -class BaseMessageHandler: - """基础消息处理器,提供消息处理和任务管理功能""" - - def __init__(self) -> None: - self.message_handlers: list[MessageHandler] = [] - self.background_tasks: set[asyncio.Task] = set() - - def register_message_handler(self, handler: MessageHandler) -> None: - """ - 注册消息处理器 - - Args: - handler: 消息处理函数 - """ - if handler not in self.message_handlers: - self.message_handlers.append(handler) - - async def process_message(self, message: MessagePayload) -> None: - """ - 处理单条消息,并发执行所有注册的处理器 - - Args: - message: 消息负载 - """ - tasks: list[asyncio.Task] = [] - for handler in self.message_handlers: - try: - result = handler(message) - if asyncio.iscoroutine(result): - task = asyncio.create_task(result) - tasks.append(task) - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) - except Exception: # pragma: no cover - logging only - logging.getLogger("mofox_bus.server").exception("消息处理失败") - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - -class MessageServer(BaseMessageHandler): - """ - WebSocket 消息服务器,支持与 FastAPI 应用共享事件循环。 - """ - - def __init__( - self, - host: str = "0.0.0.0", - port: int = 18000, - *, - enable_token: bool = False, - app: FastAPI | None = None, - path: str = "/ws", - ssl_certfile: str | None = None, - ssl_keyfile: str | None = None, - mode: Literal["ws", "tcp"] = "ws", - custom_logger: logging.Logger | None = None, - enable_custom_uvicorn_logger: bool = False, - queue_maxsize: int = 1000, - worker_count: int = 1, - ) -> None: - super().__init__() - if mode != "ws": - raise NotImplementedError("Only WebSocket mode is supported in mofox_bus") - if custom_logger: - logging.getLogger("mofox_bus.server").handlers = custom_logger.handlers - self.host = host - self.port = port - self._app = app or FastAPI() - self._own_app = app is None - self._path = path - self._ssl_certfile = ssl_certfile - self._ssl_keyfile = ssl_keyfile - self._enable_token = enable_token - self._valid_tokens: set[str] = set() - self._connections: set[WebSocket] = set() - self._platform_connections: dict[str, WebSocket] = {} - self._conn_lock = asyncio.Lock() - self._server: uvicorn.Server | None = None - self._running = False - self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize) - self._worker_count = max(1, worker_count) - self._worker_tasks: list[asyncio.Task] = [] - self._setup_routes() - - def _setup_routes(self) -> None: - @_self_websocket(self._app, self._path) - async def websocket_endpoint(websocket: WebSocket) -> None: - platform = websocket.headers.get("platform", "unknown") - token = websocket.headers.get("authorization") or websocket.headers.get("Authorization") - if self._enable_token and not await self.verify_token(token): - await websocket.close(code=1008, reason="invalid token") - return - - await websocket.accept() - await self._register_connection(websocket, platform) - try: - while True: - msg = await websocket.receive() - if msg["type"] == "websocket.receive": - raw_bytes = msg.get("bytes") - if raw_bytes is None and msg.get("text") is not None: - raw_bytes = msg["text"].encode("utf-8") - if not raw_bytes: - continue - try: - payload = orjson.loads(raw_bytes) - except orjson.JSONDecodeError: - logging.getLogger("mofox_bus.server").warning("Invalid JSON payload") - continue - payload = _attach_raw_bytes(payload, raw_bytes) - if isinstance(payload, list): - for item in payload: - await self._enqueue_message(item) - else: - await self._enqueue_message(payload) - elif msg["type"] == "websocket.disconnect": - break - except WebSocketDisconnect: - pass - finally: - await self._remove_connection(websocket, platform) - - async def _enqueue_message(self, payload: MessagePayload) -> None: - if not self._worker_tasks: - self._start_workers() - try: - self._message_queue.put_nowait(payload) - except asyncio.QueueFull: - logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message") - - def _start_workers(self) -> None: - if self._worker_tasks: - return - self._running = True - for _ in range(self._worker_count): - task = asyncio.create_task(self._consumer_worker()) - self._worker_tasks.append(task) - - async def _stop_workers(self) -> None: - if not self._worker_tasks: - return - self._running = False - for task in self._worker_tasks: - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await asyncio.gather(*self._worker_tasks, return_exceptions=True) - self._worker_tasks.clear() - while not self._message_queue.empty(): - with contextlib.suppress(asyncio.QueueEmpty): - self._message_queue.get_nowait() - self._message_queue.task_done() - - async def _consumer_worker(self) -> None: - while self._running: - try: - payload = await self._message_queue.get() - except asyncio.CancelledError: - break - try: - await self.process_message(payload) - except Exception: # pragma: no cover - best effort logging - logging.getLogger("mofox_bus.server").exception("Error processing message") - finally: - self._message_queue.task_done() - - async def verify_token(self, token: str | None) -> bool: - if not self._enable_token: - return True - return token in self._valid_tokens - - def add_valid_token(self, token: str) -> None: - self._valid_tokens.add(token) - - def remove_valid_token(self, token: str) -> None: - self._valid_tokens.discard(token) - - async def _register_connection(self, websocket: WebSocket, platform: str) -> None: - async with self._conn_lock: - self._connections.add(websocket) - if platform: - previous = self._platform_connections.get(platform) - if previous and previous.client_state.name != "DISCONNECTED": - await previous.close(code=1000, reason="replaced") - self._platform_connections[platform] = websocket - - async def _remove_connection(self, websocket: WebSocket, platform: str) -> None: - async with self._conn_lock: - self._connections.discard(websocket) - if platform and self._platform_connections.get(platform) is websocket: - del self._platform_connections[platform] - - async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None: - payload: MessagePayload | bytes = message - data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes) - async with self._conn_lock: - targets = list(self._connections) - for ws in targets: - if is_binary: - await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) - else: - await ws.send_text(data if isinstance(data, str) else data.decode("utf-8")) - - async def broadcast_to_platform( - self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False - ) -> None: - ws = self._platform_connections.get(platform) - if ws is None: - raise RuntimeError(f"平台 {platform} 没有活跃的连接") - payload: MessagePayload | bytes = message - data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes) - if is_binary: - await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) - else: - await ws.send_text(data if isinstance(data, str) else data.decode("utf-8")) - - async def send_message( - self, message: MessagePayload, *, prefer_raw_bytes: bool = False - ) -> None: - platform = message.get("message_info", {}).get("platform") - if not platform: - raise ValueError("message_info.platform is required to route the message") - await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes) - - def run_sync(self) -> None: - if not self._own_app: - return - asyncio.run(self.run()) - - async def run(self) -> None: - self._start_workers() - if not self._own_app: - return - config = uvicorn.Config( - self._app, - host=self.host, - port=self.port, - ssl_certfile=self._ssl_certfile, - ssl_keyfile=self._ssl_keyfile, - log_config=None, - access_log=False, - ) - self._server = uvicorn.Server(config) - try: - await self._server.serve() - except asyncio.CancelledError: # pragma: no cover - shutdown path - pass - - async def stop(self) -> None: - self._running = False - await self._stop_workers() - if self._server: - self._server.should_exit = True - await self._server.shutdown() - self._server = None - async with self._conn_lock: - targets = list(self._connections) - self._connections.clear() - self._platform_connections.clear() - for ws in targets: - try: - await ws.close(code=1001, reason="server shutting down") - except Exception: # pragma: no cover - best effort - pass - for task in list(self.background_tasks): - if not task.done(): - task.cancel() - if self.background_tasks: - await asyncio.gather(*self.background_tasks, return_exceptions=True) - self.background_tasks.clear() - - -class MessageClient(BaseMessageHandler): - """ - WebSocket 消息客户端,实现双向传输。 - """ - - def __init__( - self, - mode: Literal["ws", "tcp"] = "ws", - *, - reconnect_interval: float = 5.0, - logger: logging.Logger | None = None, - ) -> None: - super().__init__() - if mode != "ws": - raise NotImplementedError("Only WebSocket mode is supported in mofox_bus") - self._mode = mode - self._session: aiohttp.ClientSession | None = None - self._ws: aiohttp.ClientWebSocketResponse | None = None - self._receive_task: asyncio.Task | None = None - self._url: str = "" - self._platform: str = "" - self._token: str | None = None - self._ssl_verify: str | None = None - self._closed = False - self._on_disconnect: DisconnectCallback | None = None - self._reconnect_interval = reconnect_interval - self._logger = logger or logging.getLogger("mofox_bus.client") - - async def connect( - self, - *, - url: str, - platform: str, - token: str | None = None, - ssl_verify: str | None = None, - ) -> None: - self._url = url - self._platform = platform - self._token = token - self._ssl_verify = ssl_verify - self._closed = False - await self._establish_connection() - - def set_disconnect_callback(self, callback: DisconnectCallback) -> None: - self._on_disconnect = callback - - async def _establish_connection(self) -> None: - if self._session is None: - self._session = aiohttp.ClientSession() - headers = {"platform": self._platform} - if self._token: - headers["authorization"] = self._token - ssl_context = None - if self._ssl_verify: - ssl_context = ssl.create_default_context(cafile=self._ssl_verify) - self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context) - self._receive_task = asyncio.create_task(self._receive_loop()) - - async def _connect_once(self) -> None: - await self._establish_connection() - - async def _receive_loop(self) -> None: - assert self._ws is not None - try: - async for msg in self._ws: - if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): - raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8") - try: - payload = orjson.loads(raw_bytes) - except orjson.JSONDecodeError: - logging.getLogger("mofox_bus.client").warning("Invalid JSON payload") - continue - payload = _attach_raw_bytes(payload, raw_bytes) - if isinstance(payload, list): - for item in payload: - await self.process_message(item) - else: - await self.process_message(payload) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - except asyncio.CancelledError: # pragma: no cover - cancellation path - pass - finally: - if not self._closed: - await self._notify_disconnect("websocket disconnected") - await self._reconnect() - if self._ws: - await self._ws.close() - self._ws = None - - async def run(self) -> None: - self._closed = False - while not self._closed: - if self._receive_task is None: - await self._establish_connection() - task = self._receive_task - if task is None: - break - try: - await task - except asyncio.CancelledError: # pragma: no cover - cancellation path - raise - - async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool: - ws = await self._ensure_ws() - data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes) - if is_binary: - await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) - else: - await ws.send_str(data if isinstance(data, str) else data.decode("utf-8")) - return True - - def is_connected(self) -> bool: - return self._ws is not None and not self._ws.closed - - async def stop(self) -> None: - self._closed = True - if self._receive_task and not self._receive_task.done(): - self._receive_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._receive_task - if self._ws: - await self._ws.close() - self._ws = None - if self._session: - await self._session.close() - self._session = None - - async def _notify_disconnect(self, reason: str) -> None: - if self._on_disconnect is None: - return - try: - result = self._on_disconnect(self._platform, reason) - if asyncio.iscoroutine(result): - await result - except Exception: # pragma: no cover - best effort notification - logging.getLogger("mofox_bus.client").exception("Disconnect callback failed") - - async def _reconnect(self) -> None: - self._logger.info(f"WebSocket 连接断开, 正在 {self._reconnect_interval:.1f} 秒后重试") - await asyncio.sleep(self._reconnect_interval) - await self._connect_once() - - async def _ensure_session(self) -> aiohttp.ClientSession: - if self._session is None: - self._session = aiohttp.ClientSession() - return self._session - - async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse: - if self._ws is None or self._ws.closed: - await self._connect_once() - assert self._ws is not None - return self._ws - - async def __aenter__(self) -> "MessageClient": - if not self._url or not self._platform: - raise RuntimeError("connect() must be called before using MessageClient as a context manager") - await self._ensure_session() - await self._ensure_ws() - return self - - async def __aexit__(self, exc_type, exc, tb) -> None: - await self.stop() - - -def _self_websocket(app: FastAPI, path: str): - """ - 装饰器工厂,兼容 FastAPI websocket 路由的声明方式。 - FastAPI 不允许直接重复注册同一路径,因此这里封装一个可复用的装饰器。 - """ - - def decorator(func): - app.add_api_websocket_route(path, func) - return func - - return decorator - - -__all__ = ["BaseMessageHandler", "MessageClient", "MessageServer"] diff --git a/src/mofox_bus/builder.py b/src/mofox_bus/builder.py deleted file mode 100644 index e041a9271..000000000 --- a/src/mofox_bus/builder.py +++ /dev/null @@ -1,111 +0,0 @@ -from __future__ import annotations - -import time -import uuid -from typing import Any, Dict, List - -from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload - - -class MessageBuilder: - """ - 流式构建 MessageEnvelope 的助手工具,提供类型安全的构建方法。 - - 使用示例: - msg = ( - MessageBuilder() - .text("Hello") - .image("http://example.com/1.png") - .to_user("123", platform="qq") - .build() - ) - """ - - def __init__(self) -> None: - self._direction: str = "outgoing" - self._message_info: MessageInfoPayload = {} - self._segments: List[SegPayload] = [] - self._metadata: Dict[str, Any] | None = None - self._timestamp_ms: int | None = None - self._message_id: str | None = None - - def direction(self, value: str) -> "MessageBuilder": - self._direction = value - return self - - def message_id(self, value: str) -> "MessageBuilder": - self._message_id = value - return self - - def timestamp_ms(self, value: int | None = None) -> "MessageBuilder": - self._timestamp_ms = value or int(time.time() * 1000) - return self - - def metadata(self, value: Dict[str, Any]) -> "MessageBuilder": - self._metadata = value - return self - - def platform(self, value: str) -> "MessageBuilder": - self._message_info["platform"] = value - return self - - def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder": - if platform: - self.platform(platform) - user_info: UserInfoPayload = {"user_id": user_id} - if nickname: - user_info["user_nickname"] = nickname - self._message_info["user_info"] = user_info - return self - - def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder": - if platform: - self.platform(platform) - group_info: GroupInfoPayload = {"group_id": group_id} - if name: - group_info["group_name"] = name - self._message_info["group_info"] = group_info - return self - - def seg(self, type_: str, data: Any) -> "MessageBuilder": - self._segments.append({"type": type_, "data": data}) - return self - - def text(self, content: str) -> "MessageBuilder": - return self.seg("text", content) - - def image(self, url: str) -> "MessageBuilder": - return self.seg("image", url) - - def reply(self, target_message_id: str) -> "MessageBuilder": - return self.seg("reply", target_message_id) - - def raw_segment(self, segment: SegPayload) -> "MessageBuilder": - self._segments.append(segment) - return self - - def build(self) -> MessageEnvelope: - """构建最终的消息信封""" - # 设置 message_info 默认值 - if not self._segments: - raise ValueError("需要至少添加一个消息段才能构建消息") - if self._message_id is None: - self._message_id = str(uuid.uuid4()) - info = dict(self._message_info) - info.setdefault("message_id", self._message_id) - info.setdefault("time", time.time()) - - segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments] - envelope: MessageEnvelope = { - "direction": self._direction, # type: ignore[assignment] - "message_info": info, - "message_segment": segments[0] if len(segments) == 1 else list(segments), - } - if self._metadata is not None: - envelope["metadata"] = self._metadata - if self._timestamp_ms is not None: - envelope["timestamp_ms"] = self._timestamp_ms - return envelope - - -__all__ = ["MessageBuilder"] diff --git a/src/mofox_bus/codec.py b/src/mofox_bus/codec.py deleted file mode 100644 index 29e8d0ebe..000000000 --- a/src/mofox_bus/codec.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -import json as _stdlib_json -from typing import Any, Dict, Iterable, List - -try: - import orjson as _json_impl -except Exception: # pragma: no cover - fallback when orjson is unavailable - _json_impl = None - -from .types import MessageEnvelope - -DEFAULT_SCHEMA_VERSION = 1 - - -def _dumps(obj: Any) -> bytes: - if _json_impl is not None: - return _json_impl.dumps(obj) - return _stdlib_json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") - - -def _loads(data: bytes) -> Dict[str, Any]: - if _json_impl is not None: - return _json_impl.loads(data) - return _stdlib_json.loads(data.decode("utf-8")) - - -def dumps_message(msg: MessageEnvelope) -> bytes: - """ - 将单条消息序列化为 JSON bytes。 - """ - sanitized = _strip_raw_bytes(msg) - if "schema_version" not in sanitized: - sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION - return _dumps(sanitized) - -def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes: - """ - 将批量消息序列化为 JSON bytes。 - """ - payload = { - "schema_version": DEFAULT_SCHEMA_VERSION, - "items": [_strip_raw_bytes(msg) for msg in messages], - } - return _dumps(payload) - -def loads_message(data: bytes | str) -> MessageEnvelope: - """ - 反序列化单条消息。 - """ - if isinstance(data, str): - data = data.encode("utf-8") - obj = _loads(data) - return _upgrade_schema_if_needed(obj) - - -def loads_messages(data: bytes | str) -> List[MessageEnvelope]: - """ - 反序列化批量消息。 - """ - if isinstance(data, str): - data = data.encode("utf-8") - obj = _loads(data) - version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION) - if version != DEFAULT_SCHEMA_VERSION: - raise ValueError(f"不支持的 schema_version={version}") - return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])] - - -def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope: - """ - 针对未来的 schema 版本演进预留兼容入口。 - """ - version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION) - if version == DEFAULT_SCHEMA_VERSION: - return obj # type: ignore[return-value] - raise ValueError(f"不支持的 schema_version={version}") - - - -def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope: - if isinstance(msg, dict) and "raw_bytes" in msg: - new_msg = dict(msg) - new_msg.pop("raw_bytes", None) - return new_msg # type: ignore[return-value] - return msg - -__all__ = [ - "DEFAULT_SCHEMA_VERSION", - "dumps_message", - "dumps_messages", - "loads_message", - "loads_messages", -] diff --git a/src/mofox_bus/router.py b/src/mofox_bus/router.py deleted file mode 100644 index 8783502d5..000000000 --- a/src/mofox_bus/router.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import logging -from dataclasses import asdict, dataclass -from typing import Callable, Dict, Optional - -from .api import MessageClient -from .types import MessageEnvelope - -logger = logging.getLogger("mofox_bus.router") - - -@dataclass -class TargetConfig: - """路由目标配置,包含连接信息和认证配置""" - url: str - token: str | None = None - ssl_verify: str | None = None - - def to_dict(self) -> Dict[str, str | None]: - """转换为字典格式""" - return asdict(self) - - @classmethod - def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig": - """从字典创建配置对象""" - return cls( - url=data.get("url", ""), - token=data.get("token"), - ssl_verify=data.get("ssl_verify"), - ) - - -@dataclass -class RouteConfig: - """路由配置,包含多个平台的路由目标""" - route_config: Dict[str, TargetConfig] - - def to_dict(self) -> Dict[str, Dict[str, str | None]]: - """转换为字典格式""" - return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}} - - @classmethod - def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig": - """从字典创建路由配置对象""" - cfg = { - platform: TargetConfig.from_dict(target) - for platform, target in data.get("route_config", {}).items() - } - return cls(route_config=cfg) - - -class Router: - """消息路由器,负责管理多个平台的消息客户端连接""" - - def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None: - """ - 初始化路由器 - - Args: - config: 路由配置 - custom_logger: 自定义日志记录器 - """ - if custom_logger: - logger.handlers = custom_logger.handlers - self.config = config - self.clients: Dict[str, MessageClient] = {} - self.handlers: list[Callable[[Dict], None]] = [] - self._running = False - self._client_tasks: Dict[str, asyncio.Task] = {} - self._stop_event: asyncio.Event | None = None - - async def connect(self, platform: str) -> None: - """ - 连接到指定平台 - - Args: - platform: 平台标识 - - Raises: - ValueError: 未知平台 - NotImplementedError: 不支持的模式 - """ - if platform not in self.config.route_config: - raise ValueError(f"未知平台: {platform}") - target = self.config.route_config[platform] - mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws" - if mode != "ws": - raise NotImplementedError("TCP 模式暂未实现") - client = MessageClient(mode="ws") - client.set_disconnect_callback(self._handle_client_disconnect) - await client.connect( - url=target.url, - platform=platform, - token=target.token, - ssl_verify=target.ssl_verify, - ) - for handler in self.handlers: - client.register_message_handler(handler) - self.clients[platform] = client - if self._running: - self._start_client_task(platform, client) - - def register_class_handler(self, handler: Callable[[Dict], None]) -> None: - self.handlers.append(handler) - for client in self.clients.values(): - client.register_message_handler(handler) - - async def run(self) -> None: - """启动路由器,连接所有配置的平台并开始运行""" - self._running = True - self._stop_event = asyncio.Event() - for platform in self.config.route_config: - if platform not in self.clients: - await self.connect(platform) - for platform, client in self.clients.items(): - if platform not in self._client_tasks: - self._start_client_task(platform, client) - try: - await self._stop_event.wait() - except asyncio.CancelledError: # pragma: no cover - raise - - async def remove_platform(self, platform: str) -> None: - """ - 移除指定平台的连接 - - Args: - platform: 平台标识 - """ - if platform in self._client_tasks: - task = self._client_tasks.pop(platform) - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await task - client = self.clients.pop(platform, None) - if client: - await client.stop() - - async def _handle_client_disconnect(self, platform: str, reason: str) -> None: - """ - 处理客户端断开连接 - - Args: - platform: 平台标识 - reason: 断开原因 - """ - logger.info(f"平台 {platform} 的客户端断开连接: {reason} (客户端将自动重连)") - task = self._client_tasks.get(platform) - if task is not None and not task.done(): - return - client = self.clients.get(platform) - if client and self._running: - self._start_client_task(platform, client) - - async def stop(self) -> None: - """停止路由器,关闭所有连接""" - self._running = False - if self._stop_event: - self._stop_event.set() - for platform in list(self.clients.keys()): - await self.remove_platform(platform) - self.clients.clear() - - def _start_client_task(self, platform: str, client: MessageClient) -> None: - """ - 启动客户端任务 - - Args: - platform: 平台标识 - client: 消息客户端 - """ - task = asyncio.create_task(client.run()) - task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t))) - self._client_tasks[platform] = task - - async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None: - """ - 必要时重启客户端任务 - - Args: - platform: 平台标识 - task: 已完成的任务 - """ - if not self._running: - return - if task.cancelled(): - return - exc = task.exception() - if exc: - logger.warning(f"平台 {platform} 的客户端任务异常结束: {exc}") - client = self.clients.get(platform) - if client: - self._start_client_task(platform, client) - - def get_target_url(self, message: MessageEnvelope) -> Optional[str]: - """ - 根据消息获取目标 URL - - Args: - message: 消息信封 - - Returns: - 目标 URL 或 None - """ - platform = message.get("message_info", {}).get("platform") - if not platform: - return None - target = self.config.route_config.get(platform) - return target.url if target else None - - async def send_message(self, message: MessageEnvelope): - """ - 发送消息到指定平台 - - Args: - message: 消息信封 - - Raises: - ValueError: 缺少平台信息 - RuntimeError: 未找到对应平台的客户端 - """ - platform = message.get("message_info", {}).get("platform") - if not platform: - raise ValueError("消息中缺少必需的 message_info.platform 字段") - client = self.clients.get(platform) - if client is None: - raise RuntimeError(f"平台 {platform} 没有已连接的客户端") - return await client.send_message(message) - - async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None: - """ - 更新路由配置 - - Args: - config_data: 新的配置数据 - """ - new_config = RouteConfig.from_dict(config_data) - await self._adjust_connections(new_config) - self.config = new_config - - async def _adjust_connections(self, new_config: RouteConfig) -> None: - """ - 调整连接以匹配新配置 - - Args: - new_config: 新的路由配置 - """ - current = set(self.config.route_config.keys()) - updated = set(new_config.route_config.keys()) - # 移除不再存在的平台 - for platform in current - updated: - await self.remove_platform(platform) - # 添加或更新平台 - for platform in updated: - if platform not in current: - await self.connect(platform) - else: - old = self.config.route_config[platform] - new = new_config.route_config[platform] - if old.url != new.url or old.token != new.token: - await self.remove_platform(platform) - await self.connect(platform) diff --git a/src/mofox_bus/runtime.py b/src/mofox_bus/runtime.py deleted file mode 100644 index 225f6607a..000000000 --- a/src/mofox_bus/runtime.py +++ /dev/null @@ -1,470 +0,0 @@ -from __future__ import annotations - -import asyncio -import functools -import inspect -import threading -import weakref -from dataclasses import dataclass -from typing import Awaitable, Callable, Dict, Iterable, List, Protocol - -from .types import MessageEnvelope - -Hook = Callable[[MessageEnvelope], Awaitable[None] | None] -ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None] -Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]] -MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None] -BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None] -MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]] - - -class Middleware(Protocol): - async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ... - - -class MessageProcessingError(RuntimeError): - """封装处理链路中发生的异常。""" - - def __init__(self, message: MessageEnvelope, original: BaseException): - detail = message.get("id", "") - super().__init__(f"处理消息 {detail} 时出错: {original}") - self.message_envelope = message - self.original = original - - -@dataclass -class MessageRoute: - """消息路由配置,包含匹配条件和处理函数""" - predicate: Predicate - handler: MessageHandler - name: str | None = None - message_type: str | None = None - message_types: set[str] | None = None # 支持多个消息类型 - event_types: set[str] | None = None - - -class MessageRuntime: - """ - 消息运行时环境,负责调度消息路由、执行前后处理钩子以及批量处理消息 - """ - - def __init__(self) -> None: - self._routes: list[MessageRoute] = [] - self._before_hooks: list[Hook] = [] - self._after_hooks: list[Hook] = [] - self._error_hooks: list[ErrorHook] = [] - self._batch_handler: BatchHandler | None = None - self._lock = threading.RLock() - self._middlewares: list[Middleware] = [] - self._type_routes: Dict[str, list[MessageRoute]] = {} - self._event_routes: Dict[str, list[MessageRoute]] = {} - # 用于检测同一类型的重复注册 - self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name - - def add_route( - self, - predicate: Predicate, - handler: MessageHandler, - name: str | None = None, - *, - message_type: str | list[str] | None = None, - event_types: Iterable[str] | None = None, - ) -> None: - """ - 添加消息路由 - - Args: - predicate: 路由匹配条件 - handler: 消息处理函数 - name: 路由名称(可选) - message_type: 消息类型,可以是字符串或字符串列表(可选) - event_types: 事件类型列表(可选) - """ - with self._lock: - # 处理 message_type 参数,支持字符串或列表 - message_types_set: set[str] | None = None - single_message_type: str | None = None - - if message_type is not None: - if isinstance(message_type, str): - message_types_set = {message_type} - single_message_type = message_type - elif isinstance(message_type, list): - message_types_set = set(message_type) - if len(message_types_set) == 1: - single_message_type = next(iter(message_types_set)) - else: - raise TypeError(f"message_type must be str or list[str], got {type(message_type)}") - - # 检测重复注册:如果明确指定了某个类型,不允许重复 - handler_name = name or getattr(handler, "__name__", str(handler)) - for msg_type in message_types_set: - if msg_type in self._explicit_type_handlers: - existing_handler = self._explicit_type_handlers[msg_type] - raise ValueError( - f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册," - f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。" - ) - self._explicit_type_handlers[msg_type] = handler_name - - route = MessageRoute( - predicate=predicate, - handler=handler, - name=name, - message_type=single_message_type, - message_types=message_types_set, - event_types=set(event_types) if event_types is not None else None, - ) - self._routes.append(route) - - # 为每个消息类型建立索引 - if message_types_set: - for msg_type in message_types_set: - self._type_routes.setdefault(msg_type, []).append(route) - - if route.event_types: - for et in route.event_types: - self._event_routes.setdefault(et, []).append(route) - - def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]: - """装饰器写法,便于在核心逻辑中声明式注册。 - - 支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。 - """ - - def decorator(func: MessageHandler) -> MessageHandler: - # Support decorating instance methods: defer binding until the object is created. - if _looks_like_method(func): - return _InstanceMethodRoute( - runtime=self, - func=func, - predicate=predicate, - name=name, - message_type=None, - ) - - self.add_route(predicate, func, name=name) - return func - - return decorator - - def on_message( - self, - func: MessageHandler | None = None, - *, - message_type: str | list[str] | None = None, - platform: str | None = None, - predicate: Predicate | None = None, - name: str | None = None, - ) -> Callable[[MessageHandler], MessageHandler] | MessageHandler: - """Sugar decorator with optional Seg.type/platform predicate matching. - - Args: - func: 被装饰的函数 - message_type: 消息类型,可以是单个字符串或字符串列表 - platform: 平台名称 - predicate: 自定义匹配条件 - name: 路由名称 - - Usages: - - @runtime.on_message(...) - - @runtime.on_message - - @runtime.on_message(message_type="text") - - @runtime.on_message(message_type=["text", "image"]) - - If the target looks like an instance method (first arg is self), it will be - auto-bound to the instance and registered when the object is constructed. - """ - # 将 message_type 转换为集合以便统一处理 - message_types_set: set[str] | None = None - if message_type is not None: - if isinstance(message_type, str): - message_types_set = {message_type} - elif isinstance(message_type, list): - message_types_set = set(message_type) - else: - raise TypeError(f"message_type must be str or list[str], got {type(message_type)}") - - async def combined_predicate(message: MessageEnvelope) -> bool: - if message_types_set is not None: - extracted_type = _extract_segment_type(message) - if extracted_type not in message_types_set: - return False - if platform is not None: - info_platform = message.get("message_info", {}).get("platform") - if message.get("platform") not in (None, platform) and info_platform is None: - return False - if info_platform not in (None, platform): - return False - if predicate is None: - return True - return await _invoke_callable(predicate, message, prefer_thread=False) - - def decorator(func: MessageHandler) -> MessageHandler: - # Support decorating instance methods: defer binding until the object is created. - if _looks_like_method(func): - return _InstanceMethodRoute( - runtime=self, - func=func, - predicate=combined_predicate, - name=name, - message_type=message_type, - ) - - self.add_route(combined_predicate, func, name=name, message_type=message_type) - return func - - if func is not None: - return decorator(func) - return decorator - - - - def set_batch_handler(self, handler: BatchHandler) -> None: - self._batch_handler = handler - - def register_before_hook(self, hook: Hook) -> None: - self._before_hooks.append(hook) - - def register_after_hook(self, hook: Hook) -> None: - self._after_hooks.append(hook) - - def register_error_hook(self, hook: ErrorHook) -> None: - self._error_hooks.append(hook) - - def register_middleware(self, middleware: Middleware) -> None: - """注册洋葱模型中间件,围绕处理器执行。""" - - self._middlewares.append(middleware) - - async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None: - await self._run_hooks(self._before_hooks, message) - try: - route = await self._match_route(message) - if route is None: - return None - handler = self._wrap_with_middlewares(route.handler) - result = await handler(message) - except Exception as exc: - await self._run_error_hooks(message, exc) - raise MessageProcessingError(message, exc) from exc - await self._run_hooks(self._after_hooks, message) - return result - - async def handle_batch(self, messages: Iterable[MessageEnvelope]) -> List[MessageEnvelope]: - batch = list(messages) - if not batch: - return [] - if self._batch_handler is not None: - result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True) - return result or [] - responses: list[MessageEnvelope] = [] - for message in batch: - response = await self.handle_message(message) - if response is not None: - responses.append(response) - return responses - - async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None: - """匹配消息路由,优先匹配明确指定了消息类型的处理器""" - message_type = _extract_segment_type(message) - event_type = ( - message.get("event_type") - or message.get("message_info", {}) - .get("additional_config", {}) - .get("event_type") - ) - - # 分为两层候选:优先级和普通 - priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的 - normal_candidates: list[MessageRoute] = [] # 没有指定或通配的 - - with self._lock: - # 事件路由(优先级最高) - if event_type and event_type in self._event_routes: - priority_candidates.extend(self._event_routes[event_type]) - - # 消息类型路由(明确指定的有优先级) - if message_type and message_type in self._type_routes: - priority_candidates.extend(self._type_routes[message_type]) - - # 通用路由(没有明确指定类型的) - for route in self._routes: - # 如果路由没有指定 message_types,则是通用路由 - if route.message_types is None and route.event_types is None: - normal_candidates.append(route) - - # 先尝试优先级候选 - seen: set[int] = set() - for route in priority_candidates: - rid = id(route) - if rid in seen: - continue - seen.add(rid) - should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) - if should_handle: - return route - - # 如果没有匹配到优先级候选,再尝试普通候选 - for route in normal_candidates: - rid = id(route) - if rid in seen: - continue - seen.add(rid) - should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) - if should_handle: - return route - - return None - - async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None: - coro_list = [self._call_hook(hook, message) for hook in hooks] - if coro_list: - await asyncio.gather(*coro_list) - - async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None: - await _invoke_callable(hook, message, prefer_thread=True) - - async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None: - coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks] - if coros: - await asyncio.gather(*coros) - - async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None: - await _invoke_callable(hook, message, exc, prefer_thread=True) - - def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable: - async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None: - return await _invoke_callable(handler, message, prefer_thread=True) - - wrapped: MiddlewareCallable = base_handler - for middleware in reversed(self._middlewares): - current = wrapped - - async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None: - return await _invoke_callable(mw, msg, nxt, prefer_thread=False) - - wrapped = wrapper - return wrapped - - -async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False): - """支持 sync/async 调用,并可选择在线程中执行。 - - 自动处理普通函数、类方法和绑定方法。 - """ - # 如果是绑定方法(bound method),直接使用,不需要额外处理 - # 因为绑定方法已经包含了 self 参数 - if inspect.ismethod(func): - # 绑定方法可以直接调用,args 中不应包含 self - if inspect.iscoroutinefunction(func): - return await func(*args) - if prefer_thread: - result = await asyncio.to_thread(func, *args) - if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): - return await result - return result - result = func(*args) - if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): - return await result - return result - - # 对于普通函数(未绑定的),按原有逻辑处理 - if inspect.iscoroutinefunction(func): - return await func(*args) - if prefer_thread: - result = await asyncio.to_thread(func, *args) - if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): - return await result - return result - result = func(*args) - if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): - return await result - return result - - -def _extract_segment_type(message: MessageEnvelope) -> str | None: - seg = message.get("message_segment") or message.get("message_chain") - if isinstance(seg, dict): - return seg.get("type") - if isinstance(seg, list) and seg: - first = seg[0] - if isinstance(first, dict): - return first.get("type") - return None - - -def _looks_like_method(func: Callable[..., object]) -> bool: - """Return True if callable signature suggests an instance method (first arg named self).""" - if inspect.ismethod(func): - return True - if not inspect.isfunction(func): - return False - params = inspect.signature(func).parameters - if not params: - return False - first = next(iter(params.values())) - return first.name == "self" - - -class _InstanceMethodRoute: - """Descriptor that binds decorated instance methods and registers routes per-instance.""" - - def __init__( - self, - runtime: MessageRuntime, - func: MessageHandler, - predicate: Predicate, - name: str | None, - message_type: str | None, - ) -> None: - self._runtime = runtime - self._func = func - self._predicate = predicate - self._name = name - self._message_type = message_type - self._owner: type | None = None - self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet() - - def __set_name__(self, owner: type, name: str) -> None: - self._owner = owner - registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None) - if registry is None: - registry = [] - setattr(owner, "_mofox_instance_routes", registry) - original_init = owner.__init__ - - @functools.wraps(original_init) - def wrapped_init(inst, *args, **kwargs): - original_init(inst, *args, **kwargs) - for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []): - descriptor._register_instance(inst) - - owner.__init__ = wrapped_init # type: ignore[assignment] - registry.append(self) - - def _register_instance(self, instance: object) -> None: - if instance in self._registered_instances: - return - owner = self._owner or instance.__class__ - bound = self._func.__get__(instance, owner) # type: ignore[arg-type] - self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type) - self._registered_instances.add(instance) - - def __get__(self, instance: object | None, owner: type | None = None): - if instance is None: - return self._func - self._register_instance(instance) - return self._func.__get__(instance, owner) # type: ignore[arg-type] - - -__all__ = [ - "BatchHandler", - "Hook", - "MessageHandler", - "MessageProcessingError", - "MessageRoute", - "MessageRuntime", - "Middleware", - "Predicate", -] diff --git a/src/mofox_bus/transport/__init__.py b/src/mofox_bus/transport/__init__.py deleted file mode 100644 index 5915116d5..000000000 --- a/src/mofox_bus/transport/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -传输层封装,提供 HTTP / WebSocket server & client。 -""" - -from .http_client import HttpMessageClient -from .http_server import HttpMessageServer -from .ws_client import WsMessageClient -from .ws_server import WsMessageServer - -__all__ = ["HttpMessageClient", "HttpMessageServer", "WsMessageClient", "WsMessageServer"] diff --git a/src/mofox_bus/transport/http_client.py b/src/mofox_bus/transport/http_client.py deleted file mode 100644 index f88c89091..000000000 --- a/src/mofox_bus/transport/http_client.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Iterable, List, Sequence - -import aiohttp - -from ..codec import dumps_messages, loads_messages -from ..types import MessageEnvelope - - -class HttpMessageClient: - """ - 面向消息批量传输的 HTTP 客户端封装 - """ - - def __init__( - self, - base_url: str, - *, - session: aiohttp.ClientSession | None = None, - timeout: aiohttp.ClientTimeout | None = None, - ) -> None: - self._base_url = base_url.rstrip("/") - self._session = session - self._timeout = timeout - self._owns_session = session is None - self._logger = logging.getLogger("mofox_bus.http_client") - - async def send_messages( - self, - messages: Sequence[MessageEnvelope], - *, - expect_reply: bool = False, - path: str = "/messages", - ) -> List[MessageEnvelope] | None: - if not messages: - return [] - session = await self._ensure_session() - url = f"{self._base_url}{path}" - payload = dumps_messages(messages) - self._logger.debug(f"正在发送 {len(messages)} 条消息 -> {url}") - async with session.post(url, data=payload, timeout=self._timeout) as resp: - resp.raise_for_status() - if not expect_reply: - return None - raw = await resp.read() - replies = loads_messages(raw) - self._logger.debug(f"接收到 {len(replies)} 条回复消息") - return replies - - async def close(self) -> None: - if self._owns_session and self._session is not None: - await self._session.close() - self._session = None - - async def _ensure_session(self) -> aiohttp.ClientSession: - if self._session is None: - self._session = aiohttp.ClientSession() - return self._session - - async def __aenter__(self) -> "HttpMessageClient": - await self._ensure_session() - return self - - async def __aexit__(self, exc_type, exc, tb) -> None: - await self.close() diff --git a/src/mofox_bus/transport/http_server.py b/src/mofox_bus/transport/http_server.py deleted file mode 100644 index 0a47a24ef..000000000 --- a/src/mofox_bus/transport/http_server.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Awaitable, Callable, List - -from aiohttp import web - -from ..codec import dumps_messages, loads_messages -from ..types import MessageEnvelope - -MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None]] - - -class HttpMessageServer: - """ - 轻量级 HTTP 消息入口,可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下 - """ - - def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None: - self._handler = handler - self._app = web.Application() - self._path = path - self._app.add_routes([web.post(path, self._handle_messages)]) - self._logger = logging.getLogger("mofox_bus.http_server") - - async def _handle_messages(self, request: web.Request) -> web.Response: - try: - raw = await request.read() - envelopes = loads_messages(raw) - self._logger.debug(f"接收到 {len(envelopes)} 条消息") - except Exception as exc: # pragma: no cover - network errors are integration tested - self._logger.exception(f"解析请求失败: {exc}") - raise web.HTTPBadRequest(reason=f"无效的负载: {exc}") from exc - - result = await self._handler(envelopes) - if result is None: - return web.Response(status=200, text="ok") - payload = dumps_messages(result) - return web.Response(status=200, body=payload, content_type="application/json") - - def make_app(self) -> web.Application: - """ - 返回 aiohttp Application,可被外部 server(gunicorn/uvicorn)直接使用。 - """ - - return self._app - - def add_to_app(self, app: web.Application) -> None: - """ - 将消息路由注册到给定的 aiohttp app,方便与既有服务整合。 - """ - - app.router.add_post(self._path, self._handle_messages) diff --git a/src/mofox_bus/transport/ws_client.py b/src/mofox_bus/transport/ws_client.py deleted file mode 100644 index 65366c4e0..000000000 --- a/src/mofox_bus/transport/ws_client.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from typing import Awaitable, Callable, Iterable, List, Sequence - -import aiohttp - -from ..codec import dumps_messages, loads_messages -from ..types import MessageEnvelope - -IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]] - - -class WsMessageClient: - """ - 管理 WebSocket 连接,提供 send/receive API,并在后台读取消息 - """ - - def __init__( - self, - url: str, - *, - handler: IncomingHandler | None = None, - session: aiohttp.ClientSession | None = None, - reconnect_interval: float = 5.0, - ) -> None: - self._url = url - self._handler = handler - self._session = session - self._reconnect_interval = reconnect_interval - self._owns_session = session is None - self._ws: aiohttp.ClientWebSocketResponse | None = None - self._receive_task: asyncio.Task | None = None - self._closed = False - self._logger = logging.getLogger("mofox_bus.ws_client") - - async def connect(self) -> None: - await self._ensure_session() - await self._connect_once() - - async def _connect_once(self) -> None: - assert self._session is not None - self._ws = await self._session.ws_connect(self._url) - self._logger.info(f"已连接到 {self._url}") - self._receive_task = asyncio.create_task(self._receive_loop()) - - async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None: - if not messages: - return - ws = await self._ensure_ws() - payload = dumps_messages(messages) - await ws.send_bytes(payload) - - async def send_message(self, message: MessageEnvelope) -> None: - await self.send_messages([message]) - - async def close(self) -> None: - self._closed = True - if self._receive_task: - self._receive_task.cancel() - if self._ws: - await self._ws.close() - self._ws = None - if self._owns_session and self._session: - await self._session.close() - self._session = None - - async def _receive_loop(self) -> None: - assert self._ws is not None - try: - async for msg in self._ws: - if msg.type in (aiohttp.WSMsgType.BINARY, aiohttp.WSMsgType.TEXT): - envelopes = loads_messages(msg.data) - for env in envelopes: - if self._handler is not None: - await self._handler(env) - elif msg.type == aiohttp.WSMsgType.ERROR: - self._logger.warning(f"WebSocket 错误: {msg.data}") - break - except asyncio.CancelledError: # pragma: no cover - cancellation path - return - finally: - if not self._closed: - await self._reconnect() - - async def _reconnect(self) -> None: - self._logger.info(f"WebSocket 断开, 正在 {self._reconnect_interval:.1f} 秒后重试") - await asyncio.sleep(self._reconnect_interval) - await self._connect_once() - - async def _ensure_session(self) -> aiohttp.ClientSession: - if self._session is None: - self._session = aiohttp.ClientSession() - return self._session - - async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse: - if self._ws is None or self._ws.closed: - await self._connect_once() - assert self._ws is not None - return self._ws - - async def __aenter__(self) -> "WsMessageClient": - await self.connect() - return self - - async def __aexit__(self, exc_type, exc, tb) -> None: - await self.close() diff --git a/src/mofox_bus/transport/ws_server.py b/src/mofox_bus/transport/ws_server.py deleted file mode 100644 index e82a3b758..000000000 --- a/src/mofox_bus/transport/ws_server.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from contextlib import asynccontextmanager -from typing import Awaitable, Callable, Iterable, List, Set - -from aiohttp import WSMsgType, web - -from ..codec import dumps_messages, loads_messages -from ..types import MessageEnvelope - -WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]] - - -class WsMessageServer: - """ - 封装 WebSocket 服务端逻辑,负责接收消息并广播响应 - """ - - def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None: - self._handler = handler - self._app = web.Application() - self._path = path - self._app.add_routes([web.get(path, self._handle_ws)]) - self._connections: Set[web.WebSocketResponse] = set() - self._lock = asyncio.Lock() - self._logger = logging.getLogger("mofox_bus.ws_server") - - async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse() - await ws.prepare(request) - self._logger.info(f"WebSocket 连接打开: {request.remote}") - - async with self._track_connection(ws): - async for message in ws: - if message.type in (WSMsgType.BINARY, WSMsgType.TEXT): - envelopes = loads_messages(message.data) - for env in envelopes: - await self._handler(env) - elif message.type == WSMsgType.ERROR: - self._logger.warning(f"WebSocket 连接错误: {ws.exception()}") - break - - self._logger.info(f"WebSocket 连接关闭: {request.remote}") - return ws - - @asynccontextmanager - async def _track_connection(self, ws: web.WebSocketResponse): - async with self._lock: - self._connections.add(ws) - try: - yield - finally: - async with self._lock: - self._connections.discard(ws) - - async def broadcast(self, messages: Iterable[MessageEnvelope]) -> None: - payload = dumps_messages(list(messages)) - async with self._lock: - targets = list(self._connections) - for ws in targets: - await ws.send_bytes(payload) - - def make_app(self) -> web.Application: - return self._app diff --git a/src/mofox_bus/types.py b/src/mofox_bus/types.py deleted file mode 100644 index fb69e3bd5..000000000 --- a/src/mofox_bus/types.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required - -MessageDirection = Literal["incoming", "outgoing"] - -# ---------------------------- -# maim_message 风格的 TypedDict -# ---------------------------- - - -class SegPayload(TypedDict, total=False): - """ - 对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。 - """ - - type: Required[str] - data: Required[str | List["SegPayload"]] - translated_data: NotRequired[str | List["SegPayload"]] - - -class UserInfoPayload(TypedDict, total=False): - platform: NotRequired[str] - user_id: Required[str] - user_nickname: NotRequired[str] - user_cardname: NotRequired[str] - user_avatar: NotRequired[str] - - -class GroupInfoPayload(TypedDict, total=False): - platform: NotRequired[str] - group_id: Required[str] - group_name: NotRequired[str] - - -class FormatInfoPayload(TypedDict, total=False): - content_format: NotRequired[List[str]] - accept_format: NotRequired[List[str]] - - -class TemplateInfoPayload(TypedDict, total=False): - template_items: NotRequired[Dict[str, str]] - template_name: NotRequired[Dict[str, str]] - template_default: NotRequired[bool] - - -class MessageInfoPayload(TypedDict, total=False): - platform: Required[str] - message_id: Required[str] - time: NotRequired[float] - group_info: NotRequired[GroupInfoPayload] - user_info: NotRequired[UserInfoPayload] - format_info: NotRequired[FormatInfoPayload] - template_info: NotRequired[TemplateInfoPayload] - additional_config: NotRequired[Dict[str, Any]] - -# ---------------------------- -# MessageEnvelope -# ---------------------------- - - -class MessageEnvelope(TypedDict, total=False): - """ - mofox-bus 传输层统一使用的消息信封。 - - - 采用 maim_message 风格:message_info + message_segment。 - """ - - direction: MessageDirection - message_info: Required[MessageInfoPayload] - message_segment: Required[SegPayload] | List[SegPayload] - raw_message: NotRequired[Any] - raw_bytes: NotRequired[bytes] - message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名 - platform: NotRequired[str] # 快捷访问,等价于 message_info.platform - message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id - timestamp_ms: NotRequired[int] - correlation_id: NotRequired[str] - schema_version: NotRequired[int] - metadata: NotRequired[Dict[str, Any]] - -__all__ = [ - # maim_message style payloads - "SegPayload", - "UserInfoPayload", - "GroupInfoPayload", - "FormatInfoPayload", - "TemplateInfoPayload", - "MessageInfoPayload", - # legacy content style - "MessageDirection", - "MessageEnvelope", -] diff --git a/src/plugin_system/base/base_adapter.py b/src/plugin_system/base/base_adapter.py index a93ad93dd..d2566a4f2 100644 --- a/src/plugin_system/base/base_adapter.py +++ b/src/plugin_system/base/base_adapter.py @@ -62,8 +62,8 @@ class BaseAdapter(MoFoxAdapterBase, ABC): self._config: Dict[str, Any] = {} self._health_check_task: Optional[asyncio.Task] = None self._running = False - # 标记是否在子进程中运行(由核心管理器传入 ProcessCoreSink 时自动生效) - self._is_subprocess = isinstance(core_sink, ProcessCoreSink) + # 标记是否在子进程中运行 + self._is_subprocess = False @classmethod def from_process_queues( diff --git a/src/plugin_system/core/adapter_manager.py b/src/plugin_system/core/adapter_manager.py index f64cff5ca..ae0ebec0d 100644 --- a/src/plugin_system/core/adapter_manager.py +++ b/src/plugin_system/core/adapter_manager.py @@ -2,6 +2,11 @@ Adapter 管理器 负责管理所有注册的适配器,支持子进程自动启动和生命周期管理。 + +重构说明(2025-11): +- 使用 CoreSinkManager 统一管理 InProcessCoreSink 和 ProcessCoreSink +- 根据适配器的 run_in_subprocess 属性自动选择 CoreSink 类型 +- 子进程适配器通过 CoreSinkManager 的通信队列与主进程交互 """ from __future__ import annotations @@ -15,7 +20,6 @@ if TYPE_CHECKING: from src.plugin_system.base.base_adapter import BaseAdapter from mofox_bus import ProcessCoreSinkServer -from src.common.core_sink import get_core_sink from src.common.logger import get_logger logger = get_logger("adapter_manager") @@ -34,6 +38,11 @@ def _adapter_process_entry( incoming_queue: mp.Queue, outgoing_queue: mp.Queue, ): + """ + 子进程适配器入口函数 + + 在子进程中运行,创建 ProcessCoreSink 与主进程通信 + """ import asyncio import contextlib from mofox_bus import ProcessCoreSink @@ -44,9 +53,14 @@ def _adapter_process_entry( if plugin_info: plugin_cls = _load_class(plugin_info["module"], plugin_info["class"]) plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"]) + + # 创建 ProcessCoreSink 用于与主进程通信 core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue) + + # 创建并启动适配器 adapter = adapter_cls(core_sink, plugin=plugin_instance) await adapter.start() + try: while not getattr(core_sink, "_closed", False): await asyncio.sleep(0.2) @@ -61,21 +75,23 @@ def _adapter_process_entry( class AdapterProcess: - """适配器子进程封装:管理子进程的生命周期与通信桥接""" + """ + 适配器子进程封装:管理子进程的生命周期与通信桥接 + + 使用 CoreSinkManager 创建通信队列,自动维护与子进程的消息通道 + """ - def __init__(self, adapter_cls: "type[BaseAdapter]", plugin, core_sink) -> None: + def __init__(self, adapter_cls: "type[BaseAdapter]", plugin) -> None: self.adapter_cls = adapter_cls self.adapter_name = adapter_cls.adapter_name self.plugin = plugin self.process: mp.Process | None = None self._ctx = mp.get_context("spawn") - self._incoming_queue: mp.Queue = self._ctx.Queue() - self._outgoing_queue: mp.Queue = self._ctx.Queue() + self._incoming_queue: mp.Queue | None = None + self._outgoing_queue: mp.Queue | None = None self._bridge: ProcessCoreSinkServer | None = None - self._core_sink = core_sink self._adapter_path: tuple[str, str] = (adapter_cls.__module__, adapter_cls.__name__) self._plugin_info = self._extract_plugin_info(plugin) - self._outgoing_handler = None @staticmethod def _extract_plugin_info(plugin) -> dict | None: @@ -88,37 +104,28 @@ class AdapterProcess: "metadata": getattr(plugin, "plugin_meta", None), } - def _make_outgoing_handler(self): - async def _handler(envelope): - if self._bridge: - await self._bridge.push_outgoing(envelope) - return _handler - async def start(self) -> bool: """启动适配器子进程""" try: logger.info(f"启动适配器子进程: {self.adapter_name}") - self._bridge = ProcessCoreSinkServer( - incoming_queue=self._incoming_queue, - outgoing_queue=self._outgoing_queue, - core_handler=self._core_sink.send, - name=self.adapter_name, - ) - self._bridge.start() - if hasattr(self._core_sink, "set_outgoing_handler"): - self._outgoing_handler = self._make_outgoing_handler() - try: - self._core_sink.set_outgoing_handler(self._outgoing_handler) - except Exception: - logger.exception("Failed to register outgoing bridge for %s", self.adapter_name) + + # 从 CoreSinkManager 获取通信队列 + from src.common.core_sink_manager import get_core_sink_manager + + manager = get_core_sink_manager() + self._incoming_queue, self._outgoing_queue = manager.create_process_sink_queues(self.adapter_name) + + # 启动子进程 self.process = self._ctx.Process( target=_adapter_process_entry, args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue), name=f"{self.adapter_name}-proc", ) self.process.start() + logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})") return True + except Exception as e: logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True) return False @@ -127,26 +134,31 @@ class AdapterProcess: """停止适配器子进程""" if not self.process: return + logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})") + try: - remover = getattr(self._core_sink, "remove_outgoing_handler", None) - if callable(remover) and self._outgoing_handler: - try: - remover(self._outgoing_handler) - except Exception: - logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败") - if self._bridge: - await self._bridge.close() + # 从 CoreSinkManager 移除通信队列 + from src.common.core_sink_manager import get_core_sink_manager + + manager = get_core_sink_manager() + manager.remove_process_sink(self.adapter_name) + + # 等待子进程结束 if self.process.is_alive(): self.process.join(timeout=5.0) + if self.process.is_alive(): logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中") self.process.terminate() self.process.join() + except Exception as e: logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True) finally: self.process = None + self._incoming_queue = None + self._outgoing_queue = None def is_running(self) -> bool: """适配器是否正在运行""" @@ -155,7 +167,13 @@ class AdapterProcess: return self.process.is_alive() class AdapterManager: - """适配器管理器""" + """ + 适配器管理器 + + 负责管理所有注册的适配器,根据 run_in_subprocess 属性自动选择: + - run_in_subprocess=True: 在子进程中运行,使用 ProcessCoreSink + - run_in_subprocess=False: 在主进程中运行,使用 InProcessCoreSink + """ def __init__(self): # 注册信息:name -> (adapter class, plugin instance | None) @@ -178,14 +196,26 @@ class AdapterManager: self._adapter_defs[adapter_name] = (adapter_cls, plugin) adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown') - logger.info(f"注册适配器: {adapter_name} v{adapter_version}") + run_in_subprocess = getattr(adapter_cls, 'run_in_subprocess', False) + + logger.info( + f"注册适配器: {adapter_name} v{adapter_version} " + f"(子进程: {'是' if run_in_subprocess else '否'})" + ) async def start_adapter(self, adapter_name: str) -> bool: - """启动指定适配器""" + """ + 启动指定适配器 + + 根据适配器的 run_in_subprocess 属性自动选择: + - True: 创建子进程,使用 ProcessCoreSink + - False: 在当前进程,使用 InProcessCoreSink + """ definition = self._adapter_defs.get(adapter_name) if not definition: logger.error(f"适配器 {adapter_name} 未注册") return False + adapter_cls, plugin = definition run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False) @@ -193,15 +223,14 @@ class AdapterManager: return await self._start_adapter_subprocess(adapter_name, adapter_cls, plugin) return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin) - async def _start_adapter_subprocess(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool: - """在子进程中启动适配器""" - try: - core_sink = get_core_sink() - except Exception as e: - logger.error(f"无法获取 core_sink,启动子进程 {adapter_name} 失败: {e}", exc_info=True) - return False - - adapter_process = AdapterProcess(adapter_cls, plugin, core_sink) + async def _start_adapter_subprocess( + self, + adapter_name: str, + adapter_cls: type[BaseAdapter], + plugin + ) -> bool: + """在子进程中启动适配器(使用 ProcessCoreSink)""" + adapter_process = AdapterProcess(adapter_cls, plugin) success = await adapter_process.start() if success: @@ -209,15 +238,25 @@ class AdapterManager: return success - async def _start_adapter_in_process(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool: - """在当前进程中启动适配器""" + async def _start_adapter_in_process( + self, + adapter_name: str, + adapter_cls: type[BaseAdapter], + plugin + ) -> bool: + """在当前进程中启动适配器(使用 InProcessCoreSink)""" try: - core_sink = get_core_sink() + # 从 CoreSinkManager 获取 InProcessCoreSink + from src.common.core_sink_manager import get_core_sink_manager + + core_sink = get_core_sink_manager().get_in_process_sink() adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg] await adapter.start() + self._in_process_adapters[adapter_name] = adapter logger.info(f"适配器 {adapter_name} 已在当前进程启动") return True + except Exception as e: logger.error(f"启动适配器 {adapter_name} 失败: {e}", exc_info=True) return False diff --git a/src/plugins/built_in/NEW_napcat_adapter/plugin.py b/src/plugins/built_in/NEW_napcat_adapter/plugin.py index aeda7e44e..186bd6341 100644 --- a/src/plugins/built_in/NEW_napcat_adapter/plugin.py +++ b/src/plugins/built_in/NEW_napcat_adapter/plugin.py @@ -17,12 +17,13 @@ from typing import Any, ClassVar, Dict, List, Optional import orjson import websockets -from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions +from mofox_bus import CoreSink, MessageEnvelope, WebSocketAdapterOptions from src.common.logger import get_logger from src.plugin_system import register_plugin from src.plugin_system.base import BaseAdapter, BasePlugin from src.plugin_system.apis import config_api +from .src.handlers import utils as handler_utils from .src.handlers.to_core.message_handler import MessageHandler from .src.handlers.to_core.notice_handler import NoticeHandler from .src.handlers.to_core.meta_event_handler import MetaEventHandler @@ -41,22 +42,16 @@ class NapcatAdapter(BaseAdapter): platform = "qq" run_in_subprocess = False - subprocess_entry = None - def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None): + def __init__(self, core_sink: CoreSink, plugin: Optional[BasePlugin] = None): """初始化 Napcat 适配器""" # 从插件配置读取 WebSocket URL if plugin: - mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse") host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost") port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095) - url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "") access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "") - - if mode == "forward" and url: - ws_url = url - else: - ws_url = f"ws://{host}:{port}" + + ws_url = f"ws://{host}:{port}" headers = {} if access_token: @@ -69,8 +64,6 @@ class NapcatAdapter(BaseAdapter): transport = WebSocketAdapterOptions( url=ws_url, headers=headers if headers else None, - incoming_parser=self._parse_napcat_message, - outgoing_encoder=self._encode_napcat_response, ) super().__init__(core_sink, plugin=plugin, transport=transport) @@ -89,6 +82,9 @@ class NapcatAdapter(BaseAdapter): # 注意:_ws 继承自 BaseAdapter,是 WebSocketLike 协议类型 self._napcat_ws = None # 可选的额外连接引用 + # 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS + handler_utils.register_adapter(self) + async def on_adapter_loaded(self) -> None: """适配器加载时的初始化""" logger.info("Napcat 适配器正在启动...") @@ -114,22 +110,6 @@ class NapcatAdapter(BaseAdapter): logger.info("Napcat 适配器已关闭") - def _parse_napcat_message(self, raw: str | bytes) -> Any: - """解析 Napcat/OneBot 消息""" - try: - if isinstance(raw, bytes): - data = orjson.loads(raw) - else: - data = orjson.loads(raw) - return data - except Exception as e: - logger.error(f"解析 Napcat 消息失败: {e}") - raise - - def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes: - """编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)""" - return orjson.dumps(envelope) - async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override] """ 将 Napcat/OneBot 原始消息转换为 MessageEnvelope @@ -178,20 +158,6 @@ class NapcatAdapter(BaseAdapter): """ await self.send_handler.handle_message(envelope) - def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return] - """创建一个空的消息信封(用于不需要处理的事件)""" - import time - return { - "direction": "incoming", - "message_info": { - "platform": self.platform, - "message_id": str(uuid.uuid4()), - "time": time.time(), - }, - "message_segment": {"type": "text", "data": "[系统事件]"}, - "timestamp_ms": int(time.time() * 1000), - } - async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]: """ 发送 Napcat API 请求并等待响应 @@ -260,18 +226,12 @@ class NapcatAdapterPlugin(BasePlugin): config_schema: ClassVar[dict] = { "plugin": { "name": {"type": str, "default": "napcat_adapter_plugin"}, - "version": {"type": str, "default": "2.0.0"}, + "version": {"type": str, "default": "1.0.0"}, "enabled": {"type": bool, "default": True}, }, "napcat_server": { - "mode": { - "type": str, - "default": "reverse", - "description": "连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", - }, "host": {"type": str, "default": "localhost"}, "port": {"type": int, "default": 8095}, - "url": {"type": str, "default": "", "description": "正向连接时的完整URL"}, "access_token": {"type": str, "default": ""}, }, "features": { @@ -284,34 +244,18 @@ class NapcatAdapterPlugin(BasePlugin): }, } - def __init__(self, plugin_dir: str = "", metadata: Any = None): - # 如果没有提供参数,创建一个默认的元数据 - if metadata is None: - from src.plugin_system.base.plugin_metadata import PluginMetadata - metadata = PluginMetadata( - name=self.plugin_name, - version=self.plugin_version, - author=self.plugin_author, - description=self.plugin_description, - usage="", - dependencies=[], - python_dependencies=[], - ) - - if not plugin_dir: - from pathlib import Path - plugin_dir = str(Path(__file__).parent) - - super().__init__(plugin_dir, metadata) + def __init__(self): self._adapter: Optional[NapcatAdapter] = None async def on_plugin_loaded(self): """插件加载时启动适配器""" logger.info("Napcat 适配器插件正在加载...") - # 获取核心 Sink - from src.common.core_sink import get_core_sink - core_sink = get_core_sink() + # 从 CoreSinkManager 获取 InProcessCoreSink + from src.common.core_sink_manager import get_core_sink_manager + + core_sink_manager = get_core_sink_manager() + core_sink = core_sink_manager.get_in_process_sink() # 创建并启动适配器 self._adapter = NapcatAdapter(core_sink, plugin=self) diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py index fc9655b0f..8f92f123a 100644 --- a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py @@ -5,11 +5,28 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Any, Dict, List, Optional +from mofox_bus import MessageBuilder from src.common.logger import get_logger from src.plugin_system.apis import config_api +from mofox_bus import ( + MessageEnvelope, + SegPayload, + MessageInfoPayload, + UserInfoPayload, + GroupInfoPayload, +) + +from ...event_models import ACCEPT_FORMAT, QQ_FACE +from ..utils import ( + get_group_info, + get_image_base64, + get_self_info, + get_member_info, + get_message_detail, +) if TYPE_CHECKING: - from ...plugin import NapcatAdapter + from ....plugin import NapcatAdapter logger = get_logger("napcat_adapter.message_handler") @@ -28,99 +45,190 @@ class MessageHandler: async def handle_raw_message(self, raw: Dict[str, Any]): """ 处理原始消息并转换为 MessageEnvelope - + Args: raw: OneBot 原始消息数据 - + Returns: MessageEnvelope (dict) """ - from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload message_type = raw.get("message_type") message_id = str(raw.get("message_id", "")) message_time = time.time() - + + msg_builder = MessageBuilder() + # 构造用户信息 sender_info = raw.get("sender", {}) - user_info: UserInfoPayload = { - "platform": "qq", - "user_id": str(sender_info.get("user_id", "")), - "user_nickname": sender_info.get("nickname", ""), - "user_cardname": sender_info.get("card", ""), - "user_avatar": sender_info.get("avatar", ""), - } + + ( + msg_builder.direction("incoming") + .message_id(message_id) + .timestamp_ms(int(message_time * 1000)) + .from_user( + user_id=str(sender_info.get("user_id", "")), + platform="qq", + nickname=sender_info.get("nickname", ""), + cardname=sender_info.get("card", ""), + user_avatar=sender_info.get("avatar", ""), + ) + ) # 构造群组信息(如果是群消息) - group_info: Optional[GroupInfoPayload] = None if message_type == "group": group_id = raw.get("group_id") if group_id: - group_info = { - "platform": "qq", - "group_id": str(group_id), - "group_name": "", # 可以通过 API 获取 - } + fetched_group_info = await get_group_info(group_id) + ( + msg_builder.from_group( + group_id=str(group_id), + platform="qq", + name=( + fetched_group_info.get("group_name", "") + if fetched_group_info + else raw.get("group_name", "") + ), + ) + ) # 解析消息段 message_segments = raw.get("message", []) seg_list: List[SegPayload] = [] - - for seg in message_segments: - seg_type = seg.get("type", "") - seg_data = seg.get("data", {}) - - # 转换为 SegPayload - if seg_type == "text": - seg_list.append({ - "type": "text", - "data": seg_data.get("text", "") - }) - elif seg_type == "image": - # 这里需要下载图片并转换为 base64(简化版本) - seg_list.append({ - "type": "image", - "data": seg_data.get("url", "") # 实际应该转换为 base64 - }) - elif seg_type == "at": - seg_list.append({ - "type": "at", - "data": f"{seg_data.get('qq', '')}" - }) - # 其他消息类型... - # 构造 MessageInfoPayload - message_info = { - "platform": "qq", - "message_id": message_id, - "time": message_time, - "user_info": user_info, - "format_info": { - "content_format": ["text", "image"], # 根据实际消息类型设置 - "accept_format": ["text", "image", "emoji", "voice"], - }, - } - - # 添加群组信息(如果存在) - if group_info: - message_info["group_info"] = group_info + for segment in message_segments: + seg_message = await self.handle_single_segment(segment, raw) + if seg_message: + seg_list.append(seg_message) - # 构造 MessageEnvelope - envelope = { - "direction": "incoming", - "message_info": message_info, - "message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}), - "raw_message": raw.get("raw_message", ""), - "platform": "qq", - "message_id": message_id, - "timestamp_ms": int(message_time * 1000), - } + msg_builder.format_info( + content_format=[seg["type"] for seg in seg_list], + accept_format=ACCEPT_FORMAT, + ) - return envelope + return msg_builder.build() + async def handle_single_segment( + self, segment: dict, raw_message: dict, in_reply: bool = False + ) -> SegPayload | List[SegPayload] | None: + """ + 处理单一消息段并转换为 MessageEnvelope + Args: + segment: 单一原始消息段 + raw_message: 完整的原始消息数据 - - - - + Returns: + SegPayload | List[SegPayload] | None + """ + seg_type = segment.get("type") + seg_data: dict = segment.get("data", {}) + match seg_type: + case "text": + return {"type": "text", "data": seg_data.get("text", "")} + case "image": + image_sub_type = seg_data.get("sub_type") + try: + image_base64 = await get_image_base64(seg_data.get("url", "")) + except Exception as e: + logger.error(f"图片消息处理失败: {str(e)}") + return None + if image_sub_type == 0: + """这部分认为是图片""" + return {"type": "image", "data": image_base64} + elif image_sub_type not in [4, 9]: + """这部分认为是表情包""" + return {"type": "emoji", "data": image_base64} + else: + logger.warning(f"不支持的图片子类型:{image_sub_type}") + return None + case "face": + message_data: dict = segment.get("data", {}) + face_raw_id: str = str(message_data.get("id")) + if face_raw_id in QQ_FACE: + face_content: str = QQ_FACE.get(face_raw_id, "[未知表情]") + return {"type": "text", "data": face_content} + else: + logger.warning(f"不支持的表情:{face_raw_id}") + return None + case "at": + if seg_data: + qq_id = seg_data.get("qq") + self_id = raw_message.get("self_id") + group_id = raw_message.get("group_id") + if str(self_id) == str(qq_id): + logger.debug("机器人被at") + self_info = await get_self_info() + if self_info: + # 返回包含昵称和用户ID的at格式,便于后续处理 + return { + "type": "at", + "data": f"{self_info.get('nickname')}:{self_info.get('user_id')}", + } + else: + return None + else: + if qq_id and group_id: + member_info = await get_member_info( + group_id=group_id, user_id=qq_id + ) + if member_info: + # 返回包含昵称和用户ID的at格式,便于后续处理 + return { + "type": "at", + "data": f"{member_info.get('nickname')}:{member_info.get('user_id')}", + } + else: + return None + case "emoji": + seg_data = segment.get("id", "") + case "reply": + if not in_reply: + message_id = None + if seg_data: + message_id = seg_data.get("id") + else: + return None + message_detail = await get_message_detail(message_id) + if not message_detail: + logger.warning("获取被引用的消息详情失败") + return None + reply_message = await self.handle_single_segment( + message_detail, raw_message, in_reply=True + ) + if reply_message is None: + reply_message = [ + {"type": "text", "data": "[无法获取被引用的消息]"} + ] + sender_info: dict = message_detail.get("sender", {}) + sender_nickname: str = sender_info.get("nickname", "") + sender_id = sender_info.get("user_id") + seg_message: List[SegPayload] = [] + if not sender_nickname: + logger.warning("无法获取被引用的人的昵称,返回默认值") + seg_message.append( + { + "type": "text", + "data": f"[回复<未知用户>:{reply_message}],说:", + } + ) + else: + if sender_id: + seg_message.append( + { + "type": "text", + "data": f"[回复<{sender_nickname}({sender_id})>:{reply_message}],说:", + } + ) + else: + seg_message.append( + { + "type": "text", + "data": f"[回复<{sender_nickname}>:{reply_message}],说:", + } + ) + return seg_message + case "voice": + seg_data = segment.get("url", "") + case _: + logger.warning(f"Unsupported segment type: {seg_type}") diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/request_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/request_handler.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/utils.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/utils.py new file mode 100644 index 000000000..37bdf81dc --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/utils.py @@ -0,0 +1,361 @@ +import asyncio +import base64 +import io +import ssl +import time +import weakref +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +import orjson +import urllib3 +from PIL import Image + +from src.common.logger import get_logger + +if TYPE_CHECKING: + from ...plugin import NapcatAdapter + +logger = get_logger("napcat_adapter") + +# 简单的缓存实现,通过 JSON 文件实现磁盘一价存储 +_CACHE_FILE = Path(__file__).resolve().parent / "napcat_cache.json" +_CACHE_LOCK = asyncio.Lock() +_CACHE: Dict[str, Dict[str, Dict[str, Any]]] = { + "group_info": {}, + "group_detail_info": {}, + "member_info": {}, + "stranger_info": {}, + "self_info": {}, +} + +# 各类信息的 TTL 缓存过期时间设置 +GROUP_INFO_TTL = 300 # 5 min +GROUP_DETAIL_TTL = 300 +MEMBER_INFO_TTL = 180 +STRANGER_INFO_TTL = 300 +SELF_INFO_TTL = 300 + +_adapter_ref: weakref.ReferenceType["NapcatAdapter"] | None = None + + +def register_adapter(adapter: "NapcatAdapter") -> None: + """注册 NapcatAdapter 实例,以便 utils 模块可以获取 WebSocket""" + global _adapter_ref + _adapter_ref = weakref.ref(adapter) + logger.debug("Napcat adapter registered in utils for websocket access") + + +def _load_cache_from_disk() -> None: + if not _CACHE_FILE.exists(): + return + try: + data = orjson.loads(_CACHE_FILE.read_bytes()) + if isinstance(data, dict): + for key, section in _CACHE.items(): + cached_section = data.get(key) + if isinstance(cached_section, dict): + section.update(cached_section) + except Exception as e: + logger.debug(f"Failed to load napcat cache: {e}") + + +def _save_cache_to_disk_locked() -> None: + """重要提示:不要在持有 _CACHE_LOCK 时调用此函数""" + _CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) + _CACHE_FILE.write_bytes(orjson.dumps(_CACHE)) + + +async def _get_cached(section: str, key: str, ttl: int) -> Any | None: + now = time.time() + async with _CACHE_LOCK: + entry = _CACHE.get(section, {}).get(key) + if not entry: + return None + ts = entry.get("ts", 0) + if ts and now - ts <= ttl: + return entry.get("data") + _CACHE.get(section, {}).pop(key, None) + try: + _save_cache_to_disk_locked() + except Exception: + pass + return None + + +async def _set_cached(section: str, key: str, data: Any) -> None: + async with _CACHE_LOCK: + _CACHE.setdefault(section, {})[key] = {"data": data, "ts": time.time()} + try: + _save_cache_to_disk_locked() + except Exception: + logger.debug("Write napcat cache failed", exc_info=True) + + +def _get_adapter(adapter: "NapcatAdapter | None" = None) -> "NapcatAdapter": + target = adapter + if target is None and _adapter_ref: + target = _adapter_ref() + if target is None: + raise RuntimeError("NapcatAdapter 未注册,请确保已调用 utils.register_adapter 注册") + return target + + +async def _call_adapter_api( + action: str, + params: Dict[str, Any], + adapter: "NapcatAdapter | None" = None, + timeout: float = 30.0, +) -> Optional[Dict[str, Any]]: + """统一通过 adapter 发送和接收 API 调用""" + try: + target = _get_adapter(adapter) + # 确保 WS 已连接 + target.get_ws_connection() + except Exception as e: # pragma: no cover - 难以在单元测试中查看 + logger.error(f"WebSocket 未准备好,无法调用 API: {e}") + return None + + try: + return await target.send_napcat_api(action, params, timeout=timeout) + except Exception as e: + logger.error(f"{action} 调用失败: {e}") + return None + + +# 加载缓存到内存一次,避免在每次调用缓存时重复加载 +_load_cache_from_disk() + + +class SSLAdapter(urllib3.PoolManager): + def __init__(self, *args, **kwargs): + context = ssl.create_default_context() + context.set_ciphers("DEFAULT@SECLEVEL=1") + context.minimum_version = ssl.TLSVersion.TLSv1_2 + kwargs["ssl_context"] = context + super().__init__(*args, **kwargs) + + +async def get_group_info( + group_id: int, + *, + use_cache: bool = True, + force_refresh: bool = False, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取群组基本信息 + + 返回值可能是None,需要调用方检查空值 + """ + logger.debug("获取群组基本信息中") + cache_key = str(group_id) + if use_cache and not force_refresh: + cached = await _get_cached("group_info", cache_key, GROUP_INFO_TTL) + if cached is not None: + return cached + + socket_response = await _call_adapter_api( + "get_group_info", + {"group_id": group_id}, + adapter=adapter, + ) + data = socket_response.get("data") if socket_response else None + if data is not None and use_cache: + await _set_cached("group_info", cache_key, data) + return data + + +async def get_group_detail_info( + group_id: int, + *, + use_cache: bool = True, + force_refresh: bool = False, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取群组详细信息 + + 返回值可能是None,需要调用方检查空值 + """ + logger.debug("获取群组详细信息中") + cache_key = str(group_id) + if use_cache and not force_refresh: + cached = await _get_cached("group_detail_info", cache_key, GROUP_DETAIL_TTL) + if cached is not None: + return cached + + socket_response = await _call_adapter_api( + "get_group_detail_info", + {"group_id": group_id}, + adapter=adapter, + ) + data = socket_response.get("data") if socket_response else None + if data is not None and use_cache: + await _set_cached("group_detail_info", cache_key, data) + return data + + +async def get_member_info( + group_id: int, + user_id: int, + *, + use_cache: bool = True, + force_refresh: bool = False, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取群组成员信息 + + 返回值可能是None,需要调用方检查空值 + """ + logger.debug("获取群组成员信息中") + cache_key = f"{group_id}:{user_id}" + if use_cache and not force_refresh: + cached = await _get_cached("member_info", cache_key, MEMBER_INFO_TTL) + if cached is not None: + return cached + + socket_response = await _call_adapter_api( + "get_group_member_info", + {"group_id": group_id, "user_id": user_id, "no_cache": True}, + adapter=adapter, + ) + data = socket_response.get("data") if socket_response else None + if data is not None and use_cache: + await _set_cached("member_info", cache_key, data) + return data + + +async def get_image_base64(url: str) -> str: + # sourcery skip: raise-specific-error + """下载图片/视频并返回Base64""" + logger.debug(f"下载图片: {url}") + http = SSLAdapter() + try: + response = http.request("GET", url, timeout=10) + if response.status != 200: + raise Exception(f"HTTP Error: {response.status}") + image_bytes = response.data + return base64.b64encode(image_bytes).decode("utf-8") + except Exception as e: + logger.error(f"图片下载失败: {str(e)}") + raise + + +def convert_image_to_gif(image_base64: str) -> str: + # sourcery skip: extract-method + """ + 将Base64编码的图片转换为GIF格式 + Parameters: + image_base64: str: Base64编码的图片数据 + Returns: + str: Base64编码的GIF图片数据 + """ + logger.debug("转换图片为GIF格式") + try: + image_bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) + output_buffer = io.BytesIO() + image.save(output_buffer, format="GIF") + output_buffer.seek(0) + return base64.b64encode(output_buffer.read()).decode("utf-8") + except Exception as e: + logger.error(f"图片转换为GIF失败: {str(e)}") + return image_base64 + + +async def get_self_info( + *, + use_cache: bool = True, + force_refresh: bool = False, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取机器人信息 + """ + logger.debug("获取机器人信息中") + cache_key = "self" + if use_cache and not force_refresh: + cached = await _get_cached("self_info", cache_key, SELF_INFO_TTL) + if cached is not None: + return cached + + response = await _call_adapter_api("get_login_info", {}, adapter=adapter) + data = response.get("data") if response else None + if data is not None and use_cache: + await _set_cached("self_info", cache_key, data) + return data + + +def get_image_format(raw_data: str) -> str: + """ + 从Base64编码的数据中确定图片的格式类型 + Parameters: + raw_data: str: Base64编码的图片数据 + Returns: + format: str: 图片的格式类型,如 'jpeg', 'png', 'gif'等 + """ + image_bytes = base64.b64decode(raw_data) + return Image.open(io.BytesIO(image_bytes)).format.lower() + + +async def get_stranger_info( + user_id: int, + *, + use_cache: bool = True, + force_refresh: bool = False, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取陌生人信息 + """ + logger.debug("获取陌生人信息中") + cache_key = str(user_id) + if use_cache and not force_refresh: + cached = await _get_cached("stranger_info", cache_key, STRANGER_INFO_TTL) + if cached is not None: + return cached + + response = await _call_adapter_api("get_stranger_info", {"user_id": user_id}, adapter=adapter) + data = response.get("data") if response else None + if data is not None and use_cache: + await _set_cached("stranger_info", cache_key, data) + return data + + +async def get_message_detail( + message_id: Union[str, int], + *, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取消息详情,仅作为参考 + """ + logger.debug("获取消息详情中") + response = await _call_adapter_api( + "get_msg", + {"message_id": message_id}, + adapter=adapter, + timeout=30, + ) + return response.get("data") if response else None + + +async def get_record_detail( + file: str, + file_id: Optional[str] = None, + *, + adapter: "NapcatAdapter | None" = None, +) -> dict | None: + """ + 获取语音信息详情 + """ + logger.debug("获取语音信息详情中") + response = await _call_adapter_api( + "get_record", + {"file": file, "file_id": file_id, "out_format": "wav"}, + adapter=adapter, + timeout=30, + ) + return response.get("data") if response else None diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 91f84b5ad..9ec950bc8 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -3,7 +3,7 @@ import random import time import websockets as Server import uuid -from mofox_bus import ( +from maim_message import ( UserInfo, GroupInfo, Seg, diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 74877c9be..775ef6c36 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.8.2" +version = "7.8.3" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值