diff --git a/docs/mofox_bus.md b/docs/mofox_bus.md new file mode 100644 index 000000000..aaa4a3951 --- /dev/null +++ b/docs/mofox_bus.md @@ -0,0 +1,189 @@ +# MoFox Bus 消息库说明 + +MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_message`,将核心与各平台适配器之间的通信抽象成可拓展、可热插拔的组件。该库完全异步、面向高吞吐,覆盖消息建模、序列化、传输层、运行时路由、适配器工具等多个层面。 + +--- + +## 1. 设计目标 + +- **通用消息模型**:统一 envelope / content,使核心逻辑不关心平台差异。 +- **零拷贝字典结构**:TypedDict + dataclass,方便直接序列化 JSON。 +- **高性能传输**:批量收发 + orjson 序列化 + WS/HTTP 封装。 +- **适配器友好**:提供 BaseAdapter、Sink、Router 与批处理工具。 +- **渐进可扩展**:未来扩充 gRPC/MQ 仅需在 `transport/` 下新增实现。 + +--- + +## 2. 包结构概览(`src/mofox_bus/`) + +| 模块 | 主要职责 | +| --- | --- | +| `types.py` | TypedDict 消息模型(MessageEnvelope、Content、Sender/ChannelInfo 等)。 | +| `message_models.py` | dataclass 版 `Seg` / `MessageBase`,兼容老的消息段语义。 | +| `codec.py` | 高性能 JSON 编解码,含批量接口与 schema version 升级钩子。 | +| `runtime.py` | 消息路由/Hook/批处理调度器,支撑核心处理链。 | +| `adapter_utils.py` | BaseAdapter、CoreMessageSink、BatchDispatcher 等工具。 | +| `api.py` | WebSocket `MessageServer`/`MessageClient`,提供 token、复用 FastAPI。 | +| `router.py` | 多平台客户端统一管理,自动重连与动态路由。 | +| `transport/` | HTTP/WS server&client 轻量封装,可独立复用。 | +| `__init__.py` | 导出常用符号供外部按需引用。 | + +--- + +## 3. 消息模型 + +### 3.1 Envelope TypedDict(`types.py`) + +- `MessageEnvelope`:核心字段包括 `id`、`direction`、`platform`、`timestamp_ms`、`channel`、`sender`、`content` 等,一律使用毫秒时间戳,保留 `raw_platform_message` 与 `metadata` 便于调试 / 扩展。 +- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。 +- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。 + +### 3.2 dataclass 消息段(`message_models.py`) + +- `Seg`:表示一段内容,支持嵌套 `seglist`。 +- `UserInfo` / `GroupInfo` / `FormatInfo` / `TemplateInfo`:保留旧结构字段,但新增 `user_avatar` 等业务常用字段。 +- `BaseMessageInfo` + `MessageBase`:方便适配器沿用原始 `MessageBase` API,也使核心可以在内存中直接传递 dataclass。 + +> **何时使用 TypedDict vs dataclass?** +TypedDict 更适合网络传输和依赖注入;dataclass 版 MessageBase 则保留分段消息特性,适合适配器内部加工。 + +--- + +## 4. 序列化与版本(`codec.py`) + +- `dumps_message` / `loads_message`:处理单条 `MessageEnvelope`,自动补充 `schema_version`(默认 1)。 +- `dumps_messages` / `loads_messages`:批量传输 `{schema_version, items: [...]}`,减少 HTTP/WS 次数。 +- 预留 `_upgrade_schema_if_needed` 钩子,可在引入 v2/v3 时集中兼容逻辑。 + +默认使用 `orjson`,若运行环境缺失会自动 fallback 到标准库 `json`,保证兼容性。 + +--- + +## 5. 运行时调度(`runtime.py`) + +- `MessageRuntime`: + - `add_route(predicate, handler)` 或 `@runtime.route(...)` 装饰器注册消息处理器。 + - `register_before_hook` / `register_after_hook` / `register_error_hook` 注入监控、埋点、Trace。 + - `set_batch_handler` 支持一次处理整批消息(例如批量落库)。 +- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。 + +运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。 + +--- + +## 6. 传输层封装(`transport/`) + +### 6.1 HTTP +- `HttpMessageServer`:使用 `aiohttp.web` 监听 `POST /messages`,调用业务 handler 后可返回响应批。 +- `HttpMessageClient`:管理 `aiohttp.ClientSession`,`send_messages(messages, expect_reply=True)` 支持同步等待回复。 + +### 6.2 WebSocket +- `WsMessageServer`:基于 `aiohttp`,维护连接集合,支持 `broadcast`。 +- `WsMessageClient`:自动重连、后台读取,`send_messages`/`send_message` 直接发送批量。 + +以上都复用了 `codec` 批量协议,统一上下游格式。 + +--- + +## 7. Server / Client / Router(`api.py`、`router.py`) + +### 7.1 MessageServer +- 可复用已有 FastAPI 实例(`app=get_global_server().get_app()`),在同一进程内共享路由。 +- 支持 header token 校验 (`enable_token` + `add_valid_token`)。 +- `broadcast_message`、`broadcast_to_platform`、`send_message(message: MessageBase)` 满足不同场景。 + +### 7.2 MessageClient +- 仅 WebSocket 模式,管理 `aiohttp` 连接、自动收消息,供适配器推送到核心。 + +### 7.3 Router +- `RouteConfig` + `TargetConfig` 描述平台到 URL 的映射。 +- `Router.run()` 会为每个平台创建 `MessageClient` 并保持心跳,`_monitor_connections` 自动重连。 +- `register_class_handler` 可绑定 Napcat 适配器那样的 class handler。 + +--- + +## 8. 适配器工具(`adapter_utils.py`) + +- `BaseAdapter`:约定入站 `from_platform_message` / 出站 `_send_platform_message`,默认提供批量入口。 +- `CoreMessageSink` 协议 + `InProcessCoreSink`:方便在同进程中把适配器消息直接推给核心协程。 +- `BatchDispatcher`:封装缓冲 + 定时 flush 的发送管道,可与 HTTP/WS 客户端组合提升吞吐。 + +--- + +## 9. 集成与配置 + +1. **配置文件**:在 `config.*.toml` 中新增 `[message_bus]` 段(参考 `template/bot_config_template.toml`),控制 host/port/token/wss 等。 +2. **服务启动**:`src/common/message/api.py` 中的 `get_global_api()` 已默认实例化 `MessageServer`,并将 token 写入服务器。 +3. **适配器更新**:所有使用原 `maim_message` 的模块已改为 `from mofox_bus import ...`,无需额外适配即可继续利用 `MessageBase` / `Router` API。 + +--- + +## 10. 快速上手示例 + +```python +from mofox_bus import MessageRuntime, types +from mofox_bus.transport import HttpMessageServer + +runtime = MessageRuntime() + +@runtime.route(lambda env: env["content"]["type"] == "text") +async def handle_text(env: types.MessageEnvelope): + print("收到文本:", env["content"]["text"]) + +async def http_handler(messages: list[types.MessageEnvelope]): + await runtime.handle_batch(messages) + +server = HttpMessageServer(http_handler) +app = server.make_app() # 交给 aiohttp/uvicorn 运行 +``` + +**适配器 Skeleton:** +```python +from mofox_bus import ( + BaseAdapter, + MessageEnvelope, + WebSocketAdapterOptions, +) + +class MyAdapter(BaseAdapter): + platform = "custom" + + def __init__(self, core_sink): + super().__init__( + core_sink, + transport=WebSocketAdapterOptions( + url="ws://127.0.0.1:19898", + incoming_parser=lambda raw: orjson.loads(raw)["payload"], + ), + ) + + def from_platform_message(self, raw: dict) -> MessageEnvelope: + return { + "id": raw["id"], + "direction": "incoming", + "platform": self.platform, + "timestamp_ms": raw["ts"], + "channel": {"channel_id": raw["room_id"], "channel_type": "dm"}, + "sender": {"user_id": raw["user_id"], "role": "user"}, + "content": {"type": "text", "text": raw["content"]}, + "conversation_id": raw["room_id"], + } +``` + +- 如果传入 `WebSocketAdapterOptions`,BaseAdapter 会自动建立连接、监听、默认封装 `{"type":"message","payload":...}` 的标准 JSON,并允许通过 `outgoing_encoder` 自定义下行格式。 +- 如果传入 `HttpAdapterOptions`,BaseAdapter 会自动启动一个 aiohttp Webhook(`POST /adapter/messages`)并将收到的 JSON 批量投递给核心。 + +> 完整的 WebSocket 适配器示例见 `examples/mofox_bus_demo_adapter.py`:演示了平台提供 WS 接口、适配器通过 `WebSocketAdapterOptions` 自动启动监听、接收/处理/回发的全过程,可直接运行观察日志。 + +--- + +## 11. 调试与最佳实践 + +- 利用 `MessageRuntime.register_error_hook` 打印 `correlation_id` / `id`,快速定位异常消息。 +- 如果适配器与核心同进程,优先使用 `InProcessCoreSink` 以避免 JSON 编解码。 +- 批量吞吐场景(如 HTTP 推送)优先通过 `BatchDispatcher` 聚合再发送,可显著降低连接开销。 +- 自定义传输实现可参考 `transport/http_server.py` / `ws_client.py`,保持 `loads_messages` / `dumps_messages` 协议即可与现有核心互通。 + +--- + +通过以上结构,MoFox Bus 提供了一套端到端的统一消息能力,满足 AI Bot 在多平台、多形态场景下的高性能传输与扩展需求。若需要扩展新的传输协议或内容类型,只需在对应模块增加 Literal/TypedDict/Transport 实现即可。祝使用愉快! diff --git a/examples/mofox_bus_demo_adapter.py b/examples/mofox_bus_demo_adapter.py new file mode 100644 index 000000000..40cff4f35 --- /dev/null +++ b/examples/mofox_bus_demo_adapter.py @@ -0,0 +1,196 @@ +""" +示例:演示一个最小可用的 WebSocket 适配器如何使用 BaseAdapter 的自动传输封装: +1) 通过 WS 接入平台; +2) 将平台推送的消息转成 MessageEnvelope 并交给核心; +3) 接收核心回复并通过 WS 再发回平台。 +""" + +from __future__ import annotations + +import asyncio +import sys +import time +import uuid +from pathlib import Path +from typing import Any, Dict, Optional + +import orjson +import websockets + +# 追加 src 目录,便于直接运行示例 +sys.path.append(str(Path(__file__).resolve().parents[1] / "src")) + +from mofox_bus import ( + BaseAdapter, + InProcessCoreSink, + MessageEnvelope, + MessageRuntime, + WebSocketAdapterOptions, +) + + +# --------------------------------------------------------------------------- +# 1. 模拟一个提供 WebSocket 接口的平台 +# --------------------------------------------------------------------------- + + +class FakePlatformServer: + """ + 适配器将通过 WS 连接到这个模拟平台。 + 平台会广播消息给所有连接,适配器发送的响应也会被打印出来。 + """ + + def __init__(self, host: str = "127.0.0.1", port: int = 19898) -> None: + self._host = host + self._port = port + self._connections: set[Any] = set() + self._server = None + + @property + def url(self) -> str: + return f"ws://{self._host}:{self._port}" + + async def start(self) -> None: + self._server = await websockets.serve(self._handler, self._host, self._port) + print(f"[Platform] WebSocket server listening on {self.url}") + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + async def _handler(self, ws) -> None: + self._connections.add(ws) + print("[Platform] adapter connected") + try: + async for raw in ws: + data = orjson.loads(raw) + if data["type"] == "send": + print(f"[Platform] <- Bot: {data['payload']['text']}") + finally: + self._connections.discard(ws) + print("[Platform] adapter disconnected") + + async def simulate_incoming_message(self, text: str) -> None: + payload = { + "message_id": str(uuid.uuid4()), + "channel_id": "room-42", + "user_id": "demo-user", + "text": text, + "timestamp": time.time(), + } + message = orjson.dumps({"type": "message", "payload": payload}).decode() + for ws in list(self._connections): + await ws.send(message) + + +# --------------------------------------------------------------------------- +# 2. 适配器实现:仅关注核心转换逻辑,网络层交由 BaseAdapter 管理 +# --------------------------------------------------------------------------- + + +class DemoWsAdapter(BaseAdapter): + platform = "demo" + + def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: + return { + "id": raw["message_id"], + "direction": "incoming", + "platform": self.platform, + "timestamp_ms": int(raw["timestamp"] * 1000), + "channel": {"channel_id": raw["channel_id"], "channel_type": "room"}, + "sender": {"user_id": raw["user_id"], "role": "user"}, + "conversation_id": raw["channel_id"], + "content": {"type": "text", "text": raw["text"]}, + } + + +def incoming_parser(raw: str | bytes) -> Any: + data = orjson.loads(raw) + if data.get("type") == "message": + return data["payload"] + return data + + +def outgoing_encoder(envelope: MessageEnvelope) -> str: + return orjson.dumps( + { + "type": "send", + "payload": { + "channel_id": envelope["channel"]["channel_id"], + "text": envelope["content"]["text"], + }, + } + ).decode() + + +# --------------------------------------------------------------------------- +# 3. 核心 Runtime:注册处理器并通过 InProcessCoreSink 接收消息 +# --------------------------------------------------------------------------- + +runtime = MessageRuntime() + + +@runtime.route(lambda env: env["direction"] == "incoming") +async def handle_incoming(env: MessageEnvelope) -> MessageEnvelope: + user_text = env["content"]["text"] + reply_text = f"核心收到:{user_text}" + return { + "id": str(uuid.uuid4()), + "direction": "outgoing", + "platform": env["platform"], + "timestamp_ms": int(time.time() * 1000), + "channel": env["channel"], + "sender": { + "user_id": "bot", + "role": "assistant", + "display_name": "DemoBot", + }, + "conversation_id": env["conversation_id"], + "content": {"type": "text", "text": reply_text}, + } + + +adapter: Optional[DemoWsAdapter] = None + + +async def core_entry(message: MessageEnvelope) -> None: + response = await runtime.handle_message(message) + if response and adapter is not None: + await adapter.send_to_platform(response) + + +core_sink = InProcessCoreSink(core_entry) + + +# --------------------------------------------------------------------------- +# 4. 串起来并运行 Demo +# --------------------------------------------------------------------------- + +async def main() -> None: + platform = FakePlatformServer() + await platform.start() + + global adapter + adapter = DemoWsAdapter( + core_sink, + transport=WebSocketAdapterOptions( + url=platform.url, + incoming_parser=incoming_parser, + outgoing_encoder=outgoing_encoder, + ), + ) + await adapter.start() + + await asyncio.sleep(0.1) + await platform.simulate_incoming_message("你好,MoFox Bus!") + await platform.simulate_incoming_message("请问你是谁?") + + await asyncio.sleep(0.5) + await adapter.stop() + await platform.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 2f70c2c4c..d9d250a9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "langfuse==3.7.0", "lunar-python>=1.4.4", "lxml>=6.0.0", - "maim-message>=0.3.8", "matplotlib>=3.10.3", "networkx>=3.4.2", "orjson>=3.10", diff --git a/requirements.txt b/requirements.txt index 4fa4c3705..2aacffef1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,6 @@ filetype slowapi rjieba jsonlines -maim_message quick_algo matplotlib networkx diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 6c9b78ba6..31b2b5e71 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -3,7 +3,7 @@ import re import traceback from typing import Any -from maim_message import UserInfo +from mofox_bus import UserInfo from src.chat.message_manager import message_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager @@ -353,7 +353,7 @@ class ChatBot: return # 先提取基础信息检查是否是自身消息上报 - from maim_message import BaseMessageInfo + from mofox_bus import BaseMessageInfo temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) if temp_message_info.additional_config: sent_message = temp_message_info.additional_config.get("echo", False) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index e16930ffe..2dc1f5696 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -2,7 +2,7 @@ import asyncio import hashlib import time -from maim_message import GroupInfo, UserInfo +from mofox_bus import GroupInfo, UserInfo from rich.traceback import install from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert @@ -358,7 +358,7 @@ class ChatManager: def register_message(self, message: DatabaseMessages): """注册消息到聊天流""" # 从 DatabaseMessages 提取平台和用户/群组信息 - from maim_message import GroupInfo, UserInfo + from mofox_bus import GroupInfo, UserInfo user_info = UserInfo( platform=message.user_info.platform, diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 68fc4f1bf..fc2dcacc5 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Optional import urllib3 -from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo +from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo from rich.traceback import install from src.chat.message_receive.chat_stream import ChatStream diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index 0c9e89951..9f848f819 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -7,7 +7,7 @@ import time from typing import Any import orjson -from maim_message import BaseMessageInfo, Seg +from mofox_bus import BaseMessageInfo, Seg from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager @@ -430,9 +430,9 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag Returns: BaseMessageInfo: 重建的消息信息对象 """ - from maim_message import GroupInfo, UserInfo + from mofox_bus import GroupInfo, UserInfo - # 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo + # 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo user_info = UserInfo( platform=db_message.user_info.platform, user_id=db_message.user_info.user_id, @@ -440,7 +440,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag user_cardname=db_message.user_info.user_cardname or "" ) - # 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在) + # 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo(如果存在) group_info = None if db_message.group_info: group_info = GroupInfo( diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 71d2d1861..2c2713cda 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -8,7 +8,7 @@ from typing import Any import numpy as np import rjieba -from maim_message import UserInfo +from mofox_bus import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/common/logger.py b/src/common/logger.py index 0e4c50fa3..39bc89c3c 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -534,7 +534,7 @@ DEFAULT_MODULE_COLORS = { # 数据库和消息 "database_model": "#875F00", # 橙褐色 "database": "#00FF00", # 橙褐色 - "maim_message": "#AF87D7", # 紫褐色 + "mofox_bus": "#AF87D7", # 紫褐色 # 日志系统 "logger": "#808080", # 深灰色 "confirm": "#FFFF00", # 黄色+粗体 diff --git a/src/common/message/api.py b/src/common/message/api.py index 2d797a5a8..f49474df7 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -1,70 +1,53 @@ -import importlib.metadata import os -from maim_message import MessageServer +from mofox_bus import MessageServer from src.common.logger import get_logger from src.common.server import get_global_server from src.config.config import global_config -global_api = None +global_api: MessageServer | None = None -def get_global_api() -> MessageServer: # sourcery skip: extract-method - """获取全局MessageServer实例""" +def get_global_api() -> MessageServer: + """ + 获取全局 MessageServer 单例。 + """ + global global_api - if global_api is None: - # 检查maim_message版本 - try: - maim_message_version = importlib.metadata.version("maim_message") - version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3] - except (importlib.metadata.PackageNotFoundError, ValueError): - version_compatible = False + if global_api is not None: + return global_api - # 读取配置项 - maim_message_config = global_config.maim_message + bus_config = global_config.message_bus + host = os.getenv("HOST", "127.0.0.1") + port_str = os.getenv("PORT", "8000") - # 设置基本参数 + try: + port = int(port_str) + except ValueError: + port = 8000 - host = os.getenv("HOST", "127.0.0.1") - port_str = os.getenv("PORT", "8000") + kwargs: dict[str, object] = { + "host": host, + "port": port, + "app": get_global_server().get_app(), + } - try: - port = int(port_str) - except ValueError: - port = 8000 + if bus_config.use_custom: + kwargs["host"] = bus_config.host + kwargs["port"] = bus_config.port + kwargs.pop("app", None) + if bus_config.use_wss: + if bus_config.cert_file: + kwargs["ssl_certfile"] = bus_config.cert_file + if bus_config.key_file: + kwargs["ssl_keyfile"] = bus_config.key_file - kwargs = { - "host": host, - "port": port, - "app": get_global_server().get_app(), - } + if bus_config.auth_token: + kwargs["enable_token"] = True + kwargs["custom_logger"] = get_logger("mofox_bus") - # 只有在版本 >= 0.3.0 时才使用高级特性 - if version_compatible: - # 添加自定义logger - maim_message_logger = get_logger("maim_message") - kwargs["custom_logger"] = maim_message_logger - - # 添加token认证 - if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0: - kwargs["enable_token"] = True - - if maim_message_config.use_custom: - # 添加WSS模式支持 - del kwargs["app"] - kwargs["host"] = maim_message_config.host - kwargs["port"] = maim_message_config.port - kwargs["mode"] = maim_message_config.mode - if maim_message_config.use_wss: - if maim_message_config.cert_file: - kwargs["ssl_certfile"] = maim_message_config.cert_file - if maim_message_config.key_file: - kwargs["ssl_keyfile"] = maim_message_config.key_file - kwargs["enable_custom_uvicorn_logger"] = False - - global_api = MessageServer(**kwargs) - if version_compatible and maim_message_config.auth_token: - for token in maim_message_config.auth_token: - global_api.add_valid_token(token) + global_api = MessageServer(**kwargs) + for token in bus_config.auth_token: + global_api.add_valid_token(token) return global_api diff --git a/src/config/config.py b/src/config/config.py index 07cee3688..5b40a7242 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -26,7 +26,7 @@ from src.config.official_configs import ( ExperimentalConfig, ExpressionConfig, LPMMKnowledgeConfig, - MaimMessageConfig, + MessageBusConfig, MemoryConfig, MessageReceiveConfig, MoodConfig, @@ -392,7 +392,7 @@ class Config(ValidatedConfigBase): response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置") response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置") experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置") - maim_message: MaimMessageConfig = Field(..., description="Maim消息配置") + message_bus: MessageBusConfig = Field(..., description="消息总线配置") lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置") tool: ToolConfig = Field(..., description="工具配置") debug: DebugConfig = Field(..., description="调试配置") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7a98d76f7..bb1716ddf 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -574,8 +574,18 @@ class ExperimentalConfig(ValidatedConfigBase): pfc_chatting: bool = Field(default=False, description="启用PFC聊天") -class MaimMessageConfig(ValidatedConfigBase): - """maim_message配置类""" +class MessageBusConfig(ValidatedConfigBase): + """mofox_bus 消息服务配置""" + + use_custom: bool = Field(default=False, description="是否使用自定义地址") + host: str = Field(default="127.0.0.1", description="消息服务主机") + port: int = Field(default=8090, description="消息服务端口") + mode: Literal["ws", "tcp"] = Field(default="ws", description="传输模式") + use_wss: bool = Field(default=False, description="是否启用 WSS") + cert_file: str = Field(default="", description="证书文件路径") + key_file: str = Field(default="", description="密钥文件路径") + auth_token: list[str] = Field(default_factory=lambda: [], description="认证 token 列表") + use_custom: bool = Field(default=False, description="启用自定义") host: str = Field(default="127.0.0.1", description="主机") diff --git a/src/main.py b/src/main.py index 143feae51..acba32aa8 100644 --- a/src/main.py +++ b/src/main.py @@ -10,7 +10,7 @@ from functools import partial from random import choices from typing import Any -from maim_message import MessageServer +from mofox_bus import MessageServer from rich.traceback import install from src.chat.emoji_system.emoji_manager import get_emoji_manager diff --git a/src/mofox_bus/__init__.py b/src/mofox_bus/__init__.py new file mode 100644 index 000000000..955736e8a --- /dev/null +++ b/src/mofox_bus/__init__.py @@ -0,0 +1,94 @@ +""" +MoFox 内部通用消息总线实现。 + +该模块导出 TypedDict 消息模型、序列化工具、传输层封装以及适配器辅助工具, +供核心进程与各类平台适配器共享。 +""" + +from . import codec, types +from .adapter_utils import ( + AdapterTransportOptions, + BaseAdapter, + BatchDispatcher, + CoreMessageSink, + HttpAdapterOptions, + InProcessCoreSink, + WebSocketLike, + WebSocketAdapterOptions, +) +from .api import MessageClient, MessageServer +from .codec import dumps_message, dumps_messages, loads_message, loads_messages +from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo +from .router import RouteConfig, Router, TargetConfig +from .runtime import MessageProcessingError, MessageRoute, MessageRuntime +from .types import ( + AudioContent, + ChannelInfo, + CommandContent, + Content, + ContentType, + EventContent, + EventType, + FileContent, + ImageContent, + MessageDirection, + MessageEnvelope, + Role, + SenderInfo, + SystemContent, + TextContent, + VideoContent, +) + +__all__ = [ + # TypedDict model + "AudioContent", + "ChannelInfo", + "CommandContent", + "Content", + "ContentType", + "EventContent", + "EventType", + "FileContent", + "ImageContent", + "MessageDirection", + "MessageEnvelope", + "Role", + "SenderInfo", + "SystemContent", + "TextContent", + "VideoContent", + # Codec helpers + "codec", + "dumps_message", + "dumps_messages", + "loads_message", + "loads_messages", + # Runtime / routing + "MessageRoute", + "MessageRuntime", + "MessageProcessingError", + # Message dataclasses + "Seg", + "GroupInfo", + "UserInfo", + "FormatInfo", + "TemplateInfo", + "BaseMessageInfo", + "MessageBase", + # Server/client/router + "MessageServer", + "MessageClient", + "Router", + "RouteConfig", + "TargetConfig", + # Adapter helpers + "AdapterTransportOptions", + "BaseAdapter", + "BatchDispatcher", + "CoreMessageSink", + "InProcessCoreSink", + "WebSocketLike", + "WebSocketAdapterOptions", + "HttpAdapterOptions", +] diff --git a/src/mofox_bus/adapter_utils.py b/src/mofox_bus/adapter_utils.py new file mode 100644 index 000000000..5c4b01be3 --- /dev/null +++ b/src/mofox_bus/adapter_utils.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +import asyncio +import contextlib +from dataclasses import dataclass +from typing import Any, AsyncIterator, Awaitable, Callable, Protocol + +import orjson +from aiohttp import web as aiohttp_web +import websockets + +from .types import MessageEnvelope + + +class CoreMessageSink(Protocol): + async def send(self, message: MessageEnvelope) -> None: ... + + async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional + + +class WebSocketLike(Protocol): + def __aiter__(self) -> AsyncIterator[str | bytes]: ... + + @property + def closed(self) -> bool: ... + + async def send(self, data: str | bytes) -> None: ... + + async def close(self) -> None: ... + + +@dataclass +class WebSocketAdapterOptions: + url: str + headers: dict[str, str] | None = None + incoming_parser: Callable[[str | bytes], Any] | None = None + outgoing_encoder: Callable[[MessageEnvelope], str | bytes] | None = None + + +@dataclass +class HttpAdapterOptions: + host: str = "0.0.0.0" + port: int = 8089 + path: str = "/adapter/messages" + app: aiohttp_web.Application | None = None + + +AdapterTransportOptions = WebSocketAdapterOptions | HttpAdapterOptions | None + + +class BaseAdapter: + """ + 适配器基类:负责平台原始消息与 MessageEnvelope 之间的互转。 + 子类需要实现平台入站解析与出站发送逻辑。 + """ + + platform: str = "unknown" + + def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None): + """ + Args: + core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。 + transport: 传入 WebSocketAdapterOptions / HttpAdapterOptions 即可自动管理监听逻辑。 + """ + self.core_sink = core_sink + self._transport_config = transport + self._ws: WebSocketLike | None = None + self._ws_task: asyncio.Task | None = None + self._http_runner: aiohttp_web.AppRunner | None = None + self._http_site: aiohttp_web.BaseSite | None = None + + async def start(self) -> None: + """根据配置自动启动 WS/HTTP 监听。""" + if isinstance(self._transport_config, WebSocketAdapterOptions): + await self._start_ws_transport(self._transport_config) + elif isinstance(self._transport_config, HttpAdapterOptions): + await self._start_http_transport(self._transport_config) + + async def stop(self) -> None: + """停止自动管理的传输层。""" + if self._ws_task: + self._ws_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._ws_task + self._ws_task = None + if self._ws: + await self._ws.close() + self._ws = None + if self._http_site: + await self._http_site.stop() + self._http_site = None + if self._http_runner: + await self._http_runner.cleanup() + self._http_runner = None + + async def on_platform_message(self, raw: Any) -> None: + """处理平台下发的单条消息并交给核心。""" + envelope = self.from_platform_message(raw) + await self.core_sink.send(envelope) + + async def on_platform_messages(self, raw_messages: list[Any]) -> None: + """批量推送入口,内部自动批量或逐条送入核心。""" + envelopes = [self.from_platform_message(raw) for raw in raw_messages] + await _send_many(self.core_sink, envelopes) + + async def send_to_platform(self, envelope: MessageEnvelope) -> None: + """核心生成单条消息时调用,由子类或自动传输层发送。""" + await self._send_platform_message(envelope) + + async def send_batch_to_platform(self, envelopes: list[MessageEnvelope]) -> None: + """默认串行发送整批消息,子类可根据平台特性重写。""" + for env in envelopes: + await self._send_platform_message(env) + + def from_platform_message(self, raw: Any) -> MessageEnvelope: + """子类必须实现:将平台原始结构转换为统一 MessageEnvelope。""" + raise NotImplementedError + + async def _send_platform_message(self, envelope: MessageEnvelope) -> None: + """子类必须实现:把 MessageEnvelope 转为平台格式并发送出去。""" + if isinstance(self._transport_config, WebSocketAdapterOptions): + await self._send_via_ws(envelope) + return + raise NotImplementedError + + async def _start_ws_transport(self, options: WebSocketAdapterOptions) -> None: + self._ws = await websockets.connect(options.url, extra_headers=options.headers) + self._ws_task = asyncio.create_task(self._ws_listen_loop(options)) + + async def _ws_listen_loop(self, options: WebSocketAdapterOptions) -> None: + assert self._ws is not None + parser = options.incoming_parser or self._default_ws_parser + try: + async for raw in self._ws: + payload = parser(raw) + await self.on_platform_message(payload) + finally: + pass + + async def _send_via_ws(self, envelope: MessageEnvelope) -> None: + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket transport is not active") + encoder = None + if isinstance(self._transport_config, WebSocketAdapterOptions): + encoder = self._transport_config.outgoing_encoder + data = encoder(envelope) if encoder else self._default_ws_encoder(envelope) + await self._ws.send(data) + + async def _start_http_transport(self, options: HttpAdapterOptions) -> None: + app = options.app or aiohttp_web.Application() + app.add_routes([aiohttp_web.post(options.path, self._handle_http_request)]) + self._http_runner = aiohttp_web.AppRunner(app) + await self._http_runner.setup() + self._http_site = aiohttp_web.TCPSite(self._http_runner, options.host, options.port) + await self._http_site.start() + + async def _handle_http_request(self, request: aiohttp_web.Request) -> aiohttp_web.Response: + raw = await request.read() + data = orjson.loads(raw) if raw else {} + if isinstance(data, list): + await self.on_platform_messages(data) + else: + await self.on_platform_message(data) + return aiohttp_web.json_response({"status": "ok"}) + + @staticmethod + def _default_ws_parser(raw: str | bytes) -> Any: + data = orjson.loads(raw) + if isinstance(data, dict) and data.get("type") == "message" and "payload" in data: + return data["payload"] + return data + + @staticmethod + def _default_ws_encoder(envelope: MessageEnvelope) -> bytes: + return orjson.dumps({"type": "send", "payload": envelope}) + + +class InProcessCoreSink: + """ + 简单的进程内 sink,实现 CoreMessageSink 协议。 + """ + + def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]): + self._handler = handler + + async def send(self, message: MessageEnvelope) -> None: + await self._handler(message) + + async def send_many(self, messages: list[MessageEnvelope]) -> None: + for message in messages: + await self._handler(message) + + +async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None: + send_many = getattr(sink, "send_many", None) + if callable(send_many): + await send_many(envelopes) + return + for env in envelopes: + await sink.send(env) + + +class BatchDispatcher: + """ + 将 send 操作合并为批量发送,适合网络 IO 密集场景。 + """ + + def __init__( + self, + sink: CoreMessageSink, + *, + max_batch_size: int = 50, + flush_interval: float = 0.2, + ) -> None: + self._sink = sink + self._max_batch_size = max_batch_size + self._flush_interval = flush_interval + self._buffer: list[MessageEnvelope] = [] + self._lock = asyncio.Lock() + self._flush_task: asyncio.Task | None = None + self._closed = False + + async def add(self, message: MessageEnvelope) -> None: + async with self._lock: + if self._closed: + raise RuntimeError("Dispatcher closed") + self._buffer.append(message) + self._ensure_timer() + if len(self._buffer) >= self._max_batch_size: + await self._flush_locked() + + async def close(self) -> None: + async with self._lock: + self._closed = True + await self._flush_locked() + if self._flush_task: + self._flush_task.cancel() + self._flush_task = None + + def _ensure_timer(self) -> None: + if self._flush_task is not None and not self._flush_task.done(): + return + loop = asyncio.get_running_loop() + self._flush_task = loop.create_task(self._flush_loop()) + + async def _flush_loop(self) -> None: + try: + await asyncio.sleep(self._flush_interval) + async with self._lock: + await self._flush_locked() + except asyncio.CancelledError: # pragma: no cover - timer cancellation + pass + + async def _flush_locked(self) -> None: + if not self._buffer: + return + payload = list(self._buffer) + self._buffer.clear() + await self._sink.send_many(payload) + + +__all__ = [ + "AdapterTransportOptions", + "BaseAdapter", + "BatchDispatcher", + "CoreMessageSink", + "HttpAdapterOptions", + "InProcessCoreSink", + "WebSocketAdapterOptions", +] diff --git a/src/mofox_bus/api.py b/src/mofox_bus/api.py new file mode 100644 index 000000000..11ae1d4ef --- /dev/null +++ b/src/mofox_bus/api.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +import ssl +from typing import Any, Awaitable, Callable, Dict, Literal, Optional + +import aiohttp +import orjson +import uvicorn +from fastapi import FastAPI, WebSocket, WebSocketDisconnect + +from .message_models import MessageBase + +MessagePayload = Dict[str, Any] +MessageHandler = Callable[[MessagePayload], Awaitable[None] | None] + + +class BaseMessageHandler: + def __init__(self) -> None: + self.message_handlers: list[MessageHandler] = [] + self.background_tasks: set[asyncio.Task] = set() + + def register_message_handler(self, handler: MessageHandler) -> None: + if handler not in self.message_handlers: + self.message_handlers.append(handler) + + async def process_message(self, message: MessagePayload) -> None: + tasks: list[asyncio.Task] = [] + for handler in self.message_handlers: + try: + result = handler(message) + if asyncio.iscoroutine(result): + task = asyncio.create_task(result) + tasks.append(task) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + except Exception: # pragma: no cover - logging only + logging.getLogger("mofox_bus.server").exception("Failed to handle message") + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +class MessageServer(BaseMessageHandler): + """ + WebSocket 消息服务器,支持与 FastAPI 应用共享事件循环。 + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 18000, + *, + enable_token: bool = False, + app: FastAPI | None = None, + path: str = "/ws", + ssl_certfile: str | None = None, + ssl_keyfile: str | None = None, + mode: Literal["ws", "tcp"] = "ws", + custom_logger: logging.Logger | None = None, + enable_custom_uvicorn_logger: bool = False, + ) -> None: + super().__init__() + if mode != "ws": + raise NotImplementedError("Only WebSocket mode is supported in mofox_bus") + if custom_logger: + logging.getLogger("mofox_bus.server").handlers = custom_logger.handlers + self.host = host + self.port = port + self._app = app or FastAPI() + self._own_app = app is None + self._path = path + self._ssl_certfile = ssl_certfile + self._ssl_keyfile = ssl_keyfile + self._enable_token = enable_token + self._valid_tokens: set[str] = set() + self._connections: set[WebSocket] = set() + self._platform_connections: dict[str, WebSocket] = {} + self._conn_lock = asyncio.Lock() + self._server: uvicorn.Server | None = None + self._running = False + self._setup_routes() + + def _setup_routes(self) -> None: + @_self_websocket(self._app, self._path) + async def websocket_endpoint(websocket: WebSocket) -> None: + platform = websocket.headers.get("platform", "unknown") + token = websocket.headers.get("authorization") or websocket.headers.get("Authorization") + if self._enable_token and not await self.verify_token(token): + await websocket.close(code=1008, reason="invalid token") + return + + await websocket.accept() + await self._register_connection(websocket, platform) + try: + while True: + msg = await websocket.receive() + if msg["type"] == "websocket.receive": + data = msg.get("text") + if data is None and msg.get("bytes") is not None: + data = msg["bytes"].decode("utf-8") + if not data: + continue + try: + payload = orjson.loads(data) + except orjson.JSONDecodeError: + logging.getLogger("mofox_bus.server").warning("Invalid JSON payload") + continue + if isinstance(payload, list): + for item in payload: + await self.process_message(item) + else: + await self.process_message(payload) + elif msg["type"] == "websocket.disconnect": + break + except WebSocketDisconnect: + pass + finally: + await self._remove_connection(websocket, platform) + + async def verify_token(self, token: str | None) -> bool: + if not self._enable_token: + return True + return token in self._valid_tokens + + def add_valid_token(self, token: str) -> None: + self._valid_tokens.add(token) + + def remove_valid_token(self, token: str) -> None: + self._valid_tokens.discard(token) + + async def _register_connection(self, websocket: WebSocket, platform: str) -> None: + async with self._conn_lock: + self._connections.add(websocket) + if platform: + previous = self._platform_connections.get(platform) + if previous and previous.client_state.name != "DISCONNECTED": + await previous.close(code=1000, reason="replaced") + self._platform_connections[platform] = websocket + + async def _remove_connection(self, websocket: WebSocket, platform: str) -> None: + async with self._conn_lock: + self._connections.discard(websocket) + if platform and self._platform_connections.get(platform) is websocket: + del self._platform_connections[platform] + + async def broadcast_message(self, message: MessagePayload) -> None: + data = orjson.dumps(message).decode("utf-8") + async with self._conn_lock: + targets = list(self._connections) + for ws in targets: + await ws.send_text(data) + + async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None: + ws = self._platform_connections.get(platform) + if ws is None: + raise RuntimeError(f"No active connection for platform {platform}") + await ws.send_text(orjson.dumps(message).decode("utf-8")) + + async def send_message(self, message: MessageBase | MessagePayload) -> None: + payload = message.to_dict() if isinstance(message, MessageBase) else message + platform = payload.get("message_info", {}).get("platform") + if not platform: + raise ValueError("message_info.platform is required to route the message") + await self.broadcast_to_platform(platform, payload) + + def run_sync(self) -> None: + if not self._own_app: + return + asyncio.run(self.run()) + + async def run(self) -> None: + self._running = True + if not self._own_app: + return + config = uvicorn.Config( + self._app, + host=self.host, + port=self.port, + ssl_certfile=self._ssl_certfile, + ssl_keyfile=self._ssl_keyfile, + log_config=None, + access_log=False, + ) + self._server = uvicorn.Server(config) + try: + await self._server.serve() + except asyncio.CancelledError: # pragma: no cover - shutdown path + pass + + async def stop(self) -> None: + self._running = False + if self._server: + self._server.should_exit = True + await self._server.shutdown() + self._server = None + async with self._conn_lock: + targets = list(self._connections) + self._connections.clear() + self._platform_connections.clear() + for ws in targets: + try: + await ws.close(code=1001, reason="server shutting down") + except Exception: # pragma: no cover - best effort + pass + for task in list(self.background_tasks): + if not task.done(): + task.cancel() + if self.background_tasks: + await asyncio.gather(*self.background_tasks, return_exceptions=True) + self.background_tasks.clear() + + +class MessageClient(BaseMessageHandler): + """ + WebSocket 消息客户端,实现双向传输。 + """ + + def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None: + super().__init__() + if mode != "ws": + raise NotImplementedError("Only WebSocket mode is supported in mofox_bus") + self._mode = mode + self._session: aiohttp.ClientSession | None = None + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._receive_task: asyncio.Task | None = None + self._url: str = "" + self._platform: str = "" + self._token: str | None = None + self._ssl_verify: str | None = None + self._closed = False + + async def connect( + self, + *, + url: str, + platform: str, + token: str | None = None, + ssl_verify: str | None = None, + ) -> None: + self._url = url + self._platform = platform + self._token = token + self._ssl_verify = ssl_verify + await self._establish_connection() + + async def _establish_connection(self) -> None: + if self._session is None: + self._session = aiohttp.ClientSession() + headers = {"platform": self._platform} + if self._token: + headers["authorization"] = self._token + ssl_context = None + if self._ssl_verify: + ssl_context = ssl.create_default_context(cafile=self._ssl_verify) + self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context) + self._receive_task = asyncio.create_task(self._receive_loop()) + + async def _receive_loop(self) -> None: + assert self._ws is not None + try: + async for msg in self._ws: + if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): + data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8") + try: + payload = orjson.loads(data) + except orjson.JSONDecodeError: + logging.getLogger("mofox_bus.client").warning("Invalid JSON payload") + continue + if isinstance(payload, list): + for item in payload: + await self.process_message(item) + else: + await self.process_message(payload) + elif msg.type == aiohttp.WSMsgType.ERROR: + break + except asyncio.CancelledError: # pragma: no cover - cancellation path + pass + finally: + if self._ws: + await self._ws.close() + self._ws = None + + async def run(self) -> None: + if self._receive_task is None: + await self._establish_connection() + try: + if self._receive_task: + await self._receive_task + except asyncio.CancelledError: # pragma: no cover - cancellation path + pass + + async def send_message(self, message: MessagePayload) -> bool: + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket connection is not established") + await self._ws.send_str(orjson.dumps(message).decode("utf-8")) + return True + + def is_connected(self) -> bool: + return self._ws is not None and not self._ws.closed + + async def stop(self) -> None: + self._closed = True + if self._receive_task and not self._receive_task.done(): + self._receive_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._receive_task + if self._ws: + await self._ws.close() + self._ws = None + if self._session: + await self._session.close() + self._session = None + + +def _self_websocket(app: FastAPI, path: str): + """ + 装饰器工厂,兼容 FastAPI websocket 路由的声明方式。 + FastAPI 不允许直接重复注册同一路径,因此这里封装一个可复用的装饰器。 + """ + + def decorator(func): + app.add_api_websocket_route(path, func) + return func + + return decorator + + +__all__ = ["BaseMessageHandler", "MessageClient", "MessageServer"] diff --git a/src/mofox_bus/codec.py b/src/mofox_bus/codec.py new file mode 100644 index 000000000..6f8d23cc6 --- /dev/null +++ b/src/mofox_bus/codec.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import json as _stdlib_json +from typing import Any, Dict, Iterable, List + +try: + import orjson as _json_impl +except Exception: # pragma: no cover - fallback when orjson is unavailable + _json_impl = None + +from .types import MessageEnvelope + +DEFAULT_SCHEMA_VERSION = 1 + + +def _dumps(obj: Any) -> bytes: + if _json_impl is not None: + return _json_impl.dumps(obj) + return _stdlib_json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + + +def _loads(data: bytes) -> Dict[str, Any]: + if _json_impl is not None: + return _json_impl.loads(data) + return _stdlib_json.loads(data.decode("utf-8")) + + +def dumps_message(msg: MessageEnvelope) -> bytes: + """ + 将单条 MessageEnvelope 序列化为 JSON bytes。 + """ + if "schema_version" not in msg: + msg["schema_version"] = DEFAULT_SCHEMA_VERSION + return _dumps(msg) + + +def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes: + """ + 将多条消息批量序列化,以提升吞吐。 + """ + payload = { + "schema_version": DEFAULT_SCHEMA_VERSION, + "items": list(messages), + } + return _dumps(payload) + + +def loads_message(data: bytes | str) -> MessageEnvelope: + """ + 反序列化单条消息。 + """ + if isinstance(data, str): + data = data.encode("utf-8") + obj = _loads(data) + return _upgrade_schema_if_needed(obj) + + +def loads_messages(data: bytes | str) -> List[MessageEnvelope]: + """ + 反序列化批量消息。 + """ + if isinstance(data, str): + data = data.encode("utf-8") + obj = _loads(data) + version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION) + if version != DEFAULT_SCHEMA_VERSION: + raise ValueError(f"Unsupported schema_version={version}") + return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])] + + +def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope: + """ + 针对未来的 schema 版本演进预留兼容入口。 + """ + version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION) + if version == DEFAULT_SCHEMA_VERSION: + return obj # type: ignore[return-value] + raise ValueError(f"Unsupported schema_version={version}") + + +__all__ = [ + "DEFAULT_SCHEMA_VERSION", + "dumps_message", + "dumps_messages", + "loads_message", + "loads_messages", +] diff --git a/src/mofox_bus/message_models.py b/src/mofox_bus/message_models.py new file mode 100644 index 000000000..ad408b63f --- /dev/null +++ b/src/mofox_bus/message_models.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional + + +@dataclass +class Seg: + """ + 消息段,表示一段文本/图片/结构化内容。 + """ + + type: str + data: Any + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Seg": + seg_type = data.get("type") + seg_data = data.get("data") + if seg_type == "seglist" and isinstance(seg_data, list): + seg_data = [Seg.from_dict(item) for item in seg_data] + return cls(type=seg_type, data=seg_data) + + def to_dict(self) -> Dict[str, Any]: + if self.type == "seglist" and isinstance(self.data, list): + payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data] + else: + payload = self.data + return {"type": self.type, "data": payload} + + +@dataclass +class GroupInfo: + platform: Optional[str] = None + group_id: Optional[str] = None + group_name: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]: + if not data or data.get("group_id") is None: + return None + return cls( + platform=data.get("platform"), + group_id=data.get("group_id"), + group_name=data.get("group_name"), + ) + + +@dataclass +class UserInfo: + platform: Optional[str] = None + user_id: Optional[str] = None + user_nickname: Optional[str] = None + user_cardname: Optional[str] = None + user_avatar: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UserInfo": + return cls( + platform=data.get("platform"), + user_id=data.get("user_id"), + user_nickname=data.get("user_nickname"), + user_cardname=data.get("user_cardname"), + user_avatar=data.get("user_avatar"), + ) + + +@dataclass +class FormatInfo: + content_format: Optional[List[str]] = None + accept_format: Optional[List[str]] = None + + def to_dict(self) -> Dict[str, Any]: + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]: + if not data: + return None + return cls( + content_format=data.get("content_format"), + accept_format=data.get("accept_format"), + ) + + +@dataclass +class TemplateInfo: + template_items: Optional[Dict[str, str]] = None + template_name: Optional[Dict[str, str]] = None + template_default: bool = True + + def to_dict(self) -> Dict[str, Any]: + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]: + if not data: + return None + return cls( + template_items=data.get("template_items"), + template_name=data.get("template_name"), + template_default=data.get("template_default", True), + ) + + +@dataclass +class BaseMessageInfo: + platform: Optional[str] = None + message_id: Optional[str] = None + time: Optional[float] = None + group_info: Optional[GroupInfo] = None + user_info: Optional[UserInfo] = None + format_info: Optional[FormatInfo] = None + template_info: Optional[TemplateInfo] = None + additional_config: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + if self.platform is not None: + result["platform"] = self.platform + if self.message_id is not None: + result["message_id"] = self.message_id + if self.time is not None: + result["time"] = self.time + if self.additional_config is not None: + result["additional_config"] = self.additional_config + if self.group_info is not None: + result["group_info"] = self.group_info.to_dict() + if self.user_info is not None: + result["user_info"] = self.user_info.to_dict() + if self.format_info is not None: + result["format_info"] = self.format_info.to_dict() + if self.template_info is not None: + result["template_info"] = self.template_info.to_dict() + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo": + return cls( + platform=data.get("platform"), + message_id=data.get("message_id"), + time=data.get("time"), + additional_config=data.get("additional_config"), + group_info=GroupInfo.from_dict(data.get("group_info", {})), + user_info=UserInfo.from_dict(data.get("user_info", {})), + format_info=FormatInfo.from_dict(data.get("format_info", {})), + template_info=TemplateInfo.from_dict(data.get("template_info", {})), + ) + + +@dataclass +class MessageBase: + message_info: BaseMessageInfo + message_segment: Seg + raw_message: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "message_info": self.message_info.to_dict(), + "message_segment": self.message_segment.to_dict(), + } + if self.raw_message is not None: + payload["raw_message"] = self.raw_message + return payload + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MessageBase": + return cls( + message_info=BaseMessageInfo.from_dict(data.get("message_info", {})), + message_segment=Seg.from_dict(data.get("message_segment", {})), + raw_message=data.get("raw_message"), + ) + + +__all__ = [ + "BaseMessageInfo", + "FormatInfo", + "GroupInfo", + "MessageBase", + "Seg", + "TemplateInfo", + "UserInfo", +] diff --git a/src/mofox_bus/router.py b/src/mofox_bus/router.py new file mode 100644 index 000000000..46e29de27 --- /dev/null +++ b/src/mofox_bus/router.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +from dataclasses import asdict, dataclass +from typing import Callable, Dict, Optional + +from .api import MessageClient +from .message_models import MessageBase + +logger = logging.getLogger("mofox_bus.router") + + +@dataclass +class TargetConfig: + url: str + token: str | None = None + ssl_verify: str | None = None + + def to_dict(self) -> Dict[str, str | None]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig": + return cls( + url=data.get("url", ""), + token=data.get("token"), + ssl_verify=data.get("ssl_verify"), + ) + + +@dataclass +class RouteConfig: + route_config: Dict[str, TargetConfig] + + def to_dict(self) -> Dict[str, Dict[str, str | None]]: + return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}} + + @classmethod + def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig": + cfg = { + platform: TargetConfig.from_dict(target) + for platform, target in data.get("route_config", {}).items() + } + return cls(route_config=cfg) + + +class Router: + def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None: + if custom_logger: + logger.handlers = custom_logger.handlers + self.config = config + self.clients: Dict[str, MessageClient] = {} + self.handlers: list[Callable[[Dict], None]] = [] + self._running = False + self._client_tasks: Dict[str, asyncio.Task] = {} + self._monitor_task: asyncio.Task | None = None + + async def connect(self, platform: str) -> None: + if platform not in self.config.route_config: + raise ValueError(f"Unknown platform {platform}") + target = self.config.route_config[platform] + mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws" + if mode != "ws": + raise NotImplementedError("TCP mode is not implemented yet") + client = MessageClient(mode="ws") + await client.connect( + url=target.url, + platform=platform, + token=target.token, + ssl_verify=target.ssl_verify, + ) + for handler in self.handlers: + client.register_message_handler(handler) + self.clients[platform] = client + if self._running: + self._client_tasks[platform] = asyncio.create_task(client.run()) + + def register_class_handler(self, handler: Callable[[Dict], None]) -> None: + self.handlers.append(handler) + for client in self.clients.values(): + client.register_message_handler(handler) + + async def run(self) -> None: + self._running = True + for platform in self.config.route_config: + if platform not in self.clients: + await self.connect(platform) + for platform, client in self.clients.items(): + if platform not in self._client_tasks: + self._client_tasks[platform] = asyncio.create_task(client.run()) + self._monitor_task = asyncio.create_task(self._monitor_connections()) + try: + while self._running: + await asyncio.sleep(1) + except asyncio.CancelledError: # pragma: no cover + raise + + async def _monitor_connections(self) -> None: + await asyncio.sleep(3) + while self._running: + for platform in list(self.clients.keys()): + client = self.clients.get(platform) + if client is None: + continue + if not client.is_connected(): + logger.info("Detected disconnect from %s, attempting reconnect", platform) + await self._reconnect_platform(platform) + await asyncio.sleep(5) + + async def _reconnect_platform(self, platform: str) -> None: + await self.remove_platform(platform) + if platform in self.config.route_config: + await self.connect(platform) + + async def remove_platform(self, platform: str) -> None: + if platform in self._client_tasks: + task = self._client_tasks.pop(platform) + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + client = self.clients.pop(platform, None) + if client: + await client.stop() + + async def stop(self) -> None: + self._running = False + if self._monitor_task: + self._monitor_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._monitor_task + self._monitor_task = None + for platform in list(self.clients.keys()): + await self.remove_platform(platform) + self.clients.clear() + + def get_target_url(self, message: MessageBase) -> Optional[str]: + platform = message.message_info.platform + if not platform: + return None + target = self.config.route_config.get(platform) + return target.url if target else None + + async def send_message(self, message: MessageBase): + platform = message.message_info.platform + if not platform: + raise ValueError("message_info.platform is required") + client = self.clients.get(platform) + if client is None: + raise RuntimeError(f"No client connected for platform {platform}") + return await client.send_message(message.to_dict()) + + async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None: + new_config = RouteConfig.from_dict(config_data) + await self._adjust_connections(new_config) + self.config = new_config + + async def _adjust_connections(self, new_config: RouteConfig) -> None: + current = set(self.config.route_config.keys()) + updated = set(new_config.route_config.keys()) + for platform in current - updated: + await self.remove_platform(platform) + for platform in updated: + if platform not in current: + await self.connect(platform) + else: + old = self.config.route_config[platform] + new = new_config.route_config[platform] + if old.url != new.url or old.token != new.token: + await self.remove_platform(platform) + await self.connect(platform) diff --git a/src/mofox_bus/runtime.py b/src/mofox_bus/runtime.py new file mode 100644 index 000000000..1d41a445b --- /dev/null +++ b/src/mofox_bus/runtime.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import asyncio +import threading +from dataclasses import dataclass +from typing import Awaitable, Callable, Iterable, List + +from .types import MessageEnvelope + +Hook = Callable[[MessageEnvelope], Awaitable[None] | None] +ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None] +Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]] +MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None] +BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None] + + +class MessageProcessingError(RuntimeError): + """封装处理链路中发生的异常。""" + + def __init__(self, message: MessageEnvelope, original: BaseException): + detail = message.get("id", "") + super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only + self.message_envelope = message + self.original = original + + +@dataclass +class MessageRoute: + predicate: Predicate + handler: MessageHandler + name: str | None = None + + +class MessageRuntime: + """ + 负责调度消息路由、执行前后 hook 以及批量处理。 + """ + + def __init__(self) -> None: + self._routes: list[MessageRoute] = [] + self._before_hooks: list[Hook] = [] + self._after_hooks: list[Hook] = [] + self._error_hooks: list[ErrorHook] = [] + self._batch_handler: BatchHandler | None = None + self._lock = threading.RLock() + + def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None: + with self._lock: + self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name)) + + def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]: + """ + 装饰器写法,便于在核心逻辑中声明式注册。 + """ + + def decorator(func: MessageHandler) -> MessageHandler: + self.add_route(predicate, func, name=name) + return func + + return decorator + + def set_batch_handler(self, handler: BatchHandler) -> None: + self._batch_handler = handler + + def register_before_hook(self, hook: Hook) -> None: + self._before_hooks.append(hook) + + def register_after_hook(self, hook: Hook) -> None: + self._after_hooks.append(hook) + + def register_error_hook(self, hook: ErrorHook) -> None: + self._error_hooks.append(hook) + + async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None: + await self._run_hooks(self._before_hooks, message) + try: + route = await self._match_route(message) + if route is None: + return None + result = await _maybe_await(route.handler(message)) + except Exception as exc: # pragma: no cover - tested indirectly + await self._run_error_hooks(message, exc) + raise MessageProcessingError(message, exc) from exc + await self._run_hooks(self._after_hooks, message) + return result + + async def handle_batch(self, messages: Iterable[MessageEnvelope]) -> List[MessageEnvelope]: + batch = list(messages) + if not batch: + return [] + if self._batch_handler is not None: + result = await _maybe_await(self._batch_handler(batch)) + return result or [] + responses: list[MessageEnvelope] = [] + for message in batch: + response = await self.handle_message(message) + if response is not None: + responses.append(response) + return responses + + async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None: + with self._lock: + routes = list(self._routes) + for route in routes: + should_handle = await _maybe_await(route.predicate(message)) + if should_handle: + return route + return None + + async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None: + for hook in hooks: + await _maybe_await(hook(message)) + + async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None: + for hook in self._error_hooks: + await _maybe_await(hook(message, exc)) + + +async def _maybe_await(result): + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + + +__all__ = [ + "BatchHandler", + "Hook", + "MessageHandler", + "MessageProcessingError", + "MessageRoute", + "MessageRuntime", + "Predicate", +] diff --git a/src/mofox_bus/transport/__init__.py b/src/mofox_bus/transport/__init__.py new file mode 100644 index 000000000..5915116d5 --- /dev/null +++ b/src/mofox_bus/transport/__init__.py @@ -0,0 +1,10 @@ +""" +传输层封装,提供 HTTP / WebSocket server & client。 +""" + +from .http_client import HttpMessageClient +from .http_server import HttpMessageServer +from .ws_client import WsMessageClient +from .ws_server import WsMessageServer + +__all__ = ["HttpMessageClient", "HttpMessageServer", "WsMessageClient", "WsMessageServer"] diff --git a/src/mofox_bus/transport/http_client.py b/src/mofox_bus/transport/http_client.py new file mode 100644 index 000000000..7472aa49e --- /dev/null +++ b/src/mofox_bus/transport/http_client.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import logging +from typing import Iterable, List, Sequence + +import aiohttp + +from ..codec import dumps_messages, loads_messages +from ..types import MessageEnvelope + + +class HttpMessageClient: + """ + 面向消息批量传输的 HTTP 客户端封装。 + """ + + def __init__( + self, + base_url: str, + *, + session: aiohttp.ClientSession | None = None, + timeout: aiohttp.ClientTimeout | None = None, + ) -> None: + self._base_url = base_url.rstrip("/") + self._session = session + self._timeout = timeout + self._owns_session = session is None + self._logger = logging.getLogger("mofox_bus.http_client") + + async def send_messages( + self, + messages: Sequence[MessageEnvelope], + *, + expect_reply: bool = False, + path: str = "/messages", + ) -> List[MessageEnvelope] | None: + if not messages: + return [] + session = await self._ensure_session() + url = f"{self._base_url}{path}" + payload = dumps_messages(messages) + self._logger.debug("Sending %d message(s) -> %s", len(messages), url) + async with session.post(url, data=payload, timeout=self._timeout) as resp: + resp.raise_for_status() + if not expect_reply: + return None + raw = await resp.read() + replies = loads_messages(raw) + self._logger.debug("Received %d reply message(s)", len(replies)) + return replies + + async def close(self) -> None: + if self._owns_session and self._session is not None: + await self._session.close() + self._session = None + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession() + return self._session + + async def __aenter__(self) -> "HttpMessageClient": + await self._ensure_session() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() diff --git a/src/mofox_bus/transport/http_server.py b/src/mofox_bus/transport/http_server.py new file mode 100644 index 000000000..e86869a9f --- /dev/null +++ b/src/mofox_bus/transport/http_server.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import logging +from typing import Awaitable, Callable, List + +from aiohttp import web + +from ..codec import dumps_messages, loads_messages +from ..types import MessageEnvelope + +MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None]] + + +class HttpMessageServer: + """ + 轻量级 HTTP 消息入口。可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下。 + """ + + def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None: + self._handler = handler + self._app = web.Application() + self._path = path + self._app.add_routes([web.post(path, self._handle_messages)]) + self._logger = logging.getLogger("mofox_bus.http_server") + + async def _handle_messages(self, request: web.Request) -> web.Response: + try: + raw = await request.read() + envelopes = loads_messages(raw) + self._logger.debug("Received %d message(s)", len(envelopes)) + except Exception as exc: # pragma: no cover - network errors are integration tested + self._logger.exception("Failed to parse incoming messages: %s", exc) + raise web.HTTPBadRequest(reason=f"Invalid payload: {exc}") from exc + + result = await self._handler(envelopes) + if result is None: + return web.Response(status=200, text="ok") + payload = dumps_messages(result) + return web.Response(status=200, body=payload, content_type="application/json") + + def make_app(self) -> web.Application: + """ + 返回 aiohttp Application,可被外部 server(gunicorn/uvicorn)直接使用。 + """ + + return self._app + + def add_to_app(self, app: web.Application) -> None: + """ + 将消息路由注册到给定的 aiohttp app,方便与既有服务整合。 + """ + + app.router.add_post(self._path, self._handle_messages) diff --git a/src/mofox_bus/transport/ws_client.py b/src/mofox_bus/transport/ws_client.py new file mode 100644 index 000000000..c9e26955e --- /dev/null +++ b/src/mofox_bus/transport/ws_client.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable, Iterable, List, Sequence + +import aiohttp + +from ..codec import dumps_messages, loads_messages +from ..types import MessageEnvelope + +IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]] + + +class WsMessageClient: + """ + 管理 WebSocket 连接,提供 send/receive API,并在后台读取消息。 + """ + + def __init__( + self, + url: str, + *, + handler: IncomingHandler | None = None, + session: aiohttp.ClientSession | None = None, + reconnect_interval: float = 5.0, + ) -> None: + self._url = url + self._handler = handler + self._session = session + self._reconnect_interval = reconnect_interval + self._owns_session = session is None + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._receive_task: asyncio.Task | None = None + self._closed = False + self._logger = logging.getLogger("mofox_bus.ws_client") + + async def connect(self) -> None: + await self._ensure_session() + await self._connect_once() + + async def _connect_once(self) -> None: + assert self._session is not None + self._ws = await self._session.ws_connect(self._url) + self._logger.info("Connected to %s", self._url) + self._receive_task = asyncio.create_task(self._receive_loop()) + + async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None: + if not messages: + return + ws = await self._ensure_ws() + payload = dumps_messages(messages) + await ws.send_bytes(payload) + + async def send_message(self, message: MessageEnvelope) -> None: + await self.send_messages([message]) + + async def close(self) -> None: + self._closed = True + if self._receive_task: + self._receive_task.cancel() + if self._ws: + await self._ws.close() + self._ws = None + if self._owns_session and self._session: + await self._session.close() + self._session = None + + async def _receive_loop(self) -> None: + assert self._ws is not None + try: + async for msg in self._ws: + if msg.type in (aiohttp.WSMsgType.BINARY, aiohttp.WSMsgType.TEXT): + envelopes = loads_messages(msg.data) + for env in envelopes: + if self._handler is not None: + await self._handler(env) + elif msg.type == aiohttp.WSMsgType.ERROR: + self._logger.warning("WebSocket error: %s", msg.data) + break + except asyncio.CancelledError: # pragma: no cover - cancellation path + return + finally: + if not self._closed: + await self._reconnect() + + async def _reconnect(self) -> None: + self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval) + await asyncio.sleep(self._reconnect_interval) + await self._connect_once() + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession() + return self._session + + async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse: + if self._ws is None or self._ws.closed: + await self._connect_once() + assert self._ws is not None + return self._ws + + async def __aenter__(self) -> "WsMessageClient": + await self.connect() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() diff --git a/src/mofox_bus/transport/ws_server.py b/src/mofox_bus/transport/ws_server.py new file mode 100644 index 000000000..7e75cc749 --- /dev/null +++ b/src/mofox_bus/transport/ws_server.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import Awaitable, Callable, Iterable, List, Set + +from aiohttp import WSMsgType, web + +from ..codec import dumps_messages, loads_messages +from ..types import MessageEnvelope + +WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]] + + +class WsMessageServer: + """ + 封装 WebSocket 服务端逻辑,负责接收消息并广播响应。 + """ + + def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None: + self._handler = handler + self._app = web.Application() + self._path = path + self._app.add_routes([web.get(path, self._handle_ws)]) + self._connections: Set[web.WebSocketResponse] = set() + self._lock = asyncio.Lock() + self._logger = logging.getLogger("mofox_bus.ws_server") + + async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + self._logger.info("WebSocket connection opened: %s", request.remote) + + async with self._track_connection(ws): + async for message in ws: + if message.type in (WSMsgType.BINARY, WSMsgType.TEXT): + envelopes = loads_messages(message.data) + for env in envelopes: + await self._handler(env) + elif message.type == WSMsgType.ERROR: + self._logger.warning("WebSocket connection error: %s", ws.exception()) + break + + self._logger.info("WebSocket connection closed: %s", request.remote) + return ws + + @asynccontextmanager + async def _track_connection(self, ws: web.WebSocketResponse): + async with self._lock: + self._connections.add(ws) + try: + yield + finally: + async with self._lock: + self._connections.discard(ws) + + async def broadcast(self, messages: Iterable[MessageEnvelope]) -> None: + payload = dumps_messages(list(messages)) + async with self._lock: + targets = list(self._connections) + for ws in targets: + await ws.send_bytes(payload) + + def make_app(self) -> web.Application: + return self._app diff --git a/src/mofox_bus/types.py b/src/mofox_bus/types.py new file mode 100644 index 000000000..f4a9eebae --- /dev/null +++ b/src/mofox_bus/types.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Literal, NotRequired, TypedDict + +MessageDirection = Literal["incoming", "outgoing"] +Role = Literal["user", "assistant", "system", "tool", "platform"] +ContentType = Literal[ + "text", + "image", + "audio", + "file", + "video", + "event", + "command", + "system", +] + +EventType = Literal[ + "message_created", + "message_updated", + "message_deleted", + "member_join", + "member_leave", + "typing", + "reaction_add", + "reaction_remove", +] + + +class TextContent(TypedDict, total=False): + type: Literal["text"] + text: str + markdown: NotRequired[bool] + entities: NotRequired[List[Dict[str, Any]]] + + +class ImageContent(TypedDict, total=False): + type: Literal["image"] + url: str + mime_type: NotRequired[str] + width: NotRequired[int] + height: NotRequired[int] + file_id: NotRequired[str] + + +class FileContent(TypedDict, total=False): + type: Literal["file"] + url: str + mime_type: NotRequired[str] + file_name: NotRequired[str] + file_size: NotRequired[int] + file_id: NotRequired[str] + + +class AudioContent(TypedDict, total=False): + type: Literal["audio"] + url: str + mime_type: NotRequired[str] + duration_ms: NotRequired[int] + file_id: NotRequired[str] + + +class VideoContent(TypedDict, total=False): + type: Literal["video"] + url: str + mime_type: NotRequired[str] + duration_ms: NotRequired[int] + width: NotRequired[int] + height: NotRequired[int] + file_id: NotRequired[str] + + +class EventContent(TypedDict): + type: Literal["event"] + event_type: EventType + raw: Dict[str, Any] + + +class CommandContent(TypedDict, total=False): + type: Literal["command"] + name: str + args: Dict[str, Any] + + +class SystemContent(TypedDict): + type: Literal["system"] + text: str + + +Content = ( + TextContent + | ImageContent + | FileContent + | AudioContent + | VideoContent + | EventContent + | CommandContent + | SystemContent +) + + +class SenderInfo(TypedDict, total=False): + user_id: str + role: Role + display_name: NotRequired[str] + avatar_url: NotRequired[str] + raw: NotRequired[Dict[str, Any]] + + +class ChannelInfo(TypedDict, total=False): + channel_id: str + channel_type: Literal[ + "private", + "group", + "supergroup", + "channel", + "dm", + "room", + "thread", + ] + title: NotRequired[str] + workspace_id: NotRequired[str] + raw: NotRequired[Dict[str, Any]] + + +class MessageEnvelope(TypedDict, total=False): + id: str + direction: MessageDirection + platform: str + timestamp_ms: int + channel: ChannelInfo + sender: SenderInfo + content: Content + conversation_id: str + thread_id: NotRequired[str] + reply_to_message_id: NotRequired[str] + correlation_id: NotRequired[str] + is_edited: NotRequired[bool] + is_ephemeral: NotRequired[bool] + raw_platform_message: NotRequired[Dict[str, Any]] + metadata: NotRequired[Dict[str, Any]] + schema_version: NotRequired[int] + + +__all__ = [ + "AudioContent", + "ChannelInfo", + "CommandContent", + "Content", + "ContentType", + "EventContent", + "EventType", + "FileContent", + "ImageContent", + "MessageDirection", + "MessageEnvelope", + "Role", + "SenderInfo", + "SystemContent", + "TextContent", + "VideoContent", +] diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 7214cd874..fb8fac8f8 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -91,7 +91,7 @@ import time import traceback from typing import TYPE_CHECKING, Any -from maim_message import Seg, UserInfo +from mofox_bus import Seg, UserInfo if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index d58a5d2e9..10d1e2ede 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -34,7 +34,7 @@ class InjectionRule: raise ValueError(f"'{self.injection_type.value}'类型的注入规则必须提供 'target_content'。") -from maim_message import Seg +from mofox_bus import Seg from src.llm_models.payload_content.tool_option import ToolCall as ToolCall from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 38442fd09..bab11db14 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -10,7 +10,7 @@ from collections.abc import Callable import aiohttp import filetype -from maim_message import UserInfo +from mofox_bus import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py index 3abf48b18..c15aba88c 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py @@ -1,4 +1,4 @@ -from maim_message import RouteConfig, Router, TargetConfig +from mofox_bus import RouteConfig, Router, TargetConfig from src.common.logger import get_logger from src.common.server import get_global_server diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index b7e7b2c25..7824aa3a0 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import websockets as Server -from maim_message import ( +from mofox_bus import ( BaseMessageInfo, FormatInfo, GroupInfo, diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py index b64db620e..02b1a0a12 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py @@ -1,6 +1,6 @@ import asyncio -from maim_message import MessageBase, Router +from mofox_bus import MessageBase, Router from src.common.logger import get_logger from src.plugin_system.apis import config_api diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 1f6bf104e..7e64556aa 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -4,7 +4,7 @@ import time from typing import ClassVar, Optional, Tuple import websockets as Server -from maim_message import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo +from mofox_bus import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo from src.common.logger import get_logger from src.plugin_system.apis import config_api diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 9ec950bc8..91f84b5ad 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -3,7 +3,7 @@ import random import time import websockets as Server import uuid -from maim_message import ( +from mofox_bus import ( UserInfo, GroupInfo, Seg, diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index ec3910f9a..74877c9be 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -406,7 +406,7 @@ console_log_level = "INFO" # 控制台日志级别,可选: DEBUG, INFO, WARNIN file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL # 第三方库日志控制 -suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "aiosqlite", "openai","uvicorn","rjieba","maim_message"] # 完全屏蔽的库 +suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "aiosqlite", "openai","uvicorn","rjieba","message_bus"] # 完全屏蔽的库 library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别 [dependency_management] # 插件Python依赖管理配置 @@ -423,10 +423,10 @@ install_log_level = "INFO" [debug] show_prompt = false # 是否显示prompt -[maim_message] +[message_bus] auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 -# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 -use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复 +# 以下项目若要使用需要打开use_custom,并单独配置message_bus的服务器 +use_custom = false # 是否启用自定义的message_bus服务器,注意这需要设置新的端口,不能与.env重复 host="127.0.0.1" port=8090 mode="ws" # 支持ws和tcp两种模式 @@ -631,4 +631,4 @@ throw_topic_weight = 0.3 # throw_topic动作的基础权重 # --- 调试与监控 --- enable_statistics = false # 是否启用统计功能(记录触发次数、决策分布等) -log_decisions = false # 是否记录每次决策的详细日志(用于调试) \ No newline at end of file +log_decisions = false # 是否记录每次决策的详细日志(用于调试)