From 36fce6ca98a42e291bb73fb91738954bc1b89ff1 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Mon, 24 Nov 2025 13:24:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=B8=A6=E6=9C=89?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E5=92=8C=E8=B7=AF=E7=94=B1?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E7=9A=84NEW=5Fnapcat=5Fadapter=E6=8F=92?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为NEW_napcat_adapter插件实现了核心模块,包括消息处理、事件处理和路由。 - 创建了MessageHandler、MetaEventHandler和NoticeHandler来处理收到的消息和事件。 - 开发了SendHandler,用于向Napcat发送回消息。 引入了StreamRouter来管理多个聊天流,确保消息的顺序和高效处理。 - 增加了对各种消息类型和格式的支持,包括文本、图像和通知。 - 建立了一个用于监控和调试的日志系统。 --- docs/mofox_bus.md | 21 +- src/common/core_sink.py | 35 ++ src/common/message/envelope_converter.py | 492 ++++++++++-------- src/mofox_bus/__init__.py | 61 +-- src/mofox_bus/adapter_utils.py | 305 +++++++++-- src/mofox_bus/api.py | 217 ++++++-- src/mofox_bus/builder.py | 110 ++++ src/mofox_bus/codec.py | 23 +- src/mofox_bus/message_models.py | 189 ------- src/mofox_bus/router.py | 74 +-- src/mofox_bus/runtime.py | 190 ++++++- src/mofox_bus/types.py | 199 +++---- src/plugin_system/base/base_adapter.py | 6 +- src/plugin_system/core/adapter_manager.py | 219 ++++---- .../built_in/NEW_napcat_adapter/README.md | 427 +++++++++++++++ .../built_in/NEW_napcat_adapter/__init__.py | 16 + .../built_in/NEW_napcat_adapter/plugin.py | 330 ++++++++++++ .../NEW_napcat_adapter/src/__init__.py | 1 + .../NEW_napcat_adapter/src/event_models.py | 310 +++++++++++ .../src/handlers/__init__.py | 1 + .../src/handlers/to_core/__init__.py | 1 + .../src/handlers/to_core/message_handler.py | 126 +++++ .../handlers/to_core/meta_event_handler.py | 41 ++ .../src/handlers/to_core/notice_handler.py | 41 ++ .../src/handlers/to_napcat/__init__.py | 1 + .../src/handlers/to_napcat/send_handler.py | 77 +++ .../NEW_napcat_adapter/stream_router.py | 350 +++++++++++++ .../built_in/napcat_adapter_plugin/plugin.py | 2 - 28 files changed, 3041 insertions(+), 824 deletions(-) create mode 100644 src/common/core_sink.py create mode 100644 src/mofox_bus/builder.py delete mode 100644 src/mofox_bus/message_models.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/README.md create mode 100644 src/plugins/built_in/NEW_napcat_adapter/__init__.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/plugin.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/__init__.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/event_models.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/__init__.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/__init__.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/meta_event_handler.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/notice_handler.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/__init__.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/send_handler.py create mode 100644 src/plugins/built_in/NEW_napcat_adapter/stream_router.py diff --git a/docs/mofox_bus.md b/docs/mofox_bus.md index aaa4a3951..d7dc68427 100644 --- a/docs/mofox_bus.md +++ b/docs/mofox_bus.md @@ -32,11 +32,11 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess ## 3. 消息模型 -### 3.1 Envelope TypedDict(`types.py`) +### 3.1 Envelope TypedDict��`types.py`�� -- `MessageEnvelope`:核心字段包括 `id`、`direction`、`platform`、`timestamp_ms`、`channel`、`sender`、`content` 等,一律使用毫秒时间戳,保留 `raw_platform_message` 与 `metadata` 便于调试 / 扩展。 -- `Content` 联合类型支持文本、图片、音频、文件、视频、事件、命令、系统消息,后续可扩展更多 literal。 -- `SenderInfo` / `ChannelInfo` / `MessageDirection` / `Role` 等均以 `Literal` 控制取值,方便 IDE 静态检查。 +- `MessageEnvelope` ��ȫ��Ƶ� maim_message �ṹ�����ĵ������� `message_info` + `message_segment` (SegPayload)��`direction`��`schema_version` �� raw �����ֶβ��������ˣ���Ժ����� `channel`��`sender`��`content` �� v0 �ֶΪ��ѡ�� +- `SegPayload` / `MessageInfoPayload` / `UserInfoPayload` / `GroupInfoPayload` / `FormatInfoPayload` / `TemplateInfoPayload` �� maim_message dataclass �Դ�TypedDict ��Ӧ���ʺ�ֱ�� JSON ���� +- `Content` / `SenderInfo` / `ChannelInfo` �Ȳ�Ȼ�����ڣ����ܻ��� IDE ע�⣬Ҳ�Ƕ� v0 content ģ�͵Ļ�֧ ### 3.2 dataclass 消息段(`message_models.py`) @@ -62,15 +62,14 @@ TypedDict 更适合网络传输和依赖注入;dataclass 版 MessageBase 则 ## 5. 运行时调度(`runtime.py`) - `MessageRuntime`: - - `add_route(predicate, handler)` 或 `@runtime.route(...)` 装饰器注册消息处理器。 - - `register_before_hook` / `register_after_hook` / `register_error_hook` 注入监控、埋点、Trace。 - - `set_batch_handler` 支持一次处理整批消息(例如批量落库)。 -- `MessageProcessingError` 在 handler 抛出异常时封装上下文,便于日志追踪。 + - `add_route(predicate, handler)` 和 `@runtime.route(...)` 装饰器注册消息处理器 + - `register_before_hook` / `register_after_hook` / `register_error_hook` 注册前置、后置、Trace 处理 + - `set_batch_handler` 支持一次处理一批消息(可用于 batch IO 优化) +- `MessageProcessingError` 在 handler 抛出异常时包装原因,方便日志追踪。 运行时内部使用 `RLock` 保护路由表,适合多协程并发读写,`_maybe_await` 自动兼容同步/异步 handler。 --- - ## 6. 传输层封装(`transport/`) ### 6.1 HTTP @@ -126,9 +125,9 @@ from mofox_bus.transport import HttpMessageServer runtime = MessageRuntime() -@runtime.route(lambda env: env["content"]["type"] == "text") +@runtime.route(lambda env: (env.get("message_segment") or {}).get("type") == "text") async def handle_text(env: types.MessageEnvelope): - print("收到文本:", env["content"]["text"]) + print("收到文本", env["message_segment"]["data"]) async def http_handler(messages: list[types.MessageEnvelope]): await runtime.handle_batch(messages) diff --git a/src/common/core_sink.py b/src/common/core_sink.py new file mode 100644 index 000000000..261adf397 --- /dev/null +++ b/src/common/core_sink.py @@ -0,0 +1,35 @@ +""" +从 src.main 导出 core_sink 的辅助函数 + +由于 src.main 中实际使用的是 InProcessCoreSink, +我们需要创建一个全局访问点 +""" + +from mofox_bus import CoreSink, InProcessCoreSink + +_global_core_sink: CoreSink | None = None + + +def set_core_sink(sink: CoreSink) -> None: + """设置全局 core sink""" + global _global_core_sink + _global_core_sink = sink + + +def get_core_sink() -> CoreSink: + """获取全局 core sink""" + global _global_core_sink + if _global_core_sink is None: + raise RuntimeError("Core sink 尚未初始化") + return _global_core_sink + + +async def push_outgoing(envelope) -> None: + """将消息推送到 core sink 的 outgoing 通道""" + sink = get_core_sink() + push = getattr(sink, "push_outgoing", None) + if push is None: + raise RuntimeError("当前 core sink 不支持 push_outgoing 方法") + await push(envelope) + +__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"] diff --git a/src/common/message/envelope_converter.py b/src/common/message/envelope_converter.py index 2e59e031b..8434565f7 100644 --- a/src/common/message/envelope_converter.py +++ b/src/common/message/envelope_converter.py @@ -1,15 +1,22 @@ """ -MessageEnvelope 转换器 +MessageEnvelope converter between mofox_bus schema and internal message structures. -将 mofox_bus 的 MessageEnvelope 转换为 MoFox Bot 内部使用的消息格式。 -""" +- 优先处理 maim_message 风格的 message_info + message_segment。 +- 兼容旧版 content/sender/channel 结构,方便逐步迁移。 +""" from __future__ import annotations -import time from typing import Any, Dict, List, Optional -from mofox_bus import MessageEnvelope, BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo +from mofox_bus import ( + BaseMessageInfo, + MessageBase, + MessageEnvelope, + Seg, + UserInfo, + GroupInfo, +) from src.common.logger import get_logger @@ -17,151 +24,221 @@ logger = get_logger("envelope_converter") class EnvelopeConverter: - """MessageEnvelope 到内部消息格式的转换器""" + """MessageEnvelope <-> MessageBase converter.""" @staticmethod def to_message_base(envelope: MessageEnvelope) -> MessageBase: """ - 将 MessageEnvelope 转换为 MessageBase - - Args: - envelope: 统一的消息信封 - - Returns: - MessageBase: 内部消息格式 + Convert MessageEnvelope to MessageBase. """ try: - # 提取基本信息 - platform = envelope["platform"] - channel = envelope["channel"] - sender = envelope["sender"] - content = envelope["content"] - - # 创建 UserInfo - user_info = UserInfo( - user_id=sender["user_id"], - user_nickname=sender.get("display_name", sender["user_id"]), - user_avatar=sender.get("avatar_url"), - ) - - # 创建 GroupInfo (如果是群组消息) - group_info: Optional[GroupInfo] = None - if channel["channel_type"] in ("group", "supergroup", "room"): - group_info = GroupInfo( - group_id=channel["channel_id"], - group_name=channel.get("title", channel["channel_id"]), - ) - - # 创建 BaseMessageInfo - message_info = BaseMessageInfo( - platform=platform, - chat_type="group" if group_info else "private", - message_id=envelope["id"], - user_info=user_info, - group_info=group_info, - timestamp=envelope["timestamp_ms"] / 1000.0, # 转换为秒 - ) - - # 转换 Content 为 Seg 列表 - segments = EnvelopeConverter._content_to_segments(content) - - # 创建 MessageBase - message_base = MessageBase( + # 优先使用 maim_message 样式字段 + info_payload = envelope.get("message_info") or {} + seg_payload = envelope.get("message_segment") or envelope.get("message_chain") + + if info_payload: + message_info = BaseMessageInfo.from_dict(info_payload) + else: + message_info = EnvelopeConverter._build_info_from_legacy(envelope) + + if seg_payload is None: + seg_list = EnvelopeConverter._content_to_segments(envelope.get("content")) + seg_payload = seg_list + + message_segment = EnvelopeConverter._ensure_seg(seg_payload) + raw_message = envelope.get("raw_message") or envelope.get("raw_platform_message") + + return MessageBase( message_info=message_info, - message=segments, + message_segment=message_segment, + raw_message=raw_message, ) - - # 保存原始 envelope 到 raw 字段 - if hasattr(message_base, "raw"): - message_base.raw = envelope - - return message_base - except Exception as e: logger.error(f"转换 MessageEnvelope 失败: {e}", exc_info=True) raise @staticmethod - def _content_to_segments(content: Dict[str, Any]) -> List[Seg]: + def _build_info_from_legacy(envelope: MessageEnvelope) -> BaseMessageInfo: + """将 legacy 字段映射为 BaseMessageInfo。""" + platform = envelope.get("platform") + channel = envelope.get("channel") or {} + sender = envelope.get("sender") or {} + + message_id = envelope.get("id") or envelope.get("message_id") + timestamp_ms = envelope.get("timestamp_ms") + time_value = (timestamp_ms / 1000.0) if timestamp_ms is not None else None + + group_info: Optional[GroupInfo] = None + channel_type = channel.get("channel_type") + if channel_type in ("group", "supergroup", "room"): + group_info = GroupInfo( + platform=platform, + group_id=channel.get("channel_id"), + group_name=channel.get("title"), + ) + + user_info: Optional[UserInfo] = None + if sender: + user_info = UserInfo( + platform=platform, + user_id=str(sender.get("user_id")) if sender.get("user_id") is not None else None, + user_nickname=sender.get("display_name") or sender.get("user_nickname"), + user_avatar=sender.get("avatar_url"), + ) + + return BaseMessageInfo( + platform=platform, + message_id=message_id, + time=time_value, + group_info=group_info, + user_info=user_info, + additional_config=envelope.get("metadata"), + ) + + @staticmethod + def _ensure_seg(payload: Any) -> Seg: + """将任意 payload 转为 Seg dataclass。""" + if isinstance(payload, Seg): + return payload + if isinstance(payload, list): + # 直接传入 Seg 列表或 seglist data + return Seg(type="seglist", data=[EnvelopeConverter._ensure_seg(item) for item in payload]) + if isinstance(payload, dict): + seg_type = payload.get("type") or "text" + data = payload.get("data") + if seg_type == "seglist" and isinstance(data, list): + data = [EnvelopeConverter._ensure_seg(item) for item in data] + return Seg(type=seg_type, data=data) + # 兜底:转成文本片段 + return Seg(type="text", data="" if payload is None else str(payload)) + + @staticmethod + def _flatten_segments(seg: Seg) -> List[Seg]: + """将 Seg/seglist 打平成列表,便于旧 content 转换。""" + if seg.type == "seglist" and isinstance(seg.data, list): + return [item if isinstance(item, Seg) else EnvelopeConverter._ensure_seg(item) for item in seg.data] + return [seg] + + @staticmethod + def _content_to_segments(content: Any) -> List[Seg]: """ - 将 Content 转换为 Seg 列表 - - Args: - content: 消息内容 - - Returns: - List[Seg]: 消息段列表 + Convert legacy Content (type/data/metadata) to a flat list of Seg. """ segments: List[Seg] = [] - content_type = content.get("type") - - if content_type == "text": - # 文本消息 - text = content.get("text", "") - segments.append(Seg.text(text)) - - elif content_type == "image": - # 图片消息 - url = content.get("url", "") - file_id = content.get("file_id") - segments.append(Seg.image(url if url else file_id)) - - elif content_type == "audio": - # 音频消息 - url = content.get("url", "") - file_id = content.get("file_id") - segments.append(Seg.record(url if url else file_id)) - - elif content_type == "video": - # 视频消息 - url = content.get("url", "") - file_id = content.get("file_id") - segments.append(Seg.video(url if url else file_id)) - - elif content_type == "file": - # 文件消息 - url = content.get("url", "") - file_name = content.get("file_name", "file") - # 使用 text 表示文件(或者可以自定义一个 file seg type) - segments.append(Seg.text(f"[文件: {file_name}]")) - - elif content_type == "command": - # 命令消息 - name = content.get("name", "") - args = content.get("args", {}) - # 重构为文本格式 - cmd_text = f"/{name}" - if args: - cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items()) - segments.append(Seg.text(cmd_text)) - - elif content_type == "event": - # 事件消息 - 转换为文本表示 - event_type = content.get("event_type", "unknown") - segments.append(Seg.text(f"[事件: {event_type}]")) - - elif content_type == "system": - # 系统消息 - text = content.get("text", "") - segments.append(Seg.text(f"[系统] {text}")) - - else: - # 未知类型 - 转换为文本 + + def _walk(node: Any) -> None: + if node is None: + return + if isinstance(node, list): + for item in node: + _walk(item) + return + if not isinstance(node, dict): + logger.warning("未知的 content 节点类型: %s", type(node)) + return + + content_type = node.get("type") + data = node.get("data") + metadata = node.get("metadata") or {} + + if content_type == "collection": + items = data if isinstance(data, list) else node.get("items", []) + for item in items: + _walk(item) + return + + if content_type in ("text", "at"): + subtype = metadata.get("subtype") or ("at" if content_type == "at" else None) + text = "" if data is None else str(data) + if subtype in ("at", "mention"): + user_info = metadata.get("user") or {} + seg_data: Dict[str, Any] = { + "user_id": user_info.get("id") or user_info.get("user_id"), + "user_name": user_info.get("name") or user_info.get("display_name"), + "text": text, + "raw": user_info.get("raw") or user_info if user_info else None, + } + if any(v is not None for v in seg_data.values()): + segments.append(Seg(type="at", data=seg_data)) + else: + segments.append(Seg(type="at", data=text)) + else: + segments.append(Seg(type="text", data=text)) + return + + if content_type == "image": + url = "" + if isinstance(data, dict): + url = data.get("url") or data.get("file") or data.get("file_id") or "" + elif data is not None: + url = str(data) + segments.append(Seg(type="image", data=url)) + return + + if content_type == "audio": + url = "" + if isinstance(data, dict): + url = data.get("url") or data.get("file") or data.get("file_id") or "" + elif data is not None: + url = str(data) + segments.append(Seg(type="record", data=url)) + return + + if content_type == "video": + url = "" + if isinstance(data, dict): + url = data.get("url") or data.get("file") or data.get("file_id") or "" + elif data is not None: + url = str(data) + segments.append(Seg(type="video", data=url)) + return + + if content_type == "file": + file_name = "" + if isinstance(data, dict): + file_name = data.get("file_name") or data.get("name") or "" + text = file_name or "[file]" + segments.append(Seg(type="text", data=text)) + return + + if content_type == "command": + name = "" + args: Dict[str, Any] = {} + if isinstance(data, dict): + name = data.get("name", "") + args = data.get("args", {}) or {} + else: + name = str(data or "") + cmd_text = f"/{name}" if name else "/command" + if args: + cmd_text += " " + " ".join(f"{k}={v}" for k, v in args.items()) + segments.append(Seg(type="text", data=cmd_text)) + return + + if content_type == "event": + event_type = "" + if isinstance(data, dict): + event_type = data.get("event_type", "") + else: + event_type = str(data or "") + segments.append(Seg(type="text", data=f"[事件: {event_type or 'unknown'}]")) + return + + if content_type == "system": + text = "" if data is None else str(data) + segments.append(Seg(type="text", data=f"[系统] {text}")) + return + logger.warning(f"未知的消息类型: {content_type}") - segments.append(Seg.text(f"[未知消息类型: {content_type}]")) - + segments.append(Seg(type="text", data=f"[未知消息类型: {content_type}]")) + + _walk(content) return segments @staticmethod def to_legacy_dict(envelope: MessageEnvelope) -> Dict[str, Any]: """ - 将 MessageEnvelope 转换为旧版字典格式(用于向后兼容) - - Args: - envelope: 统一的消息信封 - - Returns: - Dict[str, Any]: 旧版消息字典 + Convert MessageEnvelope to legacy dict for backward compatibility. """ message_base = EnvelopeConverter.to_message_base(envelope) return message_base.to_dict() @@ -169,61 +246,45 @@ class EnvelopeConverter: @staticmethod def from_message_base(message: MessageBase, direction: str = "outgoing") -> MessageEnvelope: """ - 将 MessageBase 转换为 MessageEnvelope (反向转换) - - Args: - message: 内部消息格式 - direction: 消息方向 ("incoming" 或 "outgoing") - - Returns: - MessageEnvelope: 统一的消息信封 + Convert MessageBase to MessageEnvelope (maim_message style preferred). """ try: - message_info = message.message_info - user_info = message_info.user_info - group_info = message_info.group_info - - # 创建 SenderInfo - sender = { - "user_id": user_info.user_id, - "role": "assistant" if direction == "outgoing" else "user", - } - if user_info.user_nickname: - sender["display_name"] = user_info.user_nickname - if user_info.user_avatar: - sender["avatar_url"] = user_info.user_avatar - - # 创建 ChannelInfo - if group_info: - channel = { - "channel_id": group_info.group_id, - "channel_type": "group", - } - if group_info.group_name: - channel["title"] = group_info.group_name - else: - channel = { - "channel_id": user_info.user_id, - "channel_type": "private", - } - - # 转换 segments 为 Content - content = EnvelopeConverter._segments_to_content(message.message) - - # 创建 MessageEnvelope + info_dict = message.message_info.to_dict() + seg_dict = message.message_segment.to_dict() + envelope: MessageEnvelope = { - "id": message_info.message_id, "direction": direction, - "platform": message_info.platform, - "timestamp_ms": int(message_info.timestamp * 1000), - "channel": channel, - "sender": sender, - "content": content, - "conversation_id": group_info.group_id if group_info else user_info.user_id, + "message_info": info_dict, + "message_segment": seg_dict, + "platform": info_dict.get("platform"), + "message_id": info_dict.get("message_id"), + "schema_version": 1, } - + + if message.message_info.time is not None: + envelope["timestamp_ms"] = int(message.message_info.time * 1000) + if message.raw_message is not None: + envelope["raw_message"] = message.raw_message + + # legacy 补充,方便老代码继续工作 + segments = EnvelopeConverter._flatten_segments(message.message_segment) + envelope["content"] = EnvelopeConverter._segments_to_content(segments) + if message.message_info.user_info: + envelope["sender"] = { + "user_id": message.message_info.user_info.user_id, + "role": "assistant" if direction == "outgoing" else "user", + "display_name": message.message_info.user_info.user_nickname, + "avatar_url": getattr(message.message_info.user_info, "user_avatar", None), + } + if message.message_info.group_info: + envelope["channel"] = { + "channel_id": message.message_info.group_info.group_id, + "channel_type": "group", + "title": message.message_info.group_info.group_name, + } + return envelope - + except Exception as e: logger.error(f"转换 MessageBase 失败: {e}", exc_info=True) raise @@ -231,45 +292,50 @@ class EnvelopeConverter: @staticmethod def _segments_to_content(segments: List[Seg]) -> Dict[str, Any]: """ - 将 Seg 列表转换为 Content - - Args: - segments: 消息段列表 - - Returns: - Dict[str, Any]: 消息内容 + Convert Seg list to legacy Content (type/data/metadata). """ if not segments: - return {"type": "text", "text": ""} - - # 简化处理:如果有多个段,合并为文本 + return {"type": "text", "data": ""} + + def _seg_to_content(seg: Seg) -> Dict[str, Any]: + data = seg.data + + if seg.type == "text": + return {"type": "text", "data": data} + + if seg.type == "at": + content: Dict[str, Any] = {"type": "text", "data": ""} + metadata: Dict[str, Any] = {"subtype": "at"} + if isinstance(data, dict): + content["data"] = data.get("text", "") + user = { + "id": data.get("user_id"), + "name": data.get("user_name"), + "raw": data.get("raw"), + } + if any(v is not None for v in user.values()): + metadata["user"] = user + else: + content["data"] = data + if metadata: + content["metadata"] = metadata + return content + + if seg.type == "image": + return {"type": "image", "data": data} + + if seg.type in ("record", "voice", "audio"): + return {"type": "audio", "data": data} + + if seg.type == "video": + return {"type": "video", "data": data} + + return {"type": seg.type, "data": data} + if len(segments) == 1: - seg = segments[0] - - if seg.type == "text": - return {"type": "text", "text": seg.data.get("text", "")} - elif seg.type == "image": - return {"type": "image", "url": seg.data.get("file", "")} - elif seg.type == "record": - return {"type": "audio", "url": seg.data.get("file", "")} - elif seg.type == "video": - return {"type": "video", "url": seg.data.get("file", "")} - - # 多个段或未知类型 - 合并为文本 - text_parts = [] - for seg in segments: - if seg.type == "text": - text_parts.append(seg.data.get("text", "")) - elif seg.type == "image": - text_parts.append("[图片]") - elif seg.type == "record": - text_parts.append("[语音]") - elif seg.type == "video": - text_parts.append("[视频]") - else: - text_parts.append(f"[{seg.type}]") - - return {"type": "text", "text": "".join(text_parts)} + return _seg_to_content(segments[0]) + + return {"type": "collection", "data": [_seg_to_content(seg) for seg in segments]} __all__ = ["EnvelopeConverter"] diff --git a/src/mofox_bus/__init__.py b/src/mofox_bus/__init__.py index 85a29fa55..b0868c14b 100644 --- a/src/mofox_bus/__init__.py +++ b/src/mofox_bus/__init__.py @@ -10,72 +10,54 @@ from .adapter_utils import ( AdapterTransportOptions, AdapterBase, BatchDispatcher, + CoreSink, CoreMessageSink, HttpAdapterOptions, InProcessCoreSink, + ProcessCoreSink, + ProcessCoreSinkServer, WebSocketLike, WebSocketAdapterOptions, ) from .api import MessageClient, MessageServer from .codec import dumps_message, dumps_messages, loads_message, loads_messages -from .message_models import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, TemplateInfo, UserInfo +from .builder import MessageBuilder from .router import RouteConfig, Router, TargetConfig -from .runtime import MessageProcessingError, MessageRoute, MessageRuntime +from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware from .types import ( - AudioContent, - ChannelInfo, - CommandContent, - Content, - ContentType, - EventContent, - EventType, - FileContent, - ImageContent, + FormatInfoPayload, + GroupInfoPayload, MessageDirection, MessageEnvelope, - Role, - SenderInfo, - SystemContent, - TextContent, - VideoContent, + MessageInfoPayload, + SegPayload, + + TemplateInfoPayload, + UserInfoPayload, ) __all__ = [ # TypedDict model - "AudioContent", - "ChannelInfo", - "CommandContent", - "Content", - "ContentType", - "EventContent", - "EventType", - "FileContent", - "ImageContent", "MessageDirection", "MessageEnvelope", - "Role", - "SenderInfo", - "SystemContent", - "TextContent", - "VideoContent", + "SegPayload", + "UserInfoPayload", + "GroupInfoPayload", + "FormatInfoPayload", + "TemplateInfoPayload", + "MessageInfoPayload", # Codec helpers "codec", "dumps_message", "dumps_messages", "loads_message", "loads_messages", + "MessageBuilder", # Runtime / routing "MessageRoute", "MessageRuntime", "MessageProcessingError", - # Message dataclasses - "Seg", - "GroupInfo", - "UserInfo", - "FormatInfo", - "TemplateInfo", - "BaseMessageInfo", - "MessageBase", + "Middleware", # Server/client/router "MessageServer", "MessageClient", @@ -86,8 +68,11 @@ __all__ = [ "AdapterTransportOptions", "AdapterBase", "BatchDispatcher", + "CoreSink", "CoreMessageSink", "InProcessCoreSink", + "ProcessCoreSink", + "ProcessCoreSinkServer", "WebSocketLike", "WebSocketAdapterOptions", "HttpAdapterOptions", diff --git a/src/mofox_bus/adapter_utils.py b/src/mofox_bus/adapter_utils.py index 11c278ca5..e56e96577 100644 --- a/src/mofox_bus/adapter_utils.py +++ b/src/mofox_bus/adapter_utils.py @@ -2,6 +2,8 @@ from __future__ import annotations import asyncio import contextlib +import logging +import multiprocessing as mp from dataclasses import dataclass from typing import Any, AsyncIterator, Awaitable, Callable, Protocol @@ -11,6 +13,11 @@ import websockets from .types import MessageEnvelope +logger = logging.getLogger("mofox_bus.adapter") + + +OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]] + class CoreMessageSink(Protocol): async def send(self, message: MessageEnvelope) -> None: ... @@ -18,6 +25,22 @@ class CoreMessageSink(Protocol): async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional +class CoreSink(CoreMessageSink, Protocol): + """ + 双向 CoreSink 协议: + - send/send_many: 适配器 → 核心(incoming) + - push_outgoing: 核心 → 适配器(outgoing) + """ + + def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ... + + def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ... + + async def push_outgoing(self, envelope: MessageEnvelope) -> None: ... + + async def close(self) -> None: ... # pragma: no cover - lifecycle hook + + class WebSocketLike(Protocol): def __aiter__(self) -> AsyncIterator[str | bytes]: ... @@ -56,7 +79,7 @@ class AdapterBase: platform: str = "unknown" - def __init__(self, core_sink: CoreMessageSink, transport: AdapterTransportOptions = None): + def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None): """ Args: core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。 @@ -70,14 +93,31 @@ class AdapterBase: self._http_site: aiohttp_web.BaseSite | None = None async def start(self) -> None: - """根据配置自动启动 WS/HTTP 监听。""" + """启动适配器的传输层监听(如果配置了传输选项)。""" + if hasattr(self.core_sink, "set_outgoing_handler"): + try: + self.core_sink.set_outgoing_handler(self._on_outgoing_from_core) + except Exception: + logger.exception("Failed to register outgoing handler on core sink") if isinstance(self._transport_config, WebSocketAdapterOptions): await self._start_ws_transport(self._transport_config) elif isinstance(self._transport_config, HttpAdapterOptions): await self._start_http_transport(self._transport_config) + async def stop(self) -> None: - """停止自动管理的传输层。""" + """停止适配器的传输层监听(如果配置了传输选项)。""" + remove = getattr(self.core_sink, "remove_outgoing_handler", None) + if callable(remove): + try: + remove(self._on_outgoing_from_core) + except Exception: + logger.exception("Failed to detach outgoing handler on core sink") + elif hasattr(self.core_sink, "set_outgoing_handler"): + try: + self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type] + except Exception: + logger.exception("Failed to detach outgoing handler on core sink") if self._ws_task: self._ws_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -95,12 +135,12 @@ class AdapterBase: async def on_platform_message(self, raw: Any) -> None: """处理平台下发的单条消息并交给核心。""" - envelope = self.from_platform_message(raw) + envelope = await _maybe_await(self.from_platform_message(raw)) await self.core_sink.send(envelope) async def on_platform_messages(self, raw_messages: list[Any]) -> None: """批量推送入口,内部自动批量或逐条送入核心。""" - envelopes = [self.from_platform_message(raw) for raw in raw_messages] + envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages] await _send_many(self.core_sink, envelopes) async def send_to_platform(self, envelope: MessageEnvelope) -> None: @@ -112,7 +152,14 @@ class AdapterBase: for env in envelopes: await self._send_platform_message(env) - def from_platform_message(self, raw: Any) -> MessageEnvelope: + async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None: + """核心生成 outgoing envelope 时的内部处理逻辑""" + platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform") + if platform and platform != getattr(self, "platform", None): + return + await self._send_platform_message(envelope) + + def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]: """子类必须实现:将平台原始结构转换为统一 MessageEnvelope。""" raise NotImplementedError @@ -175,13 +222,22 @@ class AdapterBase: return orjson.dumps({"type": "send", "payload": envelope}) -class InProcessCoreSink: +class InProcessCoreSink(CoreSink): """ - 简单的进程内 sink,实现 CoreMessageSink 协议。 + 进程内核心消息 sink,实现 CoreSink 协议。 """ def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]): self._handler = handler + self._outgoing_handlers: set[OutgoingHandler] = set() + + def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: + if handler is None: + return + self._outgoing_handlers.add(handler) + + def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: + self._outgoing_handlers.discard(handler) async def send(self, message: MessageEnvelope) -> None: await self._handler(message) @@ -190,6 +246,140 @@ class InProcessCoreSink: for message in messages: await self._handler(message) + async def push_outgoing(self, envelope: MessageEnvelope) -> None: + if not self._outgoing_handlers: + logger.debug("Outgoing envelope dropped: no handler registered") + return + for callback in list(self._outgoing_handlers): + await callback(envelope) + + async def close(self) -> None: # pragma: no cover - symmetry + self._outgoing_handlers.clear() + + +class ProcessCoreSink(CoreSink): + """ + 进程间核心消息 sink,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化 + """ + + _CONTROL_STOP = {"__core_sink_control__": "stop"} + + def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None: + self._to_core_queue = to_core_queue + self._from_core_queue = from_core_queue + self._outgoing_handler: OutgoingHandler | None = None + self._closed = False + self._listener_task: asyncio.Task | None = None + self._loop = asyncio.get_event_loop() + + def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: + self._outgoing_handler = handler + if handler is not None and (self._listener_task is None or self._listener_task.done()): + self._listener_task = self._loop.create_task(self._listen_from_core()) + + def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: + if self._outgoing_handler is handler: + self._outgoing_handler = None + if self._listener_task and not self._listener_task.done(): + self._listener_task.cancel() + + async def send(self, message: MessageEnvelope) -> None: + await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message}) + + async def send_many(self, messages: list[MessageEnvelope]) -> None: + for message in messages: + await self.send(message) + + async def push_outgoing(self, envelope: MessageEnvelope) -> None: + logger.debug("ProcessCoreSink.push_outgoing called in child; ignored") + + async def close(self) -> None: + if self._closed: + return + self._closed = True + await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP) + if self._listener_task: + self._listener_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._listener_task + self._listener_task = None + + async def _listen_from_core(self) -> None: + while not self._closed: + try: + item = await asyncio.to_thread(self._from_core_queue.get) + except asyncio.CancelledError: + break + if item == self._CONTROL_STOP: + break + if isinstance(item, dict) and item.get("kind") == "outgoing": + envelope = item.get("payload") + if self._outgoing_handler: + try: + await self._outgoing_handler(envelope) + except Exception: # pragma: no cover + logger.exception("Failed to handle outgoing envelope in ProcessCoreSink") + else: + logger.debug("ProcessCoreSink received unknown payload: %r", item) + + +class ProcessCoreSinkServer: + """ + 进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。 + - 将传入的 incoming 消息转发给指定的 handler + - 将接收到的 outgoing 消息放入 outgoing 队列 + """ + + def __init__( + self, + *, + incoming_queue: mp.Queue, + outgoing_queue: mp.Queue, + core_handler: Callable[[MessageEnvelope], Awaitable[None]], + name: str | None = None, + ) -> None: + self._incoming_queue = incoming_queue + self._outgoing_queue = outgoing_queue + self._core_handler = core_handler + self._task: asyncio.Task | None = None + self._closed = False + self._name = name or "adapter" + + def start(self) -> None: + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._consume_incoming()) + + async def _consume_incoming(self) -> None: + while not self._closed: + try: + item = await asyncio.to_thread(self._incoming_queue.get) + except asyncio.CancelledError: + break + if isinstance(item, dict) and item.get("__core_sink_control__") == "stop": + break + if isinstance(item, dict) and item.get("kind") == "incoming": + envelope = item.get("payload") + try: + await self._core_handler(envelope) + except Exception: # pragma: no cover + logger.exception("Failed to dispatch incoming envelope from %s", self._name) + else: + logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item) + + async def push_outgoing(self, envelope: MessageEnvelope) -> None: + await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope}) + + async def close(self) -> None: + if self._closed: + return + self._closed = True + await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"}) + await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP) + if self._task: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None: send_many = getattr(sink, "send_many", None) @@ -200,11 +390,19 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> await sink.send(env) +async def _maybe_await(result: Any) -> Any: + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + + class BatchDispatcher: """ - 将 send 操作合并为批量发送,适合网络 IO 密集场景。 + 批量消息分发器,负责将消息批量发送到核心 sink。 """ + _STOP = object() + def __init__( self, sink: CoreMessageSink, @@ -215,56 +413,79 @@ class BatchDispatcher: self._sink = sink self._max_batch_size = max_batch_size self._flush_interval = flush_interval - self._buffer: list[MessageEnvelope] = [] - self._lock = asyncio.Lock() - self._flush_task: asyncio.Task | None = None + self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue() + self._worker: asyncio.Task | None = None self._closed = False async def add(self, message: MessageEnvelope) -> None: - async with self._lock: - if self._closed: - raise RuntimeError("Dispatcher closed") - self._buffer.append(message) - self._ensure_timer() - if len(self._buffer) >= self._max_batch_size: - await self._flush_locked() + if self._closed: + raise RuntimeError("Dispatcher closed") + self._ensure_worker() + await self._queue.put(message) async def close(self) -> None: - async with self._lock: - self._closed = True - await self._flush_locked() - if self._flush_task: - self._flush_task.cancel() - self._flush_task = None - - def _ensure_timer(self) -> None: - if self._flush_task is not None and not self._flush_task.done(): + if self._closed: return - loop = asyncio.get_running_loop() - self._flush_task = loop.create_task(self._flush_loop()) + self._closed = True + self._ensure_worker() + await self._queue.put(self._STOP) + if self._worker: + await self._worker + self._worker = None - async def _flush_loop(self) -> None: + def _ensure_worker(self) -> None: + if self._worker is not None and not self._worker.done(): + return + self._worker = asyncio.create_task(self._worker_loop()) + + async def _worker_loop(self) -> None: + buffer: list[MessageEnvelope] = [] try: - await asyncio.sleep(self._flush_interval) - async with self._lock: - await self._flush_locked() - except asyncio.CancelledError: # pragma: no cover - timer cancellation - pass + while True: + try: + item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval) + except asyncio.TimeoutError: + item = None - async def _flush_locked(self) -> None: - if not self._buffer: + if item is self._STOP: + await self._flush_buffer(buffer) + return + if item is not None: + buffer.append(item) # type: ignore[arg-type] + + while len(buffer) < self._max_batch_size: + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + if item is self._STOP: + await self._flush_buffer(buffer) + return + buffer.append(item) # type: ignore[arg-type] + + if buffer and (len(buffer) >= self._max_batch_size or item is None): + await self._flush_buffer(buffer) + except asyncio.CancelledError: # pragma: no cover - worker cancellation + if buffer: + await self._flush_buffer(buffer) + + async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None: + if not buffer: return - payload = list(self._buffer) - self._buffer.clear() - await self._sink.send_many(payload) - + payload = list(buffer) + buffer.clear() + await _send_many(self._sink, payload) __all__ = [ "AdapterTransportOptions", "AdapterBase", "BatchDispatcher", + "CoreSink", "CoreMessageSink", "HttpAdapterOptions", "InProcessCoreSink", + "ProcessCoreSink", + "ProcessCoreSinkServer", + "WebSocketLike", "WebSocketAdapterOptions", ] diff --git a/src/mofox_bus/api.py b/src/mofox_bus/api.py index 11ae1d4ef..8058cc783 100644 --- a/src/mofox_bus/api.py +++ b/src/mofox_bus/api.py @@ -11,10 +11,36 @@ import orjson import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from .message_models import MessageBase MessagePayload = Dict[str, Any] MessageHandler = Callable[[MessagePayload], Awaitable[None] | None] +DisconnectCallback = Callable[[str, str], Awaitable[None] | None] + + +def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any: + if isinstance(payload, dict): + payload.setdefault("raw_bytes", raw_bytes) + elif isinstance(payload, list): + for item in payload: + if isinstance(item, dict): + item.setdefault("raw_bytes", raw_bytes) + return payload + + +def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]: + if isinstance(message, (bytes, bytearray)): + return bytes(message), True + if use_raw_bytes and isinstance(message, dict): + raw = message.get("raw_bytes") + if isinstance(raw, (bytes, bytearray)): + return bytes(raw), True + payload = message + if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes: + payload = {k: v for k, v in payload.items() if k != "raw_bytes"} + data = orjson.dumps(payload) + if use_raw_bytes: + return data, True + return data.decode("utf-8"), False class BaseMessageHandler: @@ -60,6 +86,8 @@ class MessageServer(BaseMessageHandler): mode: Literal["ws", "tcp"] = "ws", custom_logger: logging.Logger | None = None, enable_custom_uvicorn_logger: bool = False, + queue_maxsize: int = 1000, + worker_count: int = 1, ) -> None: super().__init__() if mode != "ws": @@ -80,6 +108,9 @@ class MessageServer(BaseMessageHandler): self._conn_lock = asyncio.Lock() self._server: uvicorn.Server | None = None self._running = False + self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize) + self._worker_count = max(1, worker_count) + self._worker_tasks: list[asyncio.Task] = [] self._setup_routes() def _setup_routes(self) -> None: @@ -97,21 +128,22 @@ class MessageServer(BaseMessageHandler): while True: msg = await websocket.receive() if msg["type"] == "websocket.receive": - data = msg.get("text") - if data is None and msg.get("bytes") is not None: - data = msg["bytes"].decode("utf-8") - if not data: + raw_bytes = msg.get("bytes") + if raw_bytes is None and msg.get("text") is not None: + raw_bytes = msg["text"].encode("utf-8") + if not raw_bytes: continue try: - payload = orjson.loads(data) + payload = orjson.loads(raw_bytes) except orjson.JSONDecodeError: logging.getLogger("mofox_bus.server").warning("Invalid JSON payload") continue + payload = _attach_raw_bytes(payload, raw_bytes) if isinstance(payload, list): for item in payload: - await self.process_message(item) + await self._enqueue_message(item) else: - await self.process_message(payload) + await self._enqueue_message(payload) elif msg["type"] == "websocket.disconnect": break except WebSocketDisconnect: @@ -119,6 +151,49 @@ class MessageServer(BaseMessageHandler): finally: await self._remove_connection(websocket, platform) + async def _enqueue_message(self, payload: MessagePayload) -> None: + if not self._worker_tasks: + self._start_workers() + try: + self._message_queue.put_nowait(payload) + except asyncio.QueueFull: + logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message") + + def _start_workers(self) -> None: + if self._worker_tasks: + return + self._running = True + for _ in range(self._worker_count): + task = asyncio.create_task(self._consumer_worker()) + self._worker_tasks.append(task) + + async def _stop_workers(self) -> None: + if not self._worker_tasks: + return + self._running = False + for task in self._worker_tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*self._worker_tasks, return_exceptions=True) + self._worker_tasks.clear() + while not self._message_queue.empty(): + with contextlib.suppress(asyncio.QueueEmpty): + self._message_queue.get_nowait() + self._message_queue.task_done() + + async def _consumer_worker(self) -> None: + while self._running: + try: + payload = await self._message_queue.get() + except asyncio.CancelledError: + break + try: + await self.process_message(payload) + except Exception: # pragma: no cover - best effort logging + logging.getLogger("mofox_bus.server").exception("Error processing message") + finally: + self._message_queue.task_done() + async def verify_token(self, token: str | None) -> bool: if not self._enable_token: return True @@ -145,33 +220,45 @@ class MessageServer(BaseMessageHandler): if platform and self._platform_connections.get(platform) is websocket: del self._platform_connections[platform] - async def broadcast_message(self, message: MessagePayload) -> None: - data = orjson.dumps(message).decode("utf-8") + async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None: + payload: MessagePayload | bytes = message + data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes) async with self._conn_lock: targets = list(self._connections) for ws in targets: - await ws.send_text(data) + if is_binary: + await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) + else: + await ws.send_text(data if isinstance(data, str) else data.decode("utf-8")) - async def broadcast_to_platform(self, platform: str, message: MessagePayload) -> None: + async def broadcast_to_platform( + self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False + ) -> None: ws = self._platform_connections.get(platform) if ws is None: raise RuntimeError(f"No active connection for platform {platform}") - await ws.send_text(orjson.dumps(message).decode("utf-8")) + payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message + data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes) + if is_binary: + await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) + else: + await ws.send_text(data if isinstance(data, str) else data.decode("utf-8")) - async def send_message(self, message: MessageBase | MessagePayload) -> None: - payload = message.to_dict() if isinstance(message, MessageBase) else message - platform = payload.get("message_info", {}).get("platform") + async def send_message( + self, message: MessagePayload, *, prefer_raw_bytes: bool = False + ) -> None: + platform = message.get("message_info", {}).get("platform") if not platform: raise ValueError("message_info.platform is required to route the message") - await self.broadcast_to_platform(platform, payload) - + await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes) + def run_sync(self) -> None: if not self._own_app: return asyncio.run(self.run()) async def run(self) -> None: - self._running = True + self._start_workers() if not self._own_app: return config = uvicorn.Config( @@ -191,6 +278,7 @@ class MessageServer(BaseMessageHandler): async def stop(self) -> None: self._running = False + await self._stop_workers() if self._server: self._server.should_exit = True await self._server.shutdown() @@ -217,7 +305,13 @@ class MessageClient(BaseMessageHandler): WebSocket 消息客户端,实现双向传输。 """ - def __init__(self, mode: Literal["ws", "tcp"] = "ws") -> None: + def __init__( + self, + mode: Literal["ws", "tcp"] = "ws", + *, + reconnect_interval: float = 5.0, + logger: logging.Logger | None = None, + ) -> None: super().__init__() if mode != "ws": raise NotImplementedError("Only WebSocket mode is supported in mofox_bus") @@ -230,6 +324,9 @@ class MessageClient(BaseMessageHandler): self._token: str | None = None self._ssl_verify: str | None = None self._closed = False + self._on_disconnect: DisconnectCallback | None = None + self._reconnect_interval = reconnect_interval + self._logger = logger or logging.getLogger("mofox_bus.client") async def connect( self, @@ -243,8 +340,12 @@ class MessageClient(BaseMessageHandler): self._platform = platform self._token = token self._ssl_verify = ssl_verify + self._closed = False await self._establish_connection() + def set_disconnect_callback(self, callback: DisconnectCallback) -> None: + self._on_disconnect = callback + async def _establish_connection(self) -> None: if self._session is None: self._session = aiohttp.ClientSession() @@ -257,17 +358,21 @@ class MessageClient(BaseMessageHandler): self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context) self._receive_task = asyncio.create_task(self._receive_loop()) + async def _connect_once(self) -> None: + await self._establish_connection() + async def _receive_loop(self) -> None: assert self._ws is not None try: async for msg in self._ws: if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): - data = msg.data if isinstance(msg.data, str) else msg.data.decode("utf-8") + raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8") try: - payload = orjson.loads(data) + payload = orjson.loads(raw_bytes) except orjson.JSONDecodeError: logging.getLogger("mofox_bus.client").warning("Invalid JSON payload") continue + payload = _attach_raw_bytes(payload, raw_bytes) if isinstance(payload, list): for item in payload: await self.process_message(item) @@ -278,23 +383,33 @@ class MessageClient(BaseMessageHandler): except asyncio.CancelledError: # pragma: no cover - cancellation path pass finally: + if not self._closed: + await self._notify_disconnect("websocket disconnected") + await self._reconnect() if self._ws: await self._ws.close() self._ws = None async def run(self) -> None: - if self._receive_task is None: - await self._establish_connection() - try: - if self._receive_task: - await self._receive_task - except asyncio.CancelledError: # pragma: no cover - cancellation path - pass + self._closed = False + while not self._closed: + if self._receive_task is None: + await self._establish_connection() + task = self._receive_task + if task is None: + break + try: + await task + except asyncio.CancelledError: # pragma: no cover - cancellation path + raise - async def send_message(self, message: MessagePayload) -> bool: - if self._ws is None or self._ws.closed: - raise RuntimeError("WebSocket connection is not established") - await self._ws.send_str(orjson.dumps(message).decode("utf-8")) + async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool: + ws = await self._ensure_ws() + data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes) + if is_binary: + await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) + else: + await ws.send_str(data if isinstance(data, str) else data.decode("utf-8")) return True def is_connected(self) -> bool: @@ -313,6 +428,42 @@ class MessageClient(BaseMessageHandler): await self._session.close() self._session = None + async def _notify_disconnect(self, reason: str) -> None: + if self._on_disconnect is None: + return + try: + result = self._on_disconnect(self._platform, reason) + if asyncio.iscoroutine(result): + await result + except Exception: # pragma: no cover - best effort notification + logging.getLogger("mofox_bus.client").exception("Disconnect callback failed") + + async def _reconnect(self) -> None: + self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval) + await asyncio.sleep(self._reconnect_interval) + await self._connect_once() + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession() + return self._session + + async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse: + if self._ws is None or self._ws.closed: + await self._connect_once() + assert self._ws is not None + return self._ws + + async def __aenter__(self) -> "MessageClient": + if not self._url or not self._platform: + raise RuntimeError("connect() must be called before using MessageClient as a context manager") + await self._ensure_session() + await self._ensure_ws() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.stop() + def _self_websocket(app: FastAPI, path: str): """ diff --git a/src/mofox_bus/builder.py b/src/mofox_bus/builder.py new file mode 100644 index 000000000..240f2a195 --- /dev/null +++ b/src/mofox_bus/builder.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import time +import uuid +from typing import Any, Dict, List + +from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload + + +class MessageBuilder: + """ + Fluent helper to build MessageEnvelope safely with type hints. + + Example: + msg = ( + MessageBuilder() + .text("Hello") + .image("http://example.com/1.png") + .to_user("123", platform="qq") + .build() + ) + """ + + def __init__(self) -> None: + self._direction: str = "outgoing" + self._message_info: MessageInfoPayload = {} + self._segments: List[SegPayload] = [] + self._metadata: Dict[str, Any] | None = None + self._timestamp_ms: int | None = None + self._message_id: str | None = None + + def direction(self, value: str) -> "MessageBuilder": + self._direction = value + return self + + def message_id(self, value: str) -> "MessageBuilder": + self._message_id = value + return self + + def timestamp_ms(self, value: int | None = None) -> "MessageBuilder": + self._timestamp_ms = value or int(time.time() * 1000) + return self + + def metadata(self, value: Dict[str, Any]) -> "MessageBuilder": + self._metadata = value + return self + + def platform(self, value: str) -> "MessageBuilder": + self._message_info["platform"] = value + return self + + def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder": + if platform: + self.platform(platform) + user_info: UserInfoPayload = {"user_id": user_id} + if nickname: + user_info["user_nickname"] = nickname + self._message_info["user_info"] = user_info + return self + + def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder": + if platform: + self.platform(platform) + group_info: GroupInfoPayload = {"group_id": group_id} + if name: + group_info["group_name"] = name + self._message_info["group_info"] = group_info + return self + + def seg(self, type_: str, data: Any) -> "MessageBuilder": + self._segments.append({"type": type_, "data": data}) + return self + + def text(self, content: str) -> "MessageBuilder": + return self.seg("text", content) + + def image(self, url: str) -> "MessageBuilder": + return self.seg("image", url) + + def reply(self, target_message_id: str) -> "MessageBuilder": + return self.seg("reply", target_message_id) + + def raw_segment(self, segment: SegPayload) -> "MessageBuilder": + self._segments.append(segment) + return self + + def build(self) -> MessageEnvelope: + # message_info defaults + if not self._segments: + raise ValueError("message_segment is required, add at least one segment before build()") + if self._message_id is None: + self._message_id = str(uuid.uuid4()) + info = dict(self._message_info) + info.setdefault("message_id", self._message_id) + info.setdefault("time", time.time()) + + segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments] + envelope: MessageEnvelope = { + "direction": self._direction, # type: ignore[assignment] + "message_info": info, + "message_segment": segments[0] if len(segments) == 1 else list(segments), + } + if self._metadata is not None: + envelope["metadata"] = self._metadata + if self._timestamp_ms is not None: + envelope["timestamp_ms"] = self._timestamp_ms + return envelope + + +__all__ = ["MessageBuilder"] diff --git a/src/mofox_bus/codec.py b/src/mofox_bus/codec.py index 6f8d23cc6..f6430c824 100644 --- a/src/mofox_bus/codec.py +++ b/src/mofox_bus/codec.py @@ -27,24 +27,23 @@ def _loads(data: bytes) -> Dict[str, Any]: def dumps_message(msg: MessageEnvelope) -> bytes: """ - 将单条 MessageEnvelope 序列化为 JSON bytes。 + 将单条消息序列化为 JSON bytes。 """ - if "schema_version" not in msg: - msg["schema_version"] = DEFAULT_SCHEMA_VERSION - return _dumps(msg) - + sanitized = _strip_raw_bytes(msg) + if "schema_version" not in sanitized: + sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION + return _dumps(sanitized) def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes: """ - 将多条消息批量序列化,以提升吞吐。 + 将批量消息序列化为 JSON bytes。 """ payload = { "schema_version": DEFAULT_SCHEMA_VERSION, - "items": list(messages), + "items": [_strip_raw_bytes(msg) for msg in messages], } return _dumps(payload) - def loads_message(data: bytes | str) -> MessageEnvelope: """ 反序列化单条消息。 @@ -78,6 +77,14 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope: raise ValueError(f"Unsupported schema_version={version}") + +def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope: + if isinstance(msg, dict) and "raw_bytes" in msg: + new_msg = dict(msg) + new_msg.pop("raw_bytes", None) + return new_msg # type: ignore[return-value] + return msg + __all__ = [ "DEFAULT_SCHEMA_VERSION", "dumps_message", diff --git a/src/mofox_bus/message_models.py b/src/mofox_bus/message_models.py deleted file mode 100644 index ad408b63f..000000000 --- a/src/mofox_bus/message_models.py +++ /dev/null @@ -1,189 +0,0 @@ -from __future__ import annotations - -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional - - -@dataclass -class Seg: - """ - 消息段,表示一段文本/图片/结构化内容。 - """ - - type: str - data: Any - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Seg": - seg_type = data.get("type") - seg_data = data.get("data") - if seg_type == "seglist" and isinstance(seg_data, list): - seg_data = [Seg.from_dict(item) for item in seg_data] - return cls(type=seg_type, data=seg_data) - - def to_dict(self) -> Dict[str, Any]: - if self.type == "seglist" and isinstance(self.data, list): - payload = [seg.to_dict() if isinstance(seg, Seg) else seg for seg in self.data] - else: - payload = self.data - return {"type": self.type, "data": payload} - - -@dataclass -class GroupInfo: - platform: Optional[str] = None - group_id: Optional[str] = None - group_name: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Optional["GroupInfo"]: - if not data or data.get("group_id") is None: - return None - return cls( - platform=data.get("platform"), - group_id=data.get("group_id"), - group_name=data.get("group_name"), - ) - - -@dataclass -class UserInfo: - platform: Optional[str] = None - user_id: Optional[str] = None - user_nickname: Optional[str] = None - user_cardname: Optional[str] = None - user_avatar: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "UserInfo": - return cls( - platform=data.get("platform"), - user_id=data.get("user_id"), - user_nickname=data.get("user_nickname"), - user_cardname=data.get("user_cardname"), - user_avatar=data.get("user_avatar"), - ) - - -@dataclass -class FormatInfo: - content_format: Optional[List[str]] = None - accept_format: Optional[List[str]] = None - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Optional["FormatInfo"]: - if not data: - return None - return cls( - content_format=data.get("content_format"), - accept_format=data.get("accept_format"), - ) - - -@dataclass -class TemplateInfo: - template_items: Optional[Dict[str, str]] = None - template_name: Optional[Dict[str, str]] = None - template_default: bool = True - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> Optional["TemplateInfo"]: - if not data: - return None - return cls( - template_items=data.get("template_items"), - template_name=data.get("template_name"), - template_default=data.get("template_default", True), - ) - - -@dataclass -class BaseMessageInfo: - platform: Optional[str] = None - message_id: Optional[str] = None - time: Optional[float] = None - group_info: Optional[GroupInfo] = None - user_info: Optional[UserInfo] = None - format_info: Optional[FormatInfo] = None - template_info: Optional[TemplateInfo] = None - additional_config: Optional[Dict[str, Any]] = None - - def to_dict(self) -> Dict[str, Any]: - result: Dict[str, Any] = {} - if self.platform is not None: - result["platform"] = self.platform - if self.message_id is not None: - result["message_id"] = self.message_id - if self.time is not None: - result["time"] = self.time - if self.additional_config is not None: - result["additional_config"] = self.additional_config - if self.group_info is not None: - result["group_info"] = self.group_info.to_dict() - if self.user_info is not None: - result["user_info"] = self.user_info.to_dict() - if self.format_info is not None: - result["format_info"] = self.format_info.to_dict() - if self.template_info is not None: - result["template_info"] = self.template_info.to_dict() - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BaseMessageInfo": - return cls( - platform=data.get("platform"), - message_id=data.get("message_id"), - time=data.get("time"), - additional_config=data.get("additional_config"), - group_info=GroupInfo.from_dict(data.get("group_info", {})), - user_info=UserInfo.from_dict(data.get("user_info", {})), - format_info=FormatInfo.from_dict(data.get("format_info", {})), - template_info=TemplateInfo.from_dict(data.get("template_info", {})), - ) - - -@dataclass -class MessageBase: - message_info: BaseMessageInfo - message_segment: Seg - raw_message: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = { - "message_info": self.message_info.to_dict(), - "message_segment": self.message_segment.to_dict(), - } - if self.raw_message is not None: - payload["raw_message"] = self.raw_message - return payload - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MessageBase": - return cls( - message_info=BaseMessageInfo.from_dict(data.get("message_info", {})), - message_segment=Seg.from_dict(data.get("message_segment", {})), - raw_message=data.get("raw_message"), - ) - - -__all__ = [ - "BaseMessageInfo", - "FormatInfo", - "GroupInfo", - "MessageBase", - "Seg", - "TemplateInfo", - "UserInfo", -] diff --git a/src/mofox_bus/router.py b/src/mofox_bus/router.py index 46e29de27..6be2d1408 100644 --- a/src/mofox_bus/router.py +++ b/src/mofox_bus/router.py @@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass from typing import Callable, Dict, Optional from .api import MessageClient -from .message_models import MessageBase +from .types import MessageEnvelope logger = logging.getLogger("mofox_bus.router") @@ -55,7 +55,7 @@ class Router: self.handlers: list[Callable[[Dict], None]] = [] self._running = False self._client_tasks: Dict[str, asyncio.Task] = {} - self._monitor_task: asyncio.Task | None = None + self._stop_event: asyncio.Event | None = None async def connect(self, platform: str) -> None: if platform not in self.config.route_config: @@ -65,6 +65,7 @@ class Router: if mode != "ws": raise NotImplementedError("TCP mode is not implemented yet") client = MessageClient(mode="ws") + client.set_disconnect_callback(self._handle_client_disconnect) await client.connect( url=target.url, platform=platform, @@ -75,7 +76,7 @@ class Router: client.register_message_handler(handler) self.clients[platform] = client if self._running: - self._client_tasks[platform] = asyncio.create_task(client.run()) + self._start_client_task(platform, client) def register_class_handler(self, handler: Callable[[Dict], None]) -> None: self.handlers.append(handler) @@ -84,36 +85,18 @@ class Router: async def run(self) -> None: self._running = True + self._stop_event = asyncio.Event() for platform in self.config.route_config: if platform not in self.clients: await self.connect(platform) for platform, client in self.clients.items(): if platform not in self._client_tasks: - self._client_tasks[platform] = asyncio.create_task(client.run()) - self._monitor_task = asyncio.create_task(self._monitor_connections()) + self._start_client_task(platform, client) try: - while self._running: - await asyncio.sleep(1) + await self._stop_event.wait() except asyncio.CancelledError: # pragma: no cover raise - async def _monitor_connections(self) -> None: - await asyncio.sleep(3) - while self._running: - for platform in list(self.clients.keys()): - client = self.clients.get(platform) - if client is None: - continue - if not client.is_connected(): - logger.info("Detected disconnect from %s, attempting reconnect", platform) - await self._reconnect_platform(platform) - await asyncio.sleep(5) - - async def _reconnect_platform(self, platform: str) -> None: - await self.remove_platform(platform) - if platform in self.config.route_config: - await self.connect(platform) - async def remove_platform(self, platform: str) -> None: if platform in self._client_tasks: task = self._client_tasks.pop(platform) @@ -124,32 +107,55 @@ class Router: if client: await client.stop() + async def _handle_client_disconnect(self, platform: str, reason: str) -> None: + logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason) + task = self._client_tasks.get(platform) + if task is not None and not task.done(): + return + client = self.clients.get(platform) + if client and self._running: + self._start_client_task(platform, client) + async def stop(self) -> None: self._running = False - if self._monitor_task: - self._monitor_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._monitor_task - self._monitor_task = None + if self._stop_event: + self._stop_event.set() for platform in list(self.clients.keys()): await self.remove_platform(platform) self.clients.clear() - def get_target_url(self, message: MessageBase) -> Optional[str]: - platform = message.message_info.platform + def _start_client_task(self, platform: str, client: MessageClient) -> None: + task = asyncio.create_task(client.run()) + task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t))) + self._client_tasks[platform] = task + + async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None: + if not self._running: + return + if task.cancelled(): + return + exc = task.exception() + if exc: + logger.warning("Client task for %s ended with exception: %s", platform, exc) + client = self.clients.get(platform) + if client: + self._start_client_task(platform, client) + + def get_target_url(self, message: MessageEnvelope) -> Optional[str]: + platform = message.get("message_info", {}).get("platform") if not platform: return None target = self.config.route_config.get(platform) return target.url if target else None - async def send_message(self, message: MessageBase): - platform = message.message_info.platform + async def send_message(self, message: MessageEnvelope): + platform = message.get("message_info", {}).get("platform") if not platform: raise ValueError("message_info.platform is required") client = self.clients.get(platform) if client is None: raise RuntimeError(f"No client connected for platform {platform}") - return await client.send_message(message.to_dict()) + return await client.send_message(message) async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None: new_config = RouteConfig.from_dict(config_data) diff --git a/src/mofox_bus/runtime.py b/src/mofox_bus/runtime.py index 1d41a445b..ad46a8780 100644 --- a/src/mofox_bus/runtime.py +++ b/src/mofox_bus/runtime.py @@ -1,9 +1,10 @@ from __future__ import annotations import asyncio +import inspect import threading from dataclasses import dataclass -from typing import Awaitable, Callable, Iterable, List +from typing import Awaitable, Callable, Dict, Iterable, List, Protocol from .types import MessageEnvelope @@ -12,6 +13,11 @@ ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None] Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]] MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None] BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None] +MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]] + + +class Middleware(Protocol): + async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ... class MessageProcessingError(RuntimeError): @@ -19,7 +25,7 @@ class MessageProcessingError(RuntimeError): def __init__(self, message: MessageEnvelope, original: BaseException): detail = message.get("id", "") - super().__init__(f"Failed to handle message {detail}: {original}") # pragma: no cover - str repr only + super().__init__(f"Failed to handle message {detail}: {original}") self.message_envelope = message self.original = original @@ -29,6 +35,8 @@ class MessageRoute: predicate: Predicate handler: MessageHandler name: str | None = None + message_type: str | None = None + event_types: set[str] | None = None class MessageRuntime: @@ -43,15 +51,36 @@ class MessageRuntime: self._error_hooks: list[ErrorHook] = [] self._batch_handler: BatchHandler | None = None self._lock = threading.RLock() + self._middlewares: list[Middleware] = [] + self._type_routes: Dict[str, list[MessageRoute]] = {} + self._event_routes: Dict[str, list[MessageRoute]] = {} - def add_route(self, predicate: Predicate, handler: MessageHandler, name: str | None = None) -> None: + def add_route( + self, + predicate: Predicate, + handler: MessageHandler, + name: str | None = None, + *, + message_type: str | None = None, + event_types: Iterable[str] | None = None, + ) -> None: with self._lock: - self._routes.append(MessageRoute(predicate=predicate, handler=handler, name=name)) + route = MessageRoute( + predicate=predicate, + handler=handler, + name=name, + message_type=message_type, + event_types=set(event_types) if event_types is not None else None, + ) + self._routes.append(route) + if message_type: + self._type_routes.setdefault(message_type, []).append(route) + if route.event_types: + for et in route.event_types: + self._event_routes.setdefault(et, []).append(route) def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]: - """ - 装饰器写法,便于在核心逻辑中声明式注册。 - """ + """装饰器写法,便于在核心逻辑中声明式注册。""" def decorator(func: MessageHandler) -> MessageHandler: self.add_route(predicate, func, name=name) @@ -59,6 +88,60 @@ class MessageRuntime: return decorator + def on_message( + self, + *, + message_type: str | None = None, + platform: str | None = None, + predicate: Predicate | None = None, + name: str | None = None, + ) -> Callable[[MessageHandler], MessageHandler]: + """Sugar 装饰器,基于 Seg.type/platform 及可选额外谓词匹配。""" + + async def combined_predicate(message: MessageEnvelope) -> bool: + if message_type is not None and _extract_segment_type(message) != message_type: + return False + if platform is not None: + info_platform = message.get("message_info", {}).get("platform") + if message.get("platform") not in (None, platform) and info_platform is None: + return False + if info_platform not in (None, platform): + return False + if predicate is None: + return True + return await _invoke_callable(predicate, message, prefer_thread=False) + + def decorator(func: MessageHandler) -> MessageHandler: + self.add_route(combined_predicate, func, name=name, message_type=message_type) + return func + + return decorator + + def on_event( + self, + event_type: str | Iterable[str], + *, + name: str | None = None, + ) -> Callable[[MessageHandler], MessageHandler]: + """装饰器,基于 message 或 message_info.additional_config 中的 event_type 匹配。""" + + allowed = {event_type} if isinstance(event_type, str) else set(event_type) + + async def predicate(message: MessageEnvelope) -> bool: + current = ( + message.get("event_type") + or message.get("message_info", {}) + .get("additional_config", {}) + .get("event_type") + ) + return current in allowed + + def decorator(func: MessageHandler) -> MessageHandler: + self.add_route(predicate, func, name=name, event_types=allowed) + return func + + return decorator + def set_batch_handler(self, handler: BatchHandler) -> None: self._batch_handler = handler @@ -71,14 +154,20 @@ class MessageRuntime: def register_error_hook(self, hook: ErrorHook) -> None: self._error_hooks.append(hook) + def register_middleware(self, middleware: Middleware) -> None: + """注册洋葱模型中间件,围绕处理器执行。""" + + self._middlewares.append(middleware) + async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None: await self._run_hooks(self._before_hooks, message) try: route = await self._match_route(message) if route is None: return None - result = await _maybe_await(route.handler(message)) - except Exception as exc: # pragma: no cover - tested indirectly + handler = self._wrap_with_middlewares(route.handler) + result = await handler(message) + except Exception as exc: await self._run_error_hooks(message, exc) raise MessageProcessingError(message, exc) from exc await self._run_hooks(self._after_hooks, message) @@ -89,7 +178,7 @@ class MessageRuntime: if not batch: return [] if self._batch_handler is not None: - result = await _maybe_await(self._batch_handler(batch)) + result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True) return result or [] responses: list[MessageEnvelope] = [] for message in batch: @@ -99,21 +188,61 @@ class MessageRuntime: return responses async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None: + candidates: list[MessageRoute] = [] + message_type = _extract_segment_type(message) + event_type = ( + message.get("event_type") + or message.get("message_info", {}) + .get("additional_config", {}) + .get("event_type") + ) with self._lock: - routes = list(self._routes) - for route in routes: - should_handle = await _maybe_await(route.predicate(message)) + if event_type and event_type in self._event_routes: + candidates.extend(self._event_routes[event_type]) + if message_type and message_type in self._type_routes: + candidates.extend(self._type_routes[message_type]) + candidates.extend(self._routes) + + seen: set[int] = set() + for route in candidates: + rid = id(route) + if rid in seen: + continue + seen.add(rid) + should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False) if should_handle: return route return None async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None: - for hook in hooks: - await _maybe_await(hook(message)) + coro_list = [self._call_hook(hook, message) for hook in hooks] + if coro_list: + await asyncio.gather(*coro_list) + + async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None: + await _invoke_callable(hook, message, prefer_thread=True) async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None: - for hook in self._error_hooks: - await _maybe_await(hook(message, exc)) + coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks] + if coros: + await asyncio.gather(*coros) + + async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None: + await _invoke_callable(hook, message, exc, prefer_thread=True) + + def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable: + async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None: + return await _invoke_callable(handler, message, prefer_thread=True) + + wrapped: MiddlewareCallable = base_handler + for middleware in reversed(self._middlewares): + current = wrapped + + async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None: + return await _invoke_callable(mw, msg, nxt, prefer_thread=False) + + wrapped = wrapper + return wrapped async def _maybe_await(result): @@ -122,6 +251,32 @@ async def _maybe_await(result): return result +async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False): + """Support sync/async callables with optional thread offloading.""" + if inspect.iscoroutinefunction(func): + return await func(*args) + if prefer_thread: + result = await asyncio.to_thread(func, *args) + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + result = func(*args) + if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future): + return await result + return result + + +def _extract_segment_type(message: MessageEnvelope) -> str | None: + seg = message.get("message_segment") or message.get("message_chain") + if isinstance(seg, dict): + return seg.get("type") + if isinstance(seg, list) and seg: + first = seg[0] + if isinstance(first, dict): + return first.get("type") + return None + + __all__ = [ "BatchHandler", "Hook", @@ -129,5 +284,6 @@ __all__ = [ "MessageProcessingError", "MessageRoute", "MessageRuntime", + "Middleware", "Predicate", ] diff --git a/src/mofox_bus/types.py b/src/mofox_bus/types.py index f4a9eebae..9ff517629 100644 --- a/src/mofox_bus/types.py +++ b/src/mofox_bus/types.py @@ -3,160 +3,91 @@ from __future__ import annotations from typing import Any, Dict, List, Literal, NotRequired, TypedDict MessageDirection = Literal["incoming", "outgoing"] -Role = Literal["user", "assistant", "system", "tool", "platform"] -ContentType = Literal[ - "text", - "image", - "audio", - "file", - "video", - "event", - "command", - "system", -] -EventType = Literal[ - "message_created", - "message_updated", - "message_deleted", - "member_join", - "member_leave", - "typing", - "reaction_add", - "reaction_remove", -] +# ---------------------------- +# maim_message 风格的 TypedDict +# ---------------------------- -class TextContent(TypedDict, total=False): - type: Literal["text"] - text: str - markdown: NotRequired[bool] - entities: NotRequired[List[Dict[str, Any]]] +class SegPayload(TypedDict, total=False): + """ + 对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。 + """ + + type: str + data: str | List["SegPayload"] + translated_data: NotRequired[str | List["SegPayload"]] -class ImageContent(TypedDict, total=False): - type: Literal["image"] - url: str - mime_type: NotRequired[str] - width: NotRequired[int] - height: NotRequired[int] - file_id: NotRequired[str] +class UserInfoPayload(TypedDict, total=False): + platform: NotRequired[str] + user_id: NotRequired[str] + user_nickname: NotRequired[str] + user_cardname: NotRequired[str] + user_avatar: NotRequired[str] -class FileContent(TypedDict, total=False): - type: Literal["file"] - url: str - mime_type: NotRequired[str] - file_name: NotRequired[str] - file_size: NotRequired[int] - file_id: NotRequired[str] +class GroupInfoPayload(TypedDict, total=False): + platform: NotRequired[str] + group_id: NotRequired[str] + group_name: NotRequired[str] -class AudioContent(TypedDict, total=False): - type: Literal["audio"] - url: str - mime_type: NotRequired[str] - duration_ms: NotRequired[int] - file_id: NotRequired[str] +class FormatInfoPayload(TypedDict, total=False): + content_format: NotRequired[List[str]] + accept_format: NotRequired[List[str]] -class VideoContent(TypedDict, total=False): - type: Literal["video"] - url: str - mime_type: NotRequired[str] - duration_ms: NotRequired[int] - width: NotRequired[int] - height: NotRequired[int] - file_id: NotRequired[str] +class TemplateInfoPayload(TypedDict, total=False): + template_items: NotRequired[Dict[str, str]] + template_name: NotRequired[Dict[str, str]] + template_default: NotRequired[bool] -class EventContent(TypedDict): - type: Literal["event"] - event_type: EventType - raw: Dict[str, Any] +class MessageInfoPayload(TypedDict, total=False): + platform: NotRequired[str] + message_id: NotRequired[str] + time: NotRequired[float] + group_info: NotRequired[GroupInfoPayload] + user_info: NotRequired[UserInfoPayload] + format_info: NotRequired[FormatInfoPayload] + template_info: NotRequired[TemplateInfoPayload] + additional_config: NotRequired[Dict[str, Any]] - -class CommandContent(TypedDict, total=False): - type: Literal["command"] - name: str - args: Dict[str, Any] - - -class SystemContent(TypedDict): - type: Literal["system"] - text: str - - -Content = ( - TextContent - | ImageContent - | FileContent - | AudioContent - | VideoContent - | EventContent - | CommandContent - | SystemContent -) - - -class SenderInfo(TypedDict, total=False): - user_id: str - role: Role - display_name: NotRequired[str] - avatar_url: NotRequired[str] - raw: NotRequired[Dict[str, Any]] - - -class ChannelInfo(TypedDict, total=False): - channel_id: str - channel_type: Literal[ - "private", - "group", - "supergroup", - "channel", - "dm", - "room", - "thread", - ] - title: NotRequired[str] - workspace_id: NotRequired[str] - raw: NotRequired[Dict[str, Any]] +# ---------------------------- +# MessageEnvelope +# ---------------------------- class MessageEnvelope(TypedDict, total=False): - id: str - direction: MessageDirection - platform: str - timestamp_ms: int - channel: ChannelInfo - sender: SenderInfo - content: Content - conversation_id: str - thread_id: NotRequired[str] - reply_to_message_id: NotRequired[str] - correlation_id: NotRequired[str] - is_edited: NotRequired[bool] - is_ephemeral: NotRequired[bool] - raw_platform_message: NotRequired[Dict[str, Any]] - metadata: NotRequired[Dict[str, Any]] - schema_version: NotRequired[int] + """ + mofox-bus 传输层统一使用的消息信封。 + - 采用 maim_message 风格:message_info + message_segment。 + """ + + direction: MessageDirection + message_info: MessageInfoPayload + message_segment: SegPayload | List[SegPayload] + raw_message: NotRequired[Any] + raw_bytes: NotRequired[bytes] + message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名 + platform: NotRequired[str] # 快捷访问,等价于 message_info.platform + message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id + timestamp_ms: NotRequired[int] + correlation_id: NotRequired[str] + schema_version: NotRequired[int] + metadata: NotRequired[Dict[str, Any]] __all__ = [ - "AudioContent", - "ChannelInfo", - "CommandContent", - "Content", - "ContentType", - "EventContent", - "EventType", - "FileContent", - "ImageContent", + # maim_message style payloads + "SegPayload", + "UserInfoPayload", + "GroupInfoPayload", + "FormatInfoPayload", + "TemplateInfoPayload", + "MessageInfoPayload", + # legacy content style "MessageDirection", "MessageEnvelope", - "Role", - "SenderInfo", - "SystemContent", - "TextContent", - "VideoContent", ] diff --git a/src/plugin_system/base/base_adapter.py b/src/plugin_system/base/base_adapter.py index b320bf809..bf03112bb 100644 --- a/src/plugin_system/base/base_adapter.py +++ b/src/plugin_system/base/base_adapter.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional -from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreMessageSink, MessageEnvelope +from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope if TYPE_CHECKING: from src.plugin_system.base.base_plugin import BasePlugin @@ -47,7 +47,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC): def __init__( self, - core_sink: CoreMessageSink, + core_sink: CoreSink, plugin: Optional[BasePlugin] = None, **kwargs ): @@ -227,7 +227,7 @@ class BaseAdapter(MoFoxAdapterBase, ABC): ) @abstractmethod - def from_platform_message(self, raw: Any) -> MessageEnvelope: + async def from_platform_message(self, raw: Any) -> MessageEnvelope: """ 将平台原始消息转换为 MessageEnvelope diff --git a/src/plugin_system/core/adapter_manager.py b/src/plugin_system/core/adapter_manager.py index f3ee98a23..0345298df 100644 --- a/src/plugin_system/core/adapter_manager.py +++ b/src/plugin_system/core/adapter_manager.py @@ -7,130 +7,152 @@ Adapter 管理器 from __future__ import annotations import asyncio -import subprocess -import sys -from pathlib import Path +import importlib +import multiprocessing as mp from typing import TYPE_CHECKING, Dict, Optional if TYPE_CHECKING: from src.plugin_system.base.base_adapter import BaseAdapter +from mofox_bus import ProcessCoreSinkServer +from src.common.core_sink import get_core_sink from src.common.logger import get_logger logger = get_logger("adapter_manager") -class AdapterProcess: - """适配器子进程包装器""" - def __init__( - self, - adapter_name: str, - entry_path: Path, - python_executable: Optional[str] = None, - ): - self.adapter_name = adapter_name - self.entry_path = entry_path - self.python_executable = python_executable or sys.executable - self.process: Optional[subprocess.Popen] = None - self._monitor_task: Optional[asyncio.Task] = None + +def _load_class(module_name: str, class_name: str): + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +def _adapter_process_entry( + adapter_path: tuple[str, str], + plugin_info: dict | None, + incoming_queue: mp.Queue, + outgoing_queue: mp.Queue, +): + import asyncio + import contextlib + from mofox_bus import ProcessCoreSink + + async def _run() -> None: + adapter_cls = _load_class(*adapter_path) + plugin_instance = None + if plugin_info: + plugin_cls = _load_class(plugin_info["module"], plugin_info["class"]) + plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"]) + core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue) + adapter = adapter_cls(core_sink, plugin=plugin_instance) + await adapter.start() + try: + while not getattr(core_sink, "_closed", False): + await asyncio.sleep(0.2) + finally: + with contextlib.suppress(Exception): + await adapter.stop() + with contextlib.suppress(Exception): + await core_sink.close() + + asyncio.run(_run()) + + + +class AdapterProcess: + """适配器子进程包装器,负责适配器子进程的启动和生命周期管理""" + + def __init__(self, adapter: "BaseAdapter", core_sink) -> None: + self.adapter = adapter + self.adapter_name = adapter.adapter_name + self.process: mp.Process | None = None + self._ctx = mp.get_context("spawn") + self._incoming_queue: mp.Queue = self._ctx.Queue() + self._outgoing_queue: mp.Queue = self._ctx.Queue() + self._bridge: ProcessCoreSinkServer | None = None + self._core_sink = core_sink + self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__) + self._plugin_info = self._extract_plugin_info(adapter) + self._outgoing_handler = None + + @staticmethod + def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None: + plugin = getattr(adapter, "plugin", None) + if plugin is None: + return None + return { + "module": plugin.__class__.__module__, + "class": plugin.__class__.__name__, + "plugin_dir": getattr(plugin, "plugin_dir", ""), + "metadata": getattr(plugin, "plugin_meta", None), + } + + def _make_outgoing_handler(self): + async def _handler(envelope): + if self._bridge: + await self._bridge.push_outgoing(envelope) + return _handler async def start(self) -> bool: """启动适配器子进程""" try: logger.info(f"启动适配器子进程: {self.adapter_name}") - logger.debug(f"Python: {self.python_executable}") - logger.debug(f"Entry: {self.entry_path}") - - # 启动子进程 - self.process = subprocess.Popen( - [self.python_executable, str(self.entry_path)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, + self._bridge = ProcessCoreSinkServer( + incoming_queue=self._incoming_queue, + outgoing_queue=self._outgoing_queue, + core_handler=self._core_sink.send, + name=self.adapter_name, ) - - # 启动监控任务 - self._monitor_task = asyncio.create_task(self._monitor_process()) - - logger.info(f"适配器 {self.adapter_name} 子进程已启动 (PID: {self.process.pid})") + self._bridge.start() + if hasattr(self._core_sink, "set_outgoing_handler"): + self._outgoing_handler = self._make_outgoing_handler() + try: + self._core_sink.set_outgoing_handler(self._outgoing_handler) + except Exception: + logger.exception("Failed to register outgoing bridge for %s", self.adapter_name) + self.process = self._ctx.Process( + target=_adapter_process_entry, + args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue), + name=f"{self.adapter_name}-proc", + ) + self.process.start() + logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})") return True - except Exception as e: - logger.error(f"启动适配器 {self.adapter_name} 子进程失败: {e}", exc_info=True) + logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True) return False async def stop(self) -> None: """停止适配器子进程""" if not self.process: return - logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})") - try: - # 取消监控任务 - if self._monitor_task and not self._monitor_task.done(): - self._monitor_task.cancel() + remover = getattr(self._core_sink, "remove_outgoing_handler", None) + if callable(remover) and self._outgoing_handler: try: - await self._monitor_task - except asyncio.CancelledError: - pass - - # 终止进程 - self.process.terminate() - - # 等待进程退出(最多等待5秒) - try: - await asyncio.wait_for( - asyncio.to_thread(self.process.wait), - timeout=5.0 - ) - except asyncio.TimeoutError: - logger.warning(f"适配器 {self.adapter_name} 未能在5秒内退出,强制终止") - self.process.kill() - await asyncio.to_thread(self.process.wait) - - logger.info(f"适配器 {self.adapter_name} 子进程已停止") - + remover(self._outgoing_handler) + except Exception: + logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败") + if self._bridge: + await self._bridge.close() + if self.process.is_alive(): + self.process.join(timeout=5.0) + if self.process.is_alive(): + logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中") + self.process.terminate() + self.process.join() except Exception as e: - logger.error(f"停止适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True) + logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True) finally: self.process = None - async def _monitor_process(self) -> None: - """监控子进程状态""" - if not self.process: - return - - try: - # 在后台线程中等待进程退出 - return_code = await asyncio.to_thread(self.process.wait) - - if return_code != 0: - logger.error( - f"适配器 {self.adapter_name} 子进程异常退出 (返回码: {return_code})" - ) - - # 读取 stderr 输出 - if self.process.stderr: - stderr = self.process.stderr.read() - if stderr: - logger.error(f"错误输出:\n{stderr}") - else: - logger.info(f"适配器 {self.adapter_name} 子进程正常退出") - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"监控适配器 {self.adapter_name} 子进程时出错: {e}", exc_info=True) - def is_running(self) -> bool: - """检查进程是否正在运行""" + """适配器是否正在运行""" if not self.process: return False - return self.process.poll() is None - + return self.process.is_alive() class AdapterManager: """适配器管理器""" @@ -176,20 +198,17 @@ class AdapterManager: else: return await self._start_adapter_in_process(adapter) - async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool: - """在子进程中启动适配器""" - adapter_name = adapter.adapter_name - # 获取子进程入口脚本 - entry_path = adapter.get_subprocess_entry_path() - if not entry_path: - logger.error( - f"适配器 {adapter_name} 配置为子进程运行,但未提供有效的入口脚本" - ) + async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool: + """启动适配器子进程""" + adapter_name = adapter.adapter_name + try: + core_sink = get_core_sink() + except Exception as e: + logger.error(f"无法获取 core_sink,启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True) return False - # 创建并启动子进程 - adapter_process = AdapterProcess(adapter_name, entry_path) + adapter_process = AdapterProcess(adapter, core_sink) success = await adapter_process.start() if success: diff --git a/src/plugins/built_in/NEW_napcat_adapter/README.md b/src/plugins/built_in/NEW_napcat_adapter/README.md new file mode 100644 index 000000000..22ded558a --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/README.md @@ -0,0 +1,427 @@ +# NEW_napcat_adapter + +基于 mofox-bus v2.x 的 Napcat 适配器(使用 BaseAdapter 架构) + +## 🏗️ 架构设计 + +本插件采用 **BaseAdapter 继承模式** 重写,完全抛弃旧版 maim_message 库,改用 mofox-bus 的 TypedDict 数据结构。 + +### 核心组件 +- **NapcatAdapter**: 继承自 `mofox_bus.AdapterBase`,负责 OneBot 11 协议与 MessageEnvelope 的双向转换 +- **WebSocketAdapterOptions**: 自动管理 WebSocket 连接,提供 incoming_parser 和 outgoing_encoder +- **CoreMessageSink**: 通过 `InProcessCoreSink` 将消息递送到核心系统 +- **Handlers**: 独立的消息处理器,分为 to_core(接收)和 to_napcat(发送)两个方向 + +## 📁 项目结构 + +``` +NEW_napcat_adapter/ +├── plugin.py # ✅ 主插件文件(BaseAdapter实现) +├── _manifest.json # 插件清单 +│ +└── src/ + ├── event_models.py # ✅ OneBot事件类型常量 + ├── common/ + │ └── core_sink.py # ✅ 全局CoreSink访问点 + │ + ├── utils/ + │ ├── utils.py # ⏳ 工具函数(待实现) + │ ├── qq_emoji_list.py # ⏳ QQ表情映射(待实现) + │ ├── video_handler.py # ⏳ 视频处理(待实现) + │ └── message_chunker.py # ⏳ 消息切片(待实现) + │ + ├── websocket/ + │ └── (无需单独实现,使用WebSocketAdapterOptions) + │ + ├── database/ + │ └── database.py # ⏳ 数据库模型(待实现) + │ + └── handlers/ + ├── to_core/ # Napcat → MessageEnvelope 方向 + │ ├── message_handler.py # ⏳ 消息处理(部分完成) + │ ├── notice_handler.py # ⏳ 通知处理(待完成) + │ └── meta_event_handler.py # ⏳ 元事件(待完成) + │ + └── to_napcat/ # MessageEnvelope → Napcat API 方向 + └── send_handler.py # ⏳ 发送处理(部分完成) +``` + +## 🚀 快速开始 + +### 使用方式 + +1. **配置文件**: 在 `config/plugins/NEW_napcat_adapter.toml` 中配置 WebSocket URL 和其他参数 +2. **启动插件**: 插件自动在系统启动时加载 +3. **WebSocket连接**: 自动连接到 Napcat OneBot 11 服务器 + +## 🔑 核心数据结构 + +### MessageEnvelope (mofox-bus v2.x) + +```python +from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload + +# 创建消息信封 +envelope: MessageEnvelope = { + "direction": "input", + "message_info": { + "message_type": "group", + "message_id": "12345", + "self_id": "bot_qq", + "user_info": { + "user_id": "sender_qq", + "user_name": "发送者", + "user_displayname": "昵称" + }, + "group_info": { + "group_id": "group_id", + "group_name": "群名" + }, + "to_me": False + }, + "message_segment": { + "type": "seglist", + "data": [ + {"type": "text", "data": "hello"}, + {"type": "image", "data": "base64_data"} + ] + }, + "raw_message": "hello[图片]", + "platform": "napcat", + "message_id": "12345", + "timestamp_ms": 1234567890 +} +``` + +### BaseAdapter 核心方法 + +```python +class NapcatAdapter(BaseAdapter): + async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None: + """将 OneBot 11 事件转换为 MessageEnvelope""" + # 路由到对应的 Handler + + async def _send_platform_message(self, envelope: MessageEnvelope) -> dict[str, Any]: + """将 MessageEnvelope 转换为 OneBot 11 API 调用""" + # 调用 SendHandler 处理 +``` + +## 📝 实现进度 + +### ✅ 已完成的核心架构 + +1. **BaseAdapter 实现** (plugin.py) + - ✅ WebSocket 自动连接管理 + - ✅ from_platform_message() 事件路由 + - ✅ _send_platform_message() 消息发送 + - ✅ API 响应池机制(echo-based request-response) + - ✅ CoreSink 集成 + +2. **Handler 基础结构** + - ✅ MessageHandler 骨架(text、image、at 基本实现) + - ✅ NoticeHandler 骨架 + - ✅ MetaEventHandler 骨架 + - ✅ SendHandler 骨架(基本类型转换) + +3. **辅助组件** + - ✅ event_models.py(事件类型常量) + - ✅ core_sink.py(全局 CoreSink 访问) + - ✅ 配置 Schema 定义 + +### ⏳ 部分完成的功能 + +4. **消息类型处理** (MessageHandler) + - ✅ 基础消息类型:text, image, at + - ❌ 高级消息类型:face, reply, forward, video, json, file, rps, dice, shake + +5. **发送处理** (SendHandler) + - ✅ 基础 SegPayload 转换:text, image + - ❌ 高级 Seg 类型:emoji, voice, voiceurl, music, videourl, file, command + +### ❌ 待实现的功能 + +6. **通知事件处理** (NoticeHandler) + - ❌ 戳一戳事件 + - ❌ 表情回应事件 + - ❌ 撤回事件 + - ❌ 禁言事件 + +7. **工具函数** (utils.py) + - ❌ get_group_info + - ❌ get_member_info + - ❌ get_image_base64 + - ❌ get_message_detail + - ❌ get_record_detail + +8. **权限系统** + - ❌ check_allow_to_chat() + - ❌ 群组黑名单/白名单 + - ❌ 私聊黑名单/白名单 + - ❌ QQ机器人检测 + +9. **其他组件** + - ❌ 视频处理器 + - ❌ 消息切片器 + - ❌ 数据库模型 + - ❌ QQ 表情映射表 + +## 📋 下一步工作 + +### 优先级 1:完善消息处理(参考旧版 recv_handler/message_handler.py) + +1. **完整实现 MessageHandler.handle_raw_message()** + - [ ] face(表情)消息段 + - [ ] reply(回复)消息段 + - [ ] forward(转发)消息段解析 + - [ ] video(视频)消息段 + - [ ] json(JSON卡片)消息段 + - [ ] file(文件)消息段 + - [ ] rps/dice/shake(特殊消息) + +2. **实现工具函数**(参考旧版 utils.py) + - [ ] `get_group_info()` - 获取群组信息 + - [ ] `get_member_info()` - 获取成员信息 + - [ ] `get_image_base64()` - 下载图片并转Base64 + - [ ] `get_message_detail()` - 获取消息详情 + - [ ] `get_record_detail()` - 获取语音详情 + +3. **实现权限检查** + - [ ] `check_allow_to_chat()` - 检查是否允许聊天 + - [ ] 群组白名单/黑名单逻辑 + - [ ] 私聊白名单/黑名单逻辑 + - [ ] QQ机器人检测(ban_qq_bot) + +### 优先级 2:完善发送处理(参考旧版 send_handler.py) + +4. **完整实现 SendHandler._convert_seg_to_onebot()** + - [ ] emoji(表情回应)命令 + - [ ] voice(语音)消息段 + - [ ] voiceurl(语音URL)消息段 + - [ ] music(音乐卡片)消息段 + - [ ] videourl(视频URL)消息段 + - [ ] file(文件)消息段 + - [ ] command(命令)消息段 + +5. **实现命令处理** + - [ ] GROUP_BAN(禁言) + - [ ] GROUP_KICK(踢人) + - [ ] SEND_POKE(戳一戳) + - [ ] DELETE_MSG(撤回消息) + - [ ] GROUP_WHOLE_BAN(全员禁言) + - [ ] SET_GROUP_CARD(设置群名片) + - [ ] SET_GROUP_ADMIN(设置管理员) + +### 优先级 3:补全其他组件(参考旧版对应文件) + +6. **NoticeHandler 实现** + - [ ] 戳一戳通知(notify.poke) + - [ ] 表情回应通知(notice.group_emoji_like) + - [ ] 消息撤回通知(notice.group_recall) + - [ ] 禁言通知(notice.group_ban) + +7. **辅助组件** + - [ ] `qq_emoji_list.py` - QQ表情ID映射表 + - [ ] `video_handler.py` - 视频处理(ffmpeg封面提取) + - [ ] `message_chunker.py` - 消息分块与重组 + - [ ] `database.py` - 数据库模型(如有需要) + +### 优先级 4:测试与优化 + +8. **功能测试** + - [ ] 文本消息收发 + - [ ] 图片消息收发 + - [ ] @消息处理 + - [ ] 表情/语音/视频消息 + - [ ] 转发消息解析 + - [ ] 所有命令功能 + - [ ] 通知事件处理 + +9. **性能优化** + - [ ] 消息处理并发性能 + - [ ] API响应池性能 + - [ ] 内存占用优化 + +## 🔍 关键实现细节 + +### 1. MessageEnvelope vs 旧版 MessageBase + +**不再使用 Seg dataclass**,全部使用 TypedDict: + +```python +# ❌ 旧版(maim_message) +from mofox_bus import Seg, MessageBase + +seg = Seg(type="text", data="hello") +message = MessageBase(message_info=info, message_segment=seg) + +# ✅ 新版(mofox-bus v2.x) +from mofox_bus import SegPayload, MessageEnvelope + +seg_payload: SegPayload = {"type": "text", "data": "hello"} +envelope: MessageEnvelope = { + "direction": "input", + "message_info": {...}, + "message_segment": seg_payload, + ... +} +``` + +### 2. Handler 架构模式 + +**接收方向** (to_core): +```python +class MessageHandler: + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + + async def handle_raw_message(self, data: dict[str, Any]) -> MessageEnvelope: + # 1. 解析 OneBot 11 数据 + # 2. 构建 message_info(MessageInfoPayload) + # 3. 转换消息段为 SegPayload + # 4. 返回完整的 MessageEnvelope +``` + +**发送方向** (to_napcat): +```python +class SendHandler: + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + + async def handle_message(self, envelope: MessageEnvelope) -> dict[str, Any]: + # 1. 从 envelope 提取 message_segment + # 2. 递归转换 SegPayload → OneBot 格式 + # 3. 调用 adapter.send_napcat_api() 发送 +``` + +### 3. API 调用模式(响应池) + +```python +# 在 NapcatAdapter 中 +async def send_napcat_api(self, action: str, params: dict[str, Any]) -> dict[str, Any]: + # 1. 生成唯一 echo + echo = f"{action}_{uuid.uuid4()}" + + # 2. 创建 Future 等待响应 + future = asyncio.Future() + self._response_pool[echo] = future + + # 3. 发送请求(通过 WebSocket) + await self._send_request({"action": action, "params": params, "echo": echo}) + + # 4. 等待响应(带超时) + try: + result = await asyncio.wait_for(future, timeout=10.0) + return result + finally: + self._response_pool.pop(echo, None) + +# 响应回来时(在 incoming_parser 中) +def _handle_api_response(data: dict[str, Any]): + echo = data.get("echo") + if echo in adapter._response_pool: + adapter._response_pool[echo].set_result(data) +``` + +### 4. 类型提示技巧 + +处理 TypedDict 的严格类型检查: + +```python +# 使用 type: ignore 标注(编译时是 TypedDict,运行时是 dict) +envelope: MessageEnvelope = { + "direction": "input", + ... +} # type: ignore[typeddict-item] + +# 或在函数签名中使用 dict[str, Any] +async def from_platform_message(self, message: dict[str, Any]) -> MessageEnvelope | None: + ... + return envelope # type: ignore[return-value] +``` + +## 🔍 测试检查清单 + +- [ ] 文本消息接收/发送 +- [ ] 图片消息接收/发送 +- [ ] 语音消息接收/发送 +- [ ] 视频消息接收/发送 +- [ ] @消息接收/发送 +- [ ] 回复消息接收/发送 +- [ ] 转发消息接收 +- [ ] JSON消息接收 +- [ ] 文件消息接收/发送 +- [ ] 禁言命令 +- [ ] 踢人命令 +- [ ] 戳一戳命令 +- [ ] 表情回应命令 +- [ ] 通知事件处理 +- [ ] 元事件处理 + +## 📚 参考资料 + +- **mofox-bus 文档**: 查看 `mofox_bus/types.py` 了解 TypedDict 定义 +- **BaseAdapter 示例**: 参考 `docs/mofox_bus_demo_adapter.py` +- **旧版实现**: `src/plugins/built_in/napcat_adapter_plugin/` (仅参考逻辑) +- **OneBot 11 协议**: [OneBot 11 标准](https://github.com/botuniverse/onebot-11) + +## ⚠️ 重要注意事项 + +1. **完全抛弃旧版数据结构** + - ❌ 不再使用 `Seg` dataclass + - ❌ 不再使用 `MessageBase` 类 + - ✅ 全部使用 `SegPayload`(TypedDict) + - ✅ 全部使用 `MessageEnvelope`(TypedDict) + +2. **BaseAdapter 生命周期** + - `__init__()` 中初始化同步资源 + - `start()` 中执行异步初始化(WebSocket连接自动建立) + - `stop()` 中清理资源(WebSocket自动断开) + +3. **WebSocketAdapterOptions 自动管理** + - 无需手动管理 WebSocket 连接 + - incoming_parser 自动解析接收数据 + - outgoing_encoder 自动编码发送数据 + - 重连机制由基类处理 + +4. **CoreSink 依赖注入** + - 必须在插件加载后调用 `set_core_sink(sink)` + - 通过 `get_core_sink()` 全局访问 + - 用于将消息递送到核心系统 + +5. **类型安全与灵活性平衡** + - TypedDict 在编译时提供类型检查 + - 运行时仍是普通 dict,可灵活操作 + - 必要时使用 `type: ignore` 抑制误报 + +6. **参考旧版但不照搬** + - 旧版逻辑流程可参考 + - 数据结构需完全重写 + - API调用模式已改变(响应池) + +## 📊 预估工作量 + +- ✅ 核心架构: **已完成** (BaseAdapter + Handlers 骨架) +- ⏳ 消息处理完善: **4-6 小时** (所有消息类型 + 工具函数) +- ⏳ 发送处理完善: **3-4 小时** (所有 Seg 类型 + 命令) +- ⏳ 通知事件处理: **2-3 小时** (poke/emoji_like/recall/ban) +- ⏳ 测试调试: **2-4 小时** (全流程测试) +- **总剩余时间: 11-17 小时** + +## ✅ 完成标准 + +当以下条件全部满足时,重写完成: + +1. ✅ BaseAdapter 架构实现完成 +2. ⏳ 所有 OneBot 11 消息类型支持 +3. ⏳ 所有发送消息段类型支持 +4. ⏳ 所有通知事件正确处理 +5. ⏳ 权限系统集成完成 +6. ⏳ 与旧版功能完全对等 +7. ⏳ 所有测试用例通过 + +--- + +**最后更新**: 2025-11-23 +**架构状态**: ✅ 核心架构完成 +**实现状态**: ⏳ 消息处理部分完成,需完善细节 +**预计完成**: 根据优先级,核心功能预计 1-2 个工作日 diff --git a/src/plugins/built_in/NEW_napcat_adapter/__init__.py b/src/plugins/built_in/NEW_napcat_adapter/__init__.py new file mode 100644 index 000000000..d7f61b2f2 --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/__init__.py @@ -0,0 +1,16 @@ +from src.plugin_system.base.plugin_metadata import PluginMetadata + +__plugin_meta__ = PluginMetadata( + name="napcat_plugin", + description="基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接", + usage="该插件提供 `napcat_tool` tool。", + version="1.0.0", + author="Windpicker_owo", + license="GPL-v3.0-or-later", + repository_url="https://github.com/Windpicker-owo", + keywords=["qq", "bot", "napcat", "onebot", "api", "websocket"], + categories=["protocol"], + extra={ + "is_built_in": False, + }, +) diff --git a/src/plugins/built_in/NEW_napcat_adapter/plugin.py b/src/plugins/built_in/NEW_napcat_adapter/plugin.py new file mode 100644 index 000000000..aeda7e44e --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/plugin.py @@ -0,0 +1,330 @@ +""" +Napcat 适配器(基于 MoFox-Bus 完全重写版) + +核心流程: +1. Napcat WebSocket 连接 → 接收 OneBot 格式消息 +2. from_platform_message: OneBot dict → MessageEnvelope +3. CoreSink → 推送到 MoFox-Bot 核心 +4. 核心回复 → _send_platform_message: MessageEnvelope → OneBot API 调用 +""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import Any, ClassVar, Dict, List, Optional + +import orjson +import websockets + +from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions +from src.common.logger import get_logger +from src.plugin_system import register_plugin +from src.plugin_system.base import BaseAdapter, BasePlugin +from src.plugin_system.apis import config_api + +from .src.handlers.to_core.message_handler import MessageHandler +from .src.handlers.to_core.notice_handler import NoticeHandler +from .src.handlers.to_core.meta_event_handler import MetaEventHandler +from .src.handlers.to_napcat.send_handler import SendHandler + +logger = get_logger("napcat_adapter") + + +class NapcatAdapter(BaseAdapter): + """Napcat 适配器 - 完全基于 mofox-bus 架构""" + + adapter_name = "napcat_adapter" + adapter_version = "2.0.0" + adapter_author = "MoFox Team" + adapter_description = "基于 MoFox-Bus 的 Napcat/OneBot 11 适配器" + platform = "qq" + + run_in_subprocess = False + subprocess_entry = None + + def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None): + """初始化 Napcat 适配器""" + # 从插件配置读取 WebSocket URL + if plugin: + mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse") + host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost") + port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095) + url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "") + access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "") + + if mode == "forward" and url: + ws_url = url + else: + ws_url = f"ws://{host}:{port}" + + headers = {} + if access_token: + headers["Authorization"] = f"Bearer {access_token}" + else: + ws_url = "ws://127.0.0.1:8095" + headers = {} + + # 配置 WebSocket 传输 + transport = WebSocketAdapterOptions( + url=ws_url, + headers=headers if headers else None, + incoming_parser=self._parse_napcat_message, + outgoing_encoder=self._encode_napcat_response, + ) + + super().__init__(core_sink, plugin=plugin, transport=transport) + + # 初始化处理器 + self.message_handler = MessageHandler(self) + self.notice_handler = NoticeHandler(self) + self.meta_event_handler = MetaEventHandler(self) + self.send_handler = SendHandler(self) + + # 响应池:用于存储等待的 API 响应 + self._response_pool: Dict[str, asyncio.Future] = {} + self._response_timeout = 30.0 + + # WebSocket 连接(用于发送 API 请求) + # 注意:_ws 继承自 BaseAdapter,是 WebSocketLike 协议类型 + self._napcat_ws = None # 可选的额外连接引用 + + async def on_adapter_loaded(self) -> None: + """适配器加载时的初始化""" + logger.info("Napcat 适配器正在启动...") + + # 设置处理器配置 + if self.plugin: + self.message_handler.set_plugin_config(self.plugin.config) + self.notice_handler.set_plugin_config(self.plugin.config) + self.meta_event_handler.set_plugin_config(self.plugin.config) + self.send_handler.set_plugin_config(self.plugin.config) + + logger.info("Napcat 适配器已加载") + + async def on_adapter_unloaded(self) -> None: + """适配器卸载时的清理""" + logger.info("Napcat 适配器正在关闭...") + + # 清理响应池 + for future in self._response_pool.values(): + if not future.done(): + future.cancel() + self._response_pool.clear() + + logger.info("Napcat 适配器已关闭") + + def _parse_napcat_message(self, raw: str | bytes) -> Any: + """解析 Napcat/OneBot 消息""" + try: + if isinstance(raw, bytes): + data = orjson.loads(raw) + else: + data = orjson.loads(raw) + return data + except Exception as e: + logger.error(f"解析 Napcat 消息失败: {e}") + raise + + def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes: + """编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)""" + return orjson.dumps(envelope) + + async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override] + """ + 将 Napcat/OneBot 原始消息转换为 MessageEnvelope + + 这是核心转换方法,处理: + - message 事件 → 消息 + - notice 事件 → 通知(戳一戳、表情回复等) + - meta_event 事件 → 元事件(心跳、生命周期) + - API 响应 → 存入响应池 + """ + post_type = raw.get("post_type") + + # API 响应(没有 post_type,有 echo) + if post_type is None and "echo" in raw: + echo = raw.get("echo") + if echo and echo in self._response_pool: + future = self._response_pool[echo] + if not future.done(): + future.set_result(raw) + # API 响应不需要转换为 MessageEnvelope,返回空信封 + return self._create_empty_envelope() + + # 消息事件 + if post_type == "message": + return await self.message_handler.handle_raw_message(raw) # type: ignore[return-value] + + # 通知事件 + elif post_type == "notice": + return await self.notice_handler.handle_notice(raw) # type: ignore[return-value] + + # 元事件 + elif post_type == "meta_event": + return await self.meta_event_handler.handle_meta_event(raw) # type: ignore[return-value] + + # 未知事件类型 + else: + logger.warning(f"未知的事件类型: {post_type}") + return self._create_empty_envelope() # type: ignore[return-value] + + async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override] + """ + 将 MessageEnvelope 转换并发送到 Napcat + + 这里不直接通过 WebSocket 发送 envelope, + 而是调用 Napcat API(send_group_msg, send_private_msg 等) + """ + await self.send_handler.handle_message(envelope) + + def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return] + """创建一个空的消息信封(用于不需要处理的事件)""" + import time + return { + "direction": "incoming", + "message_info": { + "platform": self.platform, + "message_id": str(uuid.uuid4()), + "time": time.time(), + }, + "message_segment": {"type": "text", "data": "[系统事件]"}, + "timestamp_ms": int(time.time() * 1000), + } + + async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]: + """ + 发送 Napcat API 请求并等待响应 + + Args: + action: API 动作名称(如 send_group_msg) + params: API 参数 + timeout: 超时时间(秒) + + Returns: + API 响应数据 + """ + if not self._ws: + raise RuntimeError("WebSocket 连接未建立") + + # 生成唯一的 echo ID + echo = str(uuid.uuid4()) + + # 创建 Future 用于等待响应 + future = asyncio.Future() + self._response_pool[echo] = future + + # 构造请求 + request = orjson.dumps({ + "action": action, + "params": params, + "echo": echo, + }) + + try: + # 发送请求 + await self._ws.send(request) + + # 等待响应 + response = await asyncio.wait_for(future, timeout=timeout) + return response + + except asyncio.TimeoutError: + logger.error(f"API 请求超时: {action}") + raise + except Exception as e: + logger.error(f"API 请求失败: {action}, 错误: {e}") + raise + finally: + # 清理响应池 + self._response_pool.pop(echo, None) + + def get_ws_connection(self): + """获取 WebSocket 连接(用于发送 API 请求)""" + if not self._ws: + raise RuntimeError("WebSocket 连接未建立") + return self._ws + + +@register_plugin +class NapcatAdapterPlugin(BasePlugin): + """Napcat 适配器插件""" + + plugin_name = "napcat_adapter_plugin" + enable_plugin = True + plugin_version = "2.0.0" + plugin_author = "MoFox Team" + plugin_description = "Napcat/OneBot 11 适配器(基于 MoFox-Bus 重写)" + + # 配置 Schema + config_schema: ClassVar[dict] = { + "plugin": { + "name": {"type": str, "default": "napcat_adapter_plugin"}, + "version": {"type": str, "default": "2.0.0"}, + "enabled": {"type": bool, "default": True}, + }, + "napcat_server": { + "mode": { + "type": str, + "default": "reverse", + "description": "连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", + }, + "host": {"type": str, "default": "localhost"}, + "port": {"type": int, "default": 8095}, + "url": {"type": str, "default": "", "description": "正向连接时的完整URL"}, + "access_token": {"type": str, "default": ""}, + }, + "features": { + "group_list_type": {"type": str, "default": "blacklist"}, + "group_list": {"type": list, "default": []}, + "private_list_type": {"type": str, "default": "blacklist"}, + "private_list": {"type": list, "default": []}, + "ban_user_id": {"type": list, "default": []}, + "ban_qq_bot": {"type": bool, "default": False}, + }, + } + + def __init__(self, plugin_dir: str = "", metadata: Any = None): + # 如果没有提供参数,创建一个默认的元数据 + if metadata is None: + from src.plugin_system.base.plugin_metadata import PluginMetadata + metadata = PluginMetadata( + name=self.plugin_name, + version=self.plugin_version, + author=self.plugin_author, + description=self.plugin_description, + usage="", + dependencies=[], + python_dependencies=[], + ) + + if not plugin_dir: + from pathlib import Path + plugin_dir = str(Path(__file__).parent) + + super().__init__(plugin_dir, metadata) + self._adapter: Optional[NapcatAdapter] = None + + async def on_plugin_loaded(self): + """插件加载时启动适配器""" + logger.info("Napcat 适配器插件正在加载...") + + # 获取核心 Sink + from src.common.core_sink import get_core_sink + core_sink = get_core_sink() + + # 创建并启动适配器 + self._adapter = NapcatAdapter(core_sink, plugin=self) + await self._adapter.start() + + logger.info("Napcat 适配器插件已加载") + + async def on_plugin_unloaded(self): + """插件卸载时停止适配器""" + if self._adapter: + await self._adapter.stop() + logger.info("Napcat 适配器插件已卸载") + + def get_plugin_components(self) -> list: + """返回适配器组件""" + return [(NapcatAdapter.get_adapter_info(), NapcatAdapter)] diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/__init__.py b/src/plugins/built_in/NEW_napcat_adapter/src/__init__.py new file mode 100644 index 000000000..e7b923f49 --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/__init__.py @@ -0,0 +1 @@ +"""工具模块""" diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/event_models.py b/src/plugins/built_in/NEW_napcat_adapter/src/event_models.py new file mode 100644 index 000000000..ef5e2317e --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/event_models.py @@ -0,0 +1,310 @@ +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + group_ban = "group_ban" # 群禁言 + group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 + group_upload = "group_upload" # 群文件上传 + + class Notify: + poke = "poke" # 戳一戳 + input_status = "input_status" # 正在输入 + + class GroupBan: + ban = "ban" # 禁言 + lift_ban = "lift_ban" # 解除禁言 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + json = "json" # json消息 + file = "file" # 文件 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + DELETE_MSG = "delete_msg" # 撤回消息 + AI_VOICE_SEND = "ai_voice_send" # AI语音发送 + SET_EMOJI_LIKE = "set_msg_emoji_like" # 设置表情回应 + SEND_AT_MESSAGE = "send_at_message" # 发送@消息 + SEND_LIKE = "send_like" # 发送点赞 + + def __str__(self) -> str: + return self.value + + +# 支持的消息格式 +ACCEPT_FORMAT = [ + "text", + "image", + "emoji", + "reply", + "voice", + "command", + "voiceurl", + "music", + "videourl", + "file", +] + +# 插件名称 +PLUGIN_NAME = "NEW_napcat_adapter" + +# QQ表情映射表 +QQ_FACE = { + "0": "[表情:惊讶]", + "1": "[表情:撇嘴]", + "2": "[表情:色]", + "3": "[表情:发呆]", + "4": "[表情:得意]", + "5": "[表情:流泪]", + "6": "[表情:害羞]", + "7": "[表情:闭嘴]", + "8": "[表情:睡]", + "9": "[表情:大哭]", + "10": "[表情:尴尬]", + "11": "[表情:发怒]", + "12": "[表情:调皮]", + "13": "[表情:呲牙]", + "14": "[表情:微笑]", + "15": "[表情:难过]", + "16": "[表情:酷]", + "18": "[表情:抓狂]", + "19": "[表情:吐]", + "20": "[表情:偷笑]", + "21": "[表情:可爱]", + "22": "[表情:白眼]", + "23": "[表情:傲慢]", + "24": "[表情:饥饿]", + "25": "[表情:困]", + "26": "[表情:惊恐]", + "27": "[表情:流汗]", + "28": "[表情:憨笑]", + "29": "[表情:悠闲]", + "30": "[表情:奋斗]", + "31": "[表情:咒骂]", + "32": "[表情:疑问]", + "33": "[表情:嘘]", + "34": "[表情:晕]", + "35": "[表情:折磨]", + "36": "[表情:衰]", + "37": "[表情:骷髅]", + "38": "[表情:敲打]", + "39": "[表情:再见]", + "41": "[表情:发抖]", + "42": "[表情:爱情]", + "43": "[表情:跳跳]", + "46": "[表情:猪头]", + "49": "[表情:拥抱]", + "53": "[表情:蛋糕]", + "56": "[表情:刀]", + "59": "[表情:便便]", + "60": "[表情:咖啡]", + "63": "[表情:玫瑰]", + "64": "[表情:凋谢]", + "66": "[表情:爱心]", + "67": "[表情:心碎]", + "74": "[表情:太阳]", + "75": "[表情:月亮]", + "76": "[表情:赞]", + "77": "[表情:踩]", + "78": "[表情:握手]", + "79": "[表情:胜利]", + "85": "[表情:飞吻]", + "86": "[表情:怄火]", + "89": "[表情:西瓜]", + "96": "[表情:冷汗]", + "97": "[表情:擦汗]", + "98": "[表情:抠鼻]", + "99": "[表情:鼓掌]", + "100": "[表情:糗大了]", + "101": "[表情:坏笑]", + "102": "[表情:左哼哼]", + "103": "[表情:右哼哼]", + "104": "[表情:哈欠]", + "105": "[表情:鄙视]", + "106": "[表情:委屈]", + "107": "[表情:快哭了]", + "108": "[表情:阴险]", + "109": "[表情:左亲亲]", + "110": "[表情:吓]", + "111": "[表情:可怜]", + "112": "[表情:菜刀]", + "114": "[表情:篮球]", + "116": "[表情:示爱]", + "118": "[表情:抱拳]", + "119": "[表情:勾引]", + "120": "[表情:拳头]", + "121": "[表情:差劲]", + "123": "[表情:NO]", + "124": "[表情:OK]", + "125": "[表情:转圈]", + "129": "[表情:挥手]", + "137": "[表情:鞭炮]", + "144": "[表情:喝彩]", + "146": "[表情:爆筋]", + "147": "[表情:棒棒糖]", + "169": "[表情:手枪]", + "171": "[表情:茶]", + "172": "[表情:眨眼睛]", + "173": "[表情:泪奔]", + "174": "[表情:无奈]", + "175": "[表情:卖萌]", + "176": "[表情:小纠结]", + "177": "[表情:喷血]", + "178": "[表情:斜眼笑]", + "179": "[表情:doge]", + "181": "[表情:戳一戳]", + "182": "[表情:笑哭]", + "183": "[表情:我最美]", + "185": "[表情:羊驼]", + "187": "[表情:幽灵]", + "201": "[表情:点赞]", + "212": "[表情:托腮]", + "262": "[表情:脑阔疼]", + "263": "[表情:沧桑]", + "264": "[表情:捂脸]", + "265": "[表情:辣眼睛]", + "266": "[表情:哦哟]", + "267": "[表情:头秃]", + "268": "[表情:问号脸]", + "269": "[表情:暗中观察]", + "270": "[表情:emm]", + "271": "[表情:吃瓜]", + "272": "[表情:呵呵哒]", + "273": "[表情:我酸了]", + "277": "[表情:滑稽狗头]", + "281": "[表情:翻白眼]", + "282": "[表情:敬礼]", + "283": "[表情:狂笑]", + "284": "[表情:面无表情]", + "285": "[表情:摸鱼]", + "286": "[表情:魔鬼笑]", + "287": "[表情:哦]", + "289": "[表情:睁眼]", + "293": "[表情:摸锦鲤]", + "294": "[表情:期待]", + "295": "[表情:拿到红包]", + "297": "[表情:拜谢]", + "298": "[表情:元宝]", + "299": "[表情:牛啊]", + "300": "[表情:胖三斤]", + "302": "[表情:左拜年]", + "303": "[表情:右拜年]", + "305": "[表情:右亲亲]", + "306": "[表情:牛气冲天]", + "307": "[表情:喵喵]", + "311": "[表情:打call]", + "312": "[表情:变形]", + "314": "[表情:仔细分析]", + "317": "[表情:菜汪]", + "318": "[表情:崇拜]", + "319": "[表情:比心]", + "320": "[表情:庆祝]", + "323": "[表情:嫌弃]", + "324": "[表情:吃糖]", + "325": "[表情:惊吓]", + "326": "[表情:生气]", + "332": "[表情:举牌牌]", + "333": "[表情:烟花]", + "334": "[表情:虎虎生威]", + "336": "[表情:豹富]", + "337": "[表情:花朵脸]", + "338": "[表情:我想开了]", + "339": "[表情:舔屏]", + "341": "[表情:打招呼]", + "342": "[表情:酸Q]", + "343": "[表情:我方了]", + "344": "[表情:大怨种]", + "345": "[表情:红包多多]", + "346": "[表情:你真棒棒]", + "347": "[表情:大展宏兔]", + "349": "[表情:坚强]", + "350": "[表情:贴贴]", + "351": "[表情:敲敲]", + "352": "[表情:咦]", + "353": "[表情:拜托]", + "354": "[表情:尊嘟假嘟]", + "355": "[表情:耶]", + "356": "[表情:666]", + "357": "[表情:裂开]", + "392": "[表情:龙年快乐]", + "393": "[表情:新年中龙]", + "394": "[表情:新年大龙]", + "395": "[表情:略略略]", + "396": "[表情:龙年快乐]", + "424": "[表情:按钮]", +} + + +__all__ = [ + "MetaEventType", + "MessageType", + "NoticeType", + "RealMessageType", + "MessageSentType", + "CommandType", + "ACCEPT_FORMAT", + "PLUGIN_NAME", + "QQ_FACE", +] diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/__init__.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/__init__.py new file mode 100644 index 000000000..23a68615d --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/__init__.py @@ -0,0 +1 @@ +"""处理器模块""" diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/__init__.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/__init__.py new file mode 100644 index 000000000..dc931068c --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/__init__.py @@ -0,0 +1 @@ +"""接收方向处理器""" diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py new file mode 100644 index 000000000..fc9655b0f --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/message_handler.py @@ -0,0 +1,126 @@ +"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from src.common.logger import get_logger +from src.plugin_system.apis import config_api + +if TYPE_CHECKING: + from ...plugin import NapcatAdapter + +logger = get_logger("napcat_adapter.message_handler") + + +class MessageHandler: + """处理来自 Napcat 的消息事件""" + + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + self.plugin_config: Optional[Dict[str, Any]] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置""" + self.plugin_config = config + + async def handle_raw_message(self, raw: Dict[str, Any]): + """ + 处理原始消息并转换为 MessageEnvelope + + Args: + raw: OneBot 原始消息数据 + + Returns: + MessageEnvelope (dict) + """ + from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload + + message_type = raw.get("message_type") + message_id = str(raw.get("message_id", "")) + message_time = time.time() + + # 构造用户信息 + sender_info = raw.get("sender", {}) + user_info: UserInfoPayload = { + "platform": "qq", + "user_id": str(sender_info.get("user_id", "")), + "user_nickname": sender_info.get("nickname", ""), + "user_cardname": sender_info.get("card", ""), + "user_avatar": sender_info.get("avatar", ""), + } + + # 构造群组信息(如果是群消息) + group_info: Optional[GroupInfoPayload] = None + if message_type == "group": + group_id = raw.get("group_id") + if group_id: + group_info = { + "platform": "qq", + "group_id": str(group_id), + "group_name": "", # 可以通过 API 获取 + } + + # 解析消息段 + message_segments = raw.get("message", []) + seg_list: List[SegPayload] = [] + + for seg in message_segments: + seg_type = seg.get("type", "") + seg_data = seg.get("data", {}) + + # 转换为 SegPayload + if seg_type == "text": + seg_list.append({ + "type": "text", + "data": seg_data.get("text", "") + }) + elif seg_type == "image": + # 这里需要下载图片并转换为 base64(简化版本) + seg_list.append({ + "type": "image", + "data": seg_data.get("url", "") # 实际应该转换为 base64 + }) + elif seg_type == "at": + seg_list.append({ + "type": "at", + "data": f"{seg_data.get('qq', '')}" + }) + # 其他消息类型... + + # 构造 MessageInfoPayload + message_info = { + "platform": "qq", + "message_id": message_id, + "time": message_time, + "user_info": user_info, + "format_info": { + "content_format": ["text", "image"], # 根据实际消息类型设置 + "accept_format": ["text", "image", "emoji", "voice"], + }, + } + + # 添加群组信息(如果存在) + if group_info: + message_info["group_info"] = group_info + + # 构造 MessageEnvelope + envelope = { + "direction": "incoming", + "message_info": message_info, + "message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}), + "raw_message": raw.get("raw_message", ""), + "platform": "qq", + "message_id": message_id, + "timestamp_ms": int(message_time * 1000), + } + + return envelope + + + + + + + diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/meta_event_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/meta_event_handler.py new file mode 100644 index 000000000..d4b0c2b35 --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/meta_event_handler.py @@ -0,0 +1,41 @@ +"""元事件处理器""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.common.logger import get_logger + +if TYPE_CHECKING: + from ...plugin import NapcatAdapter + +logger = get_logger("napcat_adapter.meta_event_handler") + + +class MetaEventHandler: + """处理 Napcat 元事件(心跳、生命周期)""" + + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + self.plugin_config: Optional[Dict[str, Any]] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置""" + self.plugin_config = config + + async def handle_meta_event(self, raw: Dict[str, Any]): + """处理元事件""" + # 简化版本:返回一个空的 MessageEnvelope + import time + import uuid + + return { + "direction": "incoming", + "message_info": { + "platform": "qq", + "message_id": str(uuid.uuid4()), + "time": time.time(), + }, + "message_segment": {"type": "text", "data": "[元事件]"}, + "timestamp_ms": int(time.time() * 1000), + } diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/notice_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/notice_handler.py new file mode 100644 index 000000000..d9244fc0a --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_core/notice_handler.py @@ -0,0 +1,41 @@ +"""通知事件处理器""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.common.logger import get_logger + +if TYPE_CHECKING: + from ...plugin import NapcatAdapter + +logger = get_logger("napcat_adapter.notice_handler") + + +class NoticeHandler: + """处理 Napcat 通知事件(戳一戳、表情回复等)""" + + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + self.plugin_config: Optional[Dict[str, Any]] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置""" + self.plugin_config = config + + async def handle_notice(self, raw: Dict[str, Any]): + """处理通知事件""" + # 简化版本:返回一个空的 MessageEnvelope + import time + import uuid + + return { + "direction": "incoming", + "message_info": { + "platform": "qq", + "message_id": str(uuid.uuid4()), + "time": time.time(), + }, + "message_segment": {"type": "text", "data": "[通知事件]"}, + "timestamp_ms": int(time.time() * 1000), + } diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/__init__.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/__init__.py new file mode 100644 index 000000000..86f824a55 --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/__init__.py @@ -0,0 +1 @@ +"""发送方向处理器""" diff --git a/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/send_handler.py b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/send_handler.py new file mode 100644 index 000000000..1606950fe --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/src/handlers/to_napcat/send_handler.py @@ -0,0 +1,77 @@ +"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.common.logger import get_logger + +if TYPE_CHECKING: + from ...plugin import NapcatAdapter + +logger = get_logger("napcat_adapter.send_handler") + + +class SendHandler: + """处理向 Napcat 发送消息""" + + def __init__(self, adapter: "NapcatAdapter"): + self.adapter = adapter + self.plugin_config: Optional[Dict[str, Any]] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置""" + self.plugin_config = config + + async def handle_message(self, envelope) -> None: + """ + 处理发送消息 + + 将 MessageEnvelope 转换为 OneBot API 调用 + """ + message_info = envelope.get("message_info", {}) + message_segment = envelope.get("message_segment", {}) + + # 获取群组和用户信息 + group_info = message_info.get("group_info") + user_info = message_info.get("user_info") + + # 构造消息内容 + message = self._convert_seg_to_onebot(message_segment) + + # 发送消息 + if group_info: + # 发送群消息 + group_id = group_info.get("group_id") + if group_id: + await self.adapter.send_napcat_api("send_group_msg", { + "group_id": int(group_id), + "message": message, + }) + elif user_info: + # 发送私聊消息 + user_id = user_info.get("user_id") + if user_id: + await self.adapter.send_napcat_api("send_private_msg", { + "user_id": int(user_id), + "message": message, + }) + + def _convert_seg_to_onebot(self, seg: Dict[str, Any]) -> list: + """将 SegPayload 转换为 OneBot 消息格式""" + seg_type = seg.get("type", "") + seg_data = seg.get("data", "") + + if seg_type == "text": + return [{"type": "text", "data": {"text": seg_data}}] + elif seg_type == "image": + return [{"type": "image", "data": {"file": f"base64://{seg_data}"}}] + elif seg_type == "seglist": + # 递归处理列表 + result = [] + for sub_seg in seg_data: + result.extend(self._convert_seg_to_onebot(sub_seg)) + return result + else: + # 默认作为文本 + return [{"type": "text", "data": {"text": str(seg_data)}}] diff --git a/src/plugins/built_in/NEW_napcat_adapter/stream_router.py b/src/plugins/built_in/NEW_napcat_adapter/stream_router.py new file mode 100644 index 000000000..70169ad0f --- /dev/null +++ b/src/plugins/built_in/NEW_napcat_adapter/stream_router.py @@ -0,0 +1,350 @@ +""" +按聊天流分配消费者的消息路由系统 + +核心思想: +- 为每个活跃的聊天流(stream_id)创建独立的消息队列和消费者协程 +- 同一聊天流的消息由同一个 worker 处理,保证顺序性 +- 不同聊天流的消息并发处理,提高吞吐量 +- 动态管理流的生命周期,自动清理不活跃的流 +""" + +import asyncio +import time +from typing import Dict, Optional + +from src.common.logger import get_logger + +logger = get_logger("stream_router") + + +class StreamConsumer: + """单个聊天流的消息消费者 + + 维护独立的消息队列和处理协程 + """ + + def __init__(self, stream_id: str, queue_maxsize: int = 100): + self.stream_id = stream_id + self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize) + self.worker_task: Optional[asyncio.Task] = None + self.last_active_time = time.time() + self.is_running = False + + # 性能统计 + self.stats = { + "total_messages": 0, + "total_processing_time": 0.0, + "queue_overflow_count": 0, + } + + async def start(self) -> None: + """启动消费者""" + if not self.is_running: + self.is_running = True + self.worker_task = asyncio.create_task(self._process_loop()) + logger.debug(f"Stream Consumer 启动: {self.stream_id}") + + async def stop(self) -> None: + """停止消费者""" + self.is_running = False + if self.worker_task: + self.worker_task.cancel() + try: + await self.worker_task + except asyncio.CancelledError: + pass + logger.debug(f"Stream Consumer 停止: {self.stream_id}") + + async def enqueue(self, message: dict) -> None: + """将消息加入队列""" + self.last_active_time = time.time() + + try: + # 使用 put_nowait 避免阻塞路由器 + self.queue.put_nowait(message) + except asyncio.QueueFull: + self.stats["queue_overflow_count"] += 1 + logger.warning( + f"Stream {self.stream_id} 队列已满 " + f"({self.queue.qsize()}/{self.queue.maxsize})," + ) + + try: + self.queue.get_nowait() + self.queue.put_nowait(message) + logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息") + except asyncio.QueueEmpty: + pass + + async def _process_loop(self) -> None: + """消息处理循环""" + # 延迟导入,避免循环依赖 + from .recv_handler.message_handler import message_handler + from .recv_handler.meta_event_handler import meta_event_handler + from .recv_handler.notice_handler import notice_handler + + logger.info(f"Stream {self.stream_id} 处理循环启动") + + try: + while self.is_running: + try: + # 等待消息,1秒超时 + message = await asyncio.wait_for( + self.queue.get(), + timeout=1.0 + ) + + start_time = time.time() + + # 处理消息 + post_type = message.get("post_type") + if post_type == "message": + await message_handler.handle_raw_message(message) + elif post_type == "meta_event": + await meta_event_handler.handle_meta_event(message) + elif post_type == "notice": + await notice_handler.handle_notice(message) + else: + logger.warning(f"未知的 post_type: {post_type}") + + processing_time = time.time() - start_time + + # 更新统计 + self.stats["total_messages"] += 1 + self.stats["total_processing_time"] += processing_time + self.last_active_time = time.time() + self.queue.task_done() + + # 性能监控(每100条消息输出一次) + if self.stats["total_messages"] % 100 == 0: + avg_time = self.stats["total_processing_time"] / self.stats["total_messages"] + logger.info( + f"Stream {self.stream_id[:30]}... 统计: " + f"消息数={self.stats['total_messages']}, " + f"平均耗时={avg_time:.3f}秒, " + f"队列长度={self.queue.qsize()}" + ) + + # 动态延迟:队列空时短暂休眠 + if self.queue.qsize() == 0: + await asyncio.sleep(0.01) + + except asyncio.TimeoutError: + # 超时是正常的,继续循环 + continue + except asyncio.CancelledError: + logger.info(f"Stream {self.stream_id} 处理循环被取消") + break + except Exception as e: + logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True) + # 继续处理下一条消息 + await asyncio.sleep(0.1) + + finally: + logger.info(f"Stream {self.stream_id} 处理循环结束") + + def get_stats(self) -> dict: + """获取性能统计""" + avg_time = ( + self.stats["total_processing_time"] / self.stats["total_messages"] + if self.stats["total_messages"] > 0 + else 0 + ) + + return { + "stream_id": self.stream_id, + "queue_size": self.queue.qsize(), + "total_messages": self.stats["total_messages"], + "avg_processing_time": avg_time, + "queue_overflow_count": self.stats["queue_overflow_count"], + "last_active_time": self.last_active_time, + } + + +class StreamRouter: + """流路由器 + + 负责将消息路由到对应的聊天流队列 + 动态管理聊天流的生命周期 + """ + + def __init__( + self, + max_streams: int = 500, + stream_timeout: int = 600, + stream_queue_size: int = 100, + cleanup_interval: int = 60, + ): + self.streams: Dict[str, StreamConsumer] = {} + self.lock = asyncio.Lock() + self.max_streams = max_streams + self.stream_timeout = stream_timeout + self.stream_queue_size = stream_queue_size + self.cleanup_interval = cleanup_interval + self.cleanup_task: Optional[asyncio.Task] = None + self.is_running = False + + async def start(self) -> None: + """启动路由器""" + if not self.is_running: + self.is_running = True + self.cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info( + f"StreamRouter 已启动 - " + f"最大流数: {self.max_streams}, " + f"超时: {self.stream_timeout}秒, " + f"队列大小: {self.stream_queue_size}" + ) + + async def stop(self) -> None: + """停止路由器""" + self.is_running = False + + if self.cleanup_task: + self.cleanup_task.cancel() + try: + await self.cleanup_task + except asyncio.CancelledError: + pass + + # 停止所有流消费者 + logger.info(f"正在停止 {len(self.streams)} 个流消费者...") + for consumer in self.streams.values(): + await consumer.stop() + + self.streams.clear() + logger.info("StreamRouter 已停止") + + async def route_message(self, message: dict) -> None: + """路由消息到对应的流""" + stream_id = self._extract_stream_id(message) + + # 快速路径:流已存在 + if stream_id in self.streams: + await self.streams[stream_id].enqueue(message) + return + + # 慢路径:需要创建新流 + async with self.lock: + # 双重检查 + if stream_id not in self.streams: + # 检查流数量限制 + if len(self.streams) >= self.max_streams: + logger.warning( + f"达到最大流数量限制 ({self.max_streams})," + f"尝试清理不活跃的流..." + ) + await self._cleanup_inactive_streams() + + # 清理后仍然超限,记录警告但继续创建 + if len(self.streams) >= self.max_streams: + logger.error( + f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})!" + ) + + # 创建新流 + consumer = StreamConsumer(stream_id, self.stream_queue_size) + self.streams[stream_id] = consumer + await consumer.start() + logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})") + + await self.streams[stream_id].enqueue(message) + + def _extract_stream_id(self, message: dict) -> str: + """从消息中提取 stream_id + + 返回格式: platform:id:type + 例如: qq:123456:group 或 qq:789012:private + """ + post_type = message.get("post_type") + + # 非消息类型,使用默认流(避免创建过多流) + if post_type not in ["message", "notice"]: + return "system:meta_event" + + # 消息类型 + if post_type == "message": + message_type = message.get("message_type") + if message_type == "group": + group_id = message.get("group_id") + return f"qq:{group_id}:group" + elif message_type == "private": + user_id = message.get("user_id") + return f"qq:{user_id}:private" + + # notice 类型 + elif post_type == "notice": + group_id = message.get("group_id") + if group_id: + return f"qq:{group_id}:group" + user_id = message.get("user_id") + if user_id: + return f"qq:{user_id}:private" + + # 未知类型,使用通用流 + return "unknown:unknown" + + async def _cleanup_inactive_streams(self) -> None: + """清理不活跃的流""" + current_time = time.time() + to_remove = [] + + for stream_id, consumer in self.streams.items(): + if current_time - consumer.last_active_time > self.stream_timeout: + to_remove.append(stream_id) + + for stream_id in to_remove: + await self.streams[stream_id].stop() + del self.streams[stream_id] + logger.debug(f"清理不活跃的流: {stream_id}") + + if to_remove: + logger.info( + f"清理了 {len(to_remove)} 个不活跃的流 " + f"(当前活跃流: {len(self.streams)}/{self.max_streams})" + ) + + async def _cleanup_loop(self) -> None: + """定期清理循环""" + logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}秒") + try: + while self.is_running: + await asyncio.sleep(self.cleanup_interval) + await self._cleanup_inactive_streams() + except asyncio.CancelledError: + logger.info("清理循环已停止") + + def get_all_stats(self) -> list[dict]: + """获取所有流的统计信息""" + return [consumer.get_stats() for consumer in self.streams.values()] + + def get_summary(self) -> dict: + """获取路由器摘要""" + total_messages = sum(c.stats["total_messages"] for c in self.streams.values()) + total_queue_size = sum(c.queue.qsize() for c in self.streams.values()) + total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values()) + + # 计算平均队列长度 + avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0 + + # 找出最繁忙的流 + busiest_stream = None + if self.streams: + busiest_stream = max( + self.streams.values(), + key=lambda c: c.stats["total_messages"] + ).stream_id + + return { + "total_streams": len(self.streams), + "max_streams": self.max_streams, + "total_messages_processed": total_messages, + "total_queue_size": total_queue_size, + "avg_queue_size": avg_queue_size, + "total_queue_overflows": total_overflows, + "busiest_stream": busiest_stream, + } + + +# 全局路由器实例 +stream_router = StreamRouter() diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index e75b08110..921f4619a 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -236,8 +236,6 @@ class NapcatAdapterPlugin(BasePlugin): def enable_plugin(self) -> bool: """通过配置文件动态控制插件启用状态""" # 如果已经通过配置加载了状态,使用配置中的值 - if hasattr(self, "_is_enabled"): - return self._is_enabled # 否则使用默认值(禁用状态) return False