重构并增强Napcat适配器的功能

- 更新了`BaseAdapter`以简化子进程处理。
- 对`AdapterManager`进行了重构,以便根据适配器的`run_in_subprocess`属性来管理适配器。
- 增强了`NapcatAdapter`,以利用新的`CoreSinkManager`实现更优的进程管理。
- 在`utils.py`中实现了针对群组和成员信息的缓存机制。
- 改进了`message_handler.py`中的消息处理,以支持各种消息类型和格式。
- 已将插件配置版本更新至7.8.3。
This commit is contained in:
Windpicker-owo
2025-11-25 19:55:36 +08:00
parent 1ebdc37b22
commit 6b3b2a8245
38 changed files with 2082 additions and 3277 deletions

View File

@@ -0,0 +1,222 @@
# MoFox Bot 消息运行时架构 (MessageRuntime)
本文档描述了 MoFox Bot 使用 `mofox_bus.MessageRuntime` 简化消息处理链条的架构设计。
## 架构概述
```
┌─────────────────────────────────────────────────────────────────────────┐
│ CoreSinkManager │
│ ┌─────────────────────────────────────────────────────────────────────┐│
│ │ MessageRuntime ││
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ││
│ │ │ before_hook │→ │ Routes │→ │ after_hook │ ││
│ │ │ (预处理/过滤) │ │ (消息路由) │ │ (后处理) │ ││
│ │ └──────────────┘ └──────────────┘ └──────────────┘ ││
│ │ ↓ ↓ ↓ ││
│ │ ┌──────────────────────────────────────────────────────────────┐ ││
│ │ │ error_hook (错误处理) │ ││
│ │ └──────────────────────────────────────────────────────────────┘ ││
│ └─────────────────────────────────────────────────────────────────────┘│
│ │
│ ┌──────────────────────┐ ┌──────────────────────────────────────┐ │
│ │ InProcessCoreSink │ │ ProcessCoreSinkServer (子进程适配器) │ │
│ │ (同进程适配器) │ │ │ │
│ └──────────────────────┘ └──────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
↑ ↑
│ │
┌─────────────────────┐ ┌─────────────────────┐
│ 同进程适配器 │ │ 子进程适配器 │
│ (run_in_subprocess │ │ (run_in_subprocess │
│ = False) │ │ = True) │
└─────────────────────┘ └─────────────────────┘
```
## 核心组件
### 1. CoreSinkManager (`src/common/core_sink_manager.py`)
统一管理 CoreSink 双实例和 MessageRuntime
```python
from src.common.core_sink_manager import get_core_sink_manager, get_message_runtime
# 获取管理器
manager = get_core_sink_manager()
# 获取 MessageRuntime
runtime = get_message_runtime()
# 发送消息到适配器
await manager.send_outgoing(envelope)
```
### 2. MessageRuntime
`MessageRuntime` 是 mofox_bus 提供的消息路由核心,支持:
- **消息路由**:通过 `add_route()``@on_message` 装饰器按消息类型路由
- **钩子机制**`before_hook`(前置处理)、`after_hook`(后置处理)、`error_hook`(错误处理)
- **中间件**:洋葱模型的中间件机制
- **批量处理**:支持 `handle_batch()` 批量处理消息
### 3. MessageHandler (`src/chat/message_receive/message_handler.py`)
将消息处理逻辑注册为 MessageRuntime 的路由和钩子:
```python
class MessageHandler:
def register_handlers(self, runtime: MessageRuntime) -> None:
# 注册前置钩子
runtime.register_before_hook(self._before_hook)
# 注册后置钩子
runtime.register_after_hook(self._after_hook)
# 注册错误钩子
runtime.register_error_hook(self._error_hook)
# 注册适配器响应处理器
runtime.add_route(
predicate=_is_adapter_response,
handler=self._handle_adapter_response_route,
name="adapter_response_handler",
message_type="adapter_response",
)
# 注册默认消息处理器
runtime.add_route(
predicate=lambda _: True,
handler=self._handle_normal_message,
name="default_message_handler",
)
```
## 消息流向
### 接收消息
```
适配器 → InProcessCoreSink/ProcessCoreSinkServer → CoreSinkManager._dispatch_to_runtime()
→ MessageRuntime.handle_message()
→ before_hook (预处理、过滤)
→ 匹配路由 (adapter_response / normal_message)
→ 执行处理器
→ after_hook (后处理)
```
### 发送消息
```
消息发送请求 → CoreSinkManager.send_outgoing()
→ InProcessCoreSink.push_outgoing()
→ ProcessCoreSinkServer.push_outgoing()
→ 适配器
```
## 钩子功能
### before_hook
在消息路由之前执行,用于:
- 标准化 ID 为字符串
- 检查 echo 消息(自身消息上报)
- 通过抛出 `UserWarning` 跳过消息处理
### after_hook
在消息处理完成后执行,用于:
- 清理工作
- 日志记录
### error_hook
在处理过程中出现异常时执行,用于:
- 区分预期的流程控制UserWarning和真正的错误
- 统一异常日志记录
## 路由优先级
1. **明确指定 message_type 的路由**(优先级最高)
2. **事件路由**(基于 event_type
3. **通用路由**(无 message_type 限制)
## 扩展消息处理
### 注册自定义处理器
```python
from src.common.core_sink_manager import get_message_runtime
from mofox_bus import MessageEnvelope
runtime = get_message_runtime()
# 使用装饰器
@runtime.on_message(message_type="image")
async def handle_image(envelope: MessageEnvelope):
# 处理图片消息
pass
# 或使用 add_route
runtime.add_route(
predicate=lambda env: env.get("platform") == "qq",
handler=my_handler,
name="qq_handler",
)
```
### 注册钩子
```python
runtime = get_message_runtime()
# 前置钩子
async def my_before_hook(envelope: MessageEnvelope) -> None:
# 预处理逻辑
pass
runtime.register_before_hook(my_before_hook)
# 错误钩子
async def my_error_hook(envelope: MessageEnvelope, exc: BaseException) -> None:
# 错误处理逻辑
pass
runtime.register_error_hook(my_error_hook)
```
## 初始化流程
`MainSystem.initialize()` 中:
1. 初始化 `CoreSinkManager`(包含 `MessageRuntime`
2. 获取 `MessageHandler` 并设置 `CoreSinkManager` 引用
3. 调用 `MessageHandler.register_handlers()``MessageRuntime` 注册处理器和钩子
4. 初始化其他组件
```python
async def initialize(self) -> None:
# 初始化 CoreSinkManager包含 MessageRuntime
self.core_sink_manager = await initialize_core_sink_manager()
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
self.message_handler = get_message_handler()
self.message_handler.set_core_sink_manager(self.core_sink_manager)
self.message_handler.register_handlers(self.core_sink_manager.runtime)
```
## 优势
1. **简化消息处理链**:不再需要手动管理处理流程,使用声明式路由
2. **更好的可扩展性**:通过 `add_route()` 或装饰器轻松添加新的处理器
3. **统一的错误处理**:通过 `error_hook` 集中处理异常
4. **支持中间件**:可以添加洋葱模型的中间件
5. **更清晰的代码结构**:处理逻辑按类型分离
## 参考
- `packages/mofox-bus/src/mofox_bus/runtime.py` - MessageRuntime 实现
- `src/common/core_sink_manager.py` - CoreSinkManager 实现
- `src/chat/message_receive/message_handler.py` - MessageHandler 实现
- `docs/mofox_bus.md` - MoFox Bus 消息库说明

View File

@@ -2,6 +2,8 @@
MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_message`,将核心与各平台适配器之间的通信抽象成可拓展、可热插拔的组件。该库完全异步、面向高吞吐,覆盖消息建模、序列化、传输层、运行时路由、适配器工具等多个层面。
> 现在已拆分为独立 pip 包:在项目根目录执行 `pip install -e ./packages/mofox-bus` 即可安装到当前 Python 环境。
---
## 1. 设计目标
@@ -14,7 +16,7 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess
---
## 2. 包结构概览(`src/mofox_bus/`
## 2. 包结构概览(`packages/mofox-bus/src/mofox_bus/`
| 模块 | 主要职责 |
| --- | --- |

View File

@@ -8,7 +8,6 @@
from __future__ import annotations
import asyncio
import sys
import time
import uuid
from pathlib import Path
@@ -17,8 +16,6 @@ from typing import Any, Dict, Optional
import orjson
import websockets
# 追加 src 目录,便于直接运行示例
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
from mofox_bus import (
AdapterBase,

View File

@@ -78,6 +78,7 @@ dependencies = [
"inkfox>=0.1.1",
"rjieba>=0.1.13",
"fastmcp>=2.13.0",
"mofox-bus",
]
[[tool.uv.index]]

View File

@@ -1,4 +1,3 @@
sqlalchemy
aiosqlite
aiofiles
aiomysql

View File

@@ -121,8 +121,6 @@ class ChatterManager:
chatter_class = self.get_chatter_class(chat_type)
if not chatter_class:
from src.plugin_system.base.component_types import ChatType
all_chatter_class = self.get_chatter_class(ChatType.ALL)
if all_chatter_class:
chatter_class = all_chatter_class

View File

@@ -8,9 +8,10 @@ import random
import time
from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager
if TYPE_CHECKING:
from src.chat.chatter_manager import ChatterManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
from src.common.logger import get_logger
@@ -21,7 +22,7 @@ from .distribution_manager import stream_loop_manager
from .global_notice_manager import NoticeScope, global_notice_manager
if TYPE_CHECKING:
pass
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("message_manager")
@@ -39,6 +40,8 @@ class MessageManager:
# 初始化chatter manager
self.action_manager = ChatterActionManager()
# 延迟导入ChatterManager以避免循环导入
from src.chat.chatter_manager import ChatterManager
self.chatter_manager = ChatterManager(self.action_manager)
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context

View File

@@ -1,488 +0,0 @@
import os
import re
import traceback
from typing import Any
from mofox_bus.runtime import MessageRuntime
from mofox_bus import MessageEnvelope
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.plugin_system.base import BaseCommand, EventType
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
# 配置主程序日志格式
logger = get_logger("chat")
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否包含过滤词
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否包含过滤词
"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
return False
runtime = MessageRuntime() # 获取mofox-bus运行时环境
class ChatBot:
def __init__(self):
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
# 启动消息管理器
self._message_manager_started = False
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
logger.debug("确保ChatBot所有任务已启动")
# 启动消息管理器
if not self._message_manager_started:
await message_manager.start()
self._message_manager_started = True
logger.info("消息管理器已启动")
self._started = True
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
"""独立处理PlusCommand系统"""
try:
text = message.processed_plain_text or ""
# 获取配置的命令前缀
from src.config.config import global_config
prefixes = global_config.command.command_prefixes
# 检查是否以任何前缀开头
matched_prefix = None
for prefix in prefixes:
if text.startswith(prefix):
matched_prefix = prefix
break
if not matched_prefix:
return False, None, True # 不是命令,继续处理
# 移除前缀
command_part = text[len(matched_prefix) :].strip()
# 分离命令名和参数
parts = command_part.split(None, 1)
if not parts:
return False, None, True # 没有命令名,继续处理
command_word = parts[0].lower()
args_text = parts[1] if len(parts) > 1 else ""
# 查找匹配的PlusCommand
plus_command_registry = component_registry.get_plus_command_registry()
matching_commands = []
for plus_command_name, plus_command_class in plus_command_registry.items():
plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name)
if not plus_command_info:
continue
# 检查命令名是否匹配(命令名和别名)
all_commands = [plus_command_name.lower()] + [
alias.lower() for alias in plus_command_info.command_aliases
]
if command_word in all_commands:
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
if not matching_commands:
return False, None, True # 没有找到匹配的PlusCommand继续处理
# 如果有多个匹配,按优先级排序
if len(matching_commands) > 1:
matching_commands.sort(key=lambda x: x[1].priority, reverse=True)
logger.warning(
f"文本 '{text}' 匹配到多个PlusCommand: {[cmd[2] for cmd in matching_commands]},使用优先级最高的"
)
plus_command_class, plus_command_info, plus_command_name = matching_commands[0]
# 检查命令是否被禁用
if (
chat
and chat.stream_id
and plus_command_name
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的PlusCommand跳过处理")
return False, None, True
message.is_command = True
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plus_command_name)
# 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config)
# 为插件实例设置 chat_stream 运行时属性
setattr(plus_command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed():
is_group = chat.group_info is not None
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
return False, None, True # 跳过此命令,继续处理其他消息
# 设置参数
from src.plugin_system.base.command_args import CommandArgs
command_args = CommandArgs(args_text)
plus_command_instance.args = command_args
# 执行命令
success, response, intercept_message = await plus_command_instance.execute(command_args)
# 记录命令执行结果
if success:
logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})")
else:
logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not intercept_message # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}")
logger.error(traceback.format_exc())
try:
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
except Exception as e:
logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True # 出错时继续处理消息
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
text = message.processed_plain_text or ""
# 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text)
if command_result:
command_class, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
chat
and chat.stream_id
and command_name
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
message.is_command = True
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name)
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
# 为插件实例设置 chat_stream 运行时属性
setattr(command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not command_instance.is_chat_type_allowed():
is_group = chat.group_info is not None
logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
return False, None, True # 跳过此命令,继续处理其他消息
# 执行命令
success, response, intercept_message = await command_instance.execute()
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not intercept_message # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
# 没有找到命令,继续处理消息
return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def _handle_adapter_response_from_dict(self, seg_data: dict | None):
"""处理适配器命令响应(从字典数据)"""
try:
from src.plugin_system.apis.send_api import put_adapter_response
if isinstance(seg_data, dict):
request_id = seg_data.get("request_id")
response_data = seg_data.get("response")
else:
request_id = None
response_data = None
if request_id and response_data:
logger.info(f"[DEBUG bot.py] 收到适配器响应request_id={request_id}")
put_adapter_response(request_id, response_data)
else:
logger.warning(f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}")
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
@runtime.on_message
async def message_process(self, envelope: MessageEnvelope) -> None:
"""处理转化后的统一格式消息"""
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
message_info = envelope.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(envelope.keys()),
)
return
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( # type: ignore
message_info["group_info"]["group_id"] # type: ignore
)
if message_info.get("user_info") is not None:
message_info["user_info"]["user_id"] = str( # type: ignore
message_info["user_info"]["user_id"] # type: ignore
)
# 优先处理adapter_response消息在echo检查之前
message_segment = envelope.get("message_segment")
if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response":
logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data"))
return
# 先提取基础信息检查是否是自身消息上报
from mofox_bus import BaseMessageInfo
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
if temp_message_info.additional_config:
sent_message = temp_message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
# 直接使用消息字典更新,不再需要创建 MessageRecv
await MessageStorage.update_message(message_data)
return
message_segment = envelope.get("message_segment")
group_info = temp_message_info.group_info
user_info = temp_message_info.user_info
# 获取或创建聊天流
chat = await get_chat_manager().get_or_create_stream(
platform=temp_message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
# 使用新的消息处理器直接生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
logger.info(
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
)
# 在此添加硬编码过滤,防止回复图片处理失败的消息
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return
# 过滤检查
# DatabaseMessages 使用 display_message 作为原始消息表示
raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
raw_text,
chat,
user_info, # type: ignore
):
return
# 命令处理 - 首先尝试PlusCommand独立处理
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
# 如果是PlusCommand且不需要继续处理则直接返回
if is_plus_command and not plus_continue_process:
await MessageStorage.store_message(message, chat)
logger.info(f"PlusCommand处理完成跳过后续消息处理: {plus_cmd_result}")
return
# 如果不是PlusCommand尝试传统的BaseCommand处理
if not is_plus_command:
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
await MessageStorage.store_message(message, chat)
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
if result and not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
# 确认从接口发来的message是否有自定义的prompt模板信息
# 这个功能需要在 adapter 层通过 additional_config 传递
template_group_name = None
async def preprocess():
# message 已经是 DatabaseMessages直接使用
group_info = chat.group_info
# 先交给消息管理器处理,计算兴趣度等衍生数据
try:
# 在将消息添加到管理器之前进行最终的静默检查
should_process_in_manager = True
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
# 检查消息是否为图片或表情包
is_image_or_emoji = message.is_picid or message.is_emoji
if not message.is_mentioned and not is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
should_process_in_manager = False
elif is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
should_process_in_manager = False
if should_process_in_manager:
await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
# 存储消息到数据库,只进行一次写入
try:
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 情绪系统更新 - 在消息存储后触发情绪更新
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:
logger.error(f"更新情绪状态失败: {e}")
traceback.print_exc()
if template_group_name:
async with global_prompt_manager.async_message_scope(template_group_name):
await preprocess()
else:
await preprocess()
except Exception as e:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
# 创建全局ChatBot实例
chat_bot = ChatBot()

View File

@@ -2,11 +2,11 @@ import asyncio
import hashlib
import time
from mofox_bus import GroupInfo, UserInfo
from rich.traceback import install
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
@@ -30,8 +30,8 @@ class ChatStream:
self,
stream_id: str,
platform: str,
user_info: UserInfo | None = None,
group_info: GroupInfo | None = None,
user_info: DatabaseUserInfo | None = None,
group_info: DatabaseGroupInfo | None = None,
data: dict | None = None,
):
self.stream_id = stream_id
@@ -77,8 +77,8 @@ class ChatStream:
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
user_info = DatabaseUserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = DatabaseGroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
instance = cls(
stream_id=data["stream_id"],
@@ -369,7 +369,7 @@ class ChatManager:
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
@@ -392,7 +392,7 @@ class ChatManager:
return hashlib.sha256(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None
) -> ChatStream:
"""获取或创建聊天流 - 优化版本使用缓存机制"""
try:
@@ -483,7 +483,7 @@ class ChatManager:
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None
) -> ChatStream | None:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)

View File

@@ -1,12 +1,11 @@
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import Optional, TYPE_CHECKING
import urllib3
from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
@@ -14,6 +13,9 @@ from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3)

View File

@@ -1,39 +1,566 @@
import os
import traceback
"""
统一消息处理器 (Message Handler)
利用 mofox_bus.MessageRuntime 的路由功能,简化消息处理链条:
1. 使用 @runtime.on_message() 装饰器注册按消息类型路由的处理器
2. 使用 before_hook 进行消息预处理ID标准化、过滤等
3. 使用 after_hook 进行消息后处理(存储、情绪更新等)
4. 使用 error_hook 统一处理异常
消息流向:
适配器 → CoreSinkManager → MessageRuntime
[before_hook] 消息预处理、过滤
[on_message] 按类型路由处理(命令、普通消息等)
[after_hook] 存储、情绪更新等
回复生成 → CoreSinkManager.send_outgoing() → 适配器
重构说明2025-11
- 移除手动的消息处理链,改用 MessageRuntime 路由
- MessageHandler 变成处理器注册器,在初始化时注册各种处理器
- 利用 runtime 的钩子机制简化前置/后置处理
"""
from __future__ import annotations
import asyncio
import os
import re
import time
import traceback
from functools import partial
from typing import TYPE_CHECKING, Any
from mofox_bus import MessageEnvelope, MessageRuntime
from mofox_bus.runtime import MessageRuntime
from mofox_bus import MessageEnvelope
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages
from src.mood.mood_manager import mood_manager
from src.plugin_system.base import BaseCommand, EventType
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
runtime = MessageRuntime()
if TYPE_CHECKING:
from src.common.core_sink_manager import CoreSinkManager
from src.chat.message_receive.chat_stream import ChatStream
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
logger = get_logger("message_handler")
# 项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否包含过滤词"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
return False
# 配置主程序日志格式
logger = get_logger("chat")
class MessageHandler:
"""
统一消息处理器
利用 MessageRuntime 的路由功能,将消息处理逻辑注册为路由和钩子。
架构说明:
- 在 register_handlers() 中向 MessageRuntime 注册各种处理器
- 使用 @runtime.on_message(message_type=...) 按消息类型路由
- 使用 before_hook 进行消息预处理
- 使用 after_hook 进行消息后处理
- 使用 error_hook 统一处理异常
主要功能:
1. 消息预处理ID标准化、过滤检查
2. 适配器响应处理:处理 adapter_response 类型消息
3. 命令处理PlusCommand 和 BaseCommand
4. 普通消息处理:触发事件、存储、情绪更新
"""
def __init__(self):
self._started = False
self._message_manager_started = False
self._core_sink_manager: CoreSinkManager | None = None
self._shutting_down = False
self._runtime: MessageRuntime | None = None
async def preprocess(self, chat: ChatStream, message: DatabaseMessages):
# message 已经是 DatabaseMessages直接使用
group_info = chat.group_info
def set_core_sink_manager(self, manager: "CoreSinkManager") -> None:
"""设置 CoreSinkManager 引用"""
self._core_sink_manager = manager
# 先交给消息管理器处理
def register_handlers(self, runtime: MessageRuntime) -> None:
"""
向 MessageRuntime 注册消息处理器和钩子
这是核心方法,在系统初始化时调用,将所有处理逻辑注册到 runtime。
Args:
runtime: MessageRuntime 实例
"""
self._runtime = runtime
# 注册前置钩子:消息预处理和过滤
runtime.register_before_hook(self._before_hook)
# 注册后置钩子:存储、情绪更新等
runtime.register_after_hook(self._after_hook)
# 注册错误钩子:统一异常处理
runtime.register_error_hook(self._error_hook)
# 注册适配器响应处理器(最高优先级)
def _is_adapter_response(env: MessageEnvelope) -> bool:
segment = env.get("message_segment")
if isinstance(segment, dict):
return segment.get("type") == "adapter_response"
return False
runtime.add_route(
predicate=_is_adapter_response,
handler=self._handle_adapter_response_route,
name="adapter_response_handler",
message_type="adapter_response",
)
# 注册默认消息处理器(处理所有其他消息)
runtime.add_route(
predicate=lambda _: True, # 匹配所有消息
handler=self._handle_normal_message,
name="default_message_handler",
)
logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子")
async def ensure_started(self) -> None:
"""确保所有依赖任务已启动"""
if not self._started:
logger.debug("确保 MessageHandler 所有任务已启动")
# 启动消息管理器
if not self._message_manager_started:
await message_manager.start()
self._message_manager_started = True
logger.info("消息管理器已启动")
self._started = True
async def _before_hook(self, envelope: MessageEnvelope) -> None:
"""
前置钩子:消息预处理
1. 标准化 ID 为字符串
2. 检查是否为 echo 消息(自身发送的消息上报)
3. 附加预处理数据到 envelopechat_stream, message 等)
"""
if self._shutting_down:
raise UserWarning("系统正在关闭,拒绝处理消息")
# 确保依赖服务已启动
await self.ensure_started()
# 提取消息信息
message_info = envelope.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(envelope.keys()),
)
raise UserWarning("消息缺少 message_info")
# 标准化 ID 为字符串
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( # type: ignore
message_info["group_info"]["group_id"] # type: ignore
)
if message_info.get("user_info") is not None:
message_info["user_info"]["user_id"] = str( # type: ignore
message_info["user_info"]["user_id"] # type: ignore
)
# 处理自身消息上报echo
additional_config = message_info.get("additional_config", {})
if additional_config and isinstance(additional_config, dict):
sent_message = additional_config.get("echo", False)
if sent_message:
# 更新消息ID
await MessageStorage.update_message(dict(envelope))
raise UserWarning("Echo 消息已处理")
async def _after_hook(self, envelope: MessageEnvelope) -> None:
"""
后置钩子:消息后处理
在消息处理完成后执行的清理工作
"""
# 后置处理逻辑(如有需要)
pass
async def _error_hook(self, envelope: MessageEnvelope, exc: BaseException) -> None:
"""
错误钩子:统一异常处理
"""
if isinstance(exc, UserWarning):
# UserWarning 是预期的流程控制,只记录 debug 日志
logger.debug(f"消息处理流程控制: {exc}")
else:
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
logger.error(f"处理消息 {message_id} 时出错: {exc}", exc_info=True)
async def _handle_adapter_response_route(self, envelope: MessageEnvelope) -> MessageEnvelope | None:
"""
处理适配器响应消息的路由处理器
"""
message_segment = envelope.get("message_segment")
if message_segment and isinstance(message_segment, dict):
seg_data = message_segment.get("data")
if isinstance(seg_data, dict):
await self._handle_adapter_response(seg_data)
return None
async def _handle_normal_message(self, envelope: MessageEnvelope) -> MessageEnvelope | None:
"""
默认消息处理器:处理普通消息
1. 获取或创建聊天流
2. 转换为 DatabaseMessages
3. 过滤检查
4. 命令处理
5. 触发事件、存储、情绪更新
"""
try:
# 在将消息添加到管理器之前进行最终的静默检查
message_info = envelope.get("message_info")
if not isinstance(message_info, dict):
return None
# 获取用户和群组信息
group_info = message_info.get("group_info")
user_info = message_info.get("user_info")
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
chat = await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=user_info, # type: ignore
group_info=group_info,
)
# 将消息信封转换为 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 打印接收日志
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
# 硬编码过滤
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return None
# 过滤检查
raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(processed_text, chat, user_info) or _check_ban_regex(
raw_text, chat, user_info
):
return None
# 处理命令和后续流程
await self._process_commands(message, chat)
except UserWarning as uw:
logger.info(str(uw))
except Exception as e:
logger.error(f"处理消息时出错: {e}")
logger.error(traceback.format_exc())
return None
# 保留旧的 process_message 方法用于向后兼容
async def process_message(self, envelope: MessageEnvelope) -> None:
"""
处理接收到的消息信封(向后兼容)
注意:此方法已被 MessageRuntime 路由取代。
如果直接调用此方法,它会委托给 runtime.handle_message()。
Args:
envelope: 消息信封(来自适配器)
"""
if self._runtime:
await self._runtime.handle_message(envelope)
else:
# 如果 runtime 未设置,使用旧的处理流程
await self._handle_normal_message(envelope)
async def _process_commands(self, message: DatabaseMessages, chat: "ChatStream") -> None:
"""处理命令和继续消息流程"""
try:
# 首先尝试 PlusCommand
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
if is_plus_command and not plus_continue_process:
await MessageStorage.store_message(message, chat)
logger.info(f"PlusCommand处理完成跳过后续消息处理: {plus_cmd_result}")
return
# 如果不是 PlusCommand尝试传统 BaseCommand
if not is_plus_command:
is_command, cmd_result, continue_process = await self._process_base_commands(message, chat)
if is_command and not continue_process:
await MessageStorage.store_message(message, chat)
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return
# 触发消息事件
result = await event_manager.trigger_event(
EventType.ON_MESSAGE,
permission_group="SYSTEM",
message=message
)
if result and not result.all_continue_process():
raise UserWarning(
f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理"
)
# 预处理消息
await self._preprocess_message(message, chat)
except UserWarning as uw:
logger.info(str(uw))
except Exception as e:
logger.error(f"处理命令时出错: {e}")
logger.error(traceback.format_exc())
async def _process_plus_commands(
self,
message: DatabaseMessages,
chat: "ChatStream"
) -> tuple[bool, Any, bool]:
"""处理 PlusCommand 系统"""
try:
text = message.processed_plain_text or ""
# 获取配置的命令前缀
prefixes = global_config.command.command_prefixes
# 检查是否以任何前缀开头
matched_prefix = None
for prefix in prefixes:
if text.startswith(prefix):
matched_prefix = prefix
break
if not matched_prefix:
return False, None, True
# 移除前缀
command_part = text[len(matched_prefix):].strip()
# 分离命令名和参数
parts = command_part.split(None, 1)
if not parts:
return False, None, True
command_word = parts[0].lower()
args_text = parts[1] if len(parts) > 1 else ""
# 查找匹配的 PlusCommand
plus_command_registry = component_registry.get_plus_command_registry()
matching_commands = []
for plus_command_name, plus_command_class in plus_command_registry.items():
plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name)
if not plus_command_info:
continue
all_commands = [plus_command_name.lower()] + [
alias.lower() for alias in plus_command_info.command_aliases
]
if command_word in all_commands:
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
if not matching_commands:
return False, None, True
# 按优先级排序
if len(matching_commands) > 1:
matching_commands.sort(key=lambda x: x[1].priority, reverse=True)
plus_command_class, plus_command_info, plus_command_name = matching_commands[0]
# 检查是否被禁用
if (
chat
and chat.stream_id
and plus_command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的PlusCommand跳过处理")
return False, None, True
message.is_command = True
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plus_command_name)
# 创建实例
plus_command_instance = plus_command_class(message, plugin_config)
setattr(plus_command_instance, "chat_stream", chat)
try:
if not plus_command_instance.is_chat_type_allowed():
is_group = chat.group_info is not None
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
return False, None, True
from src.plugin_system.base.command_args import CommandArgs
command_args = CommandArgs(args_text)
plus_command_instance.args = command_args
success, response, intercept_message = await plus_command_instance.execute(command_args)
if success:
logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})")
else:
logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}")
return True, response, not intercept_message
except Exception as e:
logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}")
logger.error(traceback.format_exc())
try:
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
except Exception:
pass
return True, str(e), False
except Exception as e:
logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True
async def _process_base_commands(
self,
message: DatabaseMessages,
chat: "ChatStream"
) -> tuple[bool, Any, bool]:
"""处理传统 BaseCommand 系统"""
try:
text = message.processed_plain_text or ""
command_result = component_registry.find_command_by_text(text)
if command_result:
command_class, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
chat
and chat.stream_id
and command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
message.is_command = True
plugin_config = component_registry.get_plugin_config(plugin_name)
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
setattr(command_instance, "chat_stream", chat)
try:
if not command_instance.is_chat_type_allowed():
is_group = chat.group_info is not None
logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
return False, None, True
success, response, intercept_message = await command_instance.execute()
if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
return True, response, not intercept_message
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {e!s}")
except Exception:
pass
return True, str(e), False
return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True
async def _preprocess_message(self, message: DatabaseMessages, chat: "ChatStream") -> None:
"""预处理消息:存储、情绪更新等"""
try:
group_info = chat.group_info
# 检查是否需要处理消息
should_process_in_manager = True
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
# 检查消息是否为图片或表情包
is_image_or_emoji = message.is_picid or message.is_emoji
if not message.is_mentioned and not is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
logger.debug(
f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理"
)
should_process_in_manager = False
elif is_image_or_emoji:
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
@@ -43,68 +570,81 @@ class MessageHandler:
await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
# 存储消息
try:
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 情绪系统更新
try:
if global_config.mood.enable_mood:
interest_rate = message.interest_value or 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:
logger.error(f"更新情绪状态失败: {e}")
traceback.print_exc()
# 存储消息到数据库,只进行一次写入
try:
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
# 情绪系统更新 - 在消息存储后触发情绪更新
async def _handle_adapter_response(self, seg_data: dict | None) -> None:
"""处理适配器命令响应"""
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
from src.plugin_system.apis.send_api import put_adapter_response
if isinstance(seg_data, dict):
request_id = seg_data.get("request_id")
response_data = seg_data.get("response")
else:
request_id = None
response_data = None
if request_id and response_data:
logger.debug(f"收到适配器响应request_id={request_id}")
put_adapter_response(request_id, response_data)
else:
logger.warning(
f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}"
)
# 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:
logger.error(f"更新情绪状态失败: {e}")
traceback.print_exc()
logger.error(f"处理适配器响应时出错: {e}")
async def shutdown(self) -> None:
"""关闭消息处理器"""
self._shutting_down = True
logger.info("MessageHandler 正在关闭...")
async def handle_message(self, envelope: MessageEnvelope):
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
message_info = envelope.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(envelope.keys()),
)
return
# 全局单例
_message_handler: MessageHandler | None = None
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( # type: ignore
message_info["group_info"]["group_id"] # type: ignore
)
if message_info.get("user_info") is not None:
message_info["user_info"]["user_id"] = str( # type: ignore
message_info["user_info"]["user_id"] # type: ignore
)
group_info = message_info.get("group_info")
user_info = message_info.get("user_info")
def get_message_handler() -> MessageHandler:
"""获取 MessageHandler 单例"""
global _message_handler
if _message_handler is None:
_message_handler = MessageHandler()
return _message_handler
chat_stream = await get_chat_manager().get_or_create_stream(
platform=envelope["platform"], # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
# 生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat_stream.stream_id,
platform=chat_stream.platform
)
async def shutdown_message_handler() -> None:
"""关闭 MessageHandler"""
global _message_handler
if _message_handler:
await _message_handler.shutdown()
_message_handler = None
__all__ = [
"MessageHandler",
"get_message_handler",
"shutdown_message_handler",
]

View File

@@ -1,13 +1,26 @@
"""
统一消息发送器
重构说明2025-11
- 使用 CoreSinkManager 发送消息,而不是直接通过 WS 连接
- MessageServer 仅作为与旧适配器的兼容层
- 所有发送的消息都通过 CoreSinkManager.send_outgoing() 路由到适配器
"""
import asyncio
import traceback
import time
import uuid
from typing import Any, cast
from rich.traceback import install
from mofox_bus import MessageEnvelope
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.logger import get_logger
from src.common.message.api import get_global_api
install(extra_lines=3)
@@ -15,12 +28,30 @@ logger = get_logger("sender")
async def send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
"""
合并后的消息发送函数
重构后使用 CoreSinkManager 发送消息,而不是直接调用 MessageServer
Args:
message: 要发送的消息
show_log: 是否显示日志
Returns:
bool: 是否发送成功
"""
message_preview = truncate_message(message.processed_plain_text, max_length=120)
try:
# 直接调用API发送消息
await get_global_api().send_message(message)
# 将 MessageSending 转换为 MessageEnvelope
envelope = _message_sending_to_envelope(message)
# 通过 CoreSinkManager 发送
from src.common.core_sink_manager import get_core_sink_manager
manager = get_core_sink_manager()
await manager.send_outgoing(envelope)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
@@ -44,7 +75,67 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
traceback.print_exc()
raise e # 重新抛出其他异常
raise e
def _message_sending_to_envelope(message: MessageSending) -> MessageEnvelope:
"""
将 MessageSending 转换为 MessageEnvelope
Args:
message: MessageSending 对象
Returns:
MessageEnvelope: 消息信封
"""
# 构建消息信息
message_info: dict[str, Any] = {
"message_id": message.message_info.message_id,
"time": message.message_info.time or time.time(),
"platform": message.message_info.platform,
"user_info": {
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"platform": message.message_info.user_info.platform,
} if message.message_info.user_info else None,
}
# 添加群组信息(如果有)
if message.chat_stream and message.chat_stream.group_info:
message_info["group_info"] = {
"group_id": message.chat_stream.group_info.group_id,
"group_name": message.chat_stream.group_info.group_name,
"platform": message.chat_stream.group_info.group_platform,
}
# 构建消息段
message_segment: dict[str, Any]
if message.message_segment:
message_segment = {
"type": message.message_segment.type,
"data": message.message_segment.data,
}
else:
# 默认为文本消息
message_segment = {
"type": "text",
"data": message.processed_plain_text or "",
}
# 添加回复信息(如果有)
if message.reply_to:
message_segment["reply_to"] = message.reply_to
# 构建消息信封
envelope = cast(MessageEnvelope, {
"id": str(uuid.uuid4()),
"direction": "outgoing",
"platform": message.message_info.platform,
"message_info": message_info,
"message_segment": message_segment,
})
return envelope
class HeartFCSender:

View File

@@ -1,9 +1,9 @@
import asyncio
import time
import traceback
from typing import Any
from typing import Any, TYPE_CHECKING
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -14,6 +14,9 @@ from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.component_types import ActionInfo, ComponentType
from src.plugin_system.core.component_registry import component_registry
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager")
@@ -48,7 +51,7 @@ class ChatterActionManager:
reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
log_prefix: str,
shutting_down: bool = False,
action_message: DatabaseMessages | None = None,
@@ -476,7 +479,7 @@ class ChatterActionManager:
async def _send_and_store_reply(
self,
chat_stream: ChatStream,
chat_stream: "ChatStream",
response_set,
loop_start_time,
action_message,

View File

@@ -1,35 +0,0 @@
"""
从 src.main 导出 core_sink 的辅助函数
由于 src.main 中实际使用的是 InProcessCoreSink
我们需要创建一个全局访问点
"""
from mofox_bus import CoreSink, InProcessCoreSink
_global_core_sink: CoreSink | None = None
def set_core_sink(sink: CoreSink) -> None:
"""设置全局 core sink"""
global _global_core_sink
_global_core_sink = sink
def get_core_sink() -> CoreSink:
"""获取全局 core sink"""
global _global_core_sink
if _global_core_sink is None:
raise RuntimeError("Core sink 尚未初始化")
return _global_core_sink
async def push_outgoing(envelope) -> None:
"""将消息推送到 core sink 的 outgoing 通道"""
sink = get_core_sink()
push = getattr(sink, "push_outgoing", None)
if push is None:
raise RuntimeError("当前 core sink 不支持 push_outgoing 方法")
await push(envelope)
__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"]

View File

@@ -0,0 +1,401 @@
"""
CoreSink 统一管理器
负责管理 InProcessCoreSink 和 ProcessCoreSink 双实例,
提供统一的消息收发接口,自动维护与适配器子进程的通信管道。
核心职责:
1. 创建和管理 InProcessCoreSink进程内消息和 ProcessCoreSink跨进程消息
2. 自动维护 ProcessCoreSink 与子进程的通信管道
3. 使用 MessageRuntime 进行消息路由和处理
4. 提供统一的消息发送接口
架构说明2025-11 重构):
- 集成 mofox_bus.MessageRuntime 作为消息路由中心
- 使用 @runtime.on_message() 装饰器注册消息处理器
- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑
- 简化消息处理链条,提高可扩展性
"""
from __future__ import annotations
import asyncio
import contextlib
import multiprocessing as mp
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from mofox_bus import (
InProcessCoreSink,
MessageEnvelope,
MessageRuntime,
ProcessCoreSinkServer,
)
from src.common.logger import get_logger
if TYPE_CHECKING:
from chat.message_receive.message_handler import MessageHandler
logger = get_logger("core_sink_manager")
# 消息处理器类型
MessageHandlerCallback = Callable[[MessageEnvelope], Awaitable[None]]
class CoreSinkManager:
"""
CoreSink 统一管理器
管理 InProcessCoreSink 和 ProcessCoreSinkServer 双实例,
集成 MessageRuntime 提供统一的消息路由和收发接口。
架构说明:
- InProcessCoreSink: 用于同进程内的适配器run_in_subprocess=False
- ProcessCoreSinkServer: 用于管理与子进程适配器的通信
- MessageRuntime: 统一消息路由,支持 @on_message 装饰器和钩子机制
消息流向:
1. 适配器(同进程)→ InProcessCoreSink → MessageRuntime.handle_message() → 注册的处理器
2. 适配器(子进程)→ ProcessCoreSinkServer → MessageRuntime.handle_message() → 注册的处理器
3. 核心回复 → CoreSinkManager.send_outgoing → 适配器
使用 MessageRuntime 的优势:
- 支持 @runtime.on_message(message_type="xxx") 按消息类型路由
- 支持 before_hook/after_hook/error_hook 统一处理流程
- 支持中间件机制(洋葱模型)
- 自动处理同步/异步处理器
"""
def __init__(self):
# MessageRuntime 实例
self._runtime: MessageRuntime = MessageRuntime()
# InProcessCoreSink 实例(用于同进程适配器)
self._in_process_sink: InProcessCoreSink | None = None
# 子进程通信管理
# key: adapter_name, value: (ProcessCoreSinkServer, incoming_queue, outgoing_queue)
self._process_sinks: Dict[str, tuple[ProcessCoreSinkServer, mp.Queue, mp.Queue]] = {}
# multiprocessing context
self._mp_ctx = mp.get_context("spawn")
# 运行状态
self._running = False
self._initialized = False
@property
def runtime(self) -> MessageRuntime:
"""
获取 MessageRuntime 实例
外部模块可以通过此属性注册消息处理器、钩子等:
```python
manager = get_core_sink_manager()
# 注册消息处理器
@manager.runtime.on_message(message_type="text")
async def handle_text(envelope: MessageEnvelope):
...
# 注册前置钩子
manager.runtime.register_before_hook(my_before_hook)
```
Returns:
MessageRuntime 实例
"""
return self._runtime
async def initialize(self) -> None:
"""
初始化 CoreSink 管理器
创建 InProcessCoreSink将收到的消息交给 MessageRuntime 处理。
"""
if self._initialized:
logger.warning("CoreSinkManager 已经初始化,跳过重复初始化")
return
logger.info("正在初始化 CoreSink 管理器...")
# 创建 InProcessCoreSink使用 MessageRuntime 作为消息处理入口
self._in_process_sink = InProcessCoreSink(self._dispatch_to_runtime)
self._running = True
self._initialized = True
logger.info("CoreSink 管理器初始化完成(已集成 MessageRuntime")
async def shutdown(self) -> None:
"""关闭 CoreSink 管理器"""
if not self._running:
return
logger.info("正在关闭 CoreSink 管理器...")
self._running = False
# 关闭所有 ProcessCoreSinkServer
for adapter_name, (server, _, _) in list(self._process_sinks.items()):
try:
await server.close()
logger.info(f"已关闭适配器 {adapter_name} 的 ProcessCoreSinkServer")
except Exception as e:
logger.error(f"关闭适配器 {adapter_name} 的 ProcessCoreSinkServer 时出错: {e}")
self._process_sinks.clear()
# 关闭 InProcessCoreSink
if self._in_process_sink:
await self._in_process_sink.close()
self._in_process_sink = None
self._initialized = False
logger.info("CoreSink 管理器已关闭")
def get_in_process_sink(self) -> InProcessCoreSink:
"""
获取 InProcessCoreSink 实例
用于同进程运行的适配器
Returns:
InProcessCoreSink 实例
Raises:
RuntimeError: 如果管理器未初始化
"""
if self._in_process_sink is None:
raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()")
return self._in_process_sink
def create_process_sink_queues(self, adapter_name: str) -> tuple[mp.Queue, mp.Queue]:
"""
为子进程适配器创建通信队列
创建 incoming 和 outgoing 队列对,用于与子进程适配器通信。
同时创建 ProcessCoreSinkServer 来处理消息转发。
Args:
adapter_name: 适配器名称
Returns:
(to_core_queue, from_core_queue) 元组
- to_core_queue: 子进程发送到核心的队列
- from_core_queue: 核心发送到子进程的队列
Raises:
RuntimeError: 如果管理器未初始化
"""
if not self._initialized:
raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()")
if adapter_name in self._process_sinks:
logger.warning(f"适配器 {adapter_name} 的队列已存在,将被覆盖")
# 先关闭旧的
old_server, _, _ = self._process_sinks[adapter_name]
asyncio.create_task(old_server.close())
# 创建通信队列
incoming_queue = self._mp_ctx.Queue() # 子进程 → 核心
outgoing_queue = self._mp_ctx.Queue() # 核心 → 子进程
# 创建 ProcessCoreSinkServer使用 MessageRuntime 处理消息
server = ProcessCoreSinkServer(
incoming_queue=incoming_queue,
outgoing_queue=outgoing_queue,
core_handler=self._dispatch_to_runtime,
name=adapter_name,
)
# 启动服务器
server.start()
# 存储引用
self._process_sinks[adapter_name] = (server, incoming_queue, outgoing_queue)
logger.info(f"为适配器 {adapter_name} 创建了 ProcessCoreSink 通信队列")
return incoming_queue, outgoing_queue
def remove_process_sink(self, adapter_name: str) -> None:
"""
移除子进程适配器的通信队列
Args:
adapter_name: 适配器名称
"""
if adapter_name not in self._process_sinks:
logger.warning(f"适配器 {adapter_name} 的队列不存在")
return
server, _, _ = self._process_sinks.pop(adapter_name)
asyncio.create_task(server.close())
logger.info(f"已移除适配器 {adapter_name} 的 ProcessCoreSink 通信队列")
async def send_outgoing(
self,
envelope: MessageEnvelope,
platform: str | None = None,
adapter_name: str | None = None
) -> None:
"""
发送消息到适配器
根据 platform 或 adapter_name 路由到正确的适配器。
Args:
envelope: 消息信封
platform: 目标平台(可选)
adapter_name: 目标适配器名称(可选)
路由规则:
1. 如果指定了 adapter_name直接发送到该适配器
2. 如果指定了 platform发送到所有匹配平台的适配器
3. 如果都没指定,从 envelope 中提取 platform 并广播
"""
# 从 envelope 中获取 platform
if platform is None:
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
# 发送到 InProcessCoreSink会自动广播到所有注册的 outgoing handler
if self._in_process_sink:
await self._in_process_sink.push_outgoing(envelope)
# 发送到所有 ProcessCoreSinkServer
for name, (server, _, _) in self._process_sinks.items():
if adapter_name and name != adapter_name:
continue
try:
await server.push_outgoing(envelope)
except Exception as e:
logger.error(f"发送消息到适配器 {name} 失败: {e}")
async def _dispatch_to_runtime(self, envelope: MessageEnvelope) -> None:
"""
将消息分发给 MessageRuntime 处理
这是内部方法,由 InProcessCoreSink 和 ProcessCoreSinkServer 调用。
所有从适配器接收到的消息都会经过这里,然后交给 MessageRuntime 路由。
Args:
envelope: 消息信封
"""
if not self._running:
logger.warning("CoreSinkManager 未运行,忽略接收到的消息")
return
try:
# 使用 MessageRuntime 处理消息
await self._runtime.handle_message(envelope)
except Exception as e:
logger.error(f"MessageRuntime 处理消息时出错: {e}", exc_info=True)
# 全局单例
_core_sink_manager: CoreSinkManager | None = None
def get_core_sink_manager() -> CoreSinkManager:
"""获取 CoreSinkManager 单例"""
global _core_sink_manager
if _core_sink_manager is None:
_core_sink_manager = CoreSinkManager()
return _core_sink_manager
def get_message_runtime() -> MessageRuntime:
"""
获取全局 MessageRuntime 实例
这是获取 MessageRuntime 的推荐方式,用于注册消息处理器、钩子等:
```python
from src.common.core_sink_manager import get_message_runtime
runtime = get_message_runtime()
@runtime.on_message(message_type="text")
async def handle_text(envelope: MessageEnvelope):
...
```
Returns:
MessageRuntime 实例
"""
return get_core_sink_manager().runtime
async def initialize_core_sink_manager() -> CoreSinkManager:
"""
初始化 CoreSinkManager 单例
Returns:
初始化后的 CoreSinkManager 实例
"""
manager = get_core_sink_manager()
await manager.initialize()
return manager
async def shutdown_core_sink_manager() -> None:
"""关闭 CoreSinkManager 单例"""
global _core_sink_manager
if _core_sink_manager:
await _core_sink_manager.shutdown()
_core_sink_manager = None
# ============================================================================
# 向后兼容的 API
# ============================================================================
def get_core_sink() -> InProcessCoreSink:
"""
获取 InProcessCoreSink 实例(向后兼容)
这是旧版 API推荐使用 get_core_sink_manager().get_in_process_sink()
Returns:
InProcessCoreSink 实例
"""
return get_core_sink_manager().get_in_process_sink()
def set_core_sink(sink: Any) -> None:
"""
设置 CoreSink向后兼容现已弃用
新架构中 CoreSink 由 CoreSinkManager 统一管理,不再支持外部设置。
此函数保留仅为兼容旧代码,调用会记录警告日志。
"""
logger.warning(
"set_core_sink() 已弃用CoreSink 现由 CoreSinkManager 统一管理。"
"请使用 initialize_core_sink_manager() 初始化。"
)
async def push_outgoing(envelope: MessageEnvelope) -> None:
"""
将消息推送到所有适配器(向后兼容)
Args:
envelope: 消息信封
"""
manager = get_core_sink_manager()
await manager.send_outgoing(envelope)
__all__ = [
"CoreSinkManager",
"get_core_sink_manager",
"get_message_runtime",
"initialize_core_sink_manager",
"shutdown_core_sink_manager",
# 向后兼容
"get_core_sink",
"set_core_sink",
"push_outgoing",
]

View File

@@ -16,6 +16,24 @@ class DatabaseUserInfo(BaseDataModel):
user_nickname: str = field(default_factory=str) # 用户昵称
user_cardname: str | None = None # 用户备注名或群名片,可为空
@classmethod
def from_dict(cls, data: dict) -> "DatabaseUserInfo":
"""从字典创建实例"""
return cls(
platform=data.get("platform", ""),
user_id=data.get("user_id", ""),
user_nickname=data.get("user_nickname", ""),
user_cardname=data.get("user_cardname"),
)
def to_dict(self) -> dict:
"""将实例转换为字典"""
return {
"platform": self.platform,
"user_id": self.user_id,
"user_nickname": self.user_nickname,
"user_cardname": self.user_cardname,
}
@dataclass
class DatabaseGroupInfo(BaseDataModel):
@@ -26,7 +44,23 @@ class DatabaseGroupInfo(BaseDataModel):
group_name: str = field(default_factory=str) # 群组名称
group_platform: str | None = None # 群组所在平台,可为空
@classmethod
def from_dict(cls, data: dict) -> "DatabaseGroupInfo":
"""从字典创建实例"""
return cls(
group_id=data.get("group_id", ""),
group_name=data.get("group_name", ""),
group_platform=data.get("group_platform"),
)
def to_dict(self) -> dict:
"""将实例转换为字典"""
return {
"group_id": self.group_id,
"group_name": self.group_name,
"group_platform": self.group_platform,
}
@dataclass
class DatabaseChatInfo(BaseDataModel):
"""

View File

@@ -6,17 +6,21 @@ import sys
import time
import traceback
from collections.abc import Callable, Coroutine
from functools import partial
from random import choices
from typing import Any
from mofox_bus import InProcessCoreSink, MessageEnvelope
from rich.traceback import install
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.bot import chat_bot
from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.core_sink_manager import (
CoreSinkManager,
get_core_sink_manager,
initialize_core_sink_manager,
shutdown_core_sink_manager,
)
from src.common.logger import get_logger
# 全局背景任务集合
@@ -55,28 +59,15 @@ EGG_PHRASES: list[tuple[str, int]] = [
]
def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float) -> None:
"""后台任务完成时的回调函数"""
end_time = time.time()
duration = end_time - start_time
try:
task.result() # 如果任务有异常,这里会重新抛出
logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s")
except asyncio.CancelledError:
logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s")
except Exception:
logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:")
logger.error(traceback.format_exc())
class MainSystem:
"""主系统类,负责协调所有组件"""
def __init__(self) -> None:
self.individuality: Individuality = get_individuality()
# 创建核心消息接收器
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper)
# CoreSinkManager 和 MessageHandler 将在 initialize() 中创建
self.core_sink_manager: CoreSinkManager | None = None
self.message_handler = None
# 使用服务器
self.server: Server = get_global_server()
@@ -163,10 +154,11 @@ class MainSystem:
continue
try:
from src.plugin_system.base.component_types import ComponentType as CT
from src.plugin_system.core.component_registry import component_registry
component_class = component_registry.get_component_class(
calc_name, ComponentType.INTEREST_CALCULATOR
calc_name, CT.INTEREST_CALCULATOR
)
if not component_class:
@@ -299,6 +291,18 @@ class MainSystem:
except Exception as e:
logger.error(f"准备停止适配器管理器时出错: {e}")
# 停止 CoreSinkManager
try:
cleanup_tasks.append(("CoreSinkManager", shutdown_core_sink_manager()))
except Exception as e:
logger.error(f"准备停止 CoreSinkManager 时出错: {e}")
# 停止 MessageHandler
try:
cleanup_tasks.append(("MessageHandler", shutdown_message_handler()))
except Exception as e:
logger.error(f"准备停止 MessageHandler 时出错: {e}")
# 并行执行所有清理任务
if cleanup_tasks:
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
@@ -352,27 +356,6 @@ class MainSystem:
except Exception as e:
logger.error(f"同步清理资源时出错: {e}")
async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None:
"""并行处理消息的包装器"""
try:
start_time = time.time()
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
# 检查系统是否正在关闭
if self._shutting_down:
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
return
# 创建后台任务
task = asyncio.create_task(chat_bot.message_process(envelope))
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
# 添加一个回调函数,当任务完成时,它会被调用
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
except Exception:
logger.error("在创建消息处理任务时发生严重错误:")
logger.error(traceback.format_exc())
async def initialize(self) -> None:
"""初始化系统组件"""
# 检查必要的配置
@@ -382,6 +365,18 @@ class MainSystem:
logger.info(f"正在唤醒{global_config.bot.nickname}......")
# 初始化 CoreSinkManager包含 MessageRuntime
logger.info("正在初始化 CoreSinkManager...")
self.core_sink_manager = await initialize_core_sink_manager()
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
self.message_handler = get_message_handler()
self.message_handler.set_core_sink_manager(self.core_sink_manager)
# 向 MessageRuntime 注册消息处理器和钩子
self.message_handler.register_handlers(self.core_sink_manager.runtime)
logger.info("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)")
# 初始化组件
await self._init_components()
@@ -453,7 +448,11 @@ MoFox_Bot(第三方修改版)
logger.error(f"统一调度器初始化失败: {e}")
# 设置核心消息接收器到插件管理器
plugin_manager.set_core_sink(self.core_sink)
# 使用 CoreSinkManager 的 InProcessCoreSink
if self.core_sink_manager:
plugin_manager.set_core_sink(self.core_sink_manager.get_in_process_sink())
else:
logger.error("CoreSinkManager 未初始化,无法设置核心消息接收器")
# 加载所有插件
plugin_manager.load_all_plugins()
@@ -505,8 +504,8 @@ MoFox_Bot(第三方修改版)
except Exception as e:
logger.error(f"LPMM知识库初始化失败: {e}")
# 消息接收器已__init__ 中创建,无需再次注册
logger.info("核心消息接收器已就绪")
# 消息接收器已在 initialize() 中通过 CoreSinkManager 创建
logger.info("核心消息接收器已就绪(通过 CoreSinkManager")
# 启动消息重组器
try:

View File

@@ -1,79 +0,0 @@
"""
MoFox 内部通用消息总线实现。
该模块导出 TypedDict 消息模型、序列化工具、传输层封装以及适配器辅助工具,
供核心进程与各类平台适配器共享。
"""
from . import codec, types
from .adapter_utils import (
AdapterTransportOptions,
AdapterBase,
BatchDispatcher,
CoreSink,
CoreMessageSink,
HttpAdapterOptions,
InProcessCoreSink,
ProcessCoreSink,
ProcessCoreSinkServer,
WebSocketLike,
WebSocketAdapterOptions,
)
from .api import MessageClient, MessageServer
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
from .builder import MessageBuilder
from .router import RouteConfig, Router, TargetConfig
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware
from .types import (
FormatInfoPayload,
GroupInfoPayload,
MessageDirection,
MessageEnvelope,
MessageInfoPayload,
SegPayload,
TemplateInfoPayload,
UserInfoPayload,
)
__all__ = [
# TypedDict model
"MessageDirection",
"MessageEnvelope",
"SegPayload",
"UserInfoPayload",
"GroupInfoPayload",
"FormatInfoPayload",
"TemplateInfoPayload",
"MessageInfoPayload",
# Codec helpers
"codec",
"dumps_message",
"dumps_messages",
"loads_message",
"loads_messages",
"MessageBuilder",
# Runtime / routing
"MessageRoute",
"MessageRuntime",
"MessageProcessingError",
"Middleware",
# Server/client/router
"MessageServer",
"MessageClient",
"Router",
"RouteConfig",
"TargetConfig",
# Adapter helpers
"AdapterTransportOptions",
"AdapterBase",
"BatchDispatcher",
"CoreSink",
"CoreMessageSink",
"InProcessCoreSink",
"ProcessCoreSink",
"ProcessCoreSinkServer",
"WebSocketLike",
"WebSocketAdapterOptions",
"HttpAdapterOptions",
]

View File

@@ -1,485 +0,0 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import multiprocessing as mp
from dataclasses import dataclass
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
import orjson
from aiohttp import web as aiohttp_web
import websockets
from .types import MessageEnvelope
logger = logging.getLogger("mofox_bus.adapter")
OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]]
class CoreMessageSink(Protocol):
async def send(self, message: MessageEnvelope) -> None: ...
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
class CoreSink(CoreMessageSink, Protocol):
"""
双向 CoreSink 协议:
- send/send_many: 适配器 → 核心incoming
- push_outgoing: 核心 → 适配器outgoing
"""
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ...
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ...
async def push_outgoing(self, envelope: MessageEnvelope) -> None: ...
async def close(self) -> None: ... # pragma: no cover - lifecycle hook
class WebSocketLike(Protocol):
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
@property
def closed(self) -> bool: ...
async def send(self, data: str | bytes) -> None: ...
async def close(self) -> None: ...
@dataclass
class WebSocketAdapterOptions:
url: str
headers: dict[str, str] | None = None
incoming_parser: Callable[[str | bytes], Any] | None = None
outgoing_encoder: Callable[[MessageEnvelope], str | bytes] | None = None
@dataclass
class HttpAdapterOptions:
host: str = "0.0.0.0"
port: int = 8089
path: str = "/adapter/messages"
app: aiohttp_web.Application | None = None
AdapterTransportOptions = WebSocketAdapterOptions | HttpAdapterOptions | None
class AdapterBase:
"""
适配器基类:负责平台原始消息与 MessageEnvelope 之间的互转。
子类需要实现平台入站解析与出站发送逻辑。
"""
platform: str = "unknown"
def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None):
"""
Args:
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
transport: 传入 WebSocketAdapterOptions / HttpAdapterOptions 即可自动管理监听逻辑。
"""
self.core_sink = core_sink
self._transport_config = transport
self._ws: WebSocketLike | None = None
self._ws_task: asyncio.Task | None = None
self._http_runner: aiohttp_web.AppRunner | None = None
self._http_site: aiohttp_web.BaseSite | None = None
async def start(self) -> None:
"""启动适配器的传输层监听(如果配置了传输选项)。"""
if hasattr(self.core_sink, "set_outgoing_handler"):
try:
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
except Exception:
logger.exception("注册 outgoing 处理程序到核心接收器失败")
if isinstance(self._transport_config, WebSocketAdapterOptions):
await self._start_ws_transport(self._transport_config)
elif isinstance(self._transport_config, HttpAdapterOptions):
await self._start_http_transport(self._transport_config)
async def stop(self) -> None:
"""停止适配器的传输层监听(如果配置了传输选项)。"""
remove = getattr(self.core_sink, "remove_outgoing_handler", None)
if callable(remove):
try:
remove(self._on_outgoing_from_core)
except Exception:
logger.exception("从核心接收器分离 outgoing 处理程序失败")
elif hasattr(self.core_sink, "set_outgoing_handler"):
try:
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
except Exception:
logger.exception("从核心接收器分离 outgoing 处理程序失败")
if self._ws_task:
self._ws_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._ws_task
self._ws_task = None
if self._ws:
await self._ws.close()
self._ws = None
if self._http_site:
await self._http_site.stop()
self._http_site = None
if self._http_runner:
await self._http_runner.cleanup()
self._http_runner = None
async def on_platform_message(self, raw: Any) -> None:
"""处理平台下发的单条消息并交给核心。"""
envelope = await self.from_platform_message(raw)
await self.core_sink.send(envelope)
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
"""批量推送入口,内部自动批量或逐条送入核心。"""
envelopes = [await self.from_platform_message(raw) for raw in raw_messages]
await _send_many(self.core_sink, envelopes)
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
"""核心生成单条消息时调用,由子类或自动传输层发送。"""
await self._send_platform_message(envelope)
async def send_batch_to_platform(self, envelopes: list[MessageEnvelope]) -> None:
"""默认串行发送整批消息,子类可根据平台特性重写。"""
for env in envelopes:
await self._send_platform_message(env)
async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None:
"""核心生成 outgoing envelope 时的内部处理逻辑"""
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
if platform and platform != getattr(self, "platform", None):
return
await self._send_platform_message(envelope)
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
raise NotImplementedError
async def _send_platform_message(self, envelope: MessageEnvelope) -> None:
"""子类必须实现:把 MessageEnvelope 转为平台格式并发送出去。"""
if isinstance(self._transport_config, WebSocketAdapterOptions):
await self._send_via_ws(envelope)
return
raise NotImplementedError
async def _start_ws_transport(self, options: WebSocketAdapterOptions) -> None:
self._ws = await websockets.connect(options.url, extra_headers=options.headers)
self._ws_task = asyncio.create_task(self._ws_listen_loop(options))
async def _ws_listen_loop(self, options: WebSocketAdapterOptions) -> None:
assert self._ws is not None
parser = options.incoming_parser or self._default_ws_parser
try:
async for raw in self._ws:
payload = parser(raw)
await self.on_platform_message(payload)
finally:
pass
async def _send_via_ws(self, envelope: MessageEnvelope) -> None:
if self._ws is None or self._ws.closed:
raise RuntimeError("WebSocket transport is not active")
encoder = None
if isinstance(self._transport_config, WebSocketAdapterOptions):
encoder = self._transport_config.outgoing_encoder
data = encoder(envelope) if encoder else self._default_ws_encoder(envelope)
await self._ws.send(data)
async def _start_http_transport(self, options: HttpAdapterOptions) -> None:
app = options.app or aiohttp_web.Application()
app.add_routes([aiohttp_web.post(options.path, self._handle_http_request)])
self._http_runner = aiohttp_web.AppRunner(app)
await self._http_runner.setup()
self._http_site = aiohttp_web.TCPSite(self._http_runner, options.host, options.port)
await self._http_site.start()
async def _handle_http_request(self, request: aiohttp_web.Request) -> aiohttp_web.Response:
raw = await request.read()
data = orjson.loads(raw) if raw else {}
if isinstance(data, list):
await self.on_platform_messages(data)
else:
await self.on_platform_message(data)
return aiohttp_web.json_response({"status": "ok"})
@staticmethod
def _default_ws_parser(raw: str | bytes) -> Any:
data = orjson.loads(raw)
if isinstance(data, dict) and data.get("type") == "message" and "payload" in data:
return data["payload"]
return data
@staticmethod
def _default_ws_encoder(envelope: MessageEnvelope) -> bytes:
return orjson.dumps({"type": "send", "payload": envelope})
class InProcessCoreSink(CoreSink):
"""
进程内核心消息 sink实现 CoreSink 协议。
"""
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
self._handler = handler
self._outgoing_handlers: set[OutgoingHandler] = set()
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
if handler is None:
return
self._outgoing_handlers.add(handler)
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
self._outgoing_handlers.discard(handler)
async def send(self, message: MessageEnvelope) -> None:
await self._handler(message)
async def send_many(self, messages: list[MessageEnvelope]) -> None:
for message in messages:
await self._handler(message)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
if not self._outgoing_handlers:
logger.debug("Outgoing envelope dropped: no handler registered")
return
for callback in list(self._outgoing_handlers):
await callback(envelope)
async def close(self) -> None: # pragma: no cover - symmetry
self._outgoing_handlers.clear()
class ProcessCoreSink(CoreSink):
"""
进程间核心消息 sink实现 CoreSink 协议,使用 multiprocessing.Queue 初始化
"""
_CONTROL_STOP = {"__core_sink_control__": "stop"}
def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None:
self._to_core_queue = to_core_queue
self._from_core_queue = from_core_queue
self._outgoing_handler: OutgoingHandler | None = None
self._closed = False
self._listener_task: asyncio.Task | None = None
self._loop = asyncio.get_event_loop()
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
self._outgoing_handler = handler
if handler is not None and (self._listener_task is None or self._listener_task.done()):
self._listener_task = self._loop.create_task(self._listen_from_core())
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
if self._outgoing_handler is handler:
self._outgoing_handler = None
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
async def send(self, message: MessageEnvelope) -> None:
await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message})
async def send_many(self, messages: list[MessageEnvelope]) -> None:
for message in messages:
await self.send(message)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
logger.debug("ProcessCoreSink.push_outgoing 在子进程中调用; 被忽略")
async def close(self) -> None:
if self._closed:
return
self._closed = True
await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP)
if self._listener_task:
self._listener_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._listener_task
self._listener_task = None
async def _listen_from_core(self) -> None:
while not self._closed:
try:
item = await asyncio.to_thread(self._from_core_queue.get)
except asyncio.CancelledError:
break
if item == self._CONTROL_STOP:
break
if isinstance(item, dict) and item.get("kind") == "outgoing":
envelope = item.get("payload")
if self._outgoing_handler:
try:
await self._outgoing_handler(envelope)
except Exception: # pragma: no cover
logger.exception("处理 ProcessCoreSink 中的 outgoing 信封失败")
else:
logger.debug(f"ProcessCoreSink 接受到未知负载: {item}")
class ProcessCoreSinkServer:
"""
进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。
- 将传入的 incoming 消息转发给指定的 handler
- 将接收到的 outgoing 消息放入 outgoing 队列
"""
def __init__(
self,
*,
incoming_queue: mp.Queue,
outgoing_queue: mp.Queue,
core_handler: Callable[[MessageEnvelope], Awaitable[None]],
name: str | None = None,
) -> None:
self._incoming_queue = incoming_queue
self._outgoing_queue = outgoing_queue
self._core_handler = core_handler
self._task: asyncio.Task | None = None
self._closed = False
self._name = name or "adapter"
def start(self) -> None:
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._consume_incoming())
async def _consume_incoming(self) -> None:
while not self._closed:
try:
item = await asyncio.to_thread(self._incoming_queue.get)
except asyncio.CancelledError:
break
if isinstance(item, dict) and item.get("__core_sink_control__") == "stop":
break
if isinstance(item, dict) and item.get("kind") == "incoming":
envelope = item.get("payload")
try:
await self._core_handler(envelope)
except Exception: # pragma: no cover
logger.exception(f"处理来自 {self._name} 的 incoming 信封时失败")
else:
logger.debug(f"ProcessCoreSinkServer 忽略来自 {self._name} 的未知负载: {item}")
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
async def close(self) -> None:
if self._closed:
return
self._closed = True
await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"})
await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP)
if self._task:
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
self._task = None
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
send_many = getattr(sink, "send_many", None)
if callable(send_many):
await send_many(envelopes)
return
for env in envelopes:
await sink.send(env)
class BatchDispatcher:
"""
批量消息分发器,负责将消息批量发送到核心 sink。
"""
_STOP = object()
def __init__(
self,
sink: CoreMessageSink,
*,
max_batch_size: int = 50,
flush_interval: float = 0.2,
) -> None:
self._sink = sink
self._max_batch_size = max_batch_size
self._flush_interval = flush_interval
self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue()
self._worker: asyncio.Task | None = None
self._closed = False
async def add(self, message: MessageEnvelope) -> None:
if self._closed:
raise RuntimeError("Dispatcher closed")
self._ensure_worker()
await self._queue.put(message)
async def close(self) -> None:
if self._closed:
return
self._closed = True
self._ensure_worker()
await self._queue.put(self._STOP)
if self._worker:
await self._worker
self._worker = None
def _ensure_worker(self) -> None:
if self._worker is not None and not self._worker.done():
return
self._worker = asyncio.create_task(self._worker_loop())
async def _worker_loop(self) -> None:
buffer: list[MessageEnvelope] = []
try:
while True:
try:
item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval)
except asyncio.TimeoutError:
item = None
if item is self._STOP:
await self._flush_buffer(buffer)
return
if item is not None:
buffer.append(item) # type: ignore[arg-type]
while len(buffer) < self._max_batch_size:
try:
item = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
if item is self._STOP:
await self._flush_buffer(buffer)
return
buffer.append(item) # type: ignore[arg-type]
if buffer and (len(buffer) >= self._max_batch_size or item is None):
await self._flush_buffer(buffer)
except asyncio.CancelledError: # pragma: no cover - worker cancellation
if buffer:
await self._flush_buffer(buffer)
async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None:
if not buffer:
return
payload = list(buffer)
buffer.clear()
await _send_many(self._sink, payload)
__all__ = [
"AdapterTransportOptions",
"AdapterBase",
"BatchDispatcher",
"CoreSink",
"CoreMessageSink",
"HttpAdapterOptions",
"InProcessCoreSink",
"ProcessCoreSink",
"ProcessCoreSinkServer",
"WebSocketLike",
"WebSocketAdapterOptions",
]

View File

@@ -1,515 +0,0 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import ssl
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
MessagePayload = Dict[str, Any]
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
"""
将原始字节数据附加到消息负载中
Args:
payload: 消息负载
raw_bytes: 原始字节数据
Returns:
附加了原始数据的消息负载
"""
if isinstance(payload, dict):
payload.setdefault("raw_bytes", raw_bytes)
elif isinstance(payload, list):
for item in payload:
if isinstance(item, dict):
item.setdefault("raw_bytes", raw_bytes)
return payload
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
"""
编码消息用于 WebSocket 发送
Args:
message: 要发送的消息
use_raw_bytes: 是否使用原始字节数据
Returns:
(编码后的数据, 是否为二进制格式)
"""
if isinstance(message, (bytes, bytearray)):
return bytes(message), True
if use_raw_bytes and isinstance(message, dict):
raw = message.get("raw_bytes")
if isinstance(raw, (bytes, bytearray)):
return bytes(raw), True
payload = message
if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes:
payload = {k: v for k, v in payload.items() if k != "raw_bytes"}
data = orjson.dumps(payload)
if use_raw_bytes:
return data, True
return data.decode("utf-8"), False
class BaseMessageHandler:
"""基础消息处理器,提供消息处理和任务管理功能"""
def __init__(self) -> None:
self.message_handlers: list[MessageHandler] = []
self.background_tasks: set[asyncio.Task] = set()
def register_message_handler(self, handler: MessageHandler) -> None:
"""
注册消息处理器
Args:
handler: 消息处理函数
"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def process_message(self, message: MessagePayload) -> None:
"""
处理单条消息,并发执行所有注册的处理器
Args:
message: 消息负载
"""
tasks: list[asyncio.Task] = []
for handler in self.message_handlers:
try:
result = handler(message)
if asyncio.iscoroutine(result):
task = asyncio.create_task(result)
tasks.append(task)
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
except Exception: # pragma: no cover - logging only
logging.getLogger("mofox_bus.server").exception("消息处理失败")
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
class MessageServer(BaseMessageHandler):
"""
WebSocket 消息服务器,支持与 FastAPI 应用共享事件循环。
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
*,
enable_token: bool = False,
app: FastAPI | None = None,
path: str = "/ws",
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
mode: Literal["ws", "tcp"] = "ws",
custom_logger: logging.Logger | None = None,
enable_custom_uvicorn_logger: bool = False,
queue_maxsize: int = 1000,
worker_count: int = 1,
) -> None:
super().__init__()
if mode != "ws":
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
if custom_logger:
logging.getLogger("mofox_bus.server").handlers = custom_logger.handlers
self.host = host
self.port = port
self._app = app or FastAPI()
self._own_app = app is None
self._path = path
self._ssl_certfile = ssl_certfile
self._ssl_keyfile = ssl_keyfile
self._enable_token = enable_token
self._valid_tokens: set[str] = set()
self._connections: set[WebSocket] = set()
self._platform_connections: dict[str, WebSocket] = {}
self._conn_lock = asyncio.Lock()
self._server: uvicorn.Server | None = None
self._running = False
self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize)
self._worker_count = max(1, worker_count)
self._worker_tasks: list[asyncio.Task] = []
self._setup_routes()
def _setup_routes(self) -> None:
@_self_websocket(self._app, self._path)
async def websocket_endpoint(websocket: WebSocket) -> None:
platform = websocket.headers.get("platform", "unknown")
token = websocket.headers.get("authorization") or websocket.headers.get("Authorization")
if self._enable_token and not await self.verify_token(token):
await websocket.close(code=1008, reason="invalid token")
return
await websocket.accept()
await self._register_connection(websocket, platform)
try:
while True:
msg = await websocket.receive()
if msg["type"] == "websocket.receive":
raw_bytes = msg.get("bytes")
if raw_bytes is None and msg.get("text") is not None:
raw_bytes = msg["text"].encode("utf-8")
if not raw_bytes:
continue
try:
payload = orjson.loads(raw_bytes)
except orjson.JSONDecodeError:
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
continue
payload = _attach_raw_bytes(payload, raw_bytes)
if isinstance(payload, list):
for item in payload:
await self._enqueue_message(item)
else:
await self._enqueue_message(payload)
elif msg["type"] == "websocket.disconnect":
break
except WebSocketDisconnect:
pass
finally:
await self._remove_connection(websocket, platform)
async def _enqueue_message(self, payload: MessagePayload) -> None:
if not self._worker_tasks:
self._start_workers()
try:
self._message_queue.put_nowait(payload)
except asyncio.QueueFull:
logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message")
def _start_workers(self) -> None:
if self._worker_tasks:
return
self._running = True
for _ in range(self._worker_count):
task = asyncio.create_task(self._consumer_worker())
self._worker_tasks.append(task)
async def _stop_workers(self) -> None:
if not self._worker_tasks:
return
self._running = False
for task in self._worker_tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
self._worker_tasks.clear()
while not self._message_queue.empty():
with contextlib.suppress(asyncio.QueueEmpty):
self._message_queue.get_nowait()
self._message_queue.task_done()
async def _consumer_worker(self) -> None:
while self._running:
try:
payload = await self._message_queue.get()
except asyncio.CancelledError:
break
try:
await self.process_message(payload)
except Exception: # pragma: no cover - best effort logging
logging.getLogger("mofox_bus.server").exception("Error processing message")
finally:
self._message_queue.task_done()
async def verify_token(self, token: str | None) -> bool:
if not self._enable_token:
return True
return token in self._valid_tokens
def add_valid_token(self, token: str) -> None:
self._valid_tokens.add(token)
def remove_valid_token(self, token: str) -> None:
self._valid_tokens.discard(token)
async def _register_connection(self, websocket: WebSocket, platform: str) -> None:
async with self._conn_lock:
self._connections.add(websocket)
if platform:
previous = self._platform_connections.get(platform)
if previous and previous.client_state.name != "DISCONNECTED":
await previous.close(code=1000, reason="replaced")
self._platform_connections[platform] = websocket
async def _remove_connection(self, websocket: WebSocket, platform: str) -> None:
async with self._conn_lock:
self._connections.discard(websocket)
if platform and self._platform_connections.get(platform) is websocket:
del self._platform_connections[platform]
async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None:
payload: MessagePayload | bytes = message
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
async with self._conn_lock:
targets = list(self._connections)
for ws in targets:
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
async def broadcast_to_platform(
self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False
) -> None:
ws = self._platform_connections.get(platform)
if ws is None:
raise RuntimeError(f"平台 {platform} 没有活跃的连接")
payload: MessagePayload | bytes = message
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
async def send_message(
self, message: MessagePayload, *, prefer_raw_bytes: bool = False
) -> None:
platform = message.get("message_info", {}).get("platform")
if not platform:
raise ValueError("message_info.platform is required to route the message")
await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes)
def run_sync(self) -> None:
if not self._own_app:
return
asyncio.run(self.run())
async def run(self) -> None:
self._start_workers()
if not self._own_app:
return
config = uvicorn.Config(
self._app,
host=self.host,
port=self.port,
ssl_certfile=self._ssl_certfile,
ssl_keyfile=self._ssl_keyfile,
log_config=None,
access_log=False,
)
self._server = uvicorn.Server(config)
try:
await self._server.serve()
except asyncio.CancelledError: # pragma: no cover - shutdown path
pass
async def stop(self) -> None:
self._running = False
await self._stop_workers()
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
async with self._conn_lock:
targets = list(self._connections)
self._connections.clear()
self._platform_connections.clear()
for ws in targets:
try:
await ws.close(code=1001, reason="server shutting down")
except Exception: # pragma: no cover - best effort
pass
for task in list(self.background_tasks):
if not task.done():
task.cancel()
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
class MessageClient(BaseMessageHandler):
"""
WebSocket 消息客户端,实现双向传输。
"""
def __init__(
self,
mode: Literal["ws", "tcp"] = "ws",
*,
reconnect_interval: float = 5.0,
logger: logging.Logger | None = None,
) -> None:
super().__init__()
if mode != "ws":
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
self._mode = mode
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._receive_task: asyncio.Task | None = None
self._url: str = ""
self._platform: str = ""
self._token: str | None = None
self._ssl_verify: str | None = None
self._closed = False
self._on_disconnect: DisconnectCallback | None = None
self._reconnect_interval = reconnect_interval
self._logger = logger or logging.getLogger("mofox_bus.client")
async def connect(
self,
*,
url: str,
platform: str,
token: str | None = None,
ssl_verify: str | None = None,
) -> None:
self._url = url
self._platform = platform
self._token = token
self._ssl_verify = ssl_verify
self._closed = False
await self._establish_connection()
def set_disconnect_callback(self, callback: DisconnectCallback) -> None:
self._on_disconnect = callback
async def _establish_connection(self) -> None:
if self._session is None:
self._session = aiohttp.ClientSession()
headers = {"platform": self._platform}
if self._token:
headers["authorization"] = self._token
ssl_context = None
if self._ssl_verify:
ssl_context = ssl.create_default_context(cafile=self._ssl_verify)
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
self._receive_task = asyncio.create_task(self._receive_loop())
async def _connect_once(self) -> None:
await self._establish_connection()
async def _receive_loop(self) -> None:
assert self._ws is not None
try:
async for msg in self._ws:
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8")
try:
payload = orjson.loads(raw_bytes)
except orjson.JSONDecodeError:
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
continue
payload = _attach_raw_bytes(payload, raw_bytes)
if isinstance(payload, list):
for item in payload:
await self.process_message(item)
else:
await self.process_message(payload)
elif msg.type == aiohttp.WSMsgType.ERROR:
break
except asyncio.CancelledError: # pragma: no cover - cancellation path
pass
finally:
if not self._closed:
await self._notify_disconnect("websocket disconnected")
await self._reconnect()
if self._ws:
await self._ws.close()
self._ws = None
async def run(self) -> None:
self._closed = False
while not self._closed:
if self._receive_task is None:
await self._establish_connection()
task = self._receive_task
if task is None:
break
try:
await task
except asyncio.CancelledError: # pragma: no cover - cancellation path
raise
async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool:
ws = await self._ensure_ws()
data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_str(data if isinstance(data, str) else data.decode("utf-8"))
return True
def is_connected(self) -> bool:
return self._ws is not None and not self._ws.closed
async def stop(self) -> None:
self._closed = True
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._receive_task
if self._ws:
await self._ws.close()
self._ws = None
if self._session:
await self._session.close()
self._session = None
async def _notify_disconnect(self, reason: str) -> None:
if self._on_disconnect is None:
return
try:
result = self._on_disconnect(self._platform, reason)
if asyncio.iscoroutine(result):
await result
except Exception: # pragma: no cover - best effort notification
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
async def _reconnect(self) -> None:
self._logger.info(f"WebSocket 连接断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
await asyncio.sleep(self._reconnect_interval)
await self._connect_once()
async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._ws is None or self._ws.closed:
await self._connect_once()
assert self._ws is not None
return self._ws
async def __aenter__(self) -> "MessageClient":
if not self._url or not self._platform:
raise RuntimeError("connect() must be called before using MessageClient as a context manager")
await self._ensure_session()
await self._ensure_ws()
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.stop()
def _self_websocket(app: FastAPI, path: str):
"""
装饰器工厂,兼容 FastAPI websocket 路由的声明方式。
FastAPI 不允许直接重复注册同一路径,因此这里封装一个可复用的装饰器。
"""
def decorator(func):
app.add_api_websocket_route(path, func)
return func
return decorator
__all__ = ["BaseMessageHandler", "MessageClient", "MessageServer"]

View File

@@ -1,111 +0,0 @@
from __future__ import annotations
import time
import uuid
from typing import Any, Dict, List
from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload
class MessageBuilder:
"""
流式构建 MessageEnvelope 的助手工具,提供类型安全的构建方法。
使用示例:
msg = (
MessageBuilder()
.text("Hello")
.image("http://example.com/1.png")
.to_user("123", platform="qq")
.build()
)
"""
def __init__(self) -> None:
self._direction: str = "outgoing"
self._message_info: MessageInfoPayload = {}
self._segments: List[SegPayload] = []
self._metadata: Dict[str, Any] | None = None
self._timestamp_ms: int | None = None
self._message_id: str | None = None
def direction(self, value: str) -> "MessageBuilder":
self._direction = value
return self
def message_id(self, value: str) -> "MessageBuilder":
self._message_id = value
return self
def timestamp_ms(self, value: int | None = None) -> "MessageBuilder":
self._timestamp_ms = value or int(time.time() * 1000)
return self
def metadata(self, value: Dict[str, Any]) -> "MessageBuilder":
self._metadata = value
return self
def platform(self, value: str) -> "MessageBuilder":
self._message_info["platform"] = value
return self
def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder":
if platform:
self.platform(platform)
user_info: UserInfoPayload = {"user_id": user_id}
if nickname:
user_info["user_nickname"] = nickname
self._message_info["user_info"] = user_info
return self
def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder":
if platform:
self.platform(platform)
group_info: GroupInfoPayload = {"group_id": group_id}
if name:
group_info["group_name"] = name
self._message_info["group_info"] = group_info
return self
def seg(self, type_: str, data: Any) -> "MessageBuilder":
self._segments.append({"type": type_, "data": data})
return self
def text(self, content: str) -> "MessageBuilder":
return self.seg("text", content)
def image(self, url: str) -> "MessageBuilder":
return self.seg("image", url)
def reply(self, target_message_id: str) -> "MessageBuilder":
return self.seg("reply", target_message_id)
def raw_segment(self, segment: SegPayload) -> "MessageBuilder":
self._segments.append(segment)
return self
def build(self) -> MessageEnvelope:
"""构建最终的消息信封"""
# 设置 message_info 默认值
if not self._segments:
raise ValueError("需要至少添加一个消息段才能构建消息")
if self._message_id is None:
self._message_id = str(uuid.uuid4())
info = dict(self._message_info)
info.setdefault("message_id", self._message_id)
info.setdefault("time", time.time())
segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments]
envelope: MessageEnvelope = {
"direction": self._direction, # type: ignore[assignment]
"message_info": info,
"message_segment": segments[0] if len(segments) == 1 else list(segments),
}
if self._metadata is not None:
envelope["metadata"] = self._metadata
if self._timestamp_ms is not None:
envelope["timestamp_ms"] = self._timestamp_ms
return envelope
__all__ = ["MessageBuilder"]

View File

@@ -1,94 +0,0 @@
from __future__ import annotations
import json as _stdlib_json
from typing import Any, Dict, Iterable, List
try:
import orjson as _json_impl
except Exception: # pragma: no cover - fallback when orjson is unavailable
_json_impl = None
from .types import MessageEnvelope
DEFAULT_SCHEMA_VERSION = 1
def _dumps(obj: Any) -> bytes:
if _json_impl is not None:
return _json_impl.dumps(obj)
return _stdlib_json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
def _loads(data: bytes) -> Dict[str, Any]:
if _json_impl is not None:
return _json_impl.loads(data)
return _stdlib_json.loads(data.decode("utf-8"))
def dumps_message(msg: MessageEnvelope) -> bytes:
"""
将单条消息序列化为 JSON bytes。
"""
sanitized = _strip_raw_bytes(msg)
if "schema_version" not in sanitized:
sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION
return _dumps(sanitized)
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
"""
将批量消息序列化为 JSON bytes。
"""
payload = {
"schema_version": DEFAULT_SCHEMA_VERSION,
"items": [_strip_raw_bytes(msg) for msg in messages],
}
return _dumps(payload)
def loads_message(data: bytes | str) -> MessageEnvelope:
"""
反序列化单条消息。
"""
if isinstance(data, str):
data = data.encode("utf-8")
obj = _loads(data)
return _upgrade_schema_if_needed(obj)
def loads_messages(data: bytes | str) -> List[MessageEnvelope]:
"""
反序列化批量消息。
"""
if isinstance(data, str):
data = data.encode("utf-8")
obj = _loads(data)
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
if version != DEFAULT_SCHEMA_VERSION:
raise ValueError(f"不支持的 schema_version={version}")
return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])]
def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
"""
针对未来的 schema 版本演进预留兼容入口。
"""
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
if version == DEFAULT_SCHEMA_VERSION:
return obj # type: ignore[return-value]
raise ValueError(f"不支持的 schema_version={version}")
def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope:
if isinstance(msg, dict) and "raw_bytes" in msg:
new_msg = dict(msg)
new_msg.pop("raw_bytes", None)
return new_msg # type: ignore[return-value]
return msg
__all__ = [
"DEFAULT_SCHEMA_VERSION",
"dumps_message",
"dumps_messages",
"loads_message",
"loads_messages",
]

View File

@@ -1,265 +0,0 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
from dataclasses import asdict, dataclass
from typing import Callable, Dict, Optional
from .api import MessageClient
from .types import MessageEnvelope
logger = logging.getLogger("mofox_bus.router")
@dataclass
class TargetConfig:
"""路由目标配置,包含连接信息和认证配置"""
url: str
token: str | None = None
ssl_verify: str | None = None
def to_dict(self) -> Dict[str, str | None]:
"""转换为字典格式"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig":
"""从字典创建配置对象"""
return cls(
url=data.get("url", ""),
token=data.get("token"),
ssl_verify=data.get("ssl_verify"),
)
@dataclass
class RouteConfig:
"""路由配置,包含多个平台的路由目标"""
route_config: Dict[str, TargetConfig]
def to_dict(self) -> Dict[str, Dict[str, str | None]]:
"""转换为字典格式"""
return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}}
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig":
"""从字典创建路由配置对象"""
cfg = {
platform: TargetConfig.from_dict(target)
for platform, target in data.get("route_config", {}).items()
}
return cls(route_config=cfg)
class Router:
"""消息路由器,负责管理多个平台的消息客户端连接"""
def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None:
"""
初始化路由器
Args:
config: 路由配置
custom_logger: 自定义日志记录器
"""
if custom_logger:
logger.handlers = custom_logger.handlers
self.config = config
self.clients: Dict[str, MessageClient] = {}
self.handlers: list[Callable[[Dict], None]] = []
self._running = False
self._client_tasks: Dict[str, asyncio.Task] = {}
self._stop_event: asyncio.Event | None = None
async def connect(self, platform: str) -> None:
"""
连接到指定平台
Args:
platform: 平台标识
Raises:
ValueError: 未知平台
NotImplementedError: 不支持的模式
"""
if platform not in self.config.route_config:
raise ValueError(f"未知平台: {platform}")
target = self.config.route_config[platform]
mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws"
if mode != "ws":
raise NotImplementedError("TCP 模式暂未实现")
client = MessageClient(mode="ws")
client.set_disconnect_callback(self._handle_client_disconnect)
await client.connect(
url=target.url,
platform=platform,
token=target.token,
ssl_verify=target.ssl_verify,
)
for handler in self.handlers:
client.register_message_handler(handler)
self.clients[platform] = client
if self._running:
self._start_client_task(platform, client)
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
self.handlers.append(handler)
for client in self.clients.values():
client.register_message_handler(handler)
async def run(self) -> None:
"""启动路由器,连接所有配置的平台并开始运行"""
self._running = True
self._stop_event = asyncio.Event()
for platform in self.config.route_config:
if platform not in self.clients:
await self.connect(platform)
for platform, client in self.clients.items():
if platform not in self._client_tasks:
self._start_client_task(platform, client)
try:
await self._stop_event.wait()
except asyncio.CancelledError: # pragma: no cover
raise
async def remove_platform(self, platform: str) -> None:
"""
移除指定平台的连接
Args:
platform: 平台标识
"""
if platform in self._client_tasks:
task = self._client_tasks.pop(platform)
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
client = self.clients.pop(platform, None)
if client:
await client.stop()
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
"""
处理客户端断开连接
Args:
platform: 平台标识
reason: 断开原因
"""
logger.info(f"平台 {platform} 的客户端断开连接: {reason} (客户端将自动重连)")
task = self._client_tasks.get(platform)
if task is not None and not task.done():
return
client = self.clients.get(platform)
if client and self._running:
self._start_client_task(platform, client)
async def stop(self) -> None:
"""停止路由器,关闭所有连接"""
self._running = False
if self._stop_event:
self._stop_event.set()
for platform in list(self.clients.keys()):
await self.remove_platform(platform)
self.clients.clear()
def _start_client_task(self, platform: str, client: MessageClient) -> None:
"""
启动客户端任务
Args:
platform: 平台标识
client: 消息客户端
"""
task = asyncio.create_task(client.run())
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
self._client_tasks[platform] = task
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
"""
必要时重启客户端任务
Args:
platform: 平台标识
task: 已完成的任务
"""
if not self._running:
return
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning(f"平台 {platform} 的客户端任务异常结束: {exc}")
client = self.clients.get(platform)
if client:
self._start_client_task(platform, client)
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
"""
根据消息获取目标 URL
Args:
message: 消息信封
Returns:
目标 URL 或 None
"""
platform = message.get("message_info", {}).get("platform")
if not platform:
return None
target = self.config.route_config.get(platform)
return target.url if target else None
async def send_message(self, message: MessageEnvelope):
"""
发送消息到指定平台
Args:
message: 消息信封
Raises:
ValueError: 缺少平台信息
RuntimeError: 未找到对应平台的客户端
"""
platform = message.get("message_info", {}).get("platform")
if not platform:
raise ValueError("消息中缺少必需的 message_info.platform 字段")
client = self.clients.get(platform)
if client is None:
raise RuntimeError(f"平台 {platform} 没有已连接的客户端")
return await client.send_message(message)
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
"""
更新路由配置
Args:
config_data: 新的配置数据
"""
new_config = RouteConfig.from_dict(config_data)
await self._adjust_connections(new_config)
self.config = new_config
async def _adjust_connections(self, new_config: RouteConfig) -> None:
"""
调整连接以匹配新配置
Args:
new_config: 新的路由配置
"""
current = set(self.config.route_config.keys())
updated = set(new_config.route_config.keys())
# 移除不再存在的平台
for platform in current - updated:
await self.remove_platform(platform)
# 添加或更新平台
for platform in updated:
if platform not in current:
await self.connect(platform)
else:
old = self.config.route_config[platform]
new = new_config.route_config[platform]
if old.url != new.url or old.token != new.token:
await self.remove_platform(platform)
await self.connect(platform)

View File

@@ -1,470 +0,0 @@
from __future__ import annotations
import asyncio
import functools
import inspect
import threading
import weakref
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
from .types import MessageEnvelope
Hook = Callable[[MessageEnvelope], Awaitable[None] | None]
ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]]
class Middleware(Protocol):
async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ...
class MessageProcessingError(RuntimeError):
"""封装处理链路中发生的异常。"""
def __init__(self, message: MessageEnvelope, original: BaseException):
detail = message.get("id", "<unknown>")
super().__init__(f"处理消息 {detail} 时出错: {original}")
self.message_envelope = message
self.original = original
@dataclass
class MessageRoute:
"""消息路由配置,包含匹配条件和处理函数"""
predicate: Predicate
handler: MessageHandler
name: str | None = None
message_type: str | None = None
message_types: set[str] | None = None # 支持多个消息类型
event_types: set[str] | None = None
class MessageRuntime:
"""
消息运行时环境,负责调度消息路由、执行前后处理钩子以及批量处理消息
"""
def __init__(self) -> None:
self._routes: list[MessageRoute] = []
self._before_hooks: list[Hook] = []
self._after_hooks: list[Hook] = []
self._error_hooks: list[ErrorHook] = []
self._batch_handler: BatchHandler | None = None
self._lock = threading.RLock()
self._middlewares: list[Middleware] = []
self._type_routes: Dict[str, list[MessageRoute]] = {}
self._event_routes: Dict[str, list[MessageRoute]] = {}
# 用于检测同一类型的重复注册
self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name
def add_route(
self,
predicate: Predicate,
handler: MessageHandler,
name: str | None = None,
*,
message_type: str | list[str] | None = None,
event_types: Iterable[str] | None = None,
) -> None:
"""
添加消息路由
Args:
predicate: 路由匹配条件
handler: 消息处理函数
name: 路由名称(可选)
message_type: 消息类型,可以是字符串或字符串列表(可选)
event_types: 事件类型列表(可选)
"""
with self._lock:
# 处理 message_type 参数,支持字符串或列表
message_types_set: set[str] | None = None
single_message_type: str | None = None
if message_type is not None:
if isinstance(message_type, str):
message_types_set = {message_type}
single_message_type = message_type
elif isinstance(message_type, list):
message_types_set = set(message_type)
if len(message_types_set) == 1:
single_message_type = next(iter(message_types_set))
else:
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
# 检测重复注册:如果明确指定了某个类型,不允许重复
handler_name = name or getattr(handler, "__name__", str(handler))
for msg_type in message_types_set:
if msg_type in self._explicit_type_handlers:
existing_handler = self._explicit_type_handlers[msg_type]
raise ValueError(
f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册,"
f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。"
)
self._explicit_type_handlers[msg_type] = handler_name
route = MessageRoute(
predicate=predicate,
handler=handler,
name=name,
message_type=single_message_type,
message_types=message_types_set,
event_types=set(event_types) if event_types is not None else None,
)
self._routes.append(route)
# 为每个消息类型建立索引
if message_types_set:
for msg_type in message_types_set:
self._type_routes.setdefault(msg_type, []).append(route)
if route.event_types:
for et in route.event_types:
self._event_routes.setdefault(et, []).append(route)
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
"""装饰器写法,便于在核心逻辑中声明式注册。
支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。
"""
def decorator(func: MessageHandler) -> MessageHandler:
# Support decorating instance methods: defer binding until the object is created.
if _looks_like_method(func):
return _InstanceMethodRoute(
runtime=self,
func=func,
predicate=predicate,
name=name,
message_type=None,
)
self.add_route(predicate, func, name=name)
return func
return decorator
def on_message(
self,
func: MessageHandler | None = None,
*,
message_type: str | list[str] | None = None,
platform: str | None = None,
predicate: Predicate | None = None,
name: str | None = None,
) -> Callable[[MessageHandler], MessageHandler] | MessageHandler:
"""Sugar decorator with optional Seg.type/platform predicate matching.
Args:
func: 被装饰的函数
message_type: 消息类型,可以是单个字符串或字符串列表
platform: 平台名称
predicate: 自定义匹配条件
name: 路由名称
Usages:
- @runtime.on_message(...)
- @runtime.on_message
- @runtime.on_message(message_type="text")
- @runtime.on_message(message_type=["text", "image"])
If the target looks like an instance method (first arg is self), it will be
auto-bound to the instance and registered when the object is constructed.
"""
# 将 message_type 转换为集合以便统一处理
message_types_set: set[str] | None = None
if message_type is not None:
if isinstance(message_type, str):
message_types_set = {message_type}
elif isinstance(message_type, list):
message_types_set = set(message_type)
else:
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
async def combined_predicate(message: MessageEnvelope) -> bool:
if message_types_set is not None:
extracted_type = _extract_segment_type(message)
if extracted_type not in message_types_set:
return False
if platform is not None:
info_platform = message.get("message_info", {}).get("platform")
if message.get("platform") not in (None, platform) and info_platform is None:
return False
if info_platform not in (None, platform):
return False
if predicate is None:
return True
return await _invoke_callable(predicate, message, prefer_thread=False)
def decorator(func: MessageHandler) -> MessageHandler:
# Support decorating instance methods: defer binding until the object is created.
if _looks_like_method(func):
return _InstanceMethodRoute(
runtime=self,
func=func,
predicate=combined_predicate,
name=name,
message_type=message_type,
)
self.add_route(combined_predicate, func, name=name, message_type=message_type)
return func
if func is not None:
return decorator(func)
return decorator
def set_batch_handler(self, handler: BatchHandler) -> None:
self._batch_handler = handler
def register_before_hook(self, hook: Hook) -> None:
self._before_hooks.append(hook)
def register_after_hook(self, hook: Hook) -> None:
self._after_hooks.append(hook)
def register_error_hook(self, hook: ErrorHook) -> None:
self._error_hooks.append(hook)
def register_middleware(self, middleware: Middleware) -> None:
"""注册洋葱模型中间件,围绕处理器执行。"""
self._middlewares.append(middleware)
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
await self._run_hooks(self._before_hooks, message)
try:
route = await self._match_route(message)
if route is None:
return None
handler = self._wrap_with_middlewares(route.handler)
result = await handler(message)
except Exception as exc:
await self._run_error_hooks(message, exc)
raise MessageProcessingError(message, exc) from exc
await self._run_hooks(self._after_hooks, message)
return result
async def handle_batch(self, messages: Iterable[MessageEnvelope]) -> List[MessageEnvelope]:
batch = list(messages)
if not batch:
return []
if self._batch_handler is not None:
result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True)
return result or []
responses: list[MessageEnvelope] = []
for message in batch:
response = await self.handle_message(message)
if response is not None:
responses.append(response)
return responses
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
"""匹配消息路由,优先匹配明确指定了消息类型的处理器"""
message_type = _extract_segment_type(message)
event_type = (
message.get("event_type")
or message.get("message_info", {})
.get("additional_config", {})
.get("event_type")
)
# 分为两层候选:优先级和普通
priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的
normal_candidates: list[MessageRoute] = [] # 没有指定或通配的
with self._lock:
# 事件路由(优先级最高)
if event_type and event_type in self._event_routes:
priority_candidates.extend(self._event_routes[event_type])
# 消息类型路由(明确指定的有优先级)
if message_type and message_type in self._type_routes:
priority_candidates.extend(self._type_routes[message_type])
# 通用路由(没有明确指定类型的)
for route in self._routes:
# 如果路由没有指定 message_types则是通用路由
if route.message_types is None and route.event_types is None:
normal_candidates.append(route)
# 先尝试优先级候选
seen: set[int] = set()
for route in priority_candidates:
rid = id(route)
if rid in seen:
continue
seen.add(rid)
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
if should_handle:
return route
# 如果没有匹配到优先级候选,再尝试普通候选
for route in normal_candidates:
rid = id(route)
if rid in seen:
continue
seen.add(rid)
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
if should_handle:
return route
return None
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
coro_list = [self._call_hook(hook, message) for hook in hooks]
if coro_list:
await asyncio.gather(*coro_list)
async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None:
await _invoke_callable(hook, message, prefer_thread=True)
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks]
if coros:
await asyncio.gather(*coros)
async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None:
await _invoke_callable(hook, message, exc, prefer_thread=True)
def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable:
async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None:
return await _invoke_callable(handler, message, prefer_thread=True)
wrapped: MiddlewareCallable = base_handler
for middleware in reversed(self._middlewares):
current = wrapped
async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None:
return await _invoke_callable(mw, msg, nxt, prefer_thread=False)
wrapped = wrapper
return wrapped
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
"""支持 sync/async 调用,并可选择在线程中执行。
自动处理普通函数、类方法和绑定方法。
"""
# 如果是绑定方法bound method直接使用不需要额外处理
# 因为绑定方法已经包含了 self 参数
if inspect.ismethod(func):
# 绑定方法可以直接调用args 中不应包含 self
if inspect.iscoroutinefunction(func):
return await func(*args)
if prefer_thread:
result = await asyncio.to_thread(func, *args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
result = func(*args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
# 对于普通函数(未绑定的),按原有逻辑处理
if inspect.iscoroutinefunction(func):
return await func(*args)
if prefer_thread:
result = await asyncio.to_thread(func, *args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
result = func(*args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
def _extract_segment_type(message: MessageEnvelope) -> str | None:
seg = message.get("message_segment") or message.get("message_chain")
if isinstance(seg, dict):
return seg.get("type")
if isinstance(seg, list) and seg:
first = seg[0]
if isinstance(first, dict):
return first.get("type")
return None
def _looks_like_method(func: Callable[..., object]) -> bool:
"""Return True if callable signature suggests an instance method (first arg named self)."""
if inspect.ismethod(func):
return True
if not inspect.isfunction(func):
return False
params = inspect.signature(func).parameters
if not params:
return False
first = next(iter(params.values()))
return first.name == "self"
class _InstanceMethodRoute:
"""Descriptor that binds decorated instance methods and registers routes per-instance."""
def __init__(
self,
runtime: MessageRuntime,
func: MessageHandler,
predicate: Predicate,
name: str | None,
message_type: str | None,
) -> None:
self._runtime = runtime
self._func = func
self._predicate = predicate
self._name = name
self._message_type = message_type
self._owner: type | None = None
self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet()
def __set_name__(self, owner: type, name: str) -> None:
self._owner = owner
registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None)
if registry is None:
registry = []
setattr(owner, "_mofox_instance_routes", registry)
original_init = owner.__init__
@functools.wraps(original_init)
def wrapped_init(inst, *args, **kwargs):
original_init(inst, *args, **kwargs)
for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []):
descriptor._register_instance(inst)
owner.__init__ = wrapped_init # type: ignore[assignment]
registry.append(self)
def _register_instance(self, instance: object) -> None:
if instance in self._registered_instances:
return
owner = self._owner or instance.__class__
bound = self._func.__get__(instance, owner) # type: ignore[arg-type]
self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type)
self._registered_instances.add(instance)
def __get__(self, instance: object | None, owner: type | None = None):
if instance is None:
return self._func
self._register_instance(instance)
return self._func.__get__(instance, owner) # type: ignore[arg-type]
__all__ = [
"BatchHandler",
"Hook",
"MessageHandler",
"MessageProcessingError",
"MessageRoute",
"MessageRuntime",
"Middleware",
"Predicate",
]

View File

@@ -1,10 +0,0 @@
"""
传输层封装,提供 HTTP / WebSocket server & client。
"""
from .http_client import HttpMessageClient
from .http_server import HttpMessageServer
from .ws_client import WsMessageClient
from .ws_server import WsMessageServer
__all__ = ["HttpMessageClient", "HttpMessageServer", "WsMessageClient", "WsMessageServer"]

View File

@@ -1,67 +0,0 @@
from __future__ import annotations
import logging
from typing import Iterable, List, Sequence
import aiohttp
from ..codec import dumps_messages, loads_messages
from ..types import MessageEnvelope
class HttpMessageClient:
"""
面向消息批量传输的 HTTP 客户端封装
"""
def __init__(
self,
base_url: str,
*,
session: aiohttp.ClientSession | None = None,
timeout: aiohttp.ClientTimeout | None = None,
) -> None:
self._base_url = base_url.rstrip("/")
self._session = session
self._timeout = timeout
self._owns_session = session is None
self._logger = logging.getLogger("mofox_bus.http_client")
async def send_messages(
self,
messages: Sequence[MessageEnvelope],
*,
expect_reply: bool = False,
path: str = "/messages",
) -> List[MessageEnvelope] | None:
if not messages:
return []
session = await self._ensure_session()
url = f"{self._base_url}{path}"
payload = dumps_messages(messages)
self._logger.debug(f"正在发送 {len(messages)} 条消息 -> {url}")
async with session.post(url, data=payload, timeout=self._timeout) as resp:
resp.raise_for_status()
if not expect_reply:
return None
raw = await resp.read()
replies = loads_messages(raw)
self._logger.debug(f"接收到 {len(replies)} 条回复消息")
return replies
async def close(self) -> None:
if self._owns_session and self._session is not None:
await self._session.close()
self._session = None
async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session
async def __aenter__(self) -> "HttpMessageClient":
await self._ensure_session()
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

View File

@@ -1,53 +0,0 @@
from __future__ import annotations
import logging
from typing import Awaitable, Callable, List
from aiohttp import web
from ..codec import dumps_messages, loads_messages
from ..types import MessageEnvelope
MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None]]
class HttpMessageServer:
"""
轻量级 HTTP 消息入口,可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下
"""
def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None:
self._handler = handler
self._app = web.Application()
self._path = path
self._app.add_routes([web.post(path, self._handle_messages)])
self._logger = logging.getLogger("mofox_bus.http_server")
async def _handle_messages(self, request: web.Request) -> web.Response:
try:
raw = await request.read()
envelopes = loads_messages(raw)
self._logger.debug(f"接收到 {len(envelopes)} 条消息")
except Exception as exc: # pragma: no cover - network errors are integration tested
self._logger.exception(f"解析请求失败: {exc}")
raise web.HTTPBadRequest(reason=f"无效的负载: {exc}") from exc
result = await self._handler(envelopes)
if result is None:
return web.Response(status=200, text="ok")
payload = dumps_messages(result)
return web.Response(status=200, body=payload, content_type="application/json")
def make_app(self) -> web.Application:
"""
返回 aiohttp Application可被外部 servergunicorn/uvicorn直接使用。
"""
return self._app
def add_to_app(self, app: web.Application) -> None:
"""
将消息路由注册到给定的 aiohttp app方便与既有服务整合。
"""
app.router.add_post(self._path, self._handle_messages)

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable, Iterable, List, Sequence
import aiohttp
from ..codec import dumps_messages, loads_messages
from ..types import MessageEnvelope
IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]]
class WsMessageClient:
"""
管理 WebSocket 连接,提供 send/receive API并在后台读取消息
"""
def __init__(
self,
url: str,
*,
handler: IncomingHandler | None = None,
session: aiohttp.ClientSession | None = None,
reconnect_interval: float = 5.0,
) -> None:
self._url = url
self._handler = handler
self._session = session
self._reconnect_interval = reconnect_interval
self._owns_session = session is None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._receive_task: asyncio.Task | None = None
self._closed = False
self._logger = logging.getLogger("mofox_bus.ws_client")
async def connect(self) -> None:
await self._ensure_session()
await self._connect_once()
async def _connect_once(self) -> None:
assert self._session is not None
self._ws = await self._session.ws_connect(self._url)
self._logger.info(f"已连接到 {self._url}")
self._receive_task = asyncio.create_task(self._receive_loop())
async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None:
if not messages:
return
ws = await self._ensure_ws()
payload = dumps_messages(messages)
await ws.send_bytes(payload)
async def send_message(self, message: MessageEnvelope) -> None:
await self.send_messages([message])
async def close(self) -> None:
self._closed = True
if self._receive_task:
self._receive_task.cancel()
if self._ws:
await self._ws.close()
self._ws = None
if self._owns_session and self._session:
await self._session.close()
self._session = None
async def _receive_loop(self) -> None:
assert self._ws is not None
try:
async for msg in self._ws:
if msg.type in (aiohttp.WSMsgType.BINARY, aiohttp.WSMsgType.TEXT):
envelopes = loads_messages(msg.data)
for env in envelopes:
if self._handler is not None:
await self._handler(env)
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.warning(f"WebSocket 错误: {msg.data}")
break
except asyncio.CancelledError: # pragma: no cover - cancellation path
return
finally:
if not self._closed:
await self._reconnect()
async def _reconnect(self) -> None:
self._logger.info(f"WebSocket 断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
await asyncio.sleep(self._reconnect_interval)
await self._connect_once()
async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._ws is None or self._ws.closed:
await self._connect_once()
assert self._ws is not None
return self._ws
async def __aenter__(self) -> "WsMessageClient":
await self.connect()
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

View File

@@ -1,66 +0,0 @@
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, Iterable, List, Set
from aiohttp import WSMsgType, web
from ..codec import dumps_messages, loads_messages
from ..types import MessageEnvelope
WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]]
class WsMessageServer:
"""
封装 WebSocket 服务端逻辑,负责接收消息并广播响应
"""
def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None:
self._handler = handler
self._app = web.Application()
self._path = path
self._app.add_routes([web.get(path, self._handle_ws)])
self._connections: Set[web.WebSocketResponse] = set()
self._lock = asyncio.Lock()
self._logger = logging.getLogger("mofox_bus.ws_server")
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self._logger.info(f"WebSocket 连接打开: {request.remote}")
async with self._track_connection(ws):
async for message in ws:
if message.type in (WSMsgType.BINARY, WSMsgType.TEXT):
envelopes = loads_messages(message.data)
for env in envelopes:
await self._handler(env)
elif message.type == WSMsgType.ERROR:
self._logger.warning(f"WebSocket 连接错误: {ws.exception()}")
break
self._logger.info(f"WebSocket 连接关闭: {request.remote}")
return ws
@asynccontextmanager
async def _track_connection(self, ws: web.WebSocketResponse):
async with self._lock:
self._connections.add(ws)
try:
yield
finally:
async with self._lock:
self._connections.discard(ws)
async def broadcast(self, messages: Iterable[MessageEnvelope]) -> None:
payload = dumps_messages(list(messages))
async with self._lock:
targets = list(self._connections)
for ws in targets:
await ws.send_bytes(payload)
def make_app(self) -> web.Application:
return self._app

View File

@@ -1,93 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required
MessageDirection = Literal["incoming", "outgoing"]
# ----------------------------
# maim_message 风格的 TypedDict
# ----------------------------
class SegPayload(TypedDict, total=False):
"""
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
"""
type: Required[str]
data: Required[str | List["SegPayload"]]
translated_data: NotRequired[str | List["SegPayload"]]
class UserInfoPayload(TypedDict, total=False):
platform: NotRequired[str]
user_id: Required[str]
user_nickname: NotRequired[str]
user_cardname: NotRequired[str]
user_avatar: NotRequired[str]
class GroupInfoPayload(TypedDict, total=False):
platform: NotRequired[str]
group_id: Required[str]
group_name: NotRequired[str]
class FormatInfoPayload(TypedDict, total=False):
content_format: NotRequired[List[str]]
accept_format: NotRequired[List[str]]
class TemplateInfoPayload(TypedDict, total=False):
template_items: NotRequired[Dict[str, str]]
template_name: NotRequired[Dict[str, str]]
template_default: NotRequired[bool]
class MessageInfoPayload(TypedDict, total=False):
platform: Required[str]
message_id: Required[str]
time: NotRequired[float]
group_info: NotRequired[GroupInfoPayload]
user_info: NotRequired[UserInfoPayload]
format_info: NotRequired[FormatInfoPayload]
template_info: NotRequired[TemplateInfoPayload]
additional_config: NotRequired[Dict[str, Any]]
# ----------------------------
# MessageEnvelope
# ----------------------------
class MessageEnvelope(TypedDict, total=False):
"""
mofox-bus 传输层统一使用的消息信封。
- 采用 maim_message 风格message_info + message_segment。
"""
direction: MessageDirection
message_info: Required[MessageInfoPayload]
message_segment: Required[SegPayload] | List[SegPayload]
raw_message: NotRequired[Any]
raw_bytes: NotRequired[bytes]
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
platform: NotRequired[str] # 快捷访问,等价于 message_info.platform
message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id
timestamp_ms: NotRequired[int]
correlation_id: NotRequired[str]
schema_version: NotRequired[int]
metadata: NotRequired[Dict[str, Any]]
__all__ = [
# maim_message style payloads
"SegPayload",
"UserInfoPayload",
"GroupInfoPayload",
"FormatInfoPayload",
"TemplateInfoPayload",
"MessageInfoPayload",
# legacy content style
"MessageDirection",
"MessageEnvelope",
]

View File

@@ -62,8 +62,8 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
self._config: Dict[str, Any] = {}
self._health_check_task: Optional[asyncio.Task] = None
self._running = False
# 标记是否在子进程中运行(由核心管理器传入 ProcessCoreSink 时自动生效)
self._is_subprocess = isinstance(core_sink, ProcessCoreSink)
# 标记是否在子进程中运行
self._is_subprocess = False
@classmethod
def from_process_queues(

View File

@@ -2,6 +2,11 @@
Adapter 管理器
负责管理所有注册的适配器,支持子进程自动启动和生命周期管理。
重构说明2025-11
- 使用 CoreSinkManager 统一管理 InProcessCoreSink 和 ProcessCoreSink
- 根据适配器的 run_in_subprocess 属性自动选择 CoreSink 类型
- 子进程适配器通过 CoreSinkManager 的通信队列与主进程交互
"""
from __future__ import annotations
@@ -15,7 +20,6 @@ if TYPE_CHECKING:
from src.plugin_system.base.base_adapter import BaseAdapter
from mofox_bus import ProcessCoreSinkServer
from src.common.core_sink import get_core_sink
from src.common.logger import get_logger
logger = get_logger("adapter_manager")
@@ -34,6 +38,11 @@ def _adapter_process_entry(
incoming_queue: mp.Queue,
outgoing_queue: mp.Queue,
):
"""
子进程适配器入口函数
在子进程中运行,创建 ProcessCoreSink 与主进程通信
"""
import asyncio
import contextlib
from mofox_bus import ProcessCoreSink
@@ -44,9 +53,14 @@ def _adapter_process_entry(
if plugin_info:
plugin_cls = _load_class(plugin_info["module"], plugin_info["class"])
plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"])
# 创建 ProcessCoreSink 用于与主进程通信
core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue)
# 创建并启动适配器
adapter = adapter_cls(core_sink, plugin=plugin_instance)
await adapter.start()
try:
while not getattr(core_sink, "_closed", False):
await asyncio.sleep(0.2)
@@ -61,21 +75,23 @@ def _adapter_process_entry(
class AdapterProcess:
"""适配器子进程封装:管理子进程的生命周期与通信桥接"""
"""
适配器子进程封装:管理子进程的生命周期与通信桥接
使用 CoreSinkManager 创建通信队列,自动维护与子进程的消息通道
"""
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin, core_sink) -> None:
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin) -> None:
self.adapter_cls = adapter_cls
self.adapter_name = adapter_cls.adapter_name
self.plugin = plugin
self.process: mp.Process | None = None
self._ctx = mp.get_context("spawn")
self._incoming_queue: mp.Queue = self._ctx.Queue()
self._outgoing_queue: mp.Queue = self._ctx.Queue()
self._incoming_queue: mp.Queue | None = None
self._outgoing_queue: mp.Queue | None = None
self._bridge: ProcessCoreSinkServer | None = None
self._core_sink = core_sink
self._adapter_path: tuple[str, str] = (adapter_cls.__module__, adapter_cls.__name__)
self._plugin_info = self._extract_plugin_info(plugin)
self._outgoing_handler = None
@staticmethod
def _extract_plugin_info(plugin) -> dict | None:
@@ -88,37 +104,28 @@ class AdapterProcess:
"metadata": getattr(plugin, "plugin_meta", None),
}
def _make_outgoing_handler(self):
async def _handler(envelope):
if self._bridge:
await self._bridge.push_outgoing(envelope)
return _handler
async def start(self) -> bool:
"""启动适配器子进程"""
try:
logger.info(f"启动适配器子进程: {self.adapter_name}")
self._bridge = ProcessCoreSinkServer(
incoming_queue=self._incoming_queue,
outgoing_queue=self._outgoing_queue,
core_handler=self._core_sink.send,
name=self.adapter_name,
)
self._bridge.start()
if hasattr(self._core_sink, "set_outgoing_handler"):
self._outgoing_handler = self._make_outgoing_handler()
try:
self._core_sink.set_outgoing_handler(self._outgoing_handler)
except Exception:
logger.exception("Failed to register outgoing bridge for %s", self.adapter_name)
# 从 CoreSinkManager 获取通信队列
from src.common.core_sink_manager import get_core_sink_manager
manager = get_core_sink_manager()
self._incoming_queue, self._outgoing_queue = manager.create_process_sink_queues(self.adapter_name)
# 启动子进程
self.process = self._ctx.Process(
target=_adapter_process_entry,
args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue),
name=f"{self.adapter_name}-proc",
)
self.process.start()
logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})")
return True
except Exception as e:
logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True)
return False
@@ -127,26 +134,31 @@ class AdapterProcess:
"""停止适配器子进程"""
if not self.process:
return
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
try:
remover = getattr(self._core_sink, "remove_outgoing_handler", None)
if callable(remover) and self._outgoing_handler:
try:
remover(self._outgoing_handler)
except Exception:
logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败")
if self._bridge:
await self._bridge.close()
# 从 CoreSinkManager 移除通信队列
from src.common.core_sink_manager import get_core_sink_manager
manager = get_core_sink_manager()
manager.remove_process_sink(self.adapter_name)
# 等待子进程结束
if self.process.is_alive():
self.process.join(timeout=5.0)
if self.process.is_alive():
logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中")
self.process.terminate()
self.process.join()
except Exception as e:
logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True)
finally:
self.process = None
self._incoming_queue = None
self._outgoing_queue = None
def is_running(self) -> bool:
"""适配器是否正在运行"""
@@ -155,7 +167,13 @@ class AdapterProcess:
return self.process.is_alive()
class AdapterManager:
"""适配器管理器"""
"""
适配器管理器
负责管理所有注册的适配器,根据 run_in_subprocess 属性自动选择:
- run_in_subprocess=True: 在子进程中运行,使用 ProcessCoreSink
- run_in_subprocess=False: 在主进程中运行,使用 InProcessCoreSink
"""
def __init__(self):
# 注册信息name -> (adapter class, plugin instance | None)
@@ -178,14 +196,26 @@ class AdapterManager:
self._adapter_defs[adapter_name] = (adapter_cls, plugin)
adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown')
logger.info(f"注册适配器: {adapter_name} v{adapter_version}")
run_in_subprocess = getattr(adapter_cls, 'run_in_subprocess', False)
logger.info(
f"注册适配器: {adapter_name} v{adapter_version} "
f"(子进程: {'' if run_in_subprocess else ''})"
)
async def start_adapter(self, adapter_name: str) -> bool:
"""启动指定适配器"""
"""
启动指定适配器
根据适配器的 run_in_subprocess 属性自动选择:
- True: 创建子进程,使用 ProcessCoreSink
- False: 在当前进程,使用 InProcessCoreSink
"""
definition = self._adapter_defs.get(adapter_name)
if not definition:
logger.error(f"适配器 {adapter_name} 未注册")
return False
adapter_cls, plugin = definition
run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False)
@@ -193,15 +223,14 @@ class AdapterManager:
return await self._start_adapter_subprocess(adapter_name, adapter_cls, plugin)
return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin)
async def _start_adapter_subprocess(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
"""在子进程中启动适配器"""
try:
core_sink = get_core_sink()
except Exception as e:
logger.error(f"无法获取 core_sink启动子进程 {adapter_name} 失败: {e}", exc_info=True)
return False
adapter_process = AdapterProcess(adapter_cls, plugin, core_sink)
async def _start_adapter_subprocess(
self,
adapter_name: str,
adapter_cls: type[BaseAdapter],
plugin
) -> bool:
"""在子进程中启动适配器(使用 ProcessCoreSink"""
adapter_process = AdapterProcess(adapter_cls, plugin)
success = await adapter_process.start()
if success:
@@ -209,15 +238,25 @@ class AdapterManager:
return success
async def _start_adapter_in_process(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
"""在当前进程中启动适配器"""
async def _start_adapter_in_process(
self,
adapter_name: str,
adapter_cls: type[BaseAdapter],
plugin
) -> bool:
"""在当前进程中启动适配器(使用 InProcessCoreSink"""
try:
core_sink = get_core_sink()
# 从 CoreSinkManager 获取 InProcessCoreSink
from src.common.core_sink_manager import get_core_sink_manager
core_sink = get_core_sink_manager().get_in_process_sink()
adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg]
await adapter.start()
self._in_process_adapters[adapter_name] = adapter
logger.info(f"适配器 {adapter_name} 已在当前进程启动")
return True
except Exception as e:
logger.error(f"启动适配器 {adapter_name} 失败: {e}", exc_info=True)
return False

View File

@@ -17,12 +17,13 @@ from typing import Any, ClassVar, Dict, List, Optional
import orjson
import websockets
from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions
from mofox_bus import CoreSink, MessageEnvelope, WebSocketAdapterOptions
from src.common.logger import get_logger
from src.plugin_system import register_plugin
from src.plugin_system.base import BaseAdapter, BasePlugin
from src.plugin_system.apis import config_api
from .src.handlers import utils as handler_utils
from .src.handlers.to_core.message_handler import MessageHandler
from .src.handlers.to_core.notice_handler import NoticeHandler
from .src.handlers.to_core.meta_event_handler import MetaEventHandler
@@ -41,22 +42,16 @@ class NapcatAdapter(BaseAdapter):
platform = "qq"
run_in_subprocess = False
subprocess_entry = None
def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None):
def __init__(self, core_sink: CoreSink, plugin: Optional[BasePlugin] = None):
"""初始化 Napcat 适配器"""
# 从插件配置读取 WebSocket URL
if plugin:
mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse")
host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost")
port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095)
url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "")
access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "")
if mode == "forward" and url:
ws_url = url
else:
ws_url = f"ws://{host}:{port}"
ws_url = f"ws://{host}:{port}"
headers = {}
if access_token:
@@ -69,8 +64,6 @@ class NapcatAdapter(BaseAdapter):
transport = WebSocketAdapterOptions(
url=ws_url,
headers=headers if headers else None,
incoming_parser=self._parse_napcat_message,
outgoing_encoder=self._encode_napcat_response,
)
super().__init__(core_sink, plugin=plugin, transport=transport)
@@ -89,6 +82,9 @@ class NapcatAdapter(BaseAdapter):
# 注意_ws 继承自 BaseAdapter是 WebSocketLike 协议类型
self._napcat_ws = None # 可选的额外连接引用
# 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS
handler_utils.register_adapter(self)
async def on_adapter_loaded(self) -> None:
"""适配器加载时的初始化"""
logger.info("Napcat 适配器正在启动...")
@@ -114,22 +110,6 @@ class NapcatAdapter(BaseAdapter):
logger.info("Napcat 适配器已关闭")
def _parse_napcat_message(self, raw: str | bytes) -> Any:
"""解析 Napcat/OneBot 消息"""
try:
if isinstance(raw, bytes):
data = orjson.loads(raw)
else:
data = orjson.loads(raw)
return data
except Exception as e:
logger.error(f"解析 Napcat 消息失败: {e}")
raise
def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes:
"""编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)"""
return orjson.dumps(envelope)
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
"""
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
@@ -178,20 +158,6 @@ class NapcatAdapter(BaseAdapter):
"""
await self.send_handler.handle_message(envelope)
def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return]
"""创建一个空的消息信封(用于不需要处理的事件)"""
import time
return {
"direction": "incoming",
"message_info": {
"platform": self.platform,
"message_id": str(uuid.uuid4()),
"time": time.time(),
},
"message_segment": {"type": "text", "data": "[系统事件]"},
"timestamp_ms": int(time.time() * 1000),
}
async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
"""
发送 Napcat API 请求并等待响应
@@ -260,18 +226,12 @@ class NapcatAdapterPlugin(BasePlugin):
config_schema: ClassVar[dict] = {
"plugin": {
"name": {"type": str, "default": "napcat_adapter_plugin"},
"version": {"type": str, "default": "2.0.0"},
"version": {"type": str, "default": "1.0.0"},
"enabled": {"type": bool, "default": True},
},
"napcat_server": {
"mode": {
"type": str,
"default": "reverse",
"description": "连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
},
"host": {"type": str, "default": "localhost"},
"port": {"type": int, "default": 8095},
"url": {"type": str, "default": "", "description": "正向连接时的完整URL"},
"access_token": {"type": str, "default": ""},
},
"features": {
@@ -284,34 +244,18 @@ class NapcatAdapterPlugin(BasePlugin):
},
}
def __init__(self, plugin_dir: str = "", metadata: Any = None):
# 如果没有提供参数,创建一个默认的元数据
if metadata is None:
from src.plugin_system.base.plugin_metadata import PluginMetadata
metadata = PluginMetadata(
name=self.plugin_name,
version=self.plugin_version,
author=self.plugin_author,
description=self.plugin_description,
usage="",
dependencies=[],
python_dependencies=[],
)
if not plugin_dir:
from pathlib import Path
plugin_dir = str(Path(__file__).parent)
super().__init__(plugin_dir, metadata)
def __init__(self):
self._adapter: Optional[NapcatAdapter] = None
async def on_plugin_loaded(self):
"""插件加载时启动适配器"""
logger.info("Napcat 适配器插件正在加载...")
# 获取核心 Sink
from src.common.core_sink import get_core_sink
core_sink = get_core_sink()
# 从 CoreSinkManager 获取 InProcessCoreSink
from src.common.core_sink_manager import get_core_sink_manager
core_sink_manager = get_core_sink_manager()
core_sink = core_sink_manager.get_in_process_sink()
# 创建并启动适配器
self._adapter = NapcatAdapter(core_sink, plugin=self)

View File

@@ -5,11 +5,28 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from mofox_bus import MessageBuilder
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
from mofox_bus import (
MessageEnvelope,
SegPayload,
MessageInfoPayload,
UserInfoPayload,
GroupInfoPayload,
)
from ...event_models import ACCEPT_FORMAT, QQ_FACE
from ..utils import (
get_group_info,
get_image_base64,
get_self_info,
get_member_info,
get_message_detail,
)
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
from ....plugin import NapcatAdapter
logger = get_logger("napcat_adapter.message_handler")
@@ -28,99 +45,190 @@ class MessageHandler:
async def handle_raw_message(self, raw: Dict[str, Any]):
"""
处理原始消息并转换为 MessageEnvelope
Args:
raw: OneBot 原始消息数据
Returns:
MessageEnvelope (dict)
"""
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload
message_type = raw.get("message_type")
message_id = str(raw.get("message_id", ""))
message_time = time.time()
msg_builder = MessageBuilder()
# 构造用户信息
sender_info = raw.get("sender", {})
user_info: UserInfoPayload = {
"platform": "qq",
"user_id": str(sender_info.get("user_id", "")),
"user_nickname": sender_info.get("nickname", ""),
"user_cardname": sender_info.get("card", ""),
"user_avatar": sender_info.get("avatar", ""),
}
(
msg_builder.direction("incoming")
.message_id(message_id)
.timestamp_ms(int(message_time * 1000))
.from_user(
user_id=str(sender_info.get("user_id", "")),
platform="qq",
nickname=sender_info.get("nickname", ""),
cardname=sender_info.get("card", ""),
user_avatar=sender_info.get("avatar", ""),
)
)
# 构造群组信息(如果是群消息)
group_info: Optional[GroupInfoPayload] = None
if message_type == "group":
group_id = raw.get("group_id")
if group_id:
group_info = {
"platform": "qq",
"group_id": str(group_id),
"group_name": "", # 可以通过 API 获取
}
fetched_group_info = await get_group_info(group_id)
(
msg_builder.from_group(
group_id=str(group_id),
platform="qq",
name=(
fetched_group_info.get("group_name", "")
if fetched_group_info
else raw.get("group_name", "")
),
)
)
# 解析消息段
message_segments = raw.get("message", [])
seg_list: List[SegPayload] = []
for seg in message_segments:
seg_type = seg.get("type", "")
seg_data = seg.get("data", {})
# 转换为 SegPayload
if seg_type == "text":
seg_list.append({
"type": "text",
"data": seg_data.get("text", "")
})
elif seg_type == "image":
# 这里需要下载图片并转换为 base64简化版本
seg_list.append({
"type": "image",
"data": seg_data.get("url", "") # 实际应该转换为 base64
})
elif seg_type == "at":
seg_list.append({
"type": "at",
"data": f"{seg_data.get('qq', '')}"
})
# 其他消息类型...
# 构造 MessageInfoPayload
message_info = {
"platform": "qq",
"message_id": message_id,
"time": message_time,
"user_info": user_info,
"format_info": {
"content_format": ["text", "image"], # 根据实际消息类型设置
"accept_format": ["text", "image", "emoji", "voice"],
},
}
# 添加群组信息(如果存在)
if group_info:
message_info["group_info"] = group_info
for segment in message_segments:
seg_message = await self.handle_single_segment(segment, raw)
if seg_message:
seg_list.append(seg_message)
# 构造 MessageEnvelope
envelope = {
"direction": "incoming",
"message_info": message_info,
"message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}),
"raw_message": raw.get("raw_message", ""),
"platform": "qq",
"message_id": message_id,
"timestamp_ms": int(message_time * 1000),
}
msg_builder.format_info(
content_format=[seg["type"] for seg in seg_list],
accept_format=ACCEPT_FORMAT,
)
return envelope
return msg_builder.build()
async def handle_single_segment(
self, segment: dict, raw_message: dict, in_reply: bool = False
) -> SegPayload | List[SegPayload] | None:
"""
处理单一消息段并转换为 MessageEnvelope
Args:
segment: 单一原始消息段
raw_message: 完整的原始消息数据
Returns:
SegPayload | List[SegPayload] | None
"""
seg_type = segment.get("type")
seg_data: dict = segment.get("data", {})
match seg_type:
case "text":
return {"type": "text", "data": seg_data.get("text", "")}
case "image":
image_sub_type = seg_data.get("sub_type")
try:
image_base64 = await get_image_base64(seg_data.get("url", ""))
except Exception as e:
logger.error(f"图片消息处理失败: {str(e)}")
return None
if image_sub_type == 0:
"""这部分认为是图片"""
return {"type": "image", "data": image_base64}
elif image_sub_type not in [4, 9]:
"""这部分认为是表情包"""
return {"type": "emoji", "data": image_base64}
else:
logger.warning(f"不支持的图片子类型:{image_sub_type}")
return None
case "face":
message_data: dict = segment.get("data", {})
face_raw_id: str = str(message_data.get("id"))
if face_raw_id in QQ_FACE:
face_content: str = QQ_FACE.get(face_raw_id, "[未知表情]")
return {"type": "text", "data": face_content}
else:
logger.warning(f"不支持的表情:{face_raw_id}")
return None
case "at":
if seg_data:
qq_id = seg_data.get("qq")
self_id = raw_message.get("self_id")
group_id = raw_message.get("group_id")
if str(self_id) == str(qq_id):
logger.debug("机器人被at")
self_info = await get_self_info()
if self_info:
# 返回包含昵称和用户ID的at格式便于后续处理
return {
"type": "at",
"data": f"{self_info.get('nickname')}:{self_info.get('user_id')}",
}
else:
return None
else:
if qq_id and group_id:
member_info = await get_member_info(
group_id=group_id, user_id=qq_id
)
if member_info:
# 返回包含昵称和用户ID的at格式便于后续处理
return {
"type": "at",
"data": f"{member_info.get('nickname')}:{member_info.get('user_id')}",
}
else:
return None
case "emoji":
seg_data = segment.get("id", "")
case "reply":
if not in_reply:
message_id = None
if seg_data:
message_id = seg_data.get("id")
else:
return None
message_detail = await get_message_detail(message_id)
if not message_detail:
logger.warning("获取被引用的消息详情失败")
return None
reply_message = await self.handle_single_segment(
message_detail, raw_message, in_reply=True
)
if reply_message is None:
reply_message = [
{"type": "text", "data": "[无法获取被引用的消息]"}
]
sender_info: dict = message_detail.get("sender", {})
sender_nickname: str = sender_info.get("nickname", "")
sender_id = sender_info.get("user_id")
seg_message: List[SegPayload] = []
if not sender_nickname:
logger.warning("无法获取被引用的人的昵称,返回默认值")
seg_message.append(
{
"type": "text",
"data": f"[回复<未知用户>{reply_message}],说:",
}
)
else:
if sender_id:
seg_message.append(
{
"type": "text",
"data": f"[回复<{sender_nickname}({sender_id})>{reply_message}],说:",
}
)
else:
seg_message.append(
{
"type": "text",
"data": f"[回复<{sender_nickname}>{reply_message}],说:",
}
)
return seg_message
case "voice":
seg_data = segment.get("url", "")
case _:
logger.warning(f"Unsupported segment type: {seg_type}")

View File

@@ -0,0 +1,361 @@
import asyncio
import base64
import io
import ssl
import time
import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import orjson
import urllib3
from PIL import Image
from src.common.logger import get_logger
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
logger = get_logger("napcat_adapter")
# 简单的缓存实现,通过 JSON 文件实现磁盘一价存储
_CACHE_FILE = Path(__file__).resolve().parent / "napcat_cache.json"
_CACHE_LOCK = asyncio.Lock()
_CACHE: Dict[str, Dict[str, Dict[str, Any]]] = {
"group_info": {},
"group_detail_info": {},
"member_info": {},
"stranger_info": {},
"self_info": {},
}
# 各类信息的 TTL 缓存过期时间设置
GROUP_INFO_TTL = 300 # 5 min
GROUP_DETAIL_TTL = 300
MEMBER_INFO_TTL = 180
STRANGER_INFO_TTL = 300
SELF_INFO_TTL = 300
_adapter_ref: weakref.ReferenceType["NapcatAdapter"] | None = None
def register_adapter(adapter: "NapcatAdapter") -> None:
"""注册 NapcatAdapter 实例,以便 utils 模块可以获取 WebSocket"""
global _adapter_ref
_adapter_ref = weakref.ref(adapter)
logger.debug("Napcat adapter registered in utils for websocket access")
def _load_cache_from_disk() -> None:
if not _CACHE_FILE.exists():
return
try:
data = orjson.loads(_CACHE_FILE.read_bytes())
if isinstance(data, dict):
for key, section in _CACHE.items():
cached_section = data.get(key)
if isinstance(cached_section, dict):
section.update(cached_section)
except Exception as e:
logger.debug(f"Failed to load napcat cache: {e}")
def _save_cache_to_disk_locked() -> None:
"""重要提示:不要在持有 _CACHE_LOCK 时调用此函数"""
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
_CACHE_FILE.write_bytes(orjson.dumps(_CACHE))
async def _get_cached(section: str, key: str, ttl: int) -> Any | None:
now = time.time()
async with _CACHE_LOCK:
entry = _CACHE.get(section, {}).get(key)
if not entry:
return None
ts = entry.get("ts", 0)
if ts and now - ts <= ttl:
return entry.get("data")
_CACHE.get(section, {}).pop(key, None)
try:
_save_cache_to_disk_locked()
except Exception:
pass
return None
async def _set_cached(section: str, key: str, data: Any) -> None:
async with _CACHE_LOCK:
_CACHE.setdefault(section, {})[key] = {"data": data, "ts": time.time()}
try:
_save_cache_to_disk_locked()
except Exception:
logger.debug("Write napcat cache failed", exc_info=True)
def _get_adapter(adapter: "NapcatAdapter | None" = None) -> "NapcatAdapter":
target = adapter
if target is None and _adapter_ref:
target = _adapter_ref()
if target is None:
raise RuntimeError("NapcatAdapter 未注册,请确保已调用 utils.register_adapter 注册")
return target
async def _call_adapter_api(
action: str,
params: Dict[str, Any],
adapter: "NapcatAdapter | None" = None,
timeout: float = 30.0,
) -> Optional[Dict[str, Any]]:
"""统一通过 adapter 发送和接收 API 调用"""
try:
target = _get_adapter(adapter)
# 确保 WS 已连接
target.get_ws_connection()
except Exception as e: # pragma: no cover - 难以在单元测试中查看
logger.error(f"WebSocket 未准备好,无法调用 API: {e}")
return None
try:
return await target.send_napcat_api(action, params, timeout=timeout)
except Exception as e:
logger.error(f"{action} 调用失败: {e}")
return None
# 加载缓存到内存一次,避免在每次调用缓存时重复加载
_load_cache_from_disk()
class SSLAdapter(urllib3.PoolManager):
def __init__(self, *args, **kwargs):
context = ssl.create_default_context()
context.set_ciphers("DEFAULT@SECLEVEL=1")
context.minimum_version = ssl.TLSVersion.TLSv1_2
kwargs["ssl_context"] = context
super().__init__(*args, **kwargs)
async def get_group_info(
group_id: int,
*,
use_cache: bool = True,
force_refresh: bool = False,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取群组基本信息
返回值可能是None需要调用方检查空值
"""
logger.debug("获取群组基本信息中")
cache_key = str(group_id)
if use_cache and not force_refresh:
cached = await _get_cached("group_info", cache_key, GROUP_INFO_TTL)
if cached is not None:
return cached
socket_response = await _call_adapter_api(
"get_group_info",
{"group_id": group_id},
adapter=adapter,
)
data = socket_response.get("data") if socket_response else None
if data is not None and use_cache:
await _set_cached("group_info", cache_key, data)
return data
async def get_group_detail_info(
group_id: int,
*,
use_cache: bool = True,
force_refresh: bool = False,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取群组详细信息
返回值可能是None需要调用方检查空值
"""
logger.debug("获取群组详细信息中")
cache_key = str(group_id)
if use_cache and not force_refresh:
cached = await _get_cached("group_detail_info", cache_key, GROUP_DETAIL_TTL)
if cached is not None:
return cached
socket_response = await _call_adapter_api(
"get_group_detail_info",
{"group_id": group_id},
adapter=adapter,
)
data = socket_response.get("data") if socket_response else None
if data is not None and use_cache:
await _set_cached("group_detail_info", cache_key, data)
return data
async def get_member_info(
group_id: int,
user_id: int,
*,
use_cache: bool = True,
force_refresh: bool = False,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取群组成员信息
返回值可能是None需要调用方检查空值
"""
logger.debug("获取群组成员信息中")
cache_key = f"{group_id}:{user_id}"
if use_cache and not force_refresh:
cached = await _get_cached("member_info", cache_key, MEMBER_INFO_TTL)
if cached is not None:
return cached
socket_response = await _call_adapter_api(
"get_group_member_info",
{"group_id": group_id, "user_id": user_id, "no_cache": True},
adapter=adapter,
)
data = socket_response.get("data") if socket_response else None
if data is not None and use_cache:
await _set_cached("member_info", cache_key, data)
return data
async def get_image_base64(url: str) -> str:
# sourcery skip: raise-specific-error
"""下载图片/视频并返回Base64"""
logger.debug(f"下载图片: {url}")
http = SSLAdapter()
try:
response = http.request("GET", url, timeout=10)
if response.status != 200:
raise Exception(f"HTTP Error: {response.status}")
image_bytes = response.data
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"图片下载失败: {str(e)}")
raise
def convert_image_to_gif(image_base64: str) -> str:
# sourcery skip: extract-method
"""
将Base64编码的图片转换为GIF格式
Parameters:
image_base64: str: Base64编码的图片数据
Returns:
str: Base64编码的GIF图片数据
"""
logger.debug("转换图片为GIF格式")
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
output_buffer = io.BytesIO()
image.save(output_buffer, format="GIF")
output_buffer.seek(0)
return base64.b64encode(output_buffer.read()).decode("utf-8")
except Exception as e:
logger.error(f"图片转换为GIF失败: {str(e)}")
return image_base64
async def get_self_info(
*,
use_cache: bool = True,
force_refresh: bool = False,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取机器人信息
"""
logger.debug("获取机器人信息中")
cache_key = "self"
if use_cache and not force_refresh:
cached = await _get_cached("self_info", cache_key, SELF_INFO_TTL)
if cached is not None:
return cached
response = await _call_adapter_api("get_login_info", {}, adapter=adapter)
data = response.get("data") if response else None
if data is not None and use_cache:
await _set_cached("self_info", cache_key, data)
return data
def get_image_format(raw_data: str) -> str:
"""
从Base64编码的数据中确定图片的格式类型
Parameters:
raw_data: str: Base64编码的图片数据
Returns:
format: str: 图片的格式类型,如 'jpeg', 'png', 'gif'
"""
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
async def get_stranger_info(
user_id: int,
*,
use_cache: bool = True,
force_refresh: bool = False,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取陌生人信息
"""
logger.debug("获取陌生人信息中")
cache_key = str(user_id)
if use_cache and not force_refresh:
cached = await _get_cached("stranger_info", cache_key, STRANGER_INFO_TTL)
if cached is not None:
return cached
response = await _call_adapter_api("get_stranger_info", {"user_id": user_id}, adapter=adapter)
data = response.get("data") if response else None
if data is not None and use_cache:
await _set_cached("stranger_info", cache_key, data)
return data
async def get_message_detail(
message_id: Union[str, int],
*,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取消息详情,仅作为参考
"""
logger.debug("获取消息详情中")
response = await _call_adapter_api(
"get_msg",
{"message_id": message_id},
adapter=adapter,
timeout=30,
)
return response.get("data") if response else None
async def get_record_detail(
file: str,
file_id: Optional[str] = None,
*,
adapter: "NapcatAdapter | None" = None,
) -> dict | None:
"""
获取语音信息详情
"""
logger.debug("获取语音信息详情中")
response = await _call_adapter_api(
"get_record",
{"file": file, "file_id": file_id, "out_format": "wav"},
adapter=adapter,
timeout=30,
)
return response.get("data") if response else None

View File

@@ -3,7 +3,7 @@ import random
import time
import websockets as Server
import uuid
from mofox_bus import (
from maim_message import (
UserInfo,
GroupInfo,
Seg,

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.8.2"
version = "7.8.3"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值