feat: 添加带有消息处理和路由功能的NEW_napcat_adapter插件
- 为NEW_napcat_adapter插件实现了核心模块,包括消息处理、事件处理和路由。 - 创建了MessageHandler、MetaEventHandler和NoticeHandler来处理收到的消息和事件。 - 开发了SendHandler,用于向Napcat发送回消息。 引入了StreamRouter来管理多个聊天流,确保消息的顺序和高效处理。 - 增加了对各种消息类型和格式的支持,包括文本、图像和通知。 - 建立了一个用于监控和调试的日志系统。
This commit is contained in:
@@ -32,11 +32,11 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess
|
||||
|
||||
## 3. 消息模型
|
||||
|
||||
### 3.1 Envelope TypedDict(`types.py`)
|
||||
### 3.1 Envelope TypedDict<EFBFBD><EFBFBD>`types.py`<EFBFBD><EFBFBD>
|
||||
|
||||
- `MessageEnvelope`:核心字段包括 `id`、`direction`、`platform`、`timestamp_ms`、`channel`、`sender`、`content` 等,一律使用毫秒时间戳,保留 `raw_platform_message` 与 `metadata` 便于调试 / 扩展。
|
||||
- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。
|
||||
- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。
|
||||
- `MessageEnvelope` <20><>ȫ<EFBFBD><C8AB>Ƶ<EFBFBD> maim_message <20>ṹ<EFBFBD><E1B9B9><EFBFBD><EFBFBD><EFBFBD>ĵ<EFBFBD><C4B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> `message_info` + `message_segment` (SegPayload)<29><>`direction`<EFBFBD><EFBFBD>`schema_version` <20><> raw <20><><EFBFBD><EFBFBD><EFBFBD>ֶβ<D6B6><CEB2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ˣ<EFBFBD><CBA3><EFBFBD>Ժ<EFBFBD><D4BA><EFBFBD><EFBFBD><EFBFBD> `channel`<EFBFBD><EFBFBD>`sender`<EFBFBD><EFBFBD>`content` <EFBFBD><EFBFBD> v0 <20>ֶΪ<D6B6><CEAA>ѡ<EFBFBD><D1A1>
|
||||
- `SegPayload` / `MessageInfoPayload` / `UserInfoPayload` / `GroupInfoPayload` / `FormatInfoPayload` / `TemplateInfoPayload` <20><> maim_message dataclass <20>Դ<EFBFBD>TypedDict <20><>Ӧ<EFBFBD><D3A6><EFBFBD>ʺ<EFBFBD>ֱ<EFBFBD><D6B1> JSON <20><><EFBFBD><EFBFBD>
|
||||
- `Content` / `SenderInfo` / `ChannelInfo` <EFBFBD>Ȳ<EFBFBD>Ȼ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ܻ<EFBFBD><EFBFBD><EFBFBD> IDE ע<>⣬Ҳ<E2A3AC>Ƕ<EFBFBD> v0 content ģ<>͵Ļ<CDB5>֧
|
||||
|
||||
### 3.2 dataclass 消息段(`message_models.py`)
|
||||
|
||||
@@ -62,15 +62,14 @@ TypedDict 更适合网络传输和依赖注入;dataclass 版 MessageBase 则
|
||||
## 5. 运行时调度(`runtime.py`)
|
||||
|
||||
- `MessageRuntime`:
|
||||
- `add_route(predicate, handler)` 或 `@runtime.route(...)` 装饰器注册消息处理器。
|
||||
- `register_before_hook` / `register_after_hook` / `register_error_hook` 注入监控、埋点、Trace。
|
||||
- `set_batch_handler` 支持一次处理整批消息(例如批量落库)。
|
||||
- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。
|
||||
- `add_route(predicate, handler)` 和 `@runtime.route(...)` 装饰器注册消息处理器
|
||||
- `register_before_hook` / `register_after_hook` / `register_error_hook` 注册前置、后置、Trace 处理
|
||||
- `set_batch_handler` 支持一次处理一批消息(可用于 batch IO 优化)
|
||||
- `MessageProcessingError` 在 handler 抛出异常时包装原因,方便日志追踪。
|
||||
|
||||
运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。
|
||||
|
||||
---
|
||||
|
||||
## 6. 传输层封装(`transport/`)
|
||||
|
||||
### 6.1 HTTP
|
||||
@@ -126,9 +125,9 @@ from mofox_bus.transport import HttpMessageServer
|
||||
|
||||
runtime = MessageRuntime()
|
||||
|
||||
@runtime.route(lambda env: env["content"]["type"] == "text")
|
||||
@runtime.route(lambda env: (env.get("message_segment") or {}).get("type") == "text")
|
||||
async def handle_text(env: types.MessageEnvelope):
|
||||
print("收到文本:", env["content"]["text"])
|
||||
print("收到文本", env["message_segment"]["data"])
|
||||
|
||||
async def http_handler(messages: list[types.MessageEnvelope]):
|
||||
await runtime.handle_batch(messages)
|
||||
|
||||
35
src/common/core_sink.py
Normal file
35
src/common/core_sink.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
从 src.main 导出 core_sink 的辅助函数
|
||||
|
||||
由于 src.main 中实际使用的是 InProcessCoreSink,
|
||||
我们需要创建一个全局访问点
|
||||
"""
|
||||
|
||||
from mofox_bus import CoreSink, InProcessCoreSink
|
||||
|
||||
_global_core_sink: CoreSink | None = None
|
||||
|
||||
|
||||
def set_core_sink(sink: CoreSink) -> None:
|
||||
"""设置全局 core sink"""
|
||||
global _global_core_sink
|
||||
_global_core_sink = sink
|
||||
|
||||
|
||||
def get_core_sink() -> CoreSink:
|
||||
"""获取全局 core sink"""
|
||||
global _global_core_sink
|
||||
if _global_core_sink is None:
|
||||
raise RuntimeError("Core sink 尚未初始化")
|
||||
return _global_core_sink
|
||||
|
||||
|
||||
async def push_outgoing(envelope) -> None:
|
||||
"""将消息推送到 core sink 的 outgoing 通道"""
|
||||
sink = get_core_sink()
|
||||
push = getattr(sink, "push_outgoing", None)
|
||||
if push is None:
|
||||
raise RuntimeError("当前 core sink 不支持 push_outgoing 方法")
|
||||
await push(envelope)
|
||||
|
||||
__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"]
|
||||
@@ -1,15 +1,22 @@
|
||||
"""
|
||||
MessageEnvelope 转换器
|
||||
MessageEnvelope converter between mofox_bus schema and internal message structures.
|
||||
|
||||
将 mofox_bus 的 MessageEnvelope 转换为 MoFox Bot 内部使用的消息格式。
|
||||
"""
|
||||
- 优先处理 maim_message 风格的 message_info + message_segment。
|
||||
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mofox_bus import MessageEnvelope, BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
|
||||
from mofox_bus import (
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
MessageEnvelope,
|
||||
Seg,
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -17,151 +24,221 @@ logger = get_logger("envelope_converter")
|
||||
|
||||
|
||||
class EnvelopeConverter:
|
||||
"""MessageEnvelope 到内部消息格式的转换器"""
|
||||
"""MessageEnvelope <-> MessageBase converter."""
|
||||
|
||||
@staticmethod
|
||||
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
|
||||
"""
|
||||
将 MessageEnvelope 转换为 MessageBase
|
||||
|
||||
Args:
|
||||
envelope: 统一的消息信封
|
||||
|
||||
Returns:
|
||||
MessageBase: 内部消息格式
|
||||
Convert MessageEnvelope to MessageBase.
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
platform = envelope["platform"]
|
||||
channel = envelope["channel"]
|
||||
sender = envelope["sender"]
|
||||
content = envelope["content"]
|
||||
|
||||
# 创建 UserInfo
|
||||
user_info = UserInfo(
|
||||
user_id=sender["user_id"],
|
||||
user_nickname=sender.get("display_name", sender["user_id"]),
|
||||
user_avatar=sender.get("avatar_url"),
|
||||
)
|
||||
|
||||
# 创建 GroupInfo (如果是群组消息)
|
||||
group_info: Optional[GroupInfo] = None
|
||||
if channel["channel_type"] in ("group", "supergroup", "room"):
|
||||
group_info = GroupInfo(
|
||||
group_id=channel["channel_id"],
|
||||
group_name=channel.get("title", channel["channel_id"]),
|
||||
)
|
||||
|
||||
# 创建 BaseMessageInfo
|
||||
message_info = BaseMessageInfo(
|
||||
platform=platform,
|
||||
chat_type="group" if group_info else "private",
|
||||
message_id=envelope["id"],
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
timestamp=envelope["timestamp_ms"] / 1000.0, # 转换为秒
|
||||
)
|
||||
|
||||
# 转换 Content 为 Seg 列表
|
||||
segments = EnvelopeConverter._content_to_segments(content)
|
||||
|
||||
# 创建 MessageBase
|
||||
message_base = MessageBase(
|
||||
# 优先使用 maim_message 样式字段
|
||||
info_payload = envelope.get("message_info") or {}
|
||||
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
|
||||
|
||||
if info_payload:
|
||||
message_info = BaseMessageInfo.from_dict(info_payload)
|
||||
else:
|
||||
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
|
||||
|
||||
if seg_payload is None:
|
||||
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
|
||||
seg_payload = seg_list
|
||||
|
||||
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
|
||||
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
|
||||
|
||||
return MessageBase(
|
||||
message_info=message_info,
|
||||
message=segments,
|
||||
message_segment=message_segment,
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
# 保存原始 envelope 到 raw 字段
|
||||
if hasattr(message_base, "raw"):
|
||||
message_base.raw = envelope
|
||||
|
||||
return message_base
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _content_to_segments(content: Dict[str, Any]) -> List[Seg]:
|
||||
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
|
||||
"""将 legacy 字段映射为 BaseMessageInfo。"""
|
||||
platform = envelope.get("platform")
|
||||
channel = envelope.get("channel") or {}
|
||||
sender = envelope.get("sender") or {}
|
||||
|
||||
message_id = envelope.get("id") or envelope.get("message_id")
|
||||
timestamp_ms = envelope.get("timestamp_ms")
|
||||
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
|
||||
|
||||
group_info: Optional[GroupInfo] = None
|
||||
channel_type = channel.get("channel_type")
|
||||
if channel_type in ("group", "supergroup", "room"):
|
||||
group_info = GroupInfo(
|
||||
platform=platform,
|
||||
group_id=channel.get("channel_id"),
|
||||
group_name=channel.get("title"),
|
||||
)
|
||||
|
||||
user_info: Optional[UserInfo] = None
|
||||
if sender:
|
||||
user_info = UserInfo(
|
||||
platform=platform,
|
||||
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
|
||||
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
|
||||
user_avatar=sender.get("avatar_url"),
|
||||
)
|
||||
|
||||
return BaseMessageInfo(
|
||||
platform=platform,
|
||||
message_id=message_id,
|
||||
time=time_value,
|
||||
group_info=group_info,
|
||||
user_info=user_info,
|
||||
additional_config=envelope.get("metadata"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_seg(payload: Any) -> Seg:
|
||||
"""将任意 payload 转为 Seg dataclass。"""
|
||||
if isinstance(payload, Seg):
|
||||
return payload
|
||||
if isinstance(payload, list):
|
||||
# 直接传入 Seg 列表或 seglist data
|
||||
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
|
||||
if isinstance(payload, dict):
|
||||
seg_type = payload.get("type") or "text"
|
||||
data = payload.get("data")
|
||||
if seg_type == "seglist" and isinstance(data, list):
|
||||
data = [EnvelopeConverter._ensure_seg(item) for item in data]
|
||||
return Seg(type=seg_type, data=data)
|
||||
# 兜底:转成文本片段
|
||||
return Seg(type="text", data="" if payload is None else str(payload))
|
||||
|
||||
@staticmethod
|
||||
def _flatten_segments(seg: Seg) -> List[Seg]:
|
||||
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
|
||||
if seg.type == "seglist" and isinstance(seg.data, list):
|
||||
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
|
||||
return [seg]
|
||||
|
||||
@staticmethod
|
||||
def _content_to_segments(content: Any) -> List[Seg]:
|
||||
"""
|
||||
将 Content 转换为 Seg 列表
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
|
||||
Returns:
|
||||
List[Seg]: 消息段列表
|
||||
Convert legacy Content (type/data/metadata) to a flat list of Seg.
|
||||
"""
|
||||
segments: List[Seg] = []
|
||||
content_type = content.get("type")
|
||||
|
||||
if content_type == "text":
|
||||
# 文本消息
|
||||
text = content.get("text", "")
|
||||
segments.append(Seg.text(text))
|
||||
|
||||
elif content_type == "image":
|
||||
# 图片消息
|
||||
url = content.get("url", "")
|
||||
file_id = content.get("file_id")
|
||||
segments.append(Seg.image(url if url else file_id))
|
||||
|
||||
elif content_type == "audio":
|
||||
# 音频消息
|
||||
url = content.get("url", "")
|
||||
file_id = content.get("file_id")
|
||||
segments.append(Seg.record(url if url else file_id))
|
||||
|
||||
elif content_type == "video":
|
||||
# 视频消息
|
||||
url = content.get("url", "")
|
||||
file_id = content.get("file_id")
|
||||
segments.append(Seg.video(url if url else file_id))
|
||||
|
||||
elif content_type == "file":
|
||||
# 文件消息
|
||||
url = content.get("url", "")
|
||||
file_name = content.get("file_name", "file")
|
||||
# 使用 text 表示文件(或者可以自定义一个 file seg type)
|
||||
segments.append(Seg.text(f"[文件: {file_name}]"))
|
||||
|
||||
elif content_type == "command":
|
||||
# 命令消息
|
||||
name = content.get("name", "")
|
||||
args = content.get("args", {})
|
||||
# 重构为文本格式
|
||||
cmd_text = f"/{name}"
|
||||
if args:
|
||||
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
||||
segments.append(Seg.text(cmd_text))
|
||||
|
||||
elif content_type == "event":
|
||||
# 事件消息 - 转换为文本表示
|
||||
event_type = content.get("event_type", "unknown")
|
||||
segments.append(Seg.text(f"[事件: {event_type}]"))
|
||||
|
||||
elif content_type == "system":
|
||||
# 系统消息
|
||||
text = content.get("text", "")
|
||||
segments.append(Seg.text(f"[系统] {text}"))
|
||||
|
||||
else:
|
||||
# 未知类型 - 转换为文本
|
||||
|
||||
def _walk(node: Any) -> None:
|
||||
if node is None:
|
||||
return
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
_walk(item)
|
||||
return
|
||||
if not isinstance(node, dict):
|
||||
logger.warning("未知的 content 节点类型: %s", type(node))
|
||||
return
|
||||
|
||||
content_type = node.get("type")
|
||||
data = node.get("data")
|
||||
metadata = node.get("metadata") or {}
|
||||
|
||||
if content_type == "collection":
|
||||
items = data if isinstance(data, list) else node.get("items", [])
|
||||
for item in items:
|
||||
_walk(item)
|
||||
return
|
||||
|
||||
if content_type in ("text", "at"):
|
||||
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
|
||||
text = "" if data is None else str(data)
|
||||
if subtype in ("at", "mention"):
|
||||
user_info = metadata.get("user") or {}
|
||||
seg_data: Dict[str, Any] = {
|
||||
"user_id": user_info.get("id") or user_info.get("user_id"),
|
||||
"user_name": user_info.get("name") or user_info.get("display_name"),
|
||||
"text": text,
|
||||
"raw": user_info.get("raw") or user_info if user_info else None,
|
||||
}
|
||||
if any(v is not None for v in seg_data.values()):
|
||||
segments.append(Seg(type="at", data=seg_data))
|
||||
else:
|
||||
segments.append(Seg(type="at", data=text))
|
||||
else:
|
||||
segments.append(Seg(type="text", data=text))
|
||||
return
|
||||
|
||||
if content_type == "image":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="image", data=url))
|
||||
return
|
||||
|
||||
if content_type == "audio":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="record", data=url))
|
||||
return
|
||||
|
||||
if content_type == "video":
|
||||
url = ""
|
||||
if isinstance(data, dict):
|
||||
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||
elif data is not None:
|
||||
url = str(data)
|
||||
segments.append(Seg(type="video", data=url))
|
||||
return
|
||||
|
||||
if content_type == "file":
|
||||
file_name = ""
|
||||
if isinstance(data, dict):
|
||||
file_name = data.get("file_name") or data.get("name") or ""
|
||||
text = file_name or "[file]"
|
||||
segments.append(Seg(type="text", data=text))
|
||||
return
|
||||
|
||||
if content_type == "command":
|
||||
name = ""
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(data, dict):
|
||||
name = data.get("name", "")
|
||||
args = data.get("args", {}) or {}
|
||||
else:
|
||||
name = str(data or "")
|
||||
cmd_text = f"/{name}" if name else "/command"
|
||||
if args:
|
||||
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
||||
segments.append(Seg(type="text", data=cmd_text))
|
||||
return
|
||||
|
||||
if content_type == "event":
|
||||
event_type = ""
|
||||
if isinstance(data, dict):
|
||||
event_type = data.get("event_type", "")
|
||||
else:
|
||||
event_type = str(data or "")
|
||||
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
|
||||
return
|
||||
|
||||
if content_type == "system":
|
||||
text = "" if data is None else str(data)
|
||||
segments.append(Seg(type="text", data=f"[系统] {text}"))
|
||||
return
|
||||
|
||||
logger.warning(f"未知的消息类型: {content_type}")
|
||||
segments.append(Seg.text(f"[未知消息类型: {content_type}]"))
|
||||
|
||||
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
|
||||
|
||||
_walk(content)
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
|
||||
"""
|
||||
将 MessageEnvelope 转换为旧版字典格式(用于向后兼容)
|
||||
|
||||
Args:
|
||||
envelope: 统一的消息信封
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 旧版消息字典
|
||||
Convert MessageEnvelope to legacy dict for backward compatibility.
|
||||
"""
|
||||
message_base = EnvelopeConverter.to_message_base(envelope)
|
||||
return message_base.to_dict()
|
||||
@@ -169,61 +246,45 @@ class EnvelopeConverter:
|
||||
@staticmethod
|
||||
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
|
||||
"""
|
||||
将 MessageBase 转换为 MessageEnvelope (反向转换)
|
||||
|
||||
Args:
|
||||
message: 内部消息格式
|
||||
direction: 消息方向 ("incoming" 或 "outgoing")
|
||||
|
||||
Returns:
|
||||
MessageEnvelope: 统一的消息信封
|
||||
Convert MessageBase to MessageEnvelope (maim_message style preferred).
|
||||
"""
|
||||
try:
|
||||
message_info = message.message_info
|
||||
user_info = message_info.user_info
|
||||
group_info = message_info.group_info
|
||||
|
||||
# 创建 SenderInfo
|
||||
sender = {
|
||||
"user_id": user_info.user_id,
|
||||
"role": "assistant" if direction == "outgoing" else "user",
|
||||
}
|
||||
if user_info.user_nickname:
|
||||
sender["display_name"] = user_info.user_nickname
|
||||
if user_info.user_avatar:
|
||||
sender["avatar_url"] = user_info.user_avatar
|
||||
|
||||
# 创建 ChannelInfo
|
||||
if group_info:
|
||||
channel = {
|
||||
"channel_id": group_info.group_id,
|
||||
"channel_type": "group",
|
||||
}
|
||||
if group_info.group_name:
|
||||
channel["title"] = group_info.group_name
|
||||
else:
|
||||
channel = {
|
||||
"channel_id": user_info.user_id,
|
||||
"channel_type": "private",
|
||||
}
|
||||
|
||||
# 转换 segments 为 Content
|
||||
content = EnvelopeConverter._segments_to_content(message.message)
|
||||
|
||||
# 创建 MessageEnvelope
|
||||
info_dict = message.message_info.to_dict()
|
||||
seg_dict = message.message_segment.to_dict()
|
||||
|
||||
envelope: MessageEnvelope = {
|
||||
"id": message_info.message_id,
|
||||
"direction": direction,
|
||||
"platform": message_info.platform,
|
||||
"timestamp_ms": int(message_info.timestamp * 1000),
|
||||
"channel": channel,
|
||||
"sender": sender,
|
||||
"content": content,
|
||||
"conversation_id": group_info.group_id if group_info else user_info.user_id,
|
||||
"message_info": info_dict,
|
||||
"message_segment": seg_dict,
|
||||
"platform": info_dict.get("platform"),
|
||||
"message_id": info_dict.get("message_id"),
|
||||
"schema_version": 1,
|
||||
}
|
||||
|
||||
|
||||
if message.message_info.time is not None:
|
||||
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
|
||||
if message.raw_message is not None:
|
||||
envelope["raw_message"] = message.raw_message
|
||||
|
||||
# legacy 补充,方便老代码继续工作
|
||||
segments = EnvelopeConverter._flatten_segments(message.message_segment)
|
||||
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
|
||||
if message.message_info.user_info:
|
||||
envelope["sender"] = {
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"role": "assistant" if direction == "outgoing" else "user",
|
||||
"display_name": message.message_info.user_info.user_nickname,
|
||||
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
|
||||
}
|
||||
if message.message_info.group_info:
|
||||
envelope["channel"] = {
|
||||
"channel_id": message.message_info.group_info.group_id,
|
||||
"channel_type": "group",
|
||||
"title": message.message_info.group_info.group_name,
|
||||
}
|
||||
|
||||
return envelope
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -231,45 +292,50 @@ class EnvelopeConverter:
|
||||
@staticmethod
|
||||
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Seg 列表转换为 Content
|
||||
|
||||
Args:
|
||||
segments: 消息段列表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 消息内容
|
||||
Convert Seg list to legacy Content (type/data/metadata).
|
||||
"""
|
||||
if not segments:
|
||||
return {"type": "text", "text": ""}
|
||||
|
||||
# 简化处理:如果有多个段,合并为文本
|
||||
return {"type": "text", "data": ""}
|
||||
|
||||
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
|
||||
data = seg.data
|
||||
|
||||
if seg.type == "text":
|
||||
return {"type": "text", "data": data}
|
||||
|
||||
if seg.type == "at":
|
||||
content: Dict[str, Any] = {"type": "text", "data": ""}
|
||||
metadata: Dict[str, Any] = {"subtype": "at"}
|
||||
if isinstance(data, dict):
|
||||
content["data"] = data.get("text", "")
|
||||
user = {
|
||||
"id": data.get("user_id"),
|
||||
"name": data.get("user_name"),
|
||||
"raw": data.get("raw"),
|
||||
}
|
||||
if any(v is not None for v in user.values()):
|
||||
metadata["user"] = user
|
||||
else:
|
||||
content["data"] = data
|
||||
if metadata:
|
||||
content["metadata"] = metadata
|
||||
return content
|
||||
|
||||
if seg.type == "image":
|
||||
return {"type": "image", "data": data}
|
||||
|
||||
if seg.type in ("record", "voice", "audio"):
|
||||
return {"type": "audio", "data": data}
|
||||
|
||||
if seg.type == "video":
|
||||
return {"type": "video", "data": data}
|
||||
|
||||
return {"type": seg.type, "data": data}
|
||||
|
||||
if len(segments) == 1:
|
||||
seg = segments[0]
|
||||
|
||||
if seg.type == "text":
|
||||
return {"type": "text", "text": seg.data.get("text", "")}
|
||||
elif seg.type == "image":
|
||||
return {"type": "image", "url": seg.data.get("file", "")}
|
||||
elif seg.type == "record":
|
||||
return {"type": "audio", "url": seg.data.get("file", "")}
|
||||
elif seg.type == "video":
|
||||
return {"type": "video", "url": seg.data.get("file", "")}
|
||||
|
||||
# 多个段或未知类型 - 合并为文本
|
||||
text_parts = []
|
||||
for seg in segments:
|
||||
if seg.type == "text":
|
||||
text_parts.append(seg.data.get("text", ""))
|
||||
elif seg.type == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif seg.type == "record":
|
||||
text_parts.append("[语音]")
|
||||
elif seg.type == "video":
|
||||
text_parts.append("[视频]")
|
||||
else:
|
||||
text_parts.append(f"[{seg.type}]")
|
||||
|
||||
return {"type": "text", "text": "".join(text_parts)}
|
||||
return _seg_to_content(segments[0])
|
||||
|
||||
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
|
||||
|
||||
|
||||
__all__ = ["EnvelopeConverter"]
|
||||
|
||||
@@ -10,72 +10,54 @@ from .adapter_utils import (
|
||||
AdapterTransportOptions,
|
||||
AdapterBase,
|
||||
BatchDispatcher,
|
||||
CoreSink,
|
||||
CoreMessageSink,
|
||||
HttpAdapterOptions,
|
||||
InProcessCoreSink,
|
||||
ProcessCoreSink,
|
||||
ProcessCoreSinkServer,
|
||||
WebSocketLike,
|
||||
WebSocketAdapterOptions,
|
||||
)
|
||||
from .api import MessageClient, MessageServer
|
||||
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
|
||||
from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo
|
||||
from .builder import MessageBuilder
|
||||
from .router import RouteConfig, Router, TargetConfig
|
||||
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime
|
||||
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware
|
||||
from .types import (
|
||||
AudioContent,
|
||||
ChannelInfo,
|
||||
CommandContent,
|
||||
Content,
|
||||
ContentType,
|
||||
EventContent,
|
||||
EventType,
|
||||
FileContent,
|
||||
ImageContent,
|
||||
FormatInfoPayload,
|
||||
GroupInfoPayload,
|
||||
MessageDirection,
|
||||
MessageEnvelope,
|
||||
Role,
|
||||
SenderInfo,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
VideoContent,
|
||||
MessageInfoPayload,
|
||||
SegPayload,
|
||||
|
||||
TemplateInfoPayload,
|
||||
UserInfoPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# TypedDict model
|
||||
"AudioContent",
|
||||
"ChannelInfo",
|
||||
"CommandContent",
|
||||
"Content",
|
||||
"ContentType",
|
||||
"EventContent",
|
||||
"EventType",
|
||||
"FileContent",
|
||||
"ImageContent",
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
"Role",
|
||||
"SenderInfo",
|
||||
"SystemContent",
|
||||
"TextContent",
|
||||
"VideoContent",
|
||||
"SegPayload",
|
||||
"UserInfoPayload",
|
||||
"GroupInfoPayload",
|
||||
"FormatInfoPayload",
|
||||
"TemplateInfoPayload",
|
||||
"MessageInfoPayload",
|
||||
# Codec helpers
|
||||
"codec",
|
||||
"dumps_message",
|
||||
"dumps_messages",
|
||||
"loads_message",
|
||||
"loads_messages",
|
||||
"MessageBuilder",
|
||||
# Runtime / routing
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"MessageProcessingError",
|
||||
# Message dataclasses
|
||||
"Seg",
|
||||
"GroupInfo",
|
||||
"UserInfo",
|
||||
"FormatInfo",
|
||||
"TemplateInfo",
|
||||
"BaseMessageInfo",
|
||||
"MessageBase",
|
||||
"Middleware",
|
||||
# Server/client/router
|
||||
"MessageServer",
|
||||
"MessageClient",
|
||||
@@ -86,8 +68,11 @@ __all__ = [
|
||||
"AdapterTransportOptions",
|
||||
"AdapterBase",
|
||||
"BatchDispatcher",
|
||||
"CoreSink",
|
||||
"CoreMessageSink",
|
||||
"InProcessCoreSink",
|
||||
"ProcessCoreSink",
|
||||
"ProcessCoreSinkServer",
|
||||
"WebSocketLike",
|
||||
"WebSocketAdapterOptions",
|
||||
"HttpAdapterOptions",
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
|
||||
|
||||
@@ -11,6 +13,11 @@ import websockets
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
logger = logging.getLogger("mofox_bus.adapter")
|
||||
|
||||
|
||||
OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class CoreMessageSink(Protocol):
|
||||
async def send(self, message: MessageEnvelope) -> None: ...
|
||||
@@ -18,6 +25,22 @@ class CoreMessageSink(Protocol):
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
|
||||
|
||||
|
||||
class CoreSink(CoreMessageSink, Protocol):
|
||||
"""
|
||||
双向 CoreSink 协议:
|
||||
- send/send_many: 适配器 → 核心(incoming)
|
||||
- push_outgoing: 核心 → 适配器(outgoing)
|
||||
"""
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ...
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ...
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None: ...
|
||||
|
||||
async def close(self) -> None: ... # pragma: no cover - lifecycle hook
|
||||
|
||||
|
||||
class WebSocketLike(Protocol):
|
||||
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
|
||||
|
||||
@@ -56,7 +79,7 @@ class AdapterBase:
|
||||
|
||||
platform: str = "unknown"
|
||||
|
||||
def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None):
|
||||
def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None):
|
||||
"""
|
||||
Args:
|
||||
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
|
||||
@@ -70,14 +93,31 @@ class AdapterBase:
|
||||
self._http_site: aiohttp_web.BaseSite | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""根据配置自动启动 WS/HTTP 监听。"""
|
||||
"""启动适配器的传输层监听(如果配置了传输选项)。"""
|
||||
if hasattr(self.core_sink, "set_outgoing_handler"):
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("Failed to register outgoing handler on core sink")
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._start_ws_transport(self._transport_config)
|
||||
elif isinstance(self._transport_config, HttpAdapterOptions):
|
||||
await self._start_http_transport(self._transport_config)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止自动管理的传输层。"""
|
||||
"""停止适配器的传输层监听(如果配置了传输选项)。"""
|
||||
remove = getattr(self.core_sink, "remove_outgoing_handler", None)
|
||||
if callable(remove):
|
||||
try:
|
||||
remove(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("Failed to detach outgoing handler on core sink")
|
||||
elif hasattr(self.core_sink, "set_outgoing_handler"):
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
logger.exception("Failed to detach outgoing handler on core sink")
|
||||
if self._ws_task:
|
||||
self._ws_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
@@ -95,12 +135,12 @@ class AdapterBase:
|
||||
|
||||
async def on_platform_message(self, raw: Any) -> None:
|
||||
"""处理平台下发的单条消息并交给核心。"""
|
||||
envelope = self.from_platform_message(raw)
|
||||
envelope = await _maybe_await(self.from_platform_message(raw))
|
||||
await self.core_sink.send(envelope)
|
||||
|
||||
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
||||
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
||||
envelopes = [self.from_platform_message(raw) for raw in raw_messages]
|
||||
envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages]
|
||||
await _send_many(self.core_sink, envelopes)
|
||||
|
||||
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
||||
@@ -112,7 +152,14 @@ class AdapterBase:
|
||||
for env in envelopes:
|
||||
await self._send_platform_message(env)
|
||||
|
||||
def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None:
|
||||
"""核心生成 outgoing envelope 时的内部处理逻辑"""
|
||||
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
|
||||
if platform and platform != getattr(self, "platform", None):
|
||||
return
|
||||
await self._send_platform_message(envelope)
|
||||
|
||||
def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]:
|
||||
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -175,13 +222,22 @@ class AdapterBase:
|
||||
return orjson.dumps({"type": "send", "payload": envelope})
|
||||
|
||||
|
||||
class InProcessCoreSink:
|
||||
class InProcessCoreSink(CoreSink):
|
||||
"""
|
||||
简单的进程内 sink,实现 CoreMessageSink 协议。
|
||||
进程内核心消息 sink,实现 CoreSink 协议。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
|
||||
self._handler = handler
|
||||
self._outgoing_handlers: set[OutgoingHandler] = set()
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||
if handler is None:
|
||||
return
|
||||
self._outgoing_handlers.add(handler)
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||
self._outgoing_handlers.discard(handler)
|
||||
|
||||
async def send(self, message: MessageEnvelope) -> None:
|
||||
await self._handler(message)
|
||||
@@ -190,6 +246,140 @@ class InProcessCoreSink:
|
||||
for message in messages:
|
||||
await self._handler(message)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
if not self._outgoing_handlers:
|
||||
logger.debug("Outgoing envelope dropped: no handler registered")
|
||||
return
|
||||
for callback in list(self._outgoing_handlers):
|
||||
await callback(envelope)
|
||||
|
||||
async def close(self) -> None: # pragma: no cover - symmetry
|
||||
self._outgoing_handlers.clear()
|
||||
|
||||
|
||||
class ProcessCoreSink(CoreSink):
|
||||
"""
|
||||
进程间核心消息 sink,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化
|
||||
"""
|
||||
|
||||
_CONTROL_STOP = {"__core_sink_control__": "stop"}
|
||||
|
||||
def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None:
|
||||
self._to_core_queue = to_core_queue
|
||||
self._from_core_queue = from_core_queue
|
||||
self._outgoing_handler: OutgoingHandler | None = None
|
||||
self._closed = False
|
||||
self._listener_task: asyncio.Task | None = None
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||
self._outgoing_handler = handler
|
||||
if handler is not None and (self._listener_task is None or self._listener_task.done()):
|
||||
self._listener_task = self._loop.create_task(self._listen_from_core())
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||
if self._outgoing_handler is handler:
|
||||
self._outgoing_handler = None
|
||||
if self._listener_task and not self._listener_task.done():
|
||||
self._listener_task.cancel()
|
||||
|
||||
async def send(self, message: MessageEnvelope) -> None:
|
||||
await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message})
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None:
|
||||
for message in messages:
|
||||
await self.send(message)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
logger.debug("ProcessCoreSink.push_outgoing called in child; ignored")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP)
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._listener_task
|
||||
self._listener_task = None
|
||||
|
||||
async def _listen_from_core(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
item = await asyncio.to_thread(self._from_core_queue.get)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if item == self._CONTROL_STOP:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("kind") == "outgoing":
|
||||
envelope = item.get("payload")
|
||||
if self._outgoing_handler:
|
||||
try:
|
||||
await self._outgoing_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to handle outgoing envelope in ProcessCoreSink")
|
||||
else:
|
||||
logger.debug("ProcessCoreSink received unknown payload: %r", item)
|
||||
|
||||
|
||||
class ProcessCoreSinkServer:
|
||||
"""
|
||||
进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。
|
||||
- 将传入的 incoming 消息转发给指定的 handler
|
||||
- 将接收到的 outgoing 消息放入 outgoing 队列
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
incoming_queue: mp.Queue,
|
||||
outgoing_queue: mp.Queue,
|
||||
core_handler: Callable[[MessageEnvelope], Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
self._incoming_queue = incoming_queue
|
||||
self._outgoing_queue = outgoing_queue
|
||||
self._core_handler = core_handler
|
||||
self._task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._name = name or "adapter"
|
||||
|
||||
def start(self) -> None:
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._consume_incoming())
|
||||
|
||||
async def _consume_incoming(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
item = await asyncio.to_thread(self._incoming_queue.get)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("__core_sink_control__") == "stop":
|
||||
break
|
||||
if isinstance(item, dict) and item.get("kind") == "incoming":
|
||||
envelope = item.get("payload")
|
||||
try:
|
||||
await self._core_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to dispatch incoming envelope from %s", self._name)
|
||||
else:
|
||||
logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"})
|
||||
await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP)
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
|
||||
send_many = getattr(sink, "send_many", None)
|
||||
@@ -200,11 +390,19 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) ->
|
||||
await sink.send(env)
|
||||
|
||||
|
||||
async def _maybe_await(result: Any) -> Any:
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
class BatchDispatcher:
|
||||
"""
|
||||
将 send 操作合并为批量发送,适合网络 IO 密集场景。
|
||||
批量消息分发器,负责将消息批量发送到核心 sink。
|
||||
"""
|
||||
|
||||
_STOP = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sink: CoreMessageSink,
|
||||
@@ -215,56 +413,79 @@ class BatchDispatcher:
|
||||
self._sink = sink
|
||||
self._max_batch_size = max_batch_size
|
||||
self._flush_interval = flush_interval
|
||||
self._buffer: list[MessageEnvelope] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue()
|
||||
self._worker: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def add(self, message: MessageEnvelope) -> None:
|
||||
async with self._lock:
|
||||
if self._closed:
|
||||
raise RuntimeError("Dispatcher closed")
|
||||
self._buffer.append(message)
|
||||
self._ensure_timer()
|
||||
if len(self._buffer) >= self._max_batch_size:
|
||||
await self._flush_locked()
|
||||
if self._closed:
|
||||
raise RuntimeError("Dispatcher closed")
|
||||
self._ensure_worker()
|
||||
await self._queue.put(message)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._lock:
|
||||
self._closed = True
|
||||
await self._flush_locked()
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
self._flush_task = None
|
||||
|
||||
def _ensure_timer(self) -> None:
|
||||
if self._flush_task is not None and not self._flush_task.done():
|
||||
if self._closed:
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
self._flush_task = loop.create_task(self._flush_loop())
|
||||
self._closed = True
|
||||
self._ensure_worker()
|
||||
await self._queue.put(self._STOP)
|
||||
if self._worker:
|
||||
await self._worker
|
||||
self._worker = None
|
||||
|
||||
async def _flush_loop(self) -> None:
|
||||
def _ensure_worker(self) -> None:
|
||||
if self._worker is not None and not self._worker.done():
|
||||
return
|
||||
self._worker = asyncio.create_task(self._worker_loop())
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
buffer: list[MessageEnvelope] = []
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
async with self._lock:
|
||||
await self._flush_locked()
|
||||
except asyncio.CancelledError: # pragma: no cover - timer cancellation
|
||||
pass
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval)
|
||||
except asyncio.TimeoutError:
|
||||
item = None
|
||||
|
||||
async def _flush_locked(self) -> None:
|
||||
if not self._buffer:
|
||||
if item is self._STOP:
|
||||
await self._flush_buffer(buffer)
|
||||
return
|
||||
if item is not None:
|
||||
buffer.append(item) # type: ignore[arg-type]
|
||||
|
||||
while len(buffer) < self._max_batch_size:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if item is self._STOP:
|
||||
await self._flush_buffer(buffer)
|
||||
return
|
||||
buffer.append(item) # type: ignore[arg-type]
|
||||
|
||||
if buffer and (len(buffer) >= self._max_batch_size or item is None):
|
||||
await self._flush_buffer(buffer)
|
||||
except asyncio.CancelledError: # pragma: no cover - worker cancellation
|
||||
if buffer:
|
||||
await self._flush_buffer(buffer)
|
||||
|
||||
async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
payload = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
await self._sink.send_many(payload)
|
||||
|
||||
payload = list(buffer)
|
||||
buffer.clear()
|
||||
await _send_many(self._sink, payload)
|
||||
|
||||
__all__ = [
|
||||
"AdapterTransportOptions",
|
||||
"AdapterBase",
|
||||
"BatchDispatcher",
|
||||
"CoreSink",
|
||||
"CoreMessageSink",
|
||||
"HttpAdapterOptions",
|
||||
"InProcessCoreSink",
|
||||
"ProcessCoreSink",
|
||||
"ProcessCoreSinkServer",
|
||||
"WebSocketLike",
|
||||
"WebSocketAdapterOptions",
|
||||
]
|
||||
|
||||
@@ -11,10 +11,36 @@ import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
from .message_models import MessageBase
|
||||
|
||||
MessagePayload = Dict[str, Any]
|
||||
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
|
||||
DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
|
||||
|
||||
|
||||
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
|
||||
if isinstance(payload, dict):
|
||||
payload.setdefault("raw_bytes", raw_bytes)
|
||||
elif isinstance(payload, list):
|
||||
for item in payload:
|
||||
if isinstance(item, dict):
|
||||
item.setdefault("raw_bytes", raw_bytes)
|
||||
return payload
|
||||
|
||||
|
||||
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
|
||||
if isinstance(message, (bytes, bytearray)):
|
||||
return bytes(message), True
|
||||
if use_raw_bytes and isinstance(message, dict):
|
||||
raw = message.get("raw_bytes")
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
return bytes(raw), True
|
||||
payload = message
|
||||
if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes:
|
||||
payload = {k: v for k, v in payload.items() if k != "raw_bytes"}
|
||||
data = orjson.dumps(payload)
|
||||
if use_raw_bytes:
|
||||
return data, True
|
||||
return data.decode("utf-8"), False
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
@@ -60,6 +86,8 @@ class MessageServer(BaseMessageHandler):
|
||||
mode: Literal["ws", "tcp"] = "ws",
|
||||
custom_logger: logging.Logger | None = None,
|
||||
enable_custom_uvicorn_logger: bool = False,
|
||||
queue_maxsize: int = 1000,
|
||||
worker_count: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
@@ -80,6 +108,9 @@ class MessageServer(BaseMessageHandler):
|
||||
self._conn_lock = asyncio.Lock()
|
||||
self._server: uvicorn.Server | None = None
|
||||
self._running = False
|
||||
self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self._worker_count = max(1, worker_count)
|
||||
self._worker_tasks: list[asyncio.Task] = []
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
@@ -97,21 +128,22 @@ class MessageServer(BaseMessageHandler):
|
||||
while True:
|
||||
msg = await websocket.receive()
|
||||
if msg["type"] == "websocket.receive":
|
||||
data = msg.get("text")
|
||||
if data is None and msg.get("bytes") is not None:
|
||||
data = msg["bytes"].decode("utf-8")
|
||||
if not data:
|
||||
raw_bytes = msg.get("bytes")
|
||||
if raw_bytes is None and msg.get("text") is not None:
|
||||
raw_bytes = msg["text"].encode("utf-8")
|
||||
if not raw_bytes:
|
||||
continue
|
||||
try:
|
||||
payload = orjson.loads(data)
|
||||
payload = orjson.loads(raw_bytes)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
|
||||
continue
|
||||
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self.process_message(item)
|
||||
await self._enqueue_message(item)
|
||||
else:
|
||||
await self.process_message(payload)
|
||||
await self._enqueue_message(payload)
|
||||
elif msg["type"] == "websocket.disconnect":
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
@@ -119,6 +151,49 @@ class MessageServer(BaseMessageHandler):
|
||||
finally:
|
||||
await self._remove_connection(websocket, platform)
|
||||
|
||||
async def _enqueue_message(self, payload: MessagePayload) -> None:
|
||||
if not self._worker_tasks:
|
||||
self._start_workers()
|
||||
try:
|
||||
self._message_queue.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message")
|
||||
|
||||
def _start_workers(self) -> None:
|
||||
if self._worker_tasks:
|
||||
return
|
||||
self._running = True
|
||||
for _ in range(self._worker_count):
|
||||
task = asyncio.create_task(self._consumer_worker())
|
||||
self._worker_tasks.append(task)
|
||||
|
||||
async def _stop_workers(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
self._running = False
|
||||
for task in self._worker_tasks:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
||||
self._worker_tasks.clear()
|
||||
while not self._message_queue.empty():
|
||||
with contextlib.suppress(asyncio.QueueEmpty):
|
||||
self._message_queue.get_nowait()
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def _consumer_worker(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
payload = await self._message_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
try:
|
||||
await self.process_message(payload)
|
||||
except Exception: # pragma: no cover - best effort logging
|
||||
logging.getLogger("mofox_bus.server").exception("Error processing message")
|
||||
finally:
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def verify_token(self, token: str | None) -> bool:
|
||||
if not self._enable_token:
|
||||
return True
|
||||
@@ -145,33 +220,45 @@ class MessageServer(BaseMessageHandler):
|
||||
if platform and self._platform_connections.get(platform) is websocket:
|
||||
del self._platform_connections[platform]
|
||||
|
||||
async def broadcast_message(self, message: MessagePayload) -> None:
|
||||
data = orjson.dumps(message).decode("utf-8")
|
||||
async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None:
|
||||
payload: MessagePayload | bytes = message
|
||||
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||
async with self._conn_lock:
|
||||
targets = list(self._connections)
|
||||
for ws in targets:
|
||||
await ws.send_text(data)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
|
||||
async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None:
|
||||
async def broadcast_to_platform(
|
||||
self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False
|
||||
) -> None:
|
||||
ws = self._platform_connections.get(platform)
|
||||
if ws is None:
|
||||
raise RuntimeError(f"No active connection for platform {platform}")
|
||||
await ws.send_text(orjson.dumps(message).decode("utf-8"))
|
||||
payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message
|
||||
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
|
||||
async def send_message(self, message: MessageBase | MessagePayload) -> None:
|
||||
payload = message.to_dict() if isinstance(message, MessageBase) else message
|
||||
platform = payload.get("message_info", {}).get("platform")
|
||||
async def send_message(
|
||||
self, message: MessagePayload, *, prefer_raw_bytes: bool = False
|
||||
) -> None:
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required to route the message")
|
||||
await self.broadcast_to_platform(platform, payload)
|
||||
|
||||
await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes)
|
||||
|
||||
def run_sync(self) -> None:
|
||||
if not self._own_app:
|
||||
return
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self) -> None:
|
||||
self._running = True
|
||||
self._start_workers()
|
||||
if not self._own_app:
|
||||
return
|
||||
config = uvicorn.Config(
|
||||
@@ -191,6 +278,7 @@ class MessageServer(BaseMessageHandler):
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
await self._stop_workers()
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
@@ -217,7 +305,13 @@ class MessageClient(BaseMessageHandler):
|
||||
WebSocket 消息客户端,实现双向传输。
|
||||
"""
|
||||
|
||||
def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
mode: Literal["ws", "tcp"] = "ws",
|
||||
*,
|
||||
reconnect_interval: float = 5.0,
|
||||
logger: logging.Logger | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||
@@ -230,6 +324,9 @@ class MessageClient(BaseMessageHandler):
|
||||
self._token: str | None = None
|
||||
self._ssl_verify: str | None = None
|
||||
self._closed = False
|
||||
self._on_disconnect: DisconnectCallback | None = None
|
||||
self._reconnect_interval = reconnect_interval
|
||||
self._logger = logger or logging.getLogger("mofox_bus.client")
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
@@ -243,8 +340,12 @@ class MessageClient(BaseMessageHandler):
|
||||
self._platform = platform
|
||||
self._token = token
|
||||
self._ssl_verify = ssl_verify
|
||||
self._closed = False
|
||||
await self._establish_connection()
|
||||
|
||||
def set_disconnect_callback(self, callback: DisconnectCallback) -> None:
|
||||
self._on_disconnect = callback
|
||||
|
||||
async def _establish_connection(self) -> None:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
@@ -257,17 +358,21 @@ class MessageClient(BaseMessageHandler):
|
||||
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _connect_once(self) -> None:
|
||||
await self._establish_connection()
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||
data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8")
|
||||
raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8")
|
||||
try:
|
||||
payload = orjson.loads(data)
|
||||
payload = orjson.loads(raw_bytes)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
|
||||
continue
|
||||
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self.process_message(item)
|
||||
@@ -278,23 +383,33 @@ class MessageClient(BaseMessageHandler):
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
pass
|
||||
finally:
|
||||
if not self._closed:
|
||||
await self._notify_disconnect("websocket disconnected")
|
||||
await self._reconnect()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def run(self) -> None:
|
||||
if self._receive_task is None:
|
||||
await self._establish_connection()
|
||||
try:
|
||||
if self._receive_task:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
pass
|
||||
self._closed = False
|
||||
while not self._closed:
|
||||
if self._receive_task is None:
|
||||
await self._establish_connection()
|
||||
task = self._receive_task
|
||||
if task is None:
|
||||
break
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
raise
|
||||
|
||||
async def send_message(self, message: MessagePayload) -> bool:
|
||||
if self._ws is None or self._ws.closed:
|
||||
raise RuntimeError("WebSocket connection is not established")
|
||||
await self._ws.send_str(orjson.dumps(message).decode("utf-8"))
|
||||
async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool:
|
||||
ws = await self._ensure_ws()
|
||||
data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_str(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
return True
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
@@ -313,6 +428,42 @@ class MessageClient(BaseMessageHandler):
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _notify_disconnect(self, reason: str) -> None:
|
||||
if self._on_disconnect is None:
|
||||
return
|
||||
try:
|
||||
result = self._on_disconnect(self._platform, reason)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception: # pragma: no cover - best effort notification
|
||||
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
|
||||
if self._ws is None or self._ws.closed:
|
||||
await self._connect_once()
|
||||
assert self._ws is not None
|
||||
return self._ws
|
||||
|
||||
async def __aenter__(self) -> "MessageClient":
|
||||
if not self._url or not self._platform:
|
||||
raise RuntimeError("connect() must be called before using MessageClient as a context manager")
|
||||
await self._ensure_session()
|
||||
await self._ensure_ws()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.stop()
|
||||
|
||||
|
||||
def _self_websocket(app: FastAPI, path: str):
|
||||
"""
|
||||
|
||||
110
src/mofox_bus/builder.py
Normal file
110
src/mofox_bus/builder.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""
|
||||
Fluent helper to build MessageEnvelope safely with type hints.
|
||||
|
||||
Example:
|
||||
msg = (
|
||||
MessageBuilder()
|
||||
.text("Hello")
|
||||
.image("http://example.com/1.png")
|
||||
.to_user("123", platform="qq")
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._direction: str = "outgoing"
|
||||
self._message_info: MessageInfoPayload = {}
|
||||
self._segments: List[SegPayload] = []
|
||||
self._metadata: Dict[str, Any] | None = None
|
||||
self._timestamp_ms: int | None = None
|
||||
self._message_id: str | None = None
|
||||
|
||||
def direction(self, value: str) -> "MessageBuilder":
|
||||
self._direction = value
|
||||
return self
|
||||
|
||||
def message_id(self, value: str) -> "MessageBuilder":
|
||||
self._message_id = value
|
||||
return self
|
||||
|
||||
def timestamp_ms(self, value: int | None = None) -> "MessageBuilder":
|
||||
self._timestamp_ms = value or int(time.time() * 1000)
|
||||
return self
|
||||
|
||||
def metadata(self, value: Dict[str, Any]) -> "MessageBuilder":
|
||||
self._metadata = value
|
||||
return self
|
||||
|
||||
def platform(self, value: str) -> "MessageBuilder":
|
||||
self._message_info["platform"] = value
|
||||
return self
|
||||
|
||||
def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder":
|
||||
if platform:
|
||||
self.platform(platform)
|
||||
user_info: UserInfoPayload = {"user_id": user_id}
|
||||
if nickname:
|
||||
user_info["user_nickname"] = nickname
|
||||
self._message_info["user_info"] = user_info
|
||||
return self
|
||||
|
||||
def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder":
|
||||
if platform:
|
||||
self.platform(platform)
|
||||
group_info: GroupInfoPayload = {"group_id": group_id}
|
||||
if name:
|
||||
group_info["group_name"] = name
|
||||
self._message_info["group_info"] = group_info
|
||||
return self
|
||||
|
||||
def seg(self, type_: str, data: Any) -> "MessageBuilder":
|
||||
self._segments.append({"type": type_, "data": data})
|
||||
return self
|
||||
|
||||
def text(self, content: str) -> "MessageBuilder":
|
||||
return self.seg("text", content)
|
||||
|
||||
def image(self, url: str) -> "MessageBuilder":
|
||||
return self.seg("image", url)
|
||||
|
||||
def reply(self, target_message_id: str) -> "MessageBuilder":
|
||||
return self.seg("reply", target_message_id)
|
||||
|
||||
def raw_segment(self, segment: SegPayload) -> "MessageBuilder":
|
||||
self._segments.append(segment)
|
||||
return self
|
||||
|
||||
def build(self) -> MessageEnvelope:
|
||||
# message_info defaults
|
||||
if not self._segments:
|
||||
raise ValueError("message_segment is required, add at least one segment before build()")
|
||||
if self._message_id is None:
|
||||
self._message_id = str(uuid.uuid4())
|
||||
info = dict(self._message_info)
|
||||
info.setdefault("message_id", self._message_id)
|
||||
info.setdefault("time", time.time())
|
||||
|
||||
segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments]
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": self._direction, # type: ignore[assignment]
|
||||
"message_info": info,
|
||||
"message_segment": segments[0] if len(segments) == 1 else list(segments),
|
||||
}
|
||||
if self._metadata is not None:
|
||||
envelope["metadata"] = self._metadata
|
||||
if self._timestamp_ms is not None:
|
||||
envelope["timestamp_ms"] = self._timestamp_ms
|
||||
return envelope
|
||||
|
||||
|
||||
__all__ = ["MessageBuilder"]
|
||||
@@ -27,24 +27,23 @@ def _loads(data: bytes) -> Dict[str, Any]:
|
||||
|
||||
def dumps_message(msg: MessageEnvelope) -> bytes:
|
||||
"""
|
||||
将单条 MessageEnvelope 序列化为 JSON bytes。
|
||||
将单条消息序列化为 JSON bytes。
|
||||
"""
|
||||
if "schema_version" not in msg:
|
||||
msg["schema_version"] = DEFAULT_SCHEMA_VERSION
|
||||
return _dumps(msg)
|
||||
|
||||
sanitized = _strip_raw_bytes(msg)
|
||||
if "schema_version" not in sanitized:
|
||||
sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION
|
||||
return _dumps(sanitized)
|
||||
|
||||
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
|
||||
"""
|
||||
将多条消息批量序列化,以提升吞吐。
|
||||
将批量消息序列化为 JSON bytes。
|
||||
"""
|
||||
payload = {
|
||||
"schema_version": DEFAULT_SCHEMA_VERSION,
|
||||
"items": list(messages),
|
||||
"items": [_strip_raw_bytes(msg) for msg in messages],
|
||||
}
|
||||
return _dumps(payload)
|
||||
|
||||
|
||||
def loads_message(data: bytes | str) -> MessageEnvelope:
|
||||
"""
|
||||
反序列化单条消息。
|
||||
@@ -78,6 +77,14 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
|
||||
raise ValueError(f"Unsupported schema_version={version}")
|
||||
|
||||
|
||||
|
||||
def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope:
|
||||
if isinstance(msg, dict) and "raw_bytes" in msg:
|
||||
new_msg = dict(msg)
|
||||
new_msg.pop("raw_bytes", None)
|
||||
return new_msg # type: ignore[return-value]
|
||||
return msg
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SCHEMA_VERSION",
|
||||
"dumps_message",
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
"""
|
||||
消息段,表示一段文本/图片/结构化内容。
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: Any
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Seg":
|
||||
seg_type = data.get("type")
|
||||
seg_data = data.get("data")
|
||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||
seg_data = [Seg.from_dict(item) for item in seg_data]
|
||||
return cls(type=seg_type, data=seg_data)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
if self.type == "seglist" and isinstance(self.data, list):
|
||||
payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data]
|
||||
else:
|
||||
payload = self.data
|
||||
return {"type": self.type, "data": payload}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
platform: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
group_name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]:
|
||||
if not data or data.get("group_id") is None:
|
||||
return None
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
group_id=data.get("group_id"),
|
||||
group_name=data.get("group_name"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
platform: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
user_nickname: Optional[str] = None
|
||||
user_cardname: Optional[str] = None
|
||||
user_avatar: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "UserInfo":
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
user_id=data.get("user_id"),
|
||||
user_nickname=data.get("user_nickname"),
|
||||
user_cardname=data.get("user_cardname"),
|
||||
user_avatar=data.get("user_avatar"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatInfo:
|
||||
content_format: Optional[List[str]] = None
|
||||
accept_format: Optional[List[str]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]:
|
||||
if not data:
|
||||
return None
|
||||
return cls(
|
||||
content_format=data.get("content_format"),
|
||||
accept_format=data.get("accept_format"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
template_items: Optional[Dict[str, str]] = None
|
||||
template_name: Optional[Dict[str, str]] = None
|
||||
template_default: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]:
|
||||
if not data:
|
||||
return None
|
||||
return cls(
|
||||
template_items=data.get("template_items"),
|
||||
template_name=data.get("template_name"),
|
||||
template_default=data.get("template_default", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessageInfo:
|
||||
platform: Optional[str] = None
|
||||
message_id: Optional[str] = None
|
||||
time: Optional[float] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
format_info: Optional[FormatInfo] = None
|
||||
template_info: Optional[TemplateInfo] = None
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
if self.platform is not None:
|
||||
result["platform"] = self.platform
|
||||
if self.message_id is not None:
|
||||
result["message_id"] = self.message_id
|
||||
if self.time is not None:
|
||||
result["time"] = self.time
|
||||
if self.additional_config is not None:
|
||||
result["additional_config"] = self.additional_config
|
||||
if self.group_info is not None:
|
||||
result["group_info"] = self.group_info.to_dict()
|
||||
if self.user_info is not None:
|
||||
result["user_info"] = self.user_info.to_dict()
|
||||
if self.format_info is not None:
|
||||
result["format_info"] = self.format_info.to_dict()
|
||||
if self.template_info is not None:
|
||||
result["template_info"] = self.template_info.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo":
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
message_id=data.get("message_id"),
|
||||
time=data.get("time"),
|
||||
additional_config=data.get("additional_config"),
|
||||
group_info=GroupInfo.from_dict(data.get("group_info", {})),
|
||||
user_info=UserInfo.from_dict(data.get("user_info", {})),
|
||||
format_info=FormatInfo.from_dict(data.get("format_info", {})),
|
||||
template_info=TemplateInfo.from_dict(data.get("template_info", {})),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
message_info: BaseMessageInfo
|
||||
message_segment: Seg
|
||||
raw_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"message_info": self.message_info.to_dict(),
|
||||
"message_segment": self.message_segment.to_dict(),
|
||||
}
|
||||
if self.raw_message is not None:
|
||||
payload["raw_message"] = self.raw_message
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MessageBase":
|
||||
return cls(
|
||||
message_info=BaseMessageInfo.from_dict(data.get("message_info", {})),
|
||||
message_segment=Seg.from_dict(data.get("message_segment", {})),
|
||||
raw_message=data.get("raw_message"),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseMessageInfo",
|
||||
"FormatInfo",
|
||||
"GroupInfo",
|
||||
"MessageBase",
|
||||
"Seg",
|
||||
"TemplateInfo",
|
||||
"UserInfo",
|
||||
]
|
||||
@@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from .api import MessageClient
|
||||
from .message_models import MessageBase
|
||||
from .types import MessageEnvelope
|
||||
|
||||
logger = logging.getLogger("mofox_bus.router")
|
||||
|
||||
@@ -55,7 +55,7 @@ class Router:
|
||||
self.handlers: list[Callable[[Dict], None]] = []
|
||||
self._running = False
|
||||
self._client_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._monitor_task: asyncio.Task | None = None
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
|
||||
async def connect(self, platform: str) -> None:
|
||||
if platform not in self.config.route_config:
|
||||
@@ -65,6 +65,7 @@ class Router:
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("TCP mode is not implemented yet")
|
||||
client = MessageClient(mode="ws")
|
||||
client.set_disconnect_callback(self._handle_client_disconnect)
|
||||
await client.connect(
|
||||
url=target.url,
|
||||
platform=platform,
|
||||
@@ -75,7 +76,7 @@ class Router:
|
||||
client.register_message_handler(handler)
|
||||
self.clients[platform] = client
|
||||
if self._running:
|
||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
|
||||
self.handlers.append(handler)
|
||||
@@ -84,36 +85,18 @@ class Router:
|
||||
|
||||
async def run(self) -> None:
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
for platform in self.config.route_config:
|
||||
if platform not in self.clients:
|
||||
await self.connect(platform)
|
||||
for platform, client in self.clients.items():
|
||||
if platform not in self._client_tasks:
|
||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
||||
self._monitor_task = asyncio.create_task(self._monitor_connections())
|
||||
self._start_client_task(platform, client)
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
await self._stop_event.wait()
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
raise
|
||||
|
||||
async def _monitor_connections(self) -> None:
|
||||
await asyncio.sleep(3)
|
||||
while self._running:
|
||||
for platform in list(self.clients.keys()):
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
continue
|
||||
if not client.is_connected():
|
||||
logger.info("Detected disconnect from %s, attempting reconnect", platform)
|
||||
await self._reconnect_platform(platform)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _reconnect_platform(self, platform: str) -> None:
|
||||
await self.remove_platform(platform)
|
||||
if platform in self.config.route_config:
|
||||
await self.connect(platform)
|
||||
|
||||
async def remove_platform(self, platform: str) -> None:
|
||||
if platform in self._client_tasks:
|
||||
task = self._client_tasks.pop(platform)
|
||||
@@ -124,32 +107,55 @@ class Router:
|
||||
if client:
|
||||
await client.stop()
|
||||
|
||||
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
|
||||
logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason)
|
||||
task = self._client_tasks.get(platform)
|
||||
if task is not None and not task.done():
|
||||
return
|
||||
client = self.clients.get(platform)
|
||||
if client and self._running:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._monitor_task
|
||||
self._monitor_task = None
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
for platform in list(self.clients.keys()):
|
||||
await self.remove_platform(platform)
|
||||
self.clients.clear()
|
||||
|
||||
def get_target_url(self, message: MessageBase) -> Optional[str]:
|
||||
platform = message.message_info.platform
|
||||
def _start_client_task(self, platform: str, client: MessageClient) -> None:
|
||||
task = asyncio.create_task(client.run())
|
||||
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
|
||||
self._client_tasks[platform] = task
|
||||
|
||||
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Client task for %s ended with exception: %s", platform, exc)
|
||||
client = self.clients.get(platform)
|
||||
if client:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
return None
|
||||
target = self.config.route_config.get(platform)
|
||||
return target.url if target else None
|
||||
|
||||
async def send_message(self, message: MessageBase):
|
||||
platform = message.message_info.platform
|
||||
async def send_message(self, message: MessageEnvelope):
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required")
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
raise RuntimeError(f"No client connected for platform {platform}")
|
||||
return await client.send_message(message.to_dict())
|
||||
return await client.send_message(message)
|
||||
|
||||
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
||||
new_config = RouteConfig.from_dict(config_data)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Iterable, List
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
@@ -12,6 +13,11 @@ ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
|
||||
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
|
||||
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
|
||||
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
|
||||
MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]]
|
||||
|
||||
|
||||
class Middleware(Protocol):
|
||||
async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ...
|
||||
|
||||
|
||||
class MessageProcessingError(RuntimeError):
|
||||
@@ -19,7 +25,7 @@ class MessageProcessingError(RuntimeError):
|
||||
|
||||
def __init__(self, message: MessageEnvelope, original: BaseException):
|
||||
detail = message.get("id", "<unknown>")
|
||||
super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only
|
||||
super().__init__(f"Failed to handle message {detail}: {original}")
|
||||
self.message_envelope = message
|
||||
self.original = original
|
||||
|
||||
@@ -29,6 +35,8 @@ class MessageRoute:
|
||||
predicate: Predicate
|
||||
handler: MessageHandler
|
||||
name: str | None = None
|
||||
message_type: str | None = None
|
||||
event_types: set[str] | None = None
|
||||
|
||||
|
||||
class MessageRuntime:
|
||||
@@ -43,15 +51,36 @@ class MessageRuntime:
|
||||
self._error_hooks: list[ErrorHook] = []
|
||||
self._batch_handler: BatchHandler | None = None
|
||||
self._lock = threading.RLock()
|
||||
self._middlewares: list[Middleware] = []
|
||||
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
||||
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
||||
|
||||
def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None:
|
||||
def add_route(
|
||||
self,
|
||||
predicate: Predicate,
|
||||
handler: MessageHandler,
|
||||
name: str | None = None,
|
||||
*,
|
||||
message_type: str | None = None,
|
||||
event_types: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
with self._lock:
|
||||
self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name))
|
||||
route = MessageRoute(
|
||||
predicate=predicate,
|
||||
handler=handler,
|
||||
name=name,
|
||||
message_type=message_type,
|
||||
event_types=set(event_types) if event_types is not None else None,
|
||||
)
|
||||
self._routes.append(route)
|
||||
if message_type:
|
||||
self._type_routes.setdefault(message_type, []).append(route)
|
||||
if route.event_types:
|
||||
for et in route.event_types:
|
||||
self._event_routes.setdefault(et, []).append(route)
|
||||
|
||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""
|
||||
装饰器写法,便于在核心逻辑中声明式注册。
|
||||
"""
|
||||
"""装饰器写法,便于在核心逻辑中声明式注册。"""
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
self.add_route(predicate, func, name=name)
|
||||
@@ -59,6 +88,60 @@ class MessageRuntime:
|
||||
|
||||
return decorator
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
*,
|
||||
message_type: str | None = None,
|
||||
platform: str | None = None,
|
||||
predicate: Predicate | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。"""
|
||||
|
||||
async def combined_predicate(message: MessageEnvelope) -> bool:
|
||||
if message_type is not None and _extract_segment_type(message) != message_type:
|
||||
return False
|
||||
if platform is not None:
|
||||
info_platform = message.get("message_info", {}).get("platform")
|
||||
if message.get("platform") not in (None, platform) and info_platform is None:
|
||||
return False
|
||||
if info_platform not in (None, platform):
|
||||
return False
|
||||
if predicate is None:
|
||||
return True
|
||||
return await _invoke_callable(predicate, message, prefer_thread=False)
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
event_type: str | Iterable[str],
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
|
||||
|
||||
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
|
||||
|
||||
async def predicate(message: MessageEnvelope) -> bool:
|
||||
current = (
|
||||
message.get("event_type")
|
||||
or message.get("message_info", {})
|
||||
.get("additional_config", {})
|
||||
.get("event_type")
|
||||
)
|
||||
return current in allowed
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
self.add_route(predicate, func, name=name, event_types=allowed)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||
self._batch_handler = handler
|
||||
|
||||
@@ -71,14 +154,20 @@ class MessageRuntime:
|
||||
def register_error_hook(self, hook: ErrorHook) -> None:
|
||||
self._error_hooks.append(hook)
|
||||
|
||||
def register_middleware(self, middleware: Middleware) -> None:
|
||||
"""注册洋葱模型中间件,围绕处理器执行。"""
|
||||
|
||||
self._middlewares.append(middleware)
|
||||
|
||||
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
|
||||
await self._run_hooks(self._before_hooks, message)
|
||||
try:
|
||||
route = await self._match_route(message)
|
||||
if route is None:
|
||||
return None
|
||||
result = await _maybe_await(route.handler(message))
|
||||
except Exception as exc: # pragma: no cover - tested indirectly
|
||||
handler = self._wrap_with_middlewares(route.handler)
|
||||
result = await handler(message)
|
||||
except Exception as exc:
|
||||
await self._run_error_hooks(message, exc)
|
||||
raise MessageProcessingError(message, exc) from exc
|
||||
await self._run_hooks(self._after_hooks, message)
|
||||
@@ -89,7 +178,7 @@ class MessageRuntime:
|
||||
if not batch:
|
||||
return []
|
||||
if self._batch_handler is not None:
|
||||
result = await _maybe_await(self._batch_handler(batch))
|
||||
result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True)
|
||||
return result or []
|
||||
responses: list[MessageEnvelope] = []
|
||||
for message in batch:
|
||||
@@ -99,21 +188,61 @@ class MessageRuntime:
|
||||
return responses
|
||||
|
||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||
candidates: list[MessageRoute] = []
|
||||
message_type = _extract_segment_type(message)
|
||||
event_type = (
|
||||
message.get("event_type")
|
||||
or message.get("message_info", {})
|
||||
.get("additional_config", {})
|
||||
.get("event_type")
|
||||
)
|
||||
with self._lock:
|
||||
routes = list(self._routes)
|
||||
for route in routes:
|
||||
should_handle = await _maybe_await(route.predicate(message))
|
||||
if event_type and event_type in self._event_routes:
|
||||
candidates.extend(self._event_routes[event_type])
|
||||
if message_type and message_type in self._type_routes:
|
||||
candidates.extend(self._type_routes[message_type])
|
||||
candidates.extend(self._routes)
|
||||
|
||||
seen: set[int] = set()
|
||||
for route in candidates:
|
||||
rid = id(route)
|
||||
if rid in seen:
|
||||
continue
|
||||
seen.add(rid)
|
||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||
if should_handle:
|
||||
return route
|
||||
return None
|
||||
|
||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||
for hook in hooks:
|
||||
await _maybe_await(hook(message))
|
||||
coro_list = [self._call_hook(hook, message) for hook in hooks]
|
||||
if coro_list:
|
||||
await asyncio.gather(*coro_list)
|
||||
|
||||
async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None:
|
||||
await _invoke_callable(hook, message, prefer_thread=True)
|
||||
|
||||
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
|
||||
for hook in self._error_hooks:
|
||||
await _maybe_await(hook(message, exc))
|
||||
coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks]
|
||||
if coros:
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None:
|
||||
await _invoke_callable(hook, message, exc, prefer_thread=True)
|
||||
|
||||
def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable:
|
||||
async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None:
|
||||
return await _invoke_callable(handler, message, prefer_thread=True)
|
||||
|
||||
wrapped: MiddlewareCallable = base_handler
|
||||
for middleware in reversed(self._middlewares):
|
||||
current = wrapped
|
||||
|
||||
async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None:
|
||||
return await _invoke_callable(mw, msg, nxt, prefer_thread=False)
|
||||
|
||||
wrapped = wrapper
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _maybe_await(result):
|
||||
@@ -122,6 +251,32 @@ async def _maybe_await(result):
|
||||
return result
|
||||
|
||||
|
||||
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||
"""Support sync/async callables with optional thread offloading."""
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
result = await asyncio.to_thread(func, *args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
result = func(*args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
def _extract_segment_type(message: MessageEnvelope) -> str | None:
|
||||
seg = message.get("message_segment") or message.get("message_chain")
|
||||
if isinstance(seg, dict):
|
||||
return seg.get("type")
|
||||
if isinstance(seg, list) and seg:
|
||||
first = seg[0]
|
||||
if isinstance(first, dict):
|
||||
return first.get("type")
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchHandler",
|
||||
"Hook",
|
||||
@@ -129,5 +284,6 @@ __all__ = [
|
||||
"MessageProcessingError",
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"Middleware",
|
||||
"Predicate",
|
||||
]
|
||||
|
||||
@@ -3,160 +3,91 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
||||
|
||||
MessageDirection = Literal["incoming", "outgoing"]
|
||||
Role = Literal["user", "assistant", "system", "tool", "platform"]
|
||||
ContentType = Literal[
|
||||
"text",
|
||||
"image",
|
||||
"audio",
|
||||
"file",
|
||||
"video",
|
||||
"event",
|
||||
"command",
|
||||
"system",
|
||||
]
|
||||
|
||||
EventType = Literal[
|
||||
"message_created",
|
||||
"message_updated",
|
||||
"message_deleted",
|
||||
"member_join",
|
||||
"member_leave",
|
||||
"typing",
|
||||
"reaction_add",
|
||||
"reaction_remove",
|
||||
]
|
||||
# ----------------------------
|
||||
# maim_message 风格的 TypedDict
|
||||
# ----------------------------
|
||||
|
||||
|
||||
class TextContent(TypedDict, total=False):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
markdown: NotRequired[bool]
|
||||
entities: NotRequired[List[Dict[str, Any]]]
|
||||
class SegPayload(TypedDict, total=False):
|
||||
"""
|
||||
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: str | List["SegPayload"]
|
||||
translated_data: NotRequired[str | List["SegPayload"]]
|
||||
|
||||
|
||||
class ImageContent(TypedDict, total=False):
|
||||
type: Literal["image"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
width: NotRequired[int]
|
||||
height: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
class UserInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
user_id: NotRequired[str]
|
||||
user_nickname: NotRequired[str]
|
||||
user_cardname: NotRequired[str]
|
||||
user_avatar: NotRequired[str]
|
||||
|
||||
|
||||
class FileContent(TypedDict, total=False):
|
||||
type: Literal["file"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
file_name: NotRequired[str]
|
||||
file_size: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
class GroupInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
group_id: NotRequired[str]
|
||||
group_name: NotRequired[str]
|
||||
|
||||
|
||||
class AudioContent(TypedDict, total=False):
|
||||
type: Literal["audio"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
duration_ms: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
class FormatInfoPayload(TypedDict, total=False):
|
||||
content_format: NotRequired[List[str]]
|
||||
accept_format: NotRequired[List[str]]
|
||||
|
||||
|
||||
class VideoContent(TypedDict, total=False):
|
||||
type: Literal["video"]
|
||||
url: str
|
||||
mime_type: NotRequired[str]
|
||||
duration_ms: NotRequired[int]
|
||||
width: NotRequired[int]
|
||||
height: NotRequired[int]
|
||||
file_id: NotRequired[str]
|
||||
class TemplateInfoPayload(TypedDict, total=False):
|
||||
template_items: NotRequired[Dict[str, str]]
|
||||
template_name: NotRequired[Dict[str, str]]
|
||||
template_default: NotRequired[bool]
|
||||
|
||||
|
||||
class EventContent(TypedDict):
|
||||
type: Literal["event"]
|
||||
event_type: EventType
|
||||
raw: Dict[str, Any]
|
||||
class MessageInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
message_id: NotRequired[str]
|
||||
time: NotRequired[float]
|
||||
group_info: NotRequired[GroupInfoPayload]
|
||||
user_info: NotRequired[UserInfoPayload]
|
||||
format_info: NotRequired[FormatInfoPayload]
|
||||
template_info: NotRequired[TemplateInfoPayload]
|
||||
additional_config: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
class CommandContent(TypedDict, total=False):
|
||||
type: Literal["command"]
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
class SystemContent(TypedDict):
|
||||
type: Literal["system"]
|
||||
text: str
|
||||
|
||||
|
||||
Content = (
|
||||
TextContent
|
||||
| ImageContent
|
||||
| FileContent
|
||||
| AudioContent
|
||||
| VideoContent
|
||||
| EventContent
|
||||
| CommandContent
|
||||
| SystemContent
|
||||
)
|
||||
|
||||
|
||||
class SenderInfo(TypedDict, total=False):
|
||||
user_id: str
|
||||
role: Role
|
||||
display_name: NotRequired[str]
|
||||
avatar_url: NotRequired[str]
|
||||
raw: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
class ChannelInfo(TypedDict, total=False):
|
||||
channel_id: str
|
||||
channel_type: Literal[
|
||||
"private",
|
||||
"group",
|
||||
"supergroup",
|
||||
"channel",
|
||||
"dm",
|
||||
"room",
|
||||
"thread",
|
||||
]
|
||||
title: NotRequired[str]
|
||||
workspace_id: NotRequired[str]
|
||||
raw: NotRequired[Dict[str, Any]]
|
||||
# ----------------------------
|
||||
# MessageEnvelope
|
||||
# ----------------------------
|
||||
|
||||
|
||||
class MessageEnvelope(TypedDict, total=False):
|
||||
id: str
|
||||
direction: MessageDirection
|
||||
platform: str
|
||||
timestamp_ms: int
|
||||
channel: ChannelInfo
|
||||
sender: SenderInfo
|
||||
content: Content
|
||||
conversation_id: str
|
||||
thread_id: NotRequired[str]
|
||||
reply_to_message_id: NotRequired[str]
|
||||
correlation_id: NotRequired[str]
|
||||
is_edited: NotRequired[bool]
|
||||
is_ephemeral: NotRequired[bool]
|
||||
raw_platform_message: NotRequired[Dict[str, Any]]
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
schema_version: NotRequired[int]
|
||||
"""
|
||||
mofox-bus 传输层统一使用的消息信封。
|
||||
|
||||
- 采用 maim_message 风格:message_info + message_segment。
|
||||
"""
|
||||
|
||||
direction: MessageDirection
|
||||
message_info: MessageInfoPayload
|
||||
message_segment: SegPayload | List[SegPayload]
|
||||
raw_message: NotRequired[Any]
|
||||
raw_bytes: NotRequired[bytes]
|
||||
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
||||
platform: NotRequired[str] # 快捷访问,等价于 message_info.platform
|
||||
message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id
|
||||
timestamp_ms: NotRequired[int]
|
||||
correlation_id: NotRequired[str]
|
||||
schema_version: NotRequired[int]
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
|
||||
__all__ = [
|
||||
"AudioContent",
|
||||
"ChannelInfo",
|
||||
"CommandContent",
|
||||
"Content",
|
||||
"ContentType",
|
||||
"EventContent",
|
||||
"EventType",
|
||||
"FileContent",
|
||||
"ImageContent",
|
||||
# maim_message style payloads
|
||||
"SegPayload",
|
||||
"UserInfoPayload",
|
||||
"GroupInfoPayload",
|
||||
"FormatInfoPayload",
|
||||
"TemplateInfoPayload",
|
||||
"MessageInfoPayload",
|
||||
# legacy content style
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
"Role",
|
||||
"SenderInfo",
|
||||
"SystemContent",
|
||||
"TextContent",
|
||||
"VideoContent",
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreMessageSink, MessageEnvelope
|
||||
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
@@ -47,7 +47,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_sink: CoreMessageSink,
|
||||
core_sink: CoreSink,
|
||||
plugin: Optional[BasePlugin] = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -227,7 +227,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
"""
|
||||
将平台原始消息转换为 MessageEnvelope
|
||||
|
||||
|
||||
@@ -7,130 +7,152 @@ Adapter 管理器
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import multiprocessing as mp
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_system.base.base_adapter import BaseAdapter
|
||||
|
||||
from mofox_bus import ProcessCoreSinkServer
|
||||
from src.common.core_sink import get_core_sink
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("adapter_manager")
|
||||
|
||||
|
||||
class AdapterProcess:
|
||||
"""适配器子进程包装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
entry_path: Path,
|
||||
python_executable: Optional[str] = None,
|
||||
):
|
||||
self.adapter_name = adapter_name
|
||||
self.entry_path = entry_path
|
||||
self.python_executable = python_executable or sys.executable
|
||||
self.process: Optional[subprocess.Popen] = None
|
||||
self._monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
def _load_class(module_name: str, class_name: str):
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def _adapter_process_entry(
|
||||
adapter_path: tuple[str, str],
|
||||
plugin_info: dict | None,
|
||||
incoming_queue: mp.Queue,
|
||||
outgoing_queue: mp.Queue,
|
||||
):
|
||||
import asyncio
|
||||
import contextlib
|
||||
from mofox_bus import ProcessCoreSink
|
||||
|
||||
async def _run() -> None:
|
||||
adapter_cls = _load_class(*adapter_path)
|
||||
plugin_instance = None
|
||||
if plugin_info:
|
||||
plugin_cls = _load_class(plugin_info["module"], plugin_info["class"])
|
||||
plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"])
|
||||
core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue)
|
||||
adapter = adapter_cls(core_sink, plugin=plugin_instance)
|
||||
await adapter.start()
|
||||
try:
|
||||
while not getattr(core_sink, "_closed", False):
|
||||
await asyncio.sleep(0.2)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
await adapter.stop()
|
||||
with contextlib.suppress(Exception):
|
||||
await core_sink.close()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
|
||||
class AdapterProcess:
|
||||
"""适配器子进程包装器,负责适配器子进程的启动和生命周期管理"""
|
||||
|
||||
def __init__(self, adapter: "BaseAdapter", core_sink) -> None:
|
||||
self.adapter = adapter
|
||||
self.adapter_name = adapter.adapter_name
|
||||
self.process: mp.Process | None = None
|
||||
self._ctx = mp.get_context("spawn")
|
||||
self._incoming_queue: mp.Queue = self._ctx.Queue()
|
||||
self._outgoing_queue: mp.Queue = self._ctx.Queue()
|
||||
self._bridge: ProcessCoreSinkServer | None = None
|
||||
self._core_sink = core_sink
|
||||
self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__)
|
||||
self._plugin_info = self._extract_plugin_info(adapter)
|
||||
self._outgoing_handler = None
|
||||
|
||||
@staticmethod
|
||||
def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None:
|
||||
plugin = getattr(adapter, "plugin", None)
|
||||
if plugin is None:
|
||||
return None
|
||||
return {
|
||||
"module": plugin.__class__.__module__,
|
||||
"class": plugin.__class__.__name__,
|
||||
"plugin_dir": getattr(plugin, "plugin_dir", ""),
|
||||
"metadata": getattr(plugin, "plugin_meta", None),
|
||||
}
|
||||
|
||||
def _make_outgoing_handler(self):
|
||||
async def _handler(envelope):
|
||||
if self._bridge:
|
||||
await self._bridge.push_outgoing(envelope)
|
||||
return _handler
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""启动适配器子进程"""
|
||||
try:
|
||||
logger.info(f"启动适配器子进程: {self.adapter_name}")
|
||||
logger.debug(f"Python: {self.python_executable}")
|
||||
logger.debug(f"Entry: {self.entry_path}")
|
||||
|
||||
# 启动子进程
|
||||
self.process = subprocess.Popen(
|
||||
[self.python_executable, str(self.entry_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
self._bridge = ProcessCoreSinkServer(
|
||||
incoming_queue=self._incoming_queue,
|
||||
outgoing_queue=self._outgoing_queue,
|
||||
core_handler=self._core_sink.send,
|
||||
name=self.adapter_name,
|
||||
)
|
||||
|
||||
# 启动监控任务
|
||||
self._monitor_task = asyncio.create_task(self._monitor_process())
|
||||
|
||||
logger.info(f"适配器 {self.adapter_name} 子进程已启动 (PID: {self.process.pid})")
|
||||
self._bridge.start()
|
||||
if hasattr(self._core_sink, "set_outgoing_handler"):
|
||||
self._outgoing_handler = self._make_outgoing_handler()
|
||||
try:
|
||||
self._core_sink.set_outgoing_handler(self._outgoing_handler)
|
||||
except Exception:
|
||||
logger.exception("Failed to register outgoing bridge for %s", self.adapter_name)
|
||||
self.process = self._ctx.Process(
|
||||
target=_adapter_process_entry,
|
||||
args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue),
|
||||
name=f"{self.adapter_name}-proc",
|
||||
)
|
||||
self.process.start()
|
||||
logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动适配器 {self.adapter_name} 子进程失败: {e}", exc_info=True)
|
||||
logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止适配器子进程"""
|
||||
if not self.process:
|
||||
return
|
||||
|
||||
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
|
||||
|
||||
try:
|
||||
# 取消监控任务
|
||||
if self._monitor_task and not self._monitor_task.done():
|
||||
self._monitor_task.cancel()
|
||||
remover = getattr(self._core_sink, "remove_outgoing_handler", None)
|
||||
if callable(remover) and self._outgoing_handler:
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 终止进程
|
||||
self.process.terminate()
|
||||
|
||||
# 等待进程退出(最多等待5秒)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(self.process.wait),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"适配器 {self.adapter_name} 未能在5秒内退出,强制终止")
|
||||
self.process.kill()
|
||||
await asyncio.to_thread(self.process.wait)
|
||||
|
||||
logger.info(f"适配器 {self.adapter_name} 子进程已停止")
|
||||
|
||||
remover(self._outgoing_handler)
|
||||
except Exception:
|
||||
logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败")
|
||||
if self._bridge:
|
||||
await self._bridge.close()
|
||||
if self.process.is_alive():
|
||||
self.process.join(timeout=5.0)
|
||||
if self.process.is_alive():
|
||||
logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中")
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
except Exception as e:
|
||||
logger.error(f"停止适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
|
||||
logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self.process = None
|
||||
|
||||
async def _monitor_process(self) -> None:
|
||||
"""监控子进程状态"""
|
||||
if not self.process:
|
||||
return
|
||||
|
||||
try:
|
||||
# 在后台线程中等待进程退出
|
||||
return_code = await asyncio.to_thread(self.process.wait)
|
||||
|
||||
if return_code != 0:
|
||||
logger.error(
|
||||
f"适配器 {self.adapter_name} 子进程异常退出 (返回码: {return_code})"
|
||||
)
|
||||
|
||||
# 读取 stderr 输出
|
||||
if self.process.stderr:
|
||||
stderr = self.process.stderr.read()
|
||||
if stderr:
|
||||
logger.error(f"错误输出:\n{stderr}")
|
||||
else:
|
||||
logger.info(f"适配器 {self.adapter_name} 子进程正常退出")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"监控适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""检查进程是否正在运行"""
|
||||
"""适配器是否正在运行"""
|
||||
if not self.process:
|
||||
return False
|
||||
return self.process.poll() is None
|
||||
|
||||
return self.process.is_alive()
|
||||
|
||||
class AdapterManager:
|
||||
"""适配器管理器"""
|
||||
@@ -176,20 +198,17 @@ class AdapterManager:
|
||||
else:
|
||||
return await self._start_adapter_in_process(adapter)
|
||||
|
||||
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
|
||||
"""在子进程中启动适配器"""
|
||||
adapter_name = adapter.adapter_name
|
||||
|
||||
# 获取子进程入口脚本
|
||||
entry_path = adapter.get_subprocess_entry_path()
|
||||
if not entry_path:
|
||||
logger.error(
|
||||
f"适配器 {adapter_name} 配置为子进程运行,但未提供有效的入口脚本"
|
||||
)
|
||||
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
|
||||
"""启动适配器子进程"""
|
||||
adapter_name = adapter.adapter_name
|
||||
try:
|
||||
core_sink = get_core_sink()
|
||||
except Exception as e:
|
||||
logger.error(f"无法获取 core_sink,启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
# 创建并启动子进程
|
||||
adapter_process = AdapterProcess(adapter_name, entry_path)
|
||||
adapter_process = AdapterProcess(adapter, core_sink)
|
||||
success = await adapter_process.start()
|
||||
|
||||
if success:
|
||||
|
||||
427
src/plugins/built_in/NEW_napcat_adapter/README.md
Normal file
427
src/plugins/built_in/NEW_napcat_adapter/README.md
Normal file
@@ -0,0 +1,427 @@
|
||||
# NEW_napcat_adapter
|
||||
|
||||
基于 mofox-bus v2.x 的 Napcat 适配器(使用 BaseAdapter 架构)
|
||||
|
||||
## 🏗️ 架构设计
|
||||
|
||||
本插件采用 **BaseAdapter 继承模式** 重写,完全抛弃旧版 maim_message 库,改用 mofox-bus 的 TypedDict 数据结构。
|
||||
|
||||
### 核心组件
|
||||
- **NapcatAdapter**: 继承自 `mofox_bus.AdapterBase`,负责 OneBot 11 协议与 MessageEnvelope 的双向转换
|
||||
- **WebSocketAdapterOptions**: 自动管理 WebSocket 连接,提供 incoming_parser 和 outgoing_encoder
|
||||
- **CoreMessageSink**: 通过 `InProcessCoreSink` 将消息递送到核心系统
|
||||
- **Handlers**: 独立的消息处理器,分为 to_core(接收)和 to_napcat(发送)两个方向
|
||||
|
||||
## 📁 项目结构
|
||||
|
||||
```
|
||||
NEW_napcat_adapter/
|
||||
├── plugin.py # ✅ 主插件文件(BaseAdapter实现)
|
||||
├── _manifest.json # 插件清单
|
||||
│
|
||||
└── src/
|
||||
├── event_models.py # ✅ OneBot事件类型常量
|
||||
├── common/
|
||||
│ └── core_sink.py # ✅ 全局CoreSink访问点
|
||||
│
|
||||
├── utils/
|
||||
│ ├── utils.py # ⏳ 工具函数(待实现)
|
||||
│ ├── qq_emoji_list.py # ⏳ QQ表情映射(待实现)
|
||||
│ ├── video_handler.py # ⏳ 视频处理(待实现)
|
||||
│ └── message_chunker.py # ⏳ 消息切片(待实现)
|
||||
│
|
||||
├── websocket/
|
||||
│ └── (无需单独实现,使用WebSocketAdapterOptions)
|
||||
│
|
||||
├── database/
|
||||
│ └── database.py # ⏳ 数据库模型(待实现)
|
||||
│
|
||||
└── handlers/
|
||||
├── to_core/ # Napcat → MessageEnvelope 方向
|
||||
│ ├── message_handler.py # ⏳ 消息处理(部分完成)
|
||||
│ ├── notice_handler.py # ⏳ 通知处理(待完成)
|
||||
│ └── meta_event_handler.py # ⏳ 元事件(待完成)
|
||||
│
|
||||
└── to_napcat/ # MessageEnvelope → Napcat API 方向
|
||||
└── send_handler.py # ⏳ 发送处理(部分完成)
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 使用方式
|
||||
|
||||
1. **配置文件**: 在 `config/plugins/NEW_napcat_adapter.toml` 中配置 WebSocket URL 和其他参数
|
||||
2. **启动插件**: 插件自动在系统启动时加载
|
||||
3. **WebSocket连接**: 自动连接到 Napcat OneBot 11 服务器
|
||||
|
||||
## 🔑 核心数据结构
|
||||
|
||||
### MessageEnvelope (mofox-bus v2.x)
|
||||
|
||||
```python
|
||||
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload
|
||||
|
||||
# 创建消息信封
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": "input",
|
||||
"message_info": {
|
||||
"message_type": "group",
|
||||
"message_id": "12345",
|
||||
"self_id": "bot_qq",
|
||||
"user_info": {
|
||||
"user_id": "sender_qq",
|
||||
"user_name": "发送者",
|
||||
"user_displayname": "昵称"
|
||||
},
|
||||
"group_info": {
|
||||
"group_id": "group_id",
|
||||
"group_name": "群名"
|
||||
},
|
||||
"to_me": False
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{"type": "text", "data": "hello"},
|
||||
{"type": "image", "data": "base64_data"}
|
||||
]
|
||||
},
|
||||
"raw_message": "hello[图片]",
|
||||
"platform": "napcat",
|
||||
"message_id": "12345",
|
||||
"timestamp_ms": 1234567890
|
||||
}
|
||||
```
|
||||
|
||||
### BaseAdapter 核心方法
|
||||
|
||||
```python
|
||||
class NapcatAdapter(BaseAdapter):
|
||||
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
|
||||
"""将 OneBot 11 事件转换为 MessageEnvelope"""
|
||||
# 路由到对应的 Handler
|
||||
|
||||
async def _send_platform_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
|
||||
"""将 MessageEnvelope 转换为 OneBot 11 API 调用"""
|
||||
# 调用 SendHandler 处理
|
||||
```
|
||||
|
||||
## 📝 实现进度
|
||||
|
||||
### ✅ 已完成的核心架构
|
||||
|
||||
1. **BaseAdapter 实现** (plugin.py)
|
||||
- ✅ WebSocket 自动连接管理
|
||||
- ✅ from_platform_message() 事件路由
|
||||
- ✅ _send_platform_message() 消息发送
|
||||
- ✅ API 响应池机制(echo-based request-response)
|
||||
- ✅ CoreSink 集成
|
||||
|
||||
2. **Handler 基础结构**
|
||||
- ✅ MessageHandler 骨架(text、image、at 基本实现)
|
||||
- ✅ NoticeHandler 骨架
|
||||
- ✅ MetaEventHandler 骨架
|
||||
- ✅ SendHandler 骨架(基本类型转换)
|
||||
|
||||
3. **辅助组件**
|
||||
- ✅ event_models.py(事件类型常量)
|
||||
- ✅ core_sink.py(全局 CoreSink 访问)
|
||||
- ✅ 配置 Schema 定义
|
||||
|
||||
### ⏳ 部分完成的功能
|
||||
|
||||
4. **消息类型处理** (MessageHandler)
|
||||
- ✅ 基础消息类型:text, image, at
|
||||
- ❌ 高级消息类型:face, reply, forward, video, json, file, rps, dice, shake
|
||||
|
||||
5. **发送处理** (SendHandler)
|
||||
- ✅ 基础 SegPayload 转换:text, image
|
||||
- ❌ 高级 Seg 类型:emoji, voice, voiceurl, music, videourl, file, command
|
||||
|
||||
### ❌ 待实现的功能
|
||||
|
||||
6. **通知事件处理** (NoticeHandler)
|
||||
- ❌ 戳一戳事件
|
||||
- ❌ 表情回应事件
|
||||
- ❌ 撤回事件
|
||||
- ❌ 禁言事件
|
||||
|
||||
7. **工具函数** (utils.py)
|
||||
- ❌ get_group_info
|
||||
- ❌ get_member_info
|
||||
- ❌ get_image_base64
|
||||
- ❌ get_message_detail
|
||||
- ❌ get_record_detail
|
||||
|
||||
8. **权限系统**
|
||||
- ❌ check_allow_to_chat()
|
||||
- ❌ 群组黑名单/白名单
|
||||
- ❌ 私聊黑名单/白名单
|
||||
- ❌ QQ机器人检测
|
||||
|
||||
9. **其他组件**
|
||||
- ❌ 视频处理器
|
||||
- ❌ 消息切片器
|
||||
- ❌ 数据库模型
|
||||
- ❌ QQ 表情映射表
|
||||
|
||||
## 📋 下一步工作
|
||||
|
||||
### 优先级 1:完善消息处理(参考旧版 recv_handler/message_handler.py)
|
||||
|
||||
1. **完整实现 MessageHandler.handle_raw_message()**
|
||||
- [ ] face(表情)消息段
|
||||
- [ ] reply(回复)消息段
|
||||
- [ ] forward(转发)消息段解析
|
||||
- [ ] video(视频)消息段
|
||||
- [ ] json(JSON卡片)消息段
|
||||
- [ ] file(文件)消息段
|
||||
- [ ] rps/dice/shake(特殊消息)
|
||||
|
||||
2. **实现工具函数**(参考旧版 utils.py)
|
||||
- [ ] `get_group_info()` - 获取群组信息
|
||||
- [ ] `get_member_info()` - 获取成员信息
|
||||
- [ ] `get_image_base64()` - 下载图片并转Base64
|
||||
- [ ] `get_message_detail()` - 获取消息详情
|
||||
- [ ] `get_record_detail()` - 获取语音详情
|
||||
|
||||
3. **实现权限检查**
|
||||
- [ ] `check_allow_to_chat()` - 检查是否允许聊天
|
||||
- [ ] 群组白名单/黑名单逻辑
|
||||
- [ ] 私聊白名单/黑名单逻辑
|
||||
- [ ] QQ机器人检测(ban_qq_bot)
|
||||
|
||||
### 优先级 2:完善发送处理(参考旧版 send_handler.py)
|
||||
|
||||
4. **完整实现 SendHandler._convert_seg_to_onebot()**
|
||||
- [ ] emoji(表情回应)命令
|
||||
- [ ] voice(语音)消息段
|
||||
- [ ] voiceurl(语音URL)消息段
|
||||
- [ ] music(音乐卡片)消息段
|
||||
- [ ] videourl(视频URL)消息段
|
||||
- [ ] file(文件)消息段
|
||||
- [ ] command(命令)消息段
|
||||
|
||||
5. **实现命令处理**
|
||||
- [ ] GROUP_BAN(禁言)
|
||||
- [ ] GROUP_KICK(踢人)
|
||||
- [ ] SEND_POKE(戳一戳)
|
||||
- [ ] DELETE_MSG(撤回消息)
|
||||
- [ ] GROUP_WHOLE_BAN(全员禁言)
|
||||
- [ ] SET_GROUP_CARD(设置群名片)
|
||||
- [ ] SET_GROUP_ADMIN(设置管理员)
|
||||
|
||||
### 优先级 3:补全其他组件(参考旧版对应文件)
|
||||
|
||||
6. **NoticeHandler 实现**
|
||||
- [ ] 戳一戳通知(notify.poke)
|
||||
- [ ] 表情回应通知(notice.group_emoji_like)
|
||||
- [ ] 消息撤回通知(notice.group_recall)
|
||||
- [ ] 禁言通知(notice.group_ban)
|
||||
|
||||
7. **辅助组件**
|
||||
- [ ] `qq_emoji_list.py` - QQ表情ID映射表
|
||||
- [ ] `video_handler.py` - 视频处理(ffmpeg封面提取)
|
||||
- [ ] `message_chunker.py` - 消息分块与重组
|
||||
- [ ] `database.py` - 数据库模型(如有需要)
|
||||
|
||||
### 优先级 4:测试与优化
|
||||
|
||||
8. **功能测试**
|
||||
- [ ] 文本消息收发
|
||||
- [ ] 图片消息收发
|
||||
- [ ] @消息处理
|
||||
- [ ] 表情/语音/视频消息
|
||||
- [ ] 转发消息解析
|
||||
- [ ] 所有命令功能
|
||||
- [ ] 通知事件处理
|
||||
|
||||
9. **性能优化**
|
||||
- [ ] 消息处理并发性能
|
||||
- [ ] API响应池性能
|
||||
- [ ] 内存占用优化
|
||||
|
||||
## 🔍 关键实现细节
|
||||
|
||||
### 1. MessageEnvelope vs 旧版 MessageBase
|
||||
|
||||
**不再使用 Seg dataclass**,全部使用 TypedDict:
|
||||
|
||||
```python
|
||||
# ❌ 旧版(maim_message)
|
||||
from mofox_bus import Seg, MessageBase
|
||||
|
||||
seg = Seg(type="text", data="hello")
|
||||
message = MessageBase(message_info=info, message_segment=seg)
|
||||
|
||||
# ✅ 新版(mofox-bus v2.x)
|
||||
from mofox_bus import SegPayload, MessageEnvelope
|
||||
|
||||
seg_payload: SegPayload = {"type": "text", "data": "hello"}
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": "input",
|
||||
"message_info": {...},
|
||||
"message_segment": seg_payload,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Handler 架构模式
|
||||
|
||||
**接收方向** (to_core):
|
||||
```python
|
||||
class MessageHandler:
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
|
||||
async def handle_raw_message(self, data: dict[str, Any]) -> MessageEnvelope:
|
||||
# 1. 解析 OneBot 11 数据
|
||||
# 2. 构建 message_info(MessageInfoPayload)
|
||||
# 3. 转换消息段为 SegPayload
|
||||
# 4. 返回完整的 MessageEnvelope
|
||||
```
|
||||
|
||||
**发送方向** (to_napcat):
|
||||
```python
|
||||
class SendHandler:
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
|
||||
async def handle_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
|
||||
# 1. 从 envelope 提取 message_segment
|
||||
# 2. 递归转换 SegPayload → OneBot 格式
|
||||
# 3. 调用 adapter.send_napcat_api() 发送
|
||||
```
|
||||
|
||||
### 3. API 调用模式(响应池)
|
||||
|
||||
```python
|
||||
# 在 NapcatAdapter 中
|
||||
async def send_napcat_api(self, action: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
# 1. 生成唯一 echo
|
||||
echo = f"{action}_{uuid.uuid4()}"
|
||||
|
||||
# 2. 创建 Future 等待响应
|
||||
future = asyncio.Future()
|
||||
self._response_pool[echo] = future
|
||||
|
||||
# 3. 发送请求(通过 WebSocket)
|
||||
await self._send_request({"action": action, "params": params, "echo": echo})
|
||||
|
||||
# 4. 等待响应(带超时)
|
||||
try:
|
||||
result = await asyncio.wait_for(future, timeout=10.0)
|
||||
return result
|
||||
finally:
|
||||
self._response_pool.pop(echo, None)
|
||||
|
||||
# 响应回来时(在 incoming_parser 中)
|
||||
def _handle_api_response(data: dict[str, Any]):
|
||||
echo = data.get("echo")
|
||||
if echo in adapter._response_pool:
|
||||
adapter._response_pool[echo].set_result(data)
|
||||
```
|
||||
|
||||
### 4. 类型提示技巧
|
||||
|
||||
处理 TypedDict 的严格类型检查:
|
||||
|
||||
```python
|
||||
# 使用 type: ignore 标注(编译时是 TypedDict,运行时是 dict)
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": "input",
|
||||
...
|
||||
} # type: ignore[typeddict-item]
|
||||
|
||||
# 或在函数签名中使用 dict[str, Any]
|
||||
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
|
||||
...
|
||||
return envelope # type: ignore[return-value]
|
||||
```
|
||||
|
||||
## 🔍 测试检查清单
|
||||
|
||||
- [ ] 文本消息接收/发送
|
||||
- [ ] 图片消息接收/发送
|
||||
- [ ] 语音消息接收/发送
|
||||
- [ ] 视频消息接收/发送
|
||||
- [ ] @消息接收/发送
|
||||
- [ ] 回复消息接收/发送
|
||||
- [ ] 转发消息接收
|
||||
- [ ] JSON消息接收
|
||||
- [ ] 文件消息接收/发送
|
||||
- [ ] 禁言命令
|
||||
- [ ] 踢人命令
|
||||
- [ ] 戳一戳命令
|
||||
- [ ] 表情回应命令
|
||||
- [ ] 通知事件处理
|
||||
- [ ] 元事件处理
|
||||
|
||||
## 📚 参考资料
|
||||
|
||||
- **mofox-bus 文档**: 查看 `mofox_bus/types.py` 了解 TypedDict 定义
|
||||
- **BaseAdapter 示例**: 参考 `docs/mofox_bus_demo_adapter.py`
|
||||
- **旧版实现**: `src/plugins/built_in/napcat_adapter_plugin/` (仅参考逻辑)
|
||||
- **OneBot 11 协议**: [OneBot 11 标准](https://github.com/botuniverse/onebot-11)
|
||||
|
||||
## ⚠️ 重要注意事项
|
||||
|
||||
1. **完全抛弃旧版数据结构**
|
||||
- ❌ 不再使用 `Seg` dataclass
|
||||
- ❌ 不再使用 `MessageBase` 类
|
||||
- ✅ 全部使用 `SegPayload`(TypedDict)
|
||||
- ✅ 全部使用 `MessageEnvelope`(TypedDict)
|
||||
|
||||
2. **BaseAdapter 生命周期**
|
||||
- `__init__()` 中初始化同步资源
|
||||
- `start()` 中执行异步初始化(WebSocket连接自动建立)
|
||||
- `stop()` 中清理资源(WebSocket自动断开)
|
||||
|
||||
3. **WebSocketAdapterOptions 自动管理**
|
||||
- 无需手动管理 WebSocket 连接
|
||||
- incoming_parser 自动解析接收数据
|
||||
- outgoing_encoder 自动编码发送数据
|
||||
- 重连机制由基类处理
|
||||
|
||||
4. **CoreSink 依赖注入**
|
||||
- 必须在插件加载后调用 `set_core_sink(sink)`
|
||||
- 通过 `get_core_sink()` 全局访问
|
||||
- 用于将消息递送到核心系统
|
||||
|
||||
5. **类型安全与灵活性平衡**
|
||||
- TypedDict 在编译时提供类型检查
|
||||
- 运行时仍是普通 dict,可灵活操作
|
||||
- 必要时使用 `type: ignore` 抑制误报
|
||||
|
||||
6. **参考旧版但不照搬**
|
||||
- 旧版逻辑流程可参考
|
||||
- 数据结构需完全重写
|
||||
- API调用模式已改变(响应池)
|
||||
|
||||
## 📊 预估工作量
|
||||
|
||||
- ✅ 核心架构: **已完成** (BaseAdapter + Handlers 骨架)
|
||||
- ⏳ 消息处理完善: **4-6 小时** (所有消息类型 + 工具函数)
|
||||
- ⏳ 发送处理完善: **3-4 小时** (所有 Seg 类型 + 命令)
|
||||
- ⏳ 通知事件处理: **2-3 小时** (poke/emoji_like/recall/ban)
|
||||
- ⏳ 测试调试: **2-4 小时** (全流程测试)
|
||||
- **总剩余时间: 11-17 小时**
|
||||
|
||||
## ✅ 完成标准
|
||||
|
||||
当以下条件全部满足时,重写完成:
|
||||
|
||||
1. ✅ BaseAdapter 架构实现完成
|
||||
2. ⏳ 所有 OneBot 11 消息类型支持
|
||||
3. ⏳ 所有发送消息段类型支持
|
||||
4. ⏳ 所有通知事件正确处理
|
||||
5. ⏳ 权限系统集成完成
|
||||
6. ⏳ 与旧版功能完全对等
|
||||
7. ⏳ 所有测试用例通过
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2025-11-23
|
||||
**架构状态**: ✅ 核心架构完成
|
||||
**实现状态**: ⏳ 消息处理部分完成,需完善细节
|
||||
**预计完成**: 根据优先级,核心功能预计 1-2 个工作日
|
||||
16
src/plugins/built_in/NEW_napcat_adapter/__init__.py
Normal file
16
src/plugins/built_in/NEW_napcat_adapter/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="napcat_plugin",
|
||||
description="基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接",
|
||||
usage="该插件提供 `napcat_tool` tool。",
|
||||
version="1.0.0",
|
||||
author="Windpicker_owo",
|
||||
license="GPL-v3.0-or-later",
|
||||
repository_url="https://github.com/Windpicker-owo",
|
||||
keywords=["qq", "bot", "napcat", "onebot", "api", "websocket"],
|
||||
categories=["protocol"],
|
||||
extra={
|
||||
"is_built_in": False,
|
||||
},
|
||||
)
|
||||
330
src/plugins/built_in/NEW_napcat_adapter/plugin.py
Normal file
330
src/plugins/built_in/NEW_napcat_adapter/plugin.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Napcat 适配器(基于 MoFox-Bus 完全重写版)
|
||||
|
||||
核心流程:
|
||||
1. Napcat WebSocket 连接 → 接收 OneBot 格式消息
|
||||
2. from_platform_message: OneBot dict → MessageEnvelope
|
||||
3. CoreSink → 推送到 MoFox-Bot 核心
|
||||
4. 核心回复 → _send_platform_message: MessageEnvelope → OneBot API 调用
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, ClassVar, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
import websockets
|
||||
|
||||
from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import register_plugin
|
||||
from src.plugin_system.base import BaseAdapter, BasePlugin
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from .src.handlers.to_core.message_handler import MessageHandler
|
||||
from .src.handlers.to_core.notice_handler import NoticeHandler
|
||||
from .src.handlers.to_core.meta_event_handler import MetaEventHandler
|
||||
from .src.handlers.to_napcat.send_handler import SendHandler
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class NapcatAdapter(BaseAdapter):
|
||||
"""Napcat 适配器 - 完全基于 mofox-bus 架构"""
|
||||
|
||||
adapter_name = "napcat_adapter"
|
||||
adapter_version = "2.0.0"
|
||||
adapter_author = "MoFox Team"
|
||||
adapter_description = "基于 MoFox-Bus 的 Napcat/OneBot 11 适配器"
|
||||
platform = "qq"
|
||||
|
||||
run_in_subprocess = False
|
||||
subprocess_entry = None
|
||||
|
||||
def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None):
|
||||
"""初始化 Napcat 适配器"""
|
||||
# 从插件配置读取 WebSocket URL
|
||||
if plugin:
|
||||
mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse")
|
||||
host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095)
|
||||
url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "")
|
||||
access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "")
|
||||
|
||||
if mode == "forward" and url:
|
||||
ws_url = url
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}"
|
||||
|
||||
headers = {}
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
else:
|
||||
ws_url = "ws://127.0.0.1:8095"
|
||||
headers = {}
|
||||
|
||||
# 配置 WebSocket 传输
|
||||
transport = WebSocketAdapterOptions(
|
||||
url=ws_url,
|
||||
headers=headers if headers else None,
|
||||
incoming_parser=self._parse_napcat_message,
|
||||
outgoing_encoder=self._encode_napcat_response,
|
||||
)
|
||||
|
||||
super().__init__(core_sink, plugin=plugin, transport=transport)
|
||||
|
||||
# 初始化处理器
|
||||
self.message_handler = MessageHandler(self)
|
||||
self.notice_handler = NoticeHandler(self)
|
||||
self.meta_event_handler = MetaEventHandler(self)
|
||||
self.send_handler = SendHandler(self)
|
||||
|
||||
# 响应池:用于存储等待的 API 响应
|
||||
self._response_pool: Dict[str, asyncio.Future] = {}
|
||||
self._response_timeout = 30.0
|
||||
|
||||
# WebSocket 连接(用于发送 API 请求)
|
||||
# 注意:_ws 继承自 BaseAdapter,是 WebSocketLike 协议类型
|
||||
self._napcat_ws = None # 可选的额外连接引用
|
||||
|
||||
async def on_adapter_loaded(self) -> None:
|
||||
"""适配器加载时的初始化"""
|
||||
logger.info("Napcat 适配器正在启动...")
|
||||
|
||||
# 设置处理器配置
|
||||
if self.plugin:
|
||||
self.message_handler.set_plugin_config(self.plugin.config)
|
||||
self.notice_handler.set_plugin_config(self.plugin.config)
|
||||
self.meta_event_handler.set_plugin_config(self.plugin.config)
|
||||
self.send_handler.set_plugin_config(self.plugin.config)
|
||||
|
||||
logger.info("Napcat 适配器已加载")
|
||||
|
||||
async def on_adapter_unloaded(self) -> None:
|
||||
"""适配器卸载时的清理"""
|
||||
logger.info("Napcat 适配器正在关闭...")
|
||||
|
||||
# 清理响应池
|
||||
for future in self._response_pool.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._response_pool.clear()
|
||||
|
||||
logger.info("Napcat 适配器已关闭")
|
||||
|
||||
def _parse_napcat_message(self, raw: str | bytes) -> Any:
|
||||
"""解析 Napcat/OneBot 消息"""
|
||||
try:
|
||||
if isinstance(raw, bytes):
|
||||
data = orjson.loads(raw)
|
||||
else:
|
||||
data = orjson.loads(raw)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Napcat 消息失败: {e}")
|
||||
raise
|
||||
|
||||
def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes:
|
||||
"""编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)"""
|
||||
return orjson.dumps(envelope)
|
||||
|
||||
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
|
||||
"""
|
||||
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
|
||||
|
||||
这是核心转换方法,处理:
|
||||
- message 事件 → 消息
|
||||
- notice 事件 → 通知(戳一戳、表情回复等)
|
||||
- meta_event 事件 → 元事件(心跳、生命周期)
|
||||
- API 响应 → 存入响应池
|
||||
"""
|
||||
post_type = raw.get("post_type")
|
||||
|
||||
# API 响应(没有 post_type,有 echo)
|
||||
if post_type is None and "echo" in raw:
|
||||
echo = raw.get("echo")
|
||||
if echo and echo in self._response_pool:
|
||||
future = self._response_pool[echo]
|
||||
if not future.done():
|
||||
future.set_result(raw)
|
||||
# API 响应不需要转换为 MessageEnvelope,返回空信封
|
||||
return self._create_empty_envelope()
|
||||
|
||||
# 消息事件
|
||||
if post_type == "message":
|
||||
return await self.message_handler.handle_raw_message(raw) # type: ignore[return-value]
|
||||
|
||||
# 通知事件
|
||||
elif post_type == "notice":
|
||||
return await self.notice_handler.handle_notice(raw) # type: ignore[return-value]
|
||||
|
||||
# 元事件
|
||||
elif post_type == "meta_event":
|
||||
return await self.meta_event_handler.handle_meta_event(raw) # type: ignore[return-value]
|
||||
|
||||
# 未知事件类型
|
||||
else:
|
||||
logger.warning(f"未知的事件类型: {post_type}")
|
||||
return self._create_empty_envelope() # type: ignore[return-value]
|
||||
|
||||
async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override]
|
||||
"""
|
||||
将 MessageEnvelope 转换并发送到 Napcat
|
||||
|
||||
这里不直接通过 WebSocket 发送 envelope,
|
||||
而是调用 Napcat API(send_group_msg, send_private_msg 等)
|
||||
"""
|
||||
await self.send_handler.handle_message(envelope)
|
||||
|
||||
def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return]
|
||||
"""创建一个空的消息信封(用于不需要处理的事件)"""
|
||||
import time
|
||||
return {
|
||||
"direction": "incoming",
|
||||
"message_info": {
|
||||
"platform": self.platform,
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"time": time.time(),
|
||||
},
|
||||
"message_segment": {"type": "text", "data": "[系统事件]"},
|
||||
"timestamp_ms": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 Napcat API 请求并等待响应
|
||||
|
||||
Args:
|
||||
action: API 动作名称(如 send_group_msg)
|
||||
params: API 参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
API 响应数据
|
||||
"""
|
||||
if not self._ws:
|
||||
raise RuntimeError("WebSocket 连接未建立")
|
||||
|
||||
# 生成唯一的 echo ID
|
||||
echo = str(uuid.uuid4())
|
||||
|
||||
# 创建 Future 用于等待响应
|
||||
future = asyncio.Future()
|
||||
self._response_pool[echo] = future
|
||||
|
||||
# 构造请求
|
||||
request = orjson.dumps({
|
||||
"action": action,
|
||||
"params": params,
|
||||
"echo": echo,
|
||||
})
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._ws.send(request)
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"API 请求超时: {action}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"API 请求失败: {action}, 错误: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理响应池
|
||||
self._response_pool.pop(echo, None)
|
||||
|
||||
def get_ws_connection(self):
|
||||
"""获取 WebSocket 连接(用于发送 API 请求)"""
|
||||
if not self._ws:
|
||||
raise RuntimeError("WebSocket 连接未建立")
|
||||
return self._ws
|
||||
|
||||
|
||||
@register_plugin
|
||||
class NapcatAdapterPlugin(BasePlugin):
|
||||
"""Napcat 适配器插件"""
|
||||
|
||||
plugin_name = "napcat_adapter_plugin"
|
||||
enable_plugin = True
|
||||
plugin_version = "2.0.0"
|
||||
plugin_author = "MoFox Team"
|
||||
plugin_description = "Napcat/OneBot 11 适配器(基于 MoFox-Bus 重写)"
|
||||
|
||||
# 配置 Schema
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": {"type": str, "default": "napcat_adapter_plugin"},
|
||||
"version": {"type": str, "default": "2.0.0"},
|
||||
"enabled": {"type": bool, "default": True},
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": {
|
||||
"type": str,
|
||||
"default": "reverse",
|
||||
"description": "连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||
},
|
||||
"host": {"type": str, "default": "localhost"},
|
||||
"port": {"type": int, "default": 8095},
|
||||
"url": {"type": str, "default": "", "description": "正向连接时的完整URL"},
|
||||
"access_token": {"type": str, "default": ""},
|
||||
},
|
||||
"features": {
|
||||
"group_list_type": {"type": str, "default": "blacklist"},
|
||||
"group_list": {"type": list, "default": []},
|
||||
"private_list_type": {"type": str, "default": "blacklist"},
|
||||
"private_list": {"type": list, "default": []},
|
||||
"ban_user_id": {"type": list, "default": []},
|
||||
"ban_qq_bot": {"type": bool, "default": False},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, plugin_dir: str = "", metadata: Any = None):
|
||||
# 如果没有提供参数,创建一个默认的元数据
|
||||
if metadata is None:
|
||||
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||
metadata = PluginMetadata(
|
||||
name=self.plugin_name,
|
||||
version=self.plugin_version,
|
||||
author=self.plugin_author,
|
||||
description=self.plugin_description,
|
||||
usage="",
|
||||
dependencies=[],
|
||||
python_dependencies=[],
|
||||
)
|
||||
|
||||
if not plugin_dir:
|
||||
from pathlib import Path
|
||||
plugin_dir = str(Path(__file__).parent)
|
||||
|
||||
super().__init__(plugin_dir, metadata)
|
||||
self._adapter: Optional[NapcatAdapter] = None
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
"""插件加载时启动适配器"""
|
||||
logger.info("Napcat 适配器插件正在加载...")
|
||||
|
||||
# 获取核心 Sink
|
||||
from src.common.core_sink import get_core_sink
|
||||
core_sink = get_core_sink()
|
||||
|
||||
# 创建并启动适配器
|
||||
self._adapter = NapcatAdapter(core_sink, plugin=self)
|
||||
await self._adapter.start()
|
||||
|
||||
logger.info("Napcat 适配器插件已加载")
|
||||
|
||||
async def on_plugin_unloaded(self):
|
||||
"""插件卸载时停止适配器"""
|
||||
if self._adapter:
|
||||
await self._adapter.stop()
|
||||
logger.info("Napcat 适配器插件已卸载")
|
||||
|
||||
def get_plugin_components(self) -> list:
|
||||
"""返回适配器组件"""
|
||||
return [(NapcatAdapter.get_adapter_info(), NapcatAdapter)]
|
||||
1
src/plugins/built_in/NEW_napcat_adapter/src/__init__.py
Normal file
1
src/plugins/built_in/NEW_napcat_adapter/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""工具模块"""
|
||||
310
src/plugins/built_in/NEW_napcat_adapter/src/event_models.py
Normal file
310
src/plugins/built_in/NEW_napcat_adapter/src/event_models.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
group_upload = "group_upload" # 群文件上传
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
input_status = "input_status" # 正在输入
|
||||
|
||||
class GroupBan:
|
||||
ban = "ban" # 禁言
|
||||
lift_ban = "lift_ban" # 解除禁言
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
json = "json" # json消息
|
||||
file = "file" # 文件
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
AI_VOICE_SEND = "ai_voice_send" # AI语音发送
|
||||
SET_EMOJI_LIKE = "set_msg_emoji_like" # 设置表情回应
|
||||
SEND_AT_MESSAGE = "send_at_message" # 发送@消息
|
||||
SEND_LIKE = "send_like" # 发送点赞
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
# 支持的消息格式
|
||||
ACCEPT_FORMAT = [
|
||||
"text",
|
||||
"image",
|
||||
"emoji",
|
||||
"reply",
|
||||
"voice",
|
||||
"command",
|
||||
"voiceurl",
|
||||
"music",
|
||||
"videourl",
|
||||
"file",
|
||||
]
|
||||
|
||||
# 插件名称
|
||||
PLUGIN_NAME = "NEW_napcat_adapter"
|
||||
|
||||
# QQ表情映射表
|
||||
QQ_FACE = {
|
||||
"0": "[表情:惊讶]",
|
||||
"1": "[表情:撇嘴]",
|
||||
"2": "[表情:色]",
|
||||
"3": "[表情:发呆]",
|
||||
"4": "[表情:得意]",
|
||||
"5": "[表情:流泪]",
|
||||
"6": "[表情:害羞]",
|
||||
"7": "[表情:闭嘴]",
|
||||
"8": "[表情:睡]",
|
||||
"9": "[表情:大哭]",
|
||||
"10": "[表情:尴尬]",
|
||||
"11": "[表情:发怒]",
|
||||
"12": "[表情:调皮]",
|
||||
"13": "[表情:呲牙]",
|
||||
"14": "[表情:微笑]",
|
||||
"15": "[表情:难过]",
|
||||
"16": "[表情:酷]",
|
||||
"18": "[表情:抓狂]",
|
||||
"19": "[表情:吐]",
|
||||
"20": "[表情:偷笑]",
|
||||
"21": "[表情:可爱]",
|
||||
"22": "[表情:白眼]",
|
||||
"23": "[表情:傲慢]",
|
||||
"24": "[表情:饥饿]",
|
||||
"25": "[表情:困]",
|
||||
"26": "[表情:惊恐]",
|
||||
"27": "[表情:流汗]",
|
||||
"28": "[表情:憨笑]",
|
||||
"29": "[表情:悠闲]",
|
||||
"30": "[表情:奋斗]",
|
||||
"31": "[表情:咒骂]",
|
||||
"32": "[表情:疑问]",
|
||||
"33": "[表情:嘘]",
|
||||
"34": "[表情:晕]",
|
||||
"35": "[表情:折磨]",
|
||||
"36": "[表情:衰]",
|
||||
"37": "[表情:骷髅]",
|
||||
"38": "[表情:敲打]",
|
||||
"39": "[表情:再见]",
|
||||
"41": "[表情:发抖]",
|
||||
"42": "[表情:爱情]",
|
||||
"43": "[表情:跳跳]",
|
||||
"46": "[表情:猪头]",
|
||||
"49": "[表情:拥抱]",
|
||||
"53": "[表情:蛋糕]",
|
||||
"56": "[表情:刀]",
|
||||
"59": "[表情:便便]",
|
||||
"60": "[表情:咖啡]",
|
||||
"63": "[表情:玫瑰]",
|
||||
"64": "[表情:凋谢]",
|
||||
"66": "[表情:爱心]",
|
||||
"67": "[表情:心碎]",
|
||||
"74": "[表情:太阳]",
|
||||
"75": "[表情:月亮]",
|
||||
"76": "[表情:赞]",
|
||||
"77": "[表情:踩]",
|
||||
"78": "[表情:握手]",
|
||||
"79": "[表情:胜利]",
|
||||
"85": "[表情:飞吻]",
|
||||
"86": "[表情:怄火]",
|
||||
"89": "[表情:西瓜]",
|
||||
"96": "[表情:冷汗]",
|
||||
"97": "[表情:擦汗]",
|
||||
"98": "[表情:抠鼻]",
|
||||
"99": "[表情:鼓掌]",
|
||||
"100": "[表情:糗大了]",
|
||||
"101": "[表情:坏笑]",
|
||||
"102": "[表情:左哼哼]",
|
||||
"103": "[表情:右哼哼]",
|
||||
"104": "[表情:哈欠]",
|
||||
"105": "[表情:鄙视]",
|
||||
"106": "[表情:委屈]",
|
||||
"107": "[表情:快哭了]",
|
||||
"108": "[表情:阴险]",
|
||||
"109": "[表情:左亲亲]",
|
||||
"110": "[表情:吓]",
|
||||
"111": "[表情:可怜]",
|
||||
"112": "[表情:菜刀]",
|
||||
"114": "[表情:篮球]",
|
||||
"116": "[表情:示爱]",
|
||||
"118": "[表情:抱拳]",
|
||||
"119": "[表情:勾引]",
|
||||
"120": "[表情:拳头]",
|
||||
"121": "[表情:差劲]",
|
||||
"123": "[表情:NO]",
|
||||
"124": "[表情:OK]",
|
||||
"125": "[表情:转圈]",
|
||||
"129": "[表情:挥手]",
|
||||
"137": "[表情:鞭炮]",
|
||||
"144": "[表情:喝彩]",
|
||||
"146": "[表情:爆筋]",
|
||||
"147": "[表情:棒棒糖]",
|
||||
"169": "[表情:手枪]",
|
||||
"171": "[表情:茶]",
|
||||
"172": "[表情:眨眼睛]",
|
||||
"173": "[表情:泪奔]",
|
||||
"174": "[表情:无奈]",
|
||||
"175": "[表情:卖萌]",
|
||||
"176": "[表情:小纠结]",
|
||||
"177": "[表情:喷血]",
|
||||
"178": "[表情:斜眼笑]",
|
||||
"179": "[表情:doge]",
|
||||
"181": "[表情:戳一戳]",
|
||||
"182": "[表情:笑哭]",
|
||||
"183": "[表情:我最美]",
|
||||
"185": "[表情:羊驼]",
|
||||
"187": "[表情:幽灵]",
|
||||
"201": "[表情:点赞]",
|
||||
"212": "[表情:托腮]",
|
||||
"262": "[表情:脑阔疼]",
|
||||
"263": "[表情:沧桑]",
|
||||
"264": "[表情:捂脸]",
|
||||
"265": "[表情:辣眼睛]",
|
||||
"266": "[表情:哦哟]",
|
||||
"267": "[表情:头秃]",
|
||||
"268": "[表情:问号脸]",
|
||||
"269": "[表情:暗中观察]",
|
||||
"270": "[表情:emm]",
|
||||
"271": "[表情:吃瓜]",
|
||||
"272": "[表情:呵呵哒]",
|
||||
"273": "[表情:我酸了]",
|
||||
"277": "[表情:滑稽狗头]",
|
||||
"281": "[表情:翻白眼]",
|
||||
"282": "[表情:敬礼]",
|
||||
"283": "[表情:狂笑]",
|
||||
"284": "[表情:面无表情]",
|
||||
"285": "[表情:摸鱼]",
|
||||
"286": "[表情:魔鬼笑]",
|
||||
"287": "[表情:哦]",
|
||||
"289": "[表情:睁眼]",
|
||||
"293": "[表情:摸锦鲤]",
|
||||
"294": "[表情:期待]",
|
||||
"295": "[表情:拿到红包]",
|
||||
"297": "[表情:拜谢]",
|
||||
"298": "[表情:元宝]",
|
||||
"299": "[表情:牛啊]",
|
||||
"300": "[表情:胖三斤]",
|
||||
"302": "[表情:左拜年]",
|
||||
"303": "[表情:右拜年]",
|
||||
"305": "[表情:右亲亲]",
|
||||
"306": "[表情:牛气冲天]",
|
||||
"307": "[表情:喵喵]",
|
||||
"311": "[表情:打call]",
|
||||
"312": "[表情:变形]",
|
||||
"314": "[表情:仔细分析]",
|
||||
"317": "[表情:菜汪]",
|
||||
"318": "[表情:崇拜]",
|
||||
"319": "[表情:比心]",
|
||||
"320": "[表情:庆祝]",
|
||||
"323": "[表情:嫌弃]",
|
||||
"324": "[表情:吃糖]",
|
||||
"325": "[表情:惊吓]",
|
||||
"326": "[表情:生气]",
|
||||
"332": "[表情:举牌牌]",
|
||||
"333": "[表情:烟花]",
|
||||
"334": "[表情:虎虎生威]",
|
||||
"336": "[表情:豹富]",
|
||||
"337": "[表情:花朵脸]",
|
||||
"338": "[表情:我想开了]",
|
||||
"339": "[表情:舔屏]",
|
||||
"341": "[表情:打招呼]",
|
||||
"342": "[表情:酸Q]",
|
||||
"343": "[表情:我方了]",
|
||||
"344": "[表情:大怨种]",
|
||||
"345": "[表情:红包多多]",
|
||||
"346": "[表情:你真棒棒]",
|
||||
"347": "[表情:大展宏兔]",
|
||||
"349": "[表情:坚强]",
|
||||
"350": "[表情:贴贴]",
|
||||
"351": "[表情:敲敲]",
|
||||
"352": "[表情:咦]",
|
||||
"353": "[表情:拜托]",
|
||||
"354": "[表情:尊嘟假嘟]",
|
||||
"355": "[表情:耶]",
|
||||
"356": "[表情:666]",
|
||||
"357": "[表情:裂开]",
|
||||
"392": "[表情:龙年快乐]",
|
||||
"393": "[表情:新年中龙]",
|
||||
"394": "[表情:新年大龙]",
|
||||
"395": "[表情:略略略]",
|
||||
"396": "[表情:龙年快乐]",
|
||||
"424": "[表情:按钮]",
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetaEventType",
|
||||
"MessageType",
|
||||
"NoticeType",
|
||||
"RealMessageType",
|
||||
"MessageSentType",
|
||||
"CommandType",
|
||||
"ACCEPT_FORMAT",
|
||||
"PLUGIN_NAME",
|
||||
"QQ_FACE",
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""处理器模块"""
|
||||
@@ -0,0 +1 @@
|
||||
"""接收方向处理器"""
|
||||
@@ -0,0 +1,126 @@
|
||||
"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.message_handler")
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
"""处理来自 Napcat 的消息事件"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_raw_message(self, raw: Dict[str, Any]):
|
||||
"""
|
||||
处理原始消息并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
raw: OneBot 原始消息数据
|
||||
|
||||
Returns:
|
||||
MessageEnvelope (dict)
|
||||
"""
|
||||
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload
|
||||
|
||||
message_type = raw.get("message_type")
|
||||
message_id = str(raw.get("message_id", ""))
|
||||
message_time = time.time()
|
||||
|
||||
# 构造用户信息
|
||||
sender_info = raw.get("sender", {})
|
||||
user_info: UserInfoPayload = {
|
||||
"platform": "qq",
|
||||
"user_id": str(sender_info.get("user_id", "")),
|
||||
"user_nickname": sender_info.get("nickname", ""),
|
||||
"user_cardname": sender_info.get("card", ""),
|
||||
"user_avatar": sender_info.get("avatar", ""),
|
||||
}
|
||||
|
||||
# 构造群组信息(如果是群消息)
|
||||
group_info: Optional[GroupInfoPayload] = None
|
||||
if message_type == "group":
|
||||
group_id = raw.get("group_id")
|
||||
if group_id:
|
||||
group_info = {
|
||||
"platform": "qq",
|
||||
"group_id": str(group_id),
|
||||
"group_name": "", # 可以通过 API 获取
|
||||
}
|
||||
|
||||
# 解析消息段
|
||||
message_segments = raw.get("message", [])
|
||||
seg_list: List[SegPayload] = []
|
||||
|
||||
for seg in message_segments:
|
||||
seg_type = seg.get("type", "")
|
||||
seg_data = seg.get("data", {})
|
||||
|
||||
# 转换为 SegPayload
|
||||
if seg_type == "text":
|
||||
seg_list.append({
|
||||
"type": "text",
|
||||
"data": seg_data.get("text", "")
|
||||
})
|
||||
elif seg_type == "image":
|
||||
# 这里需要下载图片并转换为 base64(简化版本)
|
||||
seg_list.append({
|
||||
"type": "image",
|
||||
"data": seg_data.get("url", "") # 实际应该转换为 base64
|
||||
})
|
||||
elif seg_type == "at":
|
||||
seg_list.append({
|
||||
"type": "at",
|
||||
"data": f"{seg_data.get('qq', '')}"
|
||||
})
|
||||
# 其他消息类型...
|
||||
|
||||
# 构造 MessageInfoPayload
|
||||
message_info = {
|
||||
"platform": "qq",
|
||||
"message_id": message_id,
|
||||
"time": message_time,
|
||||
"user_info": user_info,
|
||||
"format_info": {
|
||||
"content_format": ["text", "image"], # 根据实际消息类型设置
|
||||
"accept_format": ["text", "image", "emoji", "voice"],
|
||||
},
|
||||
}
|
||||
|
||||
# 添加群组信息(如果存在)
|
||||
if group_info:
|
||||
message_info["group_info"] = group_info
|
||||
|
||||
# 构造 MessageEnvelope
|
||||
envelope = {
|
||||
"direction": "incoming",
|
||||
"message_info": message_info,
|
||||
"message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}),
|
||||
"raw_message": raw.get("raw_message", ""),
|
||||
"platform": "qq",
|
||||
"message_id": message_id,
|
||||
"timestamp_ms": int(message_time * 1000),
|
||||
}
|
||||
|
||||
return envelope
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""元事件处理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.meta_event_handler")
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""处理 Napcat 元事件(心跳、生命周期)"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_meta_event(self, raw: Dict[str, Any]):
|
||||
"""处理元事件"""
|
||||
# 简化版本:返回一个空的 MessageEnvelope
|
||||
import time
|
||||
import uuid
|
||||
|
||||
return {
|
||||
"direction": "incoming",
|
||||
"message_info": {
|
||||
"platform": "qq",
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"time": time.time(),
|
||||
},
|
||||
"message_segment": {"type": "text", "data": "[元事件]"},
|
||||
"timestamp_ms": int(time.time() * 1000),
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
"""通知事件处理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.notice_handler")
|
||||
|
||||
|
||||
class NoticeHandler:
|
||||
"""处理 Napcat 通知事件(戳一戳、表情回复等)"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_notice(self, raw: Dict[str, Any]):
|
||||
"""处理通知事件"""
|
||||
# 简化版本:返回一个空的 MessageEnvelope
|
||||
import time
|
||||
import uuid
|
||||
|
||||
return {
|
||||
"direction": "incoming",
|
||||
"message_info": {
|
||||
"platform": "qq",
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"time": time.time(),
|
||||
},
|
||||
"message_segment": {"type": "text", "data": "[通知事件]"},
|
||||
"timestamp_ms": int(time.time() * 1000),
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
"""发送方向处理器"""
|
||||
@@ -0,0 +1,77 @@
|
||||
"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.send_handler")
|
||||
|
||||
|
||||
class SendHandler:
|
||||
"""处理向 Napcat 发送消息"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_message(self, envelope) -> None:
|
||||
"""
|
||||
处理发送消息
|
||||
|
||||
将 MessageEnvelope 转换为 OneBot API 调用
|
||||
"""
|
||||
message_info = envelope.get("message_info", {})
|
||||
message_segment = envelope.get("message_segment", {})
|
||||
|
||||
# 获取群组和用户信息
|
||||
group_info = message_info.get("group_info")
|
||||
user_info = message_info.get("user_info")
|
||||
|
||||
# 构造消息内容
|
||||
message = self._convert_seg_to_onebot(message_segment)
|
||||
|
||||
# 发送消息
|
||||
if group_info:
|
||||
# 发送群消息
|
||||
group_id = group_info.get("group_id")
|
||||
if group_id:
|
||||
await self.adapter.send_napcat_api("send_group_msg", {
|
||||
"group_id": int(group_id),
|
||||
"message": message,
|
||||
})
|
||||
elif user_info:
|
||||
# 发送私聊消息
|
||||
user_id = user_info.get("user_id")
|
||||
if user_id:
|
||||
await self.adapter.send_napcat_api("send_private_msg", {
|
||||
"user_id": int(user_id),
|
||||
"message": message,
|
||||
})
|
||||
|
||||
def _convert_seg_to_onebot(self, seg: Dict[str, Any]) -> list:
|
||||
"""将 SegPayload 转换为 OneBot 消息格式"""
|
||||
seg_type = seg.get("type", "")
|
||||
seg_data = seg.get("data", "")
|
||||
|
||||
if seg_type == "text":
|
||||
return [{"type": "text", "data": {"text": seg_data}}]
|
||||
elif seg_type == "image":
|
||||
return [{"type": "image", "data": {"file": f"base64://{seg_data}"}}]
|
||||
elif seg_type == "seglist":
|
||||
# 递归处理列表
|
||||
result = []
|
||||
for sub_seg in seg_data:
|
||||
result.extend(self._convert_seg_to_onebot(sub_seg))
|
||||
return result
|
||||
else:
|
||||
# 默认作为文本
|
||||
return [{"type": "text", "data": {"text": str(seg_data)}}]
|
||||
350
src/plugins/built_in/NEW_napcat_adapter/stream_router.py
Normal file
350
src/plugins/built_in/NEW_napcat_adapter/stream_router.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
按聊天流分配消费者的消息路由系统
|
||||
|
||||
核心思想:
|
||||
- 为每个活跃的聊天流(stream_id)创建独立的消息队列和消费者协程
|
||||
- 同一聊天流的消息由同一个 worker 处理,保证顺序性
|
||||
- 不同聊天流的消息并发处理,提高吞吐量
|
||||
- 动态管理流的生命周期,自动清理不活跃的流
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("stream_router")
|
||||
|
||||
|
||||
class StreamConsumer:
|
||||
"""单个聊天流的消息消费者
|
||||
|
||||
维护独立的消息队列和处理协程
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: str, queue_maxsize: int = 100):
|
||||
self.stream_id = stream_id
|
||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self.worker_task: Optional[asyncio.Task] = None
|
||||
self.last_active_time = time.time()
|
||||
self.is_running = False
|
||||
|
||||
# 性能统计
|
||||
self.stats = {
|
||||
"total_messages": 0,
|
||||
"total_processing_time": 0.0,
|
||||
"queue_overflow_count": 0,
|
||||
}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动消费者"""
|
||||
if not self.is_running:
|
||||
self.is_running = True
|
||||
self.worker_task = asyncio.create_task(self._process_loop())
|
||||
logger.debug(f"Stream Consumer 启动: {self.stream_id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止消费者"""
|
||||
self.is_running = False
|
||||
if self.worker_task:
|
||||
self.worker_task.cancel()
|
||||
try:
|
||||
await self.worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Stream Consumer 停止: {self.stream_id}")
|
||||
|
||||
async def enqueue(self, message: dict) -> None:
|
||||
"""将消息加入队列"""
|
||||
self.last_active_time = time.time()
|
||||
|
||||
try:
|
||||
# 使用 put_nowait 避免阻塞路由器
|
||||
self.queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self.stats["queue_overflow_count"] += 1
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} 队列已满 "
|
||||
f"({self.queue.qsize()}/{self.queue.maxsize}),"
|
||||
)
|
||||
|
||||
try:
|
||||
self.queue.get_nowait()
|
||||
self.queue.put_nowait(message)
|
||||
logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息")
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
"""消息处理循环"""
|
||||
# 延迟导入,避免循环依赖
|
||||
from .recv_handler.message_handler import message_handler
|
||||
from .recv_handler.meta_event_handler import meta_event_handler
|
||||
from .recv_handler.notice_handler import notice_handler
|
||||
|
||||
logger.info(f"Stream {self.stream_id} 处理循环启动")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待消息,1秒超时
|
||||
message = await asyncio.wait_for(
|
||||
self.queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 处理消息
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await meta_event_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的 post_type: {post_type}")
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_messages"] += 1
|
||||
self.stats["total_processing_time"] += processing_time
|
||||
self.last_active_time = time.time()
|
||||
self.queue.task_done()
|
||||
|
||||
# 性能监控(每100条消息输出一次)
|
||||
if self.stats["total_messages"] % 100 == 0:
|
||||
avg_time = self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||
logger.info(
|
||||
f"Stream {self.stream_id[:30]}... 统计: "
|
||||
f"消息数={self.stats['total_messages']}, "
|
||||
f"平均耗时={avg_time:.3f}秒, "
|
||||
f"队列长度={self.queue.qsize()}"
|
||||
)
|
||||
|
||||
# 动态延迟:队列空时短暂休眠
|
||||
if self.queue.qsize() == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时是正常的,继续循环
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stream {self.stream_id} 处理循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True)
|
||||
# 继续处理下一条消息
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
finally:
|
||||
logger.info(f"Stream {self.stream_id} 处理循环结束")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取性能统计"""
|
||||
avg_time = (
|
||||
self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||
if self.stats["total_messages"] > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"queue_size": self.queue.qsize(),
|
||||
"total_messages": self.stats["total_messages"],
|
||||
"avg_processing_time": avg_time,
|
||||
"queue_overflow_count": self.stats["queue_overflow_count"],
|
||||
"last_active_time": self.last_active_time,
|
||||
}
|
||||
|
||||
|
||||
class StreamRouter:
|
||||
"""流路由器
|
||||
|
||||
负责将消息路由到对应的聊天流队列
|
||||
动态管理聊天流的生命周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_streams: int = 500,
|
||||
stream_timeout: int = 600,
|
||||
stream_queue_size: int = 100,
|
||||
cleanup_interval: int = 60,
|
||||
):
|
||||
self.streams: Dict[str, StreamConsumer] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.max_streams = max_streams
|
||||
self.stream_timeout = stream_timeout
|
||||
self.stream_queue_size = stream_queue_size
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动路由器"""
|
||||
if not self.is_running:
|
||||
self.is_running = True
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(
|
||||
f"StreamRouter 已启动 - "
|
||||
f"最大流数: {self.max_streams}, "
|
||||
f"超时: {self.stream_timeout}秒, "
|
||||
f"队列大小: {self.stream_queue_size}"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止路由器"""
|
||||
self.is_running = False
|
||||
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有流消费者
|
||||
logger.info(f"正在停止 {len(self.streams)} 个流消费者...")
|
||||
for consumer in self.streams.values():
|
||||
await consumer.stop()
|
||||
|
||||
self.streams.clear()
|
||||
logger.info("StreamRouter 已停止")
|
||||
|
||||
async def route_message(self, message: dict) -> None:
|
||||
"""路由消息到对应的流"""
|
||||
stream_id = self._extract_stream_id(message)
|
||||
|
||||
# 快速路径:流已存在
|
||||
if stream_id in self.streams:
|
||||
await self.streams[stream_id].enqueue(message)
|
||||
return
|
||||
|
||||
# 慢路径:需要创建新流
|
||||
async with self.lock:
|
||||
# 双重检查
|
||||
if stream_id not in self.streams:
|
||||
# 检查流数量限制
|
||||
if len(self.streams) >= self.max_streams:
|
||||
logger.warning(
|
||||
f"达到最大流数量限制 ({self.max_streams}),"
|
||||
f"尝试清理不活跃的流..."
|
||||
)
|
||||
await self._cleanup_inactive_streams()
|
||||
|
||||
# 清理后仍然超限,记录警告但继续创建
|
||||
if len(self.streams) >= self.max_streams:
|
||||
logger.error(
|
||||
f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})!"
|
||||
)
|
||||
|
||||
# 创建新流
|
||||
consumer = StreamConsumer(stream_id, self.stream_queue_size)
|
||||
self.streams[stream_id] = consumer
|
||||
await consumer.start()
|
||||
logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})")
|
||||
|
||||
await self.streams[stream_id].enqueue(message)
|
||||
|
||||
def _extract_stream_id(self, message: dict) -> str:
|
||||
"""从消息中提取 stream_id
|
||||
|
||||
返回格式: platform:id:type
|
||||
例如: qq:123456:group 或 qq:789012:private
|
||||
"""
|
||||
post_type = message.get("post_type")
|
||||
|
||||
# 非消息类型,使用默认流(避免创建过多流)
|
||||
if post_type not in ["message", "notice"]:
|
||||
return "system:meta_event"
|
||||
|
||||
# 消息类型
|
||||
if post_type == "message":
|
||||
message_type = message.get("message_type")
|
||||
if message_type == "group":
|
||||
group_id = message.get("group_id")
|
||||
return f"qq:{group_id}:group"
|
||||
elif message_type == "private":
|
||||
user_id = message.get("user_id")
|
||||
return f"qq:{user_id}:private"
|
||||
|
||||
# notice 类型
|
||||
elif post_type == "notice":
|
||||
group_id = message.get("group_id")
|
||||
if group_id:
|
||||
return f"qq:{group_id}:group"
|
||||
user_id = message.get("user_id")
|
||||
if user_id:
|
||||
return f"qq:{user_id}:private"
|
||||
|
||||
# 未知类型,使用通用流
|
||||
return "unknown:unknown"
|
||||
|
||||
async def _cleanup_inactive_streams(self) -> None:
|
||||
"""清理不活跃的流"""
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for stream_id, consumer in self.streams.items():
|
||||
if current_time - consumer.last_active_time > self.stream_timeout:
|
||||
to_remove.append(stream_id)
|
||||
|
||||
for stream_id in to_remove:
|
||||
await self.streams[stream_id].stop()
|
||||
del self.streams[stream_id]
|
||||
logger.debug(f"清理不活跃的流: {stream_id}")
|
||||
|
||||
if to_remove:
|
||||
logger.info(
|
||||
f"清理了 {len(to_remove)} 个不活跃的流 "
|
||||
f"(当前活跃流: {len(self.streams)}/{self.max_streams})"
|
||||
)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""定期清理循环"""
|
||||
logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}秒")
|
||||
try:
|
||||
while self.is_running:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._cleanup_inactive_streams()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("清理循环已停止")
|
||||
|
||||
def get_all_stats(self) -> list[dict]:
|
||||
"""获取所有流的统计信息"""
|
||||
return [consumer.get_stats() for consumer in self.streams.values()]
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""获取路由器摘要"""
|
||||
total_messages = sum(c.stats["total_messages"] for c in self.streams.values())
|
||||
total_queue_size = sum(c.queue.qsize() for c in self.streams.values())
|
||||
total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values())
|
||||
|
||||
# 计算平均队列长度
|
||||
avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0
|
||||
|
||||
# 找出最繁忙的流
|
||||
busiest_stream = None
|
||||
if self.streams:
|
||||
busiest_stream = max(
|
||||
self.streams.values(),
|
||||
key=lambda c: c.stats["total_messages"]
|
||||
).stream_id
|
||||
|
||||
return {
|
||||
"total_streams": len(self.streams),
|
||||
"max_streams": self.max_streams,
|
||||
"total_messages_processed": total_messages,
|
||||
"total_queue_size": total_queue_size,
|
||||
"avg_queue_size": avg_queue_size,
|
||||
"total_queue_overflows": total_overflows,
|
||||
"busiest_stream": busiest_stream,
|
||||
}
|
||||
|
||||
|
||||
# 全局路由器实例
|
||||
stream_router = StreamRouter()
|
||||
@@ -236,8 +236,6 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def enable_plugin(self) -> bool:
|
||||
"""通过配置文件动态控制插件启用状态"""
|
||||
# 如果已经通过配置加载了状态,使用配置中的值
|
||||
if hasattr(self, "_is_enabled"):
|
||||
return self._is_enabled
|
||||
# 否则使用默认值(禁用状态)
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user