重构并增强Napcat适配器的功能
- 更新了`BaseAdapter`以简化子进程处理。 - 对`AdapterManager`进行了重构,以便根据适配器的`run_in_subprocess`属性来管理适配器。 - 增强了`NapcatAdapter`,以利用新的`CoreSinkManager`实现更优的进程管理。 - 在`utils.py`中实现了针对群组和成员信息的缓存机制。 - 改进了`message_handler.py`中的消息处理,以支持各种消息类型和格式。 - 已将插件配置版本更新至7.8.3。
This commit is contained in:
222
docs/message_runtime_architecture.md
Normal file
222
docs/message_runtime_architecture.md
Normal file
@@ -0,0 +1,222 @@
|
||||
# MoFox Bot 消息运行时架构 (MessageRuntime)
|
||||
|
||||
本文档描述了 MoFox Bot 使用 `mofox_bus.MessageRuntime` 简化消息处理链条的架构设计。
|
||||
|
||||
## 架构概述
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ CoreSinkManager │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐│
|
||||
│ │ MessageRuntime ││
|
||||
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ││
|
||||
│ │ │ before_hook │→ │ Routes │→ │ after_hook │ ││
|
||||
│ │ │ (预处理/过滤) │ │ (消息路由) │ │ (后处理) │ ││
|
||||
│ │ └──────────────┘ └──────────────┘ └──────────────┘ ││
|
||||
│ │ ↓ ↓ ↓ ││
|
||||
│ │ ┌──────────────────────────────────────────────────────────────┐ ││
|
||||
│ │ │ error_hook (错误处理) │ ││
|
||||
│ │ └──────────────────────────────────────────────────────────────┘ ││
|
||||
│ └─────────────────────────────────────────────────────────────────────┘│
|
||||
│ │
|
||||
│ ┌──────────────────────┐ ┌──────────────────────────────────────┐ │
|
||||
│ │ InProcessCoreSink │ │ ProcessCoreSinkServer (子进程适配器) │ │
|
||||
│ │ (同进程适配器) │ │ │ │
|
||||
│ └──────────────────────┘ └──────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
↑ ↑
|
||||
│ │
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ 同进程适配器 │ │ 子进程适配器 │
|
||||
│ (run_in_subprocess │ │ (run_in_subprocess │
|
||||
│ = False) │ │ = True) │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
```
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 1. CoreSinkManager (`src/common/core_sink_manager.py`)
|
||||
|
||||
统一管理 CoreSink 双实例和 MessageRuntime:
|
||||
|
||||
```python
|
||||
from src.common.core_sink_manager import get_core_sink_manager, get_message_runtime
|
||||
|
||||
# 获取管理器
|
||||
manager = get_core_sink_manager()
|
||||
|
||||
# 获取 MessageRuntime
|
||||
runtime = get_message_runtime()
|
||||
|
||||
# 发送消息到适配器
|
||||
await manager.send_outgoing(envelope)
|
||||
```
|
||||
|
||||
### 2. MessageRuntime
|
||||
|
||||
`MessageRuntime` 是 mofox_bus 提供的消息路由核心,支持:
|
||||
|
||||
- **消息路由**:通过 `add_route()` 或 `@on_message` 装饰器按消息类型路由
|
||||
- **钩子机制**:`before_hook`(前置处理)、`after_hook`(后置处理)、`error_hook`(错误处理)
|
||||
- **中间件**:洋葱模型的中间件机制
|
||||
- **批量处理**:支持 `handle_batch()` 批量处理消息
|
||||
|
||||
### 3. MessageHandler (`src/chat/message_receive/message_handler.py`)
|
||||
|
||||
将消息处理逻辑注册为 MessageRuntime 的路由和钩子:
|
||||
|
||||
```python
|
||||
class MessageHandler:
|
||||
def register_handlers(self, runtime: MessageRuntime) -> None:
|
||||
# 注册前置钩子
|
||||
runtime.register_before_hook(self._before_hook)
|
||||
|
||||
# 注册后置钩子
|
||||
runtime.register_after_hook(self._after_hook)
|
||||
|
||||
# 注册错误钩子
|
||||
runtime.register_error_hook(self._error_hook)
|
||||
|
||||
# 注册适配器响应处理器
|
||||
runtime.add_route(
|
||||
predicate=_is_adapter_response,
|
||||
handler=self._handle_adapter_response_route,
|
||||
name="adapter_response_handler",
|
||||
message_type="adapter_response",
|
||||
)
|
||||
|
||||
# 注册默认消息处理器
|
||||
runtime.add_route(
|
||||
predicate=lambda _: True,
|
||||
handler=self._handle_normal_message,
|
||||
name="default_message_handler",
|
||||
)
|
||||
```
|
||||
|
||||
## 消息流向
|
||||
|
||||
### 接收消息
|
||||
|
||||
```
|
||||
适配器 → InProcessCoreSink/ProcessCoreSinkServer → CoreSinkManager._dispatch_to_runtime()
|
||||
→ MessageRuntime.handle_message()
|
||||
→ before_hook (预处理、过滤)
|
||||
→ 匹配路由 (adapter_response / normal_message)
|
||||
→ 执行处理器
|
||||
→ after_hook (后处理)
|
||||
```
|
||||
|
||||
### 发送消息
|
||||
|
||||
```
|
||||
消息发送请求 → CoreSinkManager.send_outgoing()
|
||||
→ InProcessCoreSink.push_outgoing()
|
||||
→ ProcessCoreSinkServer.push_outgoing()
|
||||
→ 适配器
|
||||
```
|
||||
|
||||
## 钩子功能
|
||||
|
||||
### before_hook
|
||||
|
||||
在消息路由之前执行,用于:
|
||||
- 标准化 ID 为字符串
|
||||
- 检查 echo 消息(自身消息上报)
|
||||
- 通过抛出 `UserWarning` 跳过消息处理
|
||||
|
||||
### after_hook
|
||||
|
||||
在消息处理完成后执行,用于:
|
||||
- 清理工作
|
||||
- 日志记录
|
||||
|
||||
### error_hook
|
||||
|
||||
在处理过程中出现异常时执行,用于:
|
||||
- 区分预期的流程控制(UserWarning)和真正的错误
|
||||
- 统一异常日志记录
|
||||
|
||||
## 路由优先级
|
||||
|
||||
1. **明确指定 message_type 的路由**(优先级最高)
|
||||
2. **事件路由**(基于 event_type)
|
||||
3. **通用路由**(无 message_type 限制)
|
||||
|
||||
## 扩展消息处理
|
||||
|
||||
### 注册自定义处理器
|
||||
|
||||
```python
|
||||
from src.common.core_sink_manager import get_message_runtime
|
||||
from mofox_bus import MessageEnvelope
|
||||
|
||||
runtime = get_message_runtime()
|
||||
|
||||
# 使用装饰器
|
||||
@runtime.on_message(message_type="image")
|
||||
async def handle_image(envelope: MessageEnvelope):
|
||||
# 处理图片消息
|
||||
pass
|
||||
|
||||
# 或使用 add_route
|
||||
runtime.add_route(
|
||||
predicate=lambda env: env.get("platform") == "qq",
|
||||
handler=my_handler,
|
||||
name="qq_handler",
|
||||
)
|
||||
```
|
||||
|
||||
### 注册钩子
|
||||
|
||||
```python
|
||||
runtime = get_message_runtime()
|
||||
|
||||
# 前置钩子
|
||||
async def my_before_hook(envelope: MessageEnvelope) -> None:
|
||||
# 预处理逻辑
|
||||
pass
|
||||
|
||||
runtime.register_before_hook(my_before_hook)
|
||||
|
||||
# 错误钩子
|
||||
async def my_error_hook(envelope: MessageEnvelope, exc: BaseException) -> None:
|
||||
# 错误处理逻辑
|
||||
pass
|
||||
|
||||
runtime.register_error_hook(my_error_hook)
|
||||
```
|
||||
|
||||
## 初始化流程
|
||||
|
||||
在 `MainSystem.initialize()` 中:
|
||||
|
||||
1. 初始化 `CoreSinkManager`(包含 `MessageRuntime`)
|
||||
2. 获取 `MessageHandler` 并设置 `CoreSinkManager` 引用
|
||||
3. 调用 `MessageHandler.register_handlers()` 向 `MessageRuntime` 注册处理器和钩子
|
||||
4. 初始化其他组件
|
||||
|
||||
```python
|
||||
async def initialize(self) -> None:
|
||||
# 初始化 CoreSinkManager(包含 MessageRuntime)
|
||||
self.core_sink_manager = await initialize_core_sink_manager()
|
||||
|
||||
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
|
||||
self.message_handler = get_message_handler()
|
||||
self.message_handler.set_core_sink_manager(self.core_sink_manager)
|
||||
self.message_handler.register_handlers(self.core_sink_manager.runtime)
|
||||
```
|
||||
|
||||
## 优势
|
||||
|
||||
1. **简化消息处理链**:不再需要手动管理处理流程,使用声明式路由
|
||||
2. **更好的可扩展性**:通过 `add_route()` 或装饰器轻松添加新的处理器
|
||||
3. **统一的错误处理**:通过 `error_hook` 集中处理异常
|
||||
4. **支持中间件**:可以添加洋葱模型的中间件
|
||||
5. **更清晰的代码结构**:处理逻辑按类型分离
|
||||
|
||||
## 参考
|
||||
|
||||
- `packages/mofox-bus/src/mofox_bus/runtime.py` - MessageRuntime 实现
|
||||
- `src/common/core_sink_manager.py` - CoreSinkManager 实现
|
||||
- `src/chat/message_receive/message_handler.py` - MessageHandler 实现
|
||||
- `docs/mofox_bus.md` - MoFox Bus 消息库说明
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_message`,将核心与各平台适配器之间的通信抽象成可拓展、可热插拔的组件。该库完全异步、面向高吞吐,覆盖消息建模、序列化、传输层、运行时路由、适配器工具等多个层面。
|
||||
|
||||
> 现在已拆分为独立 pip 包:在项目根目录执行 `pip install -e ./packages/mofox-bus` 即可安装到当前 Python 环境。
|
||||
|
||||
---
|
||||
|
||||
## 1. 设计目标
|
||||
@@ -14,7 +16,7 @@ MoFox Bus 是 MoFox Bot 自研的统一消息中台,替换第三方 `maim_mess
|
||||
|
||||
---
|
||||
|
||||
## 2. 包结构概览(`src/mofox_bus/`)
|
||||
## 2. 包结构概览(`packages/mofox-bus/src/mofox_bus/`)
|
||||
|
||||
| 模块 | 主要职责 |
|
||||
| --- | --- |
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -17,8 +16,6 @@ from typing import Any, Dict, Optional
|
||||
import orjson
|
||||
import websockets
|
||||
|
||||
# 追加 src 目录,便于直接运行示例
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from mofox_bus import (
|
||||
AdapterBase,
|
||||
|
||||
@@ -78,6 +78,7 @@ dependencies = [
|
||||
"inkfox>=0.1.1",
|
||||
"rjieba>=0.1.13",
|
||||
"fastmcp>=2.13.0",
|
||||
"mofox-bus",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
sqlalchemy
|
||||
aiosqlite
|
||||
aiofiles
|
||||
aiomysql
|
||||
|
||||
@@ -121,8 +121,6 @@ class ChatterManager:
|
||||
|
||||
chatter_class = self.get_chatter_class(chat_type)
|
||||
if not chatter_class:
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
all_chatter_class = self.get_chatter_class(ChatType.ALL)
|
||||
if all_chatter_class:
|
||||
chatter_class = all_chatter_class
|
||||
|
||||
@@ -8,9 +8,10 @@ import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,7 +22,7 @@ from .distribution_manager import stream_loop_manager
|
||||
from .global_notice_manager import NoticeScope, global_notice_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("message_manager")
|
||||
|
||||
@@ -39,6 +40,8 @@ class MessageManager:
|
||||
|
||||
# 初始化chatter manager
|
||||
self.action_manager = ChatterActionManager()
|
||||
# 延迟导入ChatterManager以避免循环导入
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
self.chatter_manager = ChatterManager(self.action_manager)
|
||||
|
||||
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context
|
||||
|
||||
@@ -1,488 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from mofox_bus.runtime import MessageRuntime
|
||||
from mofox_bus import MessageEnvelope
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否包含过滤词
|
||||
"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
runtime = MessageRuntime() # 获取mofox-bus运行时环境
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
# 启动消息管理器
|
||||
self._message_manager_started = False
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
# 启动消息管理器
|
||||
if not self._message_manager_started:
|
||||
await message_manager.start()
|
||||
self._message_manager_started = True
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
|
||||
"""独立处理PlusCommand系统"""
|
||||
try:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 获取配置的命令前缀
|
||||
from src.config.config import global_config
|
||||
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
matched_prefix = None
|
||||
for prefix in prefixes:
|
||||
if text.startswith(prefix):
|
||||
matched_prefix = prefix
|
||||
break
|
||||
|
||||
if not matched_prefix:
|
||||
return False, None, True # 不是命令,继续处理
|
||||
|
||||
# 移除前缀
|
||||
command_part = text[len(matched_prefix) :].strip()
|
||||
|
||||
# 分离命令名和参数
|
||||
parts = command_part.split(None, 1)
|
||||
if not parts:
|
||||
return False, None, True # 没有命令名,继续处理
|
||||
|
||||
command_word = parts[0].lower()
|
||||
args_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# 查找匹配的PlusCommand
|
||||
plus_command_registry = component_registry.get_plus_command_registry()
|
||||
matching_commands = []
|
||||
|
||||
for plus_command_name, plus_command_class in plus_command_registry.items():
|
||||
plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name)
|
||||
if not plus_command_info:
|
||||
continue
|
||||
|
||||
# 检查命令名是否匹配(命令名和别名)
|
||||
all_commands = [plus_command_name.lower()] + [
|
||||
alias.lower() for alias in plus_command_info.command_aliases
|
||||
]
|
||||
if command_word in all_commands:
|
||||
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
|
||||
|
||||
if not matching_commands:
|
||||
return False, None, True # 没有找到匹配的PlusCommand,继续处理
|
||||
|
||||
# 如果有多个匹配,按优先级排序
|
||||
if len(matching_commands) > 1:
|
||||
matching_commands.sort(key=lambda x: x[1].priority, reverse=True)
|
||||
logger.warning(
|
||||
f"文本 '{text}' 匹配到多个PlusCommand: {[cmd[2] for cmd in matching_commands]},使用优先级最高的"
|
||||
)
|
||||
|
||||
plus_command_class, plus_command_info, plus_command_name = matching_commands[0]
|
||||
|
||||
# 检查命令是否被禁用
|
||||
if (
|
||||
chat
|
||||
and chat.stream_id
|
||||
and plus_command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的PlusCommand,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plus_command_name)
|
||||
|
||||
# 创建PlusCommand实例
|
||||
plus_command_instance = plus_command_class(message, plugin_config)
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(plus_command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not plus_command_instance.is_chat_type_allowed():
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
return False, None, True # 跳过此命令,继续处理其他消息
|
||||
|
||||
# 设置参数
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
|
||||
command_args = CommandArgs(args_text)
|
||||
plus_command_instance.args = command_args
|
||||
|
||||
# 执行命令
|
||||
success, response, intercept_message = await plus_command_instance.execute(command_args)
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理PlusCommand时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
chat
|
||||
and chat.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
|
||||
# 创建命令实例
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not command_instance.is_chat_type_allowed():
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
return False, None, True # 跳过此命令,继续处理其他消息
|
||||
|
||||
# 执行命令
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
# 没有找到命令,继续处理消息
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
|
||||
async def _handle_adapter_response_from_dict(self, seg_data: dict | None):
|
||||
"""处理适配器命令响应(从字典数据)"""
|
||||
try:
|
||||
from src.plugin_system.apis.send_api import put_adapter_response
|
||||
|
||||
if isinstance(seg_data, dict):
|
||||
request_id = seg_data.get("request_id")
|
||||
response_data = seg_data.get("response")
|
||||
else:
|
||||
request_id = None
|
||||
response_data = None
|
||||
|
||||
if request_id and response_data:
|
||||
logger.info(f"[DEBUG bot.py] 收到适配器响应,request_id={request_id}")
|
||||
put_adapter_response(request_id, response_data)
|
||||
else:
|
||||
logger.warning(f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
@runtime.on_message
|
||||
async def message_process(self, envelope: MessageEnvelope) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(envelope.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||
message_info["group_info"]["group_id"] # type: ignore
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||
message_info["user_info"]["user_id"] # type: ignore
|
||||
)
|
||||
|
||||
# 优先处理adapter_response消息(在echo检查之前!)
|
||||
message_segment = envelope.get("message_segment")
|
||||
if message_segment and isinstance(message_segment, dict):
|
||||
if message_segment.get("type") == "adapter_response":
|
||||
logger.info("[DEBUG bot.py message_process] 检测到adapter_response,立即处理")
|
||||
await self._handle_adapter_response_from_dict(message_segment.get("data"))
|
||||
return
|
||||
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from mofox_bus import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
if temp_message_info.additional_config:
|
||||
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||
await MessageStorage.update_message(message_data)
|
||||
return
|
||||
|
||||
message_segment = envelope.get("message_segment")
|
||||
group_info = temp_message_info.group_info
|
||||
user_info = temp_message_info.user_info
|
||||
|
||||
# 获取或创建聊天流
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=temp_message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||
logger.info(
|
||||
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
|
||||
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return
|
||||
|
||||
# 过滤检查
|
||||
# DatabaseMessages 使用 display_message 作为原始消息表示
|
||||
raw_text = message.display_message or message.processed_plain_text or ""
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
raw_text,
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 首先尝试PlusCommand独立处理
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||
|
||||
# 如果是PlusCommand且不需要继续处理,则直接返回
|
||||
if is_plus_command and not plus_continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}")
|
||||
return
|
||||
|
||||
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
||||
if not is_plus_command:
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
|
||||
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
# 这个功能需要在 adapter 层通过 additional_config 传递
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
|
||||
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
||||
try:
|
||||
# 在将消息添加到管理器之前进行最终的静默检查
|
||||
should_process_in_manager = True
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
# 检查消息是否为图片或表情包
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||
should_process_in_manager = False
|
||||
elif is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||
should_process_in_manager = False
|
||||
|
||||
if should_process_in_manager:
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
await preprocess()
|
||||
else:
|
||||
await preprocess()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
chat_bot = ChatBot()
|
||||
@@ -2,11 +2,11 @@ import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from mofox_bus import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
@@ -30,8 +30,8 @@ class ChatStream:
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo | None = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
user_info: DatabaseUserInfo | None = None,
|
||||
group_info: DatabaseGroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
@@ -77,8 +77,8 @@ class ChatStream:
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ChatStream":
|
||||
"""从字典创建实例"""
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
user_info = DatabaseUserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = DatabaseGroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
instance = cls(
|
||||
stream_id=data["stream_id"],
|
||||
@@ -369,7 +369,7 @@ class ChatManager:
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||||
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
@@ -392,7 +392,7 @@ class ChatManager:
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流 - 优化版本使用缓存机制"""
|
||||
try:
|
||||
@@ -483,7 +483,7 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
self, platform: str, user_info: DatabaseUserInfo, group_info: DatabaseGroupInfo | None = None
|
||||
) -> ChatStream | None:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import urllib3
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
@@ -14,6 +13,9 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,566 @@
|
||||
import os
|
||||
import traceback
|
||||
"""
|
||||
统一消息处理器 (Message Handler)
|
||||
|
||||
利用 mofox_bus.MessageRuntime 的路由功能,简化消息处理链条:
|
||||
|
||||
1. 使用 @runtime.on_message() 装饰器注册按消息类型路由的处理器
|
||||
2. 使用 before_hook 进行消息预处理(ID标准化、过滤等)
|
||||
3. 使用 after_hook 进行消息后处理(存储、情绪更新等)
|
||||
4. 使用 error_hook 统一处理异常
|
||||
|
||||
消息流向:
|
||||
适配器 → CoreSinkManager → MessageRuntime
|
||||
↓
|
||||
[before_hook] 消息预处理、过滤
|
||||
↓
|
||||
[on_message] 按类型路由处理(命令、普通消息等)
|
||||
↓
|
||||
[after_hook] 存储、情绪更新等
|
||||
↓
|
||||
回复生成 → CoreSinkManager.send_outgoing() → 适配器
|
||||
|
||||
重构说明(2025-11):
|
||||
- 移除手动的消息处理链,改用 MessageRuntime 路由
|
||||
- MessageHandler 变成处理器注册器,在初始化时注册各种处理器
|
||||
- 利用 runtime 的钩子机制简化前置/后置处理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mofox_bus import MessageEnvelope, MessageRuntime
|
||||
|
||||
from mofox_bus.runtime import MessageRuntime
|
||||
from mofox_bus import MessageEnvelope
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo, DatabaseMessages
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
runtime = MessageRuntime()
|
||||
if TYPE_CHECKING:
|
||||
from src.common.core_sink_manager import CoreSinkManager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
logger = get_logger("message_handler")
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("chat")
|
||||
|
||||
class MessageHandler:
|
||||
"""
|
||||
统一消息处理器
|
||||
|
||||
利用 MessageRuntime 的路由功能,将消息处理逻辑注册为路由和钩子。
|
||||
|
||||
架构说明:
|
||||
- 在 register_handlers() 中向 MessageRuntime 注册各种处理器
|
||||
- 使用 @runtime.on_message(message_type=...) 按消息类型路由
|
||||
- 使用 before_hook 进行消息预处理
|
||||
- 使用 after_hook 进行消息后处理
|
||||
- 使用 error_hook 统一处理异常
|
||||
|
||||
主要功能:
|
||||
1. 消息预处理:ID标准化、过滤检查
|
||||
2. 适配器响应处理:处理 adapter_response 类型消息
|
||||
3. 命令处理:PlusCommand 和 BaseCommand
|
||||
4. 普通消息处理:触发事件、存储、情绪更新
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._message_manager_started = False
|
||||
self._core_sink_manager: CoreSinkManager | None = None
|
||||
self._shutting_down = False
|
||||
self._runtime: MessageRuntime | None = None
|
||||
|
||||
async def preprocess(self, chat: ChatStream, message: DatabaseMessages):
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
def set_core_sink_manager(self, manager: "CoreSinkManager") -> None:
|
||||
"""设置 CoreSinkManager 引用"""
|
||||
self._core_sink_manager = manager
|
||||
|
||||
# 先交给消息管理器处理
|
||||
def register_handlers(self, runtime: MessageRuntime) -> None:
|
||||
"""
|
||||
向 MessageRuntime 注册消息处理器和钩子
|
||||
|
||||
这是核心方法,在系统初始化时调用,将所有处理逻辑注册到 runtime。
|
||||
|
||||
Args:
|
||||
runtime: MessageRuntime 实例
|
||||
"""
|
||||
self._runtime = runtime
|
||||
|
||||
# 注册前置钩子:消息预处理和过滤
|
||||
runtime.register_before_hook(self._before_hook)
|
||||
|
||||
# 注册后置钩子:存储、情绪更新等
|
||||
runtime.register_after_hook(self._after_hook)
|
||||
|
||||
# 注册错误钩子:统一异常处理
|
||||
runtime.register_error_hook(self._error_hook)
|
||||
|
||||
# 注册适配器响应处理器(最高优先级)
|
||||
def _is_adapter_response(env: MessageEnvelope) -> bool:
|
||||
segment = env.get("message_segment")
|
||||
if isinstance(segment, dict):
|
||||
return segment.get("type") == "adapter_response"
|
||||
return False
|
||||
|
||||
runtime.add_route(
|
||||
predicate=_is_adapter_response,
|
||||
handler=self._handle_adapter_response_route,
|
||||
name="adapter_response_handler",
|
||||
message_type="adapter_response",
|
||||
)
|
||||
|
||||
# 注册默认消息处理器(处理所有其他消息)
|
||||
runtime.add_route(
|
||||
predicate=lambda _: True, # 匹配所有消息
|
||||
handler=self._handle_normal_message,
|
||||
name="default_message_handler",
|
||||
)
|
||||
|
||||
logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子")
|
||||
|
||||
async def ensure_started(self) -> None:
|
||||
"""确保所有依赖任务已启动"""
|
||||
if not self._started:
|
||||
logger.debug("确保 MessageHandler 所有任务已启动")
|
||||
|
||||
# 启动消息管理器
|
||||
if not self._message_manager_started:
|
||||
await message_manager.start()
|
||||
self._message_manager_started = True
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _before_hook(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
前置钩子:消息预处理
|
||||
|
||||
1. 标准化 ID 为字符串
|
||||
2. 检查是否为 echo 消息(自身发送的消息上报)
|
||||
3. 附加预处理数据到 envelope(chat_stream, message 等)
|
||||
"""
|
||||
if self._shutting_down:
|
||||
raise UserWarning("系统正在关闭,拒绝处理消息")
|
||||
|
||||
# 确保依赖服务已启动
|
||||
await self.ensure_started()
|
||||
|
||||
# 提取消息信息
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(envelope.keys()),
|
||||
)
|
||||
raise UserWarning("消息缺少 message_info")
|
||||
|
||||
# 标准化 ID 为字符串
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||
message_info["group_info"]["group_id"] # type: ignore
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||
message_info["user_info"]["user_id"] # type: ignore
|
||||
)
|
||||
|
||||
# 处理自身消息上报(echo)
|
||||
additional_config = message_info.get("additional_config", {})
|
||||
if additional_config and isinstance(additional_config, dict):
|
||||
sent_message = additional_config.get("echo", False)
|
||||
if sent_message:
|
||||
# 更新消息ID
|
||||
await MessageStorage.update_message(dict(envelope))
|
||||
raise UserWarning("Echo 消息已处理")
|
||||
|
||||
async def _after_hook(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
后置钩子:消息后处理
|
||||
|
||||
在消息处理完成后执行的清理工作
|
||||
"""
|
||||
# 后置处理逻辑(如有需要)
|
||||
pass
|
||||
|
||||
async def _error_hook(self, envelope: MessageEnvelope, exc: BaseException) -> None:
|
||||
"""
|
||||
错误钩子:统一异常处理
|
||||
"""
|
||||
if isinstance(exc, UserWarning):
|
||||
# UserWarning 是预期的流程控制,只记录 debug 日志
|
||||
logger.debug(f"消息处理流程控制: {exc}")
|
||||
else:
|
||||
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
logger.error(f"处理消息 {message_id} 时出错: {exc}", exc_info=True)
|
||||
|
||||
async def _handle_adapter_response_route(self, envelope: MessageEnvelope) -> MessageEnvelope | None:
|
||||
"""
|
||||
处理适配器响应消息的路由处理器
|
||||
"""
|
||||
message_segment = envelope.get("message_segment")
|
||||
if message_segment and isinstance(message_segment, dict):
|
||||
seg_data = message_segment.get("data")
|
||||
if isinstance(seg_data, dict):
|
||||
await self._handle_adapter_response(seg_data)
|
||||
return None
|
||||
|
||||
async def _handle_normal_message(self, envelope: MessageEnvelope) -> MessageEnvelope | None:
|
||||
"""
|
||||
默认消息处理器:处理普通消息
|
||||
|
||||
1. 获取或创建聊天流
|
||||
2. 转换为 DatabaseMessages
|
||||
3. 过滤检查
|
||||
4. 命令处理
|
||||
5. 触发事件、存储、情绪更新
|
||||
"""
|
||||
try:
|
||||
# 在将消息添加到管理器之前进行最终的静默检查
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
return None
|
||||
|
||||
# 获取用户和群组信息
|
||||
group_info = message_info.get("group_info")
|
||||
user_info = message_info.get("user_info")
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
# 打印接收日志
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||
|
||||
# 硬编码过滤
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return None
|
||||
|
||||
# 过滤检查
|
||||
raw_text = message.display_message or message.processed_plain_text or ""
|
||||
if _check_ban_words(processed_text, chat, user_info) or _check_ban_regex(
|
||||
raw_text, chat, user_info
|
||||
):
|
||||
return None
|
||||
|
||||
# 处理命令和后续流程
|
||||
await self._process_commands(message, chat)
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.info(str(uw))
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return None
|
||||
|
||||
# 保留旧的 process_message 方法用于向后兼容
|
||||
async def process_message(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理接收到的消息信封(向后兼容)
|
||||
|
||||
注意:此方法已被 MessageRuntime 路由取代。
|
||||
如果直接调用此方法,它会委托给 runtime.handle_message()。
|
||||
|
||||
Args:
|
||||
envelope: 消息信封(来自适配器)
|
||||
"""
|
||||
if self._runtime:
|
||||
await self._runtime.handle_message(envelope)
|
||||
else:
|
||||
# 如果 runtime 未设置,使用旧的处理流程
|
||||
await self._handle_normal_message(envelope)
|
||||
|
||||
async def _process_commands(self, message: DatabaseMessages, chat: "ChatStream") -> None:
|
||||
"""处理命令和继续消息流程"""
|
||||
try:
|
||||
# 首先尝试 PlusCommand
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||
|
||||
if is_plus_command and not plus_continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"PlusCommand处理完成,跳过后续消息处理: {plus_cmd_result}")
|
||||
return
|
||||
|
||||
# 如果不是 PlusCommand,尝试传统 BaseCommand
|
||||
if not is_plus_command:
|
||||
is_command, cmd_result, continue_process = await self._process_base_commands(message, chat)
|
||||
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
# 触发消息事件
|
||||
result = await event_manager.trigger_event(
|
||||
EventType.ON_MESSAGE,
|
||||
permission_group="SYSTEM",
|
||||
message=message
|
||||
)
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(
|
||||
f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理"
|
||||
)
|
||||
|
||||
# 预处理消息
|
||||
await self._preprocess_message(message, chat)
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.info(str(uw))
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_plus_commands(
|
||||
self,
|
||||
message: DatabaseMessages,
|
||||
chat: "ChatStream"
|
||||
) -> tuple[bool, Any, bool]:
|
||||
"""处理 PlusCommand 系统"""
|
||||
try:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
matched_prefix = None
|
||||
for prefix in prefixes:
|
||||
if text.startswith(prefix):
|
||||
matched_prefix = prefix
|
||||
break
|
||||
|
||||
if not matched_prefix:
|
||||
return False, None, True
|
||||
|
||||
# 移除前缀
|
||||
command_part = text[len(matched_prefix):].strip()
|
||||
|
||||
# 分离命令名和参数
|
||||
parts = command_part.split(None, 1)
|
||||
if not parts:
|
||||
return False, None, True
|
||||
|
||||
command_word = parts[0].lower()
|
||||
args_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# 查找匹配的 PlusCommand
|
||||
plus_command_registry = component_registry.get_plus_command_registry()
|
||||
matching_commands = []
|
||||
|
||||
for plus_command_name, plus_command_class in plus_command_registry.items():
|
||||
plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name)
|
||||
if not plus_command_info:
|
||||
continue
|
||||
|
||||
all_commands = [plus_command_name.lower()] + [
|
||||
alias.lower() for alias in plus_command_info.command_aliases
|
||||
]
|
||||
if command_word in all_commands:
|
||||
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
|
||||
|
||||
if not matching_commands:
|
||||
return False, None, True
|
||||
|
||||
# 按优先级排序
|
||||
if len(matching_commands) > 1:
|
||||
matching_commands.sort(key=lambda x: x[1].priority, reverse=True)
|
||||
|
||||
plus_command_class, plus_command_info, plus_command_name = matching_commands[0]
|
||||
|
||||
# 检查是否被禁用
|
||||
if (
|
||||
chat
|
||||
and chat.stream_id
|
||||
and plus_command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的PlusCommand,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plus_command_name)
|
||||
|
||||
# 创建实例
|
||||
plus_command_instance = plus_command_class(message, plugin_config)
|
||||
setattr(plus_command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
if not plus_command_instance.is_chat_type_allowed():
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
return False, None, True
|
||||
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
command_args = CommandArgs(args_text)
|
||||
plus_command_instance.args = command_args
|
||||
|
||||
success, response, intercept_message = await plus_command_instance.execute(command_args)
|
||||
|
||||
if success:
|
||||
logger.info(f"PlusCommand执行成功: {plus_command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"PlusCommand执行失败: {plus_command_class.__name__} - {response}")
|
||||
|
||||
return True, response, not intercept_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行PlusCommand时出错: {plus_command_class.__name__} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True, str(e), False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理PlusCommand时出错: {e}")
|
||||
return False, None, True
|
||||
|
||||
async def _process_base_commands(
|
||||
self,
|
||||
message: DatabaseMessages,
|
||||
chat: "ChatStream"
|
||||
) -> tuple[bool, Any, bool]:
|
||||
"""处理传统 BaseCommand 系统"""
|
||||
try:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
|
||||
if (
|
||||
chat
|
||||
and chat.stream_id
|
||||
and command_name in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
setattr(command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
if not command_instance.is_chat_type_allowed():
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
return False, None, True
|
||||
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
return True, response, not intercept_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True, str(e), False
|
||||
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True
|
||||
|
||||
async def _preprocess_message(self, message: DatabaseMessages, chat: "ChatStream") -> None:
|
||||
"""预处理消息:存储、情绪更新等"""
|
||||
try:
|
||||
group_info = chat.group_info
|
||||
|
||||
# 检查是否需要处理消息
|
||||
should_process_in_manager = True
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
# 检查消息是否为图片或表情包
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理")
|
||||
logger.debug(
|
||||
f"群组 {group_info.group_id} 在静默列表中,且消息不是@、回复或图片/表情包,跳过消息管理器处理"
|
||||
)
|
||||
should_process_in_manager = False
|
||||
elif is_image_or_emoji:
|
||||
logger.debug(f"群组 {group_info.group_id} 在静默列表中,但消息是图片/表情包,静默处理")
|
||||
@@ -43,68 +570,81 @@ class MessageHandler:
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
# 存储消息
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 情绪系统更新
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
interest_rate = message.interest_value or 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 情绪系统更新 - 在消息存储后触发情绪更新
|
||||
async def _handle_adapter_response(self, seg_data: dict | None) -> None:
|
||||
"""处理适配器命令响应"""
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
from src.plugin_system.apis.send_api import put_adapter_response
|
||||
|
||||
if isinstance(seg_data, dict):
|
||||
request_id = seg_data.get("request_id")
|
||||
response_data = seg_data.get("response")
|
||||
else:
|
||||
request_id = None
|
||||
response_data = None
|
||||
|
||||
if request_id and response_data:
|
||||
logger.debug(f"收到适配器响应,request_id={request_id}")
|
||||
put_adapter_response(request_id, response_data)
|
||||
else:
|
||||
logger.warning(
|
||||
f"适配器响应消息格式不正确: request_id={request_id}, response_data={response_data}"
|
||||
)
|
||||
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"更新情绪状态失败: {e}")
|
||||
traceback.print_exc()
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭消息处理器"""
|
||||
self._shutting_down = True
|
||||
logger.info("MessageHandler 正在关闭...")
|
||||
|
||||
|
||||
async def handle_message(self, envelope: MessageEnvelope):
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
message_info = envelope.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(envelope.keys()),
|
||||
)
|
||||
return
|
||||
# 全局单例
|
||||
_message_handler: MessageHandler | None = None
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str( # type: ignore
|
||||
message_info["group_info"]["group_id"] # type: ignore
|
||||
)
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str( # type: ignore
|
||||
message_info["user_info"]["user_id"] # type: ignore
|
||||
)
|
||||
|
||||
group_info = message_info.get("group_info")
|
||||
user_info = message_info.get("user_info")
|
||||
def get_message_handler() -> MessageHandler:
|
||||
"""获取 MessageHandler 单例"""
|
||||
global _message_handler
|
||||
if _message_handler is None:
|
||||
_message_handler = MessageHandler()
|
||||
return _message_handler
|
||||
|
||||
chat_stream = await get_chat_manager().get_or_create_stream(
|
||||
platform=envelope["platform"], # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat_stream.stream_id,
|
||||
platform=chat_stream.platform
|
||||
)
|
||||
async def shutdown_message_handler() -> None:
|
||||
"""关闭 MessageHandler"""
|
||||
global _message_handler
|
||||
if _message_handler:
|
||||
await _message_handler.shutdown()
|
||||
_message_handler = None
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MessageHandler",
|
||||
"get_message_handler",
|
||||
"shutdown_message_handler",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
"""
|
||||
统一消息发送器
|
||||
|
||||
重构说明(2025-11):
|
||||
- 使用 CoreSinkManager 发送消息,而不是直接通过 WS 连接
|
||||
- MessageServer 仅作为与旧适配器的兼容层
|
||||
- 所有发送的消息都通过 CoreSinkManager.send_outgoing() 路由到适配器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from mofox_bus import MessageEnvelope
|
||||
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -15,12 +28,30 @@ logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
"""
|
||||
合并后的消息发送函数
|
||||
|
||||
重构后使用 CoreSinkManager 发送消息,而不是直接调用 MessageServer
|
||||
|
||||
Args:
|
||||
message: 要发送的消息
|
||||
show_log: 是否显示日志
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=120)
|
||||
|
||||
try:
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
# 将 MessageSending 转换为 MessageEnvelope
|
||||
envelope = _message_sending_to_envelope(message)
|
||||
|
||||
# 通过 CoreSinkManager 发送
|
||||
from src.common.core_sink_manager import get_core_sink_manager
|
||||
|
||||
manager = get_core_sink_manager()
|
||||
await manager.send_outgoing(envelope)
|
||||
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
|
||||
@@ -44,7 +75,67 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
raise e
|
||||
|
||||
|
||||
def _message_sending_to_envelope(message: MessageSending) -> MessageEnvelope:
|
||||
"""
|
||||
将 MessageSending 转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
message: MessageSending 对象
|
||||
|
||||
Returns:
|
||||
MessageEnvelope: 消息信封
|
||||
"""
|
||||
# 构建消息信息
|
||||
message_info: dict[str, Any] = {
|
||||
"message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time or time.time(),
|
||||
"platform": message.message_info.platform,
|
||||
"user_info": {
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"platform": message.message_info.user_info.platform,
|
||||
} if message.message_info.user_info else None,
|
||||
}
|
||||
|
||||
# 添加群组信息(如果有)
|
||||
if message.chat_stream and message.chat_stream.group_info:
|
||||
message_info["group_info"] = {
|
||||
"group_id": message.chat_stream.group_info.group_id,
|
||||
"group_name": message.chat_stream.group_info.group_name,
|
||||
"platform": message.chat_stream.group_info.group_platform,
|
||||
}
|
||||
|
||||
# 构建消息段
|
||||
message_segment: dict[str, Any]
|
||||
if message.message_segment:
|
||||
message_segment = {
|
||||
"type": message.message_segment.type,
|
||||
"data": message.message_segment.data,
|
||||
}
|
||||
else:
|
||||
# 默认为文本消息
|
||||
message_segment = {
|
||||
"type": "text",
|
||||
"data": message.processed_plain_text or "",
|
||||
}
|
||||
|
||||
# 添加回复信息(如果有)
|
||||
if message.reply_to:
|
||||
message_segment["reply_to"] = message.reply_to
|
||||
|
||||
# 构建消息信封
|
||||
envelope = cast(MessageEnvelope, {
|
||||
"id": str(uuid.uuid4()),
|
||||
"direction": "outgoing",
|
||||
"platform": message.message_info.platform,
|
||||
"message_info": message_info,
|
||||
"message_segment": message_segment,
|
||||
})
|
||||
|
||||
return envelope
|
||||
|
||||
|
||||
class HeartFCSender:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -14,6 +14,9 @@ from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
@@ -48,7 +51,7 @@ class ChatterActionManager:
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: "ChatStream",
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: DatabaseMessages | None = None,
|
||||
@@ -476,7 +479,7 @@ class ChatterActionManager:
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: "ChatStream",
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
从 src.main 导出 core_sink 的辅助函数
|
||||
|
||||
由于 src.main 中实际使用的是 InProcessCoreSink,
|
||||
我们需要创建一个全局访问点
|
||||
"""
|
||||
|
||||
from mofox_bus import CoreSink, InProcessCoreSink
|
||||
|
||||
_global_core_sink: CoreSink | None = None
|
||||
|
||||
|
||||
def set_core_sink(sink: CoreSink) -> None:
|
||||
"""设置全局 core sink"""
|
||||
global _global_core_sink
|
||||
_global_core_sink = sink
|
||||
|
||||
|
||||
def get_core_sink() -> CoreSink:
|
||||
"""获取全局 core sink"""
|
||||
global _global_core_sink
|
||||
if _global_core_sink is None:
|
||||
raise RuntimeError("Core sink 尚未初始化")
|
||||
return _global_core_sink
|
||||
|
||||
|
||||
async def push_outgoing(envelope) -> None:
|
||||
"""将消息推送到 core sink 的 outgoing 通道"""
|
||||
sink = get_core_sink()
|
||||
push = getattr(sink, "push_outgoing", None)
|
||||
if push is None:
|
||||
raise RuntimeError("当前 core sink 不支持 push_outgoing 方法")
|
||||
await push(envelope)
|
||||
|
||||
__all__ = ["set_core_sink", "get_core_sink", "push_outgoing"]
|
||||
401
src/common/core_sink_manager.py
Normal file
401
src/common/core_sink_manager.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
CoreSink 统一管理器
|
||||
|
||||
负责管理 InProcessCoreSink 和 ProcessCoreSink 双实例,
|
||||
提供统一的消息收发接口,自动维护与适配器子进程的通信管道。
|
||||
|
||||
核心职责:
|
||||
1. 创建和管理 InProcessCoreSink(进程内消息)和 ProcessCoreSink(跨进程消息)
|
||||
2. 自动维护 ProcessCoreSink 与子进程的通信管道
|
||||
3. 使用 MessageRuntime 进行消息路由和处理
|
||||
4. 提供统一的消息发送接口
|
||||
|
||||
架构说明(2025-11 重构):
|
||||
- 集成 mofox_bus.MessageRuntime 作为消息路由中心
|
||||
- 使用 @runtime.on_message() 装饰器注册消息处理器
|
||||
- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑
|
||||
- 简化消息处理链条,提高可扩展性
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from mofox_bus import (
|
||||
InProcessCoreSink,
|
||||
MessageEnvelope,
|
||||
MessageRuntime,
|
||||
ProcessCoreSinkServer,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chat.message_receive.message_handler import MessageHandler
|
||||
|
||||
logger = get_logger("core_sink_manager")
|
||||
|
||||
|
||||
# 消息处理器类型
|
||||
MessageHandlerCallback = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class CoreSinkManager:
|
||||
"""
|
||||
CoreSink 统一管理器
|
||||
|
||||
管理 InProcessCoreSink 和 ProcessCoreSinkServer 双实例,
|
||||
集成 MessageRuntime 提供统一的消息路由和收发接口。
|
||||
|
||||
架构说明:
|
||||
- InProcessCoreSink: 用于同进程内的适配器(run_in_subprocess=False)
|
||||
- ProcessCoreSinkServer: 用于管理与子进程适配器的通信
|
||||
- MessageRuntime: 统一消息路由,支持 @on_message 装饰器和钩子机制
|
||||
|
||||
消息流向:
|
||||
1. 适配器(同进程)→ InProcessCoreSink → MessageRuntime.handle_message() → 注册的处理器
|
||||
2. 适配器(子进程)→ ProcessCoreSinkServer → MessageRuntime.handle_message() → 注册的处理器
|
||||
3. 核心回复 → CoreSinkManager.send_outgoing → 适配器
|
||||
|
||||
使用 MessageRuntime 的优势:
|
||||
- 支持 @runtime.on_message(message_type="xxx") 按消息类型路由
|
||||
- 支持 before_hook/after_hook/error_hook 统一处理流程
|
||||
- 支持中间件机制(洋葱模型)
|
||||
- 自动处理同步/异步处理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# MessageRuntime 实例
|
||||
self._runtime: MessageRuntime = MessageRuntime()
|
||||
|
||||
# InProcessCoreSink 实例(用于同进程适配器)
|
||||
self._in_process_sink: InProcessCoreSink | None = None
|
||||
|
||||
# 子进程通信管理
|
||||
# key: adapter_name, value: (ProcessCoreSinkServer, incoming_queue, outgoing_queue)
|
||||
self._process_sinks: Dict[str, tuple[ProcessCoreSinkServer, mp.Queue, mp.Queue]] = {}
|
||||
|
||||
# multiprocessing context
|
||||
self._mp_ctx = mp.get_context("spawn")
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def runtime(self) -> MessageRuntime:
|
||||
"""
|
||||
获取 MessageRuntime 实例
|
||||
|
||||
外部模块可以通过此属性注册消息处理器、钩子等:
|
||||
|
||||
```python
|
||||
manager = get_core_sink_manager()
|
||||
|
||||
# 注册消息处理器
|
||||
@manager.runtime.on_message(message_type="text")
|
||||
async def handle_text(envelope: MessageEnvelope):
|
||||
...
|
||||
|
||||
# 注册前置钩子
|
||||
manager.runtime.register_before_hook(my_before_hook)
|
||||
```
|
||||
|
||||
Returns:
|
||||
MessageRuntime 实例
|
||||
"""
|
||||
return self._runtime
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
初始化 CoreSink 管理器
|
||||
|
||||
创建 InProcessCoreSink,将收到的消息交给 MessageRuntime 处理。
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.warning("CoreSinkManager 已经初始化,跳过重复初始化")
|
||||
return
|
||||
|
||||
logger.info("正在初始化 CoreSink 管理器...")
|
||||
|
||||
# 创建 InProcessCoreSink,使用 MessageRuntime 作为消息处理入口
|
||||
self._in_process_sink = InProcessCoreSink(self._dispatch_to_runtime)
|
||||
|
||||
self._running = True
|
||||
self._initialized = True
|
||||
|
||||
logger.info("CoreSink 管理器初始化完成(已集成 MessageRuntime)")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭 CoreSink 管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("正在关闭 CoreSink 管理器...")
|
||||
self._running = False
|
||||
|
||||
# 关闭所有 ProcessCoreSinkServer
|
||||
for adapter_name, (server, _, _) in list(self._process_sinks.items()):
|
||||
try:
|
||||
await server.close()
|
||||
logger.info(f"已关闭适配器 {adapter_name} 的 ProcessCoreSinkServer")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭适配器 {adapter_name} 的 ProcessCoreSinkServer 时出错: {e}")
|
||||
|
||||
self._process_sinks.clear()
|
||||
|
||||
# 关闭 InProcessCoreSink
|
||||
if self._in_process_sink:
|
||||
await self._in_process_sink.close()
|
||||
self._in_process_sink = None
|
||||
|
||||
self._initialized = False
|
||||
logger.info("CoreSink 管理器已关闭")
|
||||
|
||||
def get_in_process_sink(self) -> InProcessCoreSink:
|
||||
"""
|
||||
获取 InProcessCoreSink 实例
|
||||
|
||||
用于同进程运行的适配器
|
||||
|
||||
Returns:
|
||||
InProcessCoreSink 实例
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果管理器未初始化
|
||||
"""
|
||||
if self._in_process_sink is None:
|
||||
raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()")
|
||||
return self._in_process_sink
|
||||
|
||||
def create_process_sink_queues(self, adapter_name: str) -> tuple[mp.Queue, mp.Queue]:
|
||||
"""
|
||||
为子进程适配器创建通信队列
|
||||
|
||||
创建 incoming 和 outgoing 队列对,用于与子进程适配器通信。
|
||||
同时创建 ProcessCoreSinkServer 来处理消息转发。
|
||||
|
||||
Args:
|
||||
adapter_name: 适配器名称
|
||||
|
||||
Returns:
|
||||
(to_core_queue, from_core_queue) 元组
|
||||
- to_core_queue: 子进程发送到核心的队列
|
||||
- from_core_queue: 核心发送到子进程的队列
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果管理器未初始化
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("CoreSinkManager 未初始化,请先调用 initialize()")
|
||||
|
||||
if adapter_name in self._process_sinks:
|
||||
logger.warning(f"适配器 {adapter_name} 的队列已存在,将被覆盖")
|
||||
# 先关闭旧的
|
||||
old_server, _, _ = self._process_sinks[adapter_name]
|
||||
asyncio.create_task(old_server.close())
|
||||
|
||||
# 创建通信队列
|
||||
incoming_queue = self._mp_ctx.Queue() # 子进程 → 核心
|
||||
outgoing_queue = self._mp_ctx.Queue() # 核心 → 子进程
|
||||
|
||||
# 创建 ProcessCoreSinkServer,使用 MessageRuntime 处理消息
|
||||
server = ProcessCoreSinkServer(
|
||||
incoming_queue=incoming_queue,
|
||||
outgoing_queue=outgoing_queue,
|
||||
core_handler=self._dispatch_to_runtime,
|
||||
name=adapter_name,
|
||||
)
|
||||
|
||||
# 启动服务器
|
||||
server.start()
|
||||
|
||||
# 存储引用
|
||||
self._process_sinks[adapter_name] = (server, incoming_queue, outgoing_queue)
|
||||
|
||||
logger.info(f"为适配器 {adapter_name} 创建了 ProcessCoreSink 通信队列")
|
||||
|
||||
return incoming_queue, outgoing_queue
|
||||
|
||||
def remove_process_sink(self, adapter_name: str) -> None:
|
||||
"""
|
||||
移除子进程适配器的通信队列
|
||||
|
||||
Args:
|
||||
adapter_name: 适配器名称
|
||||
"""
|
||||
if adapter_name not in self._process_sinks:
|
||||
logger.warning(f"适配器 {adapter_name} 的队列不存在")
|
||||
return
|
||||
|
||||
server, _, _ = self._process_sinks.pop(adapter_name)
|
||||
asyncio.create_task(server.close())
|
||||
logger.info(f"已移除适配器 {adapter_name} 的 ProcessCoreSink 通信队列")
|
||||
|
||||
async def send_outgoing(
|
||||
self,
|
||||
envelope: MessageEnvelope,
|
||||
platform: str | None = None,
|
||||
adapter_name: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
发送消息到适配器
|
||||
|
||||
根据 platform 或 adapter_name 路由到正确的适配器。
|
||||
|
||||
Args:
|
||||
envelope: 消息信封
|
||||
platform: 目标平台(可选)
|
||||
adapter_name: 目标适配器名称(可选)
|
||||
|
||||
路由规则:
|
||||
1. 如果指定了 adapter_name,直接发送到该适配器
|
||||
2. 如果指定了 platform,发送到所有匹配平台的适配器
|
||||
3. 如果都没指定,从 envelope 中提取 platform 并广播
|
||||
"""
|
||||
# 从 envelope 中获取 platform
|
||||
if platform is None:
|
||||
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
|
||||
|
||||
# 发送到 InProcessCoreSink(会自动广播到所有注册的 outgoing handler)
|
||||
if self._in_process_sink:
|
||||
await self._in_process_sink.push_outgoing(envelope)
|
||||
|
||||
# 发送到所有 ProcessCoreSinkServer
|
||||
for name, (server, _, _) in self._process_sinks.items():
|
||||
if adapter_name and name != adapter_name:
|
||||
continue
|
||||
try:
|
||||
await server.push_outgoing(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息到适配器 {name} 失败: {e}")
|
||||
|
||||
async def _dispatch_to_runtime(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
将消息分发给 MessageRuntime 处理
|
||||
|
||||
这是内部方法,由 InProcessCoreSink 和 ProcessCoreSinkServer 调用。
|
||||
所有从适配器接收到的消息都会经过这里,然后交给 MessageRuntime 路由。
|
||||
|
||||
Args:
|
||||
envelope: 消息信封
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning("CoreSinkManager 未运行,忽略接收到的消息")
|
||||
return
|
||||
|
||||
try:
|
||||
# 使用 MessageRuntime 处理消息
|
||||
await self._runtime.handle_message(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"MessageRuntime 处理消息时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_core_sink_manager: CoreSinkManager | None = None
|
||||
|
||||
|
||||
def get_core_sink_manager() -> CoreSinkManager:
|
||||
"""获取 CoreSinkManager 单例"""
|
||||
global _core_sink_manager
|
||||
if _core_sink_manager is None:
|
||||
_core_sink_manager = CoreSinkManager()
|
||||
return _core_sink_manager
|
||||
|
||||
|
||||
def get_message_runtime() -> MessageRuntime:
|
||||
"""
|
||||
获取全局 MessageRuntime 实例
|
||||
|
||||
这是获取 MessageRuntime 的推荐方式,用于注册消息处理器、钩子等:
|
||||
|
||||
```python
|
||||
from src.common.core_sink_manager import get_message_runtime
|
||||
|
||||
runtime = get_message_runtime()
|
||||
|
||||
@runtime.on_message(message_type="text")
|
||||
async def handle_text(envelope: MessageEnvelope):
|
||||
...
|
||||
```
|
||||
|
||||
Returns:
|
||||
MessageRuntime 实例
|
||||
"""
|
||||
return get_core_sink_manager().runtime
|
||||
|
||||
|
||||
async def initialize_core_sink_manager() -> CoreSinkManager:
|
||||
"""
|
||||
初始化 CoreSinkManager 单例
|
||||
|
||||
Returns:
|
||||
初始化后的 CoreSinkManager 实例
|
||||
"""
|
||||
manager = get_core_sink_manager()
|
||||
await manager.initialize()
|
||||
return manager
|
||||
|
||||
|
||||
async def shutdown_core_sink_manager() -> None:
|
||||
"""关闭 CoreSinkManager 单例"""
|
||||
global _core_sink_manager
|
||||
if _core_sink_manager:
|
||||
await _core_sink_manager.shutdown()
|
||||
_core_sink_manager = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 向后兼容的 API
|
||||
# ============================================================================
|
||||
|
||||
def get_core_sink() -> InProcessCoreSink:
|
||||
"""
|
||||
获取 InProcessCoreSink 实例(向后兼容)
|
||||
|
||||
这是旧版 API,推荐使用 get_core_sink_manager().get_in_process_sink()
|
||||
|
||||
Returns:
|
||||
InProcessCoreSink 实例
|
||||
"""
|
||||
return get_core_sink_manager().get_in_process_sink()
|
||||
|
||||
|
||||
def set_core_sink(sink: Any) -> None:
|
||||
"""
|
||||
设置 CoreSink(向后兼容,现已弃用)
|
||||
|
||||
新架构中 CoreSink 由 CoreSinkManager 统一管理,不再支持外部设置。
|
||||
此函数保留仅为兼容旧代码,调用会记录警告日志。
|
||||
"""
|
||||
logger.warning(
|
||||
"set_core_sink() 已弃用,CoreSink 现由 CoreSinkManager 统一管理。"
|
||||
"请使用 initialize_core_sink_manager() 初始化。"
|
||||
)
|
||||
|
||||
|
||||
async def push_outgoing(envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
将消息推送到所有适配器(向后兼容)
|
||||
|
||||
Args:
|
||||
envelope: 消息信封
|
||||
"""
|
||||
manager = get_core_sink_manager()
|
||||
await manager.send_outgoing(envelope)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CoreSinkManager",
|
||||
"get_core_sink_manager",
|
||||
"get_message_runtime",
|
||||
"initialize_core_sink_manager",
|
||||
"shutdown_core_sink_manager",
|
||||
# 向后兼容
|
||||
"get_core_sink",
|
||||
"set_core_sink",
|
||||
"push_outgoing",
|
||||
]
|
||||
@@ -16,6 +16,24 @@ class DatabaseUserInfo(BaseDataModel):
|
||||
user_nickname: str = field(default_factory=str) # 用户昵称
|
||||
user_cardname: str | None = None # 用户备注名或群名片,可为空
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatabaseUserInfo":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
platform=data.get("platform", ""),
|
||||
user_id=data.get("user_id", ""),
|
||||
user_nickname=data.get("user_nickname", ""),
|
||||
user_cardname=data.get("user_cardname"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将实例转换为字典"""
|
||||
return {
|
||||
"platform": self.platform,
|
||||
"user_id": self.user_id,
|
||||
"user_nickname": self.user_nickname,
|
||||
"user_cardname": self.user_cardname,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class DatabaseGroupInfo(BaseDataModel):
|
||||
@@ -26,7 +44,23 @@ class DatabaseGroupInfo(BaseDataModel):
|
||||
group_name: str = field(default_factory=str) # 群组名称
|
||||
group_platform: str | None = None # 群组所在平台,可为空
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatabaseGroupInfo":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
group_id=data.get("group_id", ""),
|
||||
group_name=data.get("group_name", ""),
|
||||
group_platform=data.get("group_platform"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将实例转换为字典"""
|
||||
return {
|
||||
"group_id": self.group_id,
|
||||
"group_name": self.group_name,
|
||||
"group_platform": self.group_platform,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class DatabaseChatInfo(BaseDataModel):
|
||||
"""
|
||||
|
||||
87
src/main.py
87
src/main.py
@@ -6,17 +6,21 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import partial
|
||||
from random import choices
|
||||
from typing import Any
|
||||
|
||||
from mofox_bus import InProcessCoreSink, MessageEnvelope
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.core_sink_manager import (
|
||||
CoreSinkManager,
|
||||
get_core_sink_manager,
|
||||
initialize_core_sink_manager,
|
||||
shutdown_core_sink_manager,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 全局背景任务集合
|
||||
@@ -55,28 +59,15 @@ EGG_PHRASES: list[tuple[str, int]] = [
|
||||
]
|
||||
|
||||
|
||||
def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float) -> None:
|
||||
"""后台任务完成时的回调函数"""
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
try:
|
||||
task.result() # 如果任务有异常,这里会重新抛出
|
||||
logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s")
|
||||
except Exception:
|
||||
logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
class MainSystem:
|
||||
"""主系统类,负责协调所有组件"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.individuality: Individuality = get_individuality()
|
||||
|
||||
# 创建核心消息接收器
|
||||
self.core_sink: InProcessCoreSink = InProcessCoreSink(self._message_process_wrapper)
|
||||
# CoreSinkManager 和 MessageHandler 将在 initialize() 中创建
|
||||
self.core_sink_manager: CoreSinkManager | None = None
|
||||
self.message_handler = None
|
||||
|
||||
# 使用服务器
|
||||
self.server: Server = get_global_server()
|
||||
@@ -163,10 +154,11 @@ class MainSystem:
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.plugin_system.base.component_types import ComponentType as CT
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
component_class = component_registry.get_component_class(
|
||||
calc_name, ComponentType.INTEREST_CALCULATOR
|
||||
calc_name, CT.INTEREST_CALCULATOR
|
||||
)
|
||||
|
||||
if not component_class:
|
||||
@@ -299,6 +291,18 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止适配器管理器时出错: {e}")
|
||||
|
||||
# 停止 CoreSinkManager
|
||||
try:
|
||||
cleanup_tasks.append(("CoreSinkManager", shutdown_core_sink_manager()))
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止 CoreSinkManager 时出错: {e}")
|
||||
|
||||
# 停止 MessageHandler
|
||||
try:
|
||||
cleanup_tasks.append(("MessageHandler", shutdown_message_handler()))
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止 MessageHandler 时出错: {e}")
|
||||
|
||||
# 并行执行所有清理任务
|
||||
if cleanup_tasks:
|
||||
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
|
||||
@@ -352,27 +356,6 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"同步清理资源时出错: {e}")
|
||||
|
||||
async def _message_process_wrapper(self, envelope: MessageEnvelope) -> None:
|
||||
"""并行处理消息的包装器"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
# 检查系统是否正在关闭
|
||||
if self._shutting_down:
|
||||
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
||||
return
|
||||
|
||||
# 创建后台任务
|
||||
task = asyncio.create_task(chat_bot.message_process(envelope))
|
||||
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||
|
||||
# 添加一个回调函数,当任务完成时,它会被调用
|
||||
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
|
||||
except Exception:
|
||||
logger.error("在创建消息处理任务时发生严重错误:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化系统组件"""
|
||||
# 检查必要的配置
|
||||
@@ -382,6 +365,18 @@ class MainSystem:
|
||||
|
||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||
|
||||
# 初始化 CoreSinkManager(包含 MessageRuntime)
|
||||
logger.info("正在初始化 CoreSinkManager...")
|
||||
self.core_sink_manager = await initialize_core_sink_manager()
|
||||
|
||||
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
|
||||
self.message_handler = get_message_handler()
|
||||
self.message_handler.set_core_sink_manager(self.core_sink_manager)
|
||||
|
||||
# 向 MessageRuntime 注册消息处理器和钩子
|
||||
self.message_handler.register_handlers(self.core_sink_manager.runtime)
|
||||
logger.info("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)")
|
||||
|
||||
# 初始化组件
|
||||
await self._init_components()
|
||||
|
||||
@@ -453,7 +448,11 @@ MoFox_Bot(第三方修改版)
|
||||
logger.error(f"统一调度器初始化失败: {e}")
|
||||
|
||||
# 设置核心消息接收器到插件管理器
|
||||
plugin_manager.set_core_sink(self.core_sink)
|
||||
# 使用 CoreSinkManager 的 InProcessCoreSink
|
||||
if self.core_sink_manager:
|
||||
plugin_manager.set_core_sink(self.core_sink_manager.get_in_process_sink())
|
||||
else:
|
||||
logger.error("CoreSinkManager 未初始化,无法设置核心消息接收器")
|
||||
|
||||
# 加载所有插件
|
||||
plugin_manager.load_all_plugins()
|
||||
@@ -505,8 +504,8 @@ MoFox_Bot(第三方修改版)
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM知识库初始化失败: {e}")
|
||||
|
||||
# 消息接收器已经在 __init__ 中创建,无需再次注册
|
||||
logger.info("核心消息接收器已就绪")
|
||||
# 消息接收器已在 initialize() 中通过 CoreSinkManager 创建
|
||||
logger.info("核心消息接收器已就绪(通过 CoreSinkManager)")
|
||||
|
||||
# 启动消息重组器
|
||||
try:
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
"""
|
||||
MoFox 内部通用消息总线实现。
|
||||
|
||||
该模块导出 TypedDict 消息模型、序列化工具、传输层封装以及适配器辅助工具,
|
||||
供核心进程与各类平台适配器共享。
|
||||
"""
|
||||
|
||||
from . import codec, types
|
||||
from .adapter_utils import (
|
||||
AdapterTransportOptions,
|
||||
AdapterBase,
|
||||
BatchDispatcher,
|
||||
CoreSink,
|
||||
CoreMessageSink,
|
||||
HttpAdapterOptions,
|
||||
InProcessCoreSink,
|
||||
ProcessCoreSink,
|
||||
ProcessCoreSinkServer,
|
||||
WebSocketLike,
|
||||
WebSocketAdapterOptions,
|
||||
)
|
||||
from .api import MessageClient, MessageServer
|
||||
from .codec import dumps_message, dumps_messages, loads_message, loads_messages
|
||||
from .builder import MessageBuilder
|
||||
from .router import RouteConfig, Router, TargetConfig
|
||||
from .runtime import MessageProcessingError, MessageRoute, MessageRuntime, Middleware
|
||||
from .types import (
|
||||
FormatInfoPayload,
|
||||
GroupInfoPayload,
|
||||
MessageDirection,
|
||||
MessageEnvelope,
|
||||
MessageInfoPayload,
|
||||
SegPayload,
|
||||
|
||||
TemplateInfoPayload,
|
||||
UserInfoPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# TypedDict model
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
"SegPayload",
|
||||
"UserInfoPayload",
|
||||
"GroupInfoPayload",
|
||||
"FormatInfoPayload",
|
||||
"TemplateInfoPayload",
|
||||
"MessageInfoPayload",
|
||||
# Codec helpers
|
||||
"codec",
|
||||
"dumps_message",
|
||||
"dumps_messages",
|
||||
"loads_message",
|
||||
"loads_messages",
|
||||
"MessageBuilder",
|
||||
# Runtime / routing
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"MessageProcessingError",
|
||||
"Middleware",
|
||||
# Server/client/router
|
||||
"MessageServer",
|
||||
"MessageClient",
|
||||
"Router",
|
||||
"RouteConfig",
|
||||
"TargetConfig",
|
||||
# Adapter helpers
|
||||
"AdapterTransportOptions",
|
||||
"AdapterBase",
|
||||
"BatchDispatcher",
|
||||
"CoreSink",
|
||||
"CoreMessageSink",
|
||||
"InProcessCoreSink",
|
||||
"ProcessCoreSink",
|
||||
"ProcessCoreSinkServer",
|
||||
"WebSocketLike",
|
||||
"WebSocketAdapterOptions",
|
||||
"HttpAdapterOptions",
|
||||
]
|
||||
@@ -1,485 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Protocol
|
||||
|
||||
import orjson
|
||||
from aiohttp import web as aiohttp_web
|
||||
import websockets
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
logger = logging.getLogger("mofox_bus.adapter")
|
||||
|
||||
|
||||
OutgoingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class CoreMessageSink(Protocol):
|
||||
async def send(self, message: MessageEnvelope) -> None: ...
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None: ... # pragma: no cover - optional
|
||||
|
||||
|
||||
class CoreSink(CoreMessageSink, Protocol):
|
||||
"""
|
||||
双向 CoreSink 协议:
|
||||
- send/send_many: 适配器 → 核心(incoming)
|
||||
- push_outgoing: 核心 → 适配器(outgoing)
|
||||
"""
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None: ...
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None: ...
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None: ...
|
||||
|
||||
async def close(self) -> None: ... # pragma: no cover - lifecycle hook
|
||||
|
||||
|
||||
class WebSocketLike(Protocol):
|
||||
def __aiter__(self) -> AsyncIterator[str | bytes]: ...
|
||||
|
||||
@property
|
||||
def closed(self) -> bool: ...
|
||||
|
||||
async def send(self, data: str | bytes) -> None: ...
|
||||
|
||||
async def close(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketAdapterOptions:
|
||||
url: str
|
||||
headers: dict[str, str] | None = None
|
||||
incoming_parser: Callable[[str | bytes], Any] | None = None
|
||||
outgoing_encoder: Callable[[MessageEnvelope], str | bytes] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpAdapterOptions:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8089
|
||||
path: str = "/adapter/messages"
|
||||
app: aiohttp_web.Application | None = None
|
||||
|
||||
|
||||
AdapterTransportOptions = WebSocketAdapterOptions | HttpAdapterOptions | None
|
||||
|
||||
|
||||
class AdapterBase:
|
||||
"""
|
||||
适配器基类:负责平台原始消息与 MessageEnvelope 之间的互转。
|
||||
子类需要实现平台入站解析与出站发送逻辑。
|
||||
"""
|
||||
|
||||
platform: str = "unknown"
|
||||
|
||||
def __init__(self, core_sink: CoreSink, transport: AdapterTransportOptions = None):
|
||||
"""
|
||||
Args:
|
||||
core_sink: 核心消息入口,通常是 InProcessCoreSink 或自定义客户端。
|
||||
transport: 传入 WebSocketAdapterOptions / HttpAdapterOptions 即可自动管理监听逻辑。
|
||||
"""
|
||||
self.core_sink = core_sink
|
||||
self._transport_config = transport
|
||||
self._ws: WebSocketLike | None = None
|
||||
self._ws_task: asyncio.Task | None = None
|
||||
self._http_runner: aiohttp_web.AppRunner | None = None
|
||||
self._http_site: aiohttp_web.BaseSite | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动适配器的传输层监听(如果配置了传输选项)。"""
|
||||
if hasattr(self.core_sink, "set_outgoing_handler"):
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("注册 outgoing 处理程序到核心接收器失败")
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._start_ws_transport(self._transport_config)
|
||||
elif isinstance(self._transport_config, HttpAdapterOptions):
|
||||
await self._start_http_transport(self._transport_config)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止适配器的传输层监听(如果配置了传输选项)。"""
|
||||
remove = getattr(self.core_sink, "remove_outgoing_handler", None)
|
||||
if callable(remove):
|
||||
try:
|
||||
remove(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("从核心接收器分离 outgoing 处理程序失败")
|
||||
elif hasattr(self.core_sink, "set_outgoing_handler"):
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
logger.exception("从核心接收器分离 outgoing 处理程序失败")
|
||||
if self._ws_task:
|
||||
self._ws_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._ws_task
|
||||
self._ws_task = None
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http_site:
|
||||
await self._http_site.stop()
|
||||
self._http_site = None
|
||||
if self._http_runner:
|
||||
await self._http_runner.cleanup()
|
||||
self._http_runner = None
|
||||
|
||||
async def on_platform_message(self, raw: Any) -> None:
|
||||
"""处理平台下发的单条消息并交给核心。"""
|
||||
envelope = await self.from_platform_message(raw)
|
||||
await self.core_sink.send(envelope)
|
||||
|
||||
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
||||
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
||||
envelopes = [await self.from_platform_message(raw) for raw in raw_messages]
|
||||
await _send_many(self.core_sink, envelopes)
|
||||
|
||||
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
||||
"""核心生成单条消息时调用,由子类或自动传输层发送。"""
|
||||
await self._send_platform_message(envelope)
|
||||
|
||||
async def send_batch_to_platform(self, envelopes: list[MessageEnvelope]) -> None:
|
||||
"""默认串行发送整批消息,子类可根据平台特性重写。"""
|
||||
for env in envelopes:
|
||||
await self._send_platform_message(env)
|
||||
|
||||
async def _on_outgoing_from_core(self, envelope: MessageEnvelope) -> None:
|
||||
"""核心生成 outgoing envelope 时的内部处理逻辑"""
|
||||
platform = envelope.get("platform") or envelope.get("message_info", {}).get("platform")
|
||||
if platform and platform != getattr(self, "platform", None):
|
||||
return
|
||||
await self._send_platform_message(envelope)
|
||||
|
||||
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _send_platform_message(self, envelope: MessageEnvelope) -> None:
|
||||
"""子类必须实现:把 MessageEnvelope 转为平台格式并发送出去。"""
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._send_via_ws(envelope)
|
||||
return
|
||||
raise NotImplementedError
|
||||
|
||||
async def _start_ws_transport(self, options: WebSocketAdapterOptions) -> None:
|
||||
self._ws = await websockets.connect(options.url, extra_headers=options.headers)
|
||||
self._ws_task = asyncio.create_task(self._ws_listen_loop(options))
|
||||
|
||||
async def _ws_listen_loop(self, options: WebSocketAdapterOptions) -> None:
|
||||
assert self._ws is not None
|
||||
parser = options.incoming_parser or self._default_ws_parser
|
||||
try:
|
||||
async for raw in self._ws:
|
||||
payload = parser(raw)
|
||||
await self.on_platform_message(payload)
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _send_via_ws(self, envelope: MessageEnvelope) -> None:
|
||||
if self._ws is None or self._ws.closed:
|
||||
raise RuntimeError("WebSocket transport is not active")
|
||||
encoder = None
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
encoder = self._transport_config.outgoing_encoder
|
||||
data = encoder(envelope) if encoder else self._default_ws_encoder(envelope)
|
||||
await self._ws.send(data)
|
||||
|
||||
async def _start_http_transport(self, options: HttpAdapterOptions) -> None:
|
||||
app = options.app or aiohttp_web.Application()
|
||||
app.add_routes([aiohttp_web.post(options.path, self._handle_http_request)])
|
||||
self._http_runner = aiohttp_web.AppRunner(app)
|
||||
await self._http_runner.setup()
|
||||
self._http_site = aiohttp_web.TCPSite(self._http_runner, options.host, options.port)
|
||||
await self._http_site.start()
|
||||
|
||||
async def _handle_http_request(self, request: aiohttp_web.Request) -> aiohttp_web.Response:
|
||||
raw = await request.read()
|
||||
data = orjson.loads(raw) if raw else {}
|
||||
if isinstance(data, list):
|
||||
await self.on_platform_messages(data)
|
||||
else:
|
||||
await self.on_platform_message(data)
|
||||
return aiohttp_web.json_response({"status": "ok"})
|
||||
|
||||
@staticmethod
|
||||
def _default_ws_parser(raw: str | bytes) -> Any:
|
||||
data = orjson.loads(raw)
|
||||
if isinstance(data, dict) and data.get("type") == "message" and "payload" in data:
|
||||
return data["payload"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _default_ws_encoder(envelope: MessageEnvelope) -> bytes:
|
||||
return orjson.dumps({"type": "send", "payload": envelope})
|
||||
|
||||
|
||||
class InProcessCoreSink(CoreSink):
|
||||
"""
|
||||
进程内核心消息 sink,实现 CoreSink 协议。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[MessageEnvelope], Awaitable[None]]):
|
||||
self._handler = handler
|
||||
self._outgoing_handlers: set[OutgoingHandler] = set()
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||
if handler is None:
|
||||
return
|
||||
self._outgoing_handlers.add(handler)
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||
self._outgoing_handlers.discard(handler)
|
||||
|
||||
async def send(self, message: MessageEnvelope) -> None:
|
||||
await self._handler(message)
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None:
|
||||
for message in messages:
|
||||
await self._handler(message)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
if not self._outgoing_handlers:
|
||||
logger.debug("Outgoing envelope dropped: no handler registered")
|
||||
return
|
||||
for callback in list(self._outgoing_handlers):
|
||||
await callback(envelope)
|
||||
|
||||
async def close(self) -> None: # pragma: no cover - symmetry
|
||||
self._outgoing_handlers.clear()
|
||||
|
||||
|
||||
class ProcessCoreSink(CoreSink):
|
||||
"""
|
||||
进程间核心消息 sink,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化
|
||||
"""
|
||||
|
||||
_CONTROL_STOP = {"__core_sink_control__": "stop"}
|
||||
|
||||
def __init__(self, *, to_core_queue: mp.Queue, from_core_queue: mp.Queue) -> None:
|
||||
self._to_core_queue = to_core_queue
|
||||
self._from_core_queue = from_core_queue
|
||||
self._outgoing_handler: OutgoingHandler | None = None
|
||||
self._closed = False
|
||||
self._listener_task: asyncio.Task | None = None
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def set_outgoing_handler(self, handler: OutgoingHandler | None) -> None:
|
||||
self._outgoing_handler = handler
|
||||
if handler is not None and (self._listener_task is None or self._listener_task.done()):
|
||||
self._listener_task = self._loop.create_task(self._listen_from_core())
|
||||
|
||||
def remove_outgoing_handler(self, handler: OutgoingHandler) -> None:
|
||||
if self._outgoing_handler is handler:
|
||||
self._outgoing_handler = None
|
||||
if self._listener_task and not self._listener_task.done():
|
||||
self._listener_task.cancel()
|
||||
|
||||
async def send(self, message: MessageEnvelope) -> None:
|
||||
await asyncio.to_thread(self._to_core_queue.put, {"kind": "incoming", "payload": message})
|
||||
|
||||
async def send_many(self, messages: list[MessageEnvelope]) -> None:
|
||||
for message in messages:
|
||||
await self.send(message)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
logger.debug("ProcessCoreSink.push_outgoing 在子进程中调用; 被忽略")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
await asyncio.to_thread(self._from_core_queue.put, self._CONTROL_STOP)
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._listener_task
|
||||
self._listener_task = None
|
||||
|
||||
async def _listen_from_core(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
item = await asyncio.to_thread(self._from_core_queue.get)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if item == self._CONTROL_STOP:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("kind") == "outgoing":
|
||||
envelope = item.get("payload")
|
||||
if self._outgoing_handler:
|
||||
try:
|
||||
await self._outgoing_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("处理 ProcessCoreSink 中的 outgoing 信封失败")
|
||||
else:
|
||||
logger.debug(f"ProcessCoreSink 接受到未知负载: {item}")
|
||||
|
||||
|
||||
class ProcessCoreSinkServer:
|
||||
"""
|
||||
进程间核心消息 sink 服务器,实现 CoreSink 协议,使用 multiprocessing.Queue 初始化。
|
||||
- 将传入的 incoming 消息转发给指定的 handler
|
||||
- 将接收到的 outgoing 消息放入 outgoing 队列
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
incoming_queue: mp.Queue,
|
||||
outgoing_queue: mp.Queue,
|
||||
core_handler: Callable[[MessageEnvelope], Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
self._incoming_queue = incoming_queue
|
||||
self._outgoing_queue = outgoing_queue
|
||||
self._core_handler = core_handler
|
||||
self._task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._name = name or "adapter"
|
||||
|
||||
def start(self) -> None:
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._consume_incoming())
|
||||
|
||||
async def _consume_incoming(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
item = await asyncio.to_thread(self._incoming_queue.get)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("__core_sink_control__") == "stop":
|
||||
break
|
||||
if isinstance(item, dict) and item.get("kind") == "incoming":
|
||||
envelope = item.get("payload")
|
||||
try:
|
||||
await self._core_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception(f"处理来自 {self._name} 的 incoming 信封时失败")
|
||||
else:
|
||||
logger.debug(f"ProcessCoreSinkServer 忽略来自 {self._name} 的未知负载: {item}")
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
await asyncio.to_thread(self._incoming_queue.put, {"__core_sink_control__": "stop"})
|
||||
await asyncio.to_thread(self._outgoing_queue.put, ProcessCoreSink._CONTROL_STOP)
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) -> None:
|
||||
send_many = getattr(sink, "send_many", None)
|
||||
if callable(send_many):
|
||||
await send_many(envelopes)
|
||||
return
|
||||
for env in envelopes:
|
||||
await sink.send(env)
|
||||
|
||||
|
||||
class BatchDispatcher:
|
||||
"""
|
||||
批量消息分发器,负责将消息批量发送到核心 sink。
|
||||
"""
|
||||
|
||||
_STOP = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sink: CoreMessageSink,
|
||||
*,
|
||||
max_batch_size: int = 50,
|
||||
flush_interval: float = 0.2,
|
||||
) -> None:
|
||||
self._sink = sink
|
||||
self._max_batch_size = max_batch_size
|
||||
self._flush_interval = flush_interval
|
||||
self._queue: asyncio.Queue[MessageEnvelope | object] = asyncio.Queue()
|
||||
self._worker: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def add(self, message: MessageEnvelope) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("Dispatcher closed")
|
||||
self._ensure_worker()
|
||||
await self._queue.put(message)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
self._ensure_worker()
|
||||
await self._queue.put(self._STOP)
|
||||
if self._worker:
|
||||
await self._worker
|
||||
self._worker = None
|
||||
|
||||
def _ensure_worker(self) -> None:
|
||||
if self._worker is not None and not self._worker.done():
|
||||
return
|
||||
self._worker = asyncio.create_task(self._worker_loop())
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
buffer: list[MessageEnvelope] = []
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=self._flush_interval)
|
||||
except asyncio.TimeoutError:
|
||||
item = None
|
||||
|
||||
if item is self._STOP:
|
||||
await self._flush_buffer(buffer)
|
||||
return
|
||||
if item is not None:
|
||||
buffer.append(item) # type: ignore[arg-type]
|
||||
|
||||
while len(buffer) < self._max_batch_size:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if item is self._STOP:
|
||||
await self._flush_buffer(buffer)
|
||||
return
|
||||
buffer.append(item) # type: ignore[arg-type]
|
||||
|
||||
if buffer and (len(buffer) >= self._max_batch_size or item is None):
|
||||
await self._flush_buffer(buffer)
|
||||
except asyncio.CancelledError: # pragma: no cover - worker cancellation
|
||||
if buffer:
|
||||
await self._flush_buffer(buffer)
|
||||
|
||||
async def _flush_buffer(self, buffer: list[MessageEnvelope]) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
payload = list(buffer)
|
||||
buffer.clear()
|
||||
await _send_many(self._sink, payload)
|
||||
|
||||
__all__ = [
|
||||
"AdapterTransportOptions",
|
||||
"AdapterBase",
|
||||
"BatchDispatcher",
|
||||
"CoreSink",
|
||||
"CoreMessageSink",
|
||||
"HttpAdapterOptions",
|
||||
"InProcessCoreSink",
|
||||
"ProcessCoreSink",
|
||||
"ProcessCoreSinkServer",
|
||||
"WebSocketLike",
|
||||
"WebSocketAdapterOptions",
|
||||
]
|
||||
@@ -1,515 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
|
||||
MessagePayload = Dict[str, Any]
|
||||
MessageHandler = Callable[[MessagePayload], Awaitable[None] | None]
|
||||
DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
|
||||
|
||||
|
||||
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
|
||||
"""
|
||||
将原始字节数据附加到消息负载中
|
||||
|
||||
Args:
|
||||
payload: 消息负载
|
||||
raw_bytes: 原始字节数据
|
||||
|
||||
Returns:
|
||||
附加了原始数据的消息负载
|
||||
"""
|
||||
if isinstance(payload, dict):
|
||||
payload.setdefault("raw_bytes", raw_bytes)
|
||||
elif isinstance(payload, list):
|
||||
for item in payload:
|
||||
if isinstance(item, dict):
|
||||
item.setdefault("raw_bytes", raw_bytes)
|
||||
return payload
|
||||
|
||||
|
||||
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
|
||||
"""
|
||||
编码消息用于 WebSocket 发送
|
||||
|
||||
Args:
|
||||
message: 要发送的消息
|
||||
use_raw_bytes: 是否使用原始字节数据
|
||||
|
||||
Returns:
|
||||
(编码后的数据, 是否为二进制格式)
|
||||
"""
|
||||
if isinstance(message, (bytes, bytearray)):
|
||||
return bytes(message), True
|
||||
if use_raw_bytes and isinstance(message, dict):
|
||||
raw = message.get("raw_bytes")
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
return bytes(raw), True
|
||||
payload = message
|
||||
if isinstance(payload, dict) and "raw_bytes" in payload and not use_raw_bytes:
|
||||
payload = {k: v for k, v in payload.items() if k != "raw_bytes"}
|
||||
data = orjson.dumps(payload)
|
||||
if use_raw_bytes:
|
||||
return data, True
|
||||
return data.decode("utf-8"), False
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""基础消息处理器,提供消息处理和任务管理功能"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.message_handlers: list[MessageHandler] = []
|
||||
self.background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def register_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
注册消息处理器
|
||||
|
||||
Args:
|
||||
handler: 消息处理函数
|
||||
"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def process_message(self, message: MessagePayload) -> None:
|
||||
"""
|
||||
处理单条消息,并发执行所有注册的处理器
|
||||
|
||||
Args:
|
||||
message: 消息负载
|
||||
"""
|
||||
tasks: list[asyncio.Task] = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
result = handler(message)
|
||||
if asyncio.iscoroutine(result):
|
||||
task = asyncio.create_task(result)
|
||||
tasks.append(task)
|
||||
self.background_tasks.add(task)
|
||||
task.add_done_callback(self.background_tasks.discard)
|
||||
except Exception: # pragma: no cover - logging only
|
||||
logging.getLogger("mofox_bus.server").exception("消息处理失败")
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
class MessageServer(BaseMessageHandler):
|
||||
"""
|
||||
WebSocket 消息服务器,支持与 FastAPI 应用共享事件循环。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
*,
|
||||
enable_token: bool = False,
|
||||
app: FastAPI | None = None,
|
||||
path: str = "/ws",
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
mode: Literal["ws", "tcp"] = "ws",
|
||||
custom_logger: logging.Logger | None = None,
|
||||
enable_custom_uvicorn_logger: bool = False,
|
||||
queue_maxsize: int = 1000,
|
||||
worker_count: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||
if custom_logger:
|
||||
logging.getLogger("mofox_bus.server").handlers = custom_logger.handlers
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._app = app or FastAPI()
|
||||
self._own_app = app is None
|
||||
self._path = path
|
||||
self._ssl_certfile = ssl_certfile
|
||||
self._ssl_keyfile = ssl_keyfile
|
||||
self._enable_token = enable_token
|
||||
self._valid_tokens: set[str] = set()
|
||||
self._connections: set[WebSocket] = set()
|
||||
self._platform_connections: dict[str, WebSocket] = {}
|
||||
self._conn_lock = asyncio.Lock()
|
||||
self._server: uvicorn.Server | None = None
|
||||
self._running = False
|
||||
self._message_queue: asyncio.Queue[MessagePayload] = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self._worker_count = max(1, worker_count)
|
||||
self._worker_tasks: list[asyncio.Task] = []
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
@_self_websocket(self._app, self._path)
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
platform = websocket.headers.get("platform", "unknown")
|
||||
token = websocket.headers.get("authorization") or websocket.headers.get("Authorization")
|
||||
if self._enable_token and not await self.verify_token(token):
|
||||
await websocket.close(code=1008, reason="invalid token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
await self._register_connection(websocket, platform)
|
||||
try:
|
||||
while True:
|
||||
msg = await websocket.receive()
|
||||
if msg["type"] == "websocket.receive":
|
||||
raw_bytes = msg.get("bytes")
|
||||
if raw_bytes is None and msg.get("text") is not None:
|
||||
raw_bytes = msg["text"].encode("utf-8")
|
||||
if not raw_bytes:
|
||||
continue
|
||||
try:
|
||||
payload = orjson.loads(raw_bytes)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.server").warning("Invalid JSON payload")
|
||||
continue
|
||||
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self._enqueue_message(item)
|
||||
else:
|
||||
await self._enqueue_message(payload)
|
||||
elif msg["type"] == "websocket.disconnect":
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
await self._remove_connection(websocket, platform)
|
||||
|
||||
async def _enqueue_message(self, payload: MessagePayload) -> None:
|
||||
if not self._worker_tasks:
|
||||
self._start_workers()
|
||||
try:
|
||||
self._message_queue.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
logging.getLogger("mofox_bus.server").warning("Message queue full, dropping message")
|
||||
|
||||
def _start_workers(self) -> None:
|
||||
if self._worker_tasks:
|
||||
return
|
||||
self._running = True
|
||||
for _ in range(self._worker_count):
|
||||
task = asyncio.create_task(self._consumer_worker())
|
||||
self._worker_tasks.append(task)
|
||||
|
||||
async def _stop_workers(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
self._running = False
|
||||
for task in self._worker_tasks:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
||||
self._worker_tasks.clear()
|
||||
while not self._message_queue.empty():
|
||||
with contextlib.suppress(asyncio.QueueEmpty):
|
||||
self._message_queue.get_nowait()
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def _consumer_worker(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
payload = await self._message_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
try:
|
||||
await self.process_message(payload)
|
||||
except Exception: # pragma: no cover - best effort logging
|
||||
logging.getLogger("mofox_bus.server").exception("Error processing message")
|
||||
finally:
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def verify_token(self, token: str | None) -> bool:
|
||||
if not self._enable_token:
|
||||
return True
|
||||
return token in self._valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str) -> None:
|
||||
self._valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str) -> None:
|
||||
self._valid_tokens.discard(token)
|
||||
|
||||
async def _register_connection(self, websocket: WebSocket, platform: str) -> None:
|
||||
async with self._conn_lock:
|
||||
self._connections.add(websocket)
|
||||
if platform:
|
||||
previous = self._platform_connections.get(platform)
|
||||
if previous and previous.client_state.name != "DISCONNECTED":
|
||||
await previous.close(code=1000, reason="replaced")
|
||||
self._platform_connections[platform] = websocket
|
||||
|
||||
async def _remove_connection(self, websocket: WebSocket, platform: str) -> None:
|
||||
async with self._conn_lock:
|
||||
self._connections.discard(websocket)
|
||||
if platform and self._platform_connections.get(platform) is websocket:
|
||||
del self._platform_connections[platform]
|
||||
|
||||
async def broadcast_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> None:
|
||||
payload: MessagePayload | bytes = message
|
||||
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||
async with self._conn_lock:
|
||||
targets = list(self._connections)
|
||||
for ws in targets:
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
|
||||
async def broadcast_to_platform(
|
||||
self, platform: str, message: MessagePayload | bytes, *, use_raw_bytes: bool = False
|
||||
) -> None:
|
||||
ws = self._platform_connections.get(platform)
|
||||
if ws is None:
|
||||
raise RuntimeError(f"平台 {platform} 没有活跃的连接")
|
||||
payload: MessagePayload | bytes = message
|
||||
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_text(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
|
||||
async def send_message(
|
||||
self, message: MessagePayload, *, prefer_raw_bytes: bool = False
|
||||
) -> None:
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required to route the message")
|
||||
await self.broadcast_to_platform(platform, message, use_raw_bytes=prefer_raw_bytes)
|
||||
|
||||
def run_sync(self) -> None:
|
||||
if not self._own_app:
|
||||
return
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self) -> None:
|
||||
self._start_workers()
|
||||
if not self._own_app:
|
||||
return
|
||||
config = uvicorn.Config(
|
||||
self._app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
ssl_certfile=self._ssl_certfile,
|
||||
ssl_keyfile=self._ssl_keyfile,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = uvicorn.Server(config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
except asyncio.CancelledError: # pragma: no cover - shutdown path
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
await self._stop_workers()
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
self._server = None
|
||||
async with self._conn_lock:
|
||||
targets = list(self._connections)
|
||||
self._connections.clear()
|
||||
self._platform_connections.clear()
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.close(code=1001, reason="server shutting down")
|
||||
except Exception: # pragma: no cover - best effort
|
||||
pass
|
||||
for task in list(self.background_tasks):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if self.background_tasks:
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
|
||||
class MessageClient(BaseMessageHandler):
|
||||
"""
|
||||
WebSocket 消息客户端,实现双向传输。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: Literal["ws", "tcp"] = "ws",
|
||||
*,
|
||||
reconnect_interval: float = 5.0,
|
||||
logger: logging.Logger | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("Only WebSocket mode is supported in mofox_bus")
|
||||
self._mode = mode
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._url: str = ""
|
||||
self._platform: str = ""
|
||||
self._token: str | None = None
|
||||
self._ssl_verify: str | None = None
|
||||
self._closed = False
|
||||
self._on_disconnect: DisconnectCallback | None = None
|
||||
self._reconnect_interval = reconnect_interval
|
||||
self._logger = logger or logging.getLogger("mofox_bus.client")
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
platform: str,
|
||||
token: str | None = None,
|
||||
ssl_verify: str | None = None,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._platform = platform
|
||||
self._token = token
|
||||
self._ssl_verify = ssl_verify
|
||||
self._closed = False
|
||||
await self._establish_connection()
|
||||
|
||||
def set_disconnect_callback(self, callback: DisconnectCallback) -> None:
|
||||
self._on_disconnect = callback
|
||||
|
||||
async def _establish_connection(self) -> None:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
headers = {"platform": self._platform}
|
||||
if self._token:
|
||||
headers["authorization"] = self._token
|
||||
ssl_context = None
|
||||
if self._ssl_verify:
|
||||
ssl_context = ssl.create_default_context(cafile=self._ssl_verify)
|
||||
self._ws = await self._session.ws_connect(self._url, headers=headers, ssl=ssl_context)
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _connect_once(self) -> None:
|
||||
await self._establish_connection()
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||
raw_bytes = msg.data if isinstance(msg.data, (bytes, bytearray)) else msg.data.encode("utf-8")
|
||||
try:
|
||||
payload = orjson.loads(raw_bytes)
|
||||
except orjson.JSONDecodeError:
|
||||
logging.getLogger("mofox_bus.client").warning("Invalid JSON payload")
|
||||
continue
|
||||
payload = _attach_raw_bytes(payload, raw_bytes)
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
await self.process_message(item)
|
||||
else:
|
||||
await self.process_message(payload)
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
pass
|
||||
finally:
|
||||
if not self._closed:
|
||||
await self._notify_disconnect("websocket disconnected")
|
||||
await self._reconnect()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def run(self) -> None:
|
||||
self._closed = False
|
||||
while not self._closed:
|
||||
if self._receive_task is None:
|
||||
await self._establish_connection()
|
||||
task = self._receive_task
|
||||
if task is None:
|
||||
break
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
raise
|
||||
|
||||
async def send_message(self, message: MessagePayload | bytes, *, use_raw_bytes: bool = False) -> bool:
|
||||
ws = await self._ensure_ws()
|
||||
data, is_binary = _encode_for_ws_send(message, use_raw_bytes=use_raw_bytes)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
else:
|
||||
await ws.send_str(data if isinstance(data, str) else data.decode("utf-8"))
|
||||
return True
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._closed = True
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._receive_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _notify_disconnect(self, reason: str) -> None:
|
||||
if self._on_disconnect is None:
|
||||
return
|
||||
try:
|
||||
result = self._on_disconnect(self._platform, reason)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception: # pragma: no cover - best effort notification
|
||||
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info(f"WebSocket 连接断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
|
||||
if self._ws is None or self._ws.closed:
|
||||
await self._connect_once()
|
||||
assert self._ws is not None
|
||||
return self._ws
|
||||
|
||||
async def __aenter__(self) -> "MessageClient":
|
||||
if not self._url or not self._platform:
|
||||
raise RuntimeError("connect() must be called before using MessageClient as a context manager")
|
||||
await self._ensure_session()
|
||||
await self._ensure_ws()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.stop()
|
||||
|
||||
|
||||
def _self_websocket(app: FastAPI, path: str):
|
||||
"""
|
||||
装饰器工厂,兼容 FastAPI websocket 路由的声明方式。
|
||||
FastAPI 不允许直接重复注册同一路径,因此这里封装一个可复用的装饰器。
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
app.add_api_websocket_route(path, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
__all__ = ["BaseMessageHandler", "MessageClient", "MessageServer"]
|
||||
@@ -1,111 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPayload, UserInfoPayload
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""
|
||||
流式构建 MessageEnvelope 的助手工具,提供类型安全的构建方法。
|
||||
|
||||
使用示例:
|
||||
msg = (
|
||||
MessageBuilder()
|
||||
.text("Hello")
|
||||
.image("http://example.com/1.png")
|
||||
.to_user("123", platform="qq")
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._direction: str = "outgoing"
|
||||
self._message_info: MessageInfoPayload = {}
|
||||
self._segments: List[SegPayload] = []
|
||||
self._metadata: Dict[str, Any] | None = None
|
||||
self._timestamp_ms: int | None = None
|
||||
self._message_id: str | None = None
|
||||
|
||||
def direction(self, value: str) -> "MessageBuilder":
|
||||
self._direction = value
|
||||
return self
|
||||
|
||||
def message_id(self, value: str) -> "MessageBuilder":
|
||||
self._message_id = value
|
||||
return self
|
||||
|
||||
def timestamp_ms(self, value: int | None = None) -> "MessageBuilder":
|
||||
self._timestamp_ms = value or int(time.time() * 1000)
|
||||
return self
|
||||
|
||||
def metadata(self, value: Dict[str, Any]) -> "MessageBuilder":
|
||||
self._metadata = value
|
||||
return self
|
||||
|
||||
def platform(self, value: str) -> "MessageBuilder":
|
||||
self._message_info["platform"] = value
|
||||
return self
|
||||
|
||||
def from_user(self, user_id: str, *, platform: str | None = None, nickname: str | None = None) -> "MessageBuilder":
|
||||
if platform:
|
||||
self.platform(platform)
|
||||
user_info: UserInfoPayload = {"user_id": user_id}
|
||||
if nickname:
|
||||
user_info["user_nickname"] = nickname
|
||||
self._message_info["user_info"] = user_info
|
||||
return self
|
||||
|
||||
def from_group(self, group_id: str, *, platform: str | None = None, name: str | None = None) -> "MessageBuilder":
|
||||
if platform:
|
||||
self.platform(platform)
|
||||
group_info: GroupInfoPayload = {"group_id": group_id}
|
||||
if name:
|
||||
group_info["group_name"] = name
|
||||
self._message_info["group_info"] = group_info
|
||||
return self
|
||||
|
||||
def seg(self, type_: str, data: Any) -> "MessageBuilder":
|
||||
self._segments.append({"type": type_, "data": data})
|
||||
return self
|
||||
|
||||
def text(self, content: str) -> "MessageBuilder":
|
||||
return self.seg("text", content)
|
||||
|
||||
def image(self, url: str) -> "MessageBuilder":
|
||||
return self.seg("image", url)
|
||||
|
||||
def reply(self, target_message_id: str) -> "MessageBuilder":
|
||||
return self.seg("reply", target_message_id)
|
||||
|
||||
def raw_segment(self, segment: SegPayload) -> "MessageBuilder":
|
||||
self._segments.append(segment)
|
||||
return self
|
||||
|
||||
def build(self) -> MessageEnvelope:
|
||||
"""构建最终的消息信封"""
|
||||
# 设置 message_info 默认值
|
||||
if not self._segments:
|
||||
raise ValueError("需要至少添加一个消息段才能构建消息")
|
||||
if self._message_id is None:
|
||||
self._message_id = str(uuid.uuid4())
|
||||
info = dict(self._message_info)
|
||||
info.setdefault("message_id", self._message_id)
|
||||
info.setdefault("time", time.time())
|
||||
|
||||
segments = [seg.copy() if isinstance(seg, dict) else seg for seg in self._segments]
|
||||
envelope: MessageEnvelope = {
|
||||
"direction": self._direction, # type: ignore[assignment]
|
||||
"message_info": info,
|
||||
"message_segment": segments[0] if len(segments) == 1 else list(segments),
|
||||
}
|
||||
if self._metadata is not None:
|
||||
envelope["metadata"] = self._metadata
|
||||
if self._timestamp_ms is not None:
|
||||
envelope["timestamp_ms"] = self._timestamp_ms
|
||||
return envelope
|
||||
|
||||
|
||||
__all__ = ["MessageBuilder"]
|
||||
@@ -1,94 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _stdlib_json
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
try:
|
||||
import orjson as _json_impl
|
||||
except Exception: # pragma: no cover - fallback when orjson is unavailable
|
||||
_json_impl = None
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
DEFAULT_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def _dumps(obj: Any) -> bytes:
|
||||
if _json_impl is not None:
|
||||
return _json_impl.dumps(obj)
|
||||
return _stdlib_json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def _loads(data: bytes) -> Dict[str, Any]:
|
||||
if _json_impl is not None:
|
||||
return _json_impl.loads(data)
|
||||
return _stdlib_json.loads(data.decode("utf-8"))
|
||||
|
||||
|
||||
def dumps_message(msg: MessageEnvelope) -> bytes:
|
||||
"""
|
||||
将单条消息序列化为 JSON bytes。
|
||||
"""
|
||||
sanitized = _strip_raw_bytes(msg)
|
||||
if "schema_version" not in sanitized:
|
||||
sanitized["schema_version"] = DEFAULT_SCHEMA_VERSION
|
||||
return _dumps(sanitized)
|
||||
|
||||
def dumps_messages(messages: Iterable[MessageEnvelope]) -> bytes:
|
||||
"""
|
||||
将批量消息序列化为 JSON bytes。
|
||||
"""
|
||||
payload = {
|
||||
"schema_version": DEFAULT_SCHEMA_VERSION,
|
||||
"items": [_strip_raw_bytes(msg) for msg in messages],
|
||||
}
|
||||
return _dumps(payload)
|
||||
|
||||
def loads_message(data: bytes | str) -> MessageEnvelope:
|
||||
"""
|
||||
反序列化单条消息。
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
obj = _loads(data)
|
||||
return _upgrade_schema_if_needed(obj)
|
||||
|
||||
|
||||
def loads_messages(data: bytes | str) -> List[MessageEnvelope]:
|
||||
"""
|
||||
反序列化批量消息。
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
obj = _loads(data)
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version != DEFAULT_SCHEMA_VERSION:
|
||||
raise ValueError(f"不支持的 schema_version={version}")
|
||||
return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])]
|
||||
|
||||
|
||||
def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
|
||||
"""
|
||||
针对未来的 schema 版本演进预留兼容入口。
|
||||
"""
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version == DEFAULT_SCHEMA_VERSION:
|
||||
return obj # type: ignore[return-value]
|
||||
raise ValueError(f"不支持的 schema_version={version}")
|
||||
|
||||
|
||||
|
||||
def _strip_raw_bytes(msg: MessageEnvelope) -> MessageEnvelope:
|
||||
if isinstance(msg, dict) and "raw_bytes" in msg:
|
||||
new_msg = dict(msg)
|
||||
new_msg.pop("raw_bytes", None)
|
||||
return new_msg # type: ignore[return-value]
|
||||
return msg
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SCHEMA_VERSION",
|
||||
"dumps_message",
|
||||
"dumps_messages",
|
||||
"loads_message",
|
||||
"loads_messages",
|
||||
]
|
||||
@@ -1,265 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from .api import MessageClient
|
||||
from .types import MessageEnvelope
|
||||
|
||||
logger = logging.getLogger("mofox_bus.router")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetConfig:
|
||||
"""路由目标配置,包含连接信息和认证配置"""
|
||||
url: str
|
||||
token: str | None = None
|
||||
ssl_verify: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, str | None]:
|
||||
"""转换为字典格式"""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig":
|
||||
"""从字典创建配置对象"""
|
||||
return cls(
|
||||
url=data.get("url", ""),
|
||||
token=data.get("token"),
|
||||
ssl_verify=data.get("ssl_verify"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteConfig:
|
||||
"""路由配置,包含多个平台的路由目标"""
|
||||
route_config: Dict[str, TargetConfig]
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, str | None]]:
|
||||
"""转换为字典格式"""
|
||||
return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig":
|
||||
"""从字典创建路由配置对象"""
|
||||
cfg = {
|
||||
platform: TargetConfig.from_dict(target)
|
||||
for platform, target in data.get("route_config", {}).items()
|
||||
}
|
||||
return cls(route_config=cfg)
|
||||
|
||||
|
||||
class Router:
|
||||
"""消息路由器,负责管理多个平台的消息客户端连接"""
|
||||
|
||||
def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None:
|
||||
"""
|
||||
初始化路由器
|
||||
|
||||
Args:
|
||||
config: 路由配置
|
||||
custom_logger: 自定义日志记录器
|
||||
"""
|
||||
if custom_logger:
|
||||
logger.handlers = custom_logger.handlers
|
||||
self.config = config
|
||||
self.clients: Dict[str, MessageClient] = {}
|
||||
self.handlers: list[Callable[[Dict], None]] = []
|
||||
self._running = False
|
||||
self._client_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
|
||||
async def connect(self, platform: str) -> None:
|
||||
"""
|
||||
连接到指定平台
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
|
||||
Raises:
|
||||
ValueError: 未知平台
|
||||
NotImplementedError: 不支持的模式
|
||||
"""
|
||||
if platform not in self.config.route_config:
|
||||
raise ValueError(f"未知平台: {platform}")
|
||||
target = self.config.route_config[platform]
|
||||
mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws"
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("TCP 模式暂未实现")
|
||||
client = MessageClient(mode="ws")
|
||||
client.set_disconnect_callback(self._handle_client_disconnect)
|
||||
await client.connect(
|
||||
url=target.url,
|
||||
platform=platform,
|
||||
token=target.token,
|
||||
ssl_verify=target.ssl_verify,
|
||||
)
|
||||
for handler in self.handlers:
|
||||
client.register_message_handler(handler)
|
||||
self.clients[platform] = client
|
||||
if self._running:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
def register_class_handler(self, handler: Callable[[Dict], None]) -> None:
|
||||
self.handlers.append(handler)
|
||||
for client in self.clients.values():
|
||||
client.register_message_handler(handler)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""启动路由器,连接所有配置的平台并开始运行"""
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
for platform in self.config.route_config:
|
||||
if platform not in self.clients:
|
||||
await self.connect(platform)
|
||||
for platform, client in self.clients.items():
|
||||
if platform not in self._client_tasks:
|
||||
self._start_client_task(platform, client)
|
||||
try:
|
||||
await self._stop_event.wait()
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
raise
|
||||
|
||||
async def remove_platform(self, platform: str) -> None:
|
||||
"""
|
||||
移除指定平台的连接
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
"""
|
||||
if platform in self._client_tasks:
|
||||
task = self._client_tasks.pop(platform)
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
client = self.clients.pop(platform, None)
|
||||
if client:
|
||||
await client.stop()
|
||||
|
||||
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
|
||||
"""
|
||||
处理客户端断开连接
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
reason: 断开原因
|
||||
"""
|
||||
logger.info(f"平台 {platform} 的客户端断开连接: {reason} (客户端将自动重连)")
|
||||
task = self._client_tasks.get(platform)
|
||||
if task is not None and not task.done():
|
||||
return
|
||||
client = self.clients.get(platform)
|
||||
if client and self._running:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止路由器,关闭所有连接"""
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
for platform in list(self.clients.keys()):
|
||||
await self.remove_platform(platform)
|
||||
self.clients.clear()
|
||||
|
||||
def _start_client_task(self, platform: str, client: MessageClient) -> None:
|
||||
"""
|
||||
启动客户端任务
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
client: 消息客户端
|
||||
"""
|
||||
task = asyncio.create_task(client.run())
|
||||
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
|
||||
self._client_tasks[platform] = task
|
||||
|
||||
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
|
||||
"""
|
||||
必要时重启客户端任务
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
task: 已完成的任务
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning(f"平台 {platform} 的客户端任务异常结束: {exc}")
|
||||
client = self.clients.get(platform)
|
||||
if client:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
|
||||
"""
|
||||
根据消息获取目标 URL
|
||||
|
||||
Args:
|
||||
message: 消息信封
|
||||
|
||||
Returns:
|
||||
目标 URL 或 None
|
||||
"""
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
return None
|
||||
target = self.config.route_config.get(platform)
|
||||
return target.url if target else None
|
||||
|
||||
async def send_message(self, message: MessageEnvelope):
|
||||
"""
|
||||
发送消息到指定平台
|
||||
|
||||
Args:
|
||||
message: 消息信封
|
||||
|
||||
Raises:
|
||||
ValueError: 缺少平台信息
|
||||
RuntimeError: 未找到对应平台的客户端
|
||||
"""
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("消息中缺少必需的 message_info.platform 字段")
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
raise RuntimeError(f"平台 {platform} 没有已连接的客户端")
|
||||
return await client.send_message(message)
|
||||
|
||||
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
||||
"""
|
||||
更新路由配置
|
||||
|
||||
Args:
|
||||
config_data: 新的配置数据
|
||||
"""
|
||||
new_config = RouteConfig.from_dict(config_data)
|
||||
await self._adjust_connections(new_config)
|
||||
self.config = new_config
|
||||
|
||||
async def _adjust_connections(self, new_config: RouteConfig) -> None:
|
||||
"""
|
||||
调整连接以匹配新配置
|
||||
|
||||
Args:
|
||||
new_config: 新的路由配置
|
||||
"""
|
||||
current = set(self.config.route_config.keys())
|
||||
updated = set(new_config.route_config.keys())
|
||||
# 移除不再存在的平台
|
||||
for platform in current - updated:
|
||||
await self.remove_platform(platform)
|
||||
# 添加或更新平台
|
||||
for platform in updated:
|
||||
if platform not in current:
|
||||
await self.connect(platform)
|
||||
else:
|
||||
old = self.config.route_config[platform]
|
||||
new = new_config.route_config[platform]
|
||||
if old.url != new.url or old.token != new.token:
|
||||
await self.remove_platform(platform)
|
||||
await self.connect(platform)
|
||||
@@ -1,470 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Protocol
|
||||
|
||||
from .types import MessageEnvelope
|
||||
|
||||
Hook = Callable[[MessageEnvelope], Awaitable[None] | None]
|
||||
ErrorHook = Callable[[MessageEnvelope, BaseException], Awaitable[None] | None]
|
||||
Predicate = Callable[[MessageEnvelope], bool | Awaitable[bool]]
|
||||
MessageHandler = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None] | MessageEnvelope | None]
|
||||
BatchHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None] | List[MessageEnvelope] | None]
|
||||
MiddlewareCallable = Callable[[MessageEnvelope], Awaitable[MessageEnvelope | None]]
|
||||
|
||||
|
||||
class Middleware(Protocol):
|
||||
async def __call__(self, message: MessageEnvelope, handler: MiddlewareCallable) -> MessageEnvelope | None: ...
|
||||
|
||||
|
||||
class MessageProcessingError(RuntimeError):
|
||||
"""封装处理链路中发生的异常。"""
|
||||
|
||||
def __init__(self, message: MessageEnvelope, original: BaseException):
|
||||
detail = message.get("id", "<unknown>")
|
||||
super().__init__(f"处理消息 {detail} 时出错: {original}")
|
||||
self.message_envelope = message
|
||||
self.original = original
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRoute:
|
||||
"""消息路由配置,包含匹配条件和处理函数"""
|
||||
predicate: Predicate
|
||||
handler: MessageHandler
|
||||
name: str | None = None
|
||||
message_type: str | None = None
|
||||
message_types: set[str] | None = None # 支持多个消息类型
|
||||
event_types: set[str] | None = None
|
||||
|
||||
|
||||
class MessageRuntime:
|
||||
"""
|
||||
消息运行时环境,负责调度消息路由、执行前后处理钩子以及批量处理消息
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._routes: list[MessageRoute] = []
|
||||
self._before_hooks: list[Hook] = []
|
||||
self._after_hooks: list[Hook] = []
|
||||
self._error_hooks: list[ErrorHook] = []
|
||||
self._batch_handler: BatchHandler | None = None
|
||||
self._lock = threading.RLock()
|
||||
self._middlewares: list[Middleware] = []
|
||||
self._type_routes: Dict[str, list[MessageRoute]] = {}
|
||||
self._event_routes: Dict[str, list[MessageRoute]] = {}
|
||||
# 用于检测同一类型的重复注册
|
||||
self._explicit_type_handlers: Dict[str, str] = {} # message_type -> handler_name
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
predicate: Predicate,
|
||||
handler: MessageHandler,
|
||||
name: str | None = None,
|
||||
*,
|
||||
message_type: str | list[str] | None = None,
|
||||
event_types: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
添加消息路由
|
||||
|
||||
Args:
|
||||
predicate: 路由匹配条件
|
||||
handler: 消息处理函数
|
||||
name: 路由名称(可选)
|
||||
message_type: 消息类型,可以是字符串或字符串列表(可选)
|
||||
event_types: 事件类型列表(可选)
|
||||
"""
|
||||
with self._lock:
|
||||
# 处理 message_type 参数,支持字符串或列表
|
||||
message_types_set: set[str] | None = None
|
||||
single_message_type: str | None = None
|
||||
|
||||
if message_type is not None:
|
||||
if isinstance(message_type, str):
|
||||
message_types_set = {message_type}
|
||||
single_message_type = message_type
|
||||
elif isinstance(message_type, list):
|
||||
message_types_set = set(message_type)
|
||||
if len(message_types_set) == 1:
|
||||
single_message_type = next(iter(message_types_set))
|
||||
else:
|
||||
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||
|
||||
# 检测重复注册:如果明确指定了某个类型,不允许重复
|
||||
handler_name = name or getattr(handler, "__name__", str(handler))
|
||||
for msg_type in message_types_set:
|
||||
if msg_type in self._explicit_type_handlers:
|
||||
existing_handler = self._explicit_type_handlers[msg_type]
|
||||
raise ValueError(
|
||||
f"消息类型 '{msg_type}' 已被处理器 '{existing_handler}' 明确注册,"
|
||||
f"不能再由 '{handler_name}' 注册。同一消息类型只能有一个明确的处理器。"
|
||||
)
|
||||
self._explicit_type_handlers[msg_type] = handler_name
|
||||
|
||||
route = MessageRoute(
|
||||
predicate=predicate,
|
||||
handler=handler,
|
||||
name=name,
|
||||
message_type=single_message_type,
|
||||
message_types=message_types_set,
|
||||
event_types=set(event_types) if event_types is not None else None,
|
||||
)
|
||||
self._routes.append(route)
|
||||
|
||||
# 为每个消息类型建立索引
|
||||
if message_types_set:
|
||||
for msg_type in message_types_set:
|
||||
self._type_routes.setdefault(msg_type, []).append(route)
|
||||
|
||||
if route.event_types:
|
||||
for et in route.event_types:
|
||||
self._event_routes.setdefault(et, []).append(route)
|
||||
|
||||
def route(self, predicate: Predicate, name: str | None = None) -> Callable[[MessageHandler], MessageHandler]:
|
||||
"""装饰器写法,便于在核心逻辑中声明式注册。
|
||||
|
||||
支持普通函数和类方法。对于类方法,会在实例创建时自动绑定并注册路由。
|
||||
"""
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
# Support decorating instance methods: defer binding until the object is created.
|
||||
if _looks_like_method(func):
|
||||
return _InstanceMethodRoute(
|
||||
runtime=self,
|
||||
func=func,
|
||||
predicate=predicate,
|
||||
name=name,
|
||||
message_type=None,
|
||||
)
|
||||
|
||||
self.add_route(predicate, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
func: MessageHandler | None = None,
|
||||
*,
|
||||
message_type: str | list[str] | None = None,
|
||||
platform: str | None = None,
|
||||
predicate: Predicate | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[[MessageHandler], MessageHandler] | MessageHandler:
|
||||
"""Sugar decorator with optional Seg.type/platform predicate matching.
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
message_type: 消息类型,可以是单个字符串或字符串列表
|
||||
platform: 平台名称
|
||||
predicate: 自定义匹配条件
|
||||
name: 路由名称
|
||||
|
||||
Usages:
|
||||
- @runtime.on_message(...)
|
||||
- @runtime.on_message
|
||||
- @runtime.on_message(message_type="text")
|
||||
- @runtime.on_message(message_type=["text", "image"])
|
||||
|
||||
If the target looks like an instance method (first arg is self), it will be
|
||||
auto-bound to the instance and registered when the object is constructed.
|
||||
"""
|
||||
# 将 message_type 转换为集合以便统一处理
|
||||
message_types_set: set[str] | None = None
|
||||
if message_type is not None:
|
||||
if isinstance(message_type, str):
|
||||
message_types_set = {message_type}
|
||||
elif isinstance(message_type, list):
|
||||
message_types_set = set(message_type)
|
||||
else:
|
||||
raise TypeError(f"message_type must be str or list[str], got {type(message_type)}")
|
||||
|
||||
async def combined_predicate(message: MessageEnvelope) -> bool:
|
||||
if message_types_set is not None:
|
||||
extracted_type = _extract_segment_type(message)
|
||||
if extracted_type not in message_types_set:
|
||||
return False
|
||||
if platform is not None:
|
||||
info_platform = message.get("message_info", {}).get("platform")
|
||||
if message.get("platform") not in (None, platform) and info_platform is None:
|
||||
return False
|
||||
if info_platform not in (None, platform):
|
||||
return False
|
||||
if predicate is None:
|
||||
return True
|
||||
return await _invoke_callable(predicate, message, prefer_thread=False)
|
||||
|
||||
def decorator(func: MessageHandler) -> MessageHandler:
|
||||
# Support decorating instance methods: defer binding until the object is created.
|
||||
if _looks_like_method(func):
|
||||
return _InstanceMethodRoute(
|
||||
runtime=self,
|
||||
func=func,
|
||||
predicate=combined_predicate,
|
||||
name=name,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
self.add_route(combined_predicate, func, name=name, message_type=message_type)
|
||||
return func
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
|
||||
def set_batch_handler(self, handler: BatchHandler) -> None:
|
||||
self._batch_handler = handler
|
||||
|
||||
def register_before_hook(self, hook: Hook) -> None:
|
||||
self._before_hooks.append(hook)
|
||||
|
||||
def register_after_hook(self, hook: Hook) -> None:
|
||||
self._after_hooks.append(hook)
|
||||
|
||||
def register_error_hook(self, hook: ErrorHook) -> None:
|
||||
self._error_hooks.append(hook)
|
||||
|
||||
def register_middleware(self, middleware: Middleware) -> None:
|
||||
"""注册洋葱模型中间件,围绕处理器执行。"""
|
||||
|
||||
self._middlewares.append(middleware)
|
||||
|
||||
async def handle_message(self, message: MessageEnvelope) -> MessageEnvelope | None:
|
||||
await self._run_hooks(self._before_hooks, message)
|
||||
try:
|
||||
route = await self._match_route(message)
|
||||
if route is None:
|
||||
return None
|
||||
handler = self._wrap_with_middlewares(route.handler)
|
||||
result = await handler(message)
|
||||
except Exception as exc:
|
||||
await self._run_error_hooks(message, exc)
|
||||
raise MessageProcessingError(message, exc) from exc
|
||||
await self._run_hooks(self._after_hooks, message)
|
||||
return result
|
||||
|
||||
async def handle_batch(self, messages: Iterable[MessageEnvelope]) -> List[MessageEnvelope]:
|
||||
batch = list(messages)
|
||||
if not batch:
|
||||
return []
|
||||
if self._batch_handler is not None:
|
||||
result = await _invoke_callable(self._batch_handler, batch, prefer_thread=True)
|
||||
return result or []
|
||||
responses: list[MessageEnvelope] = []
|
||||
for message in batch:
|
||||
response = await self.handle_message(message)
|
||||
if response is not None:
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def _match_route(self, message: MessageEnvelope) -> MessageRoute | None:
|
||||
"""匹配消息路由,优先匹配明确指定了消息类型的处理器"""
|
||||
message_type = _extract_segment_type(message)
|
||||
event_type = (
|
||||
message.get("event_type")
|
||||
or message.get("message_info", {})
|
||||
.get("additional_config", {})
|
||||
.get("event_type")
|
||||
)
|
||||
|
||||
# 分为两层候选:优先级和普通
|
||||
priority_candidates: list[MessageRoute] = [] # 明确指定了消息类型的
|
||||
normal_candidates: list[MessageRoute] = [] # 没有指定或通配的
|
||||
|
||||
with self._lock:
|
||||
# 事件路由(优先级最高)
|
||||
if event_type and event_type in self._event_routes:
|
||||
priority_candidates.extend(self._event_routes[event_type])
|
||||
|
||||
# 消息类型路由(明确指定的有优先级)
|
||||
if message_type and message_type in self._type_routes:
|
||||
priority_candidates.extend(self._type_routes[message_type])
|
||||
|
||||
# 通用路由(没有明确指定类型的)
|
||||
for route in self._routes:
|
||||
# 如果路由没有指定 message_types,则是通用路由
|
||||
if route.message_types is None and route.event_types is None:
|
||||
normal_candidates.append(route)
|
||||
|
||||
# 先尝试优先级候选
|
||||
seen: set[int] = set()
|
||||
for route in priority_candidates:
|
||||
rid = id(route)
|
||||
if rid in seen:
|
||||
continue
|
||||
seen.add(rid)
|
||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||
if should_handle:
|
||||
return route
|
||||
|
||||
# 如果没有匹配到优先级候选,再尝试普通候选
|
||||
for route in normal_candidates:
|
||||
rid = id(route)
|
||||
if rid in seen:
|
||||
continue
|
||||
seen.add(rid)
|
||||
should_handle = await _invoke_callable(route.predicate, message, prefer_thread=False)
|
||||
if should_handle:
|
||||
return route
|
||||
|
||||
return None
|
||||
|
||||
async def _run_hooks(self, hooks: Iterable[Hook], message: MessageEnvelope) -> None:
|
||||
coro_list = [self._call_hook(hook, message) for hook in hooks]
|
||||
if coro_list:
|
||||
await asyncio.gather(*coro_list)
|
||||
|
||||
async def _call_hook(self, hook: Hook, message: MessageEnvelope) -> None:
|
||||
await _invoke_callable(hook, message, prefer_thread=True)
|
||||
|
||||
async def _run_error_hooks(self, message: MessageEnvelope, exc: BaseException) -> None:
|
||||
coros = [self._call_error_hook(hook, message, exc) for hook in self._error_hooks]
|
||||
if coros:
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def _call_error_hook(self, hook: ErrorHook, message: MessageEnvelope, exc: BaseException) -> None:
|
||||
await _invoke_callable(hook, message, exc, prefer_thread=True)
|
||||
|
||||
def _wrap_with_middlewares(self, handler: MessageHandler) -> MiddlewareCallable:
|
||||
async def base_handler(message: MessageEnvelope) -> MessageEnvelope | None:
|
||||
return await _invoke_callable(handler, message, prefer_thread=True)
|
||||
|
||||
wrapped: MiddlewareCallable = base_handler
|
||||
for middleware in reversed(self._middlewares):
|
||||
current = wrapped
|
||||
|
||||
async def wrapper(msg: MessageEnvelope, mw=middleware, nxt=current) -> MessageEnvelope | None:
|
||||
return await _invoke_callable(mw, msg, nxt, prefer_thread=False)
|
||||
|
||||
wrapped = wrapper
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||
"""支持 sync/async 调用,并可选择在线程中执行。
|
||||
|
||||
自动处理普通函数、类方法和绑定方法。
|
||||
"""
|
||||
# 如果是绑定方法(bound method),直接使用,不需要额外处理
|
||||
# 因为绑定方法已经包含了 self 参数
|
||||
if inspect.ismethod(func):
|
||||
# 绑定方法可以直接调用,args 中不应包含 self
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
result = await asyncio.to_thread(func, *args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
result = func(*args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
# 对于普通函数(未绑定的),按原有逻辑处理
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
result = await asyncio.to_thread(func, *args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
result = func(*args)
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
def _extract_segment_type(message: MessageEnvelope) -> str | None:
|
||||
seg = message.get("message_segment") or message.get("message_chain")
|
||||
if isinstance(seg, dict):
|
||||
return seg.get("type")
|
||||
if isinstance(seg, list) and seg:
|
||||
first = seg[0]
|
||||
if isinstance(first, dict):
|
||||
return first.get("type")
|
||||
return None
|
||||
|
||||
|
||||
def _looks_like_method(func: Callable[..., object]) -> bool:
|
||||
"""Return True if callable signature suggests an instance method (first arg named self)."""
|
||||
if inspect.ismethod(func):
|
||||
return True
|
||||
if not inspect.isfunction(func):
|
||||
return False
|
||||
params = inspect.signature(func).parameters
|
||||
if not params:
|
||||
return False
|
||||
first = next(iter(params.values()))
|
||||
return first.name == "self"
|
||||
|
||||
|
||||
class _InstanceMethodRoute:
|
||||
"""Descriptor that binds decorated instance methods and registers routes per-instance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runtime: MessageRuntime,
|
||||
func: MessageHandler,
|
||||
predicate: Predicate,
|
||||
name: str | None,
|
||||
message_type: str | None,
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._func = func
|
||||
self._predicate = predicate
|
||||
self._name = name
|
||||
self._message_type = message_type
|
||||
self._owner: type | None = None
|
||||
self._registered_instances: weakref.WeakSet[object] = weakref.WeakSet()
|
||||
|
||||
def __set_name__(self, owner: type, name: str) -> None:
|
||||
self._owner = owner
|
||||
registry: list[_InstanceMethodRoute] | None = getattr(owner, "_mofox_instance_routes", None)
|
||||
if registry is None:
|
||||
registry = []
|
||||
setattr(owner, "_mofox_instance_routes", registry)
|
||||
original_init = owner.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def wrapped_init(inst, *args, **kwargs):
|
||||
original_init(inst, *args, **kwargs)
|
||||
for descriptor in getattr(inst.__class__, "_mofox_instance_routes", []):
|
||||
descriptor._register_instance(inst)
|
||||
|
||||
owner.__init__ = wrapped_init # type: ignore[assignment]
|
||||
registry.append(self)
|
||||
|
||||
def _register_instance(self, instance: object) -> None:
|
||||
if instance in self._registered_instances:
|
||||
return
|
||||
owner = self._owner or instance.__class__
|
||||
bound = self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||
self._runtime.add_route(self._predicate, bound, name=self._name, message_type=self._message_type)
|
||||
self._registered_instances.add(instance)
|
||||
|
||||
def __get__(self, instance: object | None, owner: type | None = None):
|
||||
if instance is None:
|
||||
return self._func
|
||||
self._register_instance(instance)
|
||||
return self._func.__get__(instance, owner) # type: ignore[arg-type]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchHandler",
|
||||
"Hook",
|
||||
"MessageHandler",
|
||||
"MessageProcessingError",
|
||||
"MessageRoute",
|
||||
"MessageRuntime",
|
||||
"Middleware",
|
||||
"Predicate",
|
||||
]
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
传输层封装,提供 HTTP / WebSocket server & client。
|
||||
"""
|
||||
|
||||
from .http_client import HttpMessageClient
|
||||
from .http_server import HttpMessageServer
|
||||
from .ws_client import WsMessageClient
|
||||
from .ws_server import WsMessageServer
|
||||
|
||||
__all__ = ["HttpMessageClient", "HttpMessageServer", "WsMessageClient", "WsMessageServer"]
|
||||
@@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
|
||||
class HttpMessageClient:
|
||||
"""
|
||||
面向消息批量传输的 HTTP 客户端封装
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
timeout: aiohttp.ClientTimeout | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._session = session
|
||||
self._timeout = timeout
|
||||
self._owns_session = session is None
|
||||
self._logger = logging.getLogger("mofox_bus.http_client")
|
||||
|
||||
async def send_messages(
|
||||
self,
|
||||
messages: Sequence[MessageEnvelope],
|
||||
*,
|
||||
expect_reply: bool = False,
|
||||
path: str = "/messages",
|
||||
) -> List[MessageEnvelope] | None:
|
||||
if not messages:
|
||||
return []
|
||||
session = await self._ensure_session()
|
||||
url = f"{self._base_url}{path}"
|
||||
payload = dumps_messages(messages)
|
||||
self._logger.debug(f"正在发送 {len(messages)} 条消息 -> {url}")
|
||||
async with session.post(url, data=payload, timeout=self._timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
if not expect_reply:
|
||||
return None
|
||||
raw = await resp.read()
|
||||
replies = loads_messages(raw)
|
||||
self._logger.debug(f"接收到 {len(replies)} 条回复消息")
|
||||
return replies
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_session and self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def __aenter__(self) -> "HttpMessageClient":
|
||||
await self._ensure_session()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelope] | None]]
|
||||
|
||||
|
||||
class HttpMessageServer:
|
||||
"""
|
||||
轻量级 HTTP 消息入口,可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下
|
||||
"""
|
||||
|
||||
def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None:
|
||||
self._handler = handler
|
||||
self._app = web.Application()
|
||||
self._path = path
|
||||
self._app.add_routes([web.post(path, self._handle_messages)])
|
||||
self._logger = logging.getLogger("mofox_bus.http_server")
|
||||
|
||||
async def _handle_messages(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
raw = await request.read()
|
||||
envelopes = loads_messages(raw)
|
||||
self._logger.debug(f"接收到 {len(envelopes)} 条消息")
|
||||
except Exception as exc: # pragma: no cover - network errors are integration tested
|
||||
self._logger.exception(f"解析请求失败: {exc}")
|
||||
raise web.HTTPBadRequest(reason=f"无效的负载: {exc}") from exc
|
||||
|
||||
result = await self._handler(envelopes)
|
||||
if result is None:
|
||||
return web.Response(status=200, text="ok")
|
||||
payload = dumps_messages(result)
|
||||
return web.Response(status=200, body=payload, content_type="application/json")
|
||||
|
||||
def make_app(self) -> web.Application:
|
||||
"""
|
||||
返回 aiohttp Application,可被外部 server(gunicorn/uvicorn)直接使用。
|
||||
"""
|
||||
|
||||
return self._app
|
||||
|
||||
def add_to_app(self, app: web.Application) -> None:
|
||||
"""
|
||||
将消息路由注册到给定的 aiohttp app,方便与既有服务整合。
|
||||
"""
|
||||
|
||||
app.router.add_post(self._path, self._handle_messages)
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Iterable, List, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class WsMessageClient:
|
||||
"""
|
||||
管理 WebSocket 连接,提供 send/receive API,并在后台读取消息
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
handler: IncomingHandler | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
reconnect_interval: float = 5.0,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._handler = handler
|
||||
self._session = session
|
||||
self._reconnect_interval = reconnect_interval
|
||||
self._owns_session = session is None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._logger = logging.getLogger("mofox_bus.ws_client")
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._ensure_session()
|
||||
await self._connect_once()
|
||||
|
||||
async def _connect_once(self) -> None:
|
||||
assert self._session is not None
|
||||
self._ws = await self._session.ws_connect(self._url)
|
||||
self._logger.info(f"已连接到 {self._url}")
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
ws = await self._ensure_ws()
|
||||
payload = dumps_messages(messages)
|
||||
await ws.send_bytes(payload)
|
||||
|
||||
async def send_message(self, message: MessageEnvelope) -> None:
|
||||
await self.send_messages([message])
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._owns_session and self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type in (aiohttp.WSMsgType.BINARY, aiohttp.WSMsgType.TEXT):
|
||||
envelopes = loads_messages(msg.data)
|
||||
for env in envelopes:
|
||||
if self._handler is not None:
|
||||
await self._handler(env)
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.warning(f"WebSocket 错误: {msg.data}")
|
||||
break
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
return
|
||||
finally:
|
||||
if not self._closed:
|
||||
await self._reconnect()
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info(f"WebSocket 断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
|
||||
if self._ws is None or self._ws.closed:
|
||||
await self._connect_once()
|
||||
assert self._ws is not None
|
||||
return self._ws
|
||||
|
||||
async def __aenter__(self) -> "WsMessageClient":
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
@@ -1,66 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Awaitable, Callable, Iterable, List, Set
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
|
||||
from ..codec import dumps_messages, loads_messages
|
||||
from ..types import MessageEnvelope
|
||||
|
||||
WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class WsMessageServer:
|
||||
"""
|
||||
封装 WebSocket 服务端逻辑,负责接收消息并广播响应
|
||||
"""
|
||||
|
||||
def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None:
|
||||
self._handler = handler
|
||||
self._app = web.Application()
|
||||
self._path = path
|
||||
self._app.add_routes([web.get(path, self._handle_ws)])
|
||||
self._connections: Set[web.WebSocketResponse] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._logger = logging.getLogger("mofox_bus.ws_server")
|
||||
|
||||
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._logger.info(f"WebSocket 连接打开: {request.remote}")
|
||||
|
||||
async with self._track_connection(ws):
|
||||
async for message in ws:
|
||||
if message.type in (WSMsgType.BINARY, WSMsgType.TEXT):
|
||||
envelopes = loads_messages(message.data)
|
||||
for env in envelopes:
|
||||
await self._handler(env)
|
||||
elif message.type == WSMsgType.ERROR:
|
||||
self._logger.warning(f"WebSocket 连接错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
self._logger.info(f"WebSocket 连接关闭: {request.remote}")
|
||||
return ws
|
||||
|
||||
@asynccontextmanager
|
||||
async def _track_connection(self, ws: web.WebSocketResponse):
|
||||
async with self._lock:
|
||||
self._connections.add(ws)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._connections.discard(ws)
|
||||
|
||||
async def broadcast(self, messages: Iterable[MessageEnvelope]) -> None:
|
||||
payload = dumps_messages(list(messages))
|
||||
async with self._lock:
|
||||
targets = list(self._connections)
|
||||
for ws in targets:
|
||||
await ws.send_bytes(payload)
|
||||
|
||||
def make_app(self) -> web.Application:
|
||||
return self._app
|
||||
@@ -1,93 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, NotRequired, TypedDict, Required
|
||||
|
||||
MessageDirection = Literal["incoming", "outgoing"]
|
||||
|
||||
# ----------------------------
|
||||
# maim_message 风格的 TypedDict
|
||||
# ----------------------------
|
||||
|
||||
|
||||
class SegPayload(TypedDict, total=False):
|
||||
"""
|
||||
对齐 maim_message.Seg 的片段定义,使用纯 dict 便于 JSON 传输。
|
||||
"""
|
||||
|
||||
type: Required[str]
|
||||
data: Required[str | List["SegPayload"]]
|
||||
translated_data: NotRequired[str | List["SegPayload"]]
|
||||
|
||||
|
||||
class UserInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
user_id: Required[str]
|
||||
user_nickname: NotRequired[str]
|
||||
user_cardname: NotRequired[str]
|
||||
user_avatar: NotRequired[str]
|
||||
|
||||
|
||||
class GroupInfoPayload(TypedDict, total=False):
|
||||
platform: NotRequired[str]
|
||||
group_id: Required[str]
|
||||
group_name: NotRequired[str]
|
||||
|
||||
|
||||
class FormatInfoPayload(TypedDict, total=False):
|
||||
content_format: NotRequired[List[str]]
|
||||
accept_format: NotRequired[List[str]]
|
||||
|
||||
|
||||
class TemplateInfoPayload(TypedDict, total=False):
|
||||
template_items: NotRequired[Dict[str, str]]
|
||||
template_name: NotRequired[Dict[str, str]]
|
||||
template_default: NotRequired[bool]
|
||||
|
||||
|
||||
class MessageInfoPayload(TypedDict, total=False):
|
||||
platform: Required[str]
|
||||
message_id: Required[str]
|
||||
time: NotRequired[float]
|
||||
group_info: NotRequired[GroupInfoPayload]
|
||||
user_info: NotRequired[UserInfoPayload]
|
||||
format_info: NotRequired[FormatInfoPayload]
|
||||
template_info: NotRequired[TemplateInfoPayload]
|
||||
additional_config: NotRequired[Dict[str, Any]]
|
||||
|
||||
# ----------------------------
|
||||
# MessageEnvelope
|
||||
# ----------------------------
|
||||
|
||||
|
||||
class MessageEnvelope(TypedDict, total=False):
|
||||
"""
|
||||
mofox-bus 传输层统一使用的消息信封。
|
||||
|
||||
- 采用 maim_message 风格:message_info + message_segment。
|
||||
"""
|
||||
|
||||
direction: MessageDirection
|
||||
message_info: Required[MessageInfoPayload]
|
||||
message_segment: Required[SegPayload] | List[SegPayload]
|
||||
raw_message: NotRequired[Any]
|
||||
raw_bytes: NotRequired[bytes]
|
||||
message_chain: NotRequired[List[SegPayload]] # seglist 的直观别名
|
||||
platform: NotRequired[str] # 快捷访问,等价于 message_info.platform
|
||||
message_id: NotRequired[str] # 快捷访问,等价于 message_info.message_id
|
||||
timestamp_ms: NotRequired[int]
|
||||
correlation_id: NotRequired[str]
|
||||
schema_version: NotRequired[int]
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
|
||||
__all__ = [
|
||||
# maim_message style payloads
|
||||
"SegPayload",
|
||||
"UserInfoPayload",
|
||||
"GroupInfoPayload",
|
||||
"FormatInfoPayload",
|
||||
"TemplateInfoPayload",
|
||||
"MessageInfoPayload",
|
||||
# legacy content style
|
||||
"MessageDirection",
|
||||
"MessageEnvelope",
|
||||
]
|
||||
@@ -62,8 +62,8 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._health_check_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
# 标记是否在子进程中运行(由核心管理器传入 ProcessCoreSink 时自动生效)
|
||||
self._is_subprocess = isinstance(core_sink, ProcessCoreSink)
|
||||
# 标记是否在子进程中运行
|
||||
self._is_subprocess = False
|
||||
|
||||
@classmethod
|
||||
def from_process_queues(
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
Adapter 管理器
|
||||
|
||||
负责管理所有注册的适配器,支持子进程自动启动和生命周期管理。
|
||||
|
||||
重构说明(2025-11):
|
||||
- 使用 CoreSinkManager 统一管理 InProcessCoreSink 和 ProcessCoreSink
|
||||
- 根据适配器的 run_in_subprocess 属性自动选择 CoreSink 类型
|
||||
- 子进程适配器通过 CoreSinkManager 的通信队列与主进程交互
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,7 +20,6 @@ if TYPE_CHECKING:
|
||||
from src.plugin_system.base.base_adapter import BaseAdapter
|
||||
|
||||
from mofox_bus import ProcessCoreSinkServer
|
||||
from src.common.core_sink import get_core_sink
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("adapter_manager")
|
||||
@@ -34,6 +38,11 @@ def _adapter_process_entry(
|
||||
incoming_queue: mp.Queue,
|
||||
outgoing_queue: mp.Queue,
|
||||
):
|
||||
"""
|
||||
子进程适配器入口函数
|
||||
|
||||
在子进程中运行,创建 ProcessCoreSink 与主进程通信
|
||||
"""
|
||||
import asyncio
|
||||
import contextlib
|
||||
from mofox_bus import ProcessCoreSink
|
||||
@@ -44,9 +53,14 @@ def _adapter_process_entry(
|
||||
if plugin_info:
|
||||
plugin_cls = _load_class(plugin_info["module"], plugin_info["class"])
|
||||
plugin_instance = plugin_cls(plugin_info["plugin_dir"], plugin_info["metadata"])
|
||||
|
||||
# 创建 ProcessCoreSink 用于与主进程通信
|
||||
core_sink = ProcessCoreSink(to_core_queue=incoming_queue, from_core_queue=outgoing_queue)
|
||||
|
||||
# 创建并启动适配器
|
||||
adapter = adapter_cls(core_sink, plugin=plugin_instance)
|
||||
await adapter.start()
|
||||
|
||||
try:
|
||||
while not getattr(core_sink, "_closed", False):
|
||||
await asyncio.sleep(0.2)
|
||||
@@ -61,21 +75,23 @@ def _adapter_process_entry(
|
||||
|
||||
|
||||
class AdapterProcess:
|
||||
"""适配器子进程封装:管理子进程的生命周期与通信桥接"""
|
||||
"""
|
||||
适配器子进程封装:管理子进程的生命周期与通信桥接
|
||||
|
||||
使用 CoreSinkManager 创建通信队列,自动维护与子进程的消息通道
|
||||
"""
|
||||
|
||||
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin, core_sink) -> None:
|
||||
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin) -> None:
|
||||
self.adapter_cls = adapter_cls
|
||||
self.adapter_name = adapter_cls.adapter_name
|
||||
self.plugin = plugin
|
||||
self.process: mp.Process | None = None
|
||||
self._ctx = mp.get_context("spawn")
|
||||
self._incoming_queue: mp.Queue = self._ctx.Queue()
|
||||
self._outgoing_queue: mp.Queue = self._ctx.Queue()
|
||||
self._incoming_queue: mp.Queue | None = None
|
||||
self._outgoing_queue: mp.Queue | None = None
|
||||
self._bridge: ProcessCoreSinkServer | None = None
|
||||
self._core_sink = core_sink
|
||||
self._adapter_path: tuple[str, str] = (adapter_cls.__module__, adapter_cls.__name__)
|
||||
self._plugin_info = self._extract_plugin_info(plugin)
|
||||
self._outgoing_handler = None
|
||||
|
||||
@staticmethod
|
||||
def _extract_plugin_info(plugin) -> dict | None:
|
||||
@@ -88,37 +104,28 @@ class AdapterProcess:
|
||||
"metadata": getattr(plugin, "plugin_meta", None),
|
||||
}
|
||||
|
||||
def _make_outgoing_handler(self):
|
||||
async def _handler(envelope):
|
||||
if self._bridge:
|
||||
await self._bridge.push_outgoing(envelope)
|
||||
return _handler
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""启动适配器子进程"""
|
||||
try:
|
||||
logger.info(f"启动适配器子进程: {self.adapter_name}")
|
||||
self._bridge = ProcessCoreSinkServer(
|
||||
incoming_queue=self._incoming_queue,
|
||||
outgoing_queue=self._outgoing_queue,
|
||||
core_handler=self._core_sink.send,
|
||||
name=self.adapter_name,
|
||||
)
|
||||
self._bridge.start()
|
||||
if hasattr(self._core_sink, "set_outgoing_handler"):
|
||||
self._outgoing_handler = self._make_outgoing_handler()
|
||||
try:
|
||||
self._core_sink.set_outgoing_handler(self._outgoing_handler)
|
||||
except Exception:
|
||||
logger.exception("Failed to register outgoing bridge for %s", self.adapter_name)
|
||||
|
||||
# 从 CoreSinkManager 获取通信队列
|
||||
from src.common.core_sink_manager import get_core_sink_manager
|
||||
|
||||
manager = get_core_sink_manager()
|
||||
self._incoming_queue, self._outgoing_queue = manager.create_process_sink_queues(self.adapter_name)
|
||||
|
||||
# 启动子进程
|
||||
self.process = self._ctx.Process(
|
||||
target=_adapter_process_entry,
|
||||
args=(self._adapter_path, self._plugin_info, self._incoming_queue, self._outgoing_queue),
|
||||
name=f"{self.adapter_name}-proc",
|
||||
)
|
||||
self.process.start()
|
||||
|
||||
logger.info(f"启动适配器子进程 {self.adapter_name} (PID: {self.process.pid})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动适配器子进程 {self.adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -127,26 +134,31 @@ class AdapterProcess:
|
||||
"""停止适配器子进程"""
|
||||
if not self.process:
|
||||
return
|
||||
|
||||
logger.info(f"停止适配器子进程: {self.adapter_name} (PID: {self.process.pid})")
|
||||
|
||||
try:
|
||||
remover = getattr(self._core_sink, "remove_outgoing_handler", None)
|
||||
if callable(remover) and self._outgoing_handler:
|
||||
try:
|
||||
remover(self._outgoing_handler)
|
||||
except Exception:
|
||||
logger.exception(f"移除 {self.adapter_name} 的 outgoing bridge 失败")
|
||||
if self._bridge:
|
||||
await self._bridge.close()
|
||||
# 从 CoreSinkManager 移除通信队列
|
||||
from src.common.core_sink_manager import get_core_sink_manager
|
||||
|
||||
manager = get_core_sink_manager()
|
||||
manager.remove_process_sink(self.adapter_name)
|
||||
|
||||
# 等待子进程结束
|
||||
if self.process.is_alive():
|
||||
self.process.join(timeout=5.0)
|
||||
|
||||
if self.process.is_alive():
|
||||
logger.warning(f"适配器 {self.adapter_name} 未能及时停止,强制终止中")
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止适配器子进程 {self.adapter_name} 时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self.process = None
|
||||
self._incoming_queue = None
|
||||
self._outgoing_queue = None
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""适配器是否正在运行"""
|
||||
@@ -155,7 +167,13 @@ class AdapterProcess:
|
||||
return self.process.is_alive()
|
||||
|
||||
class AdapterManager:
|
||||
"""适配器管理器"""
|
||||
"""
|
||||
适配器管理器
|
||||
|
||||
负责管理所有注册的适配器,根据 run_in_subprocess 属性自动选择:
|
||||
- run_in_subprocess=True: 在子进程中运行,使用 ProcessCoreSink
|
||||
- run_in_subprocess=False: 在主进程中运行,使用 InProcessCoreSink
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 注册信息:name -> (adapter class, plugin instance | None)
|
||||
@@ -178,14 +196,26 @@ class AdapterManager:
|
||||
|
||||
self._adapter_defs[adapter_name] = (adapter_cls, plugin)
|
||||
adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown')
|
||||
logger.info(f"注册适配器: {adapter_name} v{adapter_version}")
|
||||
run_in_subprocess = getattr(adapter_cls, 'run_in_subprocess', False)
|
||||
|
||||
logger.info(
|
||||
f"注册适配器: {adapter_name} v{adapter_version} "
|
||||
f"(子进程: {'是' if run_in_subprocess else '否'})"
|
||||
)
|
||||
|
||||
async def start_adapter(self, adapter_name: str) -> bool:
|
||||
"""启动指定适配器"""
|
||||
"""
|
||||
启动指定适配器
|
||||
|
||||
根据适配器的 run_in_subprocess 属性自动选择:
|
||||
- True: 创建子进程,使用 ProcessCoreSink
|
||||
- False: 在当前进程,使用 InProcessCoreSink
|
||||
"""
|
||||
definition = self._adapter_defs.get(adapter_name)
|
||||
if not definition:
|
||||
logger.error(f"适配器 {adapter_name} 未注册")
|
||||
return False
|
||||
|
||||
adapter_cls, plugin = definition
|
||||
run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False)
|
||||
|
||||
@@ -193,15 +223,14 @@ class AdapterManager:
|
||||
return await self._start_adapter_subprocess(adapter_name, adapter_cls, plugin)
|
||||
return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin)
|
||||
|
||||
async def _start_adapter_subprocess(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
|
||||
"""在子进程中启动适配器"""
|
||||
try:
|
||||
core_sink = get_core_sink()
|
||||
except Exception as e:
|
||||
logger.error(f"无法获取 core_sink,启动子进程 {adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
adapter_process = AdapterProcess(adapter_cls, plugin, core_sink)
|
||||
async def _start_adapter_subprocess(
|
||||
self,
|
||||
adapter_name: str,
|
||||
adapter_cls: type[BaseAdapter],
|
||||
plugin
|
||||
) -> bool:
|
||||
"""在子进程中启动适配器(使用 ProcessCoreSink)"""
|
||||
adapter_process = AdapterProcess(adapter_cls, plugin)
|
||||
success = await adapter_process.start()
|
||||
|
||||
if success:
|
||||
@@ -209,15 +238,25 @@ class AdapterManager:
|
||||
|
||||
return success
|
||||
|
||||
async def _start_adapter_in_process(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
|
||||
"""在当前进程中启动适配器"""
|
||||
async def _start_adapter_in_process(
|
||||
self,
|
||||
adapter_name: str,
|
||||
adapter_cls: type[BaseAdapter],
|
||||
plugin
|
||||
) -> bool:
|
||||
"""在当前进程中启动适配器(使用 InProcessCoreSink)"""
|
||||
try:
|
||||
core_sink = get_core_sink()
|
||||
# 从 CoreSinkManager 获取 InProcessCoreSink
|
||||
from src.common.core_sink_manager import get_core_sink_manager
|
||||
|
||||
core_sink = get_core_sink_manager().get_in_process_sink()
|
||||
adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg]
|
||||
await adapter.start()
|
||||
|
||||
self._in_process_adapters[adapter_name] = adapter
|
||||
logger.info(f"适配器 {adapter_name} 已在当前进程启动")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动适配器 {adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -17,12 +17,13 @@ from typing import Any, ClassVar, Dict, List, Optional
|
||||
import orjson
|
||||
import websockets
|
||||
|
||||
from mofox_bus import CoreMessageSink, MessageEnvelope, WebSocketAdapterOptions
|
||||
from mofox_bus import CoreSink, MessageEnvelope, WebSocketAdapterOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import register_plugin
|
||||
from src.plugin_system.base import BaseAdapter, BasePlugin
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from .src.handlers import utils as handler_utils
|
||||
from .src.handlers.to_core.message_handler import MessageHandler
|
||||
from .src.handlers.to_core.notice_handler import NoticeHandler
|
||||
from .src.handlers.to_core.meta_event_handler import MetaEventHandler
|
||||
@@ -41,22 +42,16 @@ class NapcatAdapter(BaseAdapter):
|
||||
platform = "qq"
|
||||
|
||||
run_in_subprocess = False
|
||||
subprocess_entry = None
|
||||
|
||||
def __init__(self, core_sink: CoreMessageSink, plugin: Optional[BasePlugin] = None):
|
||||
def __init__(self, core_sink: CoreSink, plugin: Optional[BasePlugin] = None):
|
||||
"""初始化 Napcat 适配器"""
|
||||
# 从插件配置读取 WebSocket URL
|
||||
if plugin:
|
||||
mode = config_api.get_plugin_config(plugin.config, "napcat_server.mode", "reverse")
|
||||
host = config_api.get_plugin_config(plugin.config, "napcat_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin.config, "napcat_server.port", 8095)
|
||||
url = config_api.get_plugin_config(plugin.config, "napcat_server.url", "")
|
||||
access_token = config_api.get_plugin_config(plugin.config, "napcat_server.access_token", "")
|
||||
|
||||
if mode == "forward" and url:
|
||||
ws_url = url
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}"
|
||||
|
||||
ws_url = f"ws://{host}:{port}"
|
||||
|
||||
headers = {}
|
||||
if access_token:
|
||||
@@ -69,8 +64,6 @@ class NapcatAdapter(BaseAdapter):
|
||||
transport = WebSocketAdapterOptions(
|
||||
url=ws_url,
|
||||
headers=headers if headers else None,
|
||||
incoming_parser=self._parse_napcat_message,
|
||||
outgoing_encoder=self._encode_napcat_response,
|
||||
)
|
||||
|
||||
super().__init__(core_sink, plugin=plugin, transport=transport)
|
||||
@@ -89,6 +82,9 @@ class NapcatAdapter(BaseAdapter):
|
||||
# 注意:_ws 继承自 BaseAdapter,是 WebSocketLike 协议类型
|
||||
self._napcat_ws = None # 可选的额外连接引用
|
||||
|
||||
# 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS
|
||||
handler_utils.register_adapter(self)
|
||||
|
||||
async def on_adapter_loaded(self) -> None:
|
||||
"""适配器加载时的初始化"""
|
||||
logger.info("Napcat 适配器正在启动...")
|
||||
@@ -114,22 +110,6 @@ class NapcatAdapter(BaseAdapter):
|
||||
|
||||
logger.info("Napcat 适配器已关闭")
|
||||
|
||||
def _parse_napcat_message(self, raw: str | bytes) -> Any:
|
||||
"""解析 Napcat/OneBot 消息"""
|
||||
try:
|
||||
if isinstance(raw, bytes):
|
||||
data = orjson.loads(raw)
|
||||
else:
|
||||
data = orjson.loads(raw)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Napcat 消息失败: {e}")
|
||||
raise
|
||||
|
||||
def _encode_napcat_response(self, envelope: MessageEnvelope) -> bytes:
|
||||
"""编码响应消息为 Napcat 格式(暂未使用,通过 API 调用发送)"""
|
||||
return orjson.dumps(envelope)
|
||||
|
||||
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
|
||||
"""
|
||||
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
|
||||
@@ -178,20 +158,6 @@ class NapcatAdapter(BaseAdapter):
|
||||
"""
|
||||
await self.send_handler.handle_message(envelope)
|
||||
|
||||
def _create_empty_envelope(self) -> MessageEnvelope: # type: ignore[return]
|
||||
"""创建一个空的消息信封(用于不需要处理的事件)"""
|
||||
import time
|
||||
return {
|
||||
"direction": "incoming",
|
||||
"message_info": {
|
||||
"platform": self.platform,
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"time": time.time(),
|
||||
},
|
||||
"message_segment": {"type": "text", "data": "[系统事件]"},
|
||||
"timestamp_ms": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
async def send_napcat_api(self, action: str, params: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 Napcat API 请求并等待响应
|
||||
@@ -260,18 +226,12 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": {"type": str, "default": "napcat_adapter_plugin"},
|
||||
"version": {"type": str, "default": "2.0.0"},
|
||||
"version": {"type": str, "default": "1.0.0"},
|
||||
"enabled": {"type": bool, "default": True},
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": {
|
||||
"type": str,
|
||||
"default": "reverse",
|
||||
"description": "连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||
},
|
||||
"host": {"type": str, "default": "localhost"},
|
||||
"port": {"type": int, "default": 8095},
|
||||
"url": {"type": str, "default": "", "description": "正向连接时的完整URL"},
|
||||
"access_token": {"type": str, "default": ""},
|
||||
},
|
||||
"features": {
|
||||
@@ -284,34 +244,18 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, plugin_dir: str = "", metadata: Any = None):
|
||||
# 如果没有提供参数,创建一个默认的元数据
|
||||
if metadata is None:
|
||||
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||
metadata = PluginMetadata(
|
||||
name=self.plugin_name,
|
||||
version=self.plugin_version,
|
||||
author=self.plugin_author,
|
||||
description=self.plugin_description,
|
||||
usage="",
|
||||
dependencies=[],
|
||||
python_dependencies=[],
|
||||
)
|
||||
|
||||
if not plugin_dir:
|
||||
from pathlib import Path
|
||||
plugin_dir = str(Path(__file__).parent)
|
||||
|
||||
super().__init__(plugin_dir, metadata)
|
||||
def __init__(self):
|
||||
self._adapter: Optional[NapcatAdapter] = None
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
"""插件加载时启动适配器"""
|
||||
logger.info("Napcat 适配器插件正在加载...")
|
||||
|
||||
# 获取核心 Sink
|
||||
from src.common.core_sink import get_core_sink
|
||||
core_sink = get_core_sink()
|
||||
# 从 CoreSinkManager 获取 InProcessCoreSink
|
||||
from src.common.core_sink_manager import get_core_sink_manager
|
||||
|
||||
core_sink_manager = get_core_sink_manager()
|
||||
core_sink = core_sink_manager.get_in_process_sink()
|
||||
|
||||
# 创建并启动适配器
|
||||
self._adapter = NapcatAdapter(core_sink, plugin=self)
|
||||
|
||||
@@ -5,11 +5,28 @@ from __future__ import annotations
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from mofox_bus import MessageBuilder
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from mofox_bus import (
|
||||
MessageEnvelope,
|
||||
SegPayload,
|
||||
MessageInfoPayload,
|
||||
UserInfoPayload,
|
||||
GroupInfoPayload,
|
||||
)
|
||||
|
||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_image_base64,
|
||||
get_self_info,
|
||||
get_member_info,
|
||||
get_message_detail,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
from ....plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.message_handler")
|
||||
|
||||
@@ -28,99 +45,190 @@ class MessageHandler:
|
||||
async def handle_raw_message(self, raw: Dict[str, Any]):
|
||||
"""
|
||||
处理原始消息并转换为 MessageEnvelope
|
||||
|
||||
|
||||
Args:
|
||||
raw: OneBot 原始消息数据
|
||||
|
||||
|
||||
Returns:
|
||||
MessageEnvelope (dict)
|
||||
"""
|
||||
from mofox_bus import MessageEnvelope, SegPayload, MessageInfoPayload, UserInfoPayload, GroupInfoPayload
|
||||
|
||||
message_type = raw.get("message_type")
|
||||
message_id = str(raw.get("message_id", ""))
|
||||
message_time = time.time()
|
||||
|
||||
|
||||
msg_builder = MessageBuilder()
|
||||
|
||||
# 构造用户信息
|
||||
sender_info = raw.get("sender", {})
|
||||
user_info: UserInfoPayload = {
|
||||
"platform": "qq",
|
||||
"user_id": str(sender_info.get("user_id", "")),
|
||||
"user_nickname": sender_info.get("nickname", ""),
|
||||
"user_cardname": sender_info.get("card", ""),
|
||||
"user_avatar": sender_info.get("avatar", ""),
|
||||
}
|
||||
|
||||
(
|
||||
msg_builder.direction("incoming")
|
||||
.message_id(message_id)
|
||||
.timestamp_ms(int(message_time * 1000))
|
||||
.from_user(
|
||||
user_id=str(sender_info.get("user_id", "")),
|
||||
platform="qq",
|
||||
nickname=sender_info.get("nickname", ""),
|
||||
cardname=sender_info.get("card", ""),
|
||||
user_avatar=sender_info.get("avatar", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# 构造群组信息(如果是群消息)
|
||||
group_info: Optional[GroupInfoPayload] = None
|
||||
if message_type == "group":
|
||||
group_id = raw.get("group_id")
|
||||
if group_id:
|
||||
group_info = {
|
||||
"platform": "qq",
|
||||
"group_id": str(group_id),
|
||||
"group_name": "", # 可以通过 API 获取
|
||||
}
|
||||
fetched_group_info = await get_group_info(group_id)
|
||||
(
|
||||
msg_builder.from_group(
|
||||
group_id=str(group_id),
|
||||
platform="qq",
|
||||
name=(
|
||||
fetched_group_info.get("group_name", "")
|
||||
if fetched_group_info
|
||||
else raw.get("group_name", "")
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# 解析消息段
|
||||
message_segments = raw.get("message", [])
|
||||
seg_list: List[SegPayload] = []
|
||||
|
||||
for seg in message_segments:
|
||||
seg_type = seg.get("type", "")
|
||||
seg_data = seg.get("data", {})
|
||||
|
||||
# 转换为 SegPayload
|
||||
if seg_type == "text":
|
||||
seg_list.append({
|
||||
"type": "text",
|
||||
"data": seg_data.get("text", "")
|
||||
})
|
||||
elif seg_type == "image":
|
||||
# 这里需要下载图片并转换为 base64(简化版本)
|
||||
seg_list.append({
|
||||
"type": "image",
|
||||
"data": seg_data.get("url", "") # 实际应该转换为 base64
|
||||
})
|
||||
elif seg_type == "at":
|
||||
seg_list.append({
|
||||
"type": "at",
|
||||
"data": f"{seg_data.get('qq', '')}"
|
||||
})
|
||||
# 其他消息类型...
|
||||
|
||||
# 构造 MessageInfoPayload
|
||||
message_info = {
|
||||
"platform": "qq",
|
||||
"message_id": message_id,
|
||||
"time": message_time,
|
||||
"user_info": user_info,
|
||||
"format_info": {
|
||||
"content_format": ["text", "image"], # 根据实际消息类型设置
|
||||
"accept_format": ["text", "image", "emoji", "voice"],
|
||||
},
|
||||
}
|
||||
|
||||
# 添加群组信息(如果存在)
|
||||
if group_info:
|
||||
message_info["group_info"] = group_info
|
||||
for segment in message_segments:
|
||||
seg_message = await self.handle_single_segment(segment, raw)
|
||||
if seg_message:
|
||||
seg_list.append(seg_message)
|
||||
|
||||
# 构造 MessageEnvelope
|
||||
envelope = {
|
||||
"direction": "incoming",
|
||||
"message_info": message_info,
|
||||
"message_segment": {"type": "seglist", "data": seg_list} if len(seg_list) > 1 else (seg_list[0] if seg_list else {"type": "text", "data": ""}),
|
||||
"raw_message": raw.get("raw_message", ""),
|
||||
"platform": "qq",
|
||||
"message_id": message_id,
|
||||
"timestamp_ms": int(message_time * 1000),
|
||||
}
|
||||
msg_builder.format_info(
|
||||
content_format=[seg["type"] for seg in seg_list],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
)
|
||||
|
||||
return envelope
|
||||
return msg_builder.build()
|
||||
|
||||
async def handle_single_segment(
|
||||
self, segment: dict, raw_message: dict, in_reply: bool = False
|
||||
) -> SegPayload | List[SegPayload] | None:
|
||||
"""
|
||||
处理单一消息段并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
segment: 单一原始消息段
|
||||
raw_message: 完整的原始消息数据
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
SegPayload | List[SegPayload] | None
|
||||
"""
|
||||
seg_type = segment.get("type")
|
||||
seg_data: dict = segment.get("data", {})
|
||||
match seg_type:
|
||||
case "text":
|
||||
return {"type": "text", "data": seg_data.get("text", "")}
|
||||
case "image":
|
||||
image_sub_type = seg_data.get("sub_type")
|
||||
try:
|
||||
image_base64 = await get_image_base64(seg_data.get("url", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
"""这部分认为是图片"""
|
||||
return {"type": "image", "data": image_base64}
|
||||
elif image_sub_type not in [4, 9]:
|
||||
"""这部分认为是表情包"""
|
||||
return {"type": "emoji", "data": image_base64}
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
case "face":
|
||||
message_data: dict = segment.get("data", {})
|
||||
face_raw_id: str = str(message_data.get("id"))
|
||||
if face_raw_id in QQ_FACE:
|
||||
face_content: str = QQ_FACE.get(face_raw_id, "[未知表情]")
|
||||
return {"type": "text", "data": face_content}
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
case "at":
|
||||
if seg_data:
|
||||
qq_id = seg_data.get("qq")
|
||||
self_id = raw_message.get("self_id")
|
||||
group_id = raw_message.get("group_id")
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info = await get_self_info()
|
||||
if self_info:
|
||||
# 返回包含昵称和用户ID的at格式,便于后续处理
|
||||
return {
|
||||
"type": "at",
|
||||
"data": f"{self_info.get('nickname')}:{self_info.get('user_id')}",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
if qq_id and group_id:
|
||||
member_info = await get_member_info(
|
||||
group_id=group_id, user_id=qq_id
|
||||
)
|
||||
if member_info:
|
||||
# 返回包含昵称和用户ID的at格式,便于后续处理
|
||||
return {
|
||||
"type": "at",
|
||||
"data": f"{member_info.get('nickname')}:{member_info.get('user_id')}",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
case "emoji":
|
||||
seg_data = segment.get("id", "")
|
||||
case "reply":
|
||||
if not in_reply:
|
||||
message_id = None
|
||||
if seg_data:
|
||||
message_id = seg_data.get("id")
|
||||
else:
|
||||
return None
|
||||
message_detail = await get_message_detail(message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return None
|
||||
reply_message = await self.handle_single_segment(
|
||||
message_detail, raw_message, in_reply=True
|
||||
)
|
||||
if reply_message is None:
|
||||
reply_message = [
|
||||
{"type": "text", "data": "[无法获取被引用的消息]"}
|
||||
]
|
||||
sender_info: dict = message_detail.get("sender", {})
|
||||
sender_nickname: str = sender_info.get("nickname", "")
|
||||
sender_id = sender_info.get("user_id")
|
||||
seg_message: List[SegPayload] = []
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
seg_message.append(
|
||||
{
|
||||
"type": "text",
|
||||
"data": f"[回复<未知用户>:{reply_message}],说:",
|
||||
}
|
||||
)
|
||||
else:
|
||||
if sender_id:
|
||||
seg_message.append(
|
||||
{
|
||||
"type": "text",
|
||||
"data": f"[回复<{sender_nickname}({sender_id})>:{reply_message}],说:",
|
||||
}
|
||||
)
|
||||
else:
|
||||
seg_message.append(
|
||||
{
|
||||
"type": "text",
|
||||
"data": f"[回复<{sender_nickname}>:{reply_message}],说:",
|
||||
}
|
||||
)
|
||||
return seg_message
|
||||
case "voice":
|
||||
seg_data = segment.get("url", "")
|
||||
case _:
|
||||
logger.warning(f"Unsupported segment type: {seg_type}")
|
||||
|
||||
361
src/plugins/built_in/NEW_napcat_adapter/src/handlers/utils.py
Normal file
361
src/plugins/built_in/NEW_napcat_adapter/src/handlers/utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import ssl
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
import orjson
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
# 简单的缓存实现,通过 JSON 文件实现磁盘一价存储
|
||||
_CACHE_FILE = Path(__file__).resolve().parent / "napcat_cache.json"
|
||||
_CACHE_LOCK = asyncio.Lock()
|
||||
_CACHE: Dict[str, Dict[str, Dict[str, Any]]] = {
|
||||
"group_info": {},
|
||||
"group_detail_info": {},
|
||||
"member_info": {},
|
||||
"stranger_info": {},
|
||||
"self_info": {},
|
||||
}
|
||||
|
||||
# 各类信息的 TTL 缓存过期时间设置
|
||||
GROUP_INFO_TTL = 300 # 5 min
|
||||
GROUP_DETAIL_TTL = 300
|
||||
MEMBER_INFO_TTL = 180
|
||||
STRANGER_INFO_TTL = 300
|
||||
SELF_INFO_TTL = 300
|
||||
|
||||
_adapter_ref: weakref.ReferenceType["NapcatAdapter"] | None = None
|
||||
|
||||
|
||||
def register_adapter(adapter: "NapcatAdapter") -> None:
|
||||
"""注册 NapcatAdapter 实例,以便 utils 模块可以获取 WebSocket"""
|
||||
global _adapter_ref
|
||||
_adapter_ref = weakref.ref(adapter)
|
||||
logger.debug("Napcat adapter registered in utils for websocket access")
|
||||
|
||||
|
||||
def _load_cache_from_disk() -> None:
|
||||
if not _CACHE_FILE.exists():
|
||||
return
|
||||
try:
|
||||
data = orjson.loads(_CACHE_FILE.read_bytes())
|
||||
if isinstance(data, dict):
|
||||
for key, section in _CACHE.items():
|
||||
cached_section = data.get(key)
|
||||
if isinstance(cached_section, dict):
|
||||
section.update(cached_section)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load napcat cache: {e}")
|
||||
|
||||
|
||||
def _save_cache_to_disk_locked() -> None:
|
||||
"""重要提示:不要在持有 _CACHE_LOCK 时调用此函数"""
|
||||
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_CACHE_FILE.write_bytes(orjson.dumps(_CACHE))
|
||||
|
||||
|
||||
async def _get_cached(section: str, key: str, ttl: int) -> Any | None:
|
||||
now = time.time()
|
||||
async with _CACHE_LOCK:
|
||||
entry = _CACHE.get(section, {}).get(key)
|
||||
if not entry:
|
||||
return None
|
||||
ts = entry.get("ts", 0)
|
||||
if ts and now - ts <= ttl:
|
||||
return entry.get("data")
|
||||
_CACHE.get(section, {}).pop(key, None)
|
||||
try:
|
||||
_save_cache_to_disk_locked()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cached(section: str, key: str, data: Any) -> None:
|
||||
async with _CACHE_LOCK:
|
||||
_CACHE.setdefault(section, {})[key] = {"data": data, "ts": time.time()}
|
||||
try:
|
||||
_save_cache_to_disk_locked()
|
||||
except Exception:
|
||||
logger.debug("Write napcat cache failed", exc_info=True)
|
||||
|
||||
|
||||
def _get_adapter(adapter: "NapcatAdapter | None" = None) -> "NapcatAdapter":
|
||||
target = adapter
|
||||
if target is None and _adapter_ref:
|
||||
target = _adapter_ref()
|
||||
if target is None:
|
||||
raise RuntimeError("NapcatAdapter 未注册,请确保已调用 utils.register_adapter 注册")
|
||||
return target
|
||||
|
||||
|
||||
async def _call_adapter_api(
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
timeout: float = 30.0,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""统一通过 adapter 发送和接收 API 调用"""
|
||||
try:
|
||||
target = _get_adapter(adapter)
|
||||
# 确保 WS 已连接
|
||||
target.get_ws_connection()
|
||||
except Exception as e: # pragma: no cover - 难以在单元测试中查看
|
||||
logger.error(f"WebSocket 未准备好,无法调用 API: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return await target.send_napcat_api(action, params, timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"{action} 调用失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 加载缓存到内存一次,避免在每次调用缓存时重复加载
|
||||
_load_cache_from_disk()
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
context = ssl.create_default_context()
|
||||
context.set_ciphers("DEFAULT@SECLEVEL=1")
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
kwargs["ssl_context"] = context
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
async def get_group_info(
|
||||
group_id: int,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取群组基本信息
|
||||
|
||||
返回值可能是None,需要调用方检查空值
|
||||
"""
|
||||
logger.debug("获取群组基本信息中")
|
||||
cache_key = str(group_id)
|
||||
if use_cache and not force_refresh:
|
||||
cached = await _get_cached("group_info", cache_key, GROUP_INFO_TTL)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
socket_response = await _call_adapter_api(
|
||||
"get_group_info",
|
||||
{"group_id": group_id},
|
||||
adapter=adapter,
|
||||
)
|
||||
data = socket_response.get("data") if socket_response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("group_info", cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_group_detail_info(
|
||||
group_id: int,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取群组详细信息
|
||||
|
||||
返回值可能是None,需要调用方检查空值
|
||||
"""
|
||||
logger.debug("获取群组详细信息中")
|
||||
cache_key = str(group_id)
|
||||
if use_cache and not force_refresh:
|
||||
cached = await _get_cached("group_detail_info", cache_key, GROUP_DETAIL_TTL)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
socket_response = await _call_adapter_api(
|
||||
"get_group_detail_info",
|
||||
{"group_id": group_id},
|
||||
adapter=adapter,
|
||||
)
|
||||
data = socket_response.get("data") if socket_response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("group_detail_info", cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_member_info(
|
||||
group_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取群组成员信息
|
||||
|
||||
返回值可能是None,需要调用方检查空值
|
||||
"""
|
||||
logger.debug("获取群组成员信息中")
|
||||
cache_key = f"{group_id}:{user_id}"
|
||||
if use_cache and not force_refresh:
|
||||
cached = await _get_cached("member_info", cache_key, MEMBER_INFO_TTL)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
socket_response = await _call_adapter_api(
|
||||
"get_group_member_info",
|
||||
{"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
adapter=adapter,
|
||||
)
|
||||
data = socket_response.get("data") if socket_response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("member_info", cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
# sourcery skip: raise-specific-error
|
||||
"""下载图片/视频并返回Base64"""
|
||||
logger.debug(f"下载图片: {url}")
|
||||
http = SSLAdapter()
|
||||
try:
|
||||
response = http.request("GET", url, timeout=10)
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP Error: {response.status}")
|
||||
image_bytes = response.data
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片下载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_image_to_gif(image_base64: str) -> str:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将Base64编码的图片转换为GIF格式
|
||||
Parameters:
|
||||
image_base64: str: Base64编码的图片数据
|
||||
Returns:
|
||||
str: Base64编码的GIF图片数据
|
||||
"""
|
||||
logger.debug("转换图片为GIF格式")
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="GIF")
|
||||
output_buffer.seek(0)
|
||||
return base64.b64encode(output_buffer.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为GIF失败: {str(e)}")
|
||||
return image_base64
|
||||
|
||||
|
||||
async def get_self_info(
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取机器人信息
|
||||
"""
|
||||
logger.debug("获取机器人信息中")
|
||||
cache_key = "self"
|
||||
if use_cache and not force_refresh:
|
||||
cached = await _get_cached("self_info", cache_key, SELF_INFO_TTL)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
response = await _call_adapter_api("get_login_info", {}, adapter=adapter)
|
||||
data = response.get("data") if response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("self_info", cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
"""
|
||||
从Base64编码的数据中确定图片的格式类型
|
||||
Parameters:
|
||||
raw_data: str: Base64编码的图片数据
|
||||
Returns:
|
||||
format: str: 图片的格式类型,如 'jpeg', 'png', 'gif'等
|
||||
"""
|
||||
image_bytes = base64.b64decode(raw_data)
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
|
||||
async def get_stranger_info(
|
||||
user_id: int,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取陌生人信息
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
cache_key = str(user_id)
|
||||
if use_cache and not force_refresh:
|
||||
cached = await _get_cached("stranger_info", cache_key, STRANGER_INFO_TTL)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
response = await _call_adapter_api("get_stranger_info", {"user_id": user_id}, adapter=adapter)
|
||||
data = response.get("data") if response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("stranger_info", cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_message_detail(
|
||||
message_id: Union[str, int],
|
||||
*,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取消息详情,仅作为参考
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
response = await _call_adapter_api(
|
||||
"get_msg",
|
||||
{"message_id": message_id},
|
||||
adapter=adapter,
|
||||
timeout=30,
|
||||
)
|
||||
return response.get("data") if response else None
|
||||
|
||||
|
||||
async def get_record_detail(
|
||||
file: str,
|
||||
file_id: Optional[str] = None,
|
||||
*,
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取语音信息详情
|
||||
"""
|
||||
logger.debug("获取语音信息详情中")
|
||||
response = await _call_adapter_api(
|
||||
"get_record",
|
||||
{"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
adapter=adapter,
|
||||
timeout=30,
|
||||
)
|
||||
return response.get("data") if response else None
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import time
|
||||
import websockets as Server
|
||||
import uuid
|
||||
from mofox_bus import (
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.8.2"
|
||||
version = "7.8.3"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
|
||||
Reference in New Issue
Block a user