feat: 添加带有消息处理和路由功能的NEW_napcat_adapter插件

- 为NEW_napcat_adapter插件实现了核心模块,包括消息处理、事件处理和路由。
- 创建了MessageHandler、MetaEventHandler和NoticeHandler来处理收到的消息和事件。
- 开发了SendHandler,用于向Napcat发送回消息。
引入了StreamRouter来管理多个聊天流,确保消息的顺序和高效处理。
- 增加了对各种消息类型和格式的支持,包括文本、图像和通知。
- 建立了一个用于监控和调试的日志系统。
This commit is contained in:
Windpicker-owo
2025-11-24 13:24:55 +08:00
parent b08c70dfa6
commit 36fce6ca98
28 changed files with 3041 additions and 824 deletions

View File

@@ -32,11 +32,11 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess
## 3. 消息模型
### 3.1 Envelope TypedDict`types.py`
### 3.1 Envelope TypedDict<EFBFBD><EFBFBD>`types.py`<EFBFBD><EFBFBD>
- `MessageEnvelope`:核心字段包括 `id``direction``platform``timestamp_ms``channel``sender``content` 等,一律使用毫秒时间戳,保留 `raw_platform_message``metadata` 便于调试 / 扩展。
- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。
- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。
- `MessageEnvelope` <20><>ȫ<EFBFBD><C8AB>Ƶ<EFBFBD> maim_message <20><EFBFBD><E1B9B9><EFBFBD><EFBFBD><EFBFBD>ĵ<EFBFBD><C4B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> `message_info` + `message_segment` (SegPayload)<29><>`direction`<EFBFBD><EFBFBD>`schema_version` <20><> raw <20><><EFBFBD><EFBFBD><EFBFBD>ֶβ<D6B6><CEB2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ˣ<EFBFBD><CBA3><EFBFBD>Ժ<EFBFBD><D4BA><EFBFBD><EFBFBD><EFBFBD> `channel`<EFBFBD><EFBFBD>`sender`<EFBFBD><EFBFBD>`content` <EFBFBD><EFBFBD> v0 <20>ֶΪ<D6B6><CEAA>ѡ<EFBFBD><D1A1>
- `SegPayload` / `MessageInfoPayload` / `UserInfoPayload` / `GroupInfoPayload` / `FormatInfoPayload` / `TemplateInfoPayload` <20><> maim_message dataclass <20>Դ<EFBFBD>TypedDict <20><>Ӧ<EFBFBD><D3A6><EFBFBD>ʺ<EFBFBD>ֱ<EFBFBD><D6B1> JSON <20><><EFBFBD><EFBFBD>
- `Content` / `SenderInfo` / `ChannelInfo` <EFBFBD>Ȳ<EFBFBD>Ȼ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ܻ<EFBFBD><EFBFBD><EFBFBD> IDE ע<>⣬Ҳ<E2A3AC>Ƕ<EFBFBD> v0 content ģ<>͵Ļ<CDB5>֧
### 3.2 dataclass 消息段(`message_models.py`
@@ -62,15 +62,14 @@ TypedDict 更适合网络传输和依赖注入dataclass 版 MessageBase 则
## 5. 运行时调度(`runtime.py`
- `MessageRuntime`
- `add_route(predicate, handler)` `@runtime.route(...)` 装饰器注册消息处理器
- `register_before_hook` / `register_after_hook` / `register_error_hook`入监控、埋点、Trace
- `set_batch_handler` 支持一次处理批消息(例如批量落库)。
- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。
- `add_route(predicate, handler)` `@runtime.route(...)` 装饰器注册消息处理器
- `register_before_hook` / `register_after_hook` / `register_error_hook`册前置、后置、Trace 处理
- `set_batch_handler` 支持一次处理批消息(可用于 batch IO 优化)
- `MessageProcessingError` 在 handler 抛出异常时包装原因,方便日志追踪。
运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。
---
## 6. 传输层封装(`transport/`
### 6.1 HTTP
@@ -126,9 +125,9 @@ from mofox_bus.transport import HttpMessageServer
runtime = MessageRuntime()
@runtime.route(lambda env: env["content"]["type"] == "text")
@runtime.route(lambda env: (env.get("message_segment") or {}).get("type") == "text")
async def handle_text(env: types.MessageEnvelope):
print("收到文本", env["content"]["text"])
print("收到文本", env["message_segment"]["data"])
async def http_handler(messages: list[types.MessageEnvelope]):
await runtime.handle_batch(messages)

35
src/common/core_sink.py Normal file
View File

@@ -0,0 +1,35 @@
"""
从 src.main 导出 core_sink 的辅助函数
由于 src.main 中实际使用的是 InProcessCoreSink
我们需要创建一个全局访问点
"""
from mofox_bus import CoreSink, InProcessCoreSink
_global_core_sink: CoreSink | None = None
def set_core_sink(sink: CoreSink) -> None:
"""设置全局 core sink"""
global _global_core_sink
_global_core_sink = sink
def get_core_sink() -> CoreSink:
"""获取全局 core sink"""
global _global_core_sink
if _global_core_sink is None:
raise RuntimeError("Core sink 尚未初始化")
return _global_core_sink
async def push_outgoing(envelope) -> None:
"""将消息推送到 core sink 的 outgoing 通道"""
sink = get_core_sink()
push = getattr(sink, "push_outgoing", None)
if push is None:
raise RuntimeError("当前 core sink 不支持 push_outgoing 方法")
await push(envelope)
__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"]

View File

@@ -1,15 +1,22 @@
"""
MessageEnvelope 转换器
MessageEnvelope converter between mofox_bus schema and internal message structures.
将 mofox_bus MessageEnvelope 转换为 MoFox Bot 内部使用的消息格式
"""
- 优先处理 maim_message 风格message_info + message_segment
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
"""
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
from mofox_bus import MessageEnvelope, BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
from mofox_bus import (
BaseMessageInfo,
MessageBase,
MessageEnvelope,
Seg,
UserInfo,
GroupInfo,
)
from src.common.logger import get_logger
@@ -17,151 +24,221 @@ logger = get_logger("envelope_converter")
class EnvelopeConverter:
"""MessageEnvelope 到内部消息格式的转换器"""
"""MessageEnvelope <-> MessageBase converter."""
@staticmethod
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
"""
MessageEnvelope 转换为 MessageBase
Args:
envelope: 统一的消息信封
Returns:
MessageBase: 内部消息格式
Convert MessageEnvelope to MessageBase.
"""
try:
# 提取基本信息
platform = envelope["platform"]
channel = envelope["channel"]
sender = envelope["sender"]
content = envelope["content"]
# 创建 UserInfo
user_info = UserInfo(
user_id=sender["user_id"],
user_nickname=sender.get("display_name", sender["user_id"]),
user_avatar=sender.get("avatar_url"),
)
# 创建 GroupInfo (如果是群组消息)
group_info: Optional[GroupInfo] = None
if channel["channel_type"] in ("group", "supergroup", "room"):
group_info = GroupInfo(
group_id=channel["channel_id"],
group_name=channel.get("title", channel["channel_id"]),
)
# 创建 BaseMessageInfo
message_info = BaseMessageInfo(
platform=platform,
chat_type="group" if group_info else "private",
message_id=envelope["id"],
user_info=user_info,
group_info=group_info,
timestamp=envelope["timestamp_ms"] / 1000.0, # 转换为秒
)
# 转换 Content 为 Seg 列表
segments = EnvelopeConverter._content_to_segments(content)
# 创建 MessageBase
message_base = MessageBase(
# 优先使用 maim_message 样式字段
info_payload = envelope.get("message_info") or {}
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
if info_payload:
message_info = BaseMessageInfo.from_dict(info_payload)
else:
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
if seg_payload is None:
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
seg_payload = seg_list
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
return MessageBase(
message_info=message_info,
message=segments,
message_segment=message_segment,
raw_message=raw_message,
)
# 保存原始 envelope 到 raw 字段
if hasattr(message_base, "raw"):
message_base.raw = envelope
return message_base
except Exception as e:
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
raise
@staticmethod
def _content_to_segments(content: Dict[str, Any]) -> List[Seg]:
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
"""将 legacy 字段映射为 BaseMessageInfo。"""
platform = envelope.get("platform")
channel = envelope.get("channel") or {}
sender = envelope.get("sender") or {}
message_id = envelope.get("id") or envelope.get("message_id")
timestamp_ms = envelope.get("timestamp_ms")
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
group_info: Optional[GroupInfo] = None
channel_type = channel.get("channel_type")
if channel_type in ("group", "supergroup", "room"):
group_info = GroupInfo(
platform=platform,
group_id=channel.get("channel_id"),
group_name=channel.get("title"),
)
user_info: Optional[UserInfo] = None
if sender:
user_info = UserInfo(
platform=platform,
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
user_avatar=sender.get("avatar_url"),
)
return BaseMessageInfo(
platform=platform,
message_id=message_id,
time=time_value,
group_info=group_info,
user_info=user_info,
additional_config=envelope.get("metadata"),
)
@staticmethod
def _ensure_seg(payload: Any) -> Seg:
"""将任意 payload 转为 Seg dataclass。"""
if isinstance(payload, Seg):
return payload
if isinstance(payload, list):
# 直接传入 Seg 列表或 seglist data
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
if isinstance(payload, dict):
seg_type = payload.get("type") or "text"
data = payload.get("data")
if seg_type == "seglist" and isinstance(data, list):
data = [EnvelopeConverter._ensure_seg(item) for item in data]
return Seg(type=seg_type, data=data)
# 兜底:转成文本片段
return Seg(type="text", data="" if payload is None else str(payload))
@staticmethod
def _flatten_segments(seg: Seg) -> List[Seg]:
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
if seg.type == "seglist" and isinstance(seg.data, list):
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
return [seg]
@staticmethod
def _content_to_segments(content: Any) -> List[Seg]:
"""
将 Content 转换为 Seg 列表
Args:
content: 消息内容
Returns:
List[Seg]: 消息段列表
Convert legacy Content (type/data/metadata) to a flat list of Seg.
"""
segments: List[Seg] = []
content_type = content.get("type")
if content_type == "text":
# 文本消息
text = content.get("text", "")
segments.append(Seg.text(text))
elif content_type == "image":
# 图片消息
url = content.get("url", "")
file_id = content.get("file_id")
segments.append(Seg.image(url if url else file_id))
elif content_type == "audio":
# 音频消息
url = content.get("url", "")
file_id = content.get("file_id")
segments.append(Seg.record(url if url else file_id))
elif content_type == "video":
# 视频消息
url = content.get("url", "")
file_id = content.get("file_id")
segments.append(Seg.video(url if url else file_id))
elif content_type == "file":
# 文件消息
url = content.get("url", "")
file_name = content.get("file_name", "file")
# 使用 text 表示文件(或者可以自定义一个 file seg type
segments.append(Seg.text(f"[文件: {file_name}]"))
elif content_type == "command":
# 命令消息
name = content.get("name", "")
args = content.get("args", {})
# 重构为文本格式
cmd_text = f"/{name}"
if args:
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
segments.append(Seg.text(cmd_text))
elif content_type == "event":
# 事件消息 - 转换为文本表示
event_type = content.get("event_type", "unknown")
segments.append(Seg.text(f"[事件: {event_type}]"))
elif content_type == "system":
# 系统消息
text = content.get("text", "")
segments.append(Seg.text(f"[系统] {text}"))
else:
# 未知类型 - 转换为文本
def _walk(node: Any) -> None:
if node is None:
return
if isinstance(node, list):
for item in node:
_walk(item)
return
if not isinstance(node, dict):
logger.warning("未知的 content 节点类型: %s", type(node))
return
content_type = node.get("type")
data = node.get("data")
metadata = node.get("metadata") or {}
if content_type == "collection":
items = data if isinstance(data, list) else node.get("items", [])
for item in items:
_walk(item)
return
if content_type in ("text", "at"):
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
text = "" if data is None else str(data)
if subtype in ("at", "mention"):
user_info = metadata.get("user") or {}
seg_data: Dict[str, Any] = {
"user_id": user_info.get("id") or user_info.get("user_id"),
"user_name": user_info.get("name") or user_info.get("display_name"),
"text": text,
"raw": user_info.get("raw") or user_info if user_info else None,
}
if any(v is not None for v in seg_data.values()):
segments.append(Seg(type="at", data=seg_data))
else:
segments.append(Seg(type="at", data=text))
else:
segments.append(Seg(type="text", data=text))
return
if content_type == "image":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="image", data=url))
return
if content_type == "audio":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="record", data=url))
return
if content_type == "video":
url = ""
if isinstance(data, dict):
url = data.get("url") or data.get("file") or data.get("file_id") or ""
elif data is not None:
url = str(data)
segments.append(Seg(type="video", data=url))
return
if content_type == "file":
file_name = ""
if isinstance(data, dict):
file_name = data.get("file_name") or data.get("name") or ""
text = file_name or "[file]"
segments.append(Seg(type="text", data=text))
return
if content_type == "command":
name = ""
args: Dict[str, Any] = {}
if isinstance(data, dict):
name = data.get("name", "")
args = data.get("args", {}) or {}
else:
name = str(data or "")
cmd_text = f"/{name}" if name else "/command"
if args:
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
segments.append(Seg(type="text", data=cmd_text))
return
if content_type == "event":
event_type = ""
if isinstance(data, dict):
event_type = data.get("event_type", "")
else:
event_type = str(data or "")
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
return
if content_type == "system":
text = "" if data is None else str(data)
segments.append(Seg(type="text", data=f"[系统] {text}"))
return
logger.warning(f"未知的消息类型: {content_type}")
segments.append(Seg.text(f"[未知消息类型: {content_type}]"))
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
_walk(content)
return segments
@staticmethod
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
"""
MessageEnvelope 转换为旧版字典格式(用于向后兼容)
Args:
envelope: 统一的消息信封
Returns:
Dict[str, Any]: 旧版消息字典
Convert MessageEnvelope to legacy dict for backward compatibility.
"""
message_base = EnvelopeConverter.to_message_base(envelope)
return message_base.to_dict()
@@ -169,61 +246,45 @@ class EnvelopeConverter:
@staticmethod
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
"""
MessageBase 转换为 MessageEnvelope (反向转换)
Args:
message: 内部消息格式
direction: 消息方向 ("incoming""outgoing")
Returns:
MessageEnvelope: 统一的消息信封
Convert MessageBase to MessageEnvelope (maim_message style preferred).
"""
try:
message_info = message.message_info
user_info = message_info.user_info
group_info = message_info.group_info
# 创建 SenderInfo
sender = {
"user_id": user_info.user_id,
"role": "assistant" if direction == "outgoing" else "user",
}
if user_info.user_nickname:
sender["display_name"] = user_info.user_nickname
if user_info.user_avatar:
sender["avatar_url"] = user_info.user_avatar
# 创建 ChannelInfo
if group_info:
channel = {
"channel_id": group_info.group_id,
"channel_type": "group",
}
if group_info.group_name:
channel["title"] = group_info.group_name
else:
channel = {
"channel_id": user_info.user_id,
"channel_type": "private",
}
# 转换 segments 为 Content
content = EnvelopeConverter._segments_to_content(message.message)
# 创建 MessageEnvelope
info_dict = message.message_info.to_dict()
seg_dict = message.message_segment.to_dict()
envelope: MessageEnvelope = {
"id": message_info.message_id,
"direction": direction,
"platform": message_info.platform,
"timestamp_ms": int(message_info.timestamp * 1000),
"channel": channel,
"sender": sender,
"content": content,
"conversation_id": group_info.group_id if group_info else user_info.user_id,
"message_info": info_dict,
"message_segment": seg_dict,
"platform": info_dict.get("platform"),
"message_id": info_dict.get("message_id"),
"schema_version": 1,
}
if message.message_info.time is not None:
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
if message.raw_message is not None:
envelope["raw_message"] = message.raw_message
# legacy 补充,方便老代码继续工作
segments = EnvelopeConverter._flatten_segments(message.message_segment)
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
if message.message_info.user_info:
envelope["sender"] = {
"user_id": message.message_info.user_info.user_id,
"role": "assistant" if direction == "outgoing" else "user",
"display_name": message.message_info.user_info.user_nickname,
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
}
if message.message_info.group_info:
envelope["channel"] = {
"channel_id": message.message_info.group_info.group_id,
"channel_type": "group",
"title": message.message_info.group_info.group_name,
}
return envelope
except Exception as e:
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
raise
@@ -231,45 +292,50 @@ class EnvelopeConverter:
@staticmethod
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
"""
将 Seg 列表转换为 Content
Args:
segments: 消息段列表
Returns:
Dict[str, Any]: 消息内容
Convert Seg list to legacy Content (type/data/metadata).
"""
if not segments:
return {"type": "text", "text": ""}
# 简化处理:如果有多个段,合并为文本
return {"type": "text", "data": ""}
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
data = seg.data
if seg.type == "text":
return {"type": "text", "data": data}
if seg.type == "at":
content: Dict[str, Any] = {"type": "text", "data": ""}
metadata: Dict[str, Any] = {"subtype": "at"}
if isinstance(data, dict):
content["data"] = data.get("text", "")
user = {
"id": data.get("user_id"),
"name": data.get("user_name"),
"raw": data.get("raw"),
}
if any(v is not None for v in user.values()):
metadata["user"] = user
else:
content["data"] = data
if metadata:
content["metadata"] = metadata
return content
if seg.type == "image":
return {"type": "image", "data": data}
if seg.type in ("record", "voice", "audio"):
return {"type": "audio", "data": data}
if seg.type == "video":
return {"type": "video", "data": data}
return {"type": seg.type, "data": data}
if len(segments) == 1:
seg = segments[0]
if seg.type == "text":
return {"type": "text", "text": seg.data.get("text", "")}
elif seg.type == "image":
return {"type": "image", "url": seg.data.get("file", "")}
elif seg.type == "record":
return {"type": "audio", "url": seg.data.get("file", "")}
elif seg.type == "video":
return {"type": "video", "url": seg.data.get("file", "")}
# 多个段或未知类型 - 合并为文本
text_parts = []
for seg in segments:
if seg.type == "text":
text_parts.append(seg.data.get("text", ""))
elif seg.type == "image":
text_parts.append("[图片]")
elif seg.type == "record":
text_parts.append("[语音]")
elif seg.type == "video":
text_parts.append("[视频]")
else:
text_parts.append(f"[{seg.type}]")
return {"type": "text", "text": "".join(text_parts)}
return _seg_to_content(segments[0])
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
__all__ = ["EnvelopeConverter"]

View File

@@ -10,72 +10,54 @@ from .adapter_utils import (
AdapterTransportOptions,
AdapterBase,
BatchDispatcher,
CoreSink,
CoreMessageSink,
HttpAdapterOptions,
InProcessCoreSink,
ProcessCoreSink,
ProcessCoreSinkServer,
WebSocketLike,
WebSocketAdapterOptions,
)
from .api import MessageClient, MessageServer
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo
from .builder import MessageBuilder
from .router import RouteConfig, Router, TargetConfig
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware
from .types import (
AudioContent,
ChannelInfo,
CommandContent,
Content,
ContentType,
EventContent,
EventType,
FileContent,
ImageContent,
FormatInfoPayload,
GroupInfoPayload,
MessageDirection,
MessageEnvelope,
Role,
SenderInfo,
SystemContent,
TextContent,
VideoContent,
MessageInfoPayload,
SegPayload,
TemplateInfoPayload,
UserInfoPayload,
)
__all__ = [
# TypedDict model
"AudioContent",
"ChannelInfo",
"CommandContent",
"Content",
"ContentType",
"EventContent",
"EventType",
"FileContent",
"ImageContent",
"MessageDirection",
"MessageEnvelope",
"Role",
"SenderInfo",
"SystemContent",
"TextContent",
"VideoContent",
"SegPayload",
"UserInfoPayload",
"GroupInfoPayload",
"FormatInfoPayload",
"TemplateInfoPayload",
"MessageInfoPayload",
# Codec helpers
"codec",
"dumps_message",
"dumps_messages",
"loads_message",
"loads_messages",
"MessageBuilder",
# Runtime / routing
"MessageRoute",
"MessageRuntime",
"MessageProcessingError",
# Message dataclasses
"Seg",
"GroupInfo",
"UserInfo",
"FormatInfo",
"TemplateInfo",
"BaseMessageInfo",
"MessageBase",
"Middleware",
# Server/client/router
"MessageServer",
"MessageClient",
@@ -86,8 +68,11 @@ __all__ = [
"AdapterTransportOptions",
"AdapterBase",
"BatchDispatcher",
"CoreSink",
"CoreMessageSink",
"InProcessCoreSink",
"ProcessCoreSink",
"ProcessCoreSinkServer",
"WebSocketLike",
"WebSocketAdapterOptions",
"HttpAdapterOptions",

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
import multiprocessing as mp
from dataclasses import dataclass
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
@@ -11,6 +13,11 @@ import websockets
from .types import MessageEnvelope
logger = logging.getLogger("mofox_bus.adapter")
OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]]
class CoreMessageSink(Protocol):
async def send(self, message: MessageEnvelope) -> None: ...
@@ -18,6 +25,22 @@ class CoreMessageSink(Protocol):
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
class CoreSink(CoreMessageSink, Protocol):
"""
双向 CoreSink 协议:
- send/send_many: 适配器 → 核心incoming
- push_outgoing: 核心 → 适配器outgoing
"""
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ...
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ...
async def push_outgoing(self, envelope: MessageEnvelope) -> None: ...
async def close(self) -> None: ... # pragma: no cover - lifecycle hook
class WebSocketLike(Protocol):
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
@@ -56,7 +79,7 @@ class AdapterBase:
platform: str = "unknown"
def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None):
def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None):
"""
Args:
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
@@ -70,14 +93,31 @@ class AdapterBase:
self._http_site: aiohttp_web.BaseSite | None = None
async def start(self) -> None:
"""根据配置自动启动 WS/HTTP 监听"""
"""启动适配器的传输层监听(如果配置了传输选项)"""
if hasattr(self.core_sink, "set_outgoing_handler"):
try:
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
except Exception:
logger.exception("Failed to register outgoing handler on core sink")
if isinstance(self._transport_config, WebSocketAdapterOptions):
await self._start_ws_transport(self._transport_config)
elif isinstance(self._transport_config, HttpAdapterOptions):
await self._start_http_transport(self._transport_config)
async def stop(self) -> None:
"""停止自动管理的传输层"""
"""停止适配器的传输层监听(如果配置了传输选项)"""
remove = getattr(self.core_sink, "remove_outgoing_handler", None)
if callable(remove):
try:
remove(self._on_outgoing_from_core)
except Exception:
logger.exception("Failed to detach outgoing handler on core sink")
elif hasattr(self.core_sink, "set_outgoing_handler"):
try:
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
except Exception:
logger.exception("Failed to detach outgoing handler on core sink")
if self._ws_task:
self._ws_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
@@ -95,12 +135,12 @@ class AdapterBase:
async def on_platform_message(self, raw: Any) -> None:
"""处理平台下发的单条消息并交给核心。"""
envelope = self.from_platform_message(raw)
envelope = await _maybe_await(self.from_platform_message(raw))
await self.core_sink.send(envelope)
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
"""批量推送入口,内部自动批量或逐条送入核心。"""
envelopes = [self.from_platform_message(raw) for raw in raw_messages]
envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages]
await _send_many(self.core_sink, envelopes)
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
@@ -112,7 +152,14 @@ class AdapterBase:
for env in envelopes:
await self._send_platform_message(env)
def from_platform_message(self, raw: Any) -> MessageEnvelope:
async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None:
"""核心生成 outgoing envelope 时的内部处理逻辑"""
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
if platform and platform != getattr(self, "platform", None):
return
await self._send_platform_message(envelope)
def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]:
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
raise NotImplementedError
@@ -175,13 +222,22 @@ class AdapterBase:
return orjson.dumps({"type": "send", "payload": envelope})
class InProcessCoreSink:
class InProcessCoreSink(CoreSink):
"""
简单的进程内 sink实现 CoreMessageSink 协议。
进程内核心消息 sink实现 CoreSink 协议。
"""
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
self._handler = handler
self._outgoing_handlers: set[OutgoingHandler] = set()
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
if handler is None:
return
self._outgoing_handlers.add(handler)
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
self._outgoing_handlers.discard(handler)
async def send(self, message: MessageEnvelope) -> None:
await self._handler(message)
@@ -190,6 +246,140 @@ class InProcessCoreSink:
for message in messages:
await self._handler(message)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
if not self._outgoing_handlers:
logger.debug("Outgoing envelope dropped: no handler registered")
return
for callback in list(self._outgoing_handlers):
await callback(envelope)
async def close(self) -> None: # pragma: no cover - symmetry
self._outgoing_handlers.clear()
class ProcessCoreSink(CoreSink):
"""
进程间核心消息 sink实现 CoreSink 协议,使用 multiprocessing.Queue 初始化
"""
_CONTROL_STOP = {"__core_sink_control__": "stop"}
def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None:
self._to_core_queue = to_core_queue
self._from_core_queue = from_core_queue
self._outgoing_handler: OutgoingHandler | None = None
self._closed = False
self._listener_task: asyncio.Task | None = None
self._loop = asyncio.get_event_loop()
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
self._outgoing_handler = handler
if handler is not None and (self._listener_task is None or self._listener_task.done()):
self._listener_task = self._loop.create_task(self._listen_from_core())
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
if self._outgoing_handler is handler:
self._outgoing_handler = None
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
async def send(self, message: MessageEnvelope) -> None:
await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message})
async def send_many(self, messages: list[MessageEnvelope]) -> None:
for message in messages:
await self.send(message)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
logger.debug("ProcessCoreSink.push_outgoing called in child; ignored")
async def close(self) -> None:
if self._closed:
return
self._closed = True
await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP)
if self._listener_task:
self._listener_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._listener_task
self._listener_task = None
async def _listen_from_core(self) -> None:
while not self._closed:
try:
item = await asyncio.to_thread(self._from_core_queue.get)
except asyncio.CancelledError:
break
if item == self._CONTROL_STOP:
break
if isinstance(item, dict) and item.get("kind") == "outgoing":
envelope = item.get("payload")
if self._outgoing_handler:
try:
await self._outgoing_handler(envelope)
except Exception: # pragma: no cover
logger.exception("Failed to handle outgoing envelope in ProcessCoreSink")
else:
logger.debug("ProcessCoreSink received unknown payload: %r", item)
class ProcessCoreSinkServer:
"""
进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。
- 将传入的 incoming 消息转发给指定的 handler
- 将接收到的 outgoing 消息放入 outgoing 队列
"""
def __init__(
self,
*,
incoming_queue: mp.Queue,
outgoing_queue: mp.Queue,
core_handler: Callable[[MessageEnvelope], Awaitable[None]],
name: str | None = None,
) -> None:
self._incoming_queue = incoming_queue
self._outgoing_queue = outgoing_queue
self._core_handler = core_handler
self._task: asyncio.Task | None = None
self._closed = False
self._name = name or "adapter"
def start(self) -> None:
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._consume_incoming())
async def _consume_incoming(self) -> None:
while not self._closed:
try:
item = await asyncio.to_thread(self._incoming_queue.get)
except asyncio.CancelledError:
break
if isinstance(item, dict) and item.get("__core_sink_control__") == "stop":
break
if isinstance(item, dict) and item.get("kind") == "incoming":
envelope = item.get("payload")
try:
await self._core_handler(envelope)
except Exception: # pragma: no cover
logger.exception("Failed to dispatch incoming envelope from %s", self._name)
else:
logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
async def close(self) -> None:
if self._closed:
return
self._closed = True
await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"})
await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP)
if self._task:
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
self._task = None
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
send_many = getattr(sink, "send_many", None)
@@ -200,11 +390,19 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) ->
await sink.send(env)
async def _maybe_await(result: Any) -> Any:
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
class BatchDispatcher:
"""
将 send 操作合并为批量发送,适合网络 IO 密集场景
批量消息分发器,负责将消息批量发送到核心 sink
"""
_STOP = object()
def __init__(
self,
sink: CoreMessageSink,
@@ -215,56 +413,79 @@ class BatchDispatcher:
self._sink = sink
self._max_batch_size = max_batch_size
self._flush_interval = flush_interval
self._buffer: list[MessageEnvelope] = []
self._lock = asyncio.Lock()
self._flush_task: asyncio.Task | None = None
self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue()
self._worker: asyncio.Task | None = None
self._closed = False
async def add(self, message: MessageEnvelope) -> None:
async with self._lock:
if self._closed:
raise RuntimeError("Dispatcher closed")
self._buffer.append(message)
self._ensure_timer()
if len(self._buffer) >= self._max_batch_size:
await self._flush_locked()
if self._closed:
raise RuntimeError("Dispatcher closed")
self._ensure_worker()
await self._queue.put(message)
async def close(self) -> None:
async with self._lock:
self._closed = True
await self._flush_locked()
if self._flush_task:
self._flush_task.cancel()
self._flush_task = None
def _ensure_timer(self) -> None:
if self._flush_task is not None and not self._flush_task.done():
if self._closed:
return
loop = asyncio.get_running_loop()
self._flush_task = loop.create_task(self._flush_loop())
self._closed = True
self._ensure_worker()
await self._queue.put(self._STOP)
if self._worker:
await self._worker
self._worker = None
async def _flush_loop(self) -> None:
def _ensure_worker(self) -> None:
if self._worker is not None and not self._worker.done():
return
self._worker = asyncio.create_task(self._worker_loop())
async def _worker_loop(self) -> None:
buffer: list[MessageEnvelope] = []
try:
await asyncio.sleep(self._flush_interval)
async with self._lock:
await self._flush_locked()
except asyncio.CancelledError: # pragma: no cover - timer cancellation
pass
while True:
try:
item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval)
except asyncio.TimeoutError:
item = None
async def _flush_locked(self) -> None:
if not self._buffer:
if item is self._STOP:
await self._flush_buffer(buffer)
return
if item is not None:
buffer.append(item) # type: ignore[arg-type]
while len(buffer) < self._max_batch_size:
try:
item = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
if item is self._STOP:
await self._flush_buffer(buffer)
return
buffer.append(item) # type: ignore[arg-type]
if buffer and (len(buffer) >= self._max_batch_size or item is None):
await self._flush_buffer(buffer)
except asyncio.CancelledError: # pragma: no cover - worker cancellation
if buffer:
await self._flush_buffer(buffer)
async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None:
if not buffer:
return
payload = list(self._buffer)
self._buffer.clear()
await self._sink.send_many(payload)
payload = list(buffer)
buffer.clear()
await _send_many(self._sink, payload)
__all__ = [
"AdapterTransportOptions",
"AdapterBase",
"BatchDispatcher",
"CoreSink",
"CoreMessageSink",
"HttpAdapterOptions",
"InProcessCoreSink",
"ProcessCoreSink",
"ProcessCoreSinkServer",
"WebSocketLike",
"WebSocketAdapterOptions",
]

View File

@@ -11,10 +11,36 @@ import orjson
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from .message_models import MessageBase
MessagePayload = Dict[str, Any]
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
if isinstance(payload, dict):
payload.setdefault("raw_bytes", raw_bytes)
elif isinstance(payload, list):
for item in payload:
if isinstance(item, dict):
item.setdefault("raw_bytes", raw_bytes)
return payload
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
if isinstance(message, (bytes, bytearray)):
return bytes(message), True
if use_raw_bytes and isinstance(message, dict):
raw = message.get("raw_bytes")
if isinstance(raw, (bytes, bytearray)):
return bytes(raw), True
payload = message
if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes:
payload = {k: v for k, v in payload.items() if k != "raw_bytes"}
data = orjson.dumps(payload)
if use_raw_bytes:
return data, True
return data.decode("utf-8"), False
class BaseMessageHandler:
@@ -60,6 +86,8 @@ class MessageServer(BaseMessageHandler):
mode: Literal["ws", "tcp"] = "ws",
custom_logger: logging.Logger | None = None,
enable_custom_uvicorn_logger: bool = False,
queue_maxsize: int = 1000,
worker_count: int = 1,
) -> None:
super().__init__()
if mode != "ws":
@@ -80,6 +108,9 @@ class MessageServer(BaseMessageHandler):
self._conn_lock = asyncio.Lock()
self._server: uvicorn.Server | None = None
self._running = False
self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize)
self._worker_count = max(1, worker_count)
self._worker_tasks: list[asyncio.Task] = []
self._setup_routes()
def _setup_routes(self) -> None:
@@ -97,21 +128,22 @@ class MessageServer(BaseMessageHandler):
while True:
msg = await websocket.receive()
if msg["type"] == "websocket.receive":
data = msg.get("text")
if data is None and msg.get("bytes") is not None:
data = msg["bytes"].decode("utf-8")
if not data:
raw_bytes = msg.get("bytes")
if raw_bytes is None and msg.get("text") is not None:
raw_bytes = msg["text"].encode("utf-8")
if not raw_bytes:
continue
try:
payload = orjson.loads(data)
payload = orjson.loads(raw_bytes)
except orjson.JSONDecodeError:
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
continue
payload = _attach_raw_bytes(payload, raw_bytes)
if isinstance(payload, list):
for item in payload:
await self.process_message(item)
await self._enqueue_message(item)
else:
await self.process_message(payload)
await self._enqueue_message(payload)
elif msg["type"] == "websocket.disconnect":
break
except WebSocketDisconnect:
@@ -119,6 +151,49 @@ class MessageServer(BaseMessageHandler):
finally:
await self._remove_connection(websocket, platform)
async def _enqueue_message(self, payload: MessagePayload) -> None:
if not self._worker_tasks:
self._start_workers()
try:
self._message_queue.put_nowait(payload)
except asyncio.QueueFull:
logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message")
def _start_workers(self) -> None:
if self._worker_tasks:
return
self._running = True
for _ in range(self._worker_count):
task = asyncio.create_task(self._consumer_worker())
self._worker_tasks.append(task)
async def _stop_workers(self) -> None:
if not self._worker_tasks:
return
self._running = False
for task in self._worker_tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
self._worker_tasks.clear()
while not self._message_queue.empty():
with contextlib.suppress(asyncio.QueueEmpty):
self._message_queue.get_nowait()
self._message_queue.task_done()
async def _consumer_worker(self) -> None:
while self._running:
try:
payload = await self._message_queue.get()
except asyncio.CancelledError:
break
try:
await self.process_message(payload)
except Exception: # pragma: no cover - best effort logging
logging.getLogger("mofox_bus.server").exception("Error processing message")
finally:
self._message_queue.task_done()
async def verify_token(self, token: str | None) -> bool:
if not self._enable_token:
return True
@@ -145,33 +220,45 @@ class MessageServer(BaseMessageHandler):
if platform and self._platform_connections.get(platform) is websocket:
del self._platform_connections[platform]
async def broadcast_message(self, message: MessagePayload) -> None:
data = orjson.dumps(message).decode("utf-8")
async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None:
payload: MessagePayload | bytes = message
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
async with self._conn_lock:
targets = list(self._connections)
for ws in targets:
await ws.send_text(data)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None:
async def broadcast_to_platform(
self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False
) -> None:
ws = self._platform_connections.get(platform)
if ws is None:
raise RuntimeError(f"No active connection for platform {platform}")
await ws.send_text(orjson.dumps(message).decode("utf-8"))
payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
async def send_message(self, message: MessageBase | MessagePayload) -> None:
payload = message.to_dict() if isinstance(message, MessageBase) else message
platform = payload.get("message_info", {}).get("platform")
async def send_message(
self, message: MessagePayload, *, prefer_raw_bytes: bool = False
) -> None:
platform = message.get("message_info", {}).get("platform")
if not platform:
raise ValueError("message_info.platform is required to route the message")
await self.broadcast_to_platform(platform, payload)
await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes)
def run_sync(self) -> None:
if not self._own_app:
return
asyncio.run(self.run())
async def run(self) -> None:
self._running = True
self._start_workers()
if not self._own_app:
return
config = uvicorn.Config(
@@ -191,6 +278,7 @@ class MessageServer(BaseMessageHandler):
async def stop(self) -> None:
self._running = False
await self._stop_workers()
if self._server:
self._server.should_exit = True
await self._server.shutdown()
@@ -217,7 +305,13 @@ class MessageClient(BaseMessageHandler):
WebSocket 消息客户端,实现双向传输。
"""
def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None:
def __init__(
self,
mode: Literal["ws", "tcp"] = "ws",
*,
reconnect_interval: float = 5.0,
logger: logging.Logger | None = None,
) -> None:
super().__init__()
if mode != "ws":
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
@@ -230,6 +324,9 @@ class MessageClient(BaseMessageHandler):
self._token: str | None = None
self._ssl_verify: str | None = None
self._closed = False
self._on_disconnect: DisconnectCallback | None = None
self._reconnect_interval = reconnect_interval
self._logger = logger or logging.getLogger("mofox_bus.client")
async def connect(
self,
@@ -243,8 +340,12 @@ class MessageClient(BaseMessageHandler):
self._platform = platform
self._token = token
self._ssl_verify = ssl_verify
self._closed = False
await self._establish_connection()
def set_disconnect_callback(self, callback: DisconnectCallback) -> None:
self._on_disconnect = callback
async def _establish_connection(self) -> None:
if self._session is None:
self._session = aiohttp.ClientSession()
@@ -257,17 +358,21 @@ class MessageClient(BaseMessageHandler):
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
self._receive_task = asyncio.create_task(self._receive_loop())
async def _connect_once(self) -> None:
await self._establish_connection()
async def _receive_loop(self) -> None:
assert self._ws is not None
try:
async for msg in self._ws:
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8")
raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8")
try:
payload = orjson.loads(data)
payload = orjson.loads(raw_bytes)
except orjson.JSONDecodeError:
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
continue
payload = _attach_raw_bytes(payload, raw_bytes)
if isinstance(payload, list):
for item in payload:
await self.process_message(item)
@@ -278,23 +383,33 @@ class MessageClient(BaseMessageHandler):
except asyncio.CancelledError: # pragma: no cover - cancellation path
pass
finally:
if not self._closed:
await self._notify_disconnect("websocket disconnected")
await self._reconnect()
if self._ws:
await self._ws.close()
self._ws = None
async def run(self) -> None:
if self._receive_task is None:
await self._establish_connection()
try:
if self._receive_task:
await self._receive_task
except asyncio.CancelledError: # pragma: no cover - cancellation path
pass
self._closed = False
while not self._closed:
if self._receive_task is None:
await self._establish_connection()
task = self._receive_task
if task is None:
break
try:
await task
except asyncio.CancelledError: # pragma: no cover - cancellation path
raise
async def send_message(self, message: MessagePayload) -> bool:
if self._ws is None or self._ws.closed:
raise RuntimeError("WebSocket connection is not established")
await self._ws.send_str(orjson.dumps(message).decode("utf-8"))
async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool:
ws = await self._ensure_ws()
data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
else:
await ws.send_str(data if isinstance(data, str) else data.decode("utf-8"))
return True
def is_connected(self) -> bool:
@@ -313,6 +428,42 @@ class MessageClient(BaseMessageHandler):
await self._session.close()
self._session = None
async def _notify_disconnect(self, reason: str) -> None:
if self._on_disconnect is None:
return
try:
result = self._on_disconnect(self._platform, reason)
if asyncio.iscoroutine(result):
await result
except Exception: # pragma: no cover - best effort notification
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
async def _reconnect(self) -> None:
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
await asyncio.sleep(self._reconnect_interval)
await self._connect_once()
async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._ws is None or self._ws.closed:
await self._connect_once()
assert self._ws is not None
return self._ws
async def __aenter__(self) -> "MessageClient":
if not self._url or not self._platform:
raise RuntimeError("connect() must be called before using MessageClient as a context manager")
await self._ensure_session()
await self._ensure_ws()
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.stop()
def _self_websocket(app: FastAPI, path: str):
"""

110
src/mofox_bus/builder.py Normal file
View File

@@ -0,0 +1,110 @@
from __future__ import annotations
import time
import uuid
from typing import Any, Dict, List
from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload
class MessageBuilder:
"""
Fluent helper to build MessageEnvelope safely with type hints.
Example:
msg = (
MessageBuilder()
.text("Hello")
.image("http://example.com/1.png")
.to_user("123", platform="qq")
.build()
)
"""
def __init__(self) -> None:
self._direction: str = "outgoing"
self._message_info: MessageInfoPayload = {}
self._segments: List[SegPayload] = []
self._metadata: Dict[str, Any] | None = None
self._timestamp_ms: int | None = None
self._message_id: str | None = None
def direction(self, value: str) -> "MessageBuilder":
self._direction = value
return self
def message_id(self, value: str) -> "MessageBuilder":
self._message_id = value
return self
def timestamp_ms(self, value: int | None = None) -> "MessageBuilder":
self._timestamp_ms = value or int(time.time() * 1000)
return self
def metadata(self, value: Dict[str, Any]) -> "MessageBuilder":
self._metadata = value
return self
def platform(self, value: str) -> "MessageBuilder":
self._message_info["platform"] = value
return self
def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder":
if platform:
self.platform(platform)
user_info: UserInfoPayload = {"user_id": user_id}
if nickname:
user_info["user_nickname"] = nickname
self._message_info["user_info"] = user_info
return self
def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder":
if platform:
self.platform(platform)
group_info: GroupInfoPayload = {"group_id": group_id}
if name:
group_info["group_name"] = name
self._message_info["group_info"] = group_info
return self
def seg(self, type_: str, data: Any) -> "MessageBuilder":
self._segments.append({"type": type_, "data": data})
return self
def text(self, content: str) -> "MessageBuilder":
return self.seg("text", content)
def image(self, url: str) -> "MessageBuilder":
return self.seg("image", url)
def reply(self, target_message_id: str) -> "MessageBuilder":
return self.seg("reply", target_message_id)
def raw_segment(self, segment: SegPayload) -> "MessageBuilder":
self._segments.append(segment)
return self
def build(self) -> MessageEnvelope:
# message_info defaults
if not self._segments:
raise ValueError("message_segment is required, add at least one segment before build()")
if self._message_id is None:
self._message_id = str(uuid.uuid4())
info = dict(self._message_info)
info.setdefault("message_id", self._message_id)
info.setdefault("time", time.time())
segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments]
envelope: MessageEnvelope = {
"direction": self._direction, # type: ignore[assignment]
"message_info": info,
"message_segment": segments[0] if len(segments) == 1 else list(segments),
}
if self._metadata is not None:
envelope["metadata"] = self._metadata
if self._timestamp_ms is not None:
envelope["timestamp_ms"] = self._timestamp_ms
return envelope
__all__ = ["MessageBuilder"]

View File

@@ -27,24 +27,23 @@ def _loads(data: bytes) -> Dict[str, Any]:
def dumps_message(msg: MessageEnvelope) -> bytes:
"""
将单条 MessageEnvelope 序列化为 JSON bytes。
将单条消息序列化为 JSON bytes。
"""
if "schema_version" not in msg:
msg["schema_version"] = DEFAULT_SCHEMA_VERSION
return _dumps(msg)
sanitized = _strip_raw_bytes(msg)
if "schema_version" not in sanitized:
sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION
return _dumps(sanitized)
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
"""
多条消息批量序列化,以提升吞吐
批量消息序列化为 JSON bytes
"""
payload = {
"schema_version": DEFAULT_SCHEMA_VERSION,
"items": list(messages),
"items": [_strip_raw_bytes(msg) for msg in messages],
}
return _dumps(payload)
def loads_message(data: bytes | str) -> MessageEnvelope:
"""
反序列化单条消息。
@@ -78,6 +77,14 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
raise ValueError(f"Unsupported schema_version={version}")
def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope:
if isinstance(msg, dict) and "raw_bytes" in msg:
new_msg = dict(msg)
new_msg.pop("raw_bytes", None)
return new_msg # type: ignore[return-value]
return msg
__all__ = [
"DEFAULT_SCHEMA_VERSION",
"dumps_message",

View File

@@ -1,189 +0,0 @@
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional
@dataclass
class Seg:
"""
消息段,表示一段文本/图片/结构化内容。
"""
type: str
data: Any
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Seg":
seg_type = data.get("type")
seg_data = data.get("data")
if seg_type == "seglist" and isinstance(seg_data, list):
seg_data = [Seg.from_dict(item) for item in seg_data]
return cls(type=seg_type, data=seg_data)
def to_dict(self) -> Dict[str, Any]:
if self.type == "seglist" and isinstance(self.data, list):
payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data]
else:
payload = self.data
return {"type": self.type, "data": payload}
@dataclass
class GroupInfo:
platform: Optional[str] = None
group_id: Optional[str] = None
group_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]:
if not data or data.get("group_id") is None:
return None
return cls(
platform=data.get("platform"),
group_id=data.get("group_id"),
group_name=data.get("group_name"),
)
@dataclass
class UserInfo:
platform: Optional[str] = None
user_id: Optional[str] = None
user_nickname: Optional[str] = None
user_cardname: Optional[str] = None
user_avatar: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "UserInfo":
return cls(
platform=data.get("platform"),
user_id=data.get("user_id"),
user_nickname=data.get("user_nickname"),
user_cardname=data.get("user_cardname"),
user_avatar=data.get("user_avatar"),
)
@dataclass
class FormatInfo:
content_format: Optional[List[str]] = None
accept_format: Optional[List[str]] = None
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]:
if not data:
return None
return cls(
content_format=data.get("content_format"),
accept_format=data.get("accept_format"),
)
@dataclass
class TemplateInfo:
template_items: Optional[Dict[str, str]] = None
template_name: Optional[Dict[str, str]] = None
template_default: bool = True
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]:
if not data:
return None
return cls(
template_items=data.get("template_items"),
template_name=data.get("template_name"),
template_default=data.get("template_default", True),
)
@dataclass
class BaseMessageInfo:
platform: Optional[str] = None
message_id: Optional[str] = None
time: Optional[float] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
format_info: Optional[FormatInfo] = None
template_info: Optional[TemplateInfo] = None
additional_config: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
if self.platform is not None:
result["platform"] = self.platform
if self.message_id is not None:
result["message_id"] = self.message_id
if self.time is not None:
result["time"] = self.time
if self.additional_config is not None:
result["additional_config"] = self.additional_config
if self.group_info is not None:
result["group_info"] = self.group_info.to_dict()
if self.user_info is not None:
result["user_info"] = self.user_info.to_dict()
if self.format_info is not None:
result["format_info"] = self.format_info.to_dict()
if self.template_info is not None:
result["template_info"] = self.template_info.to_dict()
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo":
return cls(
platform=data.get("platform"),
message_id=data.get("message_id"),
time=data.get("time"),
additional_config=data.get("additional_config"),
group_info=GroupInfo.from_dict(data.get("group_info", {})),
user_info=UserInfo.from_dict(data.get("user_info", {})),
format_info=FormatInfo.from_dict(data.get("format_info", {})),
template_info=TemplateInfo.from_dict(data.get("template_info", {})),
)
@dataclass
class MessageBase:
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"message_info": self.message_info.to_dict(),
"message_segment": self.message_segment.to_dict(),
}
if self.raw_message is not None:
payload["raw_message"] = self.raw_message
return payload
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MessageBase":
return cls(
message_info=BaseMessageInfo.from_dict(data.get("message_info", {})),
message_segment=Seg.from_dict(data.get("message_segment", {})),
raw_message=data.get("raw_message"),
)
__all__ = [
"BaseMessageInfo",
"FormatInfo",
"GroupInfo",
"MessageBase",
"Seg",
"TemplateInfo",
"UserInfo",
]

View File

@@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass
from typing import Callable, Dict, Optional
from .api import MessageClient
from .message_models import MessageBase
from .types import MessageEnvelope
logger = logging.getLogger("mofox_bus.router")
@@ -55,7 +55,7 @@ class Router:
self.handlers: list[Callable[[Dict], None]] = []
self._running = False
self._client_tasks: Dict[str, asyncio.Task] = {}
self._monitor_task: asyncio.Task | None = None
self._stop_event: asyncio.Event | None = None
async def connect(self, platform: str) -> None:
if platform not in self.config.route_config:
@@ -65,6 +65,7 @@ class Router:
if mode != "ws":
raise NotImplementedError("TCP mode is not implemented yet")
client = MessageClient(mode="ws")
client.set_disconnect_callback(self._handle_client_disconnect)
await client.connect(
url=target.url,
platform=platform,
@@ -75,7 +76,7 @@ class Router:
client.register_message_handler(handler)
self.clients[platform] = client
if self._running:
self._client_tasks[platform] = asyncio.create_task(client.run())
self._start_client_task(platform, client)
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
self.handlers.append(handler)
@@ -84,36 +85,18 @@ class Router:
async def run(self) -> None:
self._running = True
self._stop_event = asyncio.Event()
for platform in self.config.route_config:
if platform not in self.clients:
await self.connect(platform)
for platform, client in self.clients.items():
if platform not in self._client_tasks:
self._client_tasks[platform] = asyncio.create_task(client.run())
self._monitor_task = asyncio.create_task(self._monitor_connections())
self._start_client_task(platform, client)
try:
while self._running:
await asyncio.sleep(1)
await self._stop_event.wait()
except asyncio.CancelledError: # pragma: no cover
raise
async def _monitor_connections(self) -> None:
await asyncio.sleep(3)
while self._running:
for platform in list(self.clients.keys()):
client = self.clients.get(platform)
if client is None:
continue
if not client.is_connected():
logger.info("Detected disconnect from %s, attempting reconnect", platform)
await self._reconnect_platform(platform)
await asyncio.sleep(5)
async def _reconnect_platform(self, platform: str) -> None:
await self.remove_platform(platform)
if platform in self.config.route_config:
await self.connect(platform)
async def remove_platform(self, platform: str) -> None:
if platform in self._client_tasks:
task = self._client_tasks.pop(platform)
@@ -124,32 +107,55 @@ class Router:
if client:
await client.stop()
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason)
task = self._client_tasks.get(platform)
if task is not None and not task.done():
return
client = self.clients.get(platform)
if client and self._running:
self._start_client_task(platform, client)
async def stop(self) -> None:
self._running = False
if self._monitor_task:
self._monitor_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._monitor_task
self._monitor_task = None
if self._stop_event:
self._stop_event.set()
for platform in list(self.clients.keys()):
await self.remove_platform(platform)
self.clients.clear()
def get_target_url(self, message: MessageBase) -> Optional[str]:
platform = message.message_info.platform
def _start_client_task(self, platform: str, client: MessageClient) -> None:
task = asyncio.create_task(client.run())
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
self._client_tasks[platform] = task
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
if not self._running:
return
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Client task for %s ended with exception: %s", platform, exc)
client = self.clients.get(platform)
if client:
self._start_client_task(platform, client)
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
platform = message.get("message_info", {}).get("platform")
if not platform:
return None
target = self.config.route_config.get(platform)
return target.url if target else None
async def send_message(self, message: MessageBase):
platform = message.message_info.platform
async def send_message(self, message: MessageEnvelope):
platform = message.get("message_info", {}).get("platform")
if not platform:
raise ValueError("message_info.platform is required")
client = self.clients.get(platform)
if client is None:
raise RuntimeError(f"No client connected for platform {platform}")
return await client.send_message(message.to_dict())
return await client.send_message(message)
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
new_config = RouteConfig.from_dict(config_data)

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
import asyncio
import inspect
import threading
from dataclasses import dataclass
from typing import Awaitable, Callable, Iterable, List
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
from .types import MessageEnvelope
@@ -12,6 +13,11 @@ ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]]
class Middleware(Protocol):
async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ...
class MessageProcessingError(RuntimeError):
@@ -19,7 +25,7 @@ class MessageProcessingError(RuntimeError):
def __init__(self, message: MessageEnvelope, original: BaseException):
detail = message.get("id", "<unknown>")
super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only
super().__init__(f"Failed to handle message {detail}: {original}")
self.message_envelope = message
self.original = original
@@ -29,6 +35,8 @@ class MessageRoute:
predicate: Predicate
handler: MessageHandler
name: str | None = None
message_type: str | None = None
event_types: set[str] | None = None
class MessageRuntime:
@@ -43,15 +51,36 @@ class MessageRuntime:
self._error_hooks: list[ErrorHook] = []
self._batch_handler: BatchHandler | None = None
self._lock = threading.RLock()
self._middlewares: list[Middleware] = []
self._type_routes: Dict[str, list[MessageRoute]] = {}
self._event_routes: Dict[str, list[MessageRoute]] = {}
def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None:
def add_route(
self,
predicate: Predicate,
handler: MessageHandler,
name: str | None = None,
*,
message_type: str | None = None,
event_types: Iterable[str] | None = None,
) -> None:
with self._lock:
self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name))
route = MessageRoute(
predicate=predicate,
handler=handler,
name=name,
message_type=message_type,
event_types=set(event_types) if event_types is not None else None,
)
self._routes.append(route)
if message_type:
self._type_routes.setdefault(message_type, []).append(route)
if route.event_types:
for et in route.event_types:
self._event_routes.setdefault(et, []).append(route)
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
"""
装饰器写法,便于在核心逻辑中声明式注册。
"""
"""装饰器写法,便于在核心逻辑中声明式注册。"""
def decorator(func: MessageHandler) -> MessageHandler:
self.add_route(predicate, func, name=name)
@@ -59,6 +88,60 @@ class MessageRuntime:
return decorator
def on_message(
self,
*,
message_type: str | None = None,
platform: str | None = None,
predicate: Predicate | None = None,
name: str | None = None,
) -> Callable[[MessageHandler], MessageHandler]:
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。"""
async def combined_predicate(message: MessageEnvelope) -> bool:
if message_type is not None and _extract_segment_type(message) != message_type:
return False
if platform is not None:
info_platform = message.get("message_info", {}).get("platform")
if message.get("platform") not in (None, platform) and info_platform is None:
return False
if info_platform not in (None, platform):
return False
if predicate is None:
return True
return await _invoke_callable(predicate, message, prefer_thread=False)
def decorator(func: MessageHandler) -> MessageHandler:
self.add_route(combined_predicate, func, name=name, message_type=message_type)
return func
return decorator
def on_event(
self,
event_type: str | Iterable[str],
*,
name: str | None = None,
) -> Callable[[MessageHandler], MessageHandler]:
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
async def predicate(message: MessageEnvelope) -> bool:
current = (
message.get("event_type")
or message.get("message_info", {})
.get("additional_config", {})
.get("event_type")
)
return current in allowed
def decorator(func: MessageHandler) -> MessageHandler:
self.add_route(predicate, func, name=name, event_types=allowed)
return func
return decorator
def set_batch_handler(self, handler: BatchHandler) -> None:
self._batch_handler = handler
@@ -71,14 +154,20 @@ class MessageRuntime:
def register_error_hook(self, hook: ErrorHook) -> None:
self._error_hooks.append(hook)
def register_middleware(self, middleware: Middleware) -> None:
"""注册洋葱模型中间件,围绕处理器执行。"""
self._middlewares.append(middleware)
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
await self._run_hooks(self._before_hooks, message)
try:
route = await self._match_route(message)
if route is None:
return None
result = await _maybe_await(route.handler(message))
except Exception as exc: # pragma: no cover - tested indirectly
handler = self._wrap_with_middlewares(route.handler)
result = await handler(message)
except Exception as exc:
await self._run_error_hooks(message, exc)
raise MessageProcessingError(message, exc) from exc
await self._run_hooks(self._after_hooks, message)
@@ -89,7 +178,7 @@ class MessageRuntime:
if not batch:
return []
if self._batch_handler is not None:
result = await _maybe_await(self._batch_handler(batch))
result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True)
return result or []
responses: list[MessageEnvelope] = []
for message in batch:
@@ -99,21 +188,61 @@ class MessageRuntime:
return responses
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
candidates: list[MessageRoute] = []
message_type = _extract_segment_type(message)
event_type = (
message.get("event_type")
or message.get("message_info", {})
.get("additional_config", {})
.get("event_type")
)
with self._lock:
routes = list(self._routes)
for route in routes:
should_handle = await _maybe_await(route.predicate(message))
if event_type and event_type in self._event_routes:
candidates.extend(self._event_routes[event_type])
if message_type and message_type in self._type_routes:
candidates.extend(self._type_routes[message_type])
candidates.extend(self._routes)
seen: set[int] = set()
for route in candidates:
rid = id(route)
if rid in seen:
continue
seen.add(rid)
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
if should_handle:
return route
return None
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
for hook in hooks:
await _maybe_await(hook(message))
coro_list = [self._call_hook(hook, message) for hook in hooks]
if coro_list:
await asyncio.gather(*coro_list)
async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None:
await _invoke_callable(hook, message, prefer_thread=True)
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
for hook in self._error_hooks:
await _maybe_await(hook(message, exc))
coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks]
if coros:
await asyncio.gather(*coros)
async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None:
await _invoke_callable(hook, message, exc, prefer_thread=True)
def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable:
async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None:
return await _invoke_callable(handler, message, prefer_thread=True)
wrapped: MiddlewareCallable = base_handler
for middleware in reversed(self._middlewares):
current = wrapped
async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None:
return await _invoke_callable(mw, msg, nxt, prefer_thread=False)
wrapped = wrapper
return wrapped
async def _maybe_await(result):
@@ -122,6 +251,32 @@ async def _maybe_await(result):
return result
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
"""Support sync/async callables with optional thread offloading."""
if inspect.iscoroutinefunction(func):
return await func(*args)
if prefer_thread:
result = await asyncio.to_thread(func, *args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
result = func(*args)
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
def _extract_segment_type(message: MessageEnvelope) -> str | None:
seg = message.get("message_segment") or message.get("message_chain")
if isinstance(seg, dict):
return seg.get("type")
if isinstance(seg, list) and seg:
first = seg[0]
if isinstance(first, dict):
return first.get("type")
return None
__all__ = [
"BatchHandler",
"Hook",
@@ -129,5 +284,6 @@ __all__ = [
"MessageProcessingError",
"MessageRoute",
"MessageRuntime",
"Middleware",
"Predicate",
]

View File

@@ -3,160 +3,91 @@ from __future__ import annotations
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
MessageDirection = Literal["incoming", "outgoing"]
Role = Literal["user", "assistant", "system", "tool", "platform"]
ContentType = Literal[
"text",
"image",
"audio",
"file",
"video",
"event",
"command",
"system",
]
EventType = Literal[
"message_created",
"message_updated",
"message_deleted",
"member_join",
"member_leave",
"typing",
"reaction_add",
"reaction_remove",
]
# ----------------------------
# maim_message 风格的 TypedDict
# ----------------------------
class TextContent(TypedDict, total=False):
type: Literal["text"]
text: str
markdown: NotRequired[bool]
entities: NotRequired[List[Dict[str, Any]]]
class SegPayload(TypedDict, total=False):
"""
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
"""
type: str
data: str | List["SegPayload"]
translated_data: NotRequired[str | List["SegPayload"]]
class ImageContent(TypedDict, total=False):
type: Literal["image"]
url: str
mime_type: NotRequired[str]
width: NotRequired[int]
height: NotRequired[int]
file_id: NotRequired[str]
class UserInfoPayload(TypedDict, total=False):
platform: NotRequired[str]
user_id: NotRequired[str]
user_nickname: NotRequired[str]
user_cardname: NotRequired[str]
user_avatar: NotRequired[str]
class FileContent(TypedDict, total=False):
type: Literal["file"]
url: str
mime_type: NotRequired[str]
file_name: NotRequired[str]
file_size: NotRequired[int]
file_id: NotRequired[str]
class GroupInfoPayload(TypedDict, total=False):
platform: NotRequired[str]
group_id: NotRequired[str]
group_name: NotRequired[str]
class AudioContent(TypedDict, total=False):
type: Literal["audio"]
url: str
mime_type: NotRequired[str]
duration_ms: NotRequired[int]
file_id: NotRequired[str]
class FormatInfoPayload(TypedDict, total=False):
content_format: NotRequired[List[str]]
accept_format: NotRequired[List[str]]
class VideoContent(TypedDict, total=False):
type: Literal["video"]
url: str
mime_type: NotRequired[str]
duration_ms: NotRequired[int]
width: NotRequired[int]
height: NotRequired[int]
file_id: NotRequired[str]
class TemplateInfoPayload(TypedDict, total=False):
template_items: NotRequired[Dict[str, str]]
template_name: NotRequired[Dict[str, str]]
template_default: NotRequired[bool]
class EventContent(TypedDict):
type: Literal["event"]
event_type: EventType
raw: Dict[str, Any]
class MessageInfoPayload(TypedDict, total=False):
platform: NotRequired[str]
message_id: NotRequired[str]
time: NotRequired[float]
group_info: NotRequired[GroupInfoPayload]
user_info: NotRequired[UserInfoPayload]
format_info: NotRequired[FormatInfoPayload]
template_info: NotRequired[TemplateInfoPayload]
additional_config: NotRequired[Dict[str, Any]]
class CommandContent(TypedDict, total=False):
type: Literal["command"]
name: str
args: Dict[str, Any]
class SystemContent(TypedDict):
type: Literal["system"]
text: str
Content = (
TextContent
| ImageContent
| FileContent
| AudioContent
| VideoContent
| EventContent
| CommandContent
| SystemContent
)
class SenderInfo(TypedDict, total=False):
user_id: str
role: Role
display_name: NotRequired[str]
avatar_url: NotRequired[str]
raw: NotRequired[Dict[str, Any]]
class ChannelInfo(TypedDict, total=False):
channel_id: str
channel_type: Literal[
"private",
"group",
"supergroup",
"channel",
"dm",
"room",
"thread",
]
title: NotRequired[str]
workspace_id: NotRequired[str]
raw: NotRequired[Dict[str, Any]]
# ----------------------------
# MessageEnvelope
# ----------------------------
class MessageEnvelope(TypedDict, total=False):
id: str
direction: MessageDirection
platform: str
timestamp_ms: int
channel: ChannelInfo
sender: SenderInfo
content: Content
conversation_id: str
thread_id: NotRequired[str]
reply_to_message_id: NotRequired[str]
correlation_id: NotRequired[str]
is_edited: NotRequired[bool]
is_ephemeral: NotRequired[bool]
raw_platform_message: NotRequired[Dict[str, Any]]
metadata: NotRequired[Dict[str, Any]]
schema_version: NotRequired[int]
"""
mofox-bus 传输层统一使用的消息信封。
- 采用 maim_message 风格message_info + message_segment。
"""
direction: MessageDirection
message_info: MessageInfoPayload
message_segment: SegPayload | List[SegPayload]
raw_message: NotRequired[Any]
raw_bytes: NotRequired[bytes]
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
platform: NotRequired[str] # 快捷访问,等价于 message_info.platform
message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id
timestamp_ms: NotRequired[int]
correlation_id: NotRequired[str]
schema_version: NotRequired[int]
metadata: NotRequired[Dict[str, Any]]
__all__ = [
"AudioContent",
"ChannelInfo",
"CommandContent",
"Content",
"ContentType",
"EventContent",
"EventType",
"FileContent",
"ImageContent",
# maim_message style payloads
"SegPayload",
"UserInfoPayload",
"GroupInfoPayload",
"FormatInfoPayload",
"TemplateInfoPayload",
"MessageInfoPayload",
# legacy content style
"MessageDirection",
"MessageEnvelope",
"Role",
"SenderInfo",
"SystemContent",
"TextContent",
"VideoContent",
]

View File

@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreMessageSink, MessageEnvelope
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope
if TYPE_CHECKING:
from src.plugin_system.base.base_plugin import BasePlugin
@@ -47,7 +47,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
def __init__(
self,
core_sink: CoreMessageSink,
core_sink: CoreSink,
plugin: Optional[BasePlugin] = None,
**kwargs
):
@@ -227,7 +227,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
)
@abstractmethod
def from_platform_message(self, raw: Any) -> MessageEnvelope:
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
"""
将平台原始消息转换为 MessageEnvelope

View File

@@ -7,130 +7,152 @@ Adapter 管理器
from __future__ import annotations
import asyncio
import subprocess
import sys
from pathlib import Path
import importlib
import multiprocessing as mp
from typing import TYPE_CHECKING, Dict, Optional
if TYPE_CHECKING:
from src.plugin_system.base.base_adapter import BaseAdapter
from mofox_bus import ProcessCoreSinkServer
from src.common.core_sink import get_core_sink
from src.common.logger import get_logger
logger = get_logger("adapter_manager")
class AdapterProcess:
"""适配器子进程包装器"""
def __init__(
self,
adapter_name: str,
entry_path: Path,
python_executable: Optional[str] = None,
):
self.adapter_name = adapter_name
self.entry_path = entry_path
self.python_executable = python_executable or sys.executable
self.process: Optional[subprocess.Popen] = None
self._monitor_task: Optional[asyncio.Task] = None
def _load_class(module_name: str, class_name: str):
module = importlib.import_module(module_name)
return getattr(module, class_name)
def _adapter_process_entry(
adapter_path: tuple[str, str],
plugin_info: dict | None,
incoming_queue: mp.Queue,
outgoing_queue: mp.Queue,
):
import asyncio
import contextlib
from mofox_bus import ProcessCoreSink
async def _run() -> None:
adapter_cls = _load_class(*adapter_path)
plugin_instance = None
if plugin_info:
plugin_cls = _load_class(plugin_info["module"], plugin_info["class"])
plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"])
core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue)
adapter = adapter_cls(core_sink, plugin=plugin_instance)
await adapter.start()
try:
while not getattr(core_sink, "_closed", False):
await asyncio.sleep(0.2)
finally:
with contextlib.suppress(Exception):
await adapter.stop()
with contextlib.suppress(Exception):
await core_sink.close()
asyncio.run(_run())
class AdapterProcess:
"""适配器子进程包装器,负责适配器子进程的启动和生命周期管理"""
def __init__(self, adapter: "BaseAdapter", core_sink) -> None:
self.adapter = adapter
self.adapter_name = adapter.adapter_name
self.process: mp.Process | None = None
self._ctx = mp.get_context("spawn")
self._incoming_queue: mp.Queue = self._ctx.Queue()
self._outgoing_queue: mp.Queue = self._ctx.Queue()
self._bridge: ProcessCoreSinkServer | None = None
self._core_sink = core_sink
self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__)
self._plugin_info = self._extract_plugin_info(adapter)
self._outgoing_handler = None
@staticmethod
def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None:
plugin = getattr(adapter, "plugin", None)
if plugin is None:
return None
return {
"module": plugin.__class__.__module__,
"class": plugin.__class__.__name__,
"plugin_dir": getattr(plugin, "plugin_dir", ""),
"metadata": getattr(plugin, "plugin_meta", None),
}
def _make_outgoing_handler(self):
async def _handler(envelope):
if self._bridge:
await self._bridge.push_outgoing(envelope)
return _handler
async def start(self) -> bool:
"""启动适配器子进程"""
try:
logger.info(f"启动适配器子进程: {self.adapter_name}")
logger.debug(f"Python: {self.python_executable}")
logger.debug(f"Entry: {self.entry_path}")
# 启动子进程
self.process = subprocess.Popen(
[self.python_executable, str(self.entry_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
self._bridge = ProcessCoreSinkServer(
incoming_queue=self._incoming_queue,
outgoing_queue=self._outgoing_queue,
core_handler=self._core_sink.send,
name=self.adapter_name,
)
# 启动监控任务
self._monitor_task = asyncio.create_task(self._monitor_process())
logger.info(f"适配器 {self.adapter_name} 子进程已启动 (PID: {self.process.pid})")
self._bridge.start()
if hasattr(self._core_sink, "set_outgoing_handler"):
self._outgoing_handler = self._make_outgoing_handler()
try:
self._core_sink.set_outgoing_handler(self._outgoing_handler)
except Exception:
logger.exception("Failed to register outgoing bridge for %s", self.adapter_name)
self.process = self._ctx.Process(
target=_adapter_process_entry,
args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue),
name=f"{self.adapter_name}-proc",
)
self.process.start()
logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})")
return True
except Exception as e:
logger.error(f"启动适配器 {self.adapter_name} 子进程失败: {e}", exc_info=True)
logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True)
return False
async def stop(self) -> None:
"""停止适配器子进程"""
if not self.process:
return
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
try:
# 取消监控任务
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
remover = getattr(self._core_sink, "remove_outgoing_handler", None)
if callable(remover) and self._outgoing_handler:
try:
await self._monitor_task
except asyncio.CancelledError:
pass
# 终止进程
self.process.terminate()
# 等待进程退出最多等待5秒
try:
await asyncio.wait_for(
asyncio.to_thread(self.process.wait),
timeout=5.0
)
except asyncio.TimeoutError:
logger.warning(f"适配器 {self.adapter_name} 未能在5秒内退出强制终止")
self.process.kill()
await asyncio.to_thread(self.process.wait)
logger.info(f"适配器 {self.adapter_name} 子进程已停止")
remover(self._outgoing_handler)
except Exception:
logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败")
if self._bridge:
await self._bridge.close()
if self.process.is_alive():
self.process.join(timeout=5.0)
if self.process.is_alive():
logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中")
self.process.terminate()
self.process.join()
except Exception as e:
logger.error(f"停止适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True)
finally:
self.process = None
async def _monitor_process(self) -> None:
"""监控子进程状态"""
if not self.process:
return
try:
# 在后台线程中等待进程退出
return_code = await asyncio.to_thread(self.process.wait)
if return_code != 0:
logger.error(
f"适配器 {self.adapter_name} 子进程异常退出 (返回码: {return_code})"
)
# 读取 stderr 输出
if self.process.stderr:
stderr = self.process.stderr.read()
if stderr:
logger.error(f"错误输出:\n{stderr}")
else:
logger.info(f"适配器 {self.adapter_name} 子进程正常退出")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"监控适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
def is_running(self) -> bool:
"""检查进程是否正在运行"""
"""适配器是否正在运行"""
if not self.process:
return False
return self.process.poll() is None
return self.process.is_alive()
class AdapterManager:
"""适配器管理器"""
@@ -176,20 +198,17 @@ class AdapterManager:
else:
return await self._start_adapter_in_process(adapter)
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
"""在子进程中启动适配器"""
adapter_name = adapter.adapter_name
# 获取子进程入口脚本
entry_path = adapter.get_subprocess_entry_path()
if not entry_path:
logger.error(
f"适配器 {adapter_name} 配置为子进程运行,但未提供有效的入口脚本"
)
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
"""启动适配器子进程"""
adapter_name = adapter.adapter_name
try:
core_sink = get_core_sink()
except Exception as e:
logger.error(f"无法获取 core_sink启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True)
return False
# 创建并启动子进程
adapter_process = AdapterProcess(adapter_name, entry_path)
adapter_process = AdapterProcess(adapter, core_sink)
success = await adapter_process.start()
if success:

View File

@@ -0,0 +1,427 @@
# NEW_napcat_adapter
基于 mofox-bus v2.x 的 Napcat 适配器(使用 BaseAdapter 架构)
## 🏗️ 架构设计
本插件采用 **BaseAdapter 继承模式** 重写,完全抛弃旧版 maim_message 库,改用 mofox-bus 的 TypedDict 数据结构。
### 核心组件
- **NapcatAdapter**: 继承自 `mofox_bus.AdapterBase`,负责 OneBot 11 协议与 MessageEnvelope 的双向转换
- **WebSocketAdapterOptions**: 自动管理 WebSocket 连接,提供 incoming_parser 和 outgoing_encoder
- **CoreMessageSink**: 通过 `InProcessCoreSink` 将消息递送到核心系统
- **Handlers**: 独立的消息处理器,分为 to_core接收和 to_napcat发送两个方向
## 📁 项目结构
```
NEW_napcat_adapter/
├── plugin.py # ✅ 主插件文件BaseAdapter实现
├── _manifest.json # 插件清单
└── src/
├── event_models.py # ✅ OneBot事件类型常量
├── common/
│ └── core_sink.py # ✅ 全局CoreSink访问点
├── utils/
│ ├── utils.py # ⏳ 工具函数(待实现)
│ ├── qq_emoji_list.py # ⏳ QQ表情映射待实现
│ ├── video_handler.py # ⏳ 视频处理(待实现)
│ └── message_chunker.py # ⏳ 消息切片(待实现)
├── websocket/
│ └── (无需单独实现使用WebSocketAdapterOptions)
├── database/
│ └── database.py # ⏳ 数据库模型(待实现)
└── handlers/
├── to_core/ # Napcat → MessageEnvelope 方向
│ ├── message_handler.py # ⏳ 消息处理(部分完成)
│ ├── notice_handler.py # ⏳ 通知处理(待完成)
│ └── meta_event_handler.py # ⏳ 元事件(待完成)
└── to_napcat/ # MessageEnvelope → Napcat API 方向
└── send_handler.py # ⏳ 发送处理(部分完成)
```
## 🚀 快速开始
### 使用方式
1. **配置文件**: 在 `config/plugins/NEW_napcat_adapter.toml` 中配置 WebSocket URL 和其他参数
2. **启动插件**: 插件自动在系统启动时加载
3. **WebSocket连接**: 自动连接到 Napcat OneBot 11 服务器
## 🔑 核心数据结构
### MessageEnvelope (mofox-bus v2.x)
```python
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload
# 创建消息信封
envelope: MessageEnvelope = {
"direction": "input",
"message_info": {
"message_type": "group",
"message_id": "12345",
"self_id": "bot_qq",
"user_info": {
"user_id": "sender_qq",
"user_name": "发送者",
"user_displayname": "昵称"
},
"group_info": {
"group_id": "group_id",
"group_name": "群名"
},
"to_me": False
},
"message_segment": {
"type": "seglist",
"data": [
{"type": "text", "data": "hello"},
{"type": "image", "data": "base64_data"}
]
},
"raw_message": "hello[图片]",
"platform": "napcat",
"message_id": "12345",
"timestamp_ms": 1234567890
}
```
### BaseAdapter 核心方法
```python
class NapcatAdapter(BaseAdapter):
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
"""将 OneBot 11 事件转换为 MessageEnvelope"""
# 路由到对应的 Handler
async def _send_platform_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
"""将 MessageEnvelope 转换为 OneBot 11 API 调用"""
# 调用 SendHandler 处理
```
## 📝 实现进度
### ✅ 已完成的核心架构
1. **BaseAdapter 实现** (plugin.py)
- ✅ WebSocket 自动连接管理
- ✅ from_platform_message() 事件路由
- ✅ _send_platform_message() 消息发送
- ✅ API 响应池机制echo-based request-response
- ✅ CoreSink 集成
2. **Handler 基础结构**
- ✅ MessageHandler 骨架text、image、at 基本实现)
- ✅ NoticeHandler 骨架
- ✅ MetaEventHandler 骨架
- ✅ SendHandler 骨架(基本类型转换)
3. **辅助组件**
- ✅ event_models.py事件类型常量
- ✅ core_sink.py全局 CoreSink 访问)
- ✅ 配置 Schema 定义
### ⏳ 部分完成的功能
4. **消息类型处理** (MessageHandler)
- ✅ 基础消息类型text, image, at
- ❌ 高级消息类型face, reply, forward, video, json, file, rps, dice, shake
5. **发送处理** (SendHandler)
- ✅ 基础 SegPayload 转换text, image
- ❌ 高级 Seg 类型emoji, voice, voiceurl, music, videourl, file, command
### ❌ 待实现的功能
6. **通知事件处理** (NoticeHandler)
- ❌ 戳一戳事件
- ❌ 表情回应事件
- ❌ 撤回事件
- ❌ 禁言事件
7. **工具函数** (utils.py)
- ❌ get_group_info
- ❌ get_member_info
- ❌ get_image_base64
- ❌ get_message_detail
- ❌ get_record_detail
8. **权限系统**
- ❌ check_allow_to_chat()
- ❌ 群组黑名单/白名单
- ❌ 私聊黑名单/白名单
- ❌ QQ机器人检测
9. **其他组件**
- ❌ 视频处理器
- ❌ 消息切片器
- ❌ 数据库模型
- ❌ QQ 表情映射表
## 📋 下一步工作
### 优先级 1完善消息处理参考旧版 recv_handler/message_handler.py
1. **完整实现 MessageHandler.handle_raw_message()**
- [ ] face表情消息段
- [ ] reply回复消息段
- [ ] forward转发消息段解析
- [ ] video视频消息段
- [ ] jsonJSON卡片消息段
- [ ] file文件消息段
- [ ] rps/dice/shake特殊消息
2. **实现工具函数**(参考旧版 utils.py
- [ ] `get_group_info()` - 获取群组信息
- [ ] `get_member_info()` - 获取成员信息
- [ ] `get_image_base64()` - 下载图片并转Base64
- [ ] `get_message_detail()` - 获取消息详情
- [ ] `get_record_detail()` - 获取语音详情
3. **实现权限检查**
- [ ] `check_allow_to_chat()` - 检查是否允许聊天
- [ ] 群组白名单/黑名单逻辑
- [ ] 私聊白名单/黑名单逻辑
- [ ] QQ机器人检测ban_qq_bot
### 优先级 2完善发送处理参考旧版 send_handler.py
4. **完整实现 SendHandler._convert_seg_to_onebot()**
- [ ] emoji表情回应命令
- [ ] voice语音消息段
- [ ] voiceurl语音URL消息段
- [ ] music音乐卡片消息段
- [ ] videourl视频URL消息段
- [ ] file文件消息段
- [ ] command命令消息段
5. **实现命令处理**
- [ ] GROUP_BAN禁言
- [ ] GROUP_KICK踢人
- [ ] SEND_POKE戳一戳
- [ ] DELETE_MSG撤回消息
- [ ] GROUP_WHOLE_BAN全员禁言
- [ ] SET_GROUP_CARD设置群名片
- [ ] SET_GROUP_ADMIN设置管理员
### 优先级 3补全其他组件参考旧版对应文件
6. **NoticeHandler 实现**
- [ ] 戳一戳通知notify.poke
- [ ] 表情回应通知notice.group_emoji_like
- [ ] 消息撤回通知notice.group_recall
- [ ] 禁言通知notice.group_ban
7. **辅助组件**
- [ ] `qq_emoji_list.py` - QQ表情ID映射表
- [ ] `video_handler.py` - 视频处理ffmpeg封面提取
- [ ] `message_chunker.py` - 消息分块与重组
- [ ] `database.py` - 数据库模型(如有需要)
### 优先级 4测试与优化
8. **功能测试**
- [ ] 文本消息收发
- [ ] 图片消息收发
- [ ] @消息处理
- [ ] 表情/语音/视频消息
- [ ] 转发消息解析
- [ ] 所有命令功能
- [ ] 通知事件处理
9. **性能优化**
- [ ] 消息处理并发性能
- [ ] API响应池性能
- [ ] 内存占用优化
## 🔍 关键实现细节
### 1. MessageEnvelope vs 旧版 MessageBase
**不再使用 Seg dataclass**,全部使用 TypedDict
```python
# ❌ 旧版maim_message
from mofox_bus import Seg, MessageBase
seg = Seg(type="text", data="hello")
message = MessageBase(message_info=info, message_segment=seg)
# ✅ 新版mofox-bus v2.x
from mofox_bus import SegPayload, MessageEnvelope
seg_payload: SegPayload = {"type": "text", "data": "hello"}
envelope: MessageEnvelope = {
"direction": "input",
"message_info": {...},
"message_segment": seg_payload,
...
}
```
### 2. Handler 架构模式
**接收方向** (to_core):
```python
class MessageHandler:
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
async def handle_raw_message(self, data: dict[str, Any]) -> MessageEnvelope:
# 1. 解析 OneBot 11 数据
# 2. 构建 message_infoMessageInfoPayload
# 3. 转换消息段为 SegPayload
# 4. 返回完整的 MessageEnvelope
```
**发送方向** (to_napcat):
```python
class SendHandler:
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
async def handle_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
# 1. 从 envelope 提取 message_segment
# 2. 递归转换 SegPayload → OneBot 格式
# 3. 调用 adapter.send_napcat_api() 发送
```
### 3. API 调用模式(响应池)
```python
# 在 NapcatAdapter 中
async def send_napcat_api(self, action: str, params: dict[str, Any]) -> dict[str, Any]:
# 1. 生成唯一 echo
echo = f"{action}_{uuid.uuid4()}"
# 2. 创建 Future 等待响应
future = asyncio.Future()
self._response_pool[echo] = future
# 3. 发送请求(通过 WebSocket
await self._send_request({"action": action, "params": params, "echo": echo})
# 4. 等待响应(带超时)
try:
result = await asyncio.wait_for(future, timeout=10.0)
return result
finally:
self._response_pool.pop(echo, None)
# 响应回来时(在 incoming_parser 中)
def _handle_api_response(data: dict[str, Any]):
echo = data.get("echo")
if echo in adapter._response_pool:
adapter._response_pool[echo].set_result(data)
```
### 4. 类型提示技巧
处理 TypedDict 的严格类型检查:
```python
# 使用 type: ignore 标注(编译时是 TypedDict运行时是 dict
envelope: MessageEnvelope = {
"direction": "input",
...
} # type: ignore[typeddict-item]
# 或在函数签名中使用 dict[str, Any]
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
...
return envelope # type: ignore[return-value]
```
## 🔍 测试检查清单
- [ ] 文本消息接收/发送
- [ ] 图片消息接收/发送
- [ ] 语音消息接收/发送
- [ ] 视频消息接收/发送
- [ ] @消息接收/发送
- [ ] 回复消息接收/发送
- [ ] 转发消息接收
- [ ] JSON消息接收
- [ ] 文件消息接收/发送
- [ ] 禁言命令
- [ ] 踢人命令
- [ ] 戳一戳命令
- [ ] 表情回应命令
- [ ] 通知事件处理
- [ ] 元事件处理
## 📚 参考资料
- **mofox-bus 文档**: 查看 `mofox_bus/types.py` 了解 TypedDict 定义
- **BaseAdapter 示例**: 参考 `docs/mofox_bus_demo_adapter.py`
- **旧版实现**: `src/plugins/built_in/napcat_adapter_plugin/` (仅参考逻辑)
- **OneBot 11 协议**: [OneBot 11 标准](https://github.com/botuniverse/onebot-11)
## ⚠️ 重要注意事项
1. **完全抛弃旧版数据结构**
- ❌ 不再使用 `Seg` dataclass
- ❌ 不再使用 `MessageBase`
- ✅ 全部使用 `SegPayload`TypedDict
- ✅ 全部使用 `MessageEnvelope`TypedDict
2. **BaseAdapter 生命周期**
- `__init__()` 中初始化同步资源
- `start()` 中执行异步初始化WebSocket连接自动建立
- `stop()` 中清理资源WebSocket自动断开
3. **WebSocketAdapterOptions 自动管理**
- 无需手动管理 WebSocket 连接
- incoming_parser 自动解析接收数据
- outgoing_encoder 自动编码发送数据
- 重连机制由基类处理
4. **CoreSink 依赖注入**
- 必须在插件加载后调用 `set_core_sink(sink)`
- 通过 `get_core_sink()` 全局访问
- 用于将消息递送到核心系统
5. **类型安全与灵活性平衡**
- TypedDict 在编译时提供类型检查
- 运行时仍是普通 dict可灵活操作
- 必要时使用 `type: ignore` 抑制误报
6. **参考旧版但不照搬**
- 旧版逻辑流程可参考
- 数据结构需完全重写
- API调用模式已改变响应池
## 📊 预估工作量
- ✅ 核心架构: **已完成** (BaseAdapter + Handlers 骨架)
- ⏳ 消息处理完善: **4-6 小时** (所有消息类型 + 工具函数)
- ⏳ 发送处理完善: **3-4 小时** (所有 Seg 类型 + 命令)
- ⏳ 通知事件处理: **2-3 小时** (poke/emoji_like/recall/ban)
- ⏳ 测试调试: **2-4 小时** (全流程测试)
- **总剩余时间: 11-17 小时**
## ✅ 完成标准
当以下条件全部满足时,重写完成:
1. ✅ BaseAdapter 架构实现完成
2. ⏳ 所有 OneBot 11 消息类型支持
3. ⏳ 所有发送消息段类型支持
4. ⏳ 所有通知事件正确处理
5. ⏳ 权限系统集成完成
6. ⏳ 与旧版功能完全对等
7. ⏳ 所有测试用例通过
---
**最后更新**: 2025-11-23
**架构状态**: ✅ 核心架构完成
**实现状态**: ⏳ 消息处理部分完成,需完善细节
**预计完成**: 根据优先级,核心功能预计 1-2 个工作日

View File

@@ -0,0 +1,16 @@
from src.plugin_system.base.plugin_metadata import PluginMetadata
__plugin_meta__ = PluginMetadata(
name="napcat_plugin",
description="基于OneBot 11协议的NapCat QQ协议插件提供完整的QQ机器人API接口使用现有adapter连接",
usage="该插件提供 `napcat_tool` tool。",
version="1.0.0",
author="Windpicker_owo",
license="GPL-v3.0-or-later",
repository_url="https://github.com/Windpicker-owo",
keywords=["qq", "bot", "napcat", "onebot", "api", "websocket"],
categories=["protocol"],
extra={
"is_built_in": False,
},
)

View File

@@ -0,0 +1,330 @@
"""
Napcat 适配器(基于 MoFox-Bus 完全重写版)
核心流程:
1. Napcat WebSocket 连接 → 接收 OneBot 格式消息
2. from_platform_message: OneBot dict → MessageEnvelope
3. CoreSink → 推送到 MoFox-Bot 核心
4. 核心回复 → _send_platform_message: MessageEnvelope → OneBot API 调用
"""
from __future__ import annotations
import asyncio
import uuid
from typing import Any, ClassVar, Dict, List, Optional
import orjson
import websockets
from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions
from src.common.logger import get_logger
from src.plugin_system import register_plugin
from src.plugin_system.base import BaseAdapter, BasePlugin
from src.plugin_system.apis import config_api
from .src.handlers.to_core.message_handler import MessageHandler
from .src.handlers.to_core.notice_handler import NoticeHandler
from .src.handlers.to_core.meta_event_handler import MetaEventHandler
from .src.handlers.to_napcat.send_handler import SendHandler
logger = get_logger("napcat_adapter")
class NapcatAdapter(BaseAdapter):
"""Napcat 适配器 - 完全基于 mofox-bus 架构"""
adapter_name = "napcat_adapter"
adapter_version = "2.0.0"
adapter_author = "MoFox Team"
adapter_description = "基于 MoFox-Bus 的 Napcat/OneBot 11 适配器"
platform = "qq"
run_in_subprocess = False
subprocess_entry = None
def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None):
"""初始化 Napcat 适配器"""
# 从插件配置读取 WebSocket URL
if plugin:
mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse")
host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost")
port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095)
url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "")
access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "")
if mode == "forward" and url:
ws_url = url
else:
ws_url = f"ws://{host}:{port}"
headers = {}
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
else:
ws_url = "ws://127.0.0.1:8095"
headers = {}
# 配置 WebSocket 传输
transport = WebSocketAdapterOptions(
url=ws_url,
headers=headers if headers else None,
incoming_parser=self._parse_napcat_message,
outgoing_encoder=self._encode_napcat_response,
)
super().__init__(core_sink, plugin=plugin, transport=transport)
# 初始化处理器
self.message_handler = MessageHandler(self)
self.notice_handler = NoticeHandler(self)
self.meta_event_handler = MetaEventHandler(self)
self.send_handler = SendHandler(self)
# 响应池:用于存储等待的 API 响应
self._response_pool: Dict[str, asyncio.Future] = {}
self._response_timeout = 30.0
# WebSocket 连接(用于发送 API 请求)
# 注意_ws 继承自 BaseAdapter是 WebSocketLike 协议类型
self._napcat_ws = None # 可选的额外连接引用
async def on_adapter_loaded(self) -> None:
"""适配器加载时的初始化"""
logger.info("Napcat 适配器正在启动...")
# 设置处理器配置
if self.plugin:
self.message_handler.set_plugin_config(self.plugin.config)
self.notice_handler.set_plugin_config(self.plugin.config)
self.meta_event_handler.set_plugin_config(self.plugin.config)
self.send_handler.set_plugin_config(self.plugin.config)
logger.info("Napcat 适配器已加载")
async def on_adapter_unloaded(self) -> None:
"""适配器卸载时的清理"""
logger.info("Napcat 适配器正在关闭...")
# 清理响应池
for future in self._response_pool.values():
if not future.done():
future.cancel()
self._response_pool.clear()
logger.info("Napcat 适配器已关闭")
def _parse_napcat_message(self, raw: str | bytes) -> Any:
"""解析 Napcat/OneBot 消息"""
try:
if isinstance(raw, bytes):
data = orjson.loads(raw)
else:
data = orjson.loads(raw)
return data
except Exception as e:
logger.error(f"解析 Napcat 消息失败: {e}")
raise
def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes:
"""编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)"""
return orjson.dumps(envelope)
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
"""
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
这是核心转换方法,处理:
- message 事件 → 消息
- notice 事件 → 通知(戳一戳、表情回复等)
- meta_event 事件 → 元事件(心跳、生命周期)
- API 响应 → 存入响应池
"""
post_type = raw.get("post_type")
# API 响应(没有 post_type有 echo
if post_type is None and "echo" in raw:
echo = raw.get("echo")
if echo and echo in self._response_pool:
future = self._response_pool[echo]
if not future.done():
future.set_result(raw)
# API 响应不需要转换为 MessageEnvelope返回空信封
return self._create_empty_envelope()
# 消息事件
if post_type == "message":
return await self.message_handler.handle_raw_message(raw) # type: ignore[return-value]
# 通知事件
elif post_type == "notice":
return await self.notice_handler.handle_notice(raw) # type: ignore[return-value]
# 元事件
elif post_type == "meta_event":
return await self.meta_event_handler.handle_meta_event(raw) # type: ignore[return-value]
# 未知事件类型
else:
logger.warning(f"未知的事件类型: {post_type}")
return self._create_empty_envelope() # type: ignore[return-value]
async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override]
"""
将 MessageEnvelope 转换并发送到 Napcat
这里不直接通过 WebSocket 发送 envelope
而是调用 Napcat APIsend_group_msg, send_private_msg 等)
"""
await self.send_handler.handle_message(envelope)
def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return]
"""创建一个空的消息信封(用于不需要处理的事件)"""
import time
return {
"direction": "incoming",
"message_info": {
"platform": self.platform,
"message_id": str(uuid.uuid4()),
"time": time.time(),
},
"message_segment": {"type": "text", "data": "[系统事件]"},
"timestamp_ms": int(time.time() * 1000),
}
async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
"""
发送 Napcat API 请求并等待响应
Args:
action: API 动作名称(如 send_group_msg
params: API 参数
timeout: 超时时间(秒)
Returns:
API 响应数据
"""
if not self._ws:
raise RuntimeError("WebSocket 连接未建立")
# 生成唯一的 echo ID
echo = str(uuid.uuid4())
# 创建 Future 用于等待响应
future = asyncio.Future()
self._response_pool[echo] = future
# 构造请求
request = orjson.dumps({
"action": action,
"params": params,
"echo": echo,
})
try:
# 发送请求
await self._ws.send(request)
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
logger.error(f"API 请求超时: {action}")
raise
except Exception as e:
logger.error(f"API 请求失败: {action}, 错误: {e}")
raise
finally:
# 清理响应池
self._response_pool.pop(echo, None)
def get_ws_connection(self):
"""获取 WebSocket 连接(用于发送 API 请求)"""
if not self._ws:
raise RuntimeError("WebSocket 连接未建立")
return self._ws
@register_plugin
class NapcatAdapterPlugin(BasePlugin):
"""Napcat 适配器插件"""
plugin_name = "napcat_adapter_plugin"
enable_plugin = True
plugin_version = "2.0.0"
plugin_author = "MoFox Team"
plugin_description = "Napcat/OneBot 11 适配器(基于 MoFox-Bus 重写)"
# 配置 Schema
config_schema: ClassVar[dict] = {
"plugin": {
"name": {"type": str, "default": "napcat_adapter_plugin"},
"version": {"type": str, "default": "2.0.0"},
"enabled": {"type": bool, "default": True},
},
"napcat_server": {
"mode": {
"type": str,
"default": "reverse",
"description": "连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
},
"host": {"type": str, "default": "localhost"},
"port": {"type": int, "default": 8095},
"url": {"type": str, "default": "", "description": "正向连接时的完整URL"},
"access_token": {"type": str, "default": ""},
},
"features": {
"group_list_type": {"type": str, "default": "blacklist"},
"group_list": {"type": list, "default": []},
"private_list_type": {"type": str, "default": "blacklist"},
"private_list": {"type": list, "default": []},
"ban_user_id": {"type": list, "default": []},
"ban_qq_bot": {"type": bool, "default": False},
},
}
def __init__(self, plugin_dir: str = "", metadata: Any = None):
# 如果没有提供参数,创建一个默认的元数据
if metadata is None:
from src.plugin_system.base.plugin_metadata import PluginMetadata
metadata = PluginMetadata(
name=self.plugin_name,
version=self.plugin_version,
author=self.plugin_author,
description=self.plugin_description,
usage="",
dependencies=[],
python_dependencies=[],
)
if not plugin_dir:
from pathlib import Path
plugin_dir = str(Path(__file__).parent)
super().__init__(plugin_dir, metadata)
self._adapter: Optional[NapcatAdapter] = None
async def on_plugin_loaded(self):
"""插件加载时启动适配器"""
logger.info("Napcat 适配器插件正在加载...")
# 获取核心 Sink
from src.common.core_sink import get_core_sink
core_sink = get_core_sink()
# 创建并启动适配器
self._adapter = NapcatAdapter(core_sink, plugin=self)
await self._adapter.start()
logger.info("Napcat 适配器插件已加载")
async def on_plugin_unloaded(self):
"""插件卸载时停止适配器"""
if self._adapter:
await self._adapter.stop()
logger.info("Napcat 适配器插件已卸载")
def get_plugin_components(self) -> list:
"""返回适配器组件"""
return [(NapcatAdapter.get_adapter_info(), NapcatAdapter)]

View File

@@ -0,0 +1 @@
"""工具模块"""

View File

@@ -0,0 +1,310 @@
from enum import Enum
class MetaEventType:
lifecycle = "lifecycle" # 生命周期
class Lifecycle:
connect = "connect" # 生命周期 - WebSocket 连接成功
heartbeat = "heartbeat" # 心跳
class MessageType: # 接受消息大类
private = "private" # 私聊消息
class Private:
friend = "friend" # 私聊消息 - 好友
group = "group" # 私聊消息 - 群临时
group_self = "group_self" # 私聊消息 - 群中自身发送
other = "other" # 私聊消息 - 其他
group = "group" # 群聊消息
class Group:
normal = "normal" # 群聊消息 - 普通
anonymous = "anonymous" # 群聊消息 - 匿名消息
notice = "notice" # 群聊消息 - 系统提示
class NoticeType: # 通知事件
friend_recall = "friend_recall" # 私聊消息撤回
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
group_ban = "group_ban" # 群禁言
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
group_upload = "group_upload" # 群文件上传
class Notify:
poke = "poke" # 戳一戳
input_status = "input_status" # 正在输入
class GroupBan:
ban = "ban" # 禁言
lift_ban = "lift_ban" # 解除禁言
class RealMessageType: # 实际消息分类
text = "text" # 纯文本
face = "face" # qq表情
image = "image" # 图片
record = "record" # 语音
video = "video" # 视频
at = "at" # @某人
rps = "rps" # 猜拳魔法表情
dice = "dice" # 骰子
shake = "shake" # 私聊窗口抖动(只收)
poke = "poke" # 群聊戳一戳
share = "share" # 链接分享json形式
reply = "reply" # 回复消息
forward = "forward" # 转发消息
node = "node" # 转发消息节点
json = "json" # json消息
file = "file" # 文件
class MessageSentType:
private = "private"
class Private:
friend = "friend"
group = "group"
group = "group"
class Group:
normal = "normal"
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
AI_VOICE_SEND = "ai_voice_send" # AI语音发送
SET_EMOJI_LIKE = "set_msg_emoji_like" # 设置表情回应
SEND_AT_MESSAGE = "send_at_message" # 发送@消息
SEND_LIKE = "send_like" # 发送点赞
def __str__(self) -> str:
return self.value
# 支持的消息格式
ACCEPT_FORMAT = [
"text",
"image",
"emoji",
"reply",
"voice",
"command",
"voiceurl",
"music",
"videourl",
"file",
]
# 插件名称
PLUGIN_NAME = "NEW_napcat_adapter"
# QQ表情映射表
QQ_FACE = {
"0": "[表情:惊讶]",
"1": "[表情:撇嘴]",
"2": "[表情:色]",
"3": "[表情:发呆]",
"4": "[表情:得意]",
"5": "[表情:流泪]",
"6": "[表情:害羞]",
"7": "[表情:闭嘴]",
"8": "[表情:睡]",
"9": "[表情:大哭]",
"10": "[表情:尴尬]",
"11": "[表情:发怒]",
"12": "[表情:调皮]",
"13": "[表情:呲牙]",
"14": "[表情:微笑]",
"15": "[表情:难过]",
"16": "[表情:酷]",
"18": "[表情:抓狂]",
"19": "[表情:吐]",
"20": "[表情:偷笑]",
"21": "[表情:可爱]",
"22": "[表情:白眼]",
"23": "[表情:傲慢]",
"24": "[表情:饥饿]",
"25": "[表情:困]",
"26": "[表情:惊恐]",
"27": "[表情:流汗]",
"28": "[表情:憨笑]",
"29": "[表情:悠闲]",
"30": "[表情:奋斗]",
"31": "[表情:咒骂]",
"32": "[表情:疑问]",
"33": "[表情:嘘]",
"34": "[表情:晕]",
"35": "[表情:折磨]",
"36": "[表情:衰]",
"37": "[表情:骷髅]",
"38": "[表情:敲打]",
"39": "[表情:再见]",
"41": "[表情:发抖]",
"42": "[表情:爱情]",
"43": "[表情:跳跳]",
"46": "[表情:猪头]",
"49": "[表情:拥抱]",
"53": "[表情:蛋糕]",
"56": "[表情:刀]",
"59": "[表情:便便]",
"60": "[表情:咖啡]",
"63": "[表情:玫瑰]",
"64": "[表情:凋谢]",
"66": "[表情:爱心]",
"67": "[表情:心碎]",
"74": "[表情:太阳]",
"75": "[表情:月亮]",
"76": "[表情:赞]",
"77": "[表情:踩]",
"78": "[表情:握手]",
"79": "[表情:胜利]",
"85": "[表情:飞吻]",
"86": "[表情:怄火]",
"89": "[表情:西瓜]",
"96": "[表情:冷汗]",
"97": "[表情:擦汗]",
"98": "[表情:抠鼻]",
"99": "[表情:鼓掌]",
"100": "[表情:糗大了]",
"101": "[表情:坏笑]",
"102": "[表情:左哼哼]",
"103": "[表情:右哼哼]",
"104": "[表情:哈欠]",
"105": "[表情:鄙视]",
"106": "[表情:委屈]",
"107": "[表情:快哭了]",
"108": "[表情:阴险]",
"109": "[表情:左亲亲]",
"110": "[表情:吓]",
"111": "[表情:可怜]",
"112": "[表情:菜刀]",
"114": "[表情:篮球]",
"116": "[表情:示爱]",
"118": "[表情:抱拳]",
"119": "[表情:勾引]",
"120": "[表情:拳头]",
"121": "[表情:差劲]",
"123": "[表情NO]",
"124": "[表情OK]",
"125": "[表情:转圈]",
"129": "[表情:挥手]",
"137": "[表情:鞭炮]",
"144": "[表情:喝彩]",
"146": "[表情:爆筋]",
"147": "[表情:棒棒糖]",
"169": "[表情:手枪]",
"171": "[表情:茶]",
"172": "[表情:眨眼睛]",
"173": "[表情:泪奔]",
"174": "[表情:无奈]",
"175": "[表情:卖萌]",
"176": "[表情:小纠结]",
"177": "[表情:喷血]",
"178": "[表情:斜眼笑]",
"179": "[表情doge]",
"181": "[表情:戳一戳]",
"182": "[表情:笑哭]",
"183": "[表情:我最美]",
"185": "[表情:羊驼]",
"187": "[表情:幽灵]",
"201": "[表情:点赞]",
"212": "[表情:托腮]",
"262": "[表情:脑阔疼]",
"263": "[表情:沧桑]",
"264": "[表情:捂脸]",
"265": "[表情:辣眼睛]",
"266": "[表情:哦哟]",
"267": "[表情:头秃]",
"268": "[表情:问号脸]",
"269": "[表情:暗中观察]",
"270": "[表情emm]",
"271": "[表情:吃瓜]",
"272": "[表情:呵呵哒]",
"273": "[表情:我酸了]",
"277": "[表情:滑稽狗头]",
"281": "[表情:翻白眼]",
"282": "[表情:敬礼]",
"283": "[表情:狂笑]",
"284": "[表情:面无表情]",
"285": "[表情:摸鱼]",
"286": "[表情:魔鬼笑]",
"287": "[表情:哦]",
"289": "[表情:睁眼]",
"293": "[表情:摸锦鲤]",
"294": "[表情:期待]",
"295": "[表情:拿到红包]",
"297": "[表情:拜谢]",
"298": "[表情:元宝]",
"299": "[表情:牛啊]",
"300": "[表情:胖三斤]",
"302": "[表情:左拜年]",
"303": "[表情:右拜年]",
"305": "[表情:右亲亲]",
"306": "[表情:牛气冲天]",
"307": "[表情:喵喵]",
"311": "[表情打call]",
"312": "[表情:变形]",
"314": "[表情:仔细分析]",
"317": "[表情:菜汪]",
"318": "[表情:崇拜]",
"319": "[表情:比心]",
"320": "[表情:庆祝]",
"323": "[表情:嫌弃]",
"324": "[表情:吃糖]",
"325": "[表情:惊吓]",
"326": "[表情:生气]",
"332": "[表情:举牌牌]",
"333": "[表情:烟花]",
"334": "[表情:虎虎生威]",
"336": "[表情:豹富]",
"337": "[表情:花朵脸]",
"338": "[表情:我想开了]",
"339": "[表情:舔屏]",
"341": "[表情:打招呼]",
"342": "[表情酸Q]",
"343": "[表情:我方了]",
"344": "[表情:大怨种]",
"345": "[表情:红包多多]",
"346": "[表情:你真棒棒]",
"347": "[表情:大展宏兔]",
"349": "[表情:坚强]",
"350": "[表情:贴贴]",
"351": "[表情:敲敲]",
"352": "[表情:咦]",
"353": "[表情:拜托]",
"354": "[表情:尊嘟假嘟]",
"355": "[表情:耶]",
"356": "[表情666]",
"357": "[表情:裂开]",
"392": "[表情:龙年快乐]",
"393": "[表情:新年中龙]",
"394": "[表情:新年大龙]",
"395": "[表情:略略略]",
"396": "[表情:龙年快乐]",
"424": "[表情:按钮]",
}
__all__ = [
"MetaEventType",
"MessageType",
"NoticeType",
"RealMessageType",
"MessageSentType",
"CommandType",
"ACCEPT_FORMAT",
"PLUGIN_NAME",
"QQ_FACE",
]

View File

@@ -0,0 +1 @@
"""处理器模块"""

View File

@@ -0,0 +1 @@
"""接收方向处理器"""

View File

@@ -0,0 +1,126 @@
"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
logger = get_logger("napcat_adapter.message_handler")
class MessageHandler:
"""处理来自 Napcat 的消息事件"""
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
self.plugin_config: Optional[Dict[str, Any]] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""设置插件配置"""
self.plugin_config = config
async def handle_raw_message(self, raw: Dict[str, Any]):
"""
处理原始消息并转换为 MessageEnvelope
Args:
raw: OneBot 原始消息数据
Returns:
MessageEnvelope (dict)
"""
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload
message_type = raw.get("message_type")
message_id = str(raw.get("message_id", ""))
message_time = time.time()
# 构造用户信息
sender_info = raw.get("sender", {})
user_info: UserInfoPayload = {
"platform": "qq",
"user_id": str(sender_info.get("user_id", "")),
"user_nickname": sender_info.get("nickname", ""),
"user_cardname": sender_info.get("card", ""),
"user_avatar": sender_info.get("avatar", ""),
}
# 构造群组信息(如果是群消息)
group_info: Optional[GroupInfoPayload] = None
if message_type == "group":
group_id = raw.get("group_id")
if group_id:
group_info = {
"platform": "qq",
"group_id": str(group_id),
"group_name": "", # 可以通过 API 获取
}
# 解析消息段
message_segments = raw.get("message", [])
seg_list: List[SegPayload] = []
for seg in message_segments:
seg_type = seg.get("type", "")
seg_data = seg.get("data", {})
# 转换为 SegPayload
if seg_type == "text":
seg_list.append({
"type": "text",
"data": seg_data.get("text", "")
})
elif seg_type == "image":
# 这里需要下载图片并转换为 base64简化版本
seg_list.append({
"type": "image",
"data": seg_data.get("url", "") # 实际应该转换为 base64
})
elif seg_type == "at":
seg_list.append({
"type": "at",
"data": f"{seg_data.get('qq', '')}"
})
# 其他消息类型...
# 构造 MessageInfoPayload
message_info = {
"platform": "qq",
"message_id": message_id,
"time": message_time,
"user_info": user_info,
"format_info": {
"content_format": ["text", "image"], # 根据实际消息类型设置
"accept_format": ["text", "image", "emoji", "voice"],
},
}
# 添加群组信息(如果存在)
if group_info:
message_info["group_info"] = group_info
# 构造 MessageEnvelope
envelope = {
"direction": "incoming",
"message_info": message_info,
"message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}),
"raw_message": raw.get("raw_message", ""),
"platform": "qq",
"message_id": message_id,
"timestamp_ms": int(message_time * 1000),
}
return envelope

View File

@@ -0,0 +1,41 @@
"""元事件处理器"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.common.logger import get_logger
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
logger = get_logger("napcat_adapter.meta_event_handler")
class MetaEventHandler:
"""处理 Napcat 元事件(心跳、生命周期)"""
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
self.plugin_config: Optional[Dict[str, Any]] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""设置插件配置"""
self.plugin_config = config
async def handle_meta_event(self, raw: Dict[str, Any]):
"""处理元事件"""
# 简化版本:返回一个空的 MessageEnvelope
import time
import uuid
return {
"direction": "incoming",
"message_info": {
"platform": "qq",
"message_id": str(uuid.uuid4()),
"time": time.time(),
},
"message_segment": {"type": "text", "data": "[元事件]"},
"timestamp_ms": int(time.time() * 1000),
}

View File

@@ -0,0 +1,41 @@
"""通知事件处理器"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.common.logger import get_logger
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
logger = get_logger("napcat_adapter.notice_handler")
class NoticeHandler:
"""处理 Napcat 通知事件(戳一戳、表情回复等)"""
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
self.plugin_config: Optional[Dict[str, Any]] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""设置插件配置"""
self.plugin_config = config
async def handle_notice(self, raw: Dict[str, Any]):
"""处理通知事件"""
# 简化版本:返回一个空的 MessageEnvelope
import time
import uuid
return {
"direction": "incoming",
"message_info": {
"platform": "qq",
"message_id": str(uuid.uuid4()),
"time": time.time(),
},
"message_segment": {"type": "text", "data": "[通知事件]"},
"timestamp_ms": int(time.time() * 1000),
}

View File

@@ -0,0 +1 @@
"""发送方向处理器"""

View File

@@ -0,0 +1,77 @@
"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.common.logger import get_logger
if TYPE_CHECKING:
from ...plugin import NapcatAdapter
logger = get_logger("napcat_adapter.send_handler")
class SendHandler:
"""处理向 Napcat 发送消息"""
def __init__(self, adapter: "NapcatAdapter"):
self.adapter = adapter
self.plugin_config: Optional[Dict[str, Any]] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""设置插件配置"""
self.plugin_config = config
async def handle_message(self, envelope) -> None:
"""
处理发送消息
将 MessageEnvelope 转换为 OneBot API 调用
"""
message_info = envelope.get("message_info", {})
message_segment = envelope.get("message_segment", {})
# 获取群组和用户信息
group_info = message_info.get("group_info")
user_info = message_info.get("user_info")
# 构造消息内容
message = self._convert_seg_to_onebot(message_segment)
# 发送消息
if group_info:
# 发送群消息
group_id = group_info.get("group_id")
if group_id:
await self.adapter.send_napcat_api("send_group_msg", {
"group_id": int(group_id),
"message": message,
})
elif user_info:
# 发送私聊消息
user_id = user_info.get("user_id")
if user_id:
await self.adapter.send_napcat_api("send_private_msg", {
"user_id": int(user_id),
"message": message,
})
def _convert_seg_to_onebot(self, seg: Dict[str, Any]) -> list:
"""将 SegPayload 转换为 OneBot 消息格式"""
seg_type = seg.get("type", "")
seg_data = seg.get("data", "")
if seg_type == "text":
return [{"type": "text", "data": {"text": seg_data}}]
elif seg_type == "image":
return [{"type": "image", "data": {"file": f"base64://{seg_data}"}}]
elif seg_type == "seglist":
# 递归处理列表
result = []
for sub_seg in seg_data:
result.extend(self._convert_seg_to_onebot(sub_seg))
return result
else:
# 默认作为文本
return [{"type": "text", "data": {"text": str(seg_data)}}]

View File

@@ -0,0 +1,350 @@
"""
按聊天流分配消费者的消息路由系统
核心思想:
- 为每个活跃的聊天流stream_id创建独立的消息队列和消费者协程
- 同一聊天流的消息由同一个 worker 处理,保证顺序性
- 不同聊天流的消息并发处理,提高吞吐量
- 动态管理流的生命周期,自动清理不活跃的流
"""
import asyncio
import time
from typing import Dict, Optional
from src.common.logger import get_logger
logger = get_logger("stream_router")
class StreamConsumer:
"""单个聊天流的消息消费者
维护独立的消息队列和处理协程
"""
def __init__(self, stream_id: str, queue_maxsize: int = 100):
self.stream_id = stream_id
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
self.worker_task: Optional[asyncio.Task] = None
self.last_active_time = time.time()
self.is_running = False
# 性能统计
self.stats = {
"total_messages": 0,
"total_processing_time": 0.0,
"queue_overflow_count": 0,
}
async def start(self) -> None:
"""启动消费者"""
if not self.is_running:
self.is_running = True
self.worker_task = asyncio.create_task(self._process_loop())
logger.debug(f"Stream Consumer 启动: {self.stream_id}")
async def stop(self) -> None:
"""停止消费者"""
self.is_running = False
if self.worker_task:
self.worker_task.cancel()
try:
await self.worker_task
except asyncio.CancelledError:
pass
logger.debug(f"Stream Consumer 停止: {self.stream_id}")
async def enqueue(self, message: dict) -> None:
"""将消息加入队列"""
self.last_active_time = time.time()
try:
# 使用 put_nowait 避免阻塞路由器
self.queue.put_nowait(message)
except asyncio.QueueFull:
self.stats["queue_overflow_count"] += 1
logger.warning(
f"Stream {self.stream_id} 队列已满 "
f"({self.queue.qsize()}/{self.queue.maxsize})"
)
try:
self.queue.get_nowait()
self.queue.put_nowait(message)
logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息")
except asyncio.QueueEmpty:
pass
async def _process_loop(self) -> None:
"""消息处理循环"""
# 延迟导入,避免循环依赖
from .recv_handler.message_handler import message_handler
from .recv_handler.meta_event_handler import meta_event_handler
from .recv_handler.notice_handler import notice_handler
logger.info(f"Stream {self.stream_id} 处理循环启动")
try:
while self.is_running:
try:
# 等待消息1秒超时
message = await asyncio.wait_for(
self.queue.get(),
timeout=1.0
)
start_time = time.time()
# 处理消息
post_type = message.get("post_type")
if post_type == "message":
await message_handler.handle_raw_message(message)
elif post_type == "meta_event":
await meta_event_handler.handle_meta_event(message)
elif post_type == "notice":
await notice_handler.handle_notice(message)
else:
logger.warning(f"未知的 post_type: {post_type}")
processing_time = time.time() - start_time
# 更新统计
self.stats["total_messages"] += 1
self.stats["total_processing_time"] += processing_time
self.last_active_time = time.time()
self.queue.task_done()
# 性能监控每100条消息输出一次
if self.stats["total_messages"] % 100 == 0:
avg_time = self.stats["total_processing_time"] / self.stats["total_messages"]
logger.info(
f"Stream {self.stream_id[:30]}... 统计: "
f"消息数={self.stats['total_messages']}, "
f"平均耗时={avg_time:.3f}秒, "
f"队列长度={self.queue.qsize()}"
)
# 动态延迟:队列空时短暂休眠
if self.queue.qsize() == 0:
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
# 超时是正常的,继续循环
continue
except asyncio.CancelledError:
logger.info(f"Stream {self.stream_id} 处理循环被取消")
break
except Exception as e:
logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True)
# 继续处理下一条消息
await asyncio.sleep(0.1)
finally:
logger.info(f"Stream {self.stream_id} 处理循环结束")
def get_stats(self) -> dict:
"""获取性能统计"""
avg_time = (
self.stats["total_processing_time"] / self.stats["total_messages"]
if self.stats["total_messages"] > 0
else 0
)
return {
"stream_id": self.stream_id,
"queue_size": self.queue.qsize(),
"total_messages": self.stats["total_messages"],
"avg_processing_time": avg_time,
"queue_overflow_count": self.stats["queue_overflow_count"],
"last_active_time": self.last_active_time,
}
class StreamRouter:
"""流路由器
负责将消息路由到对应的聊天流队列
动态管理聊天流的生命周期
"""
def __init__(
self,
max_streams: int = 500,
stream_timeout: int = 600,
stream_queue_size: int = 100,
cleanup_interval: int = 60,
):
self.streams: Dict[str, StreamConsumer] = {}
self.lock = asyncio.Lock()
self.max_streams = max_streams
self.stream_timeout = stream_timeout
self.stream_queue_size = stream_queue_size
self.cleanup_interval = cleanup_interval
self.cleanup_task: Optional[asyncio.Task] = None
self.is_running = False
async def start(self) -> None:
"""启动路由器"""
if not self.is_running:
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(
f"StreamRouter 已启动 - "
f"最大流数: {self.max_streams}, "
f"超时: {self.stream_timeout}秒, "
f"队列大小: {self.stream_queue_size}"
)
async def stop(self) -> None:
"""停止路由器"""
self.is_running = False
if self.cleanup_task:
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
# 停止所有流消费者
logger.info(f"正在停止 {len(self.streams)} 个流消费者...")
for consumer in self.streams.values():
await consumer.stop()
self.streams.clear()
logger.info("StreamRouter 已停止")
async def route_message(self, message: dict) -> None:
"""路由消息到对应的流"""
stream_id = self._extract_stream_id(message)
# 快速路径:流已存在
if stream_id in self.streams:
await self.streams[stream_id].enqueue(message)
return
# 慢路径:需要创建新流
async with self.lock:
# 双重检查
if stream_id not in self.streams:
# 检查流数量限制
if len(self.streams) >= self.max_streams:
logger.warning(
f"达到最大流数量限制 ({self.max_streams})"
f"尝试清理不活跃的流..."
)
await self._cleanup_inactive_streams()
# 清理后仍然超限,记录警告但继续创建
if len(self.streams) >= self.max_streams:
logger.error(
f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})"
)
# 创建新流
consumer = StreamConsumer(stream_id, self.stream_queue_size)
self.streams[stream_id] = consumer
await consumer.start()
logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})")
await self.streams[stream_id].enqueue(message)
def _extract_stream_id(self, message: dict) -> str:
"""从消息中提取 stream_id
返回格式: platform:id:type
例如: qq:123456:group 或 qq:789012:private
"""
post_type = message.get("post_type")
# 非消息类型,使用默认流(避免创建过多流)
if post_type not in ["message", "notice"]:
return "system:meta_event"
# 消息类型
if post_type == "message":
message_type = message.get("message_type")
if message_type == "group":
group_id = message.get("group_id")
return f"qq:{group_id}:group"
elif message_type == "private":
user_id = message.get("user_id")
return f"qq:{user_id}:private"
# notice 类型
elif post_type == "notice":
group_id = message.get("group_id")
if group_id:
return f"qq:{group_id}:group"
user_id = message.get("user_id")
if user_id:
return f"qq:{user_id}:private"
# 未知类型,使用通用流
return "unknown:unknown"
async def _cleanup_inactive_streams(self) -> None:
"""清理不活跃的流"""
current_time = time.time()
to_remove = []
for stream_id, consumer in self.streams.items():
if current_time - consumer.last_active_time > self.stream_timeout:
to_remove.append(stream_id)
for stream_id in to_remove:
await self.streams[stream_id].stop()
del self.streams[stream_id]
logger.debug(f"清理不活跃的流: {stream_id}")
if to_remove:
logger.info(
f"清理了 {len(to_remove)} 个不活跃的流 "
f"(当前活跃流: {len(self.streams)}/{self.max_streams})"
)
async def _cleanup_loop(self) -> None:
"""定期清理循环"""
logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}")
try:
while self.is_running:
await asyncio.sleep(self.cleanup_interval)
await self._cleanup_inactive_streams()
except asyncio.CancelledError:
logger.info("清理循环已停止")
def get_all_stats(self) -> list[dict]:
"""获取所有流的统计信息"""
return [consumer.get_stats() for consumer in self.streams.values()]
def get_summary(self) -> dict:
"""获取路由器摘要"""
total_messages = sum(c.stats["total_messages"] for c in self.streams.values())
total_queue_size = sum(c.queue.qsize() for c in self.streams.values())
total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values())
# 计算平均队列长度
avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0
# 找出最繁忙的流
busiest_stream = None
if self.streams:
busiest_stream = max(
self.streams.values(),
key=lambda c: c.stats["total_messages"]
).stream_id
return {
"total_streams": len(self.streams),
"max_streams": self.max_streams,
"total_messages_processed": total_messages,
"total_queue_size": total_queue_size,
"avg_queue_size": avg_queue_size,
"total_queue_overflows": total_overflows,
"busiest_stream": busiest_stream,
}
# 全局路由器实例
stream_router = StreamRouter()

View File

@@ -236,8 +236,6 @@ class NapcatAdapterPlugin(BasePlugin):
def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, "_is_enabled"):
return self._is_enabled
# 否则使用默认值(禁用状态)
return False