feat: 实现消息编解码器和消息处理模型
- 添加编解码器,用于序列化和反序列化MessageEnvelope对象。 - 创建消息模型,包括分段(Seg)、群组信息(GroupInfo)、用户信息(UserInfo)、格式信息(FormatInfo)、模板信息(TemplateInfo)、基础消息信息(BaseMessageInfo)和消息基础(MessageBase)。 引入路由器以管理消息路由和连接。 - 实现运行时机制,通过钩子和路由来处理消息处理。 - 使用HTTP和WebSocket客户端和服务器开发传输层,以进行消息传输。 - 为消息内容和信封定义类型,以标准化消息结构。
This commit is contained in:
189
docs/mofox_bus.md
Normal file
189
docs/mofox_bus.md
Normal file
@@ -0,0 +1,189 @@
|
||||
# MoFox Bus 消息库说明
|
||||
|
||||
MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_message`,将核心与各平台适配器之间的通信抽象成可拓展、可热插拔的组件。该库完全异步、面向高吞吐,覆盖消息建模、序列化、传输层、运行时路由、适配器工具等多个层面。
|
||||
|
||||
---
|
||||
|
||||
## 1. 设计目标
|
||||
|
||||
- **通用消息模型**:统一 envelope / content,使核心逻辑不关心平台差异。
|
||||
- **零拷贝字典结构**:TypedDict + dataclass,方便直接序列化 JSON。
|
||||
- **高性能传输**:批量收发 + orjson 序列化 + WS/HTTP 封装。
|
||||
- **适配器友好**:提供 BaseAdapter、Sink、Router 与批处理工具。
|
||||
- **渐进可扩展**:未来扩充 gRPC/MQ 仅需在 `transport/` 下新增实现。
|
||||
|
||||
---
|
||||
|
||||
## 2. 包结构概览(`src/mofox_bus/`)
|
||||
|
||||
| 模块 | 主要职责 |
|
||||
| --- | --- |
|
||||
| `types.py` | TypedDict 消息模型(MessageEnvelope、Content、Sender/ChannelInfo 等)。 |
|
||||
| `message_models.py` | dataclass 版 `Seg` / `MessageBase`,兼容老的消息段语义。 |
|
||||
| `codec.py` | 高性能 JSON 编解码,含批量接口与 schema version 升级钩子。 |
|
||||
| `runtime.py` | 消息路由/Hook/批处理调度器,支撑核心处理链。 |
|
||||
| `adapter_utils.py` | BaseAdapter、CoreMessageSink、BatchDispatcher 等工具。 |
|
||||
| `api.py` | WebSocket `MessageServer`/`MessageClient`,提供 token、复用 FastAPI。 |
|
||||
| `router.py` | 多平台客户端统一管理,自动重连与动态路由。 |
|
||||
| `transport/` | HTTP/WS server&client 轻量封装,可独立复用。 |
|
||||
| `__init__.py` | 导出常用符号供外部按需引用。 |
|
||||
|
||||
---
|
||||
|
||||
## 3. 消息模型
|
||||
|
||||
### 3.1 Envelope TypedDict(`types.py`)
|
||||
|
||||
- `MessageEnvelope`:核心字段包括 `id`、`direction`、`platform`、`timestamp_ms`、`channel`、`sender`、`content` 等,一律使用毫秒时间戳,保留 `raw_platform_message` 与 `metadata` 便于调试 / 扩展。
|
||||
- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。
|
||||
- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。
|
||||
|
||||
### 3.2 dataclass 消息段(`message_models.py`)
|
||||
|
||||
- `Seg`:表示一段内容,支持嵌套 `seglist`。
|
||||
- `UserInfo` / `GroupInfo` / `FormatInfo` / `TemplateInfo`:保留旧结构字段,但新增 `user_avatar` 等业务常用字段。
|
||||
- `BaseMessageInfo` + `MessageBase`:方便适配器沿用原始 `MessageBase` API,也使核心可以在内存中直接传递 dataclass。
|
||||
|
||||
> **何时使用 TypedDict vs dataclass?**
|
||||
TypedDict 更适合网络传输和依赖注入;dataclass 版 MessageBase 则保留分段消息特性,适合适配器内部加工。
|
||||
|
||||
---
|
||||
|
||||
## 4. 序列化与版本(`codec.py`)
|
||||
|
||||
- `dumps_message` / `loads_message`:处理单条 `MessageEnvelope`,自动补充 `schema_version`(默认 1)。
|
||||
- `dumps_messages` / `loads_messages`:批量传输 `{schema_version, items: [...]}`,减少 HTTP/WS 次数。
|
||||
- 预留 `_upgrade_schema_if_needed` 钩子,可在引入 v2/v3 时集中兼容逻辑。
|
||||
|
||||
默认使用 `orjson`,若运行环境缺失会自动 fallback 到标准库 `json`,保证兼容性。
|
||||
|
||||
---
|
||||
|
||||
## 5. 运行时调度(`runtime.py`)
|
||||
|
||||
- `MessageRuntime`:
|
||||
- `add_route(predicate, handler)` 或 `@runtime.route(...)` 装饰器注册消息处理器。
|
||||
- `register_before_hook` / `register_after_hook` / `register_error_hook` 注入监控、埋点、Trace。
|
||||
- `set_batch_handler` 支持一次处理整批消息(例如批量落库)。
|
||||
- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。
|
||||
|
||||
运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。
|
||||
|
||||
---
|
||||
|
||||
## 6. 传输层封装(`transport/`)
|
||||
|
||||
### 6.1 HTTP
|
||||
- `HttpMessageServer`:使用 `aiohttp.web` 监听 `POST /messages`,调用业务 handler 后可返回响应批。
|
||||
- `HttpMessageClient`:管理 `aiohttp.ClientSession`,`send_messages(messages, expect_reply=True)` 支持同步等待回复。
|
||||
|
||||
### 6.2 WebSocket
|
||||
- `WsMessageServer`:基于 `aiohttp`,维护连接集合,支持 `broadcast`。
|
||||
- `WsMessageClient`:自动重连、后台读取,`send_messages`/`send_message` 直接发送批量。
|
||||
|
||||
以上都复用了 `codec` 批量协议,统一上下游格式。
|
||||
|
||||
---
|
||||
|
||||
## 7. Server / Client / Router(`api.py`、`router.py`)
|
||||
|
||||
### 7.1 MessageServer
|
||||
- 可复用已有 FastAPI 实例(`app=get_global_server().get_app()`),在同一进程内共享路由。
|
||||
- 支持 header token 校验 (`enable_token` + `add_valid_token`)。
|
||||
- `broadcast_message`、`broadcast_to_platform`、`send_message(message: MessageBase)` 满足不同场景。
|
||||
|
||||
### 7.2 MessageClient
|
||||
- 仅 WebSocket 模式,管理 `aiohttp` 连接、自动收消息,供适配器推送到核心。
|
||||
|
||||
### 7.3 Router
|
||||
- `RouteConfig` + `TargetConfig` 描述平台到 URL 的映射。
|
||||
- `Router.run()` 会为每个平台创建 `MessageClient` 并保持心跳,`_monitor_connections` 自动重连。
|
||||
- `register_class_handler` 可绑定 Napcat 适配器那样的 class handler。
|
||||
|
||||
---
|
||||
|
||||
## 8. 适配器工具(`adapter_utils.py`)
|
||||
|
||||
- `BaseAdapter`:约定入站 `from_platform_message` / 出站 `_send_platform_message`,默认提供批量入口。
|
||||
- `CoreMessageSink` 协议 + `InProcessCoreSink`:方便在同进程中把适配器消息直接推给核心协程。
|
||||
- `BatchDispatcher`:封装缓冲 + 定时 flush 的发送管道,可与 HTTP/WS 客户端组合提升吞吐。
|
||||
|
||||
---
|
||||
|
||||
## 9. 集成与配置
|
||||
|
||||
1. **配置文件**:在 `config.*.toml` 中新增 `[message_bus]` 段(参考 `template/bot_config_template.toml`),控制 host/port/token/wss 等。
|
||||
2. **服务启动**:`src/common/message/api.py` 中的 `get_global_api()` 已默认实例化 `MessageServer`,并将 token 写入服务器。
|
||||
3. **适配器更新**:所有使用原 `maim_message` 的模块已改为 `from mofox_bus import ...`,无需额外适配即可继续利用 `MessageBase` / `Router` API。
|
||||
|
||||
---
|
||||
|
||||
## 10. 快速上手示例
|
||||
|
||||
```python
|
||||
from mofox_bus import MessageRuntime, types
|
||||
from mofox_bus.transport import HttpMessageServer
|
||||
|
||||
runtime = MessageRuntime()
|
||||
|
||||
@runtime.route(lambda env: env["content"]["type"] == "text")
|
||||
async def handle_text(env: types.MessageEnvelope):
|
||||
print("收到文本:", env["content"]["text"])
|
||||
|
||||
async def http_handler(messages: list[types.MessageEnvelope]):
|
||||
await runtime.handle_batch(messages)
|
||||
|
||||
server = HttpMessageServer(http_handler)
|
||||
app = server.make_app() # 交给 aiohttp/uvicorn 运行
|
||||
```
|
||||
|
||||
**适配器 Skeleton:**
|
||||
```python
|
||||
from mofox_bus import (
|
||||
BaseAdapter,
|
||||
MessageEnvelope,
|
||||
WebSocketAdapterOptions,
|
||||
)
|
||||
|
||||
class MyAdapter(BaseAdapter):
|
||||
platform = "custom"
|
||||
|
||||
def __init__(self, core_sink):
|
||||
super().__init__(
|
||||
core_sink,
|
||||
transport=WebSocketAdapterOptions(
|
||||
url="ws://127.0.0.1:19898",
|
||||
incoming_parser=lambda raw: orjson.loads(raw)["payload"],
|
||||
),
|
||||
)
|
||||
|
||||
def from_platform_message(self, raw: dict) -> MessageEnvelope:
|
||||
return {
|
||||
"id": raw["id"],
|
||||
"direction": "incoming",
|
||||
"platform": self.platform,
|
||||
"timestamp_ms": raw["ts"],
|
||||
"channel": {"channel_id": raw["room_id"], "channel_type": "dm"},
|
||||
"sender": {"user_id": raw["user_id"], "role": "user"},
|
||||
"content": {"type": "text", "text": raw["content"]},
|
||||
"conversation_id": raw["room_id"],
|
||||
}
|
||||
```
|
||||
|
||||
- 如果传入 `WebSocketAdapterOptions`,BaseAdapter 会自动建立连接、监听、默认封装 `{"type":"message","payload":...}` 的标准 JSON,并允许通过 `outgoing_encoder` 自定义下行格式。
|
||||
- 如果传入 `HttpAdapterOptions`,BaseAdapter 会自动启动一个 aiohttp Webhook(`POST /adapter/messages`)并将收到的 JSON 批量投递给核心。
|
||||
|
||||
> 完整的 WebSocket 适配器示例见 `examples/mofox_bus_demo_adapter.py`:演示了平台提供 WS 接口、适配器通过 `WebSocketAdapterOptions` 自动启动监听、接收/处理/回发的全过程,可直接运行观察日志。
|
||||
|
||||
---
|
||||
|
||||
## 11. 调试与最佳实践
|
||||
|
||||
- 利用 `MessageRuntime.register_error_hook` 打印 `correlation_id` / `id`,快速定位异常消息。
|
||||
- 如果适配器与核心同进程,优先使用 `InProcessCoreSink` 以避免 JSON 编解码。
|
||||
- 批量吞吐场景(如 HTTP 推送)优先通过 `BatchDispatcher` 聚合再发送,可显著降低连接开销。
|
||||
- 自定义传输实现可参考 `transport/http_server.py` / `ws_client.py`,保持 `loads_messages` / `dumps_messages` 协议即可与现有核心互通。
|
||||
|
||||
---
|
||||
|
||||
通过以上结构,MoFox Bus 提供了一套端到端的统一消息能力,满足 AI Bot 在多平台、多形态场景下的高性能传输与扩展需求。若需要扩展新的传输协议或内容类型,只需在对应模块增加 Literal/TypedDict/Transport 实现即可。祝使用愉快!
|
||||
196
examples/mofox_bus_demo_adapter.py
Normal file
196
examples/mofox_bus_demo_adapter.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
示例:演示一个最小可用的 WebSocket 适配器如何使用 BaseAdapter 的自动传输封装:
|
||||
1) 通过 WS 接入平台;
|
||||
2) 将平台推送的消息转成 MessageEnvelope 并交给核心;
|
||||
3) 接收核心回复并通过 WS 再发回平台。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import orjson
|
||||
import websockets
|
||||
|
||||
# 追加 src 目录,便于直接运行示例
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from mofox_bus import (
|
||||
BaseAdapter,
|
||||
InProcessCoreSink,
|
||||
MessageEnvelope,
|
||||
MessageRuntime,
|
||||
WebSocketAdapterOptions,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. 模拟一个提供 WebSocket 接口的平台
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePlatformServer:
|
||||
"""
|
||||
适配器将通过 WS 连接到这个模拟平台。
|
||||
平台会广播消息给所有连接,适配器发送的响应也会被打印出来。
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 19898) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._connections: set[Any] = set()
|
||||
self._server = None
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"ws://{self._host}:{self._port}"
|
||||
|
||||
async def start(self) -> None:
|
||||
self._server = await websockets.serve(self._handler, self._host, self._port)
|
||||
print(f"[Platform] WebSocket server listening on {self.url}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
async def _handler(self, ws) -> None:
|
||||
self._connections.add(ws)
|
||||
print("[Platform] adapter connected")
|
||||
try:
|
||||
async for raw in ws:
|
||||
data = orjson.loads(raw)
|
||||
if data["type"] == "send":
|
||||
print(f"[Platform] <- Bot: {data['payload']['text']}")
|
||||
finally:
|
||||
self._connections.discard(ws)
|
||||
print("[Platform] adapter disconnected")
|
||||
|
||||
async def simulate_incoming_message(self, text: str) -> None:
|
||||
payload = {
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"channel_id": "room-42",
|
||||
"user_id": "demo-user",
|
||||
"text": text,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
message = orjson.dumps({"type": "message", "payload": payload}).decode()
|
||||
for ws in list(self._connections):
|
||||
await ws.send(message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. 适配器实现:仅关注核心转换逻辑,网络层交由 BaseAdapter 管理
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DemoWsAdapter(BaseAdapter):
|
||||
platform = "demo"
|
||||
|
||||
def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope:
|
||||
return {
|
||||
"id": raw["message_id"],
|
||||
"direction": "incoming",
|
||||
"platform": self.platform,
|
||||
"timestamp_ms": int(raw["timestamp"] * 1000),
|
||||
"channel": {"channel_id": raw["channel_id"], "channel_type": "room"},
|
||||
"sender": {"user_id": raw["user_id"], "role": "user"},
|
||||
"conversation_id": raw["channel_id"],
|
||||
"content": {"type": "text", "text": raw["text"]},
|
||||
}
|
||||
|
||||
|
||||
def incoming_parser(raw: str | bytes) -> Any:
|
||||
data = orjson.loads(raw)
|
||||
if data.get("type") == "message":
|
||||
return data["payload"]
|
||||
return data
|
||||
|
||||
|
||||
def outgoing_encoder(envelope: MessageEnvelope) -> str:
|
||||
return orjson.dumps(
|
||||
{
|
||||
"type": "send",
|
||||
"payload": {
|
||||
"channel_id": envelope["channel"]["channel_id"],
|
||||
"text": envelope["content"]["text"],
|
||||
},
|
||||
}
|
||||
).decode()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. 核心 Runtime:注册处理器并通过 InProcessCoreSink 接收消息
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
runtime = MessageRuntime()
|
||||
|
||||
|
||||
@runtime.route(lambda env: env["direction"] == "incoming")
|
||||
async def handle_incoming(env: MessageEnvelope) -> MessageEnvelope:
|
||||
user_text = env["content"]["text"]
|
||||
reply_text = f"核心收到:{user_text}"
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"direction": "outgoing",
|
||||
"platform": env["platform"],
|
||||
"timestamp_ms": int(time.time() * 1000),
|
||||
"channel": env["channel"],
|
||||
"sender": {
|
||||
"user_id": "bot",
|
||||
"role": "assistant",
|
||||
"display_name": "DemoBot",
|
||||
},
|
||||
"conversation_id": env["conversation_id"],
|
||||
"content": {"type": "text", "text": reply_text},
|
||||
}
|
||||
|
||||
|
||||
adapter: Optional[DemoWsAdapter] = None
|
||||
|
||||
|
||||
async def core_entry(message: MessageEnvelope) -> None:
|
||||
response = await runtime.handle_message(message)
|
||||
if response and adapter is not None:
|
||||
await adapter.send_to_platform(response)
|
||||
|
||||
|
||||
core_sink = InProcessCoreSink(core_entry)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. 串起来并运行 Demo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def main() -> None:
|
||||
platform = FakePlatformServer()
|
||||
await platform.start()
|
||||
|
||||
global adapter
|
||||
adapter = DemoWsAdapter(
|
||||
core_sink,
|
||||
transport=WebSocketAdapterOptions(
|
||||
url=platform.url,
|
||||
incoming_parser=incoming_parser,
|
||||
outgoing_encoder=outgoing_encoder,
|
||||
),
|
||||
)
|
||||
await adapter.start()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await platform.simulate_incoming_message("你好,MoFox Bus!")
|
||||
await platform.simulate_incoming_message("请问你是谁?")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
await adapter.stop()
|
||||
await platform.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -30,7 +30,6 @@ dependencies = [
|
||||
"langfuse==3.7.0",
|
||||
"lunar-python>=1.4.4",
|
||||
"lxml>=6.0.0",
|
||||
"maim-message>=0.3.8",
|
||||
"matplotlib>=3.10.3",
|
||||
"networkx>=3.4.2",
|
||||
"orjson>=3.10",
|
||||
|
||||
@@ -15,7 +15,6 @@ filetype
|
||||
slowapi
|
||||
rjieba
|
||||
jsonlines
|
||||
maim_message
|
||||
quick_algo
|
||||
matplotlib
|
||||
networkx
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from maim_message import UserInfo
|
||||
from mofox_bus import UserInfo
|
||||
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
@@ -353,7 +353,7 @@ class ChatBot:
|
||||
return
|
||||
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from maim_message import BaseMessageInfo
|
||||
from mofox_bus import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
if temp_message_info.additional_config:
|
||||
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from mofox_bus import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
@@ -358,7 +358,7 @@ class ChatManager:
|
||||
def register_message(self, message: DatabaseMessages):
|
||||
"""注册消息到聊天流"""
|
||||
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from mofox_bus import GroupInfo, UserInfo
|
||||
|
||||
user_info = UserInfo(
|
||||
platform=message.user_info.platform,
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from mofox_bus import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from maim_message import BaseMessageInfo, Seg
|
||||
from mofox_bus import BaseMessageInfo, Seg
|
||||
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
@@ -430,9 +430,9 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
Returns:
|
||||
BaseMessageInfo: 重建的消息信息对象
|
||||
"""
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from mofox_bus import GroupInfo, UserInfo
|
||||
|
||||
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
||||
# 从 DatabaseMessages 的 user_info 转换为 mofox_bus.UserInfo
|
||||
user_info = UserInfo(
|
||||
platform=db_message.user_info.platform,
|
||||
user_id=db_message.user_info.user_id,
|
||||
@@ -440,7 +440,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
user_cardname=db_message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在)
|
||||
# 从 DatabaseMessages 的 group_info 转换为 mofox_bus.GroupInfo(如果存在)
|
||||
group_info = None
|
||||
if db_message.group_info:
|
||||
group_info = GroupInfo(
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rjieba
|
||||
from maim_message import UserInfo
|
||||
from mofox_bus import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
@@ -534,7 +534,7 @@ DEFAULT_MODULE_COLORS = {
|
||||
# 数据库和消息
|
||||
"database_model": "#875F00", # 橙褐色
|
||||
"database": "#00FF00", # 橙褐色
|
||||
"maim_message": "#AF87D7", # 紫褐色
|
||||
"mofox_bus": "#AF87D7", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "#808080", # 深灰色
|
||||
"confirm": "#FFFF00", # 黄色+粗体
|
||||
|
||||
@@ -1,70 +1,53 @@
|
||||
import importlib.metadata
|
||||
import os
|
||||
|
||||
from maim_message import MessageServer
|
||||
from mofox_bus import MessageServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server
|
||||
from src.config.config import global_config
|
||||
|
||||
global_api = None
|
||||
global_api: MessageServer | None = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
def get_global_api() -> MessageServer:
|
||||
"""
|
||||
获取全局 MessageServer 单例。
|
||||
"""
|
||||
|
||||
global global_api
|
||||
if global_api is None:
|
||||
# 检查maim_message版本
|
||||
try:
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
|
||||
except (importlib.metadata.PackageNotFoundError, ValueError):
|
||||
version_compatible = False
|
||||
if global_api is not None:
|
||||
return global_api
|
||||
|
||||
# 读取配置项
|
||||
maim_message_config = global_config.maim_message
|
||||
bus_config = global_config.message_bus
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
# 设置基本参数
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
kwargs: dict[str, object] = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"app": get_global_server().get_app(),
|
||||
}
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
if bus_config.use_custom:
|
||||
kwargs["host"] = bus_config.host
|
||||
kwargs["port"] = bus_config.port
|
||||
kwargs.pop("app", None)
|
||||
if bus_config.use_wss:
|
||||
if bus_config.cert_file:
|
||||
kwargs["ssl_certfile"] = bus_config.cert_file
|
||||
if bus_config.key_file:
|
||||
kwargs["ssl_keyfile"] = bus_config.key_file
|
||||
|
||||
kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"app": get_global_server().get_app(),
|
||||
}
|
||||
if bus_config.auth_token:
|
||||
kwargs["enable_token"] = True
|
||||
kwargs["custom_logger"] = get_logger("mofox_bus")
|
||||
|
||||
# 只有在版本 >= 0.3.0 时才使用高级特性
|
||||
if version_compatible:
|
||||
# 添加自定义logger
|
||||
maim_message_logger = get_logger("maim_message")
|
||||
kwargs["custom_logger"] = maim_message_logger
|
||||
|
||||
# 添加token认证
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
if maim_message_config.use_custom:
|
||||
# 添加WSS模式支持
|
||||
del kwargs["app"]
|
||||
kwargs["host"] = maim_message_config.host
|
||||
kwargs["port"] = maim_message_config.port
|
||||
kwargs["mode"] = maim_message_config.mode
|
||||
if maim_message_config.use_wss:
|
||||
if maim_message_config.cert_file:
|
||||
kwargs["ssl_certfile"] = maim_message_config.cert_file
|
||||
if maim_message_config.key_file:
|
||||
kwargs["ssl_keyfile"] = maim_message_config.key_file
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
|
||||
global_api = MessageServer(**kwargs)
|
||||
if version_compatible and maim_message_config.auth_token:
|
||||
for token in maim_message_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
global_api = MessageServer(**kwargs)
|
||||
for token in bus_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
return global_api
|
||||
|
||||
@@ -26,7 +26,7 @@ from src.config.official_configs import (
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MessageBusConfig,
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
MoodConfig,
|
||||
@@ -392,7 +392,7 @@ class Config(ValidatedConfigBase):
|
||||
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
|
||||
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
|
||||
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
|
||||
maim_message: MaimMessageConfig = Field(..., description="Maim消息配置")
|
||||
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
|
||||
tool: ToolConfig = Field(..., description="工具配置")
|
||||
debug: DebugConfig = Field(..., description="调试配置")
|
||||
|
||||
@@ -574,8 +574,18 @@ class ExperimentalConfig(ValidatedConfigBase):
|
||||
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
|
||||
|
||||
|
||||
class MaimMessageConfig(ValidatedConfigBase):
|
||||
"""maim_message配置类"""
|
||||
class MessageBusConfig(ValidatedConfigBase):
|
||||
"""mofox_bus 消息服务配置"""
|
||||
|
||||
use_custom: bool = Field(default=False, description="是否使用自定义地址")
|
||||
host: str = Field(default="127.0.0.1", description="消息服务主机")
|
||||
port: int = Field(default=8090, description="消息服务端口")
|
||||
mode: Literal["ws", "tcp"] = Field(default="ws", description="传输模式")
|
||||
use_wss: bool = Field(default=False, description="是否启用 WSS")
|
||||
cert_file: str = Field(default="", description="证书文件路径")
|
||||
key_file: str = Field(default="", description="密钥文件路径")
|
||||
auth_token: list[str] = Field(default_factory=lambda: [], description="认证 token 列表")
|
||||
|
||||
|
||||
use_custom: bool = Field(default=False, description="启用自定义")
|
||||
host: str = Field(default="127.0.0.1", description="主机")
|
||||
|
||||
@@ -10,7 +10,7 @@ from functools import partial
|
||||
from random import choices
|
||||
from typing import Any
|
||||
|
||||
from maim_message import MessageServer
|
||||
from mofox_bus import MessageServer
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
94
src/mofox_bus/__init__.py
Normal file
94
src/mofox_bus/__init__.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
MoFox 内部通用消息总线实现。
|
||||
|
||||
该模块导出 TypedDict 消息模型、序列化工具、传输层封装以及适配器辅助工具,
|
||||
供核心进程与各类平台适配器共享。
|
||||
"""
|
||||
|
||||
from . import codec, types
|
||||
from .adapter_utils import (
|
||||
AdapterTransportOptions,
|
||||
BaseAdapter,
|
||||
BatchDispatcher,
|
||||
CoreMessageSink,
|
||||
HttpAdapterOptions,
|
||||
InProcessCoreSink,
|
||||
WebSocketLike,
|
||||
WebSocketAdapterOptions,
|
||||
)
|
||||
from .api import MessageClient, MessageServer
|
||||
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
|
||||
from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo
|
||||
from .router import RouteConfig, Router, TargetConfig
|
||||
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime
|
||||
from .types import (
|
||||
AudioContent,
|
||||
ChannelInfo,
|
||||
CommandContent,
|
||||
Content,
|
||||
ContentType,
|
||||
EventContent,
|
||||
EventType,
|
||||
FileContent,
|
||||
ImageContent,
|
||||
MessageDirection,
|
||||
MessageEnvelope,
|
||||
Role,
|
||||
SenderInfo,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
VideoContent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# TypedDict model
|
||||
"AudioContent",
|
||||
"ChannelInfo",
|
||||
"CommandContent",
|
||||
"Content",
|
||||
"ContentType",
|
||||
"EventContent",
|
||||
"EventType",
|
||||
"FileContent",
|
||||
"ImageContent",
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
"Role",
|
||||
"SenderInfo",
|
||||
"SystemContent",
|
||||
"TextContent",
|
||||
"VideoContent",
|
||||
# Codec helpers
|
||||
"codec",
|
||||
"dumps_message",
|
||||
"dumps_messages",
|
||||
"loads_message",
|
||||
"loads_messages",
|
||||
# Runtime / routing
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"MessageProcessingError",
|
||||
# Message dataclasses
|
||||
"Seg",
|
||||
"GroupInfo",
|
||||
"UserInfo",
|
||||
"FormatInfo",
|
||||
"TemplateInfo",
|
||||
"BaseMessageInfo",
|
||||
"MessageBase",
|
||||
# Server/client/router
|
||||
"MessageServer",
|
||||
"MessageClient",
|
||||
"Router",
|
||||
"RouteConfig",
|
||||
"TargetConfig",
|
||||
# Adapter helpers
|
||||
"AdapterTransportOptions",
|
||||
"BaseAdapter",
|
||||
"BatchDispatcher",
|
||||
"CoreMessageSink",
|
||||
"InProcessCoreSink",
|
||||
"WebSocketLike",
|
||||
"WebSocketAdapterOptions",
|
||||
"HttpAdapterOptions",
|
||||
]
|
||||
270
src/mofox_bus/adapter_utils.py
Normal file
270
src/mofox_bus/adapter_utils.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
|
||||
|
||||
import orjson
|
||||
from aiohttp import web as aiohttp_web
|
||||
import websockets
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
|
||||
class CoreMessageSink(Protocol):
|
||||
async def send(self, message: MessageEnvelope) -> None: ...
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
|
||||
|
||||
|
||||
class WebSocketLike(Protocol):
|
||||
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
|
||||
|
||||
@property
|
||||
def closed(self) -> bool: ...
|
||||
|
||||
async def send(self, data: str | bytes) -> None: ...
|
||||
|
||||
async def close(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketAdapterOptions:
|
||||
url: str
|
||||
headers: dict[str, str] | None = None
|
||||
incoming_parser: Callable[[str | bytes], Any] | None = None
|
||||
outgoing_encoder: Callable[[MessageEnvelope], str | bytes] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpAdapterOptions:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8089
|
||||
path: str = "/adapter/messages"
|
||||
app: aiohttp_web.Application | None = None
|
||||
|
||||
|
||||
AdapterTransportOptions = WebSocketAdapterOptions | HttpAdapterOptions | None
|
||||
|
||||
|
||||
class BaseAdapter:
|
||||
"""
|
||||
适配器基类:负责平台原始消息与 MessageEnvelope 之间的互转。
|
||||
子类需要实现平台入站解析与出站发送逻辑。
|
||||
"""
|
||||
|
||||
platform: str = "unknown"
|
||||
|
||||
def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None):
|
||||
"""
|
||||
Args:
|
||||
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
|
||||
transport: 传入 WebSocketAdapterOptions / HttpAdapterOptions 即可自动管理监听逻辑。
|
||||
"""
|
||||
self.core_sink = core_sink
|
||||
self._transport_config = transport
|
||||
self._ws: WebSocketLike | None = None
|
||||
self._ws_task: asyncio.Task | None = None
|
||||
self._http_runner: aiohttp_web.AppRunner | None = None
|
||||
self._http_site: aiohttp_web.BaseSite | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""根据配置自动启动 WS/HTTP 监听。"""
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._start_ws_transport(self._transport_config)
|
||||
elif isinstance(self._transport_config, HttpAdapterOptions):
|
||||
await self._start_http_transport(self._transport_config)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止自动管理的传输层。"""
|
||||
if self._ws_task:
|
||||
self._ws_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._ws_task
|
||||
self._ws_task = None
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http_site:
|
||||
await self._http_site.stop()
|
||||
self._http_site = None
|
||||
if self._http_runner:
|
||||
await self._http_runner.cleanup()
|
||||
self._http_runner = None
|
||||
|
||||
async def on_platform_message(self, raw: Any) -> None:
|
||||
"""处理平台下发的单条消息并交给核心。"""
|
||||
envelope = self.from_platform_message(raw)
|
||||
await self.core_sink.send(envelope)
|
||||
|
||||
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
||||
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
||||
envelopes = [self.from_platform_message(raw) for raw in raw_messages]
|
||||
await _send_many(self.core_sink, envelopes)
|
||||
|
||||
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
||||
"""核心生成单条消息时调用,由子类或自动传输层发送。"""
|
||||
await self._send_platform_message(envelope)
|
||||
|
||||
async def send_batch_to_platform(self, envelopes: list[MessageEnvelope]) -> None:
|
||||
"""默认串行发送整批消息,子类可根据平台特性重写。"""
|
||||
for env in envelopes:
|
||||
await self._send_platform_message(env)
|
||||
|
||||
def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _send_platform_message(self, envelope: MessageEnvelope) -> None:
|
||||
"""子类必须实现:把 MessageEnvelope 转为平台格式并发送出去。"""
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._send_via_ws(envelope)
|
||||
return
|
||||
raise NotImplementedError
|
||||
|
||||
async def _start_ws_transport(self, options: WebSocketAdapterOptions) -> None:
|
||||
self._ws = await websockets.connect(options.url, extra_headers=options.headers)
|
||||
self._ws_task = asyncio.create_task(self._ws_listen_loop(options))
|
||||
|
||||
async def _ws_listen_loop(self, options: WebSocketAdapterOptions) -> None:
|
||||
assert self._ws is not None
|
||||
parser = options.incoming_parser or self._default_ws_parser
|
||||
try:
|
||||
async for raw in self._ws:
|
||||
payload = parser(raw)
|
||||
await self.on_platform_message(payload)
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _send_via_ws(self, envelope: MessageEnvelope) -> None:
|
||||
if self._ws is None or self._ws.closed:
|
||||
raise RuntimeError("WebSocket transport is not active")
|
||||
encoder = None
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
encoder = self._transport_config.outgoing_encoder
|
||||
data = encoder(envelope) if encoder else self._default_ws_encoder(envelope)
|
||||
await self._ws.send(data)
|
||||
|
||||
async def _start_http_transport(self, options: HttpAdapterOptions) -> None:
|
||||
app = options.app or aiohttp_web.Application()
|
||||
app.add_routes([aiohttp_web.post(options.path, self._handle_http_request)])
|
||||
self._http_runner = aiohttp_web.AppRunner(app)
|
||||
await self._http_runner.setup()
|
||||
self._http_site = aiohttp_web.TCPSite(self._http_runner, options.host, options.port)
|
||||
await self._http_site.start()
|
||||
|
||||
async def _handle_http_request(self, request: aiohttp_web.Request) -> aiohttp_web.Response:
|
||||
raw = await request.read()
|
||||
data = orjson.loads(raw) if raw else {}
|
||||
if isinstance(data, list):
|
||||
await self.on_platform_messages(data)
|
||||
else:
|
||||
await self.on_platform_message(data)
|
||||
return aiohttp_web.json_response({"status": "ok"})
|
||||
|
||||
@staticmethod
|
||||
def _default_ws_parser(raw: str | bytes) -> Any:
|
||||
data = orjson.loads(raw)
|
||||
if isinstance(data, dict) and data.get("type") == "message" and "payload" in data:
|
||||
return data["payload"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _default_ws_encoder(envelope: MessageEnvelope) -> bytes:
|
||||
return orjson.dumps({"type": "send", "payload": envelope})
|
||||
|
||||
|
||||
class InProcessCoreSink:
|
||||
"""
|
||||
简单的进程内 sink,实现 CoreMessageSink 协议。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
|
||||
self._handler = handler
|
||||
|
||||
async def send(self, message: MessageEnvelope) -> None:
|
||||
await self._handler(message)
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None:
|
||||
for message in messages:
|
||||
await self._handler(message)
|
||||
|
||||
|
||||
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
|
||||
send_many = getattr(sink, "send_many", None)
|
||||
if callable(send_many):
|
||||
await send_many(envelopes)
|
||||
return
|
||||
for env in envelopes:
|
||||
await sink.send(env)
|
||||
|
||||
|
||||
class BatchDispatcher:
|
||||
"""
|
||||
将 send 操作合并为批量发送,适合网络 IO 密集场景。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sink: CoreMessageSink,
|
||||
*,
|
||||
max_batch_size: int = 50,
|
||||
flush_interval: float = 0.2,
|
||||
) -> None:
|
||||
self._sink = sink
|
||||
self._max_batch_size = max_batch_size
|
||||
self._flush_interval = flush_interval
|
||||
self._buffer: list[MessageEnvelope] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def add(self, message: MessageEnvelope) -> None:
|
||||
async with self._lock:
|
||||
if self._closed:
|
||||
raise RuntimeError("Dispatcher closed")
|
||||
self._buffer.append(message)
|
||||
self._ensure_timer()
|
||||
if len(self._buffer) >= self._max_batch_size:
|
||||
await self._flush_locked()
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._lock:
|
||||
self._closed = True
|
||||
await self._flush_locked()
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
self._flush_task = None
|
||||
|
||||
def _ensure_timer(self) -> None:
|
||||
if self._flush_task is not None and not self._flush_task.done():
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
self._flush_task = loop.create_task(self._flush_loop())
|
||||
|
||||
async def _flush_loop(self) -> None:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
async with self._lock:
|
||||
await self._flush_locked()
|
||||
except asyncio.CancelledError: # pragma: no cover - timer cancellation
|
||||
pass
|
||||
|
||||
async def _flush_locked(self) -> None:
|
||||
if not self._buffer:
|
||||
return
|
||||
payload = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
await self._sink.send_many(payload)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdapterTransportOptions",
|
||||
"BaseAdapter",
|
||||
"BatchDispatcher",
|
||||
"CoreMessageSink",
|
||||
"HttpAdapterOptions",
|
||||
"InProcessCoreSink",
|
||||
"WebSocketAdapterOptions",
|
||||
]
|
||||
330
src/mofox_bus/api.py
Normal file
330
src/mofox_bus/api.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
from .message_models import MessageBase
|
||||
|
||||
MessagePayload = Dict[str, Any]
|
||||
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
def __init__(self) -> None:
|
||||
self.message_handlers: list[MessageHandler] = []
|
||||
self.background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def register_message_handler(self, handler: MessageHandler) -> None:
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def process_message(self, message: MessagePayload) -> None:
|
||||
tasks: list[asyncio.Task] = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
result = handler(message)
|
||||
if asyncio.iscoroutine(result):
|
||||
task = asyncio.create_task(result)
|
||||
tasks.append(task)
|
||||
self.background_tasks.add(task)
|
||||
task.add_done_callback(self.background_tasks.discard)
|
||||
except Exception: # pragma: no cover - logging only
|
||||
logging.getLogger("mofox_bus.server").exception("Failed to handle message")
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
class MessageServer(BaseMessageHandler):
|
||||
"""
|
||||
WebSocket 消息服务器,支持与 FastAPI 应用共享事件循环。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
*,
|
||||
enable_token: bool = False,
|
||||
app: FastAPI | None = None,
|
||||
path: str = "/ws",
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
mode: Literal["ws", "tcp"] = "ws",
|
||||
custom_logger: logging.Logger | None = None,
|
||||
enable_custom_uvicorn_logger: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||
if custom_logger:
|
||||
logging.getLogger("mofox_bus.server").handlers = custom_logger.handlers
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._app = app or FastAPI()
|
||||
self._own_app = app is None
|
||||
self._path = path
|
||||
self._ssl_certfile = ssl_certfile
|
||||
self._ssl_keyfile = ssl_keyfile
|
||||
self._enable_token = enable_token
|
||||
self._valid_tokens: set[str] = set()
|
||||
self._connections: set[WebSocket] = set()
|
||||
self._platform_connections: dict[str, WebSocket] = {}
|
||||
self._conn_lock = asyncio.Lock()
|
||||
self._server: uvicorn.Server | None = None
|
||||
self._running = False
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
@_self_websocket(self._app, self._path)
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
platform = websocket.headers.get("platform", "unknown")
|
||||
token = websocket.headers.get("authorization") or websocket.headers.get("Authorization")
|
||||
if self._enable_token and not await self.verify_token(token):
|
||||
await websocket.close(code=1008, reason="invalid token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
await self._register_connection(websocket, platform)
|
||||
try:
|
||||
while True:
|
||||
msg = await websocket.receive()
|
||||
if msg["type"] == "websocket.receive":
|
||||
data = msg.get("text")
|
||||
if data is None and msg.get("bytes") is not None:
|
||||
data = msg["bytes"].decode("utf-8")
|
||||
if not data:
|
||||
continue
|
||||
try:
|
||||
payload = orjson.loads(data)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
|
||||
continue
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self.process_message(item)
|
||||
else:
|
||||
await self.process_message(payload)
|
||||
elif msg["type"] == "websocket.disconnect":
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
await self._remove_connection(websocket, platform)
|
||||
|
||||
async def verify_token(self, token: str | None) -> bool:
|
||||
if not self._enable_token:
|
||||
return True
|
||||
return token in self._valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str) -> None:
|
||||
self._valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str) -> None:
|
||||
self._valid_tokens.discard(token)
|
||||
|
||||
async def _register_connection(self, websocket: WebSocket, platform: str) -> None:
|
||||
async with self._conn_lock:
|
||||
self._connections.add(websocket)
|
||||
if platform:
|
||||
previous = self._platform_connections.get(platform)
|
||||
if previous and previous.client_state.name != "DISCONNECTED":
|
||||
await previous.close(code=1000, reason="replaced")
|
||||
self._platform_connections[platform] = websocket
|
||||
|
||||
async def _remove_connection(self, websocket: WebSocket, platform: str) -> None:
|
||||
async with self._conn_lock:
|
||||
self._connections.discard(websocket)
|
||||
if platform and self._platform_connections.get(platform) is websocket:
|
||||
del self._platform_connections[platform]
|
||||
|
||||
async def broadcast_message(self, message: MessagePayload) -> None:
|
||||
data = orjson.dumps(message).decode("utf-8")
|
||||
async with self._conn_lock:
|
||||
targets = list(self._connections)
|
||||
for ws in targets:
|
||||
await ws.send_text(data)
|
||||
|
||||
async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None:
|
||||
ws = self._platform_connections.get(platform)
|
||||
if ws is None:
|
||||
raise RuntimeError(f"No active connection for platform {platform}")
|
||||
await ws.send_text(orjson.dumps(message).decode("utf-8"))
|
||||
|
||||
async def send_message(self, message: MessageBase | MessagePayload) -> None:
|
||||
payload = message.to_dict() if isinstance(message, MessageBase) else message
|
||||
platform = payload.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required to route the message")
|
||||
await self.broadcast_to_platform(platform, payload)
|
||||
|
||||
def run_sync(self) -> None:
|
||||
if not self._own_app:
|
||||
return
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self) -> None:
|
||||
self._running = True
|
||||
if not self._own_app:
|
||||
return
|
||||
config = uvicorn.Config(
|
||||
self._app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
ssl_certfile=self._ssl_certfile,
|
||||
ssl_keyfile=self._ssl_keyfile,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = uvicorn.Server(config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
except asyncio.CancelledError: # pragma: no cover - shutdown path
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
self._server = None
|
||||
async with self._conn_lock:
|
||||
targets = list(self._connections)
|
||||
self._connections.clear()
|
||||
self._platform_connections.clear()
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.close(code=1001, reason="server shutting down")
|
||||
except Exception: # pragma: no cover - best effort
|
||||
pass
|
||||
for task in list(self.background_tasks):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if self.background_tasks:
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
|
||||
class MessageClient(BaseMessageHandler):
|
||||
"""
|
||||
WebSocket 消息客户端,实现双向传输。
|
||||
"""
|
||||
|
||||
def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||
self._mode = mode
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._url: str = ""
|
||||
self._platform: str = ""
|
||||
self._token: str | None = None
|
||||
self._ssl_verify: str | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
platform: str,
|
||||
token: str | None = None,
|
||||
ssl_verify: str | None = None,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._platform = platform
|
||||
self._token = token
|
||||
self._ssl_verify = ssl_verify
|
||||
await self._establish_connection()
|
||||
|
||||
async def _establish_connection(self) -> None:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
headers = {"platform": self._platform}
|
||||
if self._token:
|
||||
headers["authorization"] = self._token
|
||||
ssl_context = None
|
||||
if self._ssl_verify:
|
||||
ssl_context = ssl.create_default_context(cafile=self._ssl_verify)
|
||||
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||
data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8")
|
||||
try:
|
||||
payload = orjson.loads(data)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
|
||||
continue
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self.process_message(item)
|
||||
else:
|
||||
await self.process_message(payload)
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
pass
|
||||
finally:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def run(self) -> None:
|
||||
if self._receive_task is None:
|
||||
await self._establish_connection()
|
||||
try:
|
||||
if self._receive_task:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
pass
|
||||
|
||||
async def send_message(self, message: MessagePayload) -> bool:
|
||||
if self._ws is None or self._ws.closed:
|
||||
raise RuntimeError("WebSocket connection is not established")
|
||||
await self._ws.send_str(orjson.dumps(message).decode("utf-8"))
|
||||
return True
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._closed = True
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._receive_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
|
||||
def _self_websocket(app: FastAPI, path: str):
|
||||
"""
|
||||
装饰器工厂,兼容 FastAPI websocket 路由的声明方式。
|
||||
FastAPI 不允许直接重复注册同一路径,因此这里封装一个可复用的装饰器。
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
app.add_api_websocket_route(path, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
__all__ = ["BaseMessageHandler", "MessageClient", "MessageServer"]
|
||||
87
src/mofox_bus/codec.py
Normal file
87
src/mofox_bus/codec.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _stdlib_json
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
try:
|
||||
import orjson as _json_impl
|
||||
except Exception: # pragma: no cover - fallback when orjson is unavailable
|
||||
_json_impl = None
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
DEFAULT_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def _dumps(obj: Any) -> bytes:
|
||||
if _json_impl is not None:
|
||||
return _json_impl.dumps(obj)
|
||||
return _stdlib_json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def _loads(data: bytes) -> Dict[str, Any]:
|
||||
if _json_impl is not None:
|
||||
return _json_impl.loads(data)
|
||||
return _stdlib_json.loads(data.decode("utf-8"))
|
||||
|
||||
|
||||
def dumps_message(msg: MessageEnvelope) -> bytes:
|
||||
"""
|
||||
将单条 MessageEnvelope 序列化为 JSON bytes。
|
||||
"""
|
||||
if "schema_version" not in msg:
|
||||
msg["schema_version"] = DEFAULT_SCHEMA_VERSION
|
||||
return _dumps(msg)
|
||||
|
||||
|
||||
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
|
||||
"""
|
||||
将多条消息批量序列化,以提升吞吐。
|
||||
"""
|
||||
payload = {
|
||||
"schema_version": DEFAULT_SCHEMA_VERSION,
|
||||
"items": list(messages),
|
||||
}
|
||||
return _dumps(payload)
|
||||
|
||||
|
||||
def loads_message(data: bytes | str) -> MessageEnvelope:
|
||||
"""
|
||||
反序列化单条消息。
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
obj = _loads(data)
|
||||
return _upgrade_schema_if_needed(obj)
|
||||
|
||||
|
||||
def loads_messages(data: bytes | str) -> List[MessageEnvelope]:
|
||||
"""
|
||||
反序列化批量消息。
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
obj = _loads(data)
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version != DEFAULT_SCHEMA_VERSION:
|
||||
raise ValueError(f"Unsupported schema_version={version}")
|
||||
return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])]
|
||||
|
||||
|
||||
def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
|
||||
"""
|
||||
针对未来的 schema 版本演进预留兼容入口。
|
||||
"""
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version == DEFAULT_SCHEMA_VERSION:
|
||||
return obj # type: ignore[return-value]
|
||||
raise ValueError(f"Unsupported schema_version={version}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SCHEMA_VERSION",
|
||||
"dumps_message",
|
||||
"dumps_messages",
|
||||
"loads_message",
|
||||
"loads_messages",
|
||||
]
|
||||
189
src/mofox_bus/message_models.py
Normal file
189
src/mofox_bus/message_models.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
"""
|
||||
消息段,表示一段文本/图片/结构化内容。
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: Any
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Seg":
|
||||
seg_type = data.get("type")
|
||||
seg_data = data.get("data")
|
||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||
seg_data = [Seg.from_dict(item) for item in seg_data]
|
||||
return cls(type=seg_type, data=seg_data)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
if self.type == "seglist" and isinstance(self.data, list):
|
||||
payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data]
|
||||
else:
|
||||
payload = self.data
|
||||
return {"type": self.type, "data": payload}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
platform: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
group_name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]:
|
||||
if not data or data.get("group_id") is None:
|
||||
return None
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
group_id=data.get("group_id"),
|
||||
group_name=data.get("group_name"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
platform: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
user_nickname: Optional[str] = None
|
||||
user_cardname: Optional[str] = None
|
||||
user_avatar: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "UserInfo":
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
user_id=data.get("user_id"),
|
||||
user_nickname=data.get("user_nickname"),
|
||||
user_cardname=data.get("user_cardname"),
|
||||
user_avatar=data.get("user_avatar"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatInfo:
|
||||
content_format: Optional[List[str]] = None
|
||||
accept_format: Optional[List[str]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]:
|
||||
if not data:
|
||||
return None
|
||||
return cls(
|
||||
content_format=data.get("content_format"),
|
||||
accept_format=data.get("accept_format"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
template_items: Optional[Dict[str, str]] = None
|
||||
template_name: Optional[Dict[str, str]] = None
|
||||
template_default: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]:
|
||||
if not data:
|
||||
return None
|
||||
return cls(
|
||||
template_items=data.get("template_items"),
|
||||
template_name=data.get("template_name"),
|
||||
template_default=data.get("template_default", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessageInfo:
|
||||
platform: Optional[str] = None
|
||||
message_id: Optional[str] = None
|
||||
time: Optional[float] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
format_info: Optional[FormatInfo] = None
|
||||
template_info: Optional[TemplateInfo] = None
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
if self.platform is not None:
|
||||
result["platform"] = self.platform
|
||||
if self.message_id is not None:
|
||||
result["message_id"] = self.message_id
|
||||
if self.time is not None:
|
||||
result["time"] = self.time
|
||||
if self.additional_config is not None:
|
||||
result["additional_config"] = self.additional_config
|
||||
if self.group_info is not None:
|
||||
result["group_info"] = self.group_info.to_dict()
|
||||
if self.user_info is not None:
|
||||
result["user_info"] = self.user_info.to_dict()
|
||||
if self.format_info is not None:
|
||||
result["format_info"] = self.format_info.to_dict()
|
||||
if self.template_info is not None:
|
||||
result["template_info"] = self.template_info.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo":
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
message_id=data.get("message_id"),
|
||||
time=data.get("time"),
|
||||
additional_config=data.get("additional_config"),
|
||||
group_info=GroupInfo.from_dict(data.get("group_info", {})),
|
||||
user_info=UserInfo.from_dict(data.get("user_info", {})),
|
||||
format_info=FormatInfo.from_dict(data.get("format_info", {})),
|
||||
template_info=TemplateInfo.from_dict(data.get("template_info", {})),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
message_info: BaseMessageInfo
|
||||
message_segment: Seg
|
||||
raw_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"message_info": self.message_info.to_dict(),
|
||||
"message_segment": self.message_segment.to_dict(),
|
||||
}
|
||||
if self.raw_message is not None:
|
||||
payload["raw_message"] = self.raw_message
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MessageBase":
|
||||
return cls(
|
||||
message_info=BaseMessageInfo.from_dict(data.get("message_info", {})),
|
||||
message_segment=Seg.from_dict(data.get("message_segment", {})),
|
||||
raw_message=data.get("raw_message"),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseMessageInfo",
|
||||
"FormatInfo",
|
||||
"GroupInfo",
|
||||
"MessageBase",
|
||||
"Seg",
|
||||
"TemplateInfo",
|
||||
"UserInfo",
|
||||
]
|
||||
172
src/mofox_bus/router.py
Normal file
172
src/mofox_bus/router.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from .api import MessageClient
|
||||
from .message_models import MessageBase
|
||||
|
||||
logger = logging.getLogger("mofox_bus.router")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetConfig:
|
||||
url: str
|
||||
token: str | None = None
|
||||
ssl_verify: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, str | None]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig":
|
||||
return cls(
|
||||
url=data.get("url", ""),
|
||||
token=data.get("token"),
|
||||
ssl_verify=data.get("ssl_verify"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteConfig:
|
||||
route_config: Dict[str, TargetConfig]
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, str | None]]:
|
||||
return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig":
|
||||
cfg = {
|
||||
platform: TargetConfig.from_dict(target)
|
||||
for platform, target in data.get("route_config", {}).items()
|
||||
}
|
||||
return cls(route_config=cfg)
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None:
|
||||
if custom_logger:
|
||||
logger.handlers = custom_logger.handlers
|
||||
self.config = config
|
||||
self.clients: Dict[str, MessageClient] = {}
|
||||
self.handlers: list[Callable[[Dict], None]] = []
|
||||
self._running = False
|
||||
self._client_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._monitor_task: asyncio.Task | None = None
|
||||
|
||||
async def connect(self, platform: str) -> None:
|
||||
if platform not in self.config.route_config:
|
||||
raise ValueError(f"Unknown platform {platform}")
|
||||
target = self.config.route_config[platform]
|
||||
mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws"
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("TCP mode is not implemented yet")
|
||||
client = MessageClient(mode="ws")
|
||||
await client.connect(
|
||||
url=target.url,
|
||||
platform=platform,
|
||||
token=target.token,
|
||||
ssl_verify=target.ssl_verify,
|
||||
)
|
||||
for handler in self.handlers:
|
||||
client.register_message_handler(handler)
|
||||
self.clients[platform] = client
|
||||
if self._running:
|
||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
||||
|
||||
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
|
||||
self.handlers.append(handler)
|
||||
for client in self.clients.values():
|
||||
client.register_message_handler(handler)
|
||||
|
||||
async def run(self) -> None:
|
||||
self._running = True
|
||||
for platform in self.config.route_config:
|
||||
if platform not in self.clients:
|
||||
await self.connect(platform)
|
||||
for platform, client in self.clients.items():
|
||||
if platform not in self._client_tasks:
|
||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
||||
self._monitor_task = asyncio.create_task(self._monitor_connections())
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
raise
|
||||
|
||||
async def _monitor_connections(self) -> None:
|
||||
await asyncio.sleep(3)
|
||||
while self._running:
|
||||
for platform in list(self.clients.keys()):
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
continue
|
||||
if not client.is_connected():
|
||||
logger.info("Detected disconnect from %s, attempting reconnect", platform)
|
||||
await self._reconnect_platform(platform)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _reconnect_platform(self, platform: str) -> None:
|
||||
await self.remove_platform(platform)
|
||||
if platform in self.config.route_config:
|
||||
await self.connect(platform)
|
||||
|
||||
async def remove_platform(self, platform: str) -> None:
|
||||
if platform in self._client_tasks:
|
||||
task = self._client_tasks.pop(platform)
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
client = self.clients.pop(platform, None)
|
||||
if client:
|
||||
await client.stop()
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._monitor_task
|
||||
self._monitor_task = None
|
||||
for platform in list(self.clients.keys()):
|
||||
await self.remove_platform(platform)
|
||||
self.clients.clear()
|
||||
|
||||
def get_target_url(self, message: MessageBase) -> Optional[str]:
|
||||
platform = message.message_info.platform
|
||||
if not platform:
|
||||
return None
|
||||
target = self.config.route_config.get(platform)
|
||||
return target.url if target else None
|
||||
|
||||
async def send_message(self, message: MessageBase):
|
||||
platform = message.message_info.platform
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required")
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
raise RuntimeError(f"No client connected for platform {platform}")
|
||||
return await client.send_message(message.to_dict())
|
||||
|
||||
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
||||
new_config = RouteConfig.from_dict(config_data)
|
||||
await self._adjust_connections(new_config)
|
||||
self.config = new_config
|
||||
|
||||
async def _adjust_connections(self, new_config: RouteConfig) -> None:
|
||||
current = set(self.config.route_config.keys())
|
||||
updated = set(new_config.route_config.keys())
|
||||
for platform in current - updated:
|
||||
await self.remove_platform(platform)
|
||||
for platform in updated:
|
||||
if platform not in current:
|
||||
await self.connect(platform)
|
||||
else:
|
||||
old = self.config.route_config[platform]
|
||||
new = new_config.route_config[platform]
|
||||
if old.url != new.url or old.token != new.token:
|
||||
await self.remove_platform(platform)
|
||||
await self.connect(platform)
|
||||
133
src/mofox_bus/runtime.py
Normal file
133
src/mofox_bus/runtime.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Iterable, List
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
Hook = Callable[[MessageEnvelope], Awaitable[None] | None]
|
||||
ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
|
||||
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
|
||||
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
|
||||
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
|
||||
|
||||
|
||||
class MessageProcessingError(RuntimeError):
|
||||
"""封装处理链路中发生的异常。"""
|
||||
|
||||
def __init__(self, message: MessageEnvelope, original: BaseException):
|
||||
detail = message.get("id", "<unknown>")
|
||||
super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only
|
||||
self.message_envelope = message
|
||||
self.original = original
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRoute:
|
||||
predicate: Predicate
|
||||
handler: MessageHandler
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class MessageRuntime:
|
||||
"""
|
||||
负责调度消息路由、执行前后 hook 以及批量处理。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._routes: list[MessageRoute] = []
|
||||
self._before_hooks: list[Hook] = []
|
||||
self._after_hooks: list[Hook] = []
|
||||
self._error_hooks: list[ErrorHook] = []
|
||||
self._batch_handler: BatchHandler | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None:
|
||||
with self._lock:
|
||||
self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name))
|
||||
|
||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""
|
||||
装饰器写法,便于在核心逻辑中声明式注册。
|
||||
"""
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
self.add_route(predicate, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||
self._batch_handler = handler
|
||||
|
||||
def register_before_hook(self, hook: Hook) -> None:
|
||||
self._before_hooks.append(hook)
|
||||
|
||||
def register_after_hook(self, hook: Hook) -> None:
|
||||
self._after_hooks.append(hook)
|
||||
|
||||
def register_error_hook(self, hook: ErrorHook) -> None:
|
||||
self._error_hooks.append(hook)
|
||||
|
||||
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
|
||||
await self._run_hooks(self._before_hooks, message)
|
||||
try:
|
||||
route = await self._match_route(message)
|
||||
if route is None:
|
||||
return None
|
||||
result = await _maybe_await(route.handler(message))
|
||||
except Exception as exc: # pragma: no cover - tested indirectly
|
||||
await self._run_error_hooks(message, exc)
|
||||
raise MessageProcessingError(message, exc) from exc
|
||||
await self._run_hooks(self._after_hooks, message)
|
||||
return result
|
||||
|
||||
async def handle_batch(self, messages: Iterable[MessageEnvelope]) -> List[MessageEnvelope]:
|
||||
batch = list(messages)
|
||||
if not batch:
|
||||
return []
|
||||
if self._batch_handler is not None:
|
||||
result = await _maybe_await(self._batch_handler(batch))
|
||||
return result or []
|
||||
responses: list[MessageEnvelope] = []
|
||||
for message in batch:
|
||||
response = await self.handle_message(message)
|
||||
if response is not None:
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||
with self._lock:
|
||||
routes = list(self._routes)
|
||||
for route in routes:
|
||||
should_handle = await _maybe_await(route.predicate(message))
|
||||
if should_handle:
|
||||
return route
|
||||
return None
|
||||
|
||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||
for hook in hooks:
|
||||
await _maybe_await(hook(message))
|
||||
|
||||
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
|
||||
for hook in self._error_hooks:
|
||||
await _maybe_await(hook(message, exc))
|
||||
|
||||
|
||||
async def _maybe_await(result):
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchHandler",
|
||||
"Hook",
|
||||
"MessageHandler",
|
||||
"MessageProcessingError",
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"Predicate",
|
||||
]
|
||||
10
src/mofox_bus/transport/__init__.py
Normal file
10
src/mofox_bus/transport/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
传输层封装,提供 HTTP / WebSocket server & client。
|
||||
"""
|
||||
|
||||
from .http_client import HttpMessageClient
|
||||
from .http_server import HttpMessageServer
|
||||
from .ws_client import WsMessageClient
|
||||
from .ws_server import WsMessageServer
|
||||
|
||||
__all__ = ["HttpMessageClient", "HttpMessageServer", "WsMessageClient", "WsMessageServer"]
|
||||
67
src/mofox_bus/transport/http_client.py
Normal file
67
src/mofox_bus/transport/http_client.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
|
||||
class HttpMessageClient:
|
||||
"""
|
||||
面向消息批量传输的 HTTP 客户端封装。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
timeout: aiohttp.ClientTimeout | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._session = session
|
||||
self._timeout = timeout
|
||||
self._owns_session = session is None
|
||||
self._logger = logging.getLogger("mofox_bus.http_client")
|
||||
|
||||
async def send_messages(
|
||||
self,
|
||||
messages: Sequence[MessageEnvelope],
|
||||
*,
|
||||
expect_reply: bool = False,
|
||||
path: str = "/messages",
|
||||
) -> List[MessageEnvelope] | None:
|
||||
if not messages:
|
||||
return []
|
||||
session = await self._ensure_session()
|
||||
url = f"{self._base_url}{path}"
|
||||
payload = dumps_messages(messages)
|
||||
self._logger.debug("Sending %d message(s) -> %s", len(messages), url)
|
||||
async with session.post(url, data=payload, timeout=self._timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
if not expect_reply:
|
||||
return None
|
||||
raw = await resp.read()
|
||||
replies = loads_messages(raw)
|
||||
self._logger.debug("Received %d reply message(s)", len(replies))
|
||||
return replies
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_session and self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def __aenter__(self) -> "HttpMessageClient":
|
||||
await self._ensure_session()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
53
src/mofox_bus/transport/http_server.py
Normal file
53
src/mofox_bus/transport/http_server.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None]]
|
||||
|
||||
|
||||
class HttpMessageServer:
|
||||
"""
|
||||
轻量级 HTTP 消息入口。可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None:
|
||||
self._handler = handler
|
||||
self._app = web.Application()
|
||||
self._path = path
|
||||
self._app.add_routes([web.post(path, self._handle_messages)])
|
||||
self._logger = logging.getLogger("mofox_bus.http_server")
|
||||
|
||||
async def _handle_messages(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
raw = await request.read()
|
||||
envelopes = loads_messages(raw)
|
||||
self._logger.debug("Received %d message(s)", len(envelopes))
|
||||
except Exception as exc: # pragma: no cover - network errors are integration tested
|
||||
self._logger.exception("Failed to parse incoming messages: %s", exc)
|
||||
raise web.HTTPBadRequest(reason=f"Invalid payload: {exc}") from exc
|
||||
|
||||
result = await self._handler(envelopes)
|
||||
if result is None:
|
||||
return web.Response(status=200, text="ok")
|
||||
payload = dumps_messages(result)
|
||||
return web.Response(status=200, body=payload, content_type="application/json")
|
||||
|
||||
def make_app(self) -> web.Application:
|
||||
"""
|
||||
返回 aiohttp Application,可被外部 server(gunicorn/uvicorn)直接使用。
|
||||
"""
|
||||
|
||||
return self._app
|
||||
|
||||
def add_to_app(self, app: web.Application) -> None:
|
||||
"""
|
||||
将消息路由注册到给定的 aiohttp app,方便与既有服务整合。
|
||||
"""
|
||||
|
||||
app.router.add_post(self._path, self._handle_messages)
|
||||
108
src/mofox_bus/transport/ws_client.py
Normal file
108
src/mofox_bus/transport/ws_client.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Iterable, List, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class WsMessageClient:
|
||||
"""
|
||||
管理 WebSocket 连接,提供 send/receive API,并在后台读取消息。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
handler: IncomingHandler | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
reconnect_interval: float = 5.0,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._handler = handler
|
||||
self._session = session
|
||||
self._reconnect_interval = reconnect_interval
|
||||
self._owns_session = session is None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._logger = logging.getLogger("mofox_bus.ws_client")
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._ensure_session()
|
||||
await self._connect_once()
|
||||
|
||||
async def _connect_once(self) -> None:
|
||||
assert self._session is not None
|
||||
self._ws = await self._session.ws_connect(self._url)
|
||||
self._logger.info("Connected to %s", self._url)
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
ws = await self._ensure_ws()
|
||||
payload = dumps_messages(messages)
|
||||
await ws.send_bytes(payload)
|
||||
|
||||
async def send_message(self, message: MessageEnvelope) -> None:
|
||||
await self.send_messages([message])
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._owns_session and self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type in (aiohttp.WSMsgType.BINARY, aiohttp.WSMsgType.TEXT):
|
||||
envelopes = loads_messages(msg.data)
|
||||
for env in envelopes:
|
||||
if self._handler is not None:
|
||||
await self._handler(env)
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.warning("WebSocket error: %s", msg.data)
|
||||
break
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
return
|
||||
finally:
|
||||
if not self._closed:
|
||||
await self._reconnect()
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
|
||||
if self._ws is None or self._ws.closed:
|
||||
await self._connect_once()
|
||||
assert self._ws is not None
|
||||
return self._ws
|
||||
|
||||
async def __aenter__(self) -> "WsMessageClient":
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
66
src/mofox_bus/transport/ws_server.py
Normal file
66
src/mofox_bus/transport/ws_server.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Awaitable, Callable, Iterable, List, Set
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class WsMessageServer:
|
||||
"""
|
||||
封装 WebSocket 服务端逻辑,负责接收消息并广播响应。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None:
|
||||
self._handler = handler
|
||||
self._app = web.Application()
|
||||
self._path = path
|
||||
self._app.add_routes([web.get(path, self._handle_ws)])
|
||||
self._connections: Set[web.WebSocketResponse] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._logger = logging.getLogger("mofox_bus.ws_server")
|
||||
|
||||
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._logger.info("WebSocket connection opened: %s", request.remote)
|
||||
|
||||
async with self._track_connection(ws):
|
||||
async for message in ws:
|
||||
if message.type in (WSMsgType.BINARY, WSMsgType.TEXT):
|
||||
envelopes = loads_messages(message.data)
|
||||
for env in envelopes:
|
||||
await self._handler(env)
|
||||
elif message.type == WSMsgType.ERROR:
|
||||
self._logger.warning("WebSocket connection error: %s", ws.exception())
|
||||
break
|
||||
|
||||
self._logger.info("WebSocket connection closed: %s", request.remote)
|
||||
return ws
|
||||
|
||||
@asynccontextmanager
|
||||
async def _track_connection(self, ws: web.WebSocketResponse):
|
||||
async with self._lock:
|
||||
self._connections.add(ws)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._connections.discard(ws)
|
||||
|
||||
async def broadcast(self, messages: Iterable[MessageEnvelope]) -> None:
|
||||
payload = dumps_messages(list(messages))
|
||||
async with self._lock:
|
||||
targets = list(self._connections)
|
||||
for ws in targets:
|
||||
await ws.send_bytes(payload)
|
||||
|
||||
def make_app(self) -> web.Application:
|
||||
return self._app
|
||||
162
src/mofox_bus/types.py
Normal file
162
src/mofox_bus/types.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
||||
|
||||
MessageDirection = Literal["incoming", "outgoing"]
|
||||
Role = Literal["user", "assistant", "system", "tool", "platform"]
|
||||
ContentType = Literal[
|
||||
"text",
|
||||
"image",
|
||||
"audio",
|
||||
"file",
|
||||
"video",
|
||||
"event",
|
||||
"command",
|
||||
"system",
|
||||
]
|
||||
|
||||
EventType = Literal[
|
||||
"message_created",
|
||||
"message_updated",
|
||||
"message_deleted",
|
||||
"member_join",
|
||||
"member_leave",
|
||||
"typing",
|
||||
"reaction_add",
|
||||
"reaction_remove",
|
||||
]
|
||||
|
||||
|
||||
class TextContent(TypedDict, total=False):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
markdown: NotRequired[bool]
|
||||
entities: NotRequired[List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class ImageContent(TypedDict, total=False):
|
||||
type: Literal["image"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
width: NotRequired[int]
|
||||
height: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
|
||||
|
||||
class FileContent(TypedDict, total=False):
|
||||
type: Literal["file"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
file_name: NotRequired[str]
|
||||
file_size: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
|
||||
|
||||
class AudioContent(TypedDict, total=False):
|
||||
type: Literal["audio"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
duration_ms: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
|
||||
|
||||
class VideoContent(TypedDict, total=False):
|
||||
type: Literal["video"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
duration_ms: NotRequired[int]
|
||||
width: NotRequired[int]
|
||||
height: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
|
||||
|
||||
class EventContent(TypedDict):
|
||||
type: Literal["event"]
|
||||
event_type: EventType
|
||||
raw: Dict[str, Any]
|
||||
|
||||
|
||||
class CommandContent(TypedDict, total=False):
|
||||
type: Literal["command"]
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
class SystemContent(TypedDict):
|
||||
type: Literal["system"]
|
||||
text: str
|
||||
|
||||
|
||||
Content = (
|
||||
TextContent
|
||||
| ImageContent
|
||||
| FileContent
|
||||
| AudioContent
|
||||
| VideoContent
|
||||
| EventContent
|
||||
| CommandContent
|
||||
| SystemContent
|
||||
)
|
||||
|
||||
|
||||
class SenderInfo(TypedDict, total=False):
|
||||
user_id: str
|
||||
role: Role
|
||||
display_name: NotRequired[str]
|
||||
avatar_url: NotRequired[str]
|
||||
raw: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
class ChannelInfo(TypedDict, total=False):
|
||||
channel_id: str
|
||||
channel_type: Literal[
|
||||
"private",
|
||||
"group",
|
||||
"supergroup",
|
||||
"channel",
|
||||
"dm",
|
||||
"room",
|
||||
"thread",
|
||||
]
|
||||
title: NotRequired[str]
|
||||
workspace_id: NotRequired[str]
|
||||
raw: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
class MessageEnvelope(TypedDict, total=False):
|
||||
id: str
|
||||
direction: MessageDirection
|
||||
platform: str
|
||||
timestamp_ms: int
|
||||
channel: ChannelInfo
|
||||
sender: SenderInfo
|
||||
content: Content
|
||||
conversation_id: str
|
||||
thread_id: NotRequired[str]
|
||||
reply_to_message_id: NotRequired[str]
|
||||
correlation_id: NotRequired[str]
|
||||
is_edited: NotRequired[bool]
|
||||
is_ephemeral: NotRequired[bool]
|
||||
raw_platform_message: NotRequired[Dict[str, Any]]
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
schema_version: NotRequired[int]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AudioContent",
|
||||
"ChannelInfo",
|
||||
"CommandContent",
|
||||
"Content",
|
||||
"ContentType",
|
||||
"EventContent",
|
||||
"EventType",
|
||||
"FileContent",
|
||||
"ImageContent",
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
"Role",
|
||||
"SenderInfo",
|
||||
"SystemContent",
|
||||
"TextContent",
|
||||
"VideoContent",
|
||||
]
|
||||
@@ -91,7 +91,7 @@ import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from maim_message import Seg, UserInfo
|
||||
from mofox_bus import Seg, UserInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
@@ -34,7 +34,7 @@ class InjectionRule:
|
||||
raise ValueError(f"'{self.injection_type.value}'类型的注入规则必须提供 'target_content'。")
|
||||
|
||||
|
||||
from maim_message import Seg
|
||||
from mofox_bus import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
|
||||
@@ -10,7 +10,7 @@ from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import filetype
|
||||
from maim_message import UserInfo
|
||||
from mofox_bus import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from maim_message import RouteConfig, Router, TargetConfig
|
||||
from mofox_bus import RouteConfig, Router, TargetConfig
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import websockets as Server
|
||||
from maim_message import (
|
||||
from mofox_bus import (
|
||||
BaseMessageInfo,
|
||||
FormatInfo,
|
||||
GroupInfo,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from maim_message import MessageBase, Router
|
||||
from mofox_bus import MessageBase, Router
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from typing import ClassVar, Optional, Tuple
|
||||
|
||||
import websockets as Server
|
||||
from maim_message import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
|
||||
from mofox_bus import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import time
|
||||
import websockets as Server
|
||||
import uuid
|
||||
from maim_message import (
|
||||
from mofox_bus import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
|
||||
@@ -406,7 +406,7 @@ console_log_level = "INFO" # 控制台日志级别,可选: DEBUG, INFO, WARNIN
|
||||
file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
|
||||
# 第三方库日志控制
|
||||
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "aiosqlite", "openai","uvicorn","rjieba","maim_message"] # 完全屏蔽的库
|
||||
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "aiosqlite", "openai","uvicorn","rjieba","message_bus"] # 完全屏蔽的库
|
||||
library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别
|
||||
|
||||
[dependency_management] # 插件Python依赖管理配置
|
||||
@@ -423,10 +423,10 @@ install_log_level = "INFO"
|
||||
[debug]
|
||||
show_prompt = false # 是否显示prompt
|
||||
|
||||
[maim_message]
|
||||
[message_bus]
|
||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||
# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器
|
||||
use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复
|
||||
# 以下项目若要使用需要打开use_custom,并单独配置message_bus的服务器
|
||||
use_custom = false # 是否启用自定义的message_bus服务器,注意这需要设置新的端口,不能与.env重复
|
||||
host="127.0.0.1"
|
||||
port=8090
|
||||
mode="ws" # 支持ws和tcp两种模式
|
||||
@@ -631,4 +631,4 @@ throw_topic_weight = 0.3 # throw_topic动作的基础权重
|
||||
|
||||
# --- 调试与监控 ---
|
||||
enable_statistics = false # 是否启用统计功能(记录触发次数、决策分布等)
|
||||
log_decisions = false # 是否记录每次决策的详细日志(用于调试)
|
||||
log_decisions = false # 是否记录每次决策的详细日志(用于调试)
|
||||
|
||||
Reference in New Issue
Block a user