feat: 添加带有消息处理和路由功能的NEW_napcat_adapter插件
- 为NEW_napcat_adapter插件实现了核心模块,包括消息处理、事件处理和路由。 - 创建了MessageHandler、MetaEventHandler和NoticeHandler来处理收到的消息和事件。 - 开发了SendHandler,用于向Napcat发送回消息。 引入了StreamRouter来管理多个聊天流,确保消息的顺序和高效处理。 - 增加了对各种消息类型和格式的支持,包括文本、图像和通知。 - 建立了一个用于监控和调试的日志系统。
This commit is contained in:
@@ -32,11 +32,11 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess
|
|||||||
|
|
||||||
## 3. 消息模型
|
## 3. 消息模型
|
||||||
|
|
||||||
### 3.1 Envelope TypedDict(`types.py`)
|
### 3.1 Envelope TypedDict<EFBFBD><EFBFBD>`types.py`<EFBFBD><EFBFBD>
|
||||||
|
|
||||||
- `MessageEnvelope`:核心字段包括 `id`、`direction`、`platform`、`timestamp_ms`、`channel`、`sender`、`content` 等,一律使用毫秒时间戳,保留 `raw_platform_message` 与 `metadata` 便于调试 / 扩展。
|
- `MessageEnvelope` <20><>ȫ<EFBFBD><C8AB>Ƶ<EFBFBD> maim_message <20>ṹ<EFBFBD><E1B9B9><EFBFBD><EFBFBD><EFBFBD>ĵ<EFBFBD><C4B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> `message_info` + `message_segment` (SegPayload)<29><>`direction`<EFBFBD><EFBFBD>`schema_version` <20><> raw <20><><EFBFBD><EFBFBD><EFBFBD>ֶβ<D6B6><CEB2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ˣ<EFBFBD><CBA3><EFBFBD>Ժ<EFBFBD><D4BA><EFBFBD><EFBFBD><EFBFBD> `channel`<EFBFBD><EFBFBD>`sender`<EFBFBD><EFBFBD>`content` <EFBFBD><EFBFBD> v0 <20>ֶΪ<D6B6><CEAA>ѡ<EFBFBD><D1A1>
|
||||||
- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。
|
- `SegPayload` / `MessageInfoPayload` / `UserInfoPayload` / `GroupInfoPayload` / `FormatInfoPayload` / `TemplateInfoPayload` <20><> maim_message dataclass <20>Դ<EFBFBD>TypedDict <20><>Ӧ<EFBFBD><D3A6><EFBFBD>ʺ<EFBFBD>ֱ<EFBFBD><D6B1> JSON <20><><EFBFBD><EFBFBD>
|
||||||
- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。
|
- `Content` / `SenderInfo` / `ChannelInfo` <EFBFBD>Ȳ<EFBFBD>Ȼ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ܻ<EFBFBD><EFBFBD><EFBFBD> IDE ע<>⣬Ҳ<E2A3AC>Ƕ<EFBFBD> v0 content ģ<>͵Ļ<CDB5>֧
|
||||||
|
|
||||||
### 3.2 dataclass 消息段(`message_models.py`)
|
### 3.2 dataclass 消息段(`message_models.py`)
|
||||||
|
|
||||||
@@ -62,15 +62,14 @@ TypedDict 更适合网络传输和依赖注入;dataclass 版 MessageBase 则
|
|||||||
## 5. 运行时调度(`runtime.py`)
|
## 5. 运行时调度(`runtime.py`)
|
||||||
|
|
||||||
- `MessageRuntime`:
|
- `MessageRuntime`:
|
||||||
- `add_route(predicate, handler)` 或 `@runtime.route(...)` 装饰器注册消息处理器。
|
- `add_route(predicate, handler)` 和 `@runtime.route(...)` 装饰器注册消息处理器
|
||||||
- `register_before_hook` / `register_after_hook` / `register_error_hook` 注入监控、埋点、Trace。
|
- `register_before_hook` / `register_after_hook` / `register_error_hook` 注册前置、后置、Trace 处理
|
||||||
- `set_batch_handler` 支持一次处理整批消息(例如批量落库)。
|
- `set_batch_handler` 支持一次处理一批消息(可用于 batch IO 优化)
|
||||||
- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。
|
- `MessageProcessingError` 在 handler 抛出异常时包装原因,方便日志追踪。
|
||||||
|
|
||||||
运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。
|
运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 6. 传输层封装(`transport/`)
|
## 6. 传输层封装(`transport/`)
|
||||||
|
|
||||||
### 6.1 HTTP
|
### 6.1 HTTP
|
||||||
@@ -126,9 +125,9 @@ from mofox_bus.transport import HttpMessageServer
|
|||||||
|
|
||||||
runtime = MessageRuntime()
|
runtime = MessageRuntime()
|
||||||
|
|
||||||
@runtime.route(lambda env: env["content"]["type"] == "text")
|
@runtime.route(lambda env: (env.get("message_segment") or {}).get("type") == "text")
|
||||||
async def handle_text(env: types.MessageEnvelope):
|
async def handle_text(env: types.MessageEnvelope):
|
||||||
print("收到文本:", env["content"]["text"])
|
print("收到文本", env["message_segment"]["data"])
|
||||||
|
|
||||||
async def http_handler(messages: list[types.MessageEnvelope]):
|
async def http_handler(messages: list[types.MessageEnvelope]):
|
||||||
await runtime.handle_batch(messages)
|
await runtime.handle_batch(messages)
|
||||||
|
|||||||
35
src/common/core_sink.py
Normal file
35
src/common/core_sink.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
从 src.main 导出 core_sink 的辅助函数
|
||||||
|
|
||||||
|
由于 src.main 中实际使用的是 InProcessCoreSink,
|
||||||
|
我们需要创建一个全局访问点
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mofox_bus import CoreSink, InProcessCoreSink
|
||||||
|
|
||||||
|
_global_core_sink: CoreSink | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_core_sink(sink: CoreSink) -> None:
|
||||||
|
"""设置全局 core sink"""
|
||||||
|
global _global_core_sink
|
||||||
|
_global_core_sink = sink
|
||||||
|
|
||||||
|
|
||||||
|
def get_core_sink() -> CoreSink:
|
||||||
|
"""获取全局 core sink"""
|
||||||
|
global _global_core_sink
|
||||||
|
if _global_core_sink is None:
|
||||||
|
raise RuntimeError("Core sink 尚未初始化")
|
||||||
|
return _global_core_sink
|
||||||
|
|
||||||
|
|
||||||
|
async def push_outgoing(envelope) -> None:
|
||||||
|
"""将消息推送到 core sink 的 outgoing 通道"""
|
||||||
|
sink = get_core_sink()
|
||||||
|
push = getattr(sink, "push_outgoing", None)
|
||||||
|
if push is None:
|
||||||
|
raise RuntimeError("当前 core sink 不支持 push_outgoing 方法")
|
||||||
|
await push(envelope)
|
||||||
|
|
||||||
|
__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"]
|
||||||
@@ -1,15 +1,22 @@
|
|||||||
"""
|
"""
|
||||||
MessageEnvelope 转换器
|
MessageEnvelope converter between mofox_bus schema and internal message structures.
|
||||||
|
|
||||||
将 mofox_bus 的 MessageEnvelope 转换为 MoFox Bot 内部使用的消息格式。
|
- 优先处理 maim_message 风格的 message_info + message_segment。
|
||||||
"""
|
- 兼容旧版 content/sender/channel 结构,方便逐步迁移。
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from mofox_bus import MessageEnvelope, BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
|
from mofox_bus import (
|
||||||
|
BaseMessageInfo,
|
||||||
|
MessageBase,
|
||||||
|
MessageEnvelope,
|
||||||
|
Seg,
|
||||||
|
UserInfo,
|
||||||
|
GroupInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -17,151 +24,221 @@ logger = get_logger("envelope_converter")
|
|||||||
|
|
||||||
|
|
||||||
class EnvelopeConverter:
|
class EnvelopeConverter:
|
||||||
"""MessageEnvelope 到内部消息格式的转换器"""
|
"""MessageEnvelope <-> MessageBase converter."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
|
def to_message_base(envelope: MessageEnvelope) -> MessageBase:
|
||||||
"""
|
"""
|
||||||
将 MessageEnvelope 转换为 MessageBase
|
Convert MessageEnvelope to MessageBase.
|
||||||
|
|
||||||
Args:
|
|
||||||
envelope: 统一的消息信封
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MessageBase: 内部消息格式
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 提取基本信息
|
# 优先使用 maim_message 样式字段
|
||||||
platform = envelope["platform"]
|
info_payload = envelope.get("message_info") or {}
|
||||||
channel = envelope["channel"]
|
seg_payload = envelope.get("message_segment") or envelope.get("message_chain")
|
||||||
sender = envelope["sender"]
|
|
||||||
content = envelope["content"]
|
if info_payload:
|
||||||
|
message_info = BaseMessageInfo.from_dict(info_payload)
|
||||||
# 创建 UserInfo
|
else:
|
||||||
user_info = UserInfo(
|
message_info = EnvelopeConverter._build_info_from_legacy(envelope)
|
||||||
user_id=sender["user_id"],
|
|
||||||
user_nickname=sender.get("display_name", sender["user_id"]),
|
if seg_payload is None:
|
||||||
user_avatar=sender.get("avatar_url"),
|
seg_list = EnvelopeConverter._content_to_segments(envelope.get("content"))
|
||||||
)
|
seg_payload = seg_list
|
||||||
|
|
||||||
# 创建 GroupInfo (如果是群组消息)
|
message_segment = EnvelopeConverter._ensure_seg(seg_payload)
|
||||||
group_info: Optional[GroupInfo] = None
|
raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message")
|
||||||
if channel["channel_type"] in ("group", "supergroup", "room"):
|
|
||||||
group_info = GroupInfo(
|
return MessageBase(
|
||||||
group_id=channel["channel_id"],
|
|
||||||
group_name=channel.get("title", channel["channel_id"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 BaseMessageInfo
|
|
||||||
message_info = BaseMessageInfo(
|
|
||||||
platform=platform,
|
|
||||||
chat_type="group" if group_info else "private",
|
|
||||||
message_id=envelope["id"],
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info,
|
|
||||||
timestamp=envelope["timestamp_ms"] / 1000.0, # 转换为秒
|
|
||||||
)
|
|
||||||
|
|
||||||
# 转换 Content 为 Seg 列表
|
|
||||||
segments = EnvelopeConverter._content_to_segments(content)
|
|
||||||
|
|
||||||
# 创建 MessageBase
|
|
||||||
message_base = MessageBase(
|
|
||||||
message_info=message_info,
|
message_info=message_info,
|
||||||
message=segments,
|
message_segment=message_segment,
|
||||||
|
raw_message=raw_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存原始 envelope 到 raw 字段
|
|
||||||
if hasattr(message_base, "raw"):
|
|
||||||
message_base.raw = envelope
|
|
||||||
|
|
||||||
return message_base
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
|
logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _content_to_segments(content: Dict[str, Any]) -> List[Seg]:
|
def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo:
|
||||||
|
"""将 legacy 字段映射为 BaseMessageInfo。"""
|
||||||
|
platform = envelope.get("platform")
|
||||||
|
channel = envelope.get("channel") or {}
|
||||||
|
sender = envelope.get("sender") or {}
|
||||||
|
|
||||||
|
message_id = envelope.get("id") or envelope.get("message_id")
|
||||||
|
timestamp_ms = envelope.get("timestamp_ms")
|
||||||
|
time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None
|
||||||
|
|
||||||
|
group_info: Optional[GroupInfo] = None
|
||||||
|
channel_type = channel.get("channel_type")
|
||||||
|
if channel_type in ("group", "supergroup", "room"):
|
||||||
|
group_info = GroupInfo(
|
||||||
|
platform=platform,
|
||||||
|
group_id=channel.get("channel_id"),
|
||||||
|
group_name=channel.get("title"),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_info: Optional[UserInfo] = None
|
||||||
|
if sender:
|
||||||
|
user_info = UserInfo(
|
||||||
|
platform=platform,
|
||||||
|
user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None,
|
||||||
|
user_nickname=sender.get("display_name") or sender.get("user_nickname"),
|
||||||
|
user_avatar=sender.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseMessageInfo(
|
||||||
|
platform=platform,
|
||||||
|
message_id=message_id,
|
||||||
|
time=time_value,
|
||||||
|
group_info=group_info,
|
||||||
|
user_info=user_info,
|
||||||
|
additional_config=envelope.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_seg(payload: Any) -> Seg:
|
||||||
|
"""将任意 payload 转为 Seg dataclass。"""
|
||||||
|
if isinstance(payload, Seg):
|
||||||
|
return payload
|
||||||
|
if isinstance(payload, list):
|
||||||
|
# 直接传入 Seg 列表或 seglist data
|
||||||
|
return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload])
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
seg_type = payload.get("type") or "text"
|
||||||
|
data = payload.get("data")
|
||||||
|
if seg_type == "seglist" and isinstance(data, list):
|
||||||
|
data = [EnvelopeConverter._ensure_seg(item) for item in data]
|
||||||
|
return Seg(type=seg_type, data=data)
|
||||||
|
# 兜底:转成文本片段
|
||||||
|
return Seg(type="text", data="" if payload is None else str(payload))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_segments(seg: Seg) -> List[Seg]:
|
||||||
|
"""将 Seg/seglist 打平成列表,便于旧 content 转换。"""
|
||||||
|
if seg.type == "seglist" and isinstance(seg.data, list):
|
||||||
|
return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data]
|
||||||
|
return [seg]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _content_to_segments(content: Any) -> List[Seg]:
|
||||||
"""
|
"""
|
||||||
将 Content 转换为 Seg 列表
|
Convert legacy Content (type/data/metadata) to a flat list of Seg.
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 消息内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Seg]: 消息段列表
|
|
||||||
"""
|
"""
|
||||||
segments: List[Seg] = []
|
segments: List[Seg] = []
|
||||||
content_type = content.get("type")
|
|
||||||
|
def _walk(node: Any) -> None:
|
||||||
if content_type == "text":
|
if node is None:
|
||||||
# 文本消息
|
return
|
||||||
text = content.get("text", "")
|
if isinstance(node, list):
|
||||||
segments.append(Seg.text(text))
|
for item in node:
|
||||||
|
_walk(item)
|
||||||
elif content_type == "image":
|
return
|
||||||
# 图片消息
|
if not isinstance(node, dict):
|
||||||
url = content.get("url", "")
|
logger.warning("未知的 content 节点类型: %s", type(node))
|
||||||
file_id = content.get("file_id")
|
return
|
||||||
segments.append(Seg.image(url if url else file_id))
|
|
||||||
|
content_type = node.get("type")
|
||||||
elif content_type == "audio":
|
data = node.get("data")
|
||||||
# 音频消息
|
metadata = node.get("metadata") or {}
|
||||||
url = content.get("url", "")
|
|
||||||
file_id = content.get("file_id")
|
if content_type == "collection":
|
||||||
segments.append(Seg.record(url if url else file_id))
|
items = data if isinstance(data, list) else node.get("items", [])
|
||||||
|
for item in items:
|
||||||
elif content_type == "video":
|
_walk(item)
|
||||||
# 视频消息
|
return
|
||||||
url = content.get("url", "")
|
|
||||||
file_id = content.get("file_id")
|
if content_type in ("text", "at"):
|
||||||
segments.append(Seg.video(url if url else file_id))
|
subtype = metadata.get("subtype") or ("at" if content_type == "at" else None)
|
||||||
|
text = "" if data is None else str(data)
|
||||||
elif content_type == "file":
|
if subtype in ("at", "mention"):
|
||||||
# 文件消息
|
user_info = metadata.get("user") or {}
|
||||||
url = content.get("url", "")
|
seg_data: Dict[str, Any] = {
|
||||||
file_name = content.get("file_name", "file")
|
"user_id": user_info.get("id") or user_info.get("user_id"),
|
||||||
# 使用 text 表示文件(或者可以自定义一个 file seg type)
|
"user_name": user_info.get("name") or user_info.get("display_name"),
|
||||||
segments.append(Seg.text(f"[文件: {file_name}]"))
|
"text": text,
|
||||||
|
"raw": user_info.get("raw") or user_info if user_info else None,
|
||||||
elif content_type == "command":
|
}
|
||||||
# 命令消息
|
if any(v is not None for v in seg_data.values()):
|
||||||
name = content.get("name", "")
|
segments.append(Seg(type="at", data=seg_data))
|
||||||
args = content.get("args", {})
|
else:
|
||||||
# 重构为文本格式
|
segments.append(Seg(type="at", data=text))
|
||||||
cmd_text = f"/{name}"
|
else:
|
||||||
if args:
|
segments.append(Seg(type="text", data=text))
|
||||||
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
return
|
||||||
segments.append(Seg.text(cmd_text))
|
|
||||||
|
if content_type == "image":
|
||||||
elif content_type == "event":
|
url = ""
|
||||||
# 事件消息 - 转换为文本表示
|
if isinstance(data, dict):
|
||||||
event_type = content.get("event_type", "unknown")
|
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||||
segments.append(Seg.text(f"[事件: {event_type}]"))
|
elif data is not None:
|
||||||
|
url = str(data)
|
||||||
elif content_type == "system":
|
segments.append(Seg(type="image", data=url))
|
||||||
# 系统消息
|
return
|
||||||
text = content.get("text", "")
|
|
||||||
segments.append(Seg.text(f"[系统] {text}"))
|
if content_type == "audio":
|
||||||
|
url = ""
|
||||||
else:
|
if isinstance(data, dict):
|
||||||
# 未知类型 - 转换为文本
|
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||||
|
elif data is not None:
|
||||||
|
url = str(data)
|
||||||
|
segments.append(Seg(type="record", data=url))
|
||||||
|
return
|
||||||
|
|
||||||
|
if content_type == "video":
|
||||||
|
url = ""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
url = data.get("url") or data.get("file") or data.get("file_id") or ""
|
||||||
|
elif data is not None:
|
||||||
|
url = str(data)
|
||||||
|
segments.append(Seg(type="video", data=url))
|
||||||
|
return
|
||||||
|
|
||||||
|
if content_type == "file":
|
||||||
|
file_name = ""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
file_name = data.get("file_name") or data.get("name") or ""
|
||||||
|
text = file_name or "[file]"
|
||||||
|
segments.append(Seg(type="text", data=text))
|
||||||
|
return
|
||||||
|
|
||||||
|
if content_type == "command":
|
||||||
|
name = ""
|
||||||
|
args: Dict[str, Any] = {}
|
||||||
|
if isinstance(data, dict):
|
||||||
|
name = data.get("name", "")
|
||||||
|
args = data.get("args", {}) or {}
|
||||||
|
else:
|
||||||
|
name = str(data or "")
|
||||||
|
cmd_text = f"/{name}" if name else "/command"
|
||||||
|
if args:
|
||||||
|
cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items())
|
||||||
|
segments.append(Seg(type="text", data=cmd_text))
|
||||||
|
return
|
||||||
|
|
||||||
|
if content_type == "event":
|
||||||
|
event_type = ""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
event_type = data.get("event_type", "")
|
||||||
|
else:
|
||||||
|
event_type = str(data or "")
|
||||||
|
segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]"))
|
||||||
|
return
|
||||||
|
|
||||||
|
if content_type == "system":
|
||||||
|
text = "" if data is None else str(data)
|
||||||
|
segments.append(Seg(type="text", data=f"[系统] {text}"))
|
||||||
|
return
|
||||||
|
|
||||||
logger.warning(f"未知的消息类型: {content_type}")
|
logger.warning(f"未知的消息类型: {content_type}")
|
||||||
segments.append(Seg.text(f"[未知消息类型: {content_type}]"))
|
segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]"))
|
||||||
|
|
||||||
|
_walk(content)
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
|
def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
将 MessageEnvelope 转换为旧版字典格式(用于向后兼容)
|
Convert MessageEnvelope to legacy dict for backward compatibility.
|
||||||
|
|
||||||
Args:
|
|
||||||
envelope: 统一的消息信封
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 旧版消息字典
|
|
||||||
"""
|
"""
|
||||||
message_base = EnvelopeConverter.to_message_base(envelope)
|
message_base = EnvelopeConverter.to_message_base(envelope)
|
||||||
return message_base.to_dict()
|
return message_base.to_dict()
|
||||||
@@ -169,61 +246,45 @@ class EnvelopeConverter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
|
def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope:
|
||||||
"""
|
"""
|
||||||
将 MessageBase 转换为 MessageEnvelope (反向转换)
|
Convert MessageBase to MessageEnvelope (maim_message style preferred).
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 内部消息格式
|
|
||||||
direction: 消息方向 ("incoming" 或 "outgoing")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MessageEnvelope: 统一的消息信封
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
message_info = message.message_info
|
info_dict = message.message_info.to_dict()
|
||||||
user_info = message_info.user_info
|
seg_dict = message.message_segment.to_dict()
|
||||||
group_info = message_info.group_info
|
|
||||||
|
|
||||||
# 创建 SenderInfo
|
|
||||||
sender = {
|
|
||||||
"user_id": user_info.user_id,
|
|
||||||
"role": "assistant" if direction == "outgoing" else "user",
|
|
||||||
}
|
|
||||||
if user_info.user_nickname:
|
|
||||||
sender["display_name"] = user_info.user_nickname
|
|
||||||
if user_info.user_avatar:
|
|
||||||
sender["avatar_url"] = user_info.user_avatar
|
|
||||||
|
|
||||||
# 创建 ChannelInfo
|
|
||||||
if group_info:
|
|
||||||
channel = {
|
|
||||||
"channel_id": group_info.group_id,
|
|
||||||
"channel_type": "group",
|
|
||||||
}
|
|
||||||
if group_info.group_name:
|
|
||||||
channel["title"] = group_info.group_name
|
|
||||||
else:
|
|
||||||
channel = {
|
|
||||||
"channel_id": user_info.user_id,
|
|
||||||
"channel_type": "private",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 转换 segments 为 Content
|
|
||||||
content = EnvelopeConverter._segments_to_content(message.message)
|
|
||||||
|
|
||||||
# 创建 MessageEnvelope
|
|
||||||
envelope: MessageEnvelope = {
|
envelope: MessageEnvelope = {
|
||||||
"id": message_info.message_id,
|
|
||||||
"direction": direction,
|
"direction": direction,
|
||||||
"platform": message_info.platform,
|
"message_info": info_dict,
|
||||||
"timestamp_ms": int(message_info.timestamp * 1000),
|
"message_segment": seg_dict,
|
||||||
"channel": channel,
|
"platform": info_dict.get("platform"),
|
||||||
"sender": sender,
|
"message_id": info_dict.get("message_id"),
|
||||||
"content": content,
|
"schema_version": 1,
|
||||||
"conversation_id": group_info.group_id if group_info else user_info.user_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if message.message_info.time is not None:
|
||||||
|
envelope["timestamp_ms"] = int(message.message_info.time * 1000)
|
||||||
|
if message.raw_message is not None:
|
||||||
|
envelope["raw_message"] = message.raw_message
|
||||||
|
|
||||||
|
# legacy 补充,方便老代码继续工作
|
||||||
|
segments = EnvelopeConverter._flatten_segments(message.message_segment)
|
||||||
|
envelope["content"] = EnvelopeConverter._segments_to_content(segments)
|
||||||
|
if message.message_info.user_info:
|
||||||
|
envelope["sender"] = {
|
||||||
|
"user_id": message.message_info.user_info.user_id,
|
||||||
|
"role": "assistant" if direction == "outgoing" else "user",
|
||||||
|
"display_name": message.message_info.user_info.user_nickname,
|
||||||
|
"avatar_url": getattr(message.message_info.user_info, "user_avatar", None),
|
||||||
|
}
|
||||||
|
if message.message_info.group_info:
|
||||||
|
envelope["channel"] = {
|
||||||
|
"channel_id": message.message_info.group_info.group_id,
|
||||||
|
"channel_type": "group",
|
||||||
|
"title": message.message_info.group_info.group_name,
|
||||||
|
}
|
||||||
|
|
||||||
return envelope
|
return envelope
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
|
logger.error(f"转换 MessageBase 失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -231,45 +292,50 @@ class EnvelopeConverter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
|
def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
将 Seg 列表转换为 Content
|
Convert Seg list to legacy Content (type/data/metadata).
|
||||||
|
|
||||||
Args:
|
|
||||||
segments: 消息段列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 消息内容
|
|
||||||
"""
|
"""
|
||||||
if not segments:
|
if not segments:
|
||||||
return {"type": "text", "text": ""}
|
return {"type": "text", "data": ""}
|
||||||
|
|
||||||
# 简化处理:如果有多个段,合并为文本
|
def _seg_to_content(seg: Seg) -> Dict[str, Any]:
|
||||||
|
data = seg.data
|
||||||
|
|
||||||
|
if seg.type == "text":
|
||||||
|
return {"type": "text", "data": data}
|
||||||
|
|
||||||
|
if seg.type == "at":
|
||||||
|
content: Dict[str, Any] = {"type": "text", "data": ""}
|
||||||
|
metadata: Dict[str, Any] = {"subtype": "at"}
|
||||||
|
if isinstance(data, dict):
|
||||||
|
content["data"] = data.get("text", "")
|
||||||
|
user = {
|
||||||
|
"id": data.get("user_id"),
|
||||||
|
"name": data.get("user_name"),
|
||||||
|
"raw": data.get("raw"),
|
||||||
|
}
|
||||||
|
if any(v is not None for v in user.values()):
|
||||||
|
metadata["user"] = user
|
||||||
|
else:
|
||||||
|
content["data"] = data
|
||||||
|
if metadata:
|
||||||
|
content["metadata"] = metadata
|
||||||
|
return content
|
||||||
|
|
||||||
|
if seg.type == "image":
|
||||||
|
return {"type": "image", "data": data}
|
||||||
|
|
||||||
|
if seg.type in ("record", "voice", "audio"):
|
||||||
|
return {"type": "audio", "data": data}
|
||||||
|
|
||||||
|
if seg.type == "video":
|
||||||
|
return {"type": "video", "data": data}
|
||||||
|
|
||||||
|
return {"type": seg.type, "data": data}
|
||||||
|
|
||||||
if len(segments) == 1:
|
if len(segments) == 1:
|
||||||
seg = segments[0]
|
return _seg_to_content(segments[0])
|
||||||
|
|
||||||
if seg.type == "text":
|
return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]}
|
||||||
return {"type": "text", "text": seg.data.get("text", "")}
|
|
||||||
elif seg.type == "image":
|
|
||||||
return {"type": "image", "url": seg.data.get("file", "")}
|
|
||||||
elif seg.type == "record":
|
|
||||||
return {"type": "audio", "url": seg.data.get("file", "")}
|
|
||||||
elif seg.type == "video":
|
|
||||||
return {"type": "video", "url": seg.data.get("file", "")}
|
|
||||||
|
|
||||||
# 多个段或未知类型 - 合并为文本
|
|
||||||
text_parts = []
|
|
||||||
for seg in segments:
|
|
||||||
if seg.type == "text":
|
|
||||||
text_parts.append(seg.data.get("text", ""))
|
|
||||||
elif seg.type == "image":
|
|
||||||
text_parts.append("[图片]")
|
|
||||||
elif seg.type == "record":
|
|
||||||
text_parts.append("[语音]")
|
|
||||||
elif seg.type == "video":
|
|
||||||
text_parts.append("[视频]")
|
|
||||||
else:
|
|
||||||
text_parts.append(f"[{seg.type}]")
|
|
||||||
|
|
||||||
return {"type": "text", "text": "".join(text_parts)}
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["EnvelopeConverter"]
|
__all__ = ["EnvelopeConverter"]
|
||||||
|
|||||||
@@ -10,72 +10,54 @@ from .adapter_utils import (
|
|||||||
AdapterTransportOptions,
|
AdapterTransportOptions,
|
||||||
AdapterBase,
|
AdapterBase,
|
||||||
BatchDispatcher,
|
BatchDispatcher,
|
||||||
|
CoreSink,
|
||||||
CoreMessageSink,
|
CoreMessageSink,
|
||||||
HttpAdapterOptions,
|
HttpAdapterOptions,
|
||||||
InProcessCoreSink,
|
InProcessCoreSink,
|
||||||
|
ProcessCoreSink,
|
||||||
|
ProcessCoreSinkServer,
|
||||||
WebSocketLike,
|
WebSocketLike,
|
||||||
WebSocketAdapterOptions,
|
WebSocketAdapterOptions,
|
||||||
)
|
)
|
||||||
from .api import MessageClient, MessageServer
|
from .api import MessageClient, MessageServer
|
||||||
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
|
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
|
||||||
from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo
|
from .builder import MessageBuilder
|
||||||
from .router import RouteConfig, Router, TargetConfig
|
from .router import RouteConfig, Router, TargetConfig
|
||||||
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime
|
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware
|
||||||
from .types import (
|
from .types import (
|
||||||
AudioContent,
|
FormatInfoPayload,
|
||||||
ChannelInfo,
|
GroupInfoPayload,
|
||||||
CommandContent,
|
|
||||||
Content,
|
|
||||||
ContentType,
|
|
||||||
EventContent,
|
|
||||||
EventType,
|
|
||||||
FileContent,
|
|
||||||
ImageContent,
|
|
||||||
MessageDirection,
|
MessageDirection,
|
||||||
MessageEnvelope,
|
MessageEnvelope,
|
||||||
Role,
|
MessageInfoPayload,
|
||||||
SenderInfo,
|
SegPayload,
|
||||||
SystemContent,
|
|
||||||
TextContent,
|
TemplateInfoPayload,
|
||||||
VideoContent,
|
UserInfoPayload,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# TypedDict model
|
# TypedDict model
|
||||||
"AudioContent",
|
|
||||||
"ChannelInfo",
|
|
||||||
"CommandContent",
|
|
||||||
"Content",
|
|
||||||
"ContentType",
|
|
||||||
"EventContent",
|
|
||||||
"EventType",
|
|
||||||
"FileContent",
|
|
||||||
"ImageContent",
|
|
||||||
"MessageDirection",
|
"MessageDirection",
|
||||||
"MessageEnvelope",
|
"MessageEnvelope",
|
||||||
"Role",
|
"SegPayload",
|
||||||
"SenderInfo",
|
"UserInfoPayload",
|
||||||
"SystemContent",
|
"GroupInfoPayload",
|
||||||
"TextContent",
|
"FormatInfoPayload",
|
||||||
"VideoContent",
|
"TemplateInfoPayload",
|
||||||
|
"MessageInfoPayload",
|
||||||
# Codec helpers
|
# Codec helpers
|
||||||
"codec",
|
"codec",
|
||||||
"dumps_message",
|
"dumps_message",
|
||||||
"dumps_messages",
|
"dumps_messages",
|
||||||
"loads_message",
|
"loads_message",
|
||||||
"loads_messages",
|
"loads_messages",
|
||||||
|
"MessageBuilder",
|
||||||
# Runtime / routing
|
# Runtime / routing
|
||||||
"MessageRoute",
|
"MessageRoute",
|
||||||
"MessageRuntime",
|
"MessageRuntime",
|
||||||
"MessageProcessingError",
|
"MessageProcessingError",
|
||||||
# Message dataclasses
|
"Middleware",
|
||||||
"Seg",
|
|
||||||
"GroupInfo",
|
|
||||||
"UserInfo",
|
|
||||||
"FormatInfo",
|
|
||||||
"TemplateInfo",
|
|
||||||
"BaseMessageInfo",
|
|
||||||
"MessageBase",
|
|
||||||
# Server/client/router
|
# Server/client/router
|
||||||
"MessageServer",
|
"MessageServer",
|
||||||
"MessageClient",
|
"MessageClient",
|
||||||
@@ -86,8 +68,11 @@ __all__ = [
|
|||||||
"AdapterTransportOptions",
|
"AdapterTransportOptions",
|
||||||
"AdapterBase",
|
"AdapterBase",
|
||||||
"BatchDispatcher",
|
"BatchDispatcher",
|
||||||
|
"CoreSink",
|
||||||
"CoreMessageSink",
|
"CoreMessageSink",
|
||||||
"InProcessCoreSink",
|
"InProcessCoreSink",
|
||||||
|
"ProcessCoreSink",
|
||||||
|
"ProcessCoreSinkServer",
|
||||||
"WebSocketLike",
|
"WebSocketLike",
|
||||||
"WebSocketAdapterOptions",
|
"WebSocketAdapterOptions",
|
||||||
"HttpAdapterOptions",
|
"HttpAdapterOptions",
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
|
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
|
||||||
|
|
||||||
@@ -11,6 +13,11 @@ import websockets
|
|||||||
|
|
||||||
from .types import MessageEnvelope
|
from .types import MessageEnvelope
|
||||||
|
|
||||||
|
logger = logging.getLogger("mofox_bus.adapter")
|
||||||
|
|
||||||
|
|
||||||
|
OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
class CoreMessageSink(Protocol):
|
class CoreMessageSink(Protocol):
|
||||||
async def send(self, message: MessageEnvelope) -> None: ...
|
async def send(self, message: MessageEnvelope) -> None: ...
|
||||||
@@ -18,6 +25,22 @@ class CoreMessageSink(Protocol):
|
|||||||
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
|
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
|
||||||
|
|
||||||
|
|
||||||
|
class CoreSink(CoreMessageSink, Protocol):
|
||||||
|
"""
|
||||||
|
双向 CoreSink 协议:
|
||||||
|
- send/send_many: 适配器 → 核心(incoming)
|
||||||
|
- push_outgoing: 核心 → 适配器(outgoing)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ...
|
||||||
|
|
||||||
|
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ...
|
||||||
|
|
||||||
|
async def push_outgoing(self, envelope: MessageEnvelope) -> None: ...
|
||||||
|
|
||||||
|
async def close(self) -> None: ... # pragma: no cover - lifecycle hook
|
||||||
|
|
||||||
|
|
||||||
class WebSocketLike(Protocol):
|
class WebSocketLike(Protocol):
|
||||||
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
|
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
|
||||||
|
|
||||||
@@ -56,7 +79,7 @@ class AdapterBase:
|
|||||||
|
|
||||||
platform: str = "unknown"
|
platform: str = "unknown"
|
||||||
|
|
||||||
def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None):
|
def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
|
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
|
||||||
@@ -70,14 +93,31 @@ class AdapterBase:
|
|||||||
self._http_site: aiohttp_web.BaseSite | None = None
|
self._http_site: aiohttp_web.BaseSite | None = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""根据配置自动启动 WS/HTTP 监听。"""
|
"""启动适配器的传输层监听(如果配置了传输选项)。"""
|
||||||
|
if hasattr(self.core_sink, "set_outgoing_handler"):
|
||||||
|
try:
|
||||||
|
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to register outgoing handler on core sink")
|
||||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||||
await self._start_ws_transport(self._transport_config)
|
await self._start_ws_transport(self._transport_config)
|
||||||
elif isinstance(self._transport_config, HttpAdapterOptions):
|
elif isinstance(self._transport_config, HttpAdapterOptions):
|
||||||
await self._start_http_transport(self._transport_config)
|
await self._start_http_transport(self._transport_config)
|
||||||
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""停止自动管理的传输层。"""
|
"""停止适配器的传输层监听(如果配置了传输选项)。"""
|
||||||
|
remove = getattr(self.core_sink, "remove_outgoing_handler", None)
|
||||||
|
if callable(remove):
|
||||||
|
try:
|
||||||
|
remove(self._on_outgoing_from_core)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to detach outgoing handler on core sink")
|
||||||
|
elif hasattr(self.core_sink, "set_outgoing_handler"):
|
||||||
|
try:
|
||||||
|
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to detach outgoing handler on core sink")
|
||||||
if self._ws_task:
|
if self._ws_task:
|
||||||
self._ws_task.cancel()
|
self._ws_task.cancel()
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
@@ -95,12 +135,12 @@ class AdapterBase:
|
|||||||
|
|
||||||
async def on_platform_message(self, raw: Any) -> None:
|
async def on_platform_message(self, raw: Any) -> None:
|
||||||
"""处理平台下发的单条消息并交给核心。"""
|
"""处理平台下发的单条消息并交给核心。"""
|
||||||
envelope = self.from_platform_message(raw)
|
envelope = await _maybe_await(self.from_platform_message(raw))
|
||||||
await self.core_sink.send(envelope)
|
await self.core_sink.send(envelope)
|
||||||
|
|
||||||
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
||||||
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
||||||
envelopes = [self.from_platform_message(raw) for raw in raw_messages]
|
envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages]
|
||||||
await _send_many(self.core_sink, envelopes)
|
await _send_many(self.core_sink, envelopes)
|
||||||
|
|
||||||
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
||||||
@@ -112,7 +152,14 @@ class AdapterBase:
|
|||||||
for env in envelopes:
|
for env in envelopes:
|
||||||
await self._send_platform_message(env)
|
await self._send_platform_message(env)
|
||||||
|
|
||||||
def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None:
|
||||||
|
"""核心生成 outgoing envelope 时的内部处理逻辑"""
|
||||||
|
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
|
||||||
|
if platform and platform != getattr(self, "platform", None):
|
||||||
|
return
|
||||||
|
await self._send_platform_message(envelope)
|
||||||
|
|
||||||
|
def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]:
|
||||||
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -175,13 +222,22 @@ class AdapterBase:
|
|||||||
return orjson.dumps({"type": "send", "payload": envelope})
|
return orjson.dumps({"type": "send", "payload": envelope})
|
||||||
|
|
||||||
|
|
||||||
class InProcessCoreSink:
|
class InProcessCoreSink(CoreSink):
|
||||||
"""
|
"""
|
||||||
简单的进程内 sink,实现 CoreMessageSink 协议。
|
进程内核心消息 sink,实现 CoreSink 协议。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
|
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
|
||||||
self._handler = handler
|
self._handler = handler
|
||||||
|
self._outgoing_handlers: set[OutgoingHandler] = set()
|
||||||
|
|
||||||
|
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||||
|
if handler is None:
|
||||||
|
return
|
||||||
|
self._outgoing_handlers.add(handler)
|
||||||
|
|
||||||
|
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||||
|
self._outgoing_handlers.discard(handler)
|
||||||
|
|
||||||
async def send(self, message: MessageEnvelope) -> None:
|
async def send(self, message: MessageEnvelope) -> None:
|
||||||
await self._handler(message)
|
await self._handler(message)
|
||||||
@@ -190,6 +246,140 @@ class InProcessCoreSink:
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
await self._handler(message)
|
await self._handler(message)
|
||||||
|
|
||||||
|
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||||
|
if not self._outgoing_handlers:
|
||||||
|
logger.debug("Outgoing envelope dropped: no handler registered")
|
||||||
|
return
|
||||||
|
for callback in list(self._outgoing_handlers):
|
||||||
|
await callback(envelope)
|
||||||
|
|
||||||
|
async def close(self) -> None: # pragma: no cover - symmetry
|
||||||
|
self._outgoing_handlers.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessCoreSink(CoreSink):
|
||||||
|
"""
|
||||||
|
进程间核心消息 sink,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化
|
||||||
|
"""
|
||||||
|
|
||||||
|
_CONTROL_STOP = {"__core_sink_control__": "stop"}
|
||||||
|
|
||||||
|
def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None:
|
||||||
|
self._to_core_queue = to_core_queue
|
||||||
|
self._from_core_queue = from_core_queue
|
||||||
|
self._outgoing_handler: OutgoingHandler | None = None
|
||||||
|
self._closed = False
|
||||||
|
self._listener_task: asyncio.Task | None = None
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||||
|
self._outgoing_handler = handler
|
||||||
|
if handler is not None and (self._listener_task is None or self._listener_task.done()):
|
||||||
|
self._listener_task = self._loop.create_task(self._listen_from_core())
|
||||||
|
|
||||||
|
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||||
|
if self._outgoing_handler is handler:
|
||||||
|
self._outgoing_handler = None
|
||||||
|
if self._listener_task and not self._listener_task.done():
|
||||||
|
self._listener_task.cancel()
|
||||||
|
|
||||||
|
async def send(self, message: MessageEnvelope) -> None:
|
||||||
|
await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message})
|
||||||
|
|
||||||
|
async def send_many(self, messages: list[MessageEnvelope]) -> None:
|
||||||
|
for message in messages:
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||||
|
logger.debug("ProcessCoreSink.push_outgoing called in child; ignored")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
|
await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP)
|
||||||
|
if self._listener_task:
|
||||||
|
self._listener_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._listener_task
|
||||||
|
self._listener_task = None
|
||||||
|
|
||||||
|
async def _listen_from_core(self) -> None:
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
item = await asyncio.to_thread(self._from_core_queue.get)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
if item == self._CONTROL_STOP:
|
||||||
|
break
|
||||||
|
if isinstance(item, dict) and item.get("kind") == "outgoing":
|
||||||
|
envelope = item.get("payload")
|
||||||
|
if self._outgoing_handler:
|
||||||
|
try:
|
||||||
|
await self._outgoing_handler(envelope)
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
logger.exception("Failed to handle outgoing envelope in ProcessCoreSink")
|
||||||
|
else:
|
||||||
|
logger.debug("ProcessCoreSink received unknown payload: %r", item)
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessCoreSinkServer:
|
||||||
|
"""
|
||||||
|
进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。
|
||||||
|
- 将传入的 incoming 消息转发给指定的 handler
|
||||||
|
- 将接收到的 outgoing 消息放入 outgoing 队列
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
incoming_queue: mp.Queue,
|
||||||
|
outgoing_queue: mp.Queue,
|
||||||
|
core_handler: Callable[[MessageEnvelope], Awaitable[None]],
|
||||||
|
name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._incoming_queue = incoming_queue
|
||||||
|
self._outgoing_queue = outgoing_queue
|
||||||
|
self._core_handler = core_handler
|
||||||
|
self._task: asyncio.Task | None = None
|
||||||
|
self._closed = False
|
||||||
|
self._name = name or "adapter"
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self._task is None or self._task.done():
|
||||||
|
self._task = asyncio.create_task(self._consume_incoming())
|
||||||
|
|
||||||
|
async def _consume_incoming(self) -> None:
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
item = await asyncio.to_thread(self._incoming_queue.get)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
if isinstance(item, dict) and item.get("__core_sink_control__") == "stop":
|
||||||
|
break
|
||||||
|
if isinstance(item, dict) and item.get("kind") == "incoming":
|
||||||
|
envelope = item.get("payload")
|
||||||
|
try:
|
||||||
|
await self._core_handler(envelope)
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
logger.exception("Failed to dispatch incoming envelope from %s", self._name)
|
||||||
|
else:
|
||||||
|
logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item)
|
||||||
|
|
||||||
|
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||||
|
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
|
await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"})
|
||||||
|
await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP)
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._task
|
||||||
|
self._task = None
|
||||||
|
|
||||||
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
|
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
|
||||||
send_many = getattr(sink, "send_many", None)
|
send_many = getattr(sink, "send_many", None)
|
||||||
@@ -200,11 +390,19 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) ->
|
|||||||
await sink.send(env)
|
await sink.send(env)
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_await(result: Any) -> Any:
|
||||||
|
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BatchDispatcher:
|
class BatchDispatcher:
|
||||||
"""
|
"""
|
||||||
将 send 操作合并为批量发送,适合网络 IO 密集场景。
|
批量消息分发器,负责将消息批量发送到核心 sink。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_STOP = object()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sink: CoreMessageSink,
|
sink: CoreMessageSink,
|
||||||
@@ -215,56 +413,79 @@ class BatchDispatcher:
|
|||||||
self._sink = sink
|
self._sink = sink
|
||||||
self._max_batch_size = max_batch_size
|
self._max_batch_size = max_batch_size
|
||||||
self._flush_interval = flush_interval
|
self._flush_interval = flush_interval
|
||||||
self._buffer: list[MessageEnvelope] = []
|
self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue()
|
||||||
self._lock = asyncio.Lock()
|
self._worker: asyncio.Task | None = None
|
||||||
self._flush_task: asyncio.Task | None = None
|
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
async def add(self, message: MessageEnvelope) -> None:
|
async def add(self, message: MessageEnvelope) -> None:
|
||||||
async with self._lock:
|
if self._closed:
|
||||||
if self._closed:
|
raise RuntimeError("Dispatcher closed")
|
||||||
raise RuntimeError("Dispatcher closed")
|
self._ensure_worker()
|
||||||
self._buffer.append(message)
|
await self._queue.put(message)
|
||||||
self._ensure_timer()
|
|
||||||
if len(self._buffer) >= self._max_batch_size:
|
|
||||||
await self._flush_locked()
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
async with self._lock:
|
if self._closed:
|
||||||
self._closed = True
|
|
||||||
await self._flush_locked()
|
|
||||||
if self._flush_task:
|
|
||||||
self._flush_task.cancel()
|
|
||||||
self._flush_task = None
|
|
||||||
|
|
||||||
def _ensure_timer(self) -> None:
|
|
||||||
if self._flush_task is not None and not self._flush_task.done():
|
|
||||||
return
|
return
|
||||||
loop = asyncio.get_running_loop()
|
self._closed = True
|
||||||
self._flush_task = loop.create_task(self._flush_loop())
|
self._ensure_worker()
|
||||||
|
await self._queue.put(self._STOP)
|
||||||
|
if self._worker:
|
||||||
|
await self._worker
|
||||||
|
self._worker = None
|
||||||
|
|
||||||
async def _flush_loop(self) -> None:
|
def _ensure_worker(self) -> None:
|
||||||
|
if self._worker is not None and not self._worker.done():
|
||||||
|
return
|
||||||
|
self._worker = asyncio.create_task(self._worker_loop())
|
||||||
|
|
||||||
|
async def _worker_loop(self) -> None:
|
||||||
|
buffer: list[MessageEnvelope] = []
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(self._flush_interval)
|
while True:
|
||||||
async with self._lock:
|
try:
|
||||||
await self._flush_locked()
|
item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval)
|
||||||
except asyncio.CancelledError: # pragma: no cover - timer cancellation
|
except asyncio.TimeoutError:
|
||||||
pass
|
item = None
|
||||||
|
|
||||||
async def _flush_locked(self) -> None:
|
if item is self._STOP:
|
||||||
if not self._buffer:
|
await self._flush_buffer(buffer)
|
||||||
|
return
|
||||||
|
if item is not None:
|
||||||
|
buffer.append(item) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
while len(buffer) < self._max_batch_size:
|
||||||
|
try:
|
||||||
|
item = self._queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
if item is self._STOP:
|
||||||
|
await self._flush_buffer(buffer)
|
||||||
|
return
|
||||||
|
buffer.append(item) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if buffer and (len(buffer) >= self._max_batch_size or item is None):
|
||||||
|
await self._flush_buffer(buffer)
|
||||||
|
except asyncio.CancelledError: # pragma: no cover - worker cancellation
|
||||||
|
if buffer:
|
||||||
|
await self._flush_buffer(buffer)
|
||||||
|
|
||||||
|
async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None:
|
||||||
|
if not buffer:
|
||||||
return
|
return
|
||||||
payload = list(self._buffer)
|
payload = list(buffer)
|
||||||
self._buffer.clear()
|
buffer.clear()
|
||||||
await self._sink.send_many(payload)
|
await _send_many(self._sink, payload)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AdapterTransportOptions",
|
"AdapterTransportOptions",
|
||||||
"AdapterBase",
|
"AdapterBase",
|
||||||
"BatchDispatcher",
|
"BatchDispatcher",
|
||||||
|
"CoreSink",
|
||||||
"CoreMessageSink",
|
"CoreMessageSink",
|
||||||
"HttpAdapterOptions",
|
"HttpAdapterOptions",
|
||||||
"InProcessCoreSink",
|
"InProcessCoreSink",
|
||||||
|
"ProcessCoreSink",
|
||||||
|
"ProcessCoreSinkServer",
|
||||||
|
"WebSocketLike",
|
||||||
"WebSocketAdapterOptions",
|
"WebSocketAdapterOptions",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,10 +11,36 @@ import orjson
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
from .message_models import MessageBase
|
|
||||||
|
|
||||||
MessagePayload = Dict[str, Any]
|
MessagePayload = Dict[str, Any]
|
||||||
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
|
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
|
||||||
|
DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
payload.setdefault("raw_bytes", raw_bytes)
|
||||||
|
elif isinstance(payload, list):
|
||||||
|
for item in payload:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
item.setdefault("raw_bytes", raw_bytes)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
|
||||||
|
if isinstance(message, (bytes, bytearray)):
|
||||||
|
return bytes(message), True
|
||||||
|
if use_raw_bytes and isinstance(message, dict):
|
||||||
|
raw = message.get("raw_bytes")
|
||||||
|
if isinstance(raw, (bytes, bytearray)):
|
||||||
|
return bytes(raw), True
|
||||||
|
payload = message
|
||||||
|
if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes:
|
||||||
|
payload = {k: v for k, v in payload.items() if k != "raw_bytes"}
|
||||||
|
data = orjson.dumps(payload)
|
||||||
|
if use_raw_bytes:
|
||||||
|
return data, True
|
||||||
|
return data.decode("utf-8"), False
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageHandler:
|
class BaseMessageHandler:
|
||||||
@@ -60,6 +86,8 @@ class MessageServer(BaseMessageHandler):
|
|||||||
mode: Literal["ws", "tcp"] = "ws",
|
mode: Literal["ws", "tcp"] = "ws",
|
||||||
custom_logger: logging.Logger | None = None,
|
custom_logger: logging.Logger | None = None,
|
||||||
enable_custom_uvicorn_logger: bool = False,
|
enable_custom_uvicorn_logger: bool = False,
|
||||||
|
queue_maxsize: int = 1000,
|
||||||
|
worker_count: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if mode != "ws":
|
if mode != "ws":
|
||||||
@@ -80,6 +108,9 @@ class MessageServer(BaseMessageHandler):
|
|||||||
self._conn_lock = asyncio.Lock()
|
self._conn_lock = asyncio.Lock()
|
||||||
self._server: uvicorn.Server | None = None
|
self._server: uvicorn.Server | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize)
|
||||||
|
self._worker_count = max(1, worker_count)
|
||||||
|
self._worker_tasks: list[asyncio.Task] = []
|
||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
|
|
||||||
def _setup_routes(self) -> None:
|
def _setup_routes(self) -> None:
|
||||||
@@ -97,21 +128,22 @@ class MessageServer(BaseMessageHandler):
|
|||||||
while True:
|
while True:
|
||||||
msg = await websocket.receive()
|
msg = await websocket.receive()
|
||||||
if msg["type"] == "websocket.receive":
|
if msg["type"] == "websocket.receive":
|
||||||
data = msg.get("text")
|
raw_bytes = msg.get("bytes")
|
||||||
if data is None and msg.get("bytes") is not None:
|
if raw_bytes is None and msg.get("text") is not None:
|
||||||
data = msg["bytes"].decode("utf-8")
|
raw_bytes = msg["text"].encode("utf-8")
|
||||||
if not data:
|
if not raw_bytes:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
payload = orjson.loads(data)
|
payload = orjson.loads(raw_bytes)
|
||||||
except orjson.JSONDecodeError:
|
except orjson.JSONDecodeError:
|
||||||
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
|
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
|
||||||
continue
|
continue
|
||||||
|
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||||
if isinstance(payload, list):
|
if isinstance(payload, list):
|
||||||
for item in payload:
|
for item in payload:
|
||||||
await self.process_message(item)
|
await self._enqueue_message(item)
|
||||||
else:
|
else:
|
||||||
await self.process_message(payload)
|
await self._enqueue_message(payload)
|
||||||
elif msg["type"] == "websocket.disconnect":
|
elif msg["type"] == "websocket.disconnect":
|
||||||
break
|
break
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
@@ -119,6 +151,49 @@ class MessageServer(BaseMessageHandler):
|
|||||||
finally:
|
finally:
|
||||||
await self._remove_connection(websocket, platform)
|
await self._remove_connection(websocket, platform)
|
||||||
|
|
||||||
|
async def _enqueue_message(self, payload: MessagePayload) -> None:
|
||||||
|
if not self._worker_tasks:
|
||||||
|
self._start_workers()
|
||||||
|
try:
|
||||||
|
self._message_queue.put_nowait(payload)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message")
|
||||||
|
|
||||||
|
def _start_workers(self) -> None:
|
||||||
|
if self._worker_tasks:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
for _ in range(self._worker_count):
|
||||||
|
task = asyncio.create_task(self._consumer_worker())
|
||||||
|
self._worker_tasks.append(task)
|
||||||
|
|
||||||
|
async def _stop_workers(self) -> None:
|
||||||
|
if not self._worker_tasks:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
for task in self._worker_tasks:
|
||||||
|
task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
||||||
|
self._worker_tasks.clear()
|
||||||
|
while not self._message_queue.empty():
|
||||||
|
with contextlib.suppress(asyncio.QueueEmpty):
|
||||||
|
self._message_queue.get_nowait()
|
||||||
|
self._message_queue.task_done()
|
||||||
|
|
||||||
|
async def _consumer_worker(self) -> None:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
payload = await self._message_queue.get()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
await self.process_message(payload)
|
||||||
|
except Exception: # pragma: no cover - best effort logging
|
||||||
|
logging.getLogger("mofox_bus.server").exception("Error processing message")
|
||||||
|
finally:
|
||||||
|
self._message_queue.task_done()
|
||||||
|
|
||||||
async def verify_token(self, token: str | None) -> bool:
|
async def verify_token(self, token: str | None) -> bool:
|
||||||
if not self._enable_token:
|
if not self._enable_token:
|
||||||
return True
|
return True
|
||||||
@@ -145,33 +220,45 @@ class MessageServer(BaseMessageHandler):
|
|||||||
if platform and self._platform_connections.get(platform) is websocket:
|
if platform and self._platform_connections.get(platform) is websocket:
|
||||||
del self._platform_connections[platform]
|
del self._platform_connections[platform]
|
||||||
|
|
||||||
async def broadcast_message(self, message: MessagePayload) -> None:
|
async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None:
|
||||||
data = orjson.dumps(message).decode("utf-8")
|
payload: MessagePayload | bytes = message
|
||||||
|
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||||
async with self._conn_lock:
|
async with self._conn_lock:
|
||||||
targets = list(self._connections)
|
targets = list(self._connections)
|
||||||
for ws in targets:
|
for ws in targets:
|
||||||
await ws.send_text(data)
|
if is_binary:
|
||||||
|
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||||
|
else:
|
||||||
|
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||||
|
|
||||||
async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None:
|
async def broadcast_to_platform(
|
||||||
|
self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False
|
||||||
|
) -> None:
|
||||||
ws = self._platform_connections.get(platform)
|
ws = self._platform_connections.get(platform)
|
||||||
if ws is None:
|
if ws is None:
|
||||||
raise RuntimeError(f"No active connection for platform {platform}")
|
raise RuntimeError(f"No active connection for platform {platform}")
|
||||||
await ws.send_text(orjson.dumps(message).decode("utf-8"))
|
payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message
|
||||||
|
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||||
|
if is_binary:
|
||||||
|
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||||
|
else:
|
||||||
|
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||||
|
|
||||||
async def send_message(self, message: MessageBase | MessagePayload) -> None:
|
async def send_message(
|
||||||
payload = message.to_dict() if isinstance(message, MessageBase) else message
|
self, message: MessagePayload, *, prefer_raw_bytes: bool = False
|
||||||
platform = payload.get("message_info", {}).get("platform")
|
) -> None:
|
||||||
|
platform = message.get("message_info", {}).get("platform")
|
||||||
if not platform:
|
if not platform:
|
||||||
raise ValueError("message_info.platform is required to route the message")
|
raise ValueError("message_info.platform is required to route the message")
|
||||||
await self.broadcast_to_platform(platform, payload)
|
await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes)
|
||||||
|
|
||||||
def run_sync(self) -> None:
|
def run_sync(self) -> None:
|
||||||
if not self._own_app:
|
if not self._own_app:
|
||||||
return
|
return
|
||||||
asyncio.run(self.run())
|
asyncio.run(self.run())
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
self._running = True
|
self._start_workers()
|
||||||
if not self._own_app:
|
if not self._own_app:
|
||||||
return
|
return
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
@@ -191,6 +278,7 @@ class MessageServer(BaseMessageHandler):
|
|||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
|
await self._stop_workers()
|
||||||
if self._server:
|
if self._server:
|
||||||
self._server.should_exit = True
|
self._server.should_exit = True
|
||||||
await self._server.shutdown()
|
await self._server.shutdown()
|
||||||
@@ -217,7 +305,13 @@ class MessageClient(BaseMessageHandler):
|
|||||||
WebSocket 消息客户端,实现双向传输。
|
WebSocket 消息客户端,实现双向传输。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: Literal["ws", "tcp"] = "ws",
|
||||||
|
*,
|
||||||
|
reconnect_interval: float = 5.0,
|
||||||
|
logger: logging.Logger | None = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if mode != "ws":
|
if mode != "ws":
|
||||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||||
@@ -230,6 +324,9 @@ class MessageClient(BaseMessageHandler):
|
|||||||
self._token: str | None = None
|
self._token: str | None = None
|
||||||
self._ssl_verify: str | None = None
|
self._ssl_verify: str | None = None
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
self._on_disconnect: DisconnectCallback | None = None
|
||||||
|
self._reconnect_interval = reconnect_interval
|
||||||
|
self._logger = logger or logging.getLogger("mofox_bus.client")
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
self,
|
self,
|
||||||
@@ -243,8 +340,12 @@ class MessageClient(BaseMessageHandler):
|
|||||||
self._platform = platform
|
self._platform = platform
|
||||||
self._token = token
|
self._token = token
|
||||||
self._ssl_verify = ssl_verify
|
self._ssl_verify = ssl_verify
|
||||||
|
self._closed = False
|
||||||
await self._establish_connection()
|
await self._establish_connection()
|
||||||
|
|
||||||
|
def set_disconnect_callback(self, callback: DisconnectCallback) -> None:
|
||||||
|
self._on_disconnect = callback
|
||||||
|
|
||||||
async def _establish_connection(self) -> None:
|
async def _establish_connection(self) -> None:
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
self._session = aiohttp.ClientSession()
|
self._session = aiohttp.ClientSession()
|
||||||
@@ -257,17 +358,21 @@ class MessageClient(BaseMessageHandler):
|
|||||||
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
|
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
|
||||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||||
|
|
||||||
|
async def _connect_once(self) -> None:
|
||||||
|
await self._establish_connection()
|
||||||
|
|
||||||
async def _receive_loop(self) -> None:
|
async def _receive_loop(self) -> None:
|
||||||
assert self._ws is not None
|
assert self._ws is not None
|
||||||
try:
|
try:
|
||||||
async for msg in self._ws:
|
async for msg in self._ws:
|
||||||
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||||
data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8")
|
raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8")
|
||||||
try:
|
try:
|
||||||
payload = orjson.loads(data)
|
payload = orjson.loads(raw_bytes)
|
||||||
except orjson.JSONDecodeError:
|
except orjson.JSONDecodeError:
|
||||||
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
|
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
|
||||||
continue
|
continue
|
||||||
|
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||||
if isinstance(payload, list):
|
if isinstance(payload, list):
|
||||||
for item in payload:
|
for item in payload:
|
||||||
await self.process_message(item)
|
await self.process_message(item)
|
||||||
@@ -278,23 +383,33 @@ class MessageClient(BaseMessageHandler):
|
|||||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
if not self._closed:
|
||||||
|
await self._notify_disconnect("websocket disconnected")
|
||||||
|
await self._reconnect()
|
||||||
if self._ws:
|
if self._ws:
|
||||||
await self._ws.close()
|
await self._ws.close()
|
||||||
self._ws = None
|
self._ws = None
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
if self._receive_task is None:
|
self._closed = False
|
||||||
await self._establish_connection()
|
while not self._closed:
|
||||||
try:
|
if self._receive_task is None:
|
||||||
if self._receive_task:
|
await self._establish_connection()
|
||||||
await self._receive_task
|
task = self._receive_task
|
||||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
if task is None:
|
||||||
pass
|
break
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||||
|
raise
|
||||||
|
|
||||||
async def send_message(self, message: MessagePayload) -> bool:
|
async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool:
|
||||||
if self._ws is None or self._ws.closed:
|
ws = await self._ensure_ws()
|
||||||
raise RuntimeError("WebSocket connection is not established")
|
data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes)
|
||||||
await self._ws.send_str(orjson.dumps(message).decode("utf-8"))
|
if is_binary:
|
||||||
|
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||||
|
else:
|
||||||
|
await ws.send_str(data if isinstance(data, str) else data.decode("utf-8"))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
@@ -313,6 +428,42 @@ class MessageClient(BaseMessageHandler):
|
|||||||
await self._session.close()
|
await self._session.close()
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
|
async def _notify_disconnect(self, reason: str) -> None:
|
||||||
|
if self._on_disconnect is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = self._on_disconnect(self._platform, reason)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
except Exception: # pragma: no cover - best effort notification
|
||||||
|
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
|
||||||
|
|
||||||
|
async def _reconnect(self) -> None:
|
||||||
|
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
|
||||||
|
await asyncio.sleep(self._reconnect_interval)
|
||||||
|
await self._connect_once()
|
||||||
|
|
||||||
|
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
|
||||||
|
if self._ws is None or self._ws.closed:
|
||||||
|
await self._connect_once()
|
||||||
|
assert self._ws is not None
|
||||||
|
return self._ws
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "MessageClient":
|
||||||
|
if not self._url or not self._platform:
|
||||||
|
raise RuntimeError("connect() must be called before using MessageClient as a context manager")
|
||||||
|
await self._ensure_session()
|
||||||
|
await self._ensure_ws()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
|
||||||
def _self_websocket(app: FastAPI, path: str):
|
def _self_websocket(app: FastAPI, path: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
110
src/mofox_bus/builder.py
Normal file
110
src/mofox_bus/builder.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBuilder:
|
||||||
|
"""
|
||||||
|
Fluent helper to build MessageEnvelope safely with type hints.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
msg = (
|
||||||
|
MessageBuilder()
|
||||||
|
.text("Hello")
|
||||||
|
.image("http://example.com/1.png")
|
||||||
|
.to_user("123", platform="qq")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._direction: str = "outgoing"
|
||||||
|
self._message_info: MessageInfoPayload = {}
|
||||||
|
self._segments: List[SegPayload] = []
|
||||||
|
self._metadata: Dict[str, Any] | None = None
|
||||||
|
self._timestamp_ms: int | None = None
|
||||||
|
self._message_id: str | None = None
|
||||||
|
|
||||||
|
def direction(self, value: str) -> "MessageBuilder":
|
||||||
|
self._direction = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def message_id(self, value: str) -> "MessageBuilder":
|
||||||
|
self._message_id = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def timestamp_ms(self, value: int | None = None) -> "MessageBuilder":
|
||||||
|
self._timestamp_ms = value or int(time.time() * 1000)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def metadata(self, value: Dict[str, Any]) -> "MessageBuilder":
|
||||||
|
self._metadata = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def platform(self, value: str) -> "MessageBuilder":
|
||||||
|
self._message_info["platform"] = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder":
|
||||||
|
if platform:
|
||||||
|
self.platform(platform)
|
||||||
|
user_info: UserInfoPayload = {"user_id": user_id}
|
||||||
|
if nickname:
|
||||||
|
user_info["user_nickname"] = nickname
|
||||||
|
self._message_info["user_info"] = user_info
|
||||||
|
return self
|
||||||
|
|
||||||
|
def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder":
|
||||||
|
if platform:
|
||||||
|
self.platform(platform)
|
||||||
|
group_info: GroupInfoPayload = {"group_id": group_id}
|
||||||
|
if name:
|
||||||
|
group_info["group_name"] = name
|
||||||
|
self._message_info["group_info"] = group_info
|
||||||
|
return self
|
||||||
|
|
||||||
|
def seg(self, type_: str, data: Any) -> "MessageBuilder":
|
||||||
|
self._segments.append({"type": type_, "data": data})
|
||||||
|
return self
|
||||||
|
|
||||||
|
def text(self, content: str) -> "MessageBuilder":
|
||||||
|
return self.seg("text", content)
|
||||||
|
|
||||||
|
def image(self, url: str) -> "MessageBuilder":
|
||||||
|
return self.seg("image", url)
|
||||||
|
|
||||||
|
def reply(self, target_message_id: str) -> "MessageBuilder":
|
||||||
|
return self.seg("reply", target_message_id)
|
||||||
|
|
||||||
|
def raw_segment(self, segment: SegPayload) -> "MessageBuilder":
|
||||||
|
self._segments.append(segment)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> MessageEnvelope:
|
||||||
|
# message_info defaults
|
||||||
|
if not self._segments:
|
||||||
|
raise ValueError("message_segment is required, add at least one segment before build()")
|
||||||
|
if self._message_id is None:
|
||||||
|
self._message_id = str(uuid.uuid4())
|
||||||
|
info = dict(self._message_info)
|
||||||
|
info.setdefault("message_id", self._message_id)
|
||||||
|
info.setdefault("time", time.time())
|
||||||
|
|
||||||
|
segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments]
|
||||||
|
envelope: MessageEnvelope = {
|
||||||
|
"direction": self._direction, # type: ignore[assignment]
|
||||||
|
"message_info": info,
|
||||||
|
"message_segment": segments[0] if len(segments) == 1 else list(segments),
|
||||||
|
}
|
||||||
|
if self._metadata is not None:
|
||||||
|
envelope["metadata"] = self._metadata
|
||||||
|
if self._timestamp_ms is not None:
|
||||||
|
envelope["timestamp_ms"] = self._timestamp_ms
|
||||||
|
return envelope
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MessageBuilder"]
|
||||||
@@ -27,24 +27,23 @@ def _loads(data: bytes) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def dumps_message(msg: MessageEnvelope) -> bytes:
|
def dumps_message(msg: MessageEnvelope) -> bytes:
|
||||||
"""
|
"""
|
||||||
将单条 MessageEnvelope 序列化为 JSON bytes。
|
将单条消息序列化为 JSON bytes。
|
||||||
"""
|
"""
|
||||||
if "schema_version" not in msg:
|
sanitized = _strip_raw_bytes(msg)
|
||||||
msg["schema_version"] = DEFAULT_SCHEMA_VERSION
|
if "schema_version" not in sanitized:
|
||||||
return _dumps(msg)
|
sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION
|
||||||
|
return _dumps(sanitized)
|
||||||
|
|
||||||
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
|
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
|
||||||
"""
|
"""
|
||||||
将多条消息批量序列化,以提升吞吐。
|
将批量消息序列化为 JSON bytes。
|
||||||
"""
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"schema_version": DEFAULT_SCHEMA_VERSION,
|
"schema_version": DEFAULT_SCHEMA_VERSION,
|
||||||
"items": list(messages),
|
"items": [_strip_raw_bytes(msg) for msg in messages],
|
||||||
}
|
}
|
||||||
return _dumps(payload)
|
return _dumps(payload)
|
||||||
|
|
||||||
|
|
||||||
def loads_message(data: bytes | str) -> MessageEnvelope:
|
def loads_message(data: bytes | str) -> MessageEnvelope:
|
||||||
"""
|
"""
|
||||||
反序列化单条消息。
|
反序列化单条消息。
|
||||||
@@ -78,6 +77,14 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
|
|||||||
raise ValueError(f"Unsupported schema_version={version}")
|
raise ValueError(f"Unsupported schema_version={version}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope:
|
||||||
|
if isinstance(msg, dict) and "raw_bytes" in msg:
|
||||||
|
new_msg = dict(msg)
|
||||||
|
new_msg.pop("raw_bytes", None)
|
||||||
|
return new_msg # type: ignore[return-value]
|
||||||
|
return msg
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_SCHEMA_VERSION",
|
"DEFAULT_SCHEMA_VERSION",
|
||||||
"dumps_message",
|
"dumps_message",
|
||||||
|
|||||||
@@ -1,189 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seg:
|
|
||||||
"""
|
|
||||||
消息段,表示一段文本/图片/结构化内容。
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
data: Any
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "Seg":
|
|
||||||
seg_type = data.get("type")
|
|
||||||
seg_data = data.get("data")
|
|
||||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
|
||||||
seg_data = [Seg.from_dict(item) for item in seg_data]
|
|
||||||
return cls(type=seg_type, data=seg_data)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
if self.type == "seglist" and isinstance(self.data, list):
|
|
||||||
payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data]
|
|
||||||
else:
|
|
||||||
payload = self.data
|
|
||||||
return {"type": self.type, "data": payload}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GroupInfo:
|
|
||||||
platform: Optional[str] = None
|
|
||||||
group_id: Optional[str] = None
|
|
||||||
group_name: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]:
|
|
||||||
if not data or data.get("group_id") is None:
|
|
||||||
return None
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"),
|
|
||||||
group_id=data.get("group_id"),
|
|
||||||
group_name=data.get("group_name"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UserInfo:
|
|
||||||
platform: Optional[str] = None
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
user_nickname: Optional[str] = None
|
|
||||||
user_cardname: Optional[str] = None
|
|
||||||
user_avatar: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "UserInfo":
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"),
|
|
||||||
user_id=data.get("user_id"),
|
|
||||||
user_nickname=data.get("user_nickname"),
|
|
||||||
user_cardname=data.get("user_cardname"),
|
|
||||||
user_avatar=data.get("user_avatar"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FormatInfo:
|
|
||||||
content_format: Optional[List[str]] = None
|
|
||||||
accept_format: Optional[List[str]] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]:
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
return cls(
|
|
||||||
content_format=data.get("content_format"),
|
|
||||||
accept_format=data.get("accept_format"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TemplateInfo:
|
|
||||||
template_items: Optional[Dict[str, str]] = None
|
|
||||||
template_name: Optional[Dict[str, str]] = None
|
|
||||||
template_default: bool = True
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]:
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
return cls(
|
|
||||||
template_items=data.get("template_items"),
|
|
||||||
template_name=data.get("template_name"),
|
|
||||||
template_default=data.get("template_default", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseMessageInfo:
|
|
||||||
platform: Optional[str] = None
|
|
||||||
message_id: Optional[str] = None
|
|
||||||
time: Optional[float] = None
|
|
||||||
group_info: Optional[GroupInfo] = None
|
|
||||||
user_info: Optional[UserInfo] = None
|
|
||||||
format_info: Optional[FormatInfo] = None
|
|
||||||
template_info: Optional[TemplateInfo] = None
|
|
||||||
additional_config: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
result: Dict[str, Any] = {}
|
|
||||||
if self.platform is not None:
|
|
||||||
result["platform"] = self.platform
|
|
||||||
if self.message_id is not None:
|
|
||||||
result["message_id"] = self.message_id
|
|
||||||
if self.time is not None:
|
|
||||||
result["time"] = self.time
|
|
||||||
if self.additional_config is not None:
|
|
||||||
result["additional_config"] = self.additional_config
|
|
||||||
if self.group_info is not None:
|
|
||||||
result["group_info"] = self.group_info.to_dict()
|
|
||||||
if self.user_info is not None:
|
|
||||||
result["user_info"] = self.user_info.to_dict()
|
|
||||||
if self.format_info is not None:
|
|
||||||
result["format_info"] = self.format_info.to_dict()
|
|
||||||
if self.template_info is not None:
|
|
||||||
result["template_info"] = self.template_info.to_dict()
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo":
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"),
|
|
||||||
message_id=data.get("message_id"),
|
|
||||||
time=data.get("time"),
|
|
||||||
additional_config=data.get("additional_config"),
|
|
||||||
group_info=GroupInfo.from_dict(data.get("group_info", {})),
|
|
||||||
user_info=UserInfo.from_dict(data.get("user_info", {})),
|
|
||||||
format_info=FormatInfo.from_dict(data.get("format_info", {})),
|
|
||||||
template_info=TemplateInfo.from_dict(data.get("template_info", {})),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageBase:
|
|
||||||
message_info: BaseMessageInfo
|
|
||||||
message_segment: Seg
|
|
||||||
raw_message: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
payload: Dict[str, Any] = {
|
|
||||||
"message_info": self.message_info.to_dict(),
|
|
||||||
"message_segment": self.message_segment.to_dict(),
|
|
||||||
}
|
|
||||||
if self.raw_message is not None:
|
|
||||||
payload["raw_message"] = self.raw_message
|
|
||||||
return payload
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "MessageBase":
|
|
||||||
return cls(
|
|
||||||
message_info=BaseMessageInfo.from_dict(data.get("message_info", {})),
|
|
||||||
message_segment=Seg.from_dict(data.get("message_segment", {})),
|
|
||||||
raw_message=data.get("raw_message"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseMessageInfo",
|
|
||||||
"FormatInfo",
|
|
||||||
"GroupInfo",
|
|
||||||
"MessageBase",
|
|
||||||
"Seg",
|
|
||||||
"TemplateInfo",
|
|
||||||
"UserInfo",
|
|
||||||
]
|
|
||||||
@@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass
|
|||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
from .api import MessageClient
|
from .api import MessageClient
|
||||||
from .message_models import MessageBase
|
from .types import MessageEnvelope
|
||||||
|
|
||||||
logger = logging.getLogger("mofox_bus.router")
|
logger = logging.getLogger("mofox_bus.router")
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class Router:
|
|||||||
self.handlers: list[Callable[[Dict], None]] = []
|
self.handlers: list[Callable[[Dict], None]] = []
|
||||||
self._running = False
|
self._running = False
|
||||||
self._client_tasks: Dict[str, asyncio.Task] = {}
|
self._client_tasks: Dict[str, asyncio.Task] = {}
|
||||||
self._monitor_task: asyncio.Task | None = None
|
self._stop_event: asyncio.Event | None = None
|
||||||
|
|
||||||
async def connect(self, platform: str) -> None:
|
async def connect(self, platform: str) -> None:
|
||||||
if platform not in self.config.route_config:
|
if platform not in self.config.route_config:
|
||||||
@@ -65,6 +65,7 @@ class Router:
|
|||||||
if mode != "ws":
|
if mode != "ws":
|
||||||
raise NotImplementedError("TCP mode is not implemented yet")
|
raise NotImplementedError("TCP mode is not implemented yet")
|
||||||
client = MessageClient(mode="ws")
|
client = MessageClient(mode="ws")
|
||||||
|
client.set_disconnect_callback(self._handle_client_disconnect)
|
||||||
await client.connect(
|
await client.connect(
|
||||||
url=target.url,
|
url=target.url,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
@@ -75,7 +76,7 @@ class Router:
|
|||||||
client.register_message_handler(handler)
|
client.register_message_handler(handler)
|
||||||
self.clients[platform] = client
|
self.clients[platform] = client
|
||||||
if self._running:
|
if self._running:
|
||||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
self._start_client_task(platform, client)
|
||||||
|
|
||||||
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
|
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
|
||||||
self.handlers.append(handler)
|
self.handlers.append(handler)
|
||||||
@@ -84,36 +85,18 @@ class Router:
|
|||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
self._running = True
|
self._running = True
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
for platform in self.config.route_config:
|
for platform in self.config.route_config:
|
||||||
if platform not in self.clients:
|
if platform not in self.clients:
|
||||||
await self.connect(platform)
|
await self.connect(platform)
|
||||||
for platform, client in self.clients.items():
|
for platform, client in self.clients.items():
|
||||||
if platform not in self._client_tasks:
|
if platform not in self._client_tasks:
|
||||||
self._client_tasks[platform] = asyncio.create_task(client.run())
|
self._start_client_task(platform, client)
|
||||||
self._monitor_task = asyncio.create_task(self._monitor_connections())
|
|
||||||
try:
|
try:
|
||||||
while self._running:
|
await self._stop_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
except asyncio.CancelledError: # pragma: no cover
|
except asyncio.CancelledError: # pragma: no cover
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _monitor_connections(self) -> None:
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
while self._running:
|
|
||||||
for platform in list(self.clients.keys()):
|
|
||||||
client = self.clients.get(platform)
|
|
||||||
if client is None:
|
|
||||||
continue
|
|
||||||
if not client.is_connected():
|
|
||||||
logger.info("Detected disconnect from %s, attempting reconnect", platform)
|
|
||||||
await self._reconnect_platform(platform)
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
async def _reconnect_platform(self, platform: str) -> None:
|
|
||||||
await self.remove_platform(platform)
|
|
||||||
if platform in self.config.route_config:
|
|
||||||
await self.connect(platform)
|
|
||||||
|
|
||||||
async def remove_platform(self, platform: str) -> None:
|
async def remove_platform(self, platform: str) -> None:
|
||||||
if platform in self._client_tasks:
|
if platform in self._client_tasks:
|
||||||
task = self._client_tasks.pop(platform)
|
task = self._client_tasks.pop(platform)
|
||||||
@@ -124,32 +107,55 @@ class Router:
|
|||||||
if client:
|
if client:
|
||||||
await client.stop()
|
await client.stop()
|
||||||
|
|
||||||
|
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
|
||||||
|
logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason)
|
||||||
|
task = self._client_tasks.get(platform)
|
||||||
|
if task is not None and not task.done():
|
||||||
|
return
|
||||||
|
client = self.clients.get(platform)
|
||||||
|
if client and self._running:
|
||||||
|
self._start_client_task(platform, client)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._monitor_task:
|
if self._stop_event:
|
||||||
self._monitor_task.cancel()
|
self._stop_event.set()
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
|
||||||
await self._monitor_task
|
|
||||||
self._monitor_task = None
|
|
||||||
for platform in list(self.clients.keys()):
|
for platform in list(self.clients.keys()):
|
||||||
await self.remove_platform(platform)
|
await self.remove_platform(platform)
|
||||||
self.clients.clear()
|
self.clients.clear()
|
||||||
|
|
||||||
def get_target_url(self, message: MessageBase) -> Optional[str]:
|
def _start_client_task(self, platform: str, client: MessageClient) -> None:
|
||||||
platform = message.message_info.platform
|
task = asyncio.create_task(client.run())
|
||||||
|
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
|
||||||
|
self._client_tasks[platform] = task
|
||||||
|
|
||||||
|
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc:
|
||||||
|
logger.warning("Client task for %s ended with exception: %s", platform, exc)
|
||||||
|
client = self.clients.get(platform)
|
||||||
|
if client:
|
||||||
|
self._start_client_task(platform, client)
|
||||||
|
|
||||||
|
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
|
||||||
|
platform = message.get("message_info", {}).get("platform")
|
||||||
if not platform:
|
if not platform:
|
||||||
return None
|
return None
|
||||||
target = self.config.route_config.get(platform)
|
target = self.config.route_config.get(platform)
|
||||||
return target.url if target else None
|
return target.url if target else None
|
||||||
|
|
||||||
async def send_message(self, message: MessageBase):
|
async def send_message(self, message: MessageEnvelope):
|
||||||
platform = message.message_info.platform
|
platform = message.get("message_info", {}).get("platform")
|
||||||
if not platform:
|
if not platform:
|
||||||
raise ValueError("message_info.platform is required")
|
raise ValueError("message_info.platform is required")
|
||||||
client = self.clients.get(platform)
|
client = self.clients.get(platform)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise RuntimeError(f"No client connected for platform {platform}")
|
raise RuntimeError(f"No client connected for platform {platform}")
|
||||||
return await client.send_message(message.to_dict())
|
return await client.send_message(message)
|
||||||
|
|
||||||
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
||||||
new_config = RouteConfig.from_dict(config_data)
|
new_config = RouteConfig.from_dict(config_data)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Awaitable, Callable, Iterable, List
|
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
||||||
|
|
||||||
from .types import MessageEnvelope
|
from .types import MessageEnvelope
|
||||||
|
|
||||||
@@ -12,6 +13,11 @@ ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
|
|||||||
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
|
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
|
||||||
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
|
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
|
||||||
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
|
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
|
||||||
|
MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]]
|
||||||
|
|
||||||
|
|
||||||
|
class Middleware(Protocol):
|
||||||
|
async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ...
|
||||||
|
|
||||||
|
|
||||||
class MessageProcessingError(RuntimeError):
|
class MessageProcessingError(RuntimeError):
|
||||||
@@ -19,7 +25,7 @@ class MessageProcessingError(RuntimeError):
|
|||||||
|
|
||||||
def __init__(self, message: MessageEnvelope, original: BaseException):
|
def __init__(self, message: MessageEnvelope, original: BaseException):
|
||||||
detail = message.get("id", "<unknown>")
|
detail = message.get("id", "<unknown>")
|
||||||
super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only
|
super().__init__(f"Failed to handle message {detail}: {original}")
|
||||||
self.message_envelope = message
|
self.message_envelope = message
|
||||||
self.original = original
|
self.original = original
|
||||||
|
|
||||||
@@ -29,6 +35,8 @@ class MessageRoute:
|
|||||||
predicate: Predicate
|
predicate: Predicate
|
||||||
handler: MessageHandler
|
handler: MessageHandler
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
message_type: str | None = None
|
||||||
|
event_types: set[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class MessageRuntime:
|
class MessageRuntime:
|
||||||
@@ -43,15 +51,36 @@ class MessageRuntime:
|
|||||||
self._error_hooks: list[ErrorHook] = []
|
self._error_hooks: list[ErrorHook] = []
|
||||||
self._batch_handler: BatchHandler | None = None
|
self._batch_handler: BatchHandler | None = None
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
|
self._middlewares: list[Middleware] = []
|
||||||
|
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
||||||
|
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
||||||
|
|
||||||
def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None:
|
def add_route(
|
||||||
|
self,
|
||||||
|
predicate: Predicate,
|
||||||
|
handler: MessageHandler,
|
||||||
|
name: str | None = None,
|
||||||
|
*,
|
||||||
|
message_type: str | None = None,
|
||||||
|
event_types: Iterable[str] | None = None,
|
||||||
|
) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name))
|
route = MessageRoute(
|
||||||
|
predicate=predicate,
|
||||||
|
handler=handler,
|
||||||
|
name=name,
|
||||||
|
message_type=message_type,
|
||||||
|
event_types=set(event_types) if event_types is not None else None,
|
||||||
|
)
|
||||||
|
self._routes.append(route)
|
||||||
|
if message_type:
|
||||||
|
self._type_routes.setdefault(message_type, []).append(route)
|
||||||
|
if route.event_types:
|
||||||
|
for et in route.event_types:
|
||||||
|
self._event_routes.setdefault(et, []).append(route)
|
||||||
|
|
||||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||||
"""
|
"""装饰器写法,便于在核心逻辑中声明式注册。"""
|
||||||
装饰器写法,便于在核心逻辑中声明式注册。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: MessageHandler) -> MessageHandler:
|
def decorator(func: MessageHandler) -> MessageHandler:
|
||||||
self.add_route(predicate, func, name=name)
|
self.add_route(predicate, func, name=name)
|
||||||
@@ -59,6 +88,60 @@ class MessageRuntime:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def on_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message_type: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
predicate: Predicate | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[MessageHandler], MessageHandler]:
|
||||||
|
"""Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。"""
|
||||||
|
|
||||||
|
async def combined_predicate(message: MessageEnvelope) -> bool:
|
||||||
|
if message_type is not None and _extract_segment_type(message) != message_type:
|
||||||
|
return False
|
||||||
|
if platform is not None:
|
||||||
|
info_platform = message.get("message_info", {}).get("platform")
|
||||||
|
if message.get("platform") not in (None, platform) and info_platform is None:
|
||||||
|
return False
|
||||||
|
if info_platform not in (None, platform):
|
||||||
|
return False
|
||||||
|
if predicate is None:
|
||||||
|
return True
|
||||||
|
return await _invoke_callable(predicate, message, prefer_thread=False)
|
||||||
|
|
||||||
|
def decorator(func: MessageHandler) -> MessageHandler:
|
||||||
|
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def on_event(
|
||||||
|
self,
|
||||||
|
event_type: str | Iterable[str],
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[MessageHandler], MessageHandler]:
|
||||||
|
"""装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。"""
|
||||||
|
|
||||||
|
allowed = {event_type} if isinstance(event_type, str) else set(event_type)
|
||||||
|
|
||||||
|
async def predicate(message: MessageEnvelope) -> bool:
|
||||||
|
current = (
|
||||||
|
message.get("event_type")
|
||||||
|
or message.get("message_info", {})
|
||||||
|
.get("additional_config", {})
|
||||||
|
.get("event_type")
|
||||||
|
)
|
||||||
|
return current in allowed
|
||||||
|
|
||||||
|
def decorator(func: MessageHandler) -> MessageHandler:
|
||||||
|
self.add_route(predicate, func, name=name, event_types=allowed)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||||
self._batch_handler = handler
|
self._batch_handler = handler
|
||||||
|
|
||||||
@@ -71,14 +154,20 @@ class MessageRuntime:
|
|||||||
def register_error_hook(self, hook: ErrorHook) -> None:
|
def register_error_hook(self, hook: ErrorHook) -> None:
|
||||||
self._error_hooks.append(hook)
|
self._error_hooks.append(hook)
|
||||||
|
|
||||||
|
def register_middleware(self, middleware: Middleware) -> None:
|
||||||
|
"""注册洋葱模型中间件,围绕处理器执行。"""
|
||||||
|
|
||||||
|
self._middlewares.append(middleware)
|
||||||
|
|
||||||
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
|
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
|
||||||
await self._run_hooks(self._before_hooks, message)
|
await self._run_hooks(self._before_hooks, message)
|
||||||
try:
|
try:
|
||||||
route = await self._match_route(message)
|
route = await self._match_route(message)
|
||||||
if route is None:
|
if route is None:
|
||||||
return None
|
return None
|
||||||
result = await _maybe_await(route.handler(message))
|
handler = self._wrap_with_middlewares(route.handler)
|
||||||
except Exception as exc: # pragma: no cover - tested indirectly
|
result = await handler(message)
|
||||||
|
except Exception as exc:
|
||||||
await self._run_error_hooks(message, exc)
|
await self._run_error_hooks(message, exc)
|
||||||
raise MessageProcessingError(message, exc) from exc
|
raise MessageProcessingError(message, exc) from exc
|
||||||
await self._run_hooks(self._after_hooks, message)
|
await self._run_hooks(self._after_hooks, message)
|
||||||
@@ -89,7 +178,7 @@ class MessageRuntime:
|
|||||||
if not batch:
|
if not batch:
|
||||||
return []
|
return []
|
||||||
if self._batch_handler is not None:
|
if self._batch_handler is not None:
|
||||||
result = await _maybe_await(self._batch_handler(batch))
|
result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True)
|
||||||
return result or []
|
return result or []
|
||||||
responses: list[MessageEnvelope] = []
|
responses: list[MessageEnvelope] = []
|
||||||
for message in batch:
|
for message in batch:
|
||||||
@@ -99,21 +188,61 @@ class MessageRuntime:
|
|||||||
return responses
|
return responses
|
||||||
|
|
||||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||||
|
candidates: list[MessageRoute] = []
|
||||||
|
message_type = _extract_segment_type(message)
|
||||||
|
event_type = (
|
||||||
|
message.get("event_type")
|
||||||
|
or message.get("message_info", {})
|
||||||
|
.get("additional_config", {})
|
||||||
|
.get("event_type")
|
||||||
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
routes = list(self._routes)
|
if event_type and event_type in self._event_routes:
|
||||||
for route in routes:
|
candidates.extend(self._event_routes[event_type])
|
||||||
should_handle = await _maybe_await(route.predicate(message))
|
if message_type and message_type in self._type_routes:
|
||||||
|
candidates.extend(self._type_routes[message_type])
|
||||||
|
candidates.extend(self._routes)
|
||||||
|
|
||||||
|
seen: set[int] = set()
|
||||||
|
for route in candidates:
|
||||||
|
rid = id(route)
|
||||||
|
if rid in seen:
|
||||||
|
continue
|
||||||
|
seen.add(rid)
|
||||||
|
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||||
if should_handle:
|
if should_handle:
|
||||||
return route
|
return route
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||||
for hook in hooks:
|
coro_list = [self._call_hook(hook, message) for hook in hooks]
|
||||||
await _maybe_await(hook(message))
|
if coro_list:
|
||||||
|
await asyncio.gather(*coro_list)
|
||||||
|
|
||||||
|
async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None:
|
||||||
|
await _invoke_callable(hook, message, prefer_thread=True)
|
||||||
|
|
||||||
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
|
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
|
||||||
for hook in self._error_hooks:
|
coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks]
|
||||||
await _maybe_await(hook(message, exc))
|
if coros:
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
|
||||||
|
async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None:
|
||||||
|
await _invoke_callable(hook, message, exc, prefer_thread=True)
|
||||||
|
|
||||||
|
def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable:
|
||||||
|
async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None:
|
||||||
|
return await _invoke_callable(handler, message, prefer_thread=True)
|
||||||
|
|
||||||
|
wrapped: MiddlewareCallable = base_handler
|
||||||
|
for middleware in reversed(self._middlewares):
|
||||||
|
current = wrapped
|
||||||
|
|
||||||
|
async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None:
|
||||||
|
return await _invoke_callable(mw, msg, nxt, prefer_thread=False)
|
||||||
|
|
||||||
|
wrapped = wrapper
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_await(result):
|
async def _maybe_await(result):
|
||||||
@@ -122,6 +251,32 @@ async def _maybe_await(result):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||||
|
"""Support sync/async callables with optional thread offloading."""
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
return await func(*args)
|
||||||
|
if prefer_thread:
|
||||||
|
result = await asyncio.to_thread(func, *args)
|
||||||
|
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
result = func(*args)
|
||||||
|
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_segment_type(message: MessageEnvelope) -> str | None:
|
||||||
|
seg = message.get("message_segment") or message.get("message_chain")
|
||||||
|
if isinstance(seg, dict):
|
||||||
|
return seg.get("type")
|
||||||
|
if isinstance(seg, list) and seg:
|
||||||
|
first = seg[0]
|
||||||
|
if isinstance(first, dict):
|
||||||
|
return first.get("type")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchHandler",
|
"BatchHandler",
|
||||||
"Hook",
|
"Hook",
|
||||||
@@ -129,5 +284,6 @@ __all__ = [
|
|||||||
"MessageProcessingError",
|
"MessageProcessingError",
|
||||||
"MessageRoute",
|
"MessageRoute",
|
||||||
"MessageRuntime",
|
"MessageRuntime",
|
||||||
|
"Middleware",
|
||||||
"Predicate",
|
"Predicate",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,160 +3,91 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
from typing import Any, Dict, List, Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
MessageDirection = Literal["incoming", "outgoing"]
|
MessageDirection = Literal["incoming", "outgoing"]
|
||||||
Role = Literal["user", "assistant", "system", "tool", "platform"]
|
|
||||||
ContentType = Literal[
|
|
||||||
"text",
|
|
||||||
"image",
|
|
||||||
"audio",
|
|
||||||
"file",
|
|
||||||
"video",
|
|
||||||
"event",
|
|
||||||
"command",
|
|
||||||
"system",
|
|
||||||
]
|
|
||||||
|
|
||||||
EventType = Literal[
|
# ----------------------------
|
||||||
"message_created",
|
# maim_message 风格的 TypedDict
|
||||||
"message_updated",
|
# ----------------------------
|
||||||
"message_deleted",
|
|
||||||
"member_join",
|
|
||||||
"member_leave",
|
|
||||||
"typing",
|
|
||||||
"reaction_add",
|
|
||||||
"reaction_remove",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TextContent(TypedDict, total=False):
|
class SegPayload(TypedDict, total=False):
|
||||||
type: Literal["text"]
|
"""
|
||||||
text: str
|
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
||||||
markdown: NotRequired[bool]
|
"""
|
||||||
entities: NotRequired[List[Dict[str, Any]]]
|
|
||||||
|
type: str
|
||||||
|
data: str | List["SegPayload"]
|
||||||
|
translated_data: NotRequired[str | List["SegPayload"]]
|
||||||
|
|
||||||
|
|
||||||
class ImageContent(TypedDict, total=False):
|
class UserInfoPayload(TypedDict, total=False):
|
||||||
type: Literal["image"]
|
platform: NotRequired[str]
|
||||||
url: str
|
user_id: NotRequired[str]
|
||||||
mime_type: NotRequired[str]
|
user_nickname: NotRequired[str]
|
||||||
width: NotRequired[int]
|
user_cardname: NotRequired[str]
|
||||||
height: NotRequired[int]
|
user_avatar: NotRequired[str]
|
||||||
file_id: NotRequired[str]
|
|
||||||
|
|
||||||
|
|
||||||
class FileContent(TypedDict, total=False):
|
class GroupInfoPayload(TypedDict, total=False):
|
||||||
type: Literal["file"]
|
platform: NotRequired[str]
|
||||||
url: str
|
group_id: NotRequired[str]
|
||||||
mime_type: NotRequired[str]
|
group_name: NotRequired[str]
|
||||||
file_name: NotRequired[str]
|
|
||||||
file_size: NotRequired[int]
|
|
||||||
file_id: NotRequired[str]
|
|
||||||
|
|
||||||
|
|
||||||
class AudioContent(TypedDict, total=False):
|
class FormatInfoPayload(TypedDict, total=False):
|
||||||
type: Literal["audio"]
|
content_format: NotRequired[List[str]]
|
||||||
url: str
|
accept_format: NotRequired[List[str]]
|
||||||
mime_type: NotRequired[str]
|
|
||||||
duration_ms: NotRequired[int]
|
|
||||||
file_id: NotRequired[str]
|
|
||||||
|
|
||||||
|
|
||||||
class VideoContent(TypedDict, total=False):
|
class TemplateInfoPayload(TypedDict, total=False):
|
||||||
type: Literal["video"]
|
template_items: NotRequired[Dict[str, str]]
|
||||||
url: str
|
template_name: NotRequired[Dict[str, str]]
|
||||||
mime_type: NotRequired[str]
|
template_default: NotRequired[bool]
|
||||||
duration_ms: NotRequired[int]
|
|
||||||
width: NotRequired[int]
|
|
||||||
height: NotRequired[int]
|
|
||||||
file_id: NotRequired[str]
|
|
||||||
|
|
||||||
|
|
||||||
class EventContent(TypedDict):
|
class MessageInfoPayload(TypedDict, total=False):
|
||||||
type: Literal["event"]
|
platform: NotRequired[str]
|
||||||
event_type: EventType
|
message_id: NotRequired[str]
|
||||||
raw: Dict[str, Any]
|
time: NotRequired[float]
|
||||||
|
group_info: NotRequired[GroupInfoPayload]
|
||||||
|
user_info: NotRequired[UserInfoPayload]
|
||||||
|
format_info: NotRequired[FormatInfoPayload]
|
||||||
|
template_info: NotRequired[TemplateInfoPayload]
|
||||||
|
additional_config: NotRequired[Dict[str, Any]]
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
class CommandContent(TypedDict, total=False):
|
# MessageEnvelope
|
||||||
type: Literal["command"]
|
# ----------------------------
|
||||||
name: str
|
|
||||||
args: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class SystemContent(TypedDict):
|
|
||||||
type: Literal["system"]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
Content = (
|
|
||||||
TextContent
|
|
||||||
| ImageContent
|
|
||||||
| FileContent
|
|
||||||
| AudioContent
|
|
||||||
| VideoContent
|
|
||||||
| EventContent
|
|
||||||
| CommandContent
|
|
||||||
| SystemContent
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SenderInfo(TypedDict, total=False):
|
|
||||||
user_id: str
|
|
||||||
role: Role
|
|
||||||
display_name: NotRequired[str]
|
|
||||||
avatar_url: NotRequired[str]
|
|
||||||
raw: NotRequired[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelInfo(TypedDict, total=False):
|
|
||||||
channel_id: str
|
|
||||||
channel_type: Literal[
|
|
||||||
"private",
|
|
||||||
"group",
|
|
||||||
"supergroup",
|
|
||||||
"channel",
|
|
||||||
"dm",
|
|
||||||
"room",
|
|
||||||
"thread",
|
|
||||||
]
|
|
||||||
title: NotRequired[str]
|
|
||||||
workspace_id: NotRequired[str]
|
|
||||||
raw: NotRequired[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class MessageEnvelope(TypedDict, total=False):
|
class MessageEnvelope(TypedDict, total=False):
|
||||||
id: str
|
"""
|
||||||
direction: MessageDirection
|
mofox-bus 传输层统一使用的消息信封。
|
||||||
platform: str
|
|
||||||
timestamp_ms: int
|
|
||||||
channel: ChannelInfo
|
|
||||||
sender: SenderInfo
|
|
||||||
content: Content
|
|
||||||
conversation_id: str
|
|
||||||
thread_id: NotRequired[str]
|
|
||||||
reply_to_message_id: NotRequired[str]
|
|
||||||
correlation_id: NotRequired[str]
|
|
||||||
is_edited: NotRequired[bool]
|
|
||||||
is_ephemeral: NotRequired[bool]
|
|
||||||
raw_platform_message: NotRequired[Dict[str, Any]]
|
|
||||||
metadata: NotRequired[Dict[str, Any]]
|
|
||||||
schema_version: NotRequired[int]
|
|
||||||
|
|
||||||
|
- 采用 maim_message 风格:message_info + message_segment。
|
||||||
|
"""
|
||||||
|
|
||||||
|
direction: MessageDirection
|
||||||
|
message_info: MessageInfoPayload
|
||||||
|
message_segment: SegPayload | List[SegPayload]
|
||||||
|
raw_message: NotRequired[Any]
|
||||||
|
raw_bytes: NotRequired[bytes]
|
||||||
|
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
||||||
|
platform: NotRequired[str] # 快捷访问,等价于 message_info.platform
|
||||||
|
message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id
|
||||||
|
timestamp_ms: NotRequired[int]
|
||||||
|
correlation_id: NotRequired[str]
|
||||||
|
schema_version: NotRequired[int]
|
||||||
|
metadata: NotRequired[Dict[str, Any]]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioContent",
|
# maim_message style payloads
|
||||||
"ChannelInfo",
|
"SegPayload",
|
||||||
"CommandContent",
|
"UserInfoPayload",
|
||||||
"Content",
|
"GroupInfoPayload",
|
||||||
"ContentType",
|
"FormatInfoPayload",
|
||||||
"EventContent",
|
"TemplateInfoPayload",
|
||||||
"EventType",
|
"MessageInfoPayload",
|
||||||
"FileContent",
|
# legacy content style
|
||||||
"ImageContent",
|
|
||||||
"MessageDirection",
|
"MessageDirection",
|
||||||
"MessageEnvelope",
|
"MessageEnvelope",
|
||||||
"Role",
|
|
||||||
"SenderInfo",
|
|
||||||
"SystemContent",
|
|
||||||
"TextContent",
|
|
||||||
"VideoContent",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreMessageSink, MessageEnvelope
|
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
@@ -47,7 +47,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
core_sink: CoreMessageSink,
|
core_sink: CoreSink,
|
||||||
plugin: Optional[BasePlugin] = None,
|
plugin: Optional[BasePlugin] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -227,7 +227,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||||
"""
|
"""
|
||||||
将平台原始消息转换为 MessageEnvelope
|
将平台原始消息转换为 MessageEnvelope
|
||||||
|
|
||||||
|
|||||||
@@ -7,130 +7,152 @@ Adapter 管理器
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import subprocess
|
import importlib
|
||||||
import sys
|
import multiprocessing as mp
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.plugin_system.base.base_adapter import BaseAdapter
|
from src.plugin_system.base.base_adapter import BaseAdapter
|
||||||
|
|
||||||
|
from mofox_bus import ProcessCoreSinkServer
|
||||||
|
from src.common.core_sink import get_core_sink
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("adapter_manager")
|
logger = get_logger("adapter_manager")
|
||||||
|
|
||||||
|
|
||||||
class AdapterProcess:
|
|
||||||
"""适配器子进程包装器"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
def _load_class(module_name: str, class_name: str):
|
||||||
adapter_name: str,
|
module = importlib.import_module(module_name)
|
||||||
entry_path: Path,
|
return getattr(module, class_name)
|
||||||
python_executable: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.adapter_name = adapter_name
|
def _adapter_process_entry(
|
||||||
self.entry_path = entry_path
|
adapter_path: tuple[str, str],
|
||||||
self.python_executable = python_executable or sys.executable
|
plugin_info: dict | None,
|
||||||
self.process: Optional[subprocess.Popen] = None
|
incoming_queue: mp.Queue,
|
||||||
self._monitor_task: Optional[asyncio.Task] = None
|
outgoing_queue: mp.Queue,
|
||||||
|
):
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
from mofox_bus import ProcessCoreSink
|
||||||
|
|
||||||
|
async def _run() -> None:
|
||||||
|
adapter_cls = _load_class(*adapter_path)
|
||||||
|
plugin_instance = None
|
||||||
|
if plugin_info:
|
||||||
|
plugin_cls = _load_class(plugin_info["module"], plugin_info["class"])
|
||||||
|
plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"])
|
||||||
|
core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue)
|
||||||
|
adapter = adapter_cls(core_sink, plugin=plugin_instance)
|
||||||
|
await adapter.start()
|
||||||
|
try:
|
||||||
|
while not getattr(core_sink, "_closed", False):
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
finally:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await adapter.stop()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await core_sink.close()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterProcess:
|
||||||
|
"""适配器子进程包装器,负责适配器子进程的启动和生命周期管理"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "BaseAdapter", core_sink) -> None:
|
||||||
|
self.adapter = adapter
|
||||||
|
self.adapter_name = adapter.adapter_name
|
||||||
|
self.process: mp.Process | None = None
|
||||||
|
self._ctx = mp.get_context("spawn")
|
||||||
|
self._incoming_queue: mp.Queue = self._ctx.Queue()
|
||||||
|
self._outgoing_queue: mp.Queue = self._ctx.Queue()
|
||||||
|
self._bridge: ProcessCoreSinkServer | None = None
|
||||||
|
self._core_sink = core_sink
|
||||||
|
self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__)
|
||||||
|
self._plugin_info = self._extract_plugin_info(adapter)
|
||||||
|
self._outgoing_handler = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None:
|
||||||
|
plugin = getattr(adapter, "plugin", None)
|
||||||
|
if plugin is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"module": plugin.__class__.__module__,
|
||||||
|
"class": plugin.__class__.__name__,
|
||||||
|
"plugin_dir": getattr(plugin, "plugin_dir", ""),
|
||||||
|
"metadata": getattr(plugin, "plugin_meta", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _make_outgoing_handler(self):
|
||||||
|
async def _handler(envelope):
|
||||||
|
if self._bridge:
|
||||||
|
await self._bridge.push_outgoing(envelope)
|
||||||
|
return _handler
|
||||||
|
|
||||||
async def start(self) -> bool:
|
async def start(self) -> bool:
|
||||||
"""启动适配器子进程"""
|
"""启动适配器子进程"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"启动适配器子进程: {self.adapter_name}")
|
logger.info(f"启动适配器子进程: {self.adapter_name}")
|
||||||
logger.debug(f"Python: {self.python_executable}")
|
self._bridge = ProcessCoreSinkServer(
|
||||||
logger.debug(f"Entry: {self.entry_path}")
|
incoming_queue=self._incoming_queue,
|
||||||
|
outgoing_queue=self._outgoing_queue,
|
||||||
# 启动子进程
|
core_handler=self._core_sink.send,
|
||||||
self.process = subprocess.Popen(
|
name=self.adapter_name,
|
||||||
[self.python_executable, str(self.entry_path)],
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
bufsize=1,
|
|
||||||
)
|
)
|
||||||
|
self._bridge.start()
|
||||||
# 启动监控任务
|
if hasattr(self._core_sink, "set_outgoing_handler"):
|
||||||
self._monitor_task = asyncio.create_task(self._monitor_process())
|
self._outgoing_handler = self._make_outgoing_handler()
|
||||||
|
try:
|
||||||
logger.info(f"适配器 {self.adapter_name} 子进程已启动 (PID: {self.process.pid})")
|
self._core_sink.set_outgoing_handler(self._outgoing_handler)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to register outgoing bridge for %s", self.adapter_name)
|
||||||
|
self.process = self._ctx.Process(
|
||||||
|
target=_adapter_process_entry,
|
||||||
|
args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue),
|
||||||
|
name=f"{self.adapter_name}-proc",
|
||||||
|
)
|
||||||
|
self.process.start()
|
||||||
|
logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动适配器 {self.adapter_name} 子进程失败: {e}", exc_info=True)
|
logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""停止适配器子进程"""
|
"""停止适配器子进程"""
|
||||||
if not self.process:
|
if not self.process:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
|
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 取消监控任务
|
remover = getattr(self._core_sink, "remove_outgoing_handler", None)
|
||||||
if self._monitor_task and not self._monitor_task.done():
|
if callable(remover) and self._outgoing_handler:
|
||||||
self._monitor_task.cancel()
|
|
||||||
try:
|
try:
|
||||||
await self._monitor_task
|
remover(self._outgoing_handler)
|
||||||
except asyncio.CancelledError:
|
except Exception:
|
||||||
pass
|
logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败")
|
||||||
|
if self._bridge:
|
||||||
# 终止进程
|
await self._bridge.close()
|
||||||
self.process.terminate()
|
if self.process.is_alive():
|
||||||
|
self.process.join(timeout=5.0)
|
||||||
# 等待进程退出(最多等待5秒)
|
if self.process.is_alive():
|
||||||
try:
|
logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中")
|
||||||
await asyncio.wait_for(
|
self.process.terminate()
|
||||||
asyncio.to_thread(self.process.wait),
|
self.process.join()
|
||||||
timeout=5.0
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"适配器 {self.adapter_name} 未能在5秒内退出,强制终止")
|
|
||||||
self.process.kill()
|
|
||||||
await asyncio.to_thread(self.process.wait)
|
|
||||||
|
|
||||||
logger.info(f"适配器 {self.adapter_name} 子进程已停止")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
|
logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self.process = None
|
self.process = None
|
||||||
|
|
||||||
async def _monitor_process(self) -> None:
|
|
||||||
"""监控子进程状态"""
|
|
||||||
if not self.process:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 在后台线程中等待进程退出
|
|
||||||
return_code = await asyncio.to_thread(self.process.wait)
|
|
||||||
|
|
||||||
if return_code != 0:
|
|
||||||
logger.error(
|
|
||||||
f"适配器 {self.adapter_name} 子进程异常退出 (返回码: {return_code})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 读取 stderr 输出
|
|
||||||
if self.process.stderr:
|
|
||||||
stderr = self.process.stderr.read()
|
|
||||||
if stderr:
|
|
||||||
logger.error(f"错误输出:\n{stderr}")
|
|
||||||
else:
|
|
||||||
logger.info(f"适配器 {self.adapter_name} 子进程正常退出")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"监控适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""检查进程是否正在运行"""
|
"""适配器是否正在运行"""
|
||||||
if not self.process:
|
if not self.process:
|
||||||
return False
|
return False
|
||||||
return self.process.poll() is None
|
return self.process.is_alive()
|
||||||
|
|
||||||
|
|
||||||
class AdapterManager:
|
class AdapterManager:
|
||||||
"""适配器管理器"""
|
"""适配器管理器"""
|
||||||
@@ -176,20 +198,17 @@ class AdapterManager:
|
|||||||
else:
|
else:
|
||||||
return await self._start_adapter_in_process(adapter)
|
return await self._start_adapter_in_process(adapter)
|
||||||
|
|
||||||
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
|
|
||||||
"""在子进程中启动适配器"""
|
|
||||||
adapter_name = adapter.adapter_name
|
|
||||||
|
|
||||||
# 获取子进程入口脚本
|
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
|
||||||
entry_path = adapter.get_subprocess_entry_path()
|
"""启动适配器子进程"""
|
||||||
if not entry_path:
|
adapter_name = adapter.adapter_name
|
||||||
logger.error(
|
try:
|
||||||
f"适配器 {adapter_name} 配置为子进程运行,但未提供有效的入口脚本"
|
core_sink = get_core_sink()
|
||||||
)
|
except Exception as e:
|
||||||
|
logger.error(f"无法获取 core_sink,启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 创建并启动子进程
|
adapter_process = AdapterProcess(adapter, core_sink)
|
||||||
adapter_process = AdapterProcess(adapter_name, entry_path)
|
|
||||||
success = await adapter_process.start()
|
success = await adapter_process.start()
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
|
|||||||
427
src/plugins/built_in/NEW_napcat_adapter/README.md
Normal file
427
src/plugins/built_in/NEW_napcat_adapter/README.md
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
# NEW_napcat_adapter
|
||||||
|
|
||||||
|
基于 mofox-bus v2.x 的 Napcat 适配器(使用 BaseAdapter 架构)
|
||||||
|
|
||||||
|
## 🏗️ 架构设计
|
||||||
|
|
||||||
|
本插件采用 **BaseAdapter 继承模式** 重写,完全抛弃旧版 maim_message 库,改用 mofox-bus 的 TypedDict 数据结构。
|
||||||
|
|
||||||
|
### 核心组件
|
||||||
|
- **NapcatAdapter**: 继承自 `mofox_bus.AdapterBase`,负责 OneBot 11 协议与 MessageEnvelope 的双向转换
|
||||||
|
- **WebSocketAdapterOptions**: 自动管理 WebSocket 连接,提供 incoming_parser 和 outgoing_encoder
|
||||||
|
- **CoreMessageSink**: 通过 `InProcessCoreSink` 将消息递送到核心系统
|
||||||
|
- **Handlers**: 独立的消息处理器,分为 to_core(接收)和 to_napcat(发送)两个方向
|
||||||
|
|
||||||
|
## 📁 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
NEW_napcat_adapter/
|
||||||
|
├── plugin.py # ✅ 主插件文件(BaseAdapter实现)
|
||||||
|
├── _manifest.json # 插件清单
|
||||||
|
│
|
||||||
|
└── src/
|
||||||
|
├── event_models.py # ✅ OneBot事件类型常量
|
||||||
|
├── common/
|
||||||
|
│ └── core_sink.py # ✅ 全局CoreSink访问点
|
||||||
|
│
|
||||||
|
├── utils/
|
||||||
|
│ ├── utils.py # ⏳ 工具函数(待实现)
|
||||||
|
│ ├── qq_emoji_list.py # ⏳ QQ表情映射(待实现)
|
||||||
|
│ ├── video_handler.py # ⏳ 视频处理(待实现)
|
||||||
|
│ └── message_chunker.py # ⏳ 消息切片(待实现)
|
||||||
|
│
|
||||||
|
├── websocket/
|
||||||
|
│ └── (无需单独实现,使用WebSocketAdapterOptions)
|
||||||
|
│
|
||||||
|
├── database/
|
||||||
|
│ └── database.py # ⏳ 数据库模型(待实现)
|
||||||
|
│
|
||||||
|
└── handlers/
|
||||||
|
├── to_core/ # Napcat → MessageEnvelope 方向
|
||||||
|
│ ├── message_handler.py # ⏳ 消息处理(部分完成)
|
||||||
|
│ ├── notice_handler.py # ⏳ 通知处理(待完成)
|
||||||
|
│ └── meta_event_handler.py # ⏳ 元事件(待完成)
|
||||||
|
│
|
||||||
|
└── to_napcat/ # MessageEnvelope → Napcat API 方向
|
||||||
|
└── send_handler.py # ⏳ 发送处理(部分完成)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 使用方式
|
||||||
|
|
||||||
|
1. **配置文件**: 在 `config/plugins/NEW_napcat_adapter.toml` 中配置 WebSocket URL 和其他参数
|
||||||
|
2. **启动插件**: 插件自动在系统启动时加载
|
||||||
|
3. **WebSocket连接**: 自动连接到 Napcat OneBot 11 服务器
|
||||||
|
|
||||||
|
## 🔑 核心数据结构
|
||||||
|
|
||||||
|
### MessageEnvelope (mofox-bus v2.x)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload
|
||||||
|
|
||||||
|
# 创建消息信封
|
||||||
|
envelope: MessageEnvelope = {
|
||||||
|
"direction": "input",
|
||||||
|
"message_info": {
|
||||||
|
"message_type": "group",
|
||||||
|
"message_id": "12345",
|
||||||
|
"self_id": "bot_qq",
|
||||||
|
"user_info": {
|
||||||
|
"user_id": "sender_qq",
|
||||||
|
"user_name": "发送者",
|
||||||
|
"user_displayname": "昵称"
|
||||||
|
},
|
||||||
|
"group_info": {
|
||||||
|
"group_id": "group_id",
|
||||||
|
"group_name": "群名"
|
||||||
|
},
|
||||||
|
"to_me": False
|
||||||
|
},
|
||||||
|
"message_segment": {
|
||||||
|
"type": "seglist",
|
||||||
|
"data": [
|
||||||
|
{"type": "text", "data": "hello"},
|
||||||
|
{"type": "image", "data": "base64_data"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"raw_message": "hello[图片]",
|
||||||
|
"platform": "napcat",
|
||||||
|
"message_id": "12345",
|
||||||
|
"timestamp_ms": 1234567890
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### BaseAdapter 核心方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
class NapcatAdapter(BaseAdapter):
|
||||||
|
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
|
||||||
|
"""将 OneBot 11 事件转换为 MessageEnvelope"""
|
||||||
|
# 路由到对应的 Handler
|
||||||
|
|
||||||
|
async def _send_platform_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
|
||||||
|
"""将 MessageEnvelope 转换为 OneBot 11 API 调用"""
|
||||||
|
# 调用 SendHandler 处理
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📝 实现进度
|
||||||
|
|
||||||
|
### ✅ 已完成的核心架构
|
||||||
|
|
||||||
|
1. **BaseAdapter 实现** (plugin.py)
|
||||||
|
- ✅ WebSocket 自动连接管理
|
||||||
|
- ✅ from_platform_message() 事件路由
|
||||||
|
- ✅ _send_platform_message() 消息发送
|
||||||
|
- ✅ API 响应池机制(echo-based request-response)
|
||||||
|
- ✅ CoreSink 集成
|
||||||
|
|
||||||
|
2. **Handler 基础结构**
|
||||||
|
- ✅ MessageHandler 骨架(text、image、at 基本实现)
|
||||||
|
- ✅ NoticeHandler 骨架
|
||||||
|
- ✅ MetaEventHandler 骨架
|
||||||
|
- ✅ SendHandler 骨架(基本类型转换)
|
||||||
|
|
||||||
|
3. **辅助组件**
|
||||||
|
- ✅ event_models.py(事件类型常量)
|
||||||
|
- ✅ core_sink.py(全局 CoreSink 访问)
|
||||||
|
- ✅ 配置 Schema 定义
|
||||||
|
|
||||||
|
### ⏳ 部分完成的功能
|
||||||
|
|
||||||
|
4. **消息类型处理** (MessageHandler)
|
||||||
|
- ✅ 基础消息类型:text, image, at
|
||||||
|
- ❌ 高级消息类型:face, reply, forward, video, json, file, rps, dice, shake
|
||||||
|
|
||||||
|
5. **发送处理** (SendHandler)
|
||||||
|
- ✅ 基础 SegPayload 转换:text, image
|
||||||
|
- ❌ 高级 Seg 类型:emoji, voice, voiceurl, music, videourl, file, command
|
||||||
|
|
||||||
|
### ❌ 待实现的功能
|
||||||
|
|
||||||
|
6. **通知事件处理** (NoticeHandler)
|
||||||
|
- ❌ 戳一戳事件
|
||||||
|
- ❌ 表情回应事件
|
||||||
|
- ❌ 撤回事件
|
||||||
|
- ❌ 禁言事件
|
||||||
|
|
||||||
|
7. **工具函数** (utils.py)
|
||||||
|
- ❌ get_group_info
|
||||||
|
- ❌ get_member_info
|
||||||
|
- ❌ get_image_base64
|
||||||
|
- ❌ get_message_detail
|
||||||
|
- ❌ get_record_detail
|
||||||
|
|
||||||
|
8. **权限系统**
|
||||||
|
- ❌ check_allow_to_chat()
|
||||||
|
- ❌ 群组黑名单/白名单
|
||||||
|
- ❌ 私聊黑名单/白名单
|
||||||
|
- ❌ QQ机器人检测
|
||||||
|
|
||||||
|
9. **其他组件**
|
||||||
|
- ❌ 视频处理器
|
||||||
|
- ❌ 消息切片器
|
||||||
|
- ❌ 数据库模型
|
||||||
|
- ❌ QQ 表情映射表
|
||||||
|
|
||||||
|
## 📋 下一步工作
|
||||||
|
|
||||||
|
### 优先级 1:完善消息处理(参考旧版 recv_handler/message_handler.py)
|
||||||
|
|
||||||
|
1. **完整实现 MessageHandler.handle_raw_message()**
|
||||||
|
- [ ] face(表情)消息段
|
||||||
|
- [ ] reply(回复)消息段
|
||||||
|
- [ ] forward(转发)消息段解析
|
||||||
|
- [ ] video(视频)消息段
|
||||||
|
- [ ] json(JSON卡片)消息段
|
||||||
|
- [ ] file(文件)消息段
|
||||||
|
- [ ] rps/dice/shake(特殊消息)
|
||||||
|
|
||||||
|
2. **实现工具函数**(参考旧版 utils.py)
|
||||||
|
- [ ] `get_group_info()` - 获取群组信息
|
||||||
|
- [ ] `get_member_info()` - 获取成员信息
|
||||||
|
- [ ] `get_image_base64()` - 下载图片并转Base64
|
||||||
|
- [ ] `get_message_detail()` - 获取消息详情
|
||||||
|
- [ ] `get_record_detail()` - 获取语音详情
|
||||||
|
|
||||||
|
3. **实现权限检查**
|
||||||
|
- [ ] `check_allow_to_chat()` - 检查是否允许聊天
|
||||||
|
- [ ] 群组白名单/黑名单逻辑
|
||||||
|
- [ ] 私聊白名单/黑名单逻辑
|
||||||
|
- [ ] QQ机器人检测(ban_qq_bot)
|
||||||
|
|
||||||
|
### 优先级 2:完善发送处理(参考旧版 send_handler.py)
|
||||||
|
|
||||||
|
4. **完整实现 SendHandler._convert_seg_to_onebot()**
|
||||||
|
- [ ] emoji(表情回应)命令
|
||||||
|
- [ ] voice(语音)消息段
|
||||||
|
- [ ] voiceurl(语音URL)消息段
|
||||||
|
- [ ] music(音乐卡片)消息段
|
||||||
|
- [ ] videourl(视频URL)消息段
|
||||||
|
- [ ] file(文件)消息段
|
||||||
|
- [ ] command(命令)消息段
|
||||||
|
|
||||||
|
5. **实现命令处理**
|
||||||
|
- [ ] GROUP_BAN(禁言)
|
||||||
|
- [ ] GROUP_KICK(踢人)
|
||||||
|
- [ ] SEND_POKE(戳一戳)
|
||||||
|
- [ ] DELETE_MSG(撤回消息)
|
||||||
|
- [ ] GROUP_WHOLE_BAN(全员禁言)
|
||||||
|
- [ ] SET_GROUP_CARD(设置群名片)
|
||||||
|
- [ ] SET_GROUP_ADMIN(设置管理员)
|
||||||
|
|
||||||
|
### 优先级 3:补全其他组件(参考旧版对应文件)
|
||||||
|
|
||||||
|
6. **NoticeHandler 实现**
|
||||||
|
- [ ] 戳一戳通知(notify.poke)
|
||||||
|
- [ ] 表情回应通知(notice.group_emoji_like)
|
||||||
|
- [ ] 消息撤回通知(notice.group_recall)
|
||||||
|
- [ ] 禁言通知(notice.group_ban)
|
||||||
|
|
||||||
|
7. **辅助组件**
|
||||||
|
- [ ] `qq_emoji_list.py` - QQ表情ID映射表
|
||||||
|
- [ ] `video_handler.py` - 视频处理(ffmpeg封面提取)
|
||||||
|
- [ ] `message_chunker.py` - 消息分块与重组
|
||||||
|
- [ ] `database.py` - 数据库模型(如有需要)
|
||||||
|
|
||||||
|
### 优先级 4:测试与优化
|
||||||
|
|
||||||
|
8. **功能测试**
|
||||||
|
- [ ] 文本消息收发
|
||||||
|
- [ ] 图片消息收发
|
||||||
|
- [ ] @消息处理
|
||||||
|
- [ ] 表情/语音/视频消息
|
||||||
|
- [ ] 转发消息解析
|
||||||
|
- [ ] 所有命令功能
|
||||||
|
- [ ] 通知事件处理
|
||||||
|
|
||||||
|
9. **性能优化**
|
||||||
|
- [ ] 消息处理并发性能
|
||||||
|
- [ ] API响应池性能
|
||||||
|
- [ ] 内存占用优化
|
||||||
|
|
||||||
|
## 🔍 关键实现细节
|
||||||
|
|
||||||
|
### 1. MessageEnvelope vs 旧版 MessageBase
|
||||||
|
|
||||||
|
**不再使用 Seg dataclass**,全部使用 TypedDict:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ 旧版(maim_message)
|
||||||
|
from mofox_bus import Seg, MessageBase
|
||||||
|
|
||||||
|
seg = Seg(type="text", data="hello")
|
||||||
|
message = MessageBase(message_info=info, message_segment=seg)
|
||||||
|
|
||||||
|
# ✅ 新版(mofox-bus v2.x)
|
||||||
|
from mofox_bus import SegPayload, MessageEnvelope
|
||||||
|
|
||||||
|
seg_payload: SegPayload = {"type": "text", "data": "hello"}
|
||||||
|
envelope: MessageEnvelope = {
|
||||||
|
"direction": "input",
|
||||||
|
"message_info": {...},
|
||||||
|
"message_segment": seg_payload,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Handler 架构模式
|
||||||
|
|
||||||
|
**接收方向** (to_core):
|
||||||
|
```python
|
||||||
|
class MessageHandler:
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
async def handle_raw_message(self, data: dict[str, Any]) -> MessageEnvelope:
|
||||||
|
# 1. 解析 OneBot 11 数据
|
||||||
|
# 2. 构建 message_info(MessageInfoPayload)
|
||||||
|
# 3. 转换消息段为 SegPayload
|
||||||
|
# 4. 返回完整的 MessageEnvelope
|
||||||
|
```
|
||||||
|
|
||||||
|
**发送方向** (to_napcat):
|
||||||
|
```python
|
||||||
|
class SendHandler:
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
async def handle_message(self, envelope: MessageEnvelope) -> dict[str, Any]:
|
||||||
|
# 1. 从 envelope 提取 message_segment
|
||||||
|
# 2. 递归转换 SegPayload → OneBot 格式
|
||||||
|
# 3. 调用 adapter.send_napcat_api() 发送
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. API 调用模式(响应池)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 NapcatAdapter 中
|
||||||
|
async def send_napcat_api(self, action: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
# 1. 生成唯一 echo
|
||||||
|
echo = f"{action}_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# 2. 创建 Future 等待响应
|
||||||
|
future = asyncio.Future()
|
||||||
|
self._response_pool[echo] = future
|
||||||
|
|
||||||
|
# 3. 发送请求(通过 WebSocket)
|
||||||
|
await self._send_request({"action": action, "params": params, "echo": echo})
|
||||||
|
|
||||||
|
# 4. 等待响应(带超时)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=10.0)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self._response_pool.pop(echo, None)
|
||||||
|
|
||||||
|
# 响应回来时(在 incoming_parser 中)
|
||||||
|
def _handle_api_response(data: dict[str, Any]):
|
||||||
|
echo = data.get("echo")
|
||||||
|
if echo in adapter._response_pool:
|
||||||
|
adapter._response_pool[echo].set_result(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 类型提示技巧
|
||||||
|
|
||||||
|
处理 TypedDict 的严格类型检查:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 使用 type: ignore 标注(编译时是 TypedDict,运行时是 dict)
|
||||||
|
envelope: MessageEnvelope = {
|
||||||
|
"direction": "input",
|
||||||
|
...
|
||||||
|
} # type: ignore[typeddict-item]
|
||||||
|
|
||||||
|
# 或在函数签名中使用 dict[str, Any]
|
||||||
|
async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None:
|
||||||
|
...
|
||||||
|
return envelope # type: ignore[return-value]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔍 测试检查清单
|
||||||
|
|
||||||
|
- [ ] 文本消息接收/发送
|
||||||
|
- [ ] 图片消息接收/发送
|
||||||
|
- [ ] 语音消息接收/发送
|
||||||
|
- [ ] 视频消息接收/发送
|
||||||
|
- [ ] @消息接收/发送
|
||||||
|
- [ ] 回复消息接收/发送
|
||||||
|
- [ ] 转发消息接收
|
||||||
|
- [ ] JSON消息接收
|
||||||
|
- [ ] 文件消息接收/发送
|
||||||
|
- [ ] 禁言命令
|
||||||
|
- [ ] 踢人命令
|
||||||
|
- [ ] 戳一戳命令
|
||||||
|
- [ ] 表情回应命令
|
||||||
|
- [ ] 通知事件处理
|
||||||
|
- [ ] 元事件处理
|
||||||
|
|
||||||
|
## 📚 参考资料
|
||||||
|
|
||||||
|
- **mofox-bus 文档**: 查看 `mofox_bus/types.py` 了解 TypedDict 定义
|
||||||
|
- **BaseAdapter 示例**: 参考 `docs/mofox_bus_demo_adapter.py`
|
||||||
|
- **旧版实现**: `src/plugins/built_in/napcat_adapter_plugin/` (仅参考逻辑)
|
||||||
|
- **OneBot 11 协议**: [OneBot 11 标准](https://github.com/botuniverse/onebot-11)
|
||||||
|
|
||||||
|
## ⚠️ 重要注意事项
|
||||||
|
|
||||||
|
1. **完全抛弃旧版数据结构**
|
||||||
|
- ❌ 不再使用 `Seg` dataclass
|
||||||
|
- ❌ 不再使用 `MessageBase` 类
|
||||||
|
- ✅ 全部使用 `SegPayload`(TypedDict)
|
||||||
|
- ✅ 全部使用 `MessageEnvelope`(TypedDict)
|
||||||
|
|
||||||
|
2. **BaseAdapter 生命周期**
|
||||||
|
- `__init__()` 中初始化同步资源
|
||||||
|
- `start()` 中执行异步初始化(WebSocket连接自动建立)
|
||||||
|
- `stop()` 中清理资源(WebSocket自动断开)
|
||||||
|
|
||||||
|
3. **WebSocketAdapterOptions 自动管理**
|
||||||
|
- 无需手动管理 WebSocket 连接
|
||||||
|
- incoming_parser 自动解析接收数据
|
||||||
|
- outgoing_encoder 自动编码发送数据
|
||||||
|
- 重连机制由基类处理
|
||||||
|
|
||||||
|
4. **CoreSink 依赖注入**
|
||||||
|
- 必须在插件加载后调用 `set_core_sink(sink)`
|
||||||
|
- 通过 `get_core_sink()` 全局访问
|
||||||
|
- 用于将消息递送到核心系统
|
||||||
|
|
||||||
|
5. **类型安全与灵活性平衡**
|
||||||
|
- TypedDict 在编译时提供类型检查
|
||||||
|
- 运行时仍是普通 dict,可灵活操作
|
||||||
|
- 必要时使用 `type: ignore` 抑制误报
|
||||||
|
|
||||||
|
6. **参考旧版但不照搬**
|
||||||
|
- 旧版逻辑流程可参考
|
||||||
|
- 数据结构需完全重写
|
||||||
|
- API调用模式已改变(响应池)
|
||||||
|
|
||||||
|
## 📊 预估工作量
|
||||||
|
|
||||||
|
- ✅ 核心架构: **已完成** (BaseAdapter + Handlers 骨架)
|
||||||
|
- ⏳ 消息处理完善: **4-6 小时** (所有消息类型 + 工具函数)
|
||||||
|
- ⏳ 发送处理完善: **3-4 小时** (所有 Seg 类型 + 命令)
|
||||||
|
- ⏳ 通知事件处理: **2-3 小时** (poke/emoji_like/recall/ban)
|
||||||
|
- ⏳ 测试调试: **2-4 小时** (全流程测试)
|
||||||
|
- **总剩余时间: 11-17 小时**
|
||||||
|
|
||||||
|
## ✅ 完成标准
|
||||||
|
|
||||||
|
当以下条件全部满足时,重写完成:
|
||||||
|
|
||||||
|
1. ✅ BaseAdapter 架构实现完成
|
||||||
|
2. ⏳ 所有 OneBot 11 消息类型支持
|
||||||
|
3. ⏳ 所有发送消息段类型支持
|
||||||
|
4. ⏳ 所有通知事件正确处理
|
||||||
|
5. ⏳ 权限系统集成完成
|
||||||
|
6. ⏳ 与旧版功能完全对等
|
||||||
|
7. ⏳ 所有测试用例通过
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**最后更新**: 2025-11-23
|
||||||
|
**架构状态**: ✅ 核心架构完成
|
||||||
|
**实现状态**: ⏳ 消息处理部分完成,需完善细节
|
||||||
|
**预计完成**: 根据优先级,核心功能预计 1-2 个工作日
|
||||||
16
src/plugins/built_in/NEW_napcat_adapter/__init__.py
Normal file
16
src/plugins/built_in/NEW_napcat_adapter/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||||
|
|
||||||
|
__plugin_meta__ = PluginMetadata(
|
||||||
|
name="napcat_plugin",
|
||||||
|
description="基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接",
|
||||||
|
usage="该插件提供 `napcat_tool` tool。",
|
||||||
|
version="1.0.0",
|
||||||
|
author="Windpicker_owo",
|
||||||
|
license="GPL-v3.0-or-later",
|
||||||
|
repository_url="https://github.com/Windpicker-owo",
|
||||||
|
keywords=["qq", "bot", "napcat", "onebot", "api", "websocket"],
|
||||||
|
categories=["protocol"],
|
||||||
|
extra={
|
||||||
|
"is_built_in": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
330
src/plugins/built_in/NEW_napcat_adapter/plugin.py
Normal file
330
src/plugins/built_in/NEW_napcat_adapter/plugin.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
Napcat 适配器(基于 MoFox-Bus 完全重写版)
|
||||||
|
|
||||||
|
核心流程:
|
||||||
|
1. Napcat WebSocket 连接 → 接收 OneBot 格式消息
|
||||||
|
2. from_platform_message: OneBot dict → MessageEnvelope
|
||||||
|
3. CoreSink → 推送到 MoFox-Bot 核心
|
||||||
|
4. 核心回复 → _send_platform_message: MessageEnvelope → OneBot API 调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from typing import Any, ClassVar, Dict, List, Optional
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system import register_plugin
|
||||||
|
from src.plugin_system.base import BaseAdapter, BasePlugin
|
||||||
|
from src.plugin_system.apis import config_api
|
||||||
|
|
||||||
|
from .src.handlers.to_core.message_handler import MessageHandler
|
||||||
|
from .src.handlers.to_core.notice_handler import NoticeHandler
|
||||||
|
from .src.handlers.to_core.meta_event_handler import MetaEventHandler
|
||||||
|
from .src.handlers.to_napcat.send_handler import SendHandler
|
||||||
|
|
||||||
|
logger = get_logger("napcat_adapter")
|
||||||
|
|
||||||
|
|
||||||
|
class NapcatAdapter(BaseAdapter):
|
||||||
|
"""Napcat 适配器 - 完全基于 mofox-bus 架构"""
|
||||||
|
|
||||||
|
adapter_name = "napcat_adapter"
|
||||||
|
adapter_version = "2.0.0"
|
||||||
|
adapter_author = "MoFox Team"
|
||||||
|
adapter_description = "基于 MoFox-Bus 的 Napcat/OneBot 11 适配器"
|
||||||
|
platform = "qq"
|
||||||
|
|
||||||
|
run_in_subprocess = False
|
||||||
|
subprocess_entry = None
|
||||||
|
|
||||||
|
def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None):
|
||||||
|
"""初始化 Napcat 适配器"""
|
||||||
|
# 从插件配置读取 WebSocket URL
|
||||||
|
if plugin:
|
||||||
|
mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse")
|
||||||
|
host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost")
|
||||||
|
port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095)
|
||||||
|
url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "")
|
||||||
|
access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "")
|
||||||
|
|
||||||
|
if mode == "forward" and url:
|
||||||
|
ws_url = url
|
||||||
|
else:
|
||||||
|
ws_url = f"ws://{host}:{port}"
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if access_token:
|
||||||
|
headers["Authorization"] = f"Bearer {access_token}"
|
||||||
|
else:
|
||||||
|
ws_url = "ws://127.0.0.1:8095"
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
# 配置 WebSocket 传输
|
||||||
|
transport = WebSocketAdapterOptions(
|
||||||
|
url=ws_url,
|
||||||
|
headers=headers if headers else None,
|
||||||
|
incoming_parser=self._parse_napcat_message,
|
||||||
|
outgoing_encoder=self._encode_napcat_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(core_sink, plugin=plugin, transport=transport)
|
||||||
|
|
||||||
|
# 初始化处理器
|
||||||
|
self.message_handler = MessageHandler(self)
|
||||||
|
self.notice_handler = NoticeHandler(self)
|
||||||
|
self.meta_event_handler = MetaEventHandler(self)
|
||||||
|
self.send_handler = SendHandler(self)
|
||||||
|
|
||||||
|
# 响应池:用于存储等待的 API 响应
|
||||||
|
self._response_pool: Dict[str, asyncio.Future] = {}
|
||||||
|
self._response_timeout = 30.0
|
||||||
|
|
||||||
|
# WebSocket 连接(用于发送 API 请求)
|
||||||
|
# 注意:_ws 继承自 BaseAdapter,是 WebSocketLike 协议类型
|
||||||
|
self._napcat_ws = None # 可选的额外连接引用
|
||||||
|
|
||||||
|
async def on_adapter_loaded(self) -> None:
|
||||||
|
"""适配器加载时的初始化"""
|
||||||
|
logger.info("Napcat 适配器正在启动...")
|
||||||
|
|
||||||
|
# 设置处理器配置
|
||||||
|
if self.plugin:
|
||||||
|
self.message_handler.set_plugin_config(self.plugin.config)
|
||||||
|
self.notice_handler.set_plugin_config(self.plugin.config)
|
||||||
|
self.meta_event_handler.set_plugin_config(self.plugin.config)
|
||||||
|
self.send_handler.set_plugin_config(self.plugin.config)
|
||||||
|
|
||||||
|
logger.info("Napcat 适配器已加载")
|
||||||
|
|
||||||
|
async def on_adapter_unloaded(self) -> None:
|
||||||
|
"""适配器卸载时的清理"""
|
||||||
|
logger.info("Napcat 适配器正在关闭...")
|
||||||
|
|
||||||
|
# 清理响应池
|
||||||
|
for future in self._response_pool.values():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self._response_pool.clear()
|
||||||
|
|
||||||
|
logger.info("Napcat 适配器已关闭")
|
||||||
|
|
||||||
|
def _parse_napcat_message(self, raw: str | bytes) -> Any:
|
||||||
|
"""解析 Napcat/OneBot 消息"""
|
||||||
|
try:
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
data = orjson.loads(raw)
|
||||||
|
else:
|
||||||
|
data = orjson.loads(raw)
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 Napcat 消息失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes:
|
||||||
|
"""编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)"""
|
||||||
|
return orjson.dumps(envelope)
|
||||||
|
|
||||||
|
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
|
||||||
|
|
||||||
|
这是核心转换方法,处理:
|
||||||
|
- message 事件 → 消息
|
||||||
|
- notice 事件 → 通知(戳一戳、表情回复等)
|
||||||
|
- meta_event 事件 → 元事件(心跳、生命周期)
|
||||||
|
- API 响应 → 存入响应池
|
||||||
|
"""
|
||||||
|
post_type = raw.get("post_type")
|
||||||
|
|
||||||
|
# API 响应(没有 post_type,有 echo)
|
||||||
|
if post_type is None and "echo" in raw:
|
||||||
|
echo = raw.get("echo")
|
||||||
|
if echo and echo in self._response_pool:
|
||||||
|
future = self._response_pool[echo]
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(raw)
|
||||||
|
# API 响应不需要转换为 MessageEnvelope,返回空信封
|
||||||
|
return self._create_empty_envelope()
|
||||||
|
|
||||||
|
# 消息事件
|
||||||
|
if post_type == "message":
|
||||||
|
return await self.message_handler.handle_raw_message(raw) # type: ignore[return-value]
|
||||||
|
|
||||||
|
# 通知事件
|
||||||
|
elif post_type == "notice":
|
||||||
|
return await self.notice_handler.handle_notice(raw) # type: ignore[return-value]
|
||||||
|
|
||||||
|
# 元事件
|
||||||
|
elif post_type == "meta_event":
|
||||||
|
return await self.meta_event_handler.handle_meta_event(raw) # type: ignore[return-value]
|
||||||
|
|
||||||
|
# 未知事件类型
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的事件类型: {post_type}")
|
||||||
|
return self._create_empty_envelope() # type: ignore[return-value]
|
||||||
|
|
||||||
|
async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
将 MessageEnvelope 转换并发送到 Napcat
|
||||||
|
|
||||||
|
这里不直接通过 WebSocket 发送 envelope,
|
||||||
|
而是调用 Napcat API(send_group_msg, send_private_msg 等)
|
||||||
|
"""
|
||||||
|
await self.send_handler.handle_message(envelope)
|
||||||
|
|
||||||
|
def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return]
|
||||||
|
"""创建一个空的消息信封(用于不需要处理的事件)"""
|
||||||
|
import time
|
||||||
|
return {
|
||||||
|
"direction": "incoming",
|
||||||
|
"message_info": {
|
||||||
|
"platform": self.platform,
|
||||||
|
"message_id": str(uuid.uuid4()),
|
||||||
|
"time": time.time(),
|
||||||
|
},
|
||||||
|
"message_segment": {"type": "text", "data": "[系统事件]"},
|
||||||
|
"timestamp_ms": int(time.time() * 1000),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
发送 Napcat API 请求并等待响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: API 动作名称(如 send_group_msg)
|
||||||
|
params: API 参数
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API 响应数据
|
||||||
|
"""
|
||||||
|
if not self._ws:
|
||||||
|
raise RuntimeError("WebSocket 连接未建立")
|
||||||
|
|
||||||
|
# 生成唯一的 echo ID
|
||||||
|
echo = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 创建 Future 用于等待响应
|
||||||
|
future = asyncio.Future()
|
||||||
|
self._response_pool[echo] = future
|
||||||
|
|
||||||
|
# 构造请求
|
||||||
|
request = orjson.dumps({
|
||||||
|
"action": action,
|
||||||
|
"params": params,
|
||||||
|
"echo": echo,
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送请求
|
||||||
|
await self._ws.send(request)
|
||||||
|
|
||||||
|
# 等待响应
|
||||||
|
response = await asyncio.wait_for(future, timeout=timeout)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"API 请求超时: {action}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API 请求失败: {action}, 错误: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 清理响应池
|
||||||
|
self._response_pool.pop(echo, None)
|
||||||
|
|
||||||
|
def get_ws_connection(self):
|
||||||
|
"""获取 WebSocket 连接(用于发送 API 请求)"""
|
||||||
|
if not self._ws:
|
||||||
|
raise RuntimeError("WebSocket 连接未建立")
|
||||||
|
return self._ws
|
||||||
|
|
||||||
|
|
||||||
|
@register_plugin
|
||||||
|
class NapcatAdapterPlugin(BasePlugin):
|
||||||
|
"""Napcat 适配器插件"""
|
||||||
|
|
||||||
|
plugin_name = "napcat_adapter_plugin"
|
||||||
|
enable_plugin = True
|
||||||
|
plugin_version = "2.0.0"
|
||||||
|
plugin_author = "MoFox Team"
|
||||||
|
plugin_description = "Napcat/OneBot 11 适配器(基于 MoFox-Bus 重写)"
|
||||||
|
|
||||||
|
# 配置 Schema
|
||||||
|
config_schema: ClassVar[dict] = {
|
||||||
|
"plugin": {
|
||||||
|
"name": {"type": str, "default": "napcat_adapter_plugin"},
|
||||||
|
"version": {"type": str, "default": "2.0.0"},
|
||||||
|
"enabled": {"type": bool, "default": True},
|
||||||
|
},
|
||||||
|
"napcat_server": {
|
||||||
|
"mode": {
|
||||||
|
"type": str,
|
||||||
|
"default": "reverse",
|
||||||
|
"description": "连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||||
|
},
|
||||||
|
"host": {"type": str, "default": "localhost"},
|
||||||
|
"port": {"type": int, "default": 8095},
|
||||||
|
"url": {"type": str, "default": "", "description": "正向连接时的完整URL"},
|
||||||
|
"access_token": {"type": str, "default": ""},
|
||||||
|
},
|
||||||
|
"features": {
|
||||||
|
"group_list_type": {"type": str, "default": "blacklist"},
|
||||||
|
"group_list": {"type": list, "default": []},
|
||||||
|
"private_list_type": {"type": str, "default": "blacklist"},
|
||||||
|
"private_list": {"type": list, "default": []},
|
||||||
|
"ban_user_id": {"type": list, "default": []},
|
||||||
|
"ban_qq_bot": {"type": bool, "default": False},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, plugin_dir: str = "", metadata: Any = None):
|
||||||
|
# 如果没有提供参数,创建一个默认的元数据
|
||||||
|
if metadata is None:
|
||||||
|
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||||
|
metadata = PluginMetadata(
|
||||||
|
name=self.plugin_name,
|
||||||
|
version=self.plugin_version,
|
||||||
|
author=self.plugin_author,
|
||||||
|
description=self.plugin_description,
|
||||||
|
usage="",
|
||||||
|
dependencies=[],
|
||||||
|
python_dependencies=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not plugin_dir:
|
||||||
|
from pathlib import Path
|
||||||
|
plugin_dir = str(Path(__file__).parent)
|
||||||
|
|
||||||
|
super().__init__(plugin_dir, metadata)
|
||||||
|
self._adapter: Optional[NapcatAdapter] = None
|
||||||
|
|
||||||
|
async def on_plugin_loaded(self):
|
||||||
|
"""插件加载时启动适配器"""
|
||||||
|
logger.info("Napcat 适配器插件正在加载...")
|
||||||
|
|
||||||
|
# 获取核心 Sink
|
||||||
|
from src.common.core_sink import get_core_sink
|
||||||
|
core_sink = get_core_sink()
|
||||||
|
|
||||||
|
# 创建并启动适配器
|
||||||
|
self._adapter = NapcatAdapter(core_sink, plugin=self)
|
||||||
|
await self._adapter.start()
|
||||||
|
|
||||||
|
logger.info("Napcat 适配器插件已加载")
|
||||||
|
|
||||||
|
async def on_plugin_unloaded(self):
|
||||||
|
"""插件卸载时停止适配器"""
|
||||||
|
if self._adapter:
|
||||||
|
await self._adapter.stop()
|
||||||
|
logger.info("Napcat 适配器插件已卸载")
|
||||||
|
|
||||||
|
def get_plugin_components(self) -> list:
|
||||||
|
"""返回适配器组件"""
|
||||||
|
return [(NapcatAdapter.get_adapter_info(), NapcatAdapter)]
|
||||||
1
src/plugins/built_in/NEW_napcat_adapter/src/__init__.py
Normal file
1
src/plugins/built_in/NEW_napcat_adapter/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""工具模块"""
|
||||||
310
src/plugins/built_in/NEW_napcat_adapter/src/event_models.py
Normal file
310
src/plugins/built_in/NEW_napcat_adapter/src/event_models.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class MetaEventType:
|
||||||
|
lifecycle = "lifecycle" # 生命周期
|
||||||
|
|
||||||
|
class Lifecycle:
|
||||||
|
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||||
|
|
||||||
|
heartbeat = "heartbeat" # 心跳
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType: # 接受消息大类
|
||||||
|
private = "private" # 私聊消息
|
||||||
|
|
||||||
|
class Private:
|
||||||
|
friend = "friend" # 私聊消息 - 好友
|
||||||
|
group = "group" # 私聊消息 - 群临时
|
||||||
|
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||||
|
other = "other" # 私聊消息 - 其他
|
||||||
|
|
||||||
|
group = "group" # 群聊消息
|
||||||
|
|
||||||
|
class Group:
|
||||||
|
normal = "normal" # 群聊消息 - 普通
|
||||||
|
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||||
|
notice = "notice" # 群聊消息 - 系统提示
|
||||||
|
|
||||||
|
|
||||||
|
class NoticeType: # 通知事件
|
||||||
|
friend_recall = "friend_recall" # 私聊消息撤回
|
||||||
|
group_recall = "group_recall" # 群聊消息撤回
|
||||||
|
notify = "notify"
|
||||||
|
group_ban = "group_ban" # 群禁言
|
||||||
|
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||||
|
group_upload = "group_upload" # 群文件上传
|
||||||
|
|
||||||
|
class Notify:
|
||||||
|
poke = "poke" # 戳一戳
|
||||||
|
input_status = "input_status" # 正在输入
|
||||||
|
|
||||||
|
class GroupBan:
|
||||||
|
ban = "ban" # 禁言
|
||||||
|
lift_ban = "lift_ban" # 解除禁言
|
||||||
|
|
||||||
|
|
||||||
|
class RealMessageType: # 实际消息分类
|
||||||
|
text = "text" # 纯文本
|
||||||
|
face = "face" # qq表情
|
||||||
|
image = "image" # 图片
|
||||||
|
record = "record" # 语音
|
||||||
|
video = "video" # 视频
|
||||||
|
at = "at" # @某人
|
||||||
|
rps = "rps" # 猜拳魔法表情
|
||||||
|
dice = "dice" # 骰子
|
||||||
|
shake = "shake" # 私聊窗口抖动(只收)
|
||||||
|
poke = "poke" # 群聊戳一戳
|
||||||
|
share = "share" # 链接分享(json形式)
|
||||||
|
reply = "reply" # 回复消息
|
||||||
|
forward = "forward" # 转发消息
|
||||||
|
node = "node" # 转发消息节点
|
||||||
|
json = "json" # json消息
|
||||||
|
file = "file" # 文件
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSentType:
|
||||||
|
private = "private"
|
||||||
|
|
||||||
|
class Private:
|
||||||
|
friend = "friend"
|
||||||
|
group = "group"
|
||||||
|
|
||||||
|
group = "group"
|
||||||
|
|
||||||
|
class Group:
|
||||||
|
normal = "normal"
|
||||||
|
|
||||||
|
|
||||||
|
class CommandType(Enum):
|
||||||
|
"""命令类型"""
|
||||||
|
|
||||||
|
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||||
|
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||||
|
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||||
|
SEND_POKE = "send_poke" # 戳一戳
|
||||||
|
DELETE_MSG = "delete_msg" # 撤回消息
|
||||||
|
AI_VOICE_SEND = "ai_voice_send" # AI语音发送
|
||||||
|
SET_EMOJI_LIKE = "set_msg_emoji_like" # 设置表情回应
|
||||||
|
SEND_AT_MESSAGE = "send_at_message" # 发送@消息
|
||||||
|
SEND_LIKE = "send_like" # 发送点赞
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
# 支持的消息格式
|
||||||
|
ACCEPT_FORMAT = [
|
||||||
|
"text",
|
||||||
|
"image",
|
||||||
|
"emoji",
|
||||||
|
"reply",
|
||||||
|
"voice",
|
||||||
|
"command",
|
||||||
|
"voiceurl",
|
||||||
|
"music",
|
||||||
|
"videourl",
|
||||||
|
"file",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 插件名称
|
||||||
|
PLUGIN_NAME = "NEW_napcat_adapter"
|
||||||
|
|
||||||
|
# QQ表情映射表
|
||||||
|
QQ_FACE = {
|
||||||
|
"0": "[表情:惊讶]",
|
||||||
|
"1": "[表情:撇嘴]",
|
||||||
|
"2": "[表情:色]",
|
||||||
|
"3": "[表情:发呆]",
|
||||||
|
"4": "[表情:得意]",
|
||||||
|
"5": "[表情:流泪]",
|
||||||
|
"6": "[表情:害羞]",
|
||||||
|
"7": "[表情:闭嘴]",
|
||||||
|
"8": "[表情:睡]",
|
||||||
|
"9": "[表情:大哭]",
|
||||||
|
"10": "[表情:尴尬]",
|
||||||
|
"11": "[表情:发怒]",
|
||||||
|
"12": "[表情:调皮]",
|
||||||
|
"13": "[表情:呲牙]",
|
||||||
|
"14": "[表情:微笑]",
|
||||||
|
"15": "[表情:难过]",
|
||||||
|
"16": "[表情:酷]",
|
||||||
|
"18": "[表情:抓狂]",
|
||||||
|
"19": "[表情:吐]",
|
||||||
|
"20": "[表情:偷笑]",
|
||||||
|
"21": "[表情:可爱]",
|
||||||
|
"22": "[表情:白眼]",
|
||||||
|
"23": "[表情:傲慢]",
|
||||||
|
"24": "[表情:饥饿]",
|
||||||
|
"25": "[表情:困]",
|
||||||
|
"26": "[表情:惊恐]",
|
||||||
|
"27": "[表情:流汗]",
|
||||||
|
"28": "[表情:憨笑]",
|
||||||
|
"29": "[表情:悠闲]",
|
||||||
|
"30": "[表情:奋斗]",
|
||||||
|
"31": "[表情:咒骂]",
|
||||||
|
"32": "[表情:疑问]",
|
||||||
|
"33": "[表情:嘘]",
|
||||||
|
"34": "[表情:晕]",
|
||||||
|
"35": "[表情:折磨]",
|
||||||
|
"36": "[表情:衰]",
|
||||||
|
"37": "[表情:骷髅]",
|
||||||
|
"38": "[表情:敲打]",
|
||||||
|
"39": "[表情:再见]",
|
||||||
|
"41": "[表情:发抖]",
|
||||||
|
"42": "[表情:爱情]",
|
||||||
|
"43": "[表情:跳跳]",
|
||||||
|
"46": "[表情:猪头]",
|
||||||
|
"49": "[表情:拥抱]",
|
||||||
|
"53": "[表情:蛋糕]",
|
||||||
|
"56": "[表情:刀]",
|
||||||
|
"59": "[表情:便便]",
|
||||||
|
"60": "[表情:咖啡]",
|
||||||
|
"63": "[表情:玫瑰]",
|
||||||
|
"64": "[表情:凋谢]",
|
||||||
|
"66": "[表情:爱心]",
|
||||||
|
"67": "[表情:心碎]",
|
||||||
|
"74": "[表情:太阳]",
|
||||||
|
"75": "[表情:月亮]",
|
||||||
|
"76": "[表情:赞]",
|
||||||
|
"77": "[表情:踩]",
|
||||||
|
"78": "[表情:握手]",
|
||||||
|
"79": "[表情:胜利]",
|
||||||
|
"85": "[表情:飞吻]",
|
||||||
|
"86": "[表情:怄火]",
|
||||||
|
"89": "[表情:西瓜]",
|
||||||
|
"96": "[表情:冷汗]",
|
||||||
|
"97": "[表情:擦汗]",
|
||||||
|
"98": "[表情:抠鼻]",
|
||||||
|
"99": "[表情:鼓掌]",
|
||||||
|
"100": "[表情:糗大了]",
|
||||||
|
"101": "[表情:坏笑]",
|
||||||
|
"102": "[表情:左哼哼]",
|
||||||
|
"103": "[表情:右哼哼]",
|
||||||
|
"104": "[表情:哈欠]",
|
||||||
|
"105": "[表情:鄙视]",
|
||||||
|
"106": "[表情:委屈]",
|
||||||
|
"107": "[表情:快哭了]",
|
||||||
|
"108": "[表情:阴险]",
|
||||||
|
"109": "[表情:左亲亲]",
|
||||||
|
"110": "[表情:吓]",
|
||||||
|
"111": "[表情:可怜]",
|
||||||
|
"112": "[表情:菜刀]",
|
||||||
|
"114": "[表情:篮球]",
|
||||||
|
"116": "[表情:示爱]",
|
||||||
|
"118": "[表情:抱拳]",
|
||||||
|
"119": "[表情:勾引]",
|
||||||
|
"120": "[表情:拳头]",
|
||||||
|
"121": "[表情:差劲]",
|
||||||
|
"123": "[表情:NO]",
|
||||||
|
"124": "[表情:OK]",
|
||||||
|
"125": "[表情:转圈]",
|
||||||
|
"129": "[表情:挥手]",
|
||||||
|
"137": "[表情:鞭炮]",
|
||||||
|
"144": "[表情:喝彩]",
|
||||||
|
"146": "[表情:爆筋]",
|
||||||
|
"147": "[表情:棒棒糖]",
|
||||||
|
"169": "[表情:手枪]",
|
||||||
|
"171": "[表情:茶]",
|
||||||
|
"172": "[表情:眨眼睛]",
|
||||||
|
"173": "[表情:泪奔]",
|
||||||
|
"174": "[表情:无奈]",
|
||||||
|
"175": "[表情:卖萌]",
|
||||||
|
"176": "[表情:小纠结]",
|
||||||
|
"177": "[表情:喷血]",
|
||||||
|
"178": "[表情:斜眼笑]",
|
||||||
|
"179": "[表情:doge]",
|
||||||
|
"181": "[表情:戳一戳]",
|
||||||
|
"182": "[表情:笑哭]",
|
||||||
|
"183": "[表情:我最美]",
|
||||||
|
"185": "[表情:羊驼]",
|
||||||
|
"187": "[表情:幽灵]",
|
||||||
|
"201": "[表情:点赞]",
|
||||||
|
"212": "[表情:托腮]",
|
||||||
|
"262": "[表情:脑阔疼]",
|
||||||
|
"263": "[表情:沧桑]",
|
||||||
|
"264": "[表情:捂脸]",
|
||||||
|
"265": "[表情:辣眼睛]",
|
||||||
|
"266": "[表情:哦哟]",
|
||||||
|
"267": "[表情:头秃]",
|
||||||
|
"268": "[表情:问号脸]",
|
||||||
|
"269": "[表情:暗中观察]",
|
||||||
|
"270": "[表情:emm]",
|
||||||
|
"271": "[表情:吃瓜]",
|
||||||
|
"272": "[表情:呵呵哒]",
|
||||||
|
"273": "[表情:我酸了]",
|
||||||
|
"277": "[表情:滑稽狗头]",
|
||||||
|
"281": "[表情:翻白眼]",
|
||||||
|
"282": "[表情:敬礼]",
|
||||||
|
"283": "[表情:狂笑]",
|
||||||
|
"284": "[表情:面无表情]",
|
||||||
|
"285": "[表情:摸鱼]",
|
||||||
|
"286": "[表情:魔鬼笑]",
|
||||||
|
"287": "[表情:哦]",
|
||||||
|
"289": "[表情:睁眼]",
|
||||||
|
"293": "[表情:摸锦鲤]",
|
||||||
|
"294": "[表情:期待]",
|
||||||
|
"295": "[表情:拿到红包]",
|
||||||
|
"297": "[表情:拜谢]",
|
||||||
|
"298": "[表情:元宝]",
|
||||||
|
"299": "[表情:牛啊]",
|
||||||
|
"300": "[表情:胖三斤]",
|
||||||
|
"302": "[表情:左拜年]",
|
||||||
|
"303": "[表情:右拜年]",
|
||||||
|
"305": "[表情:右亲亲]",
|
||||||
|
"306": "[表情:牛气冲天]",
|
||||||
|
"307": "[表情:喵喵]",
|
||||||
|
"311": "[表情:打call]",
|
||||||
|
"312": "[表情:变形]",
|
||||||
|
"314": "[表情:仔细分析]",
|
||||||
|
"317": "[表情:菜汪]",
|
||||||
|
"318": "[表情:崇拜]",
|
||||||
|
"319": "[表情:比心]",
|
||||||
|
"320": "[表情:庆祝]",
|
||||||
|
"323": "[表情:嫌弃]",
|
||||||
|
"324": "[表情:吃糖]",
|
||||||
|
"325": "[表情:惊吓]",
|
||||||
|
"326": "[表情:生气]",
|
||||||
|
"332": "[表情:举牌牌]",
|
||||||
|
"333": "[表情:烟花]",
|
||||||
|
"334": "[表情:虎虎生威]",
|
||||||
|
"336": "[表情:豹富]",
|
||||||
|
"337": "[表情:花朵脸]",
|
||||||
|
"338": "[表情:我想开了]",
|
||||||
|
"339": "[表情:舔屏]",
|
||||||
|
"341": "[表情:打招呼]",
|
||||||
|
"342": "[表情:酸Q]",
|
||||||
|
"343": "[表情:我方了]",
|
||||||
|
"344": "[表情:大怨种]",
|
||||||
|
"345": "[表情:红包多多]",
|
||||||
|
"346": "[表情:你真棒棒]",
|
||||||
|
"347": "[表情:大展宏兔]",
|
||||||
|
"349": "[表情:坚强]",
|
||||||
|
"350": "[表情:贴贴]",
|
||||||
|
"351": "[表情:敲敲]",
|
||||||
|
"352": "[表情:咦]",
|
||||||
|
"353": "[表情:拜托]",
|
||||||
|
"354": "[表情:尊嘟假嘟]",
|
||||||
|
"355": "[表情:耶]",
|
||||||
|
"356": "[表情:666]",
|
||||||
|
"357": "[表情:裂开]",
|
||||||
|
"392": "[表情:龙年快乐]",
|
||||||
|
"393": "[表情:新年中龙]",
|
||||||
|
"394": "[表情:新年大龙]",
|
||||||
|
"395": "[表情:略略略]",
|
||||||
|
"396": "[表情:龙年快乐]",
|
||||||
|
"424": "[表情:按钮]",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MetaEventType",
|
||||||
|
"MessageType",
|
||||||
|
"NoticeType",
|
||||||
|
"RealMessageType",
|
||||||
|
"MessageSentType",
|
||||||
|
"CommandType",
|
||||||
|
"ACCEPT_FORMAT",
|
||||||
|
"PLUGIN_NAME",
|
||||||
|
"QQ_FACE",
|
||||||
|
]
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""处理器模块"""
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""接收方向处理器"""
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.apis import config_api
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...plugin import NapcatAdapter
|
||||||
|
|
||||||
|
logger = get_logger("napcat_adapter.message_handler")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHandler:
|
||||||
|
"""处理来自 Napcat 的消息事件"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""设置插件配置"""
|
||||||
|
self.plugin_config = config
|
||||||
|
|
||||||
|
async def handle_raw_message(self, raw: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
处理原始消息并转换为 MessageEnvelope
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw: OneBot 原始消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageEnvelope (dict)
|
||||||
|
"""
|
||||||
|
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload
|
||||||
|
|
||||||
|
message_type = raw.get("message_type")
|
||||||
|
message_id = str(raw.get("message_id", ""))
|
||||||
|
message_time = time.time()
|
||||||
|
|
||||||
|
# 构造用户信息
|
||||||
|
sender_info = raw.get("sender", {})
|
||||||
|
user_info: UserInfoPayload = {
|
||||||
|
"platform": "qq",
|
||||||
|
"user_id": str(sender_info.get("user_id", "")),
|
||||||
|
"user_nickname": sender_info.get("nickname", ""),
|
||||||
|
"user_cardname": sender_info.get("card", ""),
|
||||||
|
"user_avatar": sender_info.get("avatar", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构造群组信息(如果是群消息)
|
||||||
|
group_info: Optional[GroupInfoPayload] = None
|
||||||
|
if message_type == "group":
|
||||||
|
group_id = raw.get("group_id")
|
||||||
|
if group_id:
|
||||||
|
group_info = {
|
||||||
|
"platform": "qq",
|
||||||
|
"group_id": str(group_id),
|
||||||
|
"group_name": "", # 可以通过 API 获取
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析消息段
|
||||||
|
message_segments = raw.get("message", [])
|
||||||
|
seg_list: List[SegPayload] = []
|
||||||
|
|
||||||
|
for seg in message_segments:
|
||||||
|
seg_type = seg.get("type", "")
|
||||||
|
seg_data = seg.get("data", {})
|
||||||
|
|
||||||
|
# 转换为 SegPayload
|
||||||
|
if seg_type == "text":
|
||||||
|
seg_list.append({
|
||||||
|
"type": "text",
|
||||||
|
"data": seg_data.get("text", "")
|
||||||
|
})
|
||||||
|
elif seg_type == "image":
|
||||||
|
# 这里需要下载图片并转换为 base64(简化版本)
|
||||||
|
seg_list.append({
|
||||||
|
"type": "image",
|
||||||
|
"data": seg_data.get("url", "") # 实际应该转换为 base64
|
||||||
|
})
|
||||||
|
elif seg_type == "at":
|
||||||
|
seg_list.append({
|
||||||
|
"type": "at",
|
||||||
|
"data": f"{seg_data.get('qq', '')}"
|
||||||
|
})
|
||||||
|
# 其他消息类型...
|
||||||
|
|
||||||
|
# 构造 MessageInfoPayload
|
||||||
|
message_info = {
|
||||||
|
"platform": "qq",
|
||||||
|
"message_id": message_id,
|
||||||
|
"time": message_time,
|
||||||
|
"user_info": user_info,
|
||||||
|
"format_info": {
|
||||||
|
"content_format": ["text", "image"], # 根据实际消息类型设置
|
||||||
|
"accept_format": ["text", "image", "emoji", "voice"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加群组信息(如果存在)
|
||||||
|
if group_info:
|
||||||
|
message_info["group_info"] = group_info
|
||||||
|
|
||||||
|
# 构造 MessageEnvelope
|
||||||
|
envelope = {
|
||||||
|
"direction": "incoming",
|
||||||
|
"message_info": message_info,
|
||||||
|
"message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}),
|
||||||
|
"raw_message": raw.get("raw_message", ""),
|
||||||
|
"platform": "qq",
|
||||||
|
"message_id": message_id,
|
||||||
|
"timestamp_ms": int(message_time * 1000),
|
||||||
|
}
|
||||||
|
|
||||||
|
return envelope
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""元事件处理器"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...plugin import NapcatAdapter
|
||||||
|
|
||||||
|
logger = get_logger("napcat_adapter.meta_event_handler")
|
||||||
|
|
||||||
|
|
||||||
|
class MetaEventHandler:
|
||||||
|
"""处理 Napcat 元事件(心跳、生命周期)"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""设置插件配置"""
|
||||||
|
self.plugin_config = config
|
||||||
|
|
||||||
|
async def handle_meta_event(self, raw: Dict[str, Any]):
|
||||||
|
"""处理元事件"""
|
||||||
|
# 简化版本:返回一个空的 MessageEnvelope
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
return {
|
||||||
|
"direction": "incoming",
|
||||||
|
"message_info": {
|
||||||
|
"platform": "qq",
|
||||||
|
"message_id": str(uuid.uuid4()),
|
||||||
|
"time": time.time(),
|
||||||
|
},
|
||||||
|
"message_segment": {"type": "text", "data": "[元事件]"},
|
||||||
|
"timestamp_ms": int(time.time() * 1000),
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""通知事件处理器"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...plugin import NapcatAdapter
|
||||||
|
|
||||||
|
logger = get_logger("napcat_adapter.notice_handler")
|
||||||
|
|
||||||
|
|
||||||
|
class NoticeHandler:
|
||||||
|
"""处理 Napcat 通知事件(戳一戳、表情回复等)"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""设置插件配置"""
|
||||||
|
self.plugin_config = config
|
||||||
|
|
||||||
|
async def handle_notice(self, raw: Dict[str, Any]):
|
||||||
|
"""处理通知事件"""
|
||||||
|
# 简化版本:返回一个空的 MessageEnvelope
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
return {
|
||||||
|
"direction": "incoming",
|
||||||
|
"message_info": {
|
||||||
|
"platform": "qq",
|
||||||
|
"message_id": str(uuid.uuid4()),
|
||||||
|
"time": time.time(),
|
||||||
|
},
|
||||||
|
"message_segment": {"type": "text", "data": "[通知事件]"},
|
||||||
|
"timestamp_ms": int(time.time() * 1000),
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""发送方向处理器"""
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...plugin import NapcatAdapter
|
||||||
|
|
||||||
|
logger = get_logger("napcat_adapter.send_handler")
|
||||||
|
|
||||||
|
|
||||||
|
class SendHandler:
|
||||||
|
"""处理向 Napcat 发送消息"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "NapcatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""设置插件配置"""
|
||||||
|
self.plugin_config = config
|
||||||
|
|
||||||
|
async def handle_message(self, envelope) -> None:
|
||||||
|
"""
|
||||||
|
处理发送消息
|
||||||
|
|
||||||
|
将 MessageEnvelope 转换为 OneBot API 调用
|
||||||
|
"""
|
||||||
|
message_info = envelope.get("message_info", {})
|
||||||
|
message_segment = envelope.get("message_segment", {})
|
||||||
|
|
||||||
|
# 获取群组和用户信息
|
||||||
|
group_info = message_info.get("group_info")
|
||||||
|
user_info = message_info.get("user_info")
|
||||||
|
|
||||||
|
# 构造消息内容
|
||||||
|
message = self._convert_seg_to_onebot(message_segment)
|
||||||
|
|
||||||
|
# 发送消息
|
||||||
|
if group_info:
|
||||||
|
# 发送群消息
|
||||||
|
group_id = group_info.get("group_id")
|
||||||
|
if group_id:
|
||||||
|
await self.adapter.send_napcat_api("send_group_msg", {
|
||||||
|
"group_id": int(group_id),
|
||||||
|
"message": message,
|
||||||
|
})
|
||||||
|
elif user_info:
|
||||||
|
# 发送私聊消息
|
||||||
|
user_id = user_info.get("user_id")
|
||||||
|
if user_id:
|
||||||
|
await self.adapter.send_napcat_api("send_private_msg", {
|
||||||
|
"user_id": int(user_id),
|
||||||
|
"message": message,
|
||||||
|
})
|
||||||
|
|
||||||
|
def _convert_seg_to_onebot(self, seg: Dict[str, Any]) -> list:
|
||||||
|
"""将 SegPayload 转换为 OneBot 消息格式"""
|
||||||
|
seg_type = seg.get("type", "")
|
||||||
|
seg_data = seg.get("data", "")
|
||||||
|
|
||||||
|
if seg_type == "text":
|
||||||
|
return [{"type": "text", "data": {"text": seg_data}}]
|
||||||
|
elif seg_type == "image":
|
||||||
|
return [{"type": "image", "data": {"file": f"base64://{seg_data}"}}]
|
||||||
|
elif seg_type == "seglist":
|
||||||
|
# 递归处理列表
|
||||||
|
result = []
|
||||||
|
for sub_seg in seg_data:
|
||||||
|
result.extend(self._convert_seg_to_onebot(sub_seg))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# 默认作为文本
|
||||||
|
return [{"type": "text", "data": {"text": str(seg_data)}}]
|
||||||
350
src/plugins/built_in/NEW_napcat_adapter/stream_router.py
Normal file
350
src/plugins/built_in/NEW_napcat_adapter/stream_router.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
按聊天流分配消费者的消息路由系统
|
||||||
|
|
||||||
|
核心思想:
|
||||||
|
- 为每个活跃的聊天流(stream_id)创建独立的消息队列和消费者协程
|
||||||
|
- 同一聊天流的消息由同一个 worker 处理,保证顺序性
|
||||||
|
- 不同聊天流的消息并发处理,提高吞吐量
|
||||||
|
- 动态管理流的生命周期,自动清理不活跃的流
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("stream_router")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamConsumer:
|
||||||
|
"""单个聊天流的消息消费者
|
||||||
|
|
||||||
|
维护独立的消息队列和处理协程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream_id: str, queue_maxsize: int = 100):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||||
|
self.worker_task: Optional[asyncio.Task] = None
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# 性能统计
|
||||||
|
self.stats = {
|
||||||
|
"total_messages": 0,
|
||||||
|
"total_processing_time": 0.0,
|
||||||
|
"queue_overflow_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动消费者"""
|
||||||
|
if not self.is_running:
|
||||||
|
self.is_running = True
|
||||||
|
self.worker_task = asyncio.create_task(self._process_loop())
|
||||||
|
logger.debug(f"Stream Consumer 启动: {self.stream_id}")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止消费者"""
|
||||||
|
self.is_running = False
|
||||||
|
if self.worker_task:
|
||||||
|
self.worker_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.worker_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.debug(f"Stream Consumer 停止: {self.stream_id}")
|
||||||
|
|
||||||
|
async def enqueue(self, message: dict) -> None:
|
||||||
|
"""将消息加入队列"""
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用 put_nowait 避免阻塞路由器
|
||||||
|
self.queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
self.stats["queue_overflow_count"] += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Stream {self.stream_id} 队列已满 "
|
||||||
|
f"({self.queue.qsize()}/{self.queue.maxsize}),"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.queue.get_nowait()
|
||||||
|
self.queue.put_nowait(message)
|
||||||
|
logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息")
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _process_loop(self) -> None:
|
||||||
|
"""消息处理循环"""
|
||||||
|
# 延迟导入,避免循环依赖
|
||||||
|
from .recv_handler.message_handler import message_handler
|
||||||
|
from .recv_handler.meta_event_handler import meta_event_handler
|
||||||
|
from .recv_handler.notice_handler import notice_handler
|
||||||
|
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环启动")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# 等待消息,1秒超时
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
self.queue.get(),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 处理消息
|
||||||
|
post_type = message.get("post_type")
|
||||||
|
if post_type == "message":
|
||||||
|
await message_handler.handle_raw_message(message)
|
||||||
|
elif post_type == "meta_event":
|
||||||
|
await meta_event_handler.handle_meta_event(message)
|
||||||
|
elif post_type == "notice":
|
||||||
|
await notice_handler.handle_notice(message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的 post_type: {post_type}")
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["total_messages"] += 1
|
||||||
|
self.stats["total_processing_time"] += processing_time
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
self.queue.task_done()
|
||||||
|
|
||||||
|
# 性能监控(每100条消息输出一次)
|
||||||
|
if self.stats["total_messages"] % 100 == 0:
|
||||||
|
avg_time = self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||||
|
logger.info(
|
||||||
|
f"Stream {self.stream_id[:30]}... 统计: "
|
||||||
|
f"消息数={self.stats['total_messages']}, "
|
||||||
|
f"平均耗时={avg_time:.3f}秒, "
|
||||||
|
f"队列长度={self.queue.qsize()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动态延迟:队列空时短暂休眠
|
||||||
|
if self.queue.qsize() == 0:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 超时是正常的,继续循环
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环被取消")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True)
|
||||||
|
# 继续处理下一条消息
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环结束")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""获取性能统计"""
|
||||||
|
avg_time = (
|
||||||
|
self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||||
|
if self.stats["total_messages"] > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stream_id": self.stream_id,
|
||||||
|
"queue_size": self.queue.qsize(),
|
||||||
|
"total_messages": self.stats["total_messages"],
|
||||||
|
"avg_processing_time": avg_time,
|
||||||
|
"queue_overflow_count": self.stats["queue_overflow_count"],
|
||||||
|
"last_active_time": self.last_active_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StreamRouter:
|
||||||
|
"""流路由器
|
||||||
|
|
||||||
|
负责将消息路由到对应的聊天流队列
|
||||||
|
动态管理聊天流的生命周期
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_streams: int = 500,
|
||||||
|
stream_timeout: int = 600,
|
||||||
|
stream_queue_size: int = 100,
|
||||||
|
cleanup_interval: int = 60,
|
||||||
|
):
|
||||||
|
self.streams: Dict[str, StreamConsumer] = {}
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.max_streams = max_streams
|
||||||
|
self.stream_timeout = stream_timeout
|
||||||
|
self.stream_queue_size = stream_queue_size
|
||||||
|
self.cleanup_interval = cleanup_interval
|
||||||
|
self.cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动路由器"""
|
||||||
|
if not self.is_running:
|
||||||
|
self.is_running = True
|
||||||
|
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info(
|
||||||
|
f"StreamRouter 已启动 - "
|
||||||
|
f"最大流数: {self.max_streams}, "
|
||||||
|
f"超时: {self.stream_timeout}秒, "
|
||||||
|
f"队列大小: {self.stream_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止路由器"""
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
if self.cleanup_task:
|
||||||
|
self.cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 停止所有流消费者
|
||||||
|
logger.info(f"正在停止 {len(self.streams)} 个流消费者...")
|
||||||
|
for consumer in self.streams.values():
|
||||||
|
await consumer.stop()
|
||||||
|
|
||||||
|
self.streams.clear()
|
||||||
|
logger.info("StreamRouter 已停止")
|
||||||
|
|
||||||
|
async def route_message(self, message: dict) -> None:
|
||||||
|
"""路由消息到对应的流"""
|
||||||
|
stream_id = self._extract_stream_id(message)
|
||||||
|
|
||||||
|
# 快速路径:流已存在
|
||||||
|
if stream_id in self.streams:
|
||||||
|
await self.streams[stream_id].enqueue(message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 慢路径:需要创建新流
|
||||||
|
async with self.lock:
|
||||||
|
# 双重检查
|
||||||
|
if stream_id not in self.streams:
|
||||||
|
# 检查流数量限制
|
||||||
|
if len(self.streams) >= self.max_streams:
|
||||||
|
logger.warning(
|
||||||
|
f"达到最大流数量限制 ({self.max_streams}),"
|
||||||
|
f"尝试清理不活跃的流..."
|
||||||
|
)
|
||||||
|
await self._cleanup_inactive_streams()
|
||||||
|
|
||||||
|
# 清理后仍然超限,记录警告但继续创建
|
||||||
|
if len(self.streams) >= self.max_streams:
|
||||||
|
logger.error(
|
||||||
|
f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建新流
|
||||||
|
consumer = StreamConsumer(stream_id, self.stream_queue_size)
|
||||||
|
self.streams[stream_id] = consumer
|
||||||
|
await consumer.start()
|
||||||
|
logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})")
|
||||||
|
|
||||||
|
await self.streams[stream_id].enqueue(message)
|
||||||
|
|
||||||
|
def _extract_stream_id(self, message: dict) -> str:
|
||||||
|
"""从消息中提取 stream_id
|
||||||
|
|
||||||
|
返回格式: platform:id:type
|
||||||
|
例如: qq:123456:group 或 qq:789012:private
|
||||||
|
"""
|
||||||
|
post_type = message.get("post_type")
|
||||||
|
|
||||||
|
# 非消息类型,使用默认流(避免创建过多流)
|
||||||
|
if post_type not in ["message", "notice"]:
|
||||||
|
return "system:meta_event"
|
||||||
|
|
||||||
|
# 消息类型
|
||||||
|
if post_type == "message":
|
||||||
|
message_type = message.get("message_type")
|
||||||
|
if message_type == "group":
|
||||||
|
group_id = message.get("group_id")
|
||||||
|
return f"qq:{group_id}:group"
|
||||||
|
elif message_type == "private":
|
||||||
|
user_id = message.get("user_id")
|
||||||
|
return f"qq:{user_id}:private"
|
||||||
|
|
||||||
|
# notice 类型
|
||||||
|
elif post_type == "notice":
|
||||||
|
group_id = message.get("group_id")
|
||||||
|
if group_id:
|
||||||
|
return f"qq:{group_id}:group"
|
||||||
|
user_id = message.get("user_id")
|
||||||
|
if user_id:
|
||||||
|
return f"qq:{user_id}:private"
|
||||||
|
|
||||||
|
# 未知类型,使用通用流
|
||||||
|
return "unknown:unknown"
|
||||||
|
|
||||||
|
async def _cleanup_inactive_streams(self) -> None:
|
||||||
|
"""清理不活跃的流"""
|
||||||
|
current_time = time.time()
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for stream_id, consumer in self.streams.items():
|
||||||
|
if current_time - consumer.last_active_time > self.stream_timeout:
|
||||||
|
to_remove.append(stream_id)
|
||||||
|
|
||||||
|
for stream_id in to_remove:
|
||||||
|
await self.streams[stream_id].stop()
|
||||||
|
del self.streams[stream_id]
|
||||||
|
logger.debug(f"清理不活跃的流: {stream_id}")
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.info(
|
||||||
|
f"清理了 {len(to_remove)} 个不活跃的流 "
|
||||||
|
f"(当前活跃流: {len(self.streams)}/{self.max_streams})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""定期清理循环"""
|
||||||
|
logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}秒")
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
await self._cleanup_inactive_streams()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("清理循环已停止")
|
||||||
|
|
||||||
|
def get_all_stats(self) -> list[dict]:
|
||||||
|
"""获取所有流的统计信息"""
|
||||||
|
return [consumer.get_stats() for consumer in self.streams.values()]
|
||||||
|
|
||||||
|
def get_summary(self) -> dict:
|
||||||
|
"""获取路由器摘要"""
|
||||||
|
total_messages = sum(c.stats["total_messages"] for c in self.streams.values())
|
||||||
|
total_queue_size = sum(c.queue.qsize() for c in self.streams.values())
|
||||||
|
total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values())
|
||||||
|
|
||||||
|
# 计算平均队列长度
|
||||||
|
avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0
|
||||||
|
|
||||||
|
# 找出最繁忙的流
|
||||||
|
busiest_stream = None
|
||||||
|
if self.streams:
|
||||||
|
busiest_stream = max(
|
||||||
|
self.streams.values(),
|
||||||
|
key=lambda c: c.stats["total_messages"]
|
||||||
|
).stream_id
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_streams": len(self.streams),
|
||||||
|
"max_streams": self.max_streams,
|
||||||
|
"total_messages_processed": total_messages,
|
||||||
|
"total_queue_size": total_queue_size,
|
||||||
|
"avg_queue_size": avg_queue_size,
|
||||||
|
"total_queue_overflows": total_overflows,
|
||||||
|
"busiest_stream": busiest_stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局路由器实例
|
||||||
|
stream_router = StreamRouter()
|
||||||
@@ -236,8 +236,6 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
def enable_plugin(self) -> bool:
|
def enable_plugin(self) -> bool:
|
||||||
"""通过配置文件动态控制插件启用状态"""
|
"""通过配置文件动态控制插件启用状态"""
|
||||||
# 如果已经通过配置加载了状态,使用配置中的值
|
# 如果已经通过配置加载了状态,使用配置中的值
|
||||||
if hasattr(self, "_is_enabled"):
|
|
||||||
return self._is_enabled
|
|
||||||
# 否则使用默认值(禁用状态)
|
# 否则使用默认值(禁用状态)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user