Merge branch 'master' of https://github.com/MoFox-Studio/MoFox_Bot
2
.gitignore
vendored
@@ -325,6 +325,8 @@ run_pet.bat
|
||||
!/plugins/set_emoji_like
|
||||
!/plugins/permission_example
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/napcat_adapter_plugin
|
||||
!/plugins/echo_example
|
||||
|
||||
config.toml
|
||||
|
||||
4
bot.py
@@ -31,7 +31,6 @@ from src.manager.async_task_manager import async_task_manager # noqa
|
||||
from src.config.config import global_config # noqa
|
||||
from src.common.database.database import initialize_sql_database # noqa
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
|
||||
from src.common.tool_history import wrap_tool_executor #noqa
|
||||
|
||||
logger = get_logger("main")
|
||||
|
||||
@@ -240,8 +239,7 @@ class MaiBotMain(BaseMain):
|
||||
self.setup_timezone()
|
||||
self.check_and_confirm_eula()
|
||||
self.initialize_database()
|
||||
# 初始化工具历史记录
|
||||
wrap_tool_executor()
|
||||
|
||||
return self.create_main_system()
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,68 @@
|
||||
# Changelog
|
||||
|
||||
## [0.10.0-alpha] - 2025-8-28
|
||||
|
||||
> **MoFox-Bot 0.10.0-alpha 版本发布!**
|
||||
>
|
||||
> 本次更新带来了多项核心功能增强和系统优化。
|
||||
>
|
||||
> 在**新功能**方面,我们引入了**持久化回复跟踪**以避免重复回复,并为LLM请求实现了**模型故障转移机制**,提高了系统的健壮性。插件系统也得到了增强,增加了**事件触发和订阅的白名单机制**,并为事件处理器添加了**异步锁和并行执行**支持。此外,新版本还实现了对**说说中图片的识别与理解**,并引入了**弹性睡眠与睡前通知机制**,使机器人的作息更加智能化。
|
||||
>
|
||||
> 在**修复**方面,我们解决了`enable`配置的bug,修复了event权限问题,并处理了专注模式下艾特不回复、模型信息不存在、cookie获取失败等多个问题,提升了系统的稳定性。
|
||||
>
|
||||
> 在**重构**方面,我们移除了`changelog_config`并更新了模型配置模板,重构了工具缓存机制和LLM请求重试逻辑。同时,我们将项目从`MaiMbot-Pro-Max`正式更名为`MoFox_Bot`,并对内存、权限、配置等多个模块进行了优化。
|
||||
>
|
||||
> 总的来说,0.11.0版本在功能、稳定性和代码质量上都有了显著提升。
|
||||
|
||||
### 新功能
|
||||
- **maizone**: 引入持久化回复跟踪以避免重复回复
|
||||
- **llm**: 为LLM请求实现模型故障转移机制
|
||||
- **plugin-system**: 添加事件触发和订阅的白名单机制
|
||||
- **plugin**: 为事件处理器添加异步锁和并行执行支持
|
||||
- **maizone**: 实现对说说中图片的识别与理解
|
||||
- **sleep**: 实现睡眠唤醒与重新入睡机制
|
||||
- **core**: 实现HFC及睡眠状态的持久化
|
||||
- **schedule**: 引入弹性睡眠与睡前通知机制
|
||||
- **monthly_plan**: 增加月度计划数量上限并自动清理
|
||||
- **command**: 添加PlusCommand增强命令系统
|
||||
- **expression**: 重构表达学习配置,引入基于规则的结构化定义
|
||||
- **chat**: 实现睡眠压力和失眠系统
|
||||
- **schedule**: 重构日程与月度计划管理模块
|
||||
- **core**: 集成统一向量数据库服务并重构相关模块
|
||||
- **tool_system**: 实现工具的声明式缓存
|
||||
- **plugin_system**: 增加工具执行日志记录
|
||||
- **maizone**: 新增QQ空间互通组功能,根据聊天上下文生成说说
|
||||
- **monthly_plan**: 增强月度计划系统,引入状态管理和智能抽取
|
||||
|
||||
### 修复
|
||||
- 修复`enable`配置
|
||||
- 修复event权限,现在每个component都拥有`plugin_name`属性
|
||||
- 修复专注模式下艾特不回复的问题
|
||||
- 修复模型信息不存在时引发的属性错误
|
||||
- 修复`maizone_refactored`获取cookie时响应为空导致的错误
|
||||
- 修复关键词非列表形式时导致的解析错误
|
||||
- 修复即时记忆的 orjson 编码与解码问题
|
||||
- 修复空回复检测,同时修复`tool_call`
|
||||
- 处理截断消息时`message`为`None`的情况
|
||||
- 修复`get_remaining`的起始索引
|
||||
- 修复回复自己评论的问题
|
||||
- 修复`enable`配置
|
||||
|
||||
### 重构
|
||||
- **config**: 移除`changelog_config`并更新模型配置模板
|
||||
- **cache**: 重构工具缓存机制并优化LLM请求重试逻辑
|
||||
- **core**: 移除工具历史记录管理器并将缓存集成到工具执行器中
|
||||
- 重构权限检查和装饰器用法
|
||||
- **core**: 将项目从`MaiMbot-Pro-Max`重命名为`MoFox_Bot`
|
||||
- **memory**: 重构向量记忆清理逻辑以提高稳定性
|
||||
- **llm_models**: 移除官方Gemini客户端并改用`aiohttp`实现
|
||||
- **config**: 整合搜索服务配置并移除废弃选项
|
||||
- **config**: 将反截断设置移至模型配置
|
||||
- **video**: 重构视频分析,增加抽帧模式和间隔配置
|
||||
|
||||
从这里开始都是第三方改版的更新记录!!!!!
|
||||
========================================================
|
||||
|
||||
## [0.10.0] - 2025-7-1
|
||||
### 主要功能更改
|
||||
- 工具系统重构,现在合并到了插件系统中
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
## [1.0.3] - 2025-3-31
|
||||
### Added
|
||||
- 新增了心流相关配置项:
|
||||
- `heartflow` 配置项,用于控制心流功能
|
||||
|
||||
### Removed
|
||||
- 移除了 `response` 配置项中的 `model_r1_probability` 和 `model_v3_probability` 选项
|
||||
- 移除了次级推理模型相关配置
|
||||
|
||||
## [1.0.1] - 2025-3-30
|
||||
### Added
|
||||
- 增加了流式输出控制项 `stream`
|
||||
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
|
||||
|
||||
## [1.0.0] - 2025-3-30
|
||||
### Added
|
||||
- 修复了错误的版本命名
|
||||
- 杀掉了所有无关文件
|
||||
|
||||
## [0.0.11] - 2025-3-12
|
||||
### Added
|
||||
- 新增了 `schedule` 配置项,用于配置日程表生成功能
|
||||
- 新增了 `response_splitter` 配置项,用于控制回复分割
|
||||
- 新增了 `experimental` 配置项,用于实验性功能开关
|
||||
- 新增了 `llm_observation` 和 `llm_sub_heartflow` 模型配置
|
||||
- 新增了 `llm_heartflow` 模型配置
|
||||
- 在 `personality` 配置项中新增了 `prompt_schedule_gen` 参数
|
||||
|
||||
### Changed
|
||||
- 优化了模型配置的组织结构
|
||||
- 调整了部分配置项的默认值
|
||||
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
|
||||
- 在 `message` 配置项中:
|
||||
- 新增了 `model_max_output_length` 参数
|
||||
- 在 `willing` 配置项中新增了 `emoji_response_penalty` 参数
|
||||
- 将 `personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
|
||||
|
||||
### Removed
|
||||
- 移除了 `min_text_length` 配置项
|
||||
- 移除了 `cq_code` 配置项
|
||||
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
|
||||
|
||||
## [0.0.5] - 2025-3-11
|
||||
### Added
|
||||
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
|
||||
|
||||
## [0.0.4] - 2025-3-9
|
||||
### Added
|
||||
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。
|
||||
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 4.9 KiB After Width: | Height: | Size: 4.9 KiB |
124
docs/deployment_guide.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# MoFox_Bot 部署指南
|
||||
|
||||
欢迎使用 MoFox_Bot!本指南将引导您完成在 Windows 环境下部署 MoFox_Bot 的全部过程。
|
||||
|
||||
## 1. 系统要求
|
||||
|
||||
- **操作系统**: Windows 10 或 Windows 11
|
||||
- **Python**: 版本 >= 3.10
|
||||
- **Git**: 用于克隆项目仓库
|
||||
- **uv**: 推荐的 Python 包管理器 (版本 >= 0.1.0)
|
||||
|
||||
## 2. 部署步骤
|
||||
|
||||
### 第一步:获取必要的文件
|
||||
|
||||
首先,创建一个用于存放 MoFox_Bot 相关文件的文件夹,并通过 `git` 克隆 MoFox_Bot 主程序和 Napcat 适配器。
|
||||
|
||||
```shell
|
||||
mkdir MoFox_Bot_Deployment
|
||||
cd MoFox_Bot_Deployment
|
||||
git clone hhttps://github.com/MoFox-Studio/MoFox_Bot.git
|
||||
git clone https://github.com/MoFox-Studio/Napcat-Adapter.git
|
||||
```
|
||||
|
||||
### 第二步:环境配置
|
||||
|
||||
我们推荐使用 `uv` 来管理 Python 环境和依赖,因为它提供了更快的安装速度和更好的依赖管理体验。
|
||||
|
||||
**安装 uv:**
|
||||
|
||||
```shell
|
||||
pip install uv
|
||||
```
|
||||
|
||||
### 第三步:依赖安装
|
||||
|
||||
**1. 安装 MoFox_Bot 依赖:**
|
||||
|
||||
进入 `mmc` 文件夹,创建虚拟环境并安装依赖。
|
||||
|
||||
```shell
|
||||
cd mmc
|
||||
uv venv
|
||||
uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade
|
||||
```
|
||||
|
||||
**2. 安装 Napcat-Adapter 依赖:**
|
||||
|
||||
回到上一级目录,进入 `Napcat-Adapter` 文件夹,创建虚拟环境并安装依赖。
|
||||
|
||||
```shell
|
||||
cd ..
|
||||
cd Napcat-Adapter
|
||||
uv venv
|
||||
uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade
|
||||
```
|
||||
|
||||
### 第四步:配置 MoFox_Bot 和 Adapter
|
||||
|
||||
**1. MoFox_Bot 配置:**
|
||||
|
||||
- 在 `mmc` 文件夹中,将 `template/bot_config_template.toml` 复制到 `config/bot_config.toml`。
|
||||
- 将 `template/model_config_template.toml` 复制到 `config/model_config.toml`。
|
||||
- 根据 [模型配置指南](guides/model_configuration_guide.md) 和 `bot_config.toml` 文件中的注释,填写您的 API Key 和其他相关配置。
|
||||
|
||||
**2. Napcat-Adapter 配置:**
|
||||
|
||||
- 在 `Napcat-Adapter` 文件夹中,将 `template/template_config.toml` 复制到根目录并改名为 `config.toml`。
|
||||
- 打开 `config.toml` 文件,配置 `[Napcat_Server]` 和 `[MaiBot_Server]` 字段。
|
||||
- `[Napcat_Server]` 的 `port` 应与 Napcat 设置的反向代理 URL 中的端口相同。
|
||||
- `[MaiBot_Server]` 的 `port` 应与 MoFox_Bot 的 `bot_config.toml` 中设置的端口相同。
|
||||
|
||||
### 第五步:运行
|
||||
|
||||
**1. 启动 Napcat:**
|
||||
|
||||
请参考 [NapCatQQ 文档](https://napcat-qq.github.io/) 进行部署和启动。
|
||||
|
||||
**2. 启动 MoFox_Bot:**
|
||||
|
||||
进入 `mmc` 文件夹,使用 `uv` 运行。
|
||||
|
||||
```shell
|
||||
cd mmc
|
||||
uv run python bot.py
|
||||
```
|
||||
|
||||
**3. 启动 Napcat-Adapter:**
|
||||
|
||||
打开一个新的终端窗口,进入 `Napcat-Adapter` 文件夹,使用 `uv` 运行。
|
||||
|
||||
```shell
|
||||
cd Napcat-Adapter
|
||||
uv run python main.py
|
||||
```
|
||||
|
||||
至此,MoFox_Bot 已成功部署并运行。
|
||||
|
||||
## 3. 详细配置说明
|
||||
|
||||
### `bot_config.toml`
|
||||
|
||||
这是 MoFox_Bot 的主配置文件,包含了机器人昵称、主人QQ、命令前缀、数据库设置等。请根据文件内的注释进行详细配置。
|
||||
|
||||
### `model_config.toml`
|
||||
|
||||
此文件用于配置 AI 模型和 API 服务提供商。详细配置方法请参考 [模型配置指南](guides/model_configuration_guide.md)。
|
||||
|
||||
### 插件配置
|
||||
|
||||
每个插件都有独立的配置文件,位于 `mmc/config/plugins/` 目录下。插件的配置由其 `config_schema` 自动生成。详细信息请参考 [插件配置完整指南](plugins/configuration-guide.md)。
|
||||
|
||||
## 4. 故障排除
|
||||
|
||||
- **依赖安装失败**:
|
||||
- 尝试更换 PyPI 镜像源。
|
||||
- 检查网络连接。
|
||||
- **API 调用失败**:
|
||||
- 检查 `model_config.toml` 中的 API Key 和 `base_url` 是否正确。
|
||||
- **无法连接到 Napcat**:
|
||||
- 检查 Napcat 是否正常运行。
|
||||
- 确认 `Napcat-Adapter` 的 `config.toml` 中 `[Napcat_Server]` 的 `port` 是否与 Napcat 设置的端口一致。
|
||||
|
||||
如果遇到其他问题,请查看 `logs/` 目录下的日志文件以获取详细的错误信息。
|
||||
@@ -43,12 +43,11 @@ retry_interval = 10 # 重试间隔(秒)
|
||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||
| `base_url` | ✅ | API服务的基础URL | - |
|
||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`、`gemini` 或 `aiohttp_gemini` | `openai` |
|
||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
||||
### 2.3 支持的服务商示例
|
||||
|
||||
#### DeepSeek
|
||||
@@ -73,9 +72,9 @@ client_type = "openai"
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "Google"
|
||||
base_url = "https://api.google.com/v1"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta" # 在MoFox-Bot中, 使用aiohttp_gemini客户端的提供商可以自定义base_url
|
||||
api_key = "your-google-api-key"
|
||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||
client_type = "aiohttp_gemini" # 注意:Gemini需要使用特殊客户端
|
||||
```
|
||||
|
||||
## 3. 模型配置
|
||||
@@ -118,11 +117,11 @@ enable_thinking = false # 禁用思考
|
||||
|
||||
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
|
||||
|
||||

|
||||

|
||||
|
||||
以豆包文档为另一个例子
|
||||
|
||||

|
||||

|
||||
|
||||
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
|
||||
```toml
|
||||
@@ -133,7 +132,6 @@ thinking = {type = "disabled"} # 禁用思考
|
||||
```
|
||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
||||
### 3.3 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
@@ -145,6 +143,7 @@ thinking = {type = "disabled"} # 禁用思考
|
||||
| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
|
||||
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
|
||||
| `extra_params` | ❌ | 额外的模型参数配置 |
|
||||
| `anti_truncation` | ❌ | 是否启用反截断功能 |
|
||||
|
||||
## 4. 模型任务配置
|
||||
|
||||
@@ -184,7 +183,7 @@ max_tokens = 800
|
||||
```
|
||||
|
||||
### planner - 决策模型
|
||||
负责决定MaiBot该做什么:
|
||||
负责决定MoFox_Bot该做什么:
|
||||
```toml
|
||||
[model_task_config.planner]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
@@ -193,7 +192,7 @@ max_tokens = 800
|
||||
```
|
||||
|
||||
### emotion - 情绪模型
|
||||
负责MaiBot的情绪变化:
|
||||
负责MoFox_Bot的情绪变化:
|
||||
```toml
|
||||
[model_task_config.emotion]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
@@ -262,6 +261,44 @@ temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### schedule_generator - 日程生成模型
|
||||
```toml
|
||||
[model_task_config.schedule_generator]
|
||||
model_list = ["deepseek-v3"]
|
||||
temperature = 0.5
|
||||
max_tokens = 1024
|
||||
```
|
||||
|
||||
### monthly_plan_generator - 月度计划生成模型
|
||||
```toml
|
||||
[model_task_config.monthly_plan_generator]
|
||||
model_list = ["deepseek-v3"]
|
||||
temperature = 0.7
|
||||
max_tokens = 1024
|
||||
```
|
||||
|
||||
### emoji_vlm - 表情包VLM模型
|
||||
```toml
|
||||
[model_task_config.emoji_vlm]
|
||||
model_list = ["qwen-vl-max"]
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### anti_injection - 反注入模型
|
||||
```toml
|
||||
[model_task_config.anti_injection]
|
||||
model_list = ["deepseek-v3"]
|
||||
temperature = 0.1
|
||||
max_tokens = 512
|
||||
```
|
||||
|
||||
### utils_video - 视频分析模型
|
||||
```toml
|
||||
[model_task_config.utils_video]
|
||||
model_list = ["qwen-vl-max"]
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
## 5. 配置建议
|
||||
|
||||
### 5.1 Temperature 参数选择
|
||||
@@ -276,7 +313,7 @@ max_tokens = 800
|
||||
|
||||
| 任务类型 | 推荐模型类型 | 示例 |
|
||||
|----------|--------------|------|
|
||||
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
|
||||
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-5,Gemini-2.5-Pro |
|
||||
| 高频率任务 | 小模型 | Qwen3-8B |
|
||||
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
|
||||
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
|
||||
@@ -285,7 +322,6 @@ max_tokens = 800
|
||||
|
||||
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
|
||||
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
|
||||
3. **选择免费模型**:对于测试环境,优先使用price为0的模型
|
||||
|
||||
## 6. 配置验证
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
# 💻 Command组件详解
|
||||
|
||||
## 📖 什么是Command
|
||||
|
||||
Command是直接响应用户明确指令的组件,与Action不同,Command是**被动触发**的,当用户输入特定格式的命令时立即执行。
|
||||
|
||||
Command通过正则表达式匹配用户输入,提供确定性的功能服务。
|
||||
|
||||
### 🎯 Command的特点
|
||||
|
||||
- 🎯 **确定性执行**:匹配到命令立即执行,无随机性
|
||||
- ⚡ **即时响应**:用户主动触发,快速响应
|
||||
- 🔍 **正则匹配**:通过正则表达式精确匹配用户输入
|
||||
- 🛑 **拦截控制**:可以控制是否阻止消息继续处理
|
||||
- 📝 **参数解析**:支持从用户输入中提取参数
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Command组件的基本结构
|
||||
|
||||
首先,Command组件需要继承自`BaseCommand`类,并实现必要的方法。
|
||||
|
||||
```python
|
||||
class ExampleCommand(BaseCommand):
|
||||
command_name = "example" # 命令名称,作为唯一标识符
|
||||
command_description = "这是一个示例命令" # 命令描述
|
||||
command_pattern = r"" # 命令匹配的正则表达式
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""
|
||||
执行Command的主要逻辑
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]:
|
||||
- 第一个bool表示是否成功执行
|
||||
- 第二个str是执行结果消息
|
||||
- 第三个bool表示是否需要阻止消息继续处理
|
||||
"""
|
||||
# ---- 执行命令的逻辑 ----
|
||||
return True, "执行成功", False
|
||||
```
|
||||
**`command_pattern`**: 该Command匹配的正则表达式,用于精确匹配用户输入。
|
||||
|
||||
请注意:如果希望能获取到命令中的参数,请在正则表达式中使用有命名的捕获组,例如`(?P<param_name>pattern)`。
|
||||
|
||||
这样在匹配时,内部实现可以使用`re.match.groupdict()`方法获取到所有捕获组的参数,并以字典的形式存储在`self.matched_groups`中。
|
||||
|
||||
### 匹配样例
|
||||
假设我们有一个命令`/example param1=value1 param2=value2`,对应的正则表达式可以是:
|
||||
|
||||
```python
|
||||
class ExampleCommand(BaseCommand):
|
||||
command_name = "example"
|
||||
command_description = "这是一个示例命令"
|
||||
command_pattern = r"/example (?P<param1>\w+) (?P<param2>\w+)"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
# 获取匹配的参数
|
||||
param1 = self.matched_groups.get("param1")
|
||||
param2 = self.matched_groups.get("param2")
|
||||
|
||||
# 执行逻辑
|
||||
return True, f"参数1: {param1}, 参数2: {param2}", False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Command 内置方法说明
|
||||
```python
|
||||
class BaseCommand:
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问"""
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "") -> bool:
|
||||
"""发送回复消息"""
|
||||
|
||||
async def send_type(self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "") -> bool:
|
||||
"""发送指定类型的回复消息到当前聊天环境"""
|
||||
|
||||
async def send_command(self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True) -> bool:
|
||||
"""发送命令消息"""
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包"""
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片"""
|
||||
```
|
||||
具体参数与用法参见`BaseCommand`基类的定义。
|
||||
@@ -9,7 +9,7 @@
|
||||
## 组件功能详解
|
||||
|
||||
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
|
||||
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
|
||||
- [💻 Command组件详解](PLUS_COMMAND_GUIDE.md) - 学习直接响应命令的组件
|
||||
- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
|
||||
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
|
||||
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
|
||||
|
||||
@@ -90,7 +90,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
|
||||
在日志中你应该能看到插件被加载的信息。虽然插件还没有任何功能,但它已经成功运行了!
|
||||
|
||||

|
||||

|
||||
|
||||
### 5. 添加第一个功能:问候Action
|
||||
|
||||
@@ -180,7 +180,7 @@ MoFox_Bot可能会选择使用你的问候Action,发送回复:
|
||||
嗨!很开心见到你!😊
|
||||
```
|
||||
|
||||

|
||||

|
||||
|
||||
> **💡 小提示**:MoFox_Bot会智能地决定什么时候使用它。如果没有立即看到效果,多试几次不同的消息。
|
||||
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
# 自动化工具缓存系统使用指南
|
||||
|
||||
为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
|
||||
|
||||
## 核心概念
|
||||
|
||||
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
|
||||
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
|
||||
|
||||
## 如何为你的工具启用缓存
|
||||
|
||||
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
|
||||
|
||||
### 1. `enable_cache: bool`
|
||||
|
||||
这是启用缓存的总开关。
|
||||
|
||||
- **类型**: `bool`
|
||||
- **默认值**: `False`
|
||||
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyAwesomeTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
```
|
||||
|
||||
### 2. `cache_ttl: int`
|
||||
|
||||
设置缓存的生存时间(Time-To-Live)。
|
||||
|
||||
- **类型**: `int`
|
||||
- **单位**: 秒
|
||||
- **默认值**: `3600` (1小时)
|
||||
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyLongTermCacheTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 86400 # 缓存24小时
|
||||
```
|
||||
|
||||
### 3. `semantic_cache_query_key: Optional[str]`
|
||||
|
||||
启用语义缓存的关键。
|
||||
|
||||
- **类型**: `Optional[str]`
|
||||
- **默认值**: `None`
|
||||
- **作用**:
|
||||
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
|
||||
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
|
||||
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class WebSurfingTool(BaseTool):
|
||||
name: str = "web_search"
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
# ... 其他参数 ...
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 7200 # 缓存2小时
|
||||
semantic_cache_query_key: str = "query" # <-- 关键!
|
||||
```
|
||||
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
|
||||
|
||||
## 完整示例
|
||||
|
||||
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
|
||||
|
||||
```python
|
||||
# in your_plugin/tools/stock_checker.py
|
||||
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
class StockCheckerTool(BaseTool):
|
||||
"""
|
||||
一个用于查询股票价格的工具。
|
||||
"""
|
||||
name: str = "get_stock_price"
|
||||
description: str = "获取指定公司或股票代码的最新价格。"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
# 1. 开启缓存
|
||||
enable_cache: bool = True
|
||||
# 2. 股价信息缓存10分钟
|
||||
cache_ttl: int = 600
|
||||
# 3. 使用 "symbol" 参数进行语义搜索
|
||||
semantic_cache_query_key: str = "symbol"
|
||||
# --------------------
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
symbol = function_args.get("symbol")
|
||||
|
||||
# ... 这里是你调用外部API获取股票价格的逻辑 ...
|
||||
# price = await some_stock_api.get_price(symbol)
|
||||
price = 123.45 # 示例价格
|
||||
|
||||
return {
|
||||
"type": "stock_price_result",
|
||||
"content": f"{symbol} 的当前价格是 ${price}"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
|
||||
|
||||
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
|
||||
- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
|
||||
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。
|
||||
|
||||
---
|
||||
|
||||
现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## 📖 什么是工具
|
||||
|
||||
工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
|
||||
工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了MoFox_Bot能够获得的信息量。
|
||||
|
||||
### 🎯 工具的特点
|
||||
|
||||
@@ -191,8 +191,7 @@ class WeatherTool(BaseTool):
|
||||
name = "weather_query" # 清晰表达功能
|
||||
name = "knowledge_search" # 描述性强
|
||||
name = "stock_price_check" # 功能明确
|
||||
```
|
||||
#### ❌ 避免的命名
|
||||
```#### ❌ 避免的命名
|
||||
```python
|
||||
name = "tool1" # 无意义
|
||||
name = "wq" # 过于简短
|
||||
@@ -244,3 +243,130 @@ def _format_result(self, data):
|
||||
def _format_result(self, data):
|
||||
return str(data) # 直接返回原始数据
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# 自动化工具缓存系统使用指南
|
||||
|
||||
为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
|
||||
|
||||
## 核心概念
|
||||
|
||||
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
|
||||
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
|
||||
|
||||
## 如何为你的工具启用缓存
|
||||
|
||||
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
|
||||
|
||||
### 1. `enable_cache: bool`
|
||||
|
||||
这是启用缓存的总开关。
|
||||
|
||||
- **类型**: `bool`
|
||||
- **默认值**: `False`
|
||||
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyAwesomeTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
```
|
||||
|
||||
### 2. `cache_ttl: int`
|
||||
|
||||
设置缓存的生存时间(Time-To-Live)。
|
||||
|
||||
- **类型**: `int`
|
||||
- **单位**: 秒
|
||||
- **默认值**: `3600` (1小时)
|
||||
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyLongTermCacheTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 86400 # 缓存24小时
|
||||
```
|
||||
|
||||
### 3. `semantic_cache_query_key: Optional[str]`
|
||||
|
||||
启用语义缓存的关键。
|
||||
|
||||
- **类型**: `Optional[str]`
|
||||
- **默认值**: `None`
|
||||
- **作用**:
|
||||
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
|
||||
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
|
||||
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class WebSurfingTool(BaseTool):
|
||||
name: str = "web_search"
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
# ... 其他参数 ...
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 7200 # 缓存2小时
|
||||
semantic_cache_query_key: str = "query" # <-- 关键!
|
||||
```
|
||||
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
|
||||
|
||||
## 完整示例
|
||||
|
||||
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
|
||||
|
||||
```python
|
||||
# in your_plugin/tools/stock_checker.py
|
||||
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
class StockCheckerTool(BaseTool):
|
||||
"""
|
||||
一个用于查询股票价格的工具。
|
||||
"""
|
||||
name: str = "get_stock_price"
|
||||
description: str = "获取指定公司或股票代码的最新价格。"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
# 1. 开启缓存
|
||||
enable_cache: bool = True
|
||||
# 2. 股价信息缓存10分钟
|
||||
cache_ttl: int = 600
|
||||
# 3. 使用 "symbol" 参数进行语义搜索
|
||||
semantic_cache_query_key: str = "symbol"
|
||||
# --------------------
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
symbol = function_args.get("symbol")
|
||||
|
||||
# ... 这里是你调用外部API获取股票价格的逻辑 ...
|
||||
# price = await some_stock_api.get_price(symbol)
|
||||
price = 123.45 # 示例价格
|
||||
|
||||
return {
|
||||
"type": "stock_price_result",
|
||||
"content": f"{symbol} 的当前价格是 ${price}"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
|
||||
|
||||
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
|
||||
- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
|
||||
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。
|
||||
|
||||
---
|
||||
|
||||
现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。
|
||||
@@ -1,121 +0,0 @@
|
||||
# “月层计划”系统架构设计文档
|
||||
|
||||
## 1. 系统概述与目标
|
||||
|
||||
本系统旨在为MoFox_Bot引入一个动态的、由大型语言模型(LLM)驱动的“月层计划”机制。其核心目标是取代静态、预设的任务模板,转而利用LLM在程序启动时自动生成符合Bot人设的、具有时效性的月度计划。这些计划将被存储、管理,并在构建每日日程时被动态抽取和使用,从而极大地丰富日程内容的个性和多样性。
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
- **动态性与智能化:** 所有计划内容均由LLM实时生成,确保其独特性和创造性。
|
||||
- **人设一致性:** 计划的生成将严格围绕Bot的核心人设进行,强化角色形象。
|
||||
- **持久化与可管理:** 生成的计划将被存入专用数据库表,便于管理和追溯。
|
||||
- **消耗性与随机性:** 计划在使用后有一定几率被消耗(删除),模拟真实世界中计划的完成与迭代。
|
||||
|
||||
---
|
||||
|
||||
## 3. 系统核心流程规划
|
||||
|
||||
本系统包含两大核心流程:**启动时的计划生成流程**和**日程构建时的计划使用流程**。
|
||||
|
||||
### 3.1 流程一:启动时计划生成
|
||||
|
||||
此流程在每次程序启动时触发,负责填充当月的计划池。
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[程序启动] --> B{检查当月计划池};
|
||||
B -- 计划数量低于阈值 --> C[构建LLM Prompt];
|
||||
C -- prompt包含Bot人设、月份等信息 --> D[调用LLM服务];
|
||||
D -- LLM返回多个计划文本 --> E[解析并格式化计划];
|
||||
E -- 逐条处理 --> F[存入`monthly_plans`数据库表];
|
||||
F --> G[完成启动任务];
|
||||
B -- 计划数量充足 --> G;
|
||||
```
|
||||
|
||||
### 3.2 流程二:日程构建时计划使用
|
||||
|
||||
此流程在构建每日日程的提示词(Prompt)时触发。
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
H[构建日程Prompt] --> I{查询数据库};
|
||||
I -- 读取当月未使用的计划 --> J[随机抽取N个计划];
|
||||
J --> K[将计划文本嵌入日程Prompt];
|
||||
K --> L{随机数判断};
|
||||
L -- 概率命中 --> M[将已抽取的计划标记为删除];
|
||||
M --> N[完成Prompt构建];
|
||||
L -- 概率未命中 --> N;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 数据库模型设计
|
||||
|
||||
为支撑本系统,需要新增一个数据库表。
|
||||
|
||||
**表名:** `monthly_plans`
|
||||
|
||||
| 字段名 | 类型 | 描述 |
|
||||
| :--- | :--- | :--- |
|
||||
| `id` | Integer | 主键,自增。 |
|
||||
| `plan_text` | Text | 由LLM生成的计划内容原文。 |
|
||||
| `target_month` | String(7) | 计划所属的月份,格式为 "YYYY-MM"。 |
|
||||
| `is_deleted` | Boolean | 软删除标记,默认为 `false`。 |
|
||||
| `created_at` | DateTime | 记录创建时间。 |
|
||||
|
||||
---
|
||||
|
||||
## 5. 详细模块规划
|
||||
|
||||
### 5.1 LLM Prompt生成模块
|
||||
|
||||
- **职责:** 构建高质量的Prompt以引导LLM生成符合要求的计划。
|
||||
- **输入:** Bot人设描述、当前月份、期望生成的计划数量。
|
||||
- **输出:** 一个结构化的Prompt字符串。
|
||||
- **Prompt示例:**
|
||||
```
|
||||
你是一个[此处填入Bot人设描述,例如:活泼开朗、偶尔有些小迷糊的虚拟助手]。
|
||||
请为即将到来的[YYYY年MM月]设计[N]个符合你身份的月度计划或目标。
|
||||
|
||||
要求:
|
||||
1. 每个计划都是独立的、积极向上的。
|
||||
2. 语言风格要自然、口语化,符合你的性格。
|
||||
3. 每个计划用一句话或两句话简短描述。
|
||||
4. 以JSON格式返回,格式为:{"plans": ["计划一", "计划二", ...]}
|
||||
```
|
||||
|
||||
### 5.2 数据库交互模块
|
||||
|
||||
- **职责:** 提供对 `monthly_plans` 表的增、删、改、查接口。
|
||||
- **规划函数列表:**
|
||||
- `add_new_plans(plans: list[str], month: str)`: 批量添加新生成的计划。
|
||||
- `get_active_plans_for_month(month: str) -> list`: 获取指定月份所有未被删除的计划。
|
||||
- `soft_delete_plans(plan_ids: list[int])`: 将指定ID的计划标记为软删除。
|
||||
|
||||
### 5.3 配置项规划
|
||||
|
||||
需要在主配置文件 `config/bot_config.toml` 中添加以下配置项,以控制系统行为。
|
||||
|
||||
```toml
|
||||
# ----------------------------------------------------------------
|
||||
# 月层计划系统设置 (Monthly Plan System Settings)
|
||||
# ----------------------------------------------------------------
|
||||
[monthly_plan_system]
|
||||
|
||||
# 是否启用本功能
|
||||
enable = true
|
||||
|
||||
# 启动时,如果当月计划少于此数量,则触发LLM生成
|
||||
generation_threshold = 10
|
||||
|
||||
# 每次调用LLM期望生成的计划数量
|
||||
plans_per_generation = 5
|
||||
|
||||
# 计划被使用后,被删除的概率 (0.0 到 1.0)
|
||||
deletion_probability_on_use = 0.5
|
||||
```
|
||||
|
||||
---
|
||||
**文档结束。** 本文档纯粹为架构规划,旨在提供清晰的设计思路和开发指引,不包含任何实现代码。
|
||||
@@ -1,53 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Hello World 示例插件 (Hello World Plugin)",
|
||||
"version": "1.0.0",
|
||||
"description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
|
||||
"author": {
|
||||
"name": "MaiBot开发团队",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.8.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["demo", "example", "hello", "greeting", "tutorial"],
|
||||
"categories": ["Examples", "Tutorial"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "example",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "hello_greeting",
|
||||
"description": "向用户发送问候消息"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "bye_greeting",
|
||||
"description": "向用户发送告别消息",
|
||||
"activation_modes": ["keyword"],
|
||||
"keywords": ["再见", "bye", "88", "拜拜"]
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "time",
|
||||
"description": "查询当前时间",
|
||||
"pattern": "/time"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"问候和告别功能",
|
||||
"时间查询命令",
|
||||
"配置文件示例",
|
||||
"新手教程代码"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,282 +0,0 @@
|
||||
from typing import List, Tuple, Type, Any, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BaseTool,
|
||||
ComponentInfo,
|
||||
ActionActivationType,
|
||||
ConfigField,
|
||||
ToolParamType
|
||||
)
|
||||
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GetGroupListCommand(BaseCommand):
|
||||
"""获取群列表命令"""
|
||||
|
||||
command_name = "get_groups"
|
||||
command_description = "获取机器人加入的群列表"
|
||||
command_pattern = r"^/get_groups$"
|
||||
command_help = "获取机器人加入的群列表"
|
||||
command_examples = ["/get_groups"]
|
||||
intercept_message = True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
try:
|
||||
# 调用适配器命令API
|
||||
domain = "user.qzone.qq.com"
|
||||
response = await send_api.adapter_command_to_stream(
|
||||
action="get_cookies",
|
||||
platform="qq",
|
||||
params={"domain": domain},
|
||||
timeout=40.0,
|
||||
storage_message=False
|
||||
)
|
||||
text = str(response)
|
||||
await self.send_text(text)
|
||||
return True, "获取群列表成功", True
|
||||
|
||||
except Exception as e:
|
||||
await self.send_text(f"获取群列表失败: {str(e)}")
|
||||
return False, "获取群列表失败", True
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
"""比较两个数大小的工具"""
|
||||
|
||||
name = "compare_numbers"
|
||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||
parameters = [
|
||||
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
|
||||
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
|
||||
]
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
num1: int | float = function_args.get("num1") # type: ignore
|
||||
num2: int | float = function_args.get("num2") # type: ignore
|
||||
|
||||
try:
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
|
||||
return {"name": self.name, "content": result}
|
||||
except Exception as e:
|
||||
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
||||
|
||||
|
||||
# ===== Action组件 =====
|
||||
class HelloAction(BaseAction):
|
||||
"""问候Action - 简单的问候动作"""
|
||||
|
||||
# === 基本信息(必须填写)===
|
||||
action_name = "hello_greeting"
|
||||
action_description = "向用户发送问候消息"
|
||||
activation_type = ActionActivationType.ALWAYS # 始终激活
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
action_parameters = {"greeting_message": "要发送的问候消息"}
|
||||
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行问候动作 - 这是核心功能"""
|
||||
# 发送问候消息
|
||||
greeting_message = self.action_data.get("greeting_message", "")
|
||||
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
|
||||
message = base_message + greeting_message
|
||||
await self.send_text(message)
|
||||
|
||||
return True, "发送了问候消息"
|
||||
|
||||
|
||||
class ByeAction(BaseAction):
|
||||
"""告别Action - 只在用户说再见时激活"""
|
||||
|
||||
action_name = "bye_greeting"
|
||||
action_description = "向用户发送告别消息"
|
||||
|
||||
# 使用关键词激活
|
||||
activation_type = ActionActivationType.KEYWORD
|
||||
|
||||
# 关键词设置
|
||||
activation_keywords = ["再见", "bye", "88", "拜拜"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
action_parameters = {"bye_message": "要发送的告别消息"}
|
||||
action_require = [
|
||||
"用户要告别时使用",
|
||||
"当有人要离开时使用",
|
||||
"当有人和你说再见时使用",
|
||||
]
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
bye_message = self.action_data.get("bye_message", "")
|
||||
|
||||
message = f"再见!期待下次聊天!👋{bye_message}"
|
||||
await self.send_text(message)
|
||||
return True, "发送了告别消息"
|
||||
|
||||
|
||||
class TimeCommand(BaseCommand):
|
||||
"""时间查询Command - 响应/time命令"""
|
||||
|
||||
command_name = "time"
|
||||
command_description = "获取当前时间"
|
||||
|
||||
# === 命令设置(必须填写)===
|
||||
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
||||
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
"""执行时间查询"""
|
||||
import datetime
|
||||
|
||||
# 获取当前时间
|
||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime(time_format)
|
||||
|
||||
# 发送时间信息
|
||||
message = f"⏰ 当前时间:{time_str}"
|
||||
await self.send_text(message)
|
||||
|
||||
return True, f"显示了当前x时间: {time_str}", True
|
||||
|
||||
|
||||
# class PrintMessage(BaseEventHandler):
|
||||
# """打印消息事件处理器 - 处理打印消息事件"""
|
||||
#
|
||||
# event_type = EventType.ON_MESSAGE
|
||||
# handler_name = "print_message_handler"
|
||||
# handler_description = "打印接收到的消息"
|
||||
#
|
||||
# async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
||||
# """执行打印消息事件处理"""
|
||||
# # 打印接收到的消息
|
||||
#
|
||||
# if self.get_config("print_message.enabled", False):
|
||||
# print(f"接收到消息: {message.raw_message}")
|
||||
# return True, True, "消息已打印1"
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@register_plugin
|
||||
class HelloWorldPlugin(BasePlugin):
|
||||
"""Hello World插件 - 你的第一个MaiCore插件"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "hello_world_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
},
|
||||
"greeting": {
|
||||
"message": ConfigField(
|
||||
type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息"
|
||||
),
|
||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
||||
},
|
||||
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
||||
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
return [
|
||||
(HelloAction.get_action_info(), HelloAction),
|
||||
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
|
||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||
(TimeCommand.get_command_info(), TimeCommand), # 现在只能在群聊中使用
|
||||
(GetGroupListCommand.get_command_info(), GetGroupListCommand), # 添加获取群列表命令
|
||||
(PrivateInfoCommand.get_command_info(), PrivateInfoCommand), # 私聊专用命令
|
||||
(GroupOnlyAction.get_action_info(), GroupOnlyAction), # 群聊专用动作
|
||||
# (PrintMessage.get_handler_info(), PrintMessage),
|
||||
]
|
||||
|
||||
|
||||
# @register_plugin
|
||||
# class HelloWorldEventPlugin(BaseEPlugin):
|
||||
# """Hello World事件插件 - 处理问候和告别事件"""
|
||||
|
||||
# plugin_name = "hello_world_event_plugin"
|
||||
# enable_plugin = False
|
||||
# dependencies = []
|
||||
# python_dependencies = []
|
||||
# config_file_name = "event_config.toml"
|
||||
|
||||
# config_schema = {
|
||||
# "plugin": {
|
||||
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
|
||||
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
# },
|
||||
# }
|
||||
|
||||
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
# return [(PrintMessage.get_handler_info(), PrintMessage)]
|
||||
|
||||
# 添加一个新的私聊专用命令
|
||||
class PrivateInfoCommand(BaseCommand):
|
||||
command_name = "private_info"
|
||||
command_description = "获取私聊信息"
|
||||
command_pattern = r"^/私聊信息$"
|
||||
chat_type_allow = ChatType.PRIVATE # 仅在私聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行私聊信息命令"""
|
||||
try:
|
||||
await self.send_text("这是一个只能在私聊中使用的命令!")
|
||||
return True, "私聊信息命令执行成功", False
|
||||
except Exception as e:
|
||||
logger.error(f"私聊信息命令执行失败: {e}")
|
||||
return False, f"命令执行失败: {e}", False
|
||||
|
||||
# 添加一个新的仅群聊可用的Action
|
||||
class GroupOnlyAction(BaseAction):
|
||||
action_name = "group_only_test"
|
||||
action_description = "群聊专用测试动作"
|
||||
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行群聊专用测试动作"""
|
||||
try:
|
||||
await self.send_text("这是一个只能在群聊中执行的动作!")
|
||||
return True, "群聊专用动作执行成功"
|
||||
except Exception as e:
|
||||
logger.error(f"群聊专用动作执行失败: {e}")
|
||||
return False, f"动作执行失败: {e}"
|
||||
279
plugins/napcat_adapter_plugin/.gitignore
vendored
Normal file
@@ -0,0 +1,279 @@
|
||||
|
||||
log/
|
||||
logs/
|
||||
out/
|
||||
|
||||
.env
|
||||
.env.*
|
||||
.cursor
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
uv.lock
|
||||
llm_statistics.txt
|
||||
mongodb
|
||||
napcat
|
||||
run_dev.bat
|
||||
elua.confirmed
|
||||
# C extensions
|
||||
*.so
|
||||
/results
|
||||
config_backup/
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# jieba
|
||||
jieba.cache
|
||||
|
||||
# .vscode
|
||||
!.vscode/settings.json
|
||||
|
||||
# direnv
|
||||
/.direnv
|
||||
|
||||
# JetBrains
|
||||
.idea
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
# PyEnv
|
||||
# If using PyEnv and configured to use a specific Python version locally
|
||||
# a .local-version file will be created in the root of the project to specify the version.
|
||||
.python-version
|
||||
|
||||
OtherRes.txt
|
||||
|
||||
/eula.confirmed
|
||||
/privacy.confirmed
|
||||
|
||||
logs
|
||||
|
||||
.ruff_cache
|
||||
|
||||
.vscode
|
||||
|
||||
/config/*
|
||||
config/old/bot_config_20250405_212257.toml
|
||||
temp/
|
||||
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
config.toml
|
||||
feature.toml
|
||||
config.toml.back
|
||||
test
|
||||
data/NapcatAdapter.db
|
||||
data/NapcatAdapter.db-shm
|
||||
data/NapcatAdapter.db-wal
|
||||
1
plugins/napcat_adapter_plugin/CONSTS.py
Normal file
@@ -0,0 +1 @@
|
||||
PLUGIN_NAME = "napcat_adapter"
|
||||
42
plugins/napcat_adapter_plugin/_manifest.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "napcat_plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接",
|
||||
"author": {
|
||||
"name": "Windpicker_owo",
|
||||
"url": "https://github.com/Windpicker-owo"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.0",
|
||||
"max_version": "0.10.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
"repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
"keywords": ["qq", "bot", "napcat", "onebot", "api", "websocket"],
|
||||
"categories": ["protocol"],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"components": [
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "napcat_tool",
|
||||
"description": "NapCat QQ协议综合工具,提供消息发送、群管理、好友管理、文件操作等完整功能"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"消息发送与接收",
|
||||
"群管理功能",
|
||||
"好友管理功能",
|
||||
"文件上传下载",
|
||||
"AI语音功能",
|
||||
"群签到与戳一戳",
|
||||
"现有adapter连接"
|
||||
]
|
||||
}
|
||||
}
|
||||
6
plugins/napcat_adapter_plugin/event_handlers.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
69
plugins/napcat_adapter_plugin/event_types.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from enum import Enum
|
||||
|
||||
class NapcatEvent(Enum):
|
||||
# napcat插件事件枚举类
|
||||
class ON_RECEIVED(Enum):
|
||||
"""
|
||||
该分类下均为消息接受事件,只能由napcat_plugin触发
|
||||
"""
|
||||
TEXT = "napcat_on_received_text" # 接收到文本消息
|
||||
FACE = "napcat_on_received_face" # 接收到表情消息
|
||||
REPLY = "napcat_on_received_reply" # 接收到回复消息
|
||||
IMAGE = "napcat_on_received_image" # 接收到图像消息
|
||||
RECORD = "napcat_on_received_record" # 接收到语音消息
|
||||
VIDEO = "napcat_on_received_video" # 接收到视频消息
|
||||
AT = "napcat_on_received_at" # 接收到at消息
|
||||
DICE = "napcat_on_received_dice" # 接收到骰子消息
|
||||
SHAKE = "napcat_on_received_shake" # 接收到屏幕抖动消息
|
||||
JSON = "napcat_on_received_json" # 接收到JSON消息
|
||||
RPS = "napcat_on_received_rps" # 接收到魔法猜拳消息
|
||||
FRIEND_INPUT = "napcat_on_friend_input" # 好友正在输入
|
||||
|
||||
class ACCOUNT(Enum):
|
||||
"""
|
||||
该分类是对账户相关的操作,只能由外部触发,napcat_plugin负责处理
|
||||
"""
|
||||
SET_PROFILE = "napcat_set_qq_profile" # 设置账号信息
|
||||
GET_ONLINE_CLIENTS = "napcat_get_online_clients" # 获取当前账号在线客户端列表
|
||||
SET_ONLINE_STATUS = "napcat_set_online_status" # 设置在线状态
|
||||
GET_FRIENDS_WITH_CATEGORY = "napcat_get_friends_with_category" # 获取好友分组列表
|
||||
SET_AVATAR = "napcat_set_qq_avatar" # 设置头像
|
||||
SEND_LIKE = "napcat_send_like" # 点赞
|
||||
SET_FRIEND_ADD_REQUEST = "napcat_set_friend_add_request" # 处理好友请求
|
||||
SET_SELF_LONGNICK = "napcat_set_self_longnick" # 设置个性签名
|
||||
GET_LOGIN_INFO = "napcat_get_login_info" # 获取登录号信息
|
||||
GET_RECENT_CONTACT = "napcat_get_recent_contact" # 最近消息列表
|
||||
GET_STRANGER_INFO = "napcat_get_stranger_info" # 获取(指定)账号信息
|
||||
GET_FRIEND_LIST = "napcat_get_friend_list" # 获取好友列表
|
||||
GET_PROFILE_LIKE = "napcat_get_profile_like" # 获取点赞列表
|
||||
DELETE_FRIEND = "napcat_delete_friend" # 删除好友
|
||||
GET_USER_STATUS = "napcat_get_user_status" # 获取用户状态
|
||||
GET_STATUS = "napcat_get_status" # 获取状态
|
||||
GET_MINI_APP_ARK = "napcat_get_mini_app_ark" # 获取小程序卡片
|
||||
SET_DIY_ONLINE_STATUS = "napcat_set_diy_online_status" # 设置自定义在线状态
|
||||
|
||||
class MESSAGE(Enum):
|
||||
"""
|
||||
该分类是对信息相关的操作,只能由外部触发,napcat_plugin负责处理
|
||||
"""
|
||||
SEND_GROUP_POKE = "napcat_send_group_poke" # 发送群聊戳一戳
|
||||
SEND_PRIVATE_MSG = "napcat_send_private_msg" # 发送私聊消息
|
||||
SEND_POKE = "napcat_send_friend_poke" # 发送戳一戳
|
||||
DELETE_MSG = "napcat_delete_msg" # 撤回消息
|
||||
GET_GROUP_MSG_HISTORY = "napcat_get_group_msg_history" # 获取群历史消息
|
||||
GET_MSG = "napcat_get_msg" # 获取消息详情
|
||||
GET_FORWARD_MSG = "napcat_get_forward_msg" # 获取合并转发消息
|
||||
SET_MSG_EMOJI_LIKE = "napcat_set_msg_emoji_like" # 贴表情
|
||||
GET_FRIEND_MSG_HISTORY = "napcat_get_friend_msg_history" # 获取好友历史消息
|
||||
FETCH_EMOJI_LIKE = "napcat_fetch_emoji_like" # 获取贴表情详情
|
||||
SEND_FORWARF_MSG = "napcat_send_forward_msg" # 发送合并转发消息
|
||||
GET_RECOED = "napcat_get_record" # 获取语音消息详情
|
||||
SEND_GROUP_AI_RECORD = "napcat_send_group_ai_record" # 发送群AI语音
|
||||
|
||||
class GROUP(Enum):
|
||||
"""
|
||||
该分类是对群聊相关的操作,只能由外部触发,napcat_plugin负责处理
|
||||
"""
|
||||
|
||||
|
||||
|
||||
131
plugins/napcat_adapter_plugin/plugin.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import websockets as Server
|
||||
from . import event_types,CONSTS
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
# 添加当前目录到Python路径,这样可以识别src包
|
||||
current_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(current_dir))
|
||||
|
||||
from .src.recv_handler.message_handler import message_handler
|
||||
from .src.recv_handler.meta_event_handler import meta_event_handler
|
||||
from .src.recv_handler.notice_handler import notice_handler
|
||||
from .src.recv_handler.message_sending import message_send_instance
|
||||
from .src.send_handler import send_handler
|
||||
from .src.config import global_config
|
||||
from .src.config.features_config import features_manager
|
||||
from .src.config.migrate_features import auto_migrate_features
|
||||
from .src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
|
||||
from .src.response_pool import put_response, check_timeout_response
|
||||
from .src.websocket_manager import websocket_manager
|
||||
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""自动启动Adapter"""
|
||||
|
||||
handler_name: str = "launch_napcat_adapter_handler"
|
||||
handler_description: str = "自动启动napcat adapter"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [EventType.ON_START]
|
||||
|
||||
async def message_recv(self, server_connection: Server.ServerConnection):
|
||||
await message_handler.set_server_connection(server_connection)
|
||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
async def message_process(self):
|
||||
while True:
|
||||
message = await message_queue.get()
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await meta_event_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的post_type: {post_type}")
|
||||
message_queue.task_done()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def napcat_server(self):
|
||||
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
|
||||
mode = global_config.napcat_server.mode
|
||||
logger.info(f"正在启动 adapter,连接模式: {mode}")
|
||||
|
||||
try:
|
||||
await websocket_manager.start_connection(self.message_recv)
|
||||
except Exception as e:
|
||||
logger.error(f"启动 WebSocket 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, kwargs):
|
||||
# 执行功能配置迁移(如果需要)
|
||||
logger.info("检查功能配置迁移...")
|
||||
auto_migrate_features()
|
||||
|
||||
# 初始化功能管理器
|
||||
logger.info("正在初始化功能管理器...")
|
||||
features_manager.load_config()
|
||||
await features_manager.start_file_watcher(check_interval=2.0)
|
||||
logger.info("功能管理器初始化完成")
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
message_send_instance.maibot_router = router
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(self.napcat_server())
|
||||
asyncio.create_task(mmc_start_com())
|
||||
asyncio.create_task(self.message_process())
|
||||
asyncio.create_task(check_timeout_response())
|
||||
|
||||
@register_plugin
|
||||
class NapcatAdapterPlugin(BasePlugin):
|
||||
plugin_name = CONSTS.PLUGIN_NAME
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for e in event_types.NapcatEvent.ON_RECEIVED:
|
||||
event_manager.register_event(e ,allowed_triggers=[self.plugin_name])
|
||||
|
||||
def get_plugin_components(self):
|
||||
components = []
|
||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||
return components
|
||||
47
plugins/napcat_adapter_plugin/pyproject.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[project]
|
||||
name = "MaiBotNapcatAdapter"
|
||||
version = "0.4.8"
|
||||
description = "A MaiBot adapter for Napcat"
|
||||
dependencies = [
|
||||
"ruff>=0.12.9",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
include = ["*.py"]
|
||||
|
||||
# 行长度设置
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
|
||||
# 启用的规则
|
||||
select = [
|
||||
"E", # pycodestyle 错误
|
||||
"F", # pyflakes
|
||||
"B", # flake8-bugbear
|
||||
]
|
||||
|
||||
ignore = ["E711","E501"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
indent-style = "space"
|
||||
|
||||
|
||||
# 使用双引号表示字符串
|
||||
quote-style = "double"
|
||||
|
||||
# 尊重魔法尾随逗号
|
||||
# 例如:
|
||||
# items = [
|
||||
# "apple",
|
||||
# "banana",
|
||||
# "cherry",
|
||||
# ]
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# 自动检测合适的换行符
|
||||
line-ending = "auto"
|
||||
31
plugins/napcat_adapter_plugin/src/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from enum import Enum
|
||||
import tomlkit
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
|
||||
SET_EMOJI_LIKE = "set_emoji_like" # 设置表情回应
|
||||
SEND_AT_MESSAGE = "send_at_message" # 艾特用户并发送消息
|
||||
SEND_LIKE = "send_like" # 点赞
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
pyproject_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "pyproject.toml"
|
||||
)
|
||||
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
|
||||
project_data = toml_data.get("project", {})
|
||||
version = project_data.get("version", "unknown")
|
||||
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")
|
||||
5
plugins/napcat_adapter_plugin/src/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import global_config
|
||||
|
||||
__all__ = [
|
||||
"global_config",
|
||||
]
|
||||
148
plugins/napcat_adapter_plugin/src/config/config.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from rich.traceback import install
|
||||
|
||||
from .config_base import ConfigBase
|
||||
from .official_configs import (
|
||||
DebugConfig,
|
||||
MaiBotServerConfig,
|
||||
NapcatServerConfig,
|
||||
NicknameConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
|
||||
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
|
||||
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def update_config():
|
||||
"""更新配置文件,统一使用 config/old 目录进行备份"""
|
||||
# 确保目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
# 定义文件路径
|
||||
template_path = f"{TEMPLATE_DIR}/template_config.toml"
|
||||
config_path = f"{CONFIG_DIR}/config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
logger.info("主配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
|
||||
# 读取配置文件和模板文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
|
||||
|
||||
# 备份旧配置文件
|
||||
shutil.copy2(config_path, backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
"""将source字典的值更新到target字典中(如果target中存在相同的键)"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
nickname: NicknameConfig
|
||||
napcat_server: NapcatServerConfig
|
||||
maibot_server: MaiBotServerConfig
|
||||
voice: VoiceConfig
|
||||
debug: DebugConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 更新配置
|
||||
update_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
136
plugins/napcat_adapter_plugin/src/config/config_base.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: Dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
field_type = f.type
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
return field_type.from_dict(value)
|
||||
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_args_type = get_args(field_type)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||||
if field_origin_type is set:
|
||||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||||
if field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_args_type):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_args_type) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_args_type
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理Optional类型
|
||||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||||
if value is None:
|
||||
return None
|
||||
# 如果有数据,检查实际类型
|
||||
if type(value) not in field_args_type:
|
||||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||||
return cls._convert_field(value, field_args_type[0])
|
||||
|
||||
# 处理int, str, float, bool等基础类型
|
||||
if field_origin_type is None:
|
||||
if isinstance(value, field_type):
|
||||
return field_type(value)
|
||||
else:
|
||||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
# 处理其他类型
|
||||
if field_type is Any:
|
||||
return value
|
||||
|
||||
# 其他类型直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
146
plugins/napcat_adapter_plugin/src/config/config_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
配置文件工具模块
|
||||
提供统一的配置文件生成和管理功能
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs("config", exist_ok=True)
|
||||
os.makedirs("config/old", exist_ok=True)
|
||||
|
||||
|
||||
def create_config_from_template(
|
||||
config_path: str,
|
||||
template_path: str,
|
||||
config_name: str = "配置文件",
|
||||
should_exit: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
从模板创建配置文件的统一函数
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
template_path: 模板文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
should_exit: 创建后是否退出程序
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
# 确保配置目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
template_path_obj = Path(template_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if config_path_obj.exists():
|
||||
return False # 配置文件已存在,无需创建
|
||||
|
||||
logger.info(f"{config_name}不存在,从模板创建新配置")
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not template_path_obj.exists():
|
||||
logger.error(f"模板文件不存在: {template_path}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path_obj, config_path_obj)
|
||||
logger.info(f"已创建新{config_name}: {config_path}")
|
||||
|
||||
if should_exit:
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建{config_name}失败: {e}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
|
||||
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
|
||||
"""
|
||||
创建默认配置文件(使用字典数据)
|
||||
|
||||
Args:
|
||||
default_values: 默认配置值字典
|
||||
config_path: 配置文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
import tomlkit
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入默认配置
|
||||
with open(config_path_obj, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(default_values, f)
|
||||
|
||||
logger.info(f"已创建默认{config_name}: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建默认{config_name}失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
|
||||
"""
|
||||
备份配置文件
|
||||
|
||||
Args:
|
||||
config_path: 要备份的配置文件路径
|
||||
backup_dir: 备份目录
|
||||
|
||||
Returns:
|
||||
Optional[str]: 备份文件路径,失败时返回None
|
||||
"""
|
||||
try:
|
||||
config_path_obj = Path(config_path)
|
||||
if not config_path_obj.exists():
|
||||
return None
|
||||
|
||||
# 确保备份目录存在
|
||||
backup_dir_obj = Path(backup_dir)
|
||||
backup_dir_obj.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建备份文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
|
||||
backup_path = backup_dir_obj / backup_filename
|
||||
|
||||
# 备份文件
|
||||
shutil.copy2(config_path_obj, backup_path)
|
||||
logger.info(f"已备份配置文件到: {backup_path}")
|
||||
|
||||
return str(backup_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份配置文件失败: {e}")
|
||||
return None
|
||||
359
plugins/napcat_adapter_plugin/src/config/features_config.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .config_base import ConfigBase
|
||||
from .config_utils import create_config_from_template, create_default_config_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeaturesConfig(ConfigBase):
|
||||
"""功能配置类"""
|
||||
|
||||
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""群聊列表类型 白名单/黑名单"""
|
||||
|
||||
group_list: list[int] = field(default_factory=list)
|
||||
"""群聊列表"""
|
||||
|
||||
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""私聊列表类型 白名单/黑名单"""
|
||||
|
||||
private_list: list[int] = field(default_factory=list)
|
||||
"""私聊列表"""
|
||||
|
||||
ban_user_id: list[int] = field(default_factory=list)
|
||||
"""被封禁的用户ID列表,封禁后将无法与其进行交互"""
|
||||
|
||||
ban_qq_bot: bool = False
|
||||
"""是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互"""
|
||||
|
||||
enable_poke: bool = True
|
||||
"""是否启用戳一戳功能"""
|
||||
|
||||
ignore_non_self_poke: bool = False
|
||||
"""是否无视不是针对自己的戳一戳"""
|
||||
|
||||
poke_debounce_seconds: int = 3
|
||||
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
|
||||
|
||||
enable_reply_at: bool = True
|
||||
"""是否启用引用回复时艾特用户的功能"""
|
||||
|
||||
reply_at_rate: float = 0.5
|
||||
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
|
||||
|
||||
enable_video_analysis: bool = True
|
||||
"""是否启用视频识别功能"""
|
||||
|
||||
max_video_size_mb: int = 100
|
||||
"""视频文件最大大小限制(MB)"""
|
||||
|
||||
download_timeout: int = 60
|
||||
"""视频下载超时时间(秒)"""
|
||||
|
||||
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
|
||||
"""支持的视频格式"""
|
||||
|
||||
# 消息缓冲配置
|
||||
enable_message_buffer: bool = True
|
||||
"""是否启用消息缓冲合并功能"""
|
||||
|
||||
message_buffer_enable_group: bool = True
|
||||
"""是否启用群消息缓冲合并"""
|
||||
|
||||
message_buffer_enable_private: bool = True
|
||||
"""是否启用私聊消息缓冲合并"""
|
||||
|
||||
message_buffer_interval: float = 3.0
|
||||
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
|
||||
|
||||
message_buffer_initial_delay: float = 0.5
|
||||
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
|
||||
|
||||
message_buffer_max_components: int = 50
|
||||
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
|
||||
|
||||
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "!", ".", "。", "#", "%"])
|
||||
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
"""功能管理器,支持热重载"""
|
||||
|
||||
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
|
||||
self.config_path = Path(config_path)
|
||||
self.config: Optional[FeaturesConfig] = None
|
||||
self._file_watcher_task: Optional[asyncio.Task] = None
|
||||
self._last_modified: Optional[float] = None
|
||||
self._callbacks: list = []
|
||||
|
||||
def add_reload_callback(self, callback):
|
||||
"""添加配置重载回调函数"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def remove_reload_callback(self, callback):
|
||||
"""移除配置重载回调函数"""
|
||||
if callback in self._callbacks:
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
async def _notify_callbacks(self):
|
||||
"""通知所有回调函数配置已重载"""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(self.config)
|
||||
else:
|
||||
callback(self.config)
|
||||
except Exception as e:
|
||||
logger.error(f"配置重载回调执行失败: {e}")
|
||||
|
||||
def load_config(self) -> FeaturesConfig:
|
||||
"""加载功能配置文件"""
|
||||
try:
|
||||
# 检查配置文件是否存在,如果不存在则创建并退出程序
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"功能配置文件不存在: {self.config_path}")
|
||||
self._create_default_config()
|
||||
# 配置文件创建后程序应该退出,让用户检查配置
|
||||
logger.info("程序将退出,请检查功能配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
self.config = FeaturesConfig.from_dict(config_data)
|
||||
self._last_modified = self.config_path.stat().st_mtime
|
||||
logger.info(f"功能配置加载成功: {self.config_path}")
|
||||
return self.config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置加载失败: {e}")
|
||||
logger.critical("无法加载功能配置文件,程序退出")
|
||||
quit(1)
|
||||
|
||||
def _create_default_config(self):
|
||||
"""创建默认功能配置文件"""
|
||||
template_path = "template/features_template.toml"
|
||||
|
||||
# 尝试从模板创建配置文件
|
||||
if create_config_from_template(
|
||||
str(self.config_path),
|
||||
template_path,
|
||||
"功能配置文件",
|
||||
should_exit=False # 不在这里退出,由调用方决定
|
||||
):
|
||||
return
|
||||
|
||||
# 如果模板文件不存在,创建基本配置
|
||||
logger.info("模板文件不存在,创建基本功能配置")
|
||||
default_config = {
|
||||
"group_list_type": "whitelist",
|
||||
"group_list": [],
|
||||
"private_list_type": "whitelist",
|
||||
"private_list": [],
|
||||
"ban_user_id": [],
|
||||
"ban_qq_bot": False,
|
||||
"enable_poke": True,
|
||||
"ignore_non_self_poke": False,
|
||||
"poke_debounce_seconds": 3,
|
||||
"enable_reply_at": True,
|
||||
"reply_at_rate": 0.5,
|
||||
"enable_video_analysis": True,
|
||||
"max_video_size_mb": 100,
|
||||
"download_timeout": 60,
|
||||
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
|
||||
# 消息缓冲配置
|
||||
"enable_message_buffer": True,
|
||||
"message_buffer_enable_group": True,
|
||||
"message_buffer_enable_private": True,
|
||||
"message_buffer_interval": 3.0,
|
||||
"message_buffer_initial_delay": 0.5,
|
||||
"message_buffer_max_components": 50,
|
||||
"message_buffer_block_prefixes": ["/", "!", "!", ".", "。", "#", "%"]
|
||||
}
|
||||
|
||||
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
|
||||
logger.critical("无法创建功能配置文件")
|
||||
quit(1)
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
"""重新加载配置文件"""
|
||||
try:
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
|
||||
return False
|
||||
|
||||
current_modified = self.config_path.stat().st_mtime
|
||||
if self._last_modified and current_modified <= self._last_modified:
|
||||
return False # 文件未修改
|
||||
|
||||
old_config = self.config
|
||||
new_config = self.load_config()
|
||||
|
||||
# 检查配置是否真的发生了变化
|
||||
if old_config and self._configs_equal(old_config, new_config):
|
||||
return False
|
||||
|
||||
logger.info("功能配置已重载")
|
||||
await self._notify_callbacks()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置重载失败: {e}")
|
||||
return False
|
||||
|
||||
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
|
||||
"""比较两个配置是否相等"""
|
||||
return (
|
||||
config1.group_list_type == config2.group_list_type and
|
||||
set(config1.group_list) == set(config2.group_list) and
|
||||
config1.private_list_type == config2.private_list_type and
|
||||
set(config1.private_list) == set(config2.private_list) and
|
||||
set(config1.ban_user_id) == set(config2.ban_user_id) and
|
||||
config1.ban_qq_bot == config2.ban_qq_bot and
|
||||
config1.enable_poke == config2.enable_poke and
|
||||
config1.ignore_non_self_poke == config2.ignore_non_self_poke and
|
||||
config1.poke_debounce_seconds == config2.poke_debounce_seconds and
|
||||
config1.enable_reply_at == config2.enable_reply_at and
|
||||
config1.reply_at_rate == config2.reply_at_rate and
|
||||
config1.enable_video_analysis == config2.enable_video_analysis and
|
||||
config1.max_video_size_mb == config2.max_video_size_mb and
|
||||
config1.download_timeout == config2.download_timeout and
|
||||
set(config1.supported_formats) == set(config2.supported_formats) and
|
||||
# 消息缓冲配置比较
|
||||
config1.enable_message_buffer == config2.enable_message_buffer and
|
||||
config1.message_buffer_enable_group == config2.message_buffer_enable_group and
|
||||
config1.message_buffer_enable_private == config2.message_buffer_enable_private and
|
||||
config1.message_buffer_interval == config2.message_buffer_interval and
|
||||
config1.message_buffer_initial_delay == config2.message_buffer_initial_delay and
|
||||
config1.message_buffer_max_components == config2.message_buffer_max_components and
|
||||
set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
|
||||
)
|
||||
|
||||
async def start_file_watcher(self, check_interval: float = 1.0):
|
||||
"""启动文件监控,定期检查配置文件变化"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
logger.warning("文件监控已在运行")
|
||||
return
|
||||
|
||||
self._file_watcher_task = asyncio.create_task(
|
||||
self._file_watcher_loop(check_interval)
|
||||
)
|
||||
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}秒")
|
||||
|
||||
async def stop_file_watcher(self):
|
||||
"""停止文件监控"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
self._file_watcher_task.cancel()
|
||||
try:
|
||||
await self._file_watcher_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("功能配置文件监控已停止")
|
||||
|
||||
async def _file_watcher_loop(self, check_interval: float):
|
||||
"""文件监控循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(check_interval)
|
||||
await self.reload_config()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"文件监控循环出错: {e}")
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
def get_config(self) -> FeaturesConfig:
|
||||
"""获取当前功能配置"""
|
||||
if self.config is None:
|
||||
return self.load_config()
|
||||
return self.config
|
||||
|
||||
def is_group_allowed(self, group_id: int) -> bool:
|
||||
"""检查群聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.group_list_type == "whitelist":
|
||||
return group_id in config.group_list
|
||||
else: # blacklist
|
||||
return group_id not in config.group_list
|
||||
|
||||
def is_private_allowed(self, user_id: int) -> bool:
|
||||
"""检查私聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.private_list_type == "whitelist":
|
||||
return user_id in config.private_list
|
||||
else: # blacklist
|
||||
return user_id not in config.private_list
|
||||
|
||||
def is_user_banned(self, user_id: int) -> bool:
|
||||
"""检查用户是否被全局禁止"""
|
||||
config = self.get_config()
|
||||
return user_id in config.ban_user_id
|
||||
|
||||
def is_qq_bot_banned(self) -> bool:
|
||||
"""检查是否禁止QQ官方机器人"""
|
||||
config = self.get_config()
|
||||
return config.ban_qq_bot
|
||||
|
||||
def is_poke_enabled(self) -> bool:
|
||||
"""检查戳一戳功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_poke
|
||||
|
||||
def is_non_self_poke_ignored(self) -> bool:
|
||||
"""检查是否忽略非自己戳一戳"""
|
||||
config = self.get_config()
|
||||
return config.ignore_non_self_poke
|
||||
|
||||
def is_message_buffer_enabled(self) -> bool:
|
||||
"""检查消息缓冲功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_message_buffer
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查群消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查私聊消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_interval(self) -> float:
|
||||
"""获取消息缓冲间隔时间"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_interval
|
||||
|
||||
def get_message_buffer_initial_delay(self) -> float:
|
||||
"""获取消息缓冲初始延迟"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_initial_delay
|
||||
|
||||
def get_message_buffer_max_components(self) -> int:
|
||||
"""获取消息缓冲最大组件数量"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_max_components
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查是否启用群聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查是否启用私聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_block_prefixes(self) -> list[str]:
|
||||
"""获取消息缓冲屏蔽前缀列表"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_block_prefixes
|
||||
|
||||
|
||||
# 全局功能管理器实例
|
||||
features_manager = FeaturesManager()
|
||||
194
plugins/napcat_adapter_plugin/src/config/migrate_features.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
功能配置迁移脚本
|
||||
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def migrate_features_from_config(old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
|
||||
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
|
||||
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml"):
|
||||
"""
|
||||
从旧配置文件迁移功能设置到新的功能配置文件
|
||||
|
||||
Args:
|
||||
old_config_path: 旧配置文件路径
|
||||
new_features_path: 新功能配置文件路径
|
||||
template_path: 功能配置模板路径
|
||||
"""
|
||||
try:
|
||||
# 检查旧配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.warning(f"旧配置文件不存在: {old_config_path}")
|
||||
return False
|
||||
|
||||
# 读取旧配置文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
|
||||
# 检查是否有chat配置段和video配置段
|
||||
chat_config = old_config.get("chat", {})
|
||||
video_config = old_config.get("video", {})
|
||||
|
||||
# 检查是否有权限相关配置
|
||||
permission_keys = ["group_list_type", "group_list", "private_list_type",
|
||||
"private_list", "ban_user_id", "ban_qq_bot",
|
||||
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
has_permission_config = any(key in chat_config for key in permission_keys)
|
||||
has_video_config = any(key in video_config for key in video_keys)
|
||||
|
||||
if not has_permission_config and not has_video_config:
|
||||
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
|
||||
return False
|
||||
|
||||
# 确保新功能配置目录存在
|
||||
new_features_dir = Path(new_features_path).parent
|
||||
new_features_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 如果新功能配置文件已存在,先备份
|
||||
if os.path.exists(new_features_path):
|
||||
backup_path = f"{new_features_path}.backup"
|
||||
shutil.copy2(new_features_path, backup_path)
|
||||
logger.info(f"已备份现有功能配置文件到: {backup_path}")
|
||||
|
||||
# 创建新的功能配置
|
||||
new_features_config = {
|
||||
"group_list_type": chat_config.get("group_list_type", "whitelist"),
|
||||
"group_list": chat_config.get("group_list", []),
|
||||
"private_list_type": chat_config.get("private_list_type", "whitelist"),
|
||||
"private_list": chat_config.get("private_list", []),
|
||||
"ban_user_id": chat_config.get("ban_user_id", []),
|
||||
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
|
||||
"enable_poke": chat_config.get("enable_poke", True),
|
||||
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
|
||||
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
|
||||
"enable_video_analysis": video_config.get("enable_video_analysis", True),
|
||||
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
|
||||
"download_timeout": video_config.get("download_timeout", 60),
|
||||
"supported_formats": video_config.get("supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
|
||||
}
|
||||
|
||||
# 写入新的功能配置文件
|
||||
with open(new_features_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(new_features_config, f)
|
||||
|
||||
logger.info(f"功能配置已成功迁移到: {new_features_path}")
|
||||
|
||||
# 显示迁移的配置内容
|
||||
logger.info("迁移的配置内容:")
|
||||
for key, value in new_features_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置迁移失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
|
||||
"""
|
||||
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(config_path):
|
||||
logger.warning(f"配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
# 确保 config/old 目录存在
|
||||
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
|
||||
# 备份原配置文件到 config/old 目录
|
||||
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
|
||||
shutil.copy2(config_path, old_config_path)
|
||||
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
|
||||
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 移除chat段中的功能相关配置
|
||||
removed_keys = []
|
||||
if "chat" in config:
|
||||
chat_config = config["chat"]
|
||||
permission_keys = ["group_list_type", "group_list", "private_list_type",
|
||||
"private_list", "ban_user_id", "ban_qq_bot",
|
||||
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
|
||||
|
||||
for key in permission_keys:
|
||||
if key in chat_config:
|
||||
del chat_config[key]
|
||||
removed_keys.append(key)
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
|
||||
|
||||
# 移除video段中的配置
|
||||
if "video" in config:
|
||||
video_config = config["video"]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
video_removed_keys = []
|
||||
for key in video_keys:
|
||||
if key in video_config:
|
||||
del video_config[key]
|
||||
video_removed_keys.append(key)
|
||||
|
||||
if video_removed_keys:
|
||||
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
|
||||
removed_keys.extend(video_removed_keys)
|
||||
|
||||
# 如果video段为空,则删除整个段
|
||||
if not video_config:
|
||||
del config["video"]
|
||||
logger.info("已删除空的video配置段")
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"总共移除的配置项: {removed_keys}")
|
||||
|
||||
# 写回配置文件
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(config))
|
||||
|
||||
logger.info(f"已更新配置文件: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除功能配置失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def auto_migrate_features():
|
||||
"""
|
||||
自动执行功能配置迁移
|
||||
"""
|
||||
logger.info("开始自动功能配置迁移...")
|
||||
|
||||
# 执行迁移
|
||||
if migrate_features_from_config():
|
||||
logger.info("功能配置迁移成功")
|
||||
|
||||
# 询问是否要从旧配置文件中移除功能配置
|
||||
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
|
||||
# 在实际使用中,这里可以添加用户确认逻辑
|
||||
# 为了自动化,这里直接执行移除
|
||||
remove_features_from_old_config()
|
||||
|
||||
else:
|
||||
logger.info("功能配置迁移跳过或失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_migrate_features()
|
||||
67
plugins/napcat_adapter_plugin/src/config/official_configs.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
ADAPTER_PLATFORM = "qq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NicknameConfig(ConfigBase):
|
||||
nickname: str
|
||||
"""机器人昵称"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NapcatServerConfig(ConfigBase):
|
||||
mode: Literal["reverse", "forward"] = "reverse"
|
||||
"""连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8095
|
||||
"""端口号"""
|
||||
|
||||
url: str = ""
|
||||
"""正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws"""
|
||||
|
||||
access_token: str = ""
|
||||
"""WebSocket 连接的访问令牌,用于身份验证"""
|
||||
|
||||
heartbeat_interval: int = 30
|
||||
"""心跳间隔时间,单位为秒"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiBotServerConfig(ConfigBase):
|
||||
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
|
||||
"""平台名称,“qq”"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""MaiMCore的主机地址"""
|
||||
|
||||
port: int = 8000
|
||||
"""MaiMCore的端口号"""
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
use_tts: bool = False
|
||||
"""是否启用TTS功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
"""日志级别,默认为INFO"""
|
||||
163
plugins/napcat_adapter_plugin/src/database.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
"""
|
||||
表记录的方式:
|
||||
| group_id | user_id | lift_time |
|
||||
|----------|---------|-----------|
|
||||
|
||||
其中使用 user_id == 0 表示群全体禁言
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
"""
|
||||
|
||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
数据库管理类,负责与数据库交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
||||
self._ensure_database() # 确保数据库和表已创建
|
||||
|
||||
def _ensure_database(self) -> None:
|
||||
"""
|
||||
确保数据库和表已创建。
|
||||
"""
|
||||
logger.info("确保数据库文件和表已创建...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
||||
continue
|
||||
# 更新现有记录的 lift_time
|
||||
existing_record.lift_time = ban_user.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
# 创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
|
||||
logger.info("禁言记录已更新")
|
||||
|
||||
def get_ban_records(self) -> List[BanUser]:
|
||||
"""
|
||||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
为特定群组中的用户创建禁言记录。
|
||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
||||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
# 检查记录是否已存在
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
||||
)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
# 如果记录已存在,更新 lift_time
|
||||
existing_record.lift_time = ban_record.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {ban_record}")
|
||||
else:
|
||||
# 如果记录不存在,创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
||||
|
||||
|
||||
def delete_ban_record(self, ban_record: BanUser):
|
||||
"""
|
||||
删除特定用户在特定群组中的禁言记录。
|
||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
||||
"""
|
||||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
||||
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
320
plugins/napcat_adapter_plugin/src/message_buffer.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from .config.features_config import features_manager
|
||||
from .recv_handler import RealMessageType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextMessage:
|
||||
"""文本消息"""
|
||||
text: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferedSession:
|
||||
"""缓冲会话数据"""
|
||||
session_id: str
|
||||
messages: List[TextMessage] = field(default_factory=list)
|
||||
timer_task: Optional[asyncio.Task] = None
|
||||
delay_task: Optional[asyncio.Task] = None
|
||||
original_event: Any = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class SimpleMessageBuffer:
|
||||
|
||||
def __init__(self, merge_callback=None):
|
||||
"""
|
||||
初始化消息缓冲器
|
||||
|
||||
Args:
|
||||
merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数
|
||||
"""
|
||||
self.buffer_pool: Dict[str, BufferedSession] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.merge_callback = merge_callback
|
||||
self._shutdown = False
|
||||
|
||||
def get_session_id(self, event_data: Dict[str, Any]) -> str:
|
||||
"""根据事件数据生成会话ID"""
|
||||
message_type = event_data.get("message_type", "unknown")
|
||||
user_id = event_data.get("user_id", "unknown")
|
||||
|
||||
if message_type == "private":
|
||||
return f"private_{user_id}"
|
||||
elif message_type == "group":
|
||||
group_id = event_data.get("group_id", "unknown")
|
||||
return f"group_{group_id}_{user_id}"
|
||||
else:
|
||||
return f"{message_type}_{user_id}"
|
||||
|
||||
def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""从OneBot消息中提取纯文本,如果包含非文本内容则返回None"""
|
||||
text_parts = []
|
||||
has_non_text = False
|
||||
|
||||
logger.debug(f"正在提取消息文本,消息段数量: {len(message)}")
|
||||
|
||||
for msg_seg in message:
|
||||
msg_type = msg_seg.get("type", "")
|
||||
logger.debug(f"处理消息段类型: {msg_type}")
|
||||
|
||||
if msg_type == RealMessageType.text:
|
||||
text = msg_seg.get("data", {}).get("text", "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取到文本: {text[:50]}...")
|
||||
else:
|
||||
# 发现非文本消息段,标记为包含非文本内容
|
||||
has_non_text = True
|
||||
logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲")
|
||||
|
||||
# 如果包含非文本内容,则不进行缓冲
|
||||
if has_non_text:
|
||||
logger.debug("消息包含非文本内容,不进行缓冲")
|
||||
return None
|
||||
|
||||
if text_parts:
|
||||
combined_text = " ".join(text_parts).strip()
|
||||
logger.debug(f"成功提取纯文本: {combined_text[:50]}...")
|
||||
return combined_text
|
||||
|
||||
logger.debug("没有找到有效的文本内容")
|
||||
return None
|
||||
|
||||
def should_skip_message(self, text: str) -> bool:
|
||||
"""判断消息是否应该跳过缓冲"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
config = features_manager.get_config()
|
||||
block_prefixes = tuple(config.message_buffer_block_prefixes)
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def add_text_message(self, event_data: Dict[str, Any], message: List[Dict[str, Any]],
|
||||
original_event: Any = None) -> bool:
|
||||
"""
|
||||
添加文本消息到缓冲区
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
message: OneBot消息数组
|
||||
original_event: 原始事件对象
|
||||
|
||||
Returns:
|
||||
是否成功添加到缓冲区
|
||||
"""
|
||||
if self._shutdown:
|
||||
return False
|
||||
|
||||
config = features_manager.get_config()
|
||||
if not config.enable_message_buffer:
|
||||
return False
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config.message_buffer_enable_group:
|
||||
return False
|
||||
elif message_type == "private" and not config.message_buffer_enable_private:
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
text = self.extract_text_from_message(message)
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否应该跳过
|
||||
if self.should_skip_message(text):
|
||||
return False
|
||||
|
||||
session_id = self.get_session_id(event_data)
|
||||
|
||||
async with self.lock:
|
||||
# 获取或创建会话
|
||||
if session_id not in self.buffer_pool:
|
||||
self.buffer_pool[session_id] = BufferedSession(
|
||||
session_id=session_id,
|
||||
original_event=original_event
|
||||
)
|
||||
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config.message_buffer_max_components:
|
||||
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(
|
||||
session_id=session_id,
|
||||
original_event=original_event
|
||||
)
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 添加文本消息
|
||||
session.messages.append(TextMessage(text=text))
|
||||
session.original_event = original_event # 更新事件
|
||||
|
||||
# 取消之前的定时器
|
||||
await self._cancel_session_timers(session)
|
||||
|
||||
# 设置新的延迟任务
|
||||
session.delay_task = asyncio.create_task(
|
||||
self._wait_and_start_merge(session_id)
|
||||
)
|
||||
|
||||
logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...")
|
||||
return True
|
||||
|
||||
async def _cancel_session_timers(self, session: BufferedSession):
|
||||
"""取消会话的所有定时器"""
|
||||
for task_name in ['timer_task', 'delay_task']:
|
||||
task = getattr(session, task_name)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
setattr(session, task_name, None)
|
||||
|
||||
async def _wait_and_start_merge(self, session_id: str):
|
||||
"""等待初始延迟后开始合并定时器"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_initial_delay)
|
||||
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if session and session.messages:
|
||||
# 取消旧的定时器
|
||||
if session.timer_task and not session.timer_task.done():
|
||||
session.timer_task.cancel()
|
||||
try:
|
||||
await session.timer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 设置合并定时器
|
||||
session.timer_task = asyncio.create_task(
|
||||
self._wait_and_merge(session_id)
|
||||
)
|
||||
|
||||
async def _wait_and_merge(self, session_id: str):
|
||||
"""等待合并间隔后执行合并"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_interval)
|
||||
await self._merge_session(session_id)
|
||||
|
||||
async def _force_merge_session(self, session_id: str):
|
||||
"""强制合并会话(不等待定时器)"""
|
||||
await self._merge_session(session_id, force=True)
|
||||
|
||||
async def _merge_session(self, session_id: str, force: bool = False):
|
||||
"""合并会话中的消息"""
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if not session or not session.messages:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
try:
|
||||
# 合并文本消息
|
||||
text_parts = []
|
||||
for msg in session.messages:
|
||||
if msg.text.strip():
|
||||
text_parts.append(msg.text.strip())
|
||||
|
||||
if not text_parts:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
merged_text = ",".join(text_parts) # 使用中文逗号连接
|
||||
message_count = len(session.messages)
|
||||
|
||||
logger.info(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...")
|
||||
|
||||
# 调用回调函数
|
||||
if self.merge_callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(self.merge_callback):
|
||||
await self.merge_callback(session_id, merged_text, session.original_event)
|
||||
else:
|
||||
self.merge_callback(session_id, merged_text, session.original_event)
|
||||
except Exception as e:
|
||||
logger.error(f"消息合并回调执行失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并会话 {session_id} 时出错: {e}")
|
||||
finally:
|
||||
# 清理会话
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
|
||||
async def flush_session(self, session_id: str):
|
||||
"""强制刷新指定会话的缓冲区"""
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def flush_all(self):
|
||||
"""强制刷新所有会话的缓冲区"""
|
||||
session_ids = list(self.buffer_pool.keys())
|
||||
for session_id in session_ids:
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def get_buffer_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓冲区统计信息"""
|
||||
async with self.lock:
|
||||
stats = {
|
||||
"total_sessions": len(self.buffer_pool),
|
||||
"sessions": {}
|
||||
}
|
||||
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
stats["sessions"][session_id] = {
|
||||
"message_count": len(session.messages),
|
||||
"created_at": session.created_at,
|
||||
"age": time.time() - session.created_at
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def clear_expired_sessions(self, max_age: float = 300.0):
|
||||
"""清理过期的会话"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
async with self.lock:
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
if current_time - session.created_at > max_age:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
logger.info(f"清理过期会话: {session_id}")
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息缓冲器"""
|
||||
self._shutdown = True
|
||||
logger.info("正在关闭简化消息缓冲器...")
|
||||
|
||||
# 刷新所有缓冲区
|
||||
await self.flush_all()
|
||||
|
||||
# 确保所有任务都被取消
|
||||
async with self.lock:
|
||||
for session in list(self.buffer_pool.values()):
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.clear()
|
||||
|
||||
logger.info("简化消息缓冲器已关闭")
|
||||
26
plugins/napcat_adapter_plugin/src/mmc_com_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from .send_handler import send_handler
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
global_config.maibot_server.platform_name: TargetConfig(
|
||||
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
router = Router(route_config)
|
||||
|
||||
|
||||
async def mmc_start_com():
|
||||
logger.info("正在连接MaiBot")
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
|
||||
async def mmc_stop_com():
|
||||
await router.stop()
|
||||
89
plugins/napcat_adapter_plugin/src/recv_handler/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
input_status = "input_status" # 正在输入
|
||||
|
||||
class GroupBan:
|
||||
ban = "ban" # 禁言
|
||||
lift_ban = "lift_ban" # 解除禁言
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
json = "json" # json消息
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]
|
||||
@@ -0,0 +1,942 @@
|
||||
from ...event_types import NapcatEvent
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.common.logger import get_logger
|
||||
from ...CONSTS import PLUGIN_NAME
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from ..message_buffer import SimpleMessageBuffer
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_image_base64,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
get_message_detail,
|
||||
)
|
||||
from .qq_emoji_list import qq_face
|
||||
from .message_sending import message_send_instance
|
||||
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||
from ..video_handler import get_video_downloader
|
||||
from ..websocket_manager import websocket_manager
|
||||
|
||||
import time
|
||||
import json
|
||||
import websockets as Server
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
TemplateInfo,
|
||||
FormatInfo,
|
||||
)
|
||||
|
||||
|
||||
from ..response_pool import get_response
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
# 初始化简化消息缓冲器,传入回调函数
|
||||
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息处理器,清理资源"""
|
||||
if self.message_buffer:
|
||||
await self.message_buffer.shutdown()
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
def get_server_connection(self) -> Server.ServerConnection:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
async def check_allow_to_chat(
|
||||
self,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
ignore_bot: Optional[bool] = False,
|
||||
ignore_global_list: Optional[bool] = False,
|
||||
) -> bool:
|
||||
# sourcery skip: hoist-statement-from-if, merge-else-if-into-elif
|
||||
"""
|
||||
检查是否允许聊天
|
||||
Parameters:
|
||||
user_id: int: 用户ID
|
||||
group_id: int: 群ID
|
||||
ignore_bot: bool: 是否忽略机器人检查
|
||||
ignore_global_list: bool: 是否忽略全局黑名单检查
|
||||
Returns:
|
||||
bool: 是否允许聊天
|
||||
"""
|
||||
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
|
||||
logger.debug("开始检查聊天白名单/黑名单")
|
||||
|
||||
# 使用新的权限管理器检查权限
|
||||
if group_id:
|
||||
if not features_manager.is_group_allowed(group_id):
|
||||
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
else:
|
||||
if not features_manager.is_private_allowed(user_id):
|
||||
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查全局禁止名单
|
||||
if not ignore_global_list and features_manager.is_user_banned(user_id):
|
||||
logger.warning("用户在全局黑名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查QQ官方机器人
|
||||
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
|
||||
logger.debug("开始判断是否为机器人")
|
||||
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if member_info:
|
||||
is_bot = member_info.get("is_robot")
|
||||
if is_bot is None:
|
||||
logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新")
|
||||
else:
|
||||
if is_bot:
|
||||
logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单")
|
||||
self.bot_id_list[user_id] = True
|
||||
return False
|
||||
else:
|
||||
self.bot_id_list[user_id] = False
|
||||
|
||||
return True
|
||||
|
||||
async def handle_raw_message(self, raw_message: dict) -> None:
|
||||
# sourcery skip: low-code-quality, remove-unreachable-code
|
||||
"""
|
||||
从Napcat接受的原始消息处理
|
||||
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||
format_info: FormatInfo = FormatInfo(
|
||||
content_format=["text", "image", "emoji", "voice"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
) # 格式化信息
|
||||
if message_type == MessageType.private:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Private.friend:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), None):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 不存在群信息
|
||||
group_info: GroupInfo = None
|
||||
elif sub_type == MessageType.Private.group:
|
||||
"""
|
||||
本部分暂时不做支持,先放着
|
||||
"""
|
||||
logger.warning("群临时消息类型不支持")
|
||||
return None
|
||||
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
# 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取
|
||||
fetched_member_info: dict = await get_member_info(
|
||||
self.get_server_connection(),
|
||||
raw_message.get("group_id"),
|
||||
sender_info.get("user_id"),
|
||||
)
|
||||
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=nickname,
|
||||
user_cardname=None,
|
||||
)
|
||||
|
||||
# -------------------这里需要群信息吗?-------------------
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info: dict = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
|
||||
group_name = ""
|
||||
if fetched_group_info.get("group_name"):
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"私聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
elif message_type == MessageType.group:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Group.normal:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"群聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
|
||||
additional_config: dict = {}
|
||||
if global_config.voice.use_tts:
|
||||
additional_config["allow_tts"] = True
|
||||
|
||||
# 消息信息
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id=message_id,
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=template_info,
|
||||
format_info=format_info,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
# 处理实际信息
|
||||
if not raw_message.get("message"):
|
||||
logger.warning("原始消息内容为空")
|
||||
return None
|
||||
|
||||
# 获取Seg列表
|
||||
seg_message: List[Seg] = await self.handle_real_message(raw_message)
|
||||
if not seg_message:
|
||||
logger.warning("处理后消息内容为空")
|
||||
return None
|
||||
|
||||
# 检查是否需要使用消息缓冲
|
||||
if features_manager.is_message_buffer_enabled():
|
||||
# 检查消息类型是否启用缓冲
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
|
||||
logger.debug(f"原始消息段: {raw_message.get('message', [])}")
|
||||
|
||||
# 尝试添加到缓冲器
|
||||
buffered = await self.message_buffer.add_text_message(
|
||||
event_data={
|
||||
"message_type": message_type,
|
||||
"user_id": user_info.user_id,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
},
|
||||
message=raw_message.get("message", []),
|
||||
original_event={
|
||||
"message_info": message_info,
|
||||
"raw_message": raw_message
|
||||
}
|
||||
)
|
||||
|
||||
if buffered:
|
||||
logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
|
||||
return None # 缓冲成功,不立即发送
|
||||
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
|
||||
logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
|
||||
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
|
||||
|
||||
logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}")
|
||||
for i, seg in enumerate(seg_message):
|
||||
logger.debug(f"消息段 {i}: type={seg.type}, data={str(seg.data)[:100]}...")
|
||||
|
||||
submit_seg: Seg = Seg(
|
||||
type="seglist",
|
||||
data=seg_message,
|
||||
)
|
||||
# MessageBase创建
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message"),
|
||||
)
|
||||
|
||||
logger.info("发送到Maibot处理信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
处理实际消息
|
||||
Parameters:
|
||||
real_message: dict: 实际消息
|
||||
Returns:
|
||||
seg_message: list[Seg]: 处理后的消息段列表
|
||||
"""
|
||||
real_message: list = raw_message.get("message")
|
||||
if not real_message:
|
||||
return None
|
||||
seg_message: List[Seg] = []
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.TEXT,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("text处理失败")
|
||||
case RealMessageType.face:
|
||||
ret_seg = await self.handle_face_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FACE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("face处理失败或不支持")
|
||||
case RealMessageType.reply:
|
||||
if not in_reply:
|
||||
ret_seg = await self.handle_reply_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.REPLY,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message += ret_seg
|
||||
else:
|
||||
logger.warning("reply处理失败")
|
||||
case RealMessageType.image:
|
||||
logger.debug(f"开始处理图片消息段")
|
||||
ret_seg = await self.handle_image_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.IMAGE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
logger.debug(f"图片处理成功,添加到消息段")
|
||||
else:
|
||||
logger.warning("image处理失败")
|
||||
logger.debug(f"图片消息段处理完成")
|
||||
case RealMessageType.record:
|
||||
ret_seg = await self.handle_record_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RECORD,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.clear()
|
||||
seg_message.append(ret_seg)
|
||||
break # 使得消息只有record消息
|
||||
else:
|
||||
logger.warning("record处理失败或不支持")
|
||||
case RealMessageType.video:
|
||||
ret_seg = await self.handle_video_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.VIDEO,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("video处理失败")
|
||||
case RealMessageType.at:
|
||||
ret_seg = await self.handle_at_message(
|
||||
sub_message,
|
||||
raw_message.get("self_id"),
|
||||
raw_message.get("group_id"),
|
||||
)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.AT,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("at处理失败")
|
||||
case RealMessageType.rps:
|
||||
ret_seg = await self.handle_rps_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RPS,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("rps处理失败")
|
||||
case RealMessageType.dice:
|
||||
ret_seg = await self.handle_dice_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.DICE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("dice处理失败")
|
||||
case RealMessageType.shake:
|
||||
ret_seg = await self.handle_shake_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.SHAKE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("shake处理失败")
|
||||
case RealMessageType.share:
|
||||
print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌SHARE消息,快速速把你刚刚收到的消息截图发到MoFox-Bot群里!!!!\n\n")
|
||||
logger.warning("暂时不支持链接解析")
|
||||
case RealMessageType.forward:
|
||||
messages = await self._get_forward_message(sub_message)
|
||||
if not messages:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
ret_seg = await self.handle_forward_message(messages)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("转发消息处理失败")
|
||||
case RealMessageType.node:
|
||||
print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌NODE消息,快速速把你刚刚收到的消息截图发到MoFox-Bot群里!!!!\n\n")
|
||||
logger.warning("不支持转发消息节点解析")
|
||||
case RealMessageType.json:
|
||||
ret_seg = await self.handle_json_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.JSON,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("json处理失败")
|
||||
case _:
|
||||
logger.warning(f"未知消息类型: {sub_message_type}")
|
||||
|
||||
logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg")
|
||||
return seg_message
|
||||
|
||||
async def handle_text_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理纯文本信息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
plain_text: str = message_data.get("text")
|
||||
return Seg(type="text", data=plain_text)
|
||||
|
||||
async def handle_face_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理表情消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
face_raw_id: str = str(message_data.get("id"))
|
||||
if face_raw_id in qq_face:
|
||||
face_content: str = qq_face.get(face_raw_id)
|
||||
return Seg(type="text", data=face_content)
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
|
||||
async def handle_image_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理图片消息与表情包消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
image_sub_type = message_data.get("sub_type")
|
||||
try:
|
||||
logger.debug(f"开始下载图片: {message_data.get('url')}")
|
||||
image_base64 = await get_image_base64(message_data.get("url"))
|
||||
logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符")
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
"""这部分认为是图片"""
|
||||
return Seg(type="image", data=image_base64)
|
||||
elif image_sub_type not in [4, 9]:
|
||||
"""这部分认为是表情包"""
|
||||
return Seg(type="emoji", data=image_base64)
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
|
||||
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
处理at消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
self_id: int: 机器人QQ号
|
||||
group_id: int: 群号
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
if message_data:
|
||||
qq_id = message_data.get("qq")
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info: dict = await get_self_info(self.get_server_connection())
|
||||
if self_info:
|
||||
return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id=group_id, user_id=qq_id)
|
||||
if member_info:
|
||||
return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理语音消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
file: str = message_data.get("file")
|
||||
if not file:
|
||||
logger.warning("语音消息缺少文件信息")
|
||||
return None
|
||||
try:
|
||||
record_detail = await get_record_detail(self.get_server_connection(), file)
|
||||
if not record_detail:
|
||||
logger.warning("获取语音消息详情失败")
|
||||
return None
|
||||
audio_base64: str = record_detail.get("base64")
|
||||
except Exception as e:
|
||||
logger.error(f"语音消息处理失败: {str(e)}")
|
||||
return None
|
||||
if not audio_base64:
|
||||
logger.error("语音消息处理失败,未获取到音频数据")
|
||||
return None
|
||||
return Seg(type="voice", data=audio_base64)
|
||||
|
||||
async def handle_video_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理视频消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
|
||||
# 添加详细的调试信息
|
||||
logger.debug(f"视频消息原始数据: {raw_message}")
|
||||
logger.debug(f"视频消息数据: {message_data}")
|
||||
|
||||
# QQ视频消息可能包含url或filePath字段
|
||||
video_url = message_data.get("url")
|
||||
file_path = message_data.get("filePath") or message_data.get("file_path")
|
||||
|
||||
logger.info(f"视频URL: {video_url}")
|
||||
logger.info(f"视频文件路径: {file_path}")
|
||||
|
||||
# 优先使用本地文件路径,其次使用URL
|
||||
video_source = file_path if file_path else video_url
|
||||
|
||||
if not video_source:
|
||||
logger.warning("视频消息缺少URL或文件路径信息")
|
||||
logger.warning(f"完整消息数据: {message_data}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 检查是否为本地文件路径
|
||||
if file_path and Path(file_path).exists():
|
||||
logger.info(f"使用本地视频文件: {file_path}")
|
||||
# 直接读取本地文件
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(video_data).decode('utf-8')
|
||||
logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(type="video", data={
|
||||
"base64": video_base64,
|
||||
"filename": Path(file_path).name,
|
||||
"size_mb": len(video_data) / (1024 * 1024)
|
||||
})
|
||||
|
||||
elif video_url:
|
||||
logger.info(f"使用视频URL下载: {video_url}")
|
||||
# 使用video_handler下载视频
|
||||
video_downloader = get_video_downloader()
|
||||
download_result = await video_downloader.download_video(video_url)
|
||||
|
||||
if not download_result["success"]:
|
||||
logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}")
|
||||
logger.warning(f"失败的URL: {video_url}")
|
||||
return None
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(download_result["data"]).decode('utf-8')
|
||||
logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(type="video", data={
|
||||
"base64": video_base64,
|
||||
"filename": download_result.get("filename", "video.mp4"),
|
||||
"size_mb": len(download_result["data"]) / (1024 * 1024),
|
||||
"url": video_url
|
||||
})
|
||||
|
||||
else:
|
||||
logger.warning("既没有有效的本地文件路径,也没有有效的视频URL")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"视频消息处理失败: {str(e)}")
|
||||
logger.error(f"视频源: {video_source}")
|
||||
return None
|
||||
|
||||
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
||||
# sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
处理回复消息
|
||||
|
||||
"""
|
||||
raw_message_data: dict = raw_message.get("data")
|
||||
message_id: int = None
|
||||
if raw_message_data:
|
||||
message_id = raw_message_data.get("id")
|
||||
else:
|
||||
return None
|
||||
message_detail: dict = await get_message_detail(self.get_server_connection(), message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return None
|
||||
reply_message = await self.handle_real_message(message_detail, in_reply=True)
|
||||
if reply_message is None:
|
||||
reply_message = [Seg(type="text", data="(获取发言内容失败)")]
|
||||
sender_info: dict = message_detail.get("sender")
|
||||
sender_nickname: str = sender_info.get("nickname")
|
||||
sender_id: str = sender_info.get("user_id")
|
||||
seg_message: List[Seg] = []
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
seg_message.append(Seg(type="text", data="[回复 未知用户:"))
|
||||
else:
|
||||
seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:"))
|
||||
seg_message += reply_message
|
||||
seg_message.append(Seg(type="text", data="],说:"))
|
||||
return seg_message
|
||||
|
||||
async def handle_forward_message(self, message_list: list) -> Seg | None:
|
||||
"""
|
||||
递归处理转发消息,并按照动态方式确定图片处理方式
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表
|
||||
"""
|
||||
handled_message, image_count = await self._handle_forward_message(
|
||||
message_list, 0
|
||||
)
|
||||
handled_message: Seg
|
||||
image_count: int
|
||||
if not handled_message:
|
||||
return None
|
||||
|
||||
processed_message: Seg
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
logger.info("图片数量小于5,开始解析图片为base64")
|
||||
processed_message = await self._recursive_parse_image_seg(
|
||||
handled_message, True
|
||||
)
|
||||
elif image_count > 0:
|
||||
logger.info("图片数量大于等于5,开始解析图片为占位符")
|
||||
# 处理图片数量大于等于5的情况,此时解析图片为占位符
|
||||
processed_message = await self._recursive_parse_image_seg(
|
||||
handled_message, False
|
||||
)
|
||||
else:
|
||||
# 处理没有图片的情况,此时直接返回
|
||||
logger.info("没有图片,直接返回")
|
||||
processed_message = handled_message
|
||||
|
||||
# 添加转发消息提示
|
||||
forward_hint = Seg(type="text", data="这是一条转发消息:\n")
|
||||
return Seg(type="seglist", data=[forward_hint, processed_message])
|
||||
|
||||
async def handle_dice_message(self, raw_message: dict) -> Seg:
|
||||
message_data: dict = raw_message.get("data",{})
|
||||
res = message_data.get("result","")
|
||||
return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]")
|
||||
|
||||
async def handle_shake_message(self, raw_message: dict) -> Seg:
|
||||
return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]")
|
||||
|
||||
async def handle_json_message(self, raw_message: dict) -> Seg:
|
||||
message_data: str = raw_message.get("data","").get("data","")
|
||||
res = json.loads(message_data)
|
||||
return Seg(type="json", data=res)
|
||||
|
||||
async def handle_rps_message(self, raw_message: dict) -> Seg:
|
||||
message_data: dict = raw_message.get("data",{})
|
||||
res = message_data.get("result","")
|
||||
if res == "1":
|
||||
shape = "布"
|
||||
elif res == "2":
|
||||
shape = "剪刀"
|
||||
else:
|
||||
shape = "石头"
|
||||
return Seg(type="text", data=f"[发送了一个魔法猜拳表情,结果是:{shape}]")
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
|
||||
# sourcery skip: merge-else-if-into-elif
|
||||
if to_image:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[图片]")
|
||||
return Seg(type="image", data=encoded_image)
|
||||
elif seg_data.type == "emoji":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[表情包]")
|
||||
return Seg(type="emoji", data=encoded_image)
|
||||
else:
|
||||
logger.info(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
else:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
return Seg(type="text", data="[图片]")
|
||||
elif seg_data.type == "emoji":
|
||||
return Seg(type="text", data="[动画表情]")
|
||||
else:
|
||||
logger.info(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
递归处理实际转发消息
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段
|
||||
layer: int: 当前层级
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
image_count: int: 图片数量
|
||||
"""
|
||||
seg_list: List[Seg] = []
|
||||
image_count = 0
|
||||
if message_list is None:
|
||||
return None, 0
|
||||
for sub_message in message_list:
|
||||
sub_message: dict
|
||||
sender_info: dict = sub_message.get("sender")
|
||||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg = Seg(type="text", data="\n")
|
||||
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
|
||||
if not message_of_sub_message_list:
|
||||
logger.warning("转发消息内容为空")
|
||||
continue
|
||||
message_of_sub_message = message_of_sub_message_list[0]
|
||||
if message_of_sub_message.get("type") == RealMessageType.forward:
|
||||
if layer >= 3:
|
||||
full_seg_data = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n",
|
||||
)
|
||||
else:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
contents = sub_message_data.get("content")
|
||||
seg_data, count = await self._handle_forward_message(contents, layer + 1)
|
||||
image_count += count
|
||||
head_tip = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n",
|
||||
)
|
||||
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
|
||||
seg_list.append(full_seg_data)
|
||||
elif message_of_sub_message.get("type") == RealMessageType.text:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
text_message = sub_message_data.get("text")
|
||||
seg_data = Seg(type="text", data=text_message)
|
||||
data_list: List[Any] = []
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
seg_list.append(Seg(type="seglist", data=data_list))
|
||||
elif message_of_sub_message.get("type") == RealMessageType.image:
|
||||
image_count += 1
|
||||
image_data = message_of_sub_message.get("data")
|
||||
sub_type = image_data.get("sub_type")
|
||||
image_url = image_data.get("url")
|
||||
data_list: List[Any] = []
|
||||
if sub_type == 0:
|
||||
seg_data = Seg(type="image", data=image_url)
|
||||
else:
|
||||
seg_data = Seg(type="emoji", data=image_url)
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
full_seg_data = Seg(type="seglist", data=data_list)
|
||||
seg_list.append(full_seg_data)
|
||||
return Seg(type="seglist", data=seg_list), image_count
|
||||
|
||||
async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None:
|
||||
forward_message_data: Dict = raw_message.get("data")
|
||||
if not forward_message_data:
|
||||
logger.warning("转发消息内容为空")
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_forward_msg",
|
||||
"params": {"message_id": forward_message_id},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
logger.error("没有可用的 WebSocket 连接")
|
||||
return None
|
||||
await connection.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取转发消息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
return response_data.get("messages")
|
||||
|
||||
async def _send_buffered_message(self, session_id: str, merged_text: str, original_event: Dict[str, Any]):
|
||||
"""发送缓冲的合并消息"""
|
||||
try:
|
||||
# 从原始事件数据中提取信息
|
||||
message_info = original_event.get("message_info")
|
||||
raw_message = original_event.get("raw_message")
|
||||
|
||||
if not message_info or not raw_message:
|
||||
logger.error("缓冲消息缺少必要信息")
|
||||
return
|
||||
|
||||
# 创建合并后的消息段 - 将合并的文本转换为Seg格式
|
||||
from maim_message import Seg
|
||||
merged_seg = Seg(type="text", data=merged_text)
|
||||
submit_seg = Seg(type="seglist", data=[merged_seg])
|
||||
|
||||
# 创建新的消息ID
|
||||
import time
|
||||
new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}"
|
||||
|
||||
# 更新消息信息
|
||||
from maim_message import BaseMessageInfo, MessageBase
|
||||
buffered_message_info = BaseMessageInfo(
|
||||
platform=message_info.platform,
|
||||
message_id=new_message_id,
|
||||
time=time.time(),
|
||||
user_info=message_info.user_info,
|
||||
group_info=message_info.group_info,
|
||||
template_info=message_info.template_info,
|
||||
format_info=message_info.format_info,
|
||||
additional_config=message_info.additional_config,
|
||||
)
|
||||
|
||||
# 创建MessageBase
|
||||
message_base = MessageBase(
|
||||
message_info=buffered_message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message", ""),
|
||||
)
|
||||
|
||||
logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送缓冲消息失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
message_handler = MessageHandler()
|
||||
@@ -0,0 +1,32 @@
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
"""
|
||||
负责把消息发送到麦麦
|
||||
"""
|
||||
|
||||
maibot_router: Router = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def message_send(self, message_base: MessageBase) -> bool:
|
||||
"""
|
||||
发送消息
|
||||
Parameters:
|
||||
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
|
||||
"""
|
||||
try:
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
return send_status
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
|
||||
|
||||
message_send_instance = MessageSending()
|
||||
@@ -0,0 +1,50 @@
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from ..config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from . import MetaEventType
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""
|
||||
处理Meta事件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interval = global_config.napcat_server.heartbeat_interval
|
||||
self._interval_checking = False
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
if event_type == MetaEventType.lifecycle:
|
||||
sub_type = message.get("sub_type")
|
||||
if sub_type == MetaEventType.Lifecycle.connect:
|
||||
self_id = message.get("self_id")
|
||||
self.last_heart_beat = time.time()
|
||||
logger.info(f"Bot {self_id} 连接成功")
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
self_id = message.get("self_id")
|
||||
logger.warning(f"Bot {self_id} Napcat 端异常!")
|
||||
|
||||
async def check_heartbeat(self, id: int) -> None:
|
||||
self._interval_checking = True
|
||||
while True:
|
||||
now_time = time.time()
|
||||
if now_time - self.last_heart_beat > self.interval * 2:
|
||||
logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!")
|
||||
break
|
||||
else:
|
||||
logger.debug("心跳正常")
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
|
||||
meta_event_handler = MetaEventHandler()
|
||||
556
plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py
Normal file
@@ -0,0 +1,556 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import websockets as Server
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from ..database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
from ..websocket_manager import websocket_manager
|
||||
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_self_info,
|
||||
get_stranger_info,
|
||||
read_ban_list,
|
||||
)
|
||||
|
||||
from ...CONSTS import PLUGIN_NAME
|
||||
|
||||
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
|
||||
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
|
||||
|
||||
|
||||
class NoticeHandler:
|
||||
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
|
||||
lifted_list: list[BanUser] = [] # 已经自然解除禁言
|
||||
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection | None = None
|
||||
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
while self.server_connection.state != Server.State.OPEN:
|
||||
await asyncio.sleep(0.5)
|
||||
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
|
||||
|
||||
asyncio.create_task(self.auto_lift_detect())
|
||||
asyncio.create_task(self.send_notice())
|
||||
asyncio.create_task(self.handle_natural_lift())
|
||||
|
||||
def get_server_connection(self) -> Server.ServerConnection:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||
"""
|
||||
将用户禁言记录添加到self.banned_list中
|
||||
如果是全体禁言,则user_id为0
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
lift_time = -1
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
||||
for record in self.banned_list:
|
||||
if is_identical(record, ban_record):
|
||||
self.banned_list.remove(record)
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 作为更新
|
||||
return
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
||||
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
||||
self.lifted_list.append(ban_record)
|
||||
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
|
||||
|
||||
async def handle_notice(self, raw_message: dict) -> None:
|
||||
notice_type = raw_message.get("notice_type")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
group_id = raw_message.get("group_id")
|
||||
user_id = raw_message.get("user_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
|
||||
handled_message: Seg = None
|
||||
user_info: UserInfo = None
|
||||
system_notice: bool = False
|
||||
|
||||
match notice_type:
|
||||
case NoticeType.friend_recall:
|
||||
logger.info("好友撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.group_recall:
|
||||
logger.info("群内用户撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.notify:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
logger.info("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
case NoticeType.Notify.input_status:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT,plugin_name=PLUGIN_NAME)
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_ban:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.GroupBan.ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理群禁言")
|
||||
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case NoticeType.GroupBan.lift_ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理解除群禁言")
|
||||
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case _:
|
||||
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
|
||||
case _:
|
||||
logger.warning(f"不支持的notice类型: {notice_type}")
|
||||
return None
|
||||
if not handled_message or not user_info:
|
||||
logger.warning("notice处理失败或不支持")
|
||||
return None
|
||||
|
||||
group_info: GroupInfo = None
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=FormatInfo(
|
||||
content_format=["text", "notify"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
),
|
||||
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=handled_message,
|
||||
raw_message=json.dumps(raw_message),
|
||||
)
|
||||
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
else:
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_poke_notify(
|
||||
self, raw_message: dict, group_id: int, user_id: int
|
||||
) -> Tuple[Seg | None, UserInfo | None]:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
|
||||
self_info: dict = await get_self_info(self.get_server_connection())
|
||||
|
||||
if not self_info:
|
||||
logger.error("自身信息获取失败")
|
||||
return None, None
|
||||
|
||||
self_id = raw_message.get("self_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
|
||||
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
|
||||
if self_id == target_id:
|
||||
current_time = time.time()
|
||||
debounce_seconds = features_manager.get_config().poke_debounce_seconds
|
||||
|
||||
if self.last_poke_time > 0:
|
||||
time_diff = current_time - self.last_poke_time
|
||||
if time_diff < debounce_seconds:
|
||||
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
|
||||
return None, None
|
||||
|
||||
# 记录这次戳一戳的时间
|
||||
self.last_poke_time = current_time
|
||||
|
||||
target_name: str = None
|
||||
raw_info: list = raw_message.get("raw_info")
|
||||
|
||||
if group_id:
|
||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
else:
|
||||
user_qq_info: dict = await get_stranger_info(self.get_server_connection(), user_id)
|
||||
if user_qq_info:
|
||||
user_name = user_qq_info.get("nickname")
|
||||
user_cardname = user_qq_info.get("card")
|
||||
else:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.info("无法获取戳一戳对方的用户昵称")
|
||||
|
||||
# 计算Seg
|
||||
if self_id == target_id:
|
||||
display_name = ""
|
||||
target_name = self_info.get("nickname")
|
||||
|
||||
elif self_id == user_id:
|
||||
# 让ada不发送麦麦戳别人的消息
|
||||
return None, None
|
||||
|
||||
else:
|
||||
# 如果配置为忽略不是针对自己的戳一戳,则直接返回None
|
||||
if features_manager.is_non_self_poke_ignored():
|
||||
logger.info("忽略不是针对自己的戳一戳消息")
|
||||
return None, None
|
||||
|
||||
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
|
||||
if group_id:
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, target_id)
|
||||
if fetched_member_info:
|
||||
target_name = fetched_member_info.get("nickname")
|
||||
else:
|
||||
target_name = "QQ用户"
|
||||
logger.info("无法获取被戳一戳方的用户昵称")
|
||||
display_name = user_name
|
||||
else:
|
||||
return None, None
|
||||
|
||||
first_txt: str = "戳了戳"
|
||||
second_txt: str = ""
|
||||
try:
|
||||
first_txt = raw_info[2].get("txt", "戳了戳")
|
||||
second_txt = raw_info[4].get("txt", "")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="text",
|
||||
data=f"{display_name}{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
user_id = raw_message.get("user_id")
|
||||
banned_user_info: UserInfo = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
sub_type: str = None
|
||||
|
||||
duration = raw_message.get("duration")
|
||||
if duration is None:
|
||||
logger.error("禁言时长不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
if user_id == 0: # 为全体禁言
|
||||
sub_type: str = "whole_ban"
|
||||
self._ban_operation(group_id)
|
||||
else: # 为单人禁言
|
||||
# 获取被禁言人的信息
|
||||
sub_type: str = "ban"
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
banned_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"duration": duration,
|
||||
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
|
||||
},
|
||||
)
|
||||
|
||||
return seg_data, operator_info
|
||||
|
||||
async def handle_lift_ban_notify(
|
||||
self, raw_message: dict, group_id: int
|
||||
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
sub_type: str = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
lifted_user_info: UserInfo = None
|
||||
|
||||
user_id = raw_message.get("user_id")
|
||||
if user_id == 0: # 全体禁言解除
|
||||
sub_type = "whole_lift_ban"
|
||||
self._lift_operation(group_id)
|
||||
else: # 单人禁言解除
|
||||
sub_type = "lift_ban"
|
||||
# 获取被解除禁言人的信息
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._lift_operation(group_id, user_id)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
|
||||
},
|
||||
)
|
||||
return seg_data, operator_info
|
||||
|
||||
async def put_notice(self, message_base: MessageBase) -> None:
|
||||
"""
|
||||
将处理后的通知消息放入通知队列
|
||||
"""
|
||||
if notice_queue.full() or unsuccessful_notice_queue.full():
|
||||
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
|
||||
else:
|
||||
await notice_queue.put(message_base)
|
||||
|
||||
async def handle_natural_lift(self) -> None:
|
||||
while True:
|
||||
if len(self.lifted_list) != 0:
|
||||
lift_record = self.lifted_list.pop()
|
||||
group_id = lift_record.group_id
|
||||
user_id = lift_record.user_id
|
||||
|
||||
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
|
||||
|
||||
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
||||
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=time.time(),
|
||||
user_info=None, # 自然解除禁言没有操作者
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=None,
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=seg_message,
|
||||
raw_message=json.dumps(
|
||||
{
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_ban",
|
||||
"sub_type": "lift_ban",
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"operator_id": None, # 自然解除禁言没有操作者
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
await asyncio.sleep(0.5) # 确保队列处理间隔
|
||||
else:
|
||||
await asyncio.sleep(5) # 每5秒检查一次
|
||||
|
||||
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None
|
||||
|
||||
if user_id == 0: # 理论上永远不会触发
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "whole_lift_ban",
|
||||
"lifted_user_info": None,
|
||||
},
|
||||
)
|
||||
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "lift_ban",
|
||||
"lifted_user_info": lifted_user_info.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
async def auto_lift_detect(self) -> None:
|
||||
while True:
|
||||
if len(self.banned_list) == 0:
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
for ban_record in self.banned_list:
|
||||
if ban_record.user_id == 0 or ban_record.lift_time == -1:
|
||||
continue
|
||||
if ban_record.lift_time <= int(time.time()):
|
||||
# 触发自然解除禁言
|
||||
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
self.lifted_list.append(ban_record)
|
||||
self.banned_list.remove(ban_record)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def send_notice(self) -> None:
|
||||
"""
|
||||
发送通知消息到Napcat
|
||||
"""
|
||||
while True:
|
||||
if not unsuccessful_notice_queue.empty():
|
||||
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
unsuccessful_notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
to_be_send: MessageBase = await notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
|
||||
notice_handler = NoticeHandler()
|
||||
250
plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py
Normal file
@@ -0,0 +1,250 @@
|
||||
qq_face: dict = {
|
||||
"0": "[表情:惊讶]",
|
||||
"1": "[表情:撇嘴]",
|
||||
"2": "[表情:色]",
|
||||
"3": "[表情:发呆]",
|
||||
"4": "[表情:得意]",
|
||||
"5": "[表情:流泪]",
|
||||
"6": "[表情:害羞]",
|
||||
"7": "[表情:闭嘴]",
|
||||
"8": "[表情:睡]",
|
||||
"9": "[表情:大哭]",
|
||||
"10": "[表情:尴尬]",
|
||||
"11": "[表情:发怒]",
|
||||
"12": "[表情:调皮]",
|
||||
"13": "[表情:呲牙]",
|
||||
"14": "[表情:微笑]",
|
||||
"15": "[表情:难过]",
|
||||
"16": "[表情:酷]",
|
||||
"18": "[表情:抓狂]",
|
||||
"19": "[表情:吐]",
|
||||
"20": "[表情:偷笑]",
|
||||
"21": "[表情:可爱]",
|
||||
"22": "[表情:白眼]",
|
||||
"23": "[表情:傲慢]",
|
||||
"24": "[表情:饥饿]",
|
||||
"25": "[表情:困]",
|
||||
"26": "[表情:惊恐]",
|
||||
"27": "[表情:流汗]",
|
||||
"28": "[表情:憨笑]",
|
||||
"29": "[表情:悠闲]",
|
||||
"30": "[表情:奋斗]",
|
||||
"31": "[表情:咒骂]",
|
||||
"32": "[表情:疑问]",
|
||||
"33": "[表情: 嘘]",
|
||||
"34": "[表情:晕]",
|
||||
"35": "[表情:折磨]",
|
||||
"36": "[表情:衰]",
|
||||
"37": "[表情:骷髅]",
|
||||
"38": "[表情:敲打]",
|
||||
"39": "[表情:再见]",
|
||||
"41": "[表情:发抖]",
|
||||
"42": "[表情:爱情]",
|
||||
"43": "[表情:跳跳]",
|
||||
"46": "[表情:猪头]",
|
||||
"49": "[表情:拥抱]",
|
||||
"53": "[表情:蛋糕]",
|
||||
"56": "[表情:刀]",
|
||||
"59": "[表情:便便]",
|
||||
"60": "[表情:咖啡]",
|
||||
"63": "[表情:玫瑰]",
|
||||
"64": "[表情:凋谢]",
|
||||
"66": "[表情:爱心]",
|
||||
"67": "[表情:心碎]",
|
||||
"74": "[表情:太阳]",
|
||||
"75": "[表情:月亮]",
|
||||
"76": "[表情:赞]",
|
||||
"77": "[表情:踩]",
|
||||
"78": "[表情:握手]",
|
||||
"79": "[表情:胜利]",
|
||||
"85": "[表情:飞吻]",
|
||||
"86": "[表情:怄火]",
|
||||
"89": "[表情:西瓜]",
|
||||
"96": "[表情:冷汗]",
|
||||
"97": "[表情:擦汗]",
|
||||
"98": "[表情:抠鼻]",
|
||||
"99": "[表情:鼓掌]",
|
||||
"100": "[表情:糗大了]",
|
||||
"101": "[表情:坏笑]",
|
||||
"102": "[表情:左哼哼]",
|
||||
"103": "[表情:右哼哼]",
|
||||
"104": "[表情:哈欠]",
|
||||
"105": "[表情:鄙视]",
|
||||
"106": "[表情:委屈]",
|
||||
"107": "[表情:快哭了]",
|
||||
"108": "[表情:阴险]",
|
||||
"109": "[表情:左亲亲]",
|
||||
"110": "[表情:吓]",
|
||||
"111": "[表情:可怜]",
|
||||
"112": "[表情:菜刀]",
|
||||
"114": "[表情:篮球]",
|
||||
"116": "[表情:示爱]",
|
||||
"118": "[表情:抱拳]",
|
||||
"119": "[表情:勾引]",
|
||||
"120": "[表情:拳头]",
|
||||
"121": "[表情:差劲]",
|
||||
"123": "[表情:NO]",
|
||||
"124": "[表情:OK]",
|
||||
"125": "[表情:转圈]",
|
||||
"129": "[表情:挥手]",
|
||||
"137": "[表情:鞭炮]",
|
||||
"144": "[表情:喝彩]",
|
||||
"146": "[表情:爆筋]",
|
||||
"147": "[表情:棒棒糖]",
|
||||
"169": "[表情:手枪]",
|
||||
"171": "[表情:茶]",
|
||||
"172": "[表情:眨眼睛]",
|
||||
"173": "[表情:泪奔]",
|
||||
"174": "[表情:无奈]",
|
||||
"175": "[表情:卖萌]",
|
||||
"176": "[表情:小纠结]",
|
||||
"177": "[表情:喷血]",
|
||||
"178": "[表情:斜眼笑]",
|
||||
"179": "[表情:doge]",
|
||||
"181": "[表情:戳一戳]",
|
||||
"182": "[表情:笑哭]",
|
||||
"183": "[表情:我最美]",
|
||||
"185": "[表情:羊驼]",
|
||||
"187": "[表情:幽灵]",
|
||||
"201": "[表情:点赞]",
|
||||
"212": "[表情:托腮]",
|
||||
"262": "[表情:脑阔疼]",
|
||||
"263": "[表情:沧桑]",
|
||||
"264": "[表情:捂脸]",
|
||||
"265": "[表情:辣眼睛]",
|
||||
"266": "[表情:哦哟]",
|
||||
"267": "[表情:头秃]",
|
||||
"268": "[表情:问号脸]",
|
||||
"269": "[表情:暗中观察]",
|
||||
"270": "[表情:emm]",
|
||||
"271": "[表情:吃 瓜]",
|
||||
"272": "[表情:呵呵哒]",
|
||||
"273": "[表情:我酸了]",
|
||||
"277": "[表情:汪汪]",
|
||||
"281": "[表情:无眼笑]",
|
||||
"282": "[表情:敬礼]",
|
||||
"283": "[表情:狂笑]",
|
||||
"284": "[表情:面无表情]",
|
||||
"285": "[表情:摸鱼]",
|
||||
"286": "[表情:魔鬼笑]",
|
||||
"287": "[表情:哦]",
|
||||
"289": "[表情:睁眼]",
|
||||
"293": "[表情:摸锦鲤]",
|
||||
"294": "[表情:期待]",
|
||||
"295": "[表情:拿到红包]",
|
||||
"297": "[表情:拜谢]",
|
||||
"298": "[表情:元宝]",
|
||||
"299": "[表情:牛啊]",
|
||||
"300": "[表情:胖三斤]",
|
||||
"302": "[表情:左拜年]",
|
||||
"303": "[表情:右拜年]",
|
||||
"305": "[表情:右亲亲]",
|
||||
"306": "[表情:牛气冲天]",
|
||||
"307": "[表情:喵喵]",
|
||||
"311": "[表情:打call]",
|
||||
"312": "[表情:变形]",
|
||||
"314": "[表情:仔细分析]",
|
||||
"317": "[表情:菜汪]",
|
||||
"318": "[表情:崇拜]",
|
||||
"319": "[表情: 比心]",
|
||||
"320": "[表情:庆祝]",
|
||||
"323": "[表情:嫌弃]",
|
||||
"324": "[表情:吃糖]",
|
||||
"325": "[表情:惊吓]",
|
||||
"326": "[表情:生气]",
|
||||
"332": "[表情:举牌牌]",
|
||||
"333": "[表情:烟花]",
|
||||
"334": "[表情:虎虎生威]",
|
||||
"336": "[表情:豹富]",
|
||||
"337": "[表情:花朵脸]",
|
||||
"338": "[表情:我想开了]",
|
||||
"339": "[表情:舔屏]",
|
||||
"341": "[表情:打招呼]",
|
||||
"342": "[表情:酸Q]",
|
||||
"343": "[表情:我方了]",
|
||||
"344": "[表情:大怨种]",
|
||||
"345": "[表情:红包多多]",
|
||||
"346": "[表情:你真棒棒]",
|
||||
"347": "[表情:大展宏兔]",
|
||||
"349": "[表情:坚强]",
|
||||
"350": "[表情:贴贴]",
|
||||
"351": "[表情:敲敲]",
|
||||
"352": "[表情:咦]",
|
||||
"353": "[表情:拜托]",
|
||||
"354": "[表情:尊嘟假嘟]",
|
||||
"355": "[表情:耶]",
|
||||
"356": "[表情:666]",
|
||||
"357": "[表情:裂开]",
|
||||
"392": "[表情:龙年 快乐]",
|
||||
"393": "[表情:新年中龙]",
|
||||
"394": "[表情:新年大龙]",
|
||||
"395": "[表情:略略略]",
|
||||
"😊": "[表情:嘿嘿]",
|
||||
"😌": "[表情:羞涩]",
|
||||
"😚": "[ 表情:亲亲]",
|
||||
"😓": "[表情:汗]",
|
||||
"😰": "[表情:紧张]",
|
||||
"😝": "[表情:吐舌]",
|
||||
"😁": "[表情:呲牙]",
|
||||
"😜": "[表情:淘气]",
|
||||
"☺": "[表情:可爱]",
|
||||
"😍": "[表情:花痴]",
|
||||
"😔": "[表情:失落]",
|
||||
"😄": "[表情:高兴]",
|
||||
"😏": "[表情:哼哼]",
|
||||
"😒": "[表情:不屑]",
|
||||
"😳": "[表情:瞪眼]",
|
||||
"😘": "[表情:飞吻]",
|
||||
"😭": "[表情:大哭]",
|
||||
"😱": "[表情:害怕]",
|
||||
"😂": "[表情:激动]",
|
||||
"💪": "[表情:肌肉]",
|
||||
"👊": "[表情:拳头]",
|
||||
"👍": "[表情 :厉害]",
|
||||
"👏": "[表情:鼓掌]",
|
||||
"👎": "[表情:鄙视]",
|
||||
"🙏": "[表情:合十]",
|
||||
"👌": "[表情:好的]",
|
||||
"👆": "[表情:向上]",
|
||||
"👀": "[表情:眼睛]",
|
||||
"🍜": "[表情:拉面]",
|
||||
"🍧": "[表情:刨冰]",
|
||||
"🍞": "[表情:面包]",
|
||||
"🍺": "[表情:啤酒]",
|
||||
"🍻": "[表情:干杯]",
|
||||
"☕": "[表情:咖啡]",
|
||||
"🍎": "[表情:苹果]",
|
||||
"🍓": "[表情:草莓]",
|
||||
"🍉": "[表情:西瓜]",
|
||||
"🚬": "[表情:吸烟]",
|
||||
"🌹": "[表情:玫瑰]",
|
||||
"🎉": "[表情:庆祝]",
|
||||
"💝": "[表情:礼物]",
|
||||
"💣": "[表情:炸弹]",
|
||||
"✨": "[表情:闪光]",
|
||||
"💨": "[表情:吹气]",
|
||||
"💦": "[表情:水]",
|
||||
"🔥": "[表情:火]",
|
||||
"💤": "[表情:睡觉]",
|
||||
"💩": "[表情:便便]",
|
||||
"💉": "[表情:打针]",
|
||||
"📫": "[表情:邮箱]",
|
||||
"🐎": "[表情:骑马]",
|
||||
"👧": "[表情:女孩]",
|
||||
"👦": "[表情:男孩]",
|
||||
"🐵": "[表情:猴]",
|
||||
"🐷": "[表情:猪]",
|
||||
"🐮": "[表情:牛]",
|
||||
"🐔": "[表情:公鸡]",
|
||||
"🐸": "[表情:青蛙]",
|
||||
"👻": "[表情:幽灵]",
|
||||
"🐛": "[表情:虫]",
|
||||
"🐶": "[表情:狗]",
|
||||
"🐳": "[表情:鲸鱼]",
|
||||
"👢": "[表情:靴子]",
|
||||
"☀": "[表情:晴天]",
|
||||
"❔": "[表情:问号]",
|
||||
"🔫": "[表情:手枪]",
|
||||
"💓": "[表情:爱 心]",
|
||||
"🏪": "[表情:便利店]",
|
||||
}
|
||||
45
plugins/napcat_adapter_plugin/src/response_pool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
response_dict: Dict = {}
|
||||
response_time_dict: Dict = {}
|
||||
|
||||
|
||||
async def get_response(request_id: str, timeout: int = 10) -> dict:
|
||||
response = await asyncio.wait_for(_get_response(request_id), timeout)
|
||||
_ = response_time_dict.pop(request_id)
|
||||
logger.info(f"响应信息id: {request_id} 已从响应字典中取出")
|
||||
return response
|
||||
|
||||
async def _get_response(request_id: str) -> dict:
|
||||
"""
|
||||
内部使用的获取响应函数,主要用于在需要时获取响应
|
||||
"""
|
||||
while request_id not in response_dict:
|
||||
await asyncio.sleep(0.2)
|
||||
return response_dict.pop(request_id)
|
||||
|
||||
async def put_response(response: dict):
|
||||
echo_id = response.get("echo")
|
||||
now_time = time.time()
|
||||
response_dict[echo_id] = response
|
||||
response_time_dict[echo_id] = now_time
|
||||
logger.info(f"响应信息id: {echo_id} 已存入响应字典")
|
||||
|
||||
|
||||
async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
response_dict.pop(echo_id)
|
||||
response_time_dict.pop(echo_id)
|
||||
logger.warning(f"响应消息 {echo_id} 超时,已删除")
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
|
||||
711
plugins/napcat_adapter_plugin/src/send_handler.py
Normal file
@@ -0,0 +1,711 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import websockets as Server
|
||||
import uuid
|
||||
import asyncio
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
)
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
|
||||
from . import CommandType
|
||||
from .config import global_config
|
||||
from .response_pool import get_response
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
from .websocket_manager import websocket_manager
|
||||
from .config.features_config import features_manager
|
||||
|
||||
|
||||
class SendHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Optional[Server.ServerConnection] = None
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
def get_server_connection(self) -> Optional[Server.ServerConnection]:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
async def handle_message(self, raw_message_base_dict: dict) -> None:
|
||||
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
logger.info("接收到来自MaiBot的消息,处理中")
|
||||
if message_segment.type == "command":
|
||||
logger.info("处理命令")
|
||||
return await self.send_command(raw_message_base)
|
||||
elif message_segment.type == "adapter_command":
|
||||
logger.info("处理适配器命令")
|
||||
return await self.handle_adapter_command(raw_message_base)
|
||||
else:
|
||||
logger.info("处理普通消息")
|
||||
return await self.send_normal_message(raw_message_base)
|
||||
|
||||
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理普通消息发送
|
||||
"""
|
||||
logger.info("处理普通信息中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: Optional[GroupInfo] = message_info.group_info
|
||||
user_info: Optional[UserInfo] = message_info.user_info
|
||||
target_id: Optional[int] = None
|
||||
action: Optional[str] = None
|
||||
id_name: Optional[str] = None
|
||||
processed_message: list = []
|
||||
try:
|
||||
if user_info:
|
||||
processed_message = await self.handle_seg_recursive(
|
||||
message_segment, user_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
return None
|
||||
|
||||
if group_info and user_info:
|
||||
logger.debug("发送群聊消息")
|
||||
target_id = int(group_info.group_id) if group_info.group_id else None
|
||||
action = "send_group_msg"
|
||||
id_name = "group_id"
|
||||
elif user_info:
|
||||
logger.debug("发送私聊消息")
|
||||
target_id = int(user_info.user_id) if user_info.user_id else None
|
||||
action = "send_private_msg"
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
response = await self.send_message_to_napcat(
|
||||
action,
|
||||
{
|
||||
id_name: target_id,
|
||||
"message": processed_message,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "ok":
|
||||
logger.info("消息发送成功")
|
||||
qq_message_id = response.get("data", {}).get("message_id")
|
||||
await self.message_sent_back(raw_message_base, qq_message_id)
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
|
||||
async def send_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理命令类
|
||||
"""
|
||||
logger.info("处理命令中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: Optional[GroupInfo] = message_info.group_info
|
||||
seg_data: Dict[str, Any] = (
|
||||
message_segment.data
|
||||
if isinstance(message_segment.data, dict)
|
||||
else {}
|
||||
)
|
||||
command_name: Optional[str] = seg_data.get("name")
|
||||
try:
|
||||
args = seg_data.get("args", {})
|
||||
if not isinstance(args, dict):
|
||||
args = {}
|
||||
|
||||
match command_name:
|
||||
case CommandType.GROUP_BAN.name:
|
||||
command, args_dict = self.handle_ban_command(args, group_info)
|
||||
case CommandType.GROUP_WHOLE_BAN.name:
|
||||
command, args_dict = self.handle_whole_ban_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.GROUP_KICK.name:
|
||||
command, args_dict = self.handle_kick_command(args, group_info)
|
||||
case CommandType.SEND_POKE.name:
|
||||
command, args_dict = self.handle_poke_command(args, group_info)
|
||||
case CommandType.DELETE_MSG.name:
|
||||
command, args_dict = self.delete_msg_command(args)
|
||||
case CommandType.AI_VOICE_SEND.name:
|
||||
command, args_dict = self.handle_ai_voice_send_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.SET_EMOJI_LIKE.name:
|
||||
command, args_dict = self.handle_set_emoji_like_command(args)
|
||||
case CommandType.SEND_AT_MESSAGE.name:
|
||||
command, args_dict = self.handle_at_message_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.SEND_LIKE.name:
|
||||
command, args_dict = self.handle_send_like_command(args)
|
||||
case _:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
|
||||
if not command or not args_dict:
|
||||
logger.error("命令或参数缺失")
|
||||
return None
|
||||
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理适配器命令类 - 用于直接向Napcat发送命令并返回结果
|
||||
"""
|
||||
logger.info("处理适配器命令中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
seg_data: Dict[str, Any] = (
|
||||
message_segment.data
|
||||
if isinstance(message_segment.data, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
try:
|
||||
action = seg_data.get("action")
|
||||
params = seg_data.get("params", {})
|
||||
request_id = seg_data.get("request_id")
|
||||
|
||||
if not action:
|
||||
logger.error("适配器命令缺少action参数")
|
||||
await self.send_adapter_command_response(
|
||||
raw_message_base,
|
||||
{"status": "error", "message": "缺少action参数"},
|
||||
request_id
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"执行适配器命令: {action}")
|
||||
|
||||
# 直接向Napcat发送命令并获取响应
|
||||
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
|
||||
response = await response_task
|
||||
|
||||
# 发送响应回MaiBot
|
||||
await self.send_adapter_command_response(raw_message_base, response, request_id)
|
||||
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"适配器命令 {action} 执行成功")
|
||||
else:
|
||||
logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器命令时发生错误: {e}")
|
||||
error_response = {"status": "error", "message": str(e)}
|
||||
await self.send_adapter_command_response(
|
||||
raw_message_base,
|
||||
error_response,
|
||||
seg_data.get("request_id")
|
||||
)
|
||||
|
||||
def get_level(self, seg_data: Seg) -> int:
|
||||
if seg_data.type == "seglist":
|
||||
return 1 + max(self.get_level(seg) for seg in seg_data.data)
|
||||
else:
|
||||
return 1
|
||||
|
||||
async def handle_seg_recursive(self, seg_data: Seg, user_info: UserInfo) -> list:
|
||||
payload: list = []
|
||||
if seg_data.type == "seglist":
|
||||
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
|
||||
if not seg_data.data:
|
||||
return []
|
||||
for seg in seg_data.data:
|
||||
payload = await self.process_message_by_type(seg, payload, user_info)
|
||||
else:
|
||||
payload = await self.process_message_by_type(seg_data, payload, user_info)
|
||||
return payload
|
||||
|
||||
async def process_message_by_type(
|
||||
self, seg: Seg, payload: list, user_info: UserInfo
|
||||
) -> list:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
new_payload = payload
|
||||
if seg.type == "reply":
|
||||
target_id = seg.data
|
||||
if target_id == "notice":
|
||||
return payload
|
||||
new_payload = self.build_payload(
|
||||
payload,
|
||||
await self.handle_reply_message(
|
||||
target_id if isinstance(target_id, str) else "", user_info
|
||||
),
|
||||
True,
|
||||
)
|
||||
elif seg.type == "text":
|
||||
text = seg.data
|
||||
if not text:
|
||||
return payload
|
||||
new_payload = self.build_payload(
|
||||
payload,
|
||||
self.handle_text_message(text if isinstance(text, str) else ""),
|
||||
False,
|
||||
)
|
||||
elif seg.type == "face":
|
||||
logger.warning("MaiBot 发送了qq原生表情,暂时不支持")
|
||||
elif seg.type == "image":
|
||||
image = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
|
||||
elif seg.type == "emoji":
|
||||
emoji = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
|
||||
elif seg.type == "voice":
|
||||
voice = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
|
||||
elif seg.type == "voiceurl":
|
||||
voice_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
|
||||
elif seg.type == "music":
|
||||
song_id = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
|
||||
elif seg.type == "videourl":
|
||||
video_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
|
||||
elif seg.type == "file":
|
||||
file_path = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
|
||||
return new_payload
|
||||
|
||||
def build_payload(
|
||||
self, payload: list, addon: dict | list, is_reply: bool = False
|
||||
) -> list:
|
||||
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
temp_list = []
|
||||
if isinstance(addon, list):
|
||||
temp_list.extend(addon)
|
||||
else:
|
||||
temp_list.append(addon)
|
||||
for i in payload:
|
||||
if i.get("type") == "reply":
|
||||
logger.debug("检测到多个回复,使用最新的回复")
|
||||
continue
|
||||
temp_list.append(i)
|
||||
return temp_list
|
||||
else:
|
||||
if isinstance(addon, list):
|
||||
payload.extend(addon)
|
||||
else:
|
||||
payload.append(addon)
|
||||
return payload
|
||||
|
||||
async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list:
|
||||
"""处理回复消息"""
|
||||
reply_seg = {"type": "reply", "data": {"id": id}}
|
||||
|
||||
# 获取功能配置
|
||||
ft_config = features_manager.get_config()
|
||||
|
||||
# 检查是否启用引用艾特功能
|
||||
if not ft_config.enable_reply_at:
|
||||
return reply_seg
|
||||
|
||||
try:
|
||||
# 尝试通过 message_id 获取消息详情
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||
|
||||
replied_user_id = None
|
||||
if msg_info_response and msg_info_response.get("status") == "ok":
|
||||
sender_info = msg_info_response.get("data", {}).get("sender")
|
||||
if sender_info:
|
||||
replied_user_id = sender_info.get("user_id")
|
||||
|
||||
# 如果没有获取到被回复者的ID,则直接返回,不进行@
|
||||
if not replied_user_id:
|
||||
logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @")
|
||||
return reply_seg
|
||||
|
||||
# 根据概率决定是否艾特用户
|
||||
if random.random() < ft_config.reply_at_rate:
|
||||
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
|
||||
# 在艾特后面添加一个空格
|
||||
text_seg = {"type": "text", "data": {"text": " "}}
|
||||
return [reply_seg, at_seg, text_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理引用回复并尝试@时出错: {e}")
|
||||
# 出现异常时,只发送普通的回复,避免程序崩溃
|
||||
return reply_seg
|
||||
|
||||
return reply_seg
|
||||
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 0,
|
||||
},
|
||||
} # base64 编码的图片
|
||||
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
if image_format != "gif":
|
||||
encoded_image = convert_image_to_gif(encoded_emoji)
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 1,
|
||||
"summary": "[动画表情]",
|
||||
},
|
||||
}
|
||||
|
||||
def handle_voice_message(self, encoded_voice: str) -> dict:
|
||||
"""处理语音消息"""
|
||||
if not global_config.voice.use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
if not encoded_voice:
|
||||
return {}
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理删除消息命令"""
|
||||
return "delete_msg", {"message_id": args["message_id"]}
|
||||
|
||||
def handle_ban_command(
|
||||
self, args: Dict[str, Any], group_info: GroupInfo
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
duration: int = int(args["duration"])
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if duration < 0:
|
||||
raise ValueError("封禁时间必须大于等于0")
|
||||
if not user_id or not group_id:
|
||||
raise ValueError("封禁命令缺少必要参数")
|
||||
if duration > 2592000:
|
||||
raise ValueError("封禁时间不能超过30天")
|
||||
return (
|
||||
CommandType.GROUP_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
enable = args["enable"]
|
||||
assert isinstance(enable, bool), "enable参数必须是布尔值"
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
return (
|
||||
CommandType.GROUP_WHOLE_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"enable": enable,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.GROUP_KICK.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"reject_add_request": False, # 不拒绝加群请求
|
||||
},
|
||||
)
|
||||
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
if group_info is None:
|
||||
group_id = None
|
||||
else:
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.SEND_POKE.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理设置表情回应命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
try:
|
||||
message_id = int(args["message_id"])
|
||||
emoji_id = int(args["emoji_id"])
|
||||
set_like = str(args["set"])
|
||||
except:
|
||||
raise ValueError("缺少必需参数: message_id 或 emoji_id")
|
||||
|
||||
return (
|
||||
CommandType.SET_EMOJI_LIKE.value,
|
||||
{
|
||||
"message_id": message_id,
|
||||
"emoji_id": emoji_id,
|
||||
"set": set_like
|
||||
},
|
||||
)
|
||||
|
||||
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理发送点赞命令的逻辑。
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
try:
|
||||
user_id: int = int(args["qq_id"])
|
||||
times: int = int(args["times"])
|
||||
except (KeyError, ValueError):
|
||||
raise ValueError("缺少必需参数: qq_id 或 times")
|
||||
|
||||
return (
|
||||
CommandType.SEND_LIKE.value,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"times": times
|
||||
},
|
||||
)
|
||||
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
"""
|
||||
if not group_info or not group_info.group_id:
|
||||
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
|
||||
if not args:
|
||||
raise ValueError("AI语音发送命令缺少参数")
|
||||
|
||||
group_id: int = int(group_info.group_id)
|
||||
character_id = args.get("character")
|
||||
text_content = args.get("text")
|
||||
|
||||
if not character_id or not text_content:
|
||||
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
|
||||
|
||||
return (
|
||||
CommandType.AI_VOICE_SEND.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"text": text_content,
|
||||
"character": character_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||
|
||||
# 获取当前连接
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
logger.error("没有可用的 Napcat 连接")
|
||||
return {"status": "error", "message": "no connection"}
|
||||
|
||||
try:
|
||||
await connection.send(payload)
|
||||
response = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("发送消息超时,未收到响应")
|
||||
return {"status": "error", "message": "timeout"}
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
return response
|
||||
|
||||
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if message_base.message_info.additional_config is None:
|
||||
message_base.message_info.additional_config = {}
|
||||
|
||||
message_base.message_info.additional_config["echo"] = True
|
||||
|
||||
# 获取原始的 mmc_message_id
|
||||
mmc_message_id = message_base.message_info.message_id
|
||||
|
||||
# 修改 message_segment 为 notify 类型
|
||||
message_base.message_segment = Seg(
|
||||
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
|
||||
)
|
||||
await message_send_instance.message_send(message_base)
|
||||
logger.debug("已回送消息ID")
|
||||
return
|
||||
|
||||
async def send_adapter_command_response(
|
||||
self, original_message: MessageBase, response_data: dict, request_id: str
|
||||
) -> None:
|
||||
"""
|
||||
发送适配器命令响应回MaiBot
|
||||
|
||||
Args:
|
||||
original_message: 原始消息
|
||||
response_data: 响应数据
|
||||
request_id: 请求ID
|
||||
"""
|
||||
try:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if original_message.message_info.additional_config is None:
|
||||
original_message.message_info.additional_config = {}
|
||||
|
||||
original_message.message_info.additional_config["echo"] = True
|
||||
|
||||
# 修改 message_segment 为 adapter_response 类型
|
||||
original_message.message_segment = Seg(
|
||||
type="adapter_response",
|
||||
data={
|
||||
"request_id": request_id,
|
||||
"response": response_data,
|
||||
"timestamp": int(time.time() * 1000)
|
||||
}
|
||||
)
|
||||
|
||||
await message_send_instance.message_send(original_message)
|
||||
logger.debug(f"已发送适配器命令响应,request_id: {request_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送适配器命令响应时出错: {e}")
|
||||
|
||||
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理艾特并发送消息命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典, 包含 qq_id 和 text
|
||||
group_info (GroupInfo): 群聊信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: (action, params)
|
||||
"""
|
||||
at_user_id = args.get("qq_id")
|
||||
text = args.get("text")
|
||||
|
||||
if not at_user_id or not text:
|
||||
raise ValueError("艾特消息命令缺少 qq_id 或 text 参数")
|
||||
|
||||
if not group_info:
|
||||
raise ValueError("艾特消息命令必须在群聊上下文中使用")
|
||||
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": str(at_user_id)}},
|
||||
{"type": "text", "data": {"text": " " + str(text)}},
|
||||
]
|
||||
|
||||
return (
|
||||
"send_group_msg",
|
||||
{
|
||||
"group_id": group_info.group_id,
|
||||
"message": message_payload,
|
||||
},
|
||||
)
|
||||
|
||||
send_handler = SendHandler()
|
||||
311
plugins/napcat_adapter_plugin/src/utils.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import websockets as Server
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
from .database import BanUser, db_manager
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .response_pool import get_response
|
||||
|
||||
from PIL import Image
|
||||
from typing import Union, List, Tuple, Optional
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
context = ssl.create_default_context()
|
||||
context.set_ciphers("DEFAULT@SECLEVEL=1")
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
kwargs["ssl_context"] = context
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群相关信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群详细信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群详细信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群详细信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群详细信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取群成员信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_group_member_info",
|
||||
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取成员信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
# sourcery skip: raise-specific-error
|
||||
"""获取图片/表情包的Base64"""
|
||||
logger.debug(f"下载图片: {url}")
|
||||
http = SSLAdapter()
|
||||
try:
|
||||
response = http.request("GET", url, timeout=10)
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP Error: {response.status}")
|
||||
image_bytes = response.data
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片下载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_image_to_gif(image_base64: str) -> str:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将Base64编码的图片转换为GIF格式
|
||||
Parameters:
|
||||
image_base64: str: Base64编码的图片数据
|
||||
Returns:
|
||||
str: Base64编码的GIF图片数据
|
||||
"""
|
||||
logger.debug("转换图片为GIF格式")
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="GIF")
|
||||
output_buffer.seek(0)
|
||||
return base64.b64encode(output_buffer.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为GIF失败: {str(e)}")
|
||||
return image_base64
|
||||
|
||||
|
||||
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
获取自身信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
Returns:
|
||||
data: dict: 返回的自身信息
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取自身信息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取自身信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
"""
|
||||
从Base64编码的数据中确定图片的格式。
|
||||
Parameters:
|
||||
raw_data: str: Base64编码的图片数据。
|
||||
Returns:
|
||||
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
|
||||
"""
|
||||
image_bytes = base64.b64decode(raw_data)
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
|
||||
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取陌生人信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
user_id: 用户ID
|
||||
Returns:
|
||||
dict: 返回的陌生人信息
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取陌生人信息超时,用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取陌生人信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
|
||||
"""
|
||||
获取消息详情,可能为空
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
message_id: 消息ID
|
||||
Returns:
|
||||
dict: 返回的消息详情
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取消息详情超时,消息ID: {message_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_record_detail(
|
||||
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取语音消息内容
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
file: 文件名
|
||||
file_id: 文件ID
|
||||
Returns:
|
||||
dict: 返回的语音消息详情
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取语音消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def read_ban_list(
|
||||
websocket: Server.ServerConnection,
|
||||
) -> Tuple[List[BanUser], List[BanUser]]:
|
||||
"""
|
||||
从根目录下的data文件夹中的文件读取禁言列表。
|
||||
同时自动更新已经失效禁言
|
||||
Returns:
|
||||
Tuple[
|
||||
一个仍在禁言中的用户的BanUser列表,
|
||||
一个已经自然解除禁言的用户的BanUser列表,
|
||||
一个仍在全体禁言中的群的BanUser列表,
|
||||
一个已经自然解除全体禁言的群的BanUser列表,
|
||||
]
|
||||
"""
|
||||
try:
|
||||
ban_list = db_manager.get_ban_records()
|
||||
lifted_list: List[BanUser] = []
|
||||
logger.info("已经读取禁言列表")
|
||||
for ban_record in ban_list:
|
||||
if ban_record.user_id == 0:
|
||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||
if fetched_group_info is None:
|
||||
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
group_all_shut: int = fetched_group_info.get("group_all_shut")
|
||||
if group_all_shut == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
else:
|
||||
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
|
||||
if fetched_member_info is None:
|
||||
logger.warning(
|
||||
f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
|
||||
)
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
|
||||
if lift_ban_time == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
else:
|
||||
ban_record.lift_time = lift_ban_time
|
||||
db_manager.update_ban_record(ban_list)
|
||||
return ban_list, lifted_list
|
||||
except Exception as e:
|
||||
logger.error(f"读取禁言列表失败: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
def save_ban_record(list: List[BanUser]):
|
||||
return db_manager.update_ban_record(list)
|
||||
192
plugins/napcat_adapter_plugin/src/video_handler.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频下载和处理模块
|
||||
用于从QQ消息中下载视频并转发给Bot进行分析
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("video_handler")
|
||||
|
||||
|
||||
class VideoDownloader:
|
||||
def __init__(self, max_size_mb: int = 100, download_timeout: int = 60):
|
||||
self.max_size_mb = max_size_mb
|
||||
self.download_timeout = download_timeout
|
||||
self.supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
|
||||
def is_video_url(self, url: str) -> bool:
|
||||
"""检查URL是否为视频文件"""
|
||||
try:
|
||||
# QQ视频URL可能没有扩展名,所以先检查Content-Type
|
||||
# 对于QQ视频,我们先假设是视频,稍后通过Content-Type验证
|
||||
|
||||
# 检查URL中是否包含视频相关的关键字
|
||||
video_keywords = ['video', 'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v']
|
||||
url_lower = url.lower()
|
||||
|
||||
# 如果URL包含视频关键字,认为是视频
|
||||
if any(keyword in url_lower for keyword in video_keywords):
|
||||
return True
|
||||
|
||||
# 检查文件扩展名(传统方法)
|
||||
path = Path(url.split('?')[0]) # 移除查询参数
|
||||
if path.suffix.lower() in self.supported_formats:
|
||||
return True
|
||||
|
||||
# 对于QQ等特殊平台,URL可能没有扩展名
|
||||
# 我们允许这些URL通过,稍后通过HTTP头Content-Type验证
|
||||
qq_domains = ['qpic.cn', 'gtimg.cn', 'qq.com', 'tencent.com']
|
||||
if any(domain in url_lower for domain in qq_domains):
|
||||
return True
|
||||
|
||||
return False
|
||||
except:
|
||||
# 如果解析失败,默认允许尝试下载(稍后验证)
|
||||
return True
|
||||
|
||||
def check_file_size(self, content_length: Optional[str]) -> bool:
|
||||
"""检查文件大小是否在允许范围内"""
|
||||
if content_length is None:
|
||||
return True # 无法获取大小时允许下载
|
||||
|
||||
try:
|
||||
size_bytes = int(content_length)
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
return size_mb <= self.max_size_mb
|
||||
except:
|
||||
return True
|
||||
|
||||
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
下载视频文件
|
||||
|
||||
Args:
|
||||
url: 视频URL
|
||||
filename: 可选的文件名
|
||||
|
||||
Returns:
|
||||
dict: 下载结果,包含success、data、filename、error等字段
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始下载视频: {url}")
|
||||
|
||||
# 检查URL格式
|
||||
if not self.is_video_url(url):
|
||||
logger.warning(f"URL格式检查失败: {url}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "不支持的视频格式",
|
||||
"url": url
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 先发送HEAD请求检查文件大小
|
||||
try:
|
||||
async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"HEAD请求失败,状态码: {response.status}")
|
||||
else:
|
||||
content_length = response.headers.get('Content-Length')
|
||||
if not self.check_file_size(content_length):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
|
||||
"url": url
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"HEAD请求失败: {e},继续尝试下载")
|
||||
|
||||
# 下载文件
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response:
|
||||
if response.status != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"下载失败,HTTP状态码: {response.status}",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 检查Content-Type是否为视频
|
||||
content_type = response.headers.get('Content-Type', '').lower()
|
||||
if content_type:
|
||||
# 检查是否为视频类型
|
||||
video_mime_types = [
|
||||
'video/', 'application/octet-stream',
|
||||
'application/x-msvideo', 'video/x-msvideo'
|
||||
]
|
||||
is_video_content = any(mime in content_type for mime in video_mime_types)
|
||||
|
||||
if not is_video_content:
|
||||
logger.warning(f"Content-Type不是视频格式: {content_type}")
|
||||
# 如果不是明确的视频类型,但可能是QQ的特殊格式,继续尝试
|
||||
if 'text/' in content_type or 'application/json' in content_type:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"URL返回的不是视频内容,Content-Type: {content_type}",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 再次检查Content-Length
|
||||
content_length = response.headers.get('Content-Length')
|
||||
if not self.check_file_size(content_length):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 读取文件内容
|
||||
video_data = await response.read()
|
||||
|
||||
# 检查实际文件大小
|
||||
actual_size_mb = len(video_data) / (1024 * 1024)
|
||||
if actual_size_mb > self.max_size_mb:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 确定文件名
|
||||
if filename is None:
|
||||
filename = Path(url.split('?')[0]).name
|
||||
if not filename or '.' not in filename:
|
||||
filename = "video.mp4"
|
||||
|
||||
logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": video_data,
|
||||
"filename": filename,
|
||||
"size_mb": actual_size_mb,
|
||||
"url": url
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "下载超时",
|
||||
"url": url
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载视频时出错: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_video_downloader = None
|
||||
|
||||
def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader:
|
||||
"""获取视频下载器实例"""
|
||||
global _video_downloader
|
||||
if _video_downloader is None:
|
||||
_video_downloader = VideoDownloader(max_size_mb, download_timeout)
|
||||
return _video_downloader
|
||||
158
plugins/napcat_adapter_plugin/src/websocket_manager.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import websockets as Server
|
||||
from typing import Optional, Callable, Any
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .config import global_config
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""WebSocket 连接管理器,支持正向和反向连接"""
|
||||
|
||||
def __init__(self):
|
||||
self.connection: Optional[Server.ServerConnection] = None
|
||||
self.server: Optional[Server.WebSocketServer] = None
|
||||
self.is_running = False
|
||||
self.reconnect_interval = 5 # 重连间隔(秒)
|
||||
self.max_reconnect_attempts = 10 # 最大重连次数
|
||||
|
||||
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""根据配置启动 WebSocket 连接"""
|
||||
mode = global_config.napcat_server.mode
|
||||
|
||||
if mode == "reverse":
|
||||
await self._start_reverse_connection(message_handler)
|
||||
elif mode == "forward":
|
||||
await self._start_forward_connection(message_handler)
|
||||
else:
|
||||
raise ValueError(f"不支持的连接模式: {mode}")
|
||||
|
||||
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""启动反向连接(作为服务器)"""
|
||||
host = global_config.napcat_server.host
|
||||
port = global_config.napcat_server.port
|
||||
|
||||
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
|
||||
|
||||
async def handle_client(websocket, path=None):
|
||||
self.connection = websocket
|
||||
logger.info(f"Napcat 客户端已连接: {websocket.remote_address}")
|
||||
try:
|
||||
await message_handler(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
logger.info("Napcat 客户端已断开连接")
|
||||
|
||||
self.server = await Server.serve(
|
||||
handle_client,
|
||||
host,
|
||||
port,
|
||||
max_size=2**26
|
||||
)
|
||||
self.is_running = True
|
||||
logger.info(f"反向连接服务器已启动,监听地址: ws://{host}:{port}")
|
||||
|
||||
# 保持服务器运行
|
||||
await self.server.serve_forever()
|
||||
|
||||
async def _start_forward_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""启动正向连接(作为客户端)"""
|
||||
url = self._get_forward_url()
|
||||
logger.info(f"正在启动正向连接模式,目标地址: {url}")
|
||||
|
||||
reconnect_count = 0
|
||||
|
||||
while reconnect_count < self.max_reconnect_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接到 Napcat 服务器: {url}")
|
||||
|
||||
# 准备连接参数
|
||||
connect_kwargs = {"max_size": 2**26}
|
||||
|
||||
# 如果配置了访问令牌,添加到请求头
|
||||
if global_config.napcat_server.access_token:
|
||||
connect_kwargs["additional_headers"] = {
|
||||
"Authorization": f"Bearer {global_config.napcat_server.access_token}"
|
||||
}
|
||||
logger.info("已添加访问令牌到连接请求头")
|
||||
|
||||
async with Server.connect(url, **connect_kwargs) as websocket:
|
||||
self.connection = websocket
|
||||
self.is_running = True
|
||||
reconnect_count = 0 # 重置重连计数
|
||||
|
||||
logger.info(f"成功连接到 Napcat 服务器: {url}")
|
||||
|
||||
try:
|
||||
await message_handler(websocket)
|
||||
except Server.exceptions.ConnectionClosed:
|
||||
logger.warning("与 Napcat 服务器的连接已断开")
|
||||
except Exception as e:
|
||||
logger.error(f"处理正向连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
self.is_running = False
|
||||
|
||||
except (Server.exceptions.ConnectionClosed, Server.exceptions.InvalidMessage, OSError, ConnectionRefusedError) as e:
|
||||
reconnect_count += 1
|
||||
logger.warning(f"连接失败 ({reconnect_count}/{self.max_reconnect_attempts}): {e}")
|
||||
|
||||
if reconnect_count < self.max_reconnect_attempts:
|
||||
logger.info(f"将在 {self.reconnect_interval} 秒后重试连接...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
else:
|
||||
logger.error("已达到最大重连次数,停止重连")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"正向连接时发生未知错误: {e}")
|
||||
raise
|
||||
|
||||
def _get_forward_url(self) -> str:
|
||||
"""获取正向连接的 URL"""
|
||||
config = global_config.napcat_server
|
||||
|
||||
# 如果配置了完整的 URL,直接使用
|
||||
if config.url:
|
||||
return config.url
|
||||
|
||||
# 否则根据 host 和 port 构建 URL
|
||||
host = config.host
|
||||
port = config.port
|
||||
return f"ws://{host}:{port}"
|
||||
|
||||
async def stop_connection(self) -> None:
|
||||
"""停止 WebSocket 连接"""
|
||||
self.is_running = False
|
||||
|
||||
if self.connection:
|
||||
try:
|
||||
await self.connection.close()
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 WebSocket 连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
|
||||
if self.server:
|
||||
try:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("WebSocket 服务器已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 WebSocket 服务器时出错: {e}")
|
||||
finally:
|
||||
self.server = None
|
||||
|
||||
def get_connection(self) -> Optional[Server.ServerConnection]:
|
||||
"""获取当前的 WebSocket 连接"""
|
||||
return self.connection
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self.connection is not None and self.is_running
|
||||
|
||||
|
||||
# 全局 WebSocket 管理器实例
|
||||
websocket_manager = WebSocketManager()
|
||||
@@ -0,0 +1,43 @@
|
||||
# 权限配置文件
|
||||
# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能
|
||||
# 支持热重载,修改后会自动生效
|
||||
|
||||
# 群聊权限设置
|
||||
group_list_type = "whitelist" # 群聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
group_list = [] # 群聊ID列表
|
||||
# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人
|
||||
# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人
|
||||
# 示例:group_list = [123456789, 987654321]
|
||||
|
||||
# 私聊权限设置
|
||||
private_list_type = "whitelist" # 私聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
private_list = [] # 用户ID列表
|
||||
# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人
|
||||
# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人
|
||||
# 示例:private_list = [123456789, 987654321]
|
||||
|
||||
# 全局禁止设置
|
||||
ban_user_id = [] # 全局禁止用户ID列表,这些用户无法在任何地方使用机器人
|
||||
ban_qq_bot = false # 是否屏蔽QQ官方机器人消息
|
||||
|
||||
# 聊天功能设置
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳
|
||||
poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略
|
||||
enable_reply_at = true # 是否启用引用回复时艾特用户的功能
|
||||
reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0)
|
||||
|
||||
# 视频处理设置
|
||||
enable_video_analysis = true # 是否启用视频识别功能
|
||||
max_video_size_mb = 100 # 视频文件最大大小限制(MB)
|
||||
download_timeout = 60 # 视频下载超时时间(秒)
|
||||
supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式
|
||||
|
||||
# 消息缓冲设置
|
||||
enable_message_buffer = true # 是否启用消息缓冲合并功能
|
||||
message_buffer_enable_group = true # 是否启用群聊消息缓冲合并
|
||||
message_buffer_enable_private = true # 是否启用私聊消息缓冲合并
|
||||
message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并
|
||||
message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并
|
||||
message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并
|
||||
message_buffer_block_prefixes = ["/", "!", "!", ".", "。", "#", "%"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲
|
||||
25
plugins/napcat_adapter_plugin/template/template_config.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[inner]
|
||||
version = "0.2.0" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)
|
||||
host = "localhost" # 主机地址
|
||||
port = 8095 # 端口号
|
||||
url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)
|
||||
access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选)
|
||||
heartbeat_interval = 30 # 心跳间隔时间(按秒计)
|
||||
|
||||
[maibot_server] # 连接麦麦的ws服务设置
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
||||
89
plugins/napcat_adapter_plugin/todo.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# TODO List:
|
||||
|
||||
[x] logger使用主程序的
|
||||
[ ] 使用插件系统的config系统
|
||||
[ ] 接收从napcat传递的所有信息
|
||||
[ ] 优化架构,各模块解耦,暴露关键方法用于提供接口
|
||||
[ ] 单独一个模块负责与主程序通信
|
||||
[ ] 使用event系统完善接口api
|
||||
|
||||
|
||||
---
|
||||
Event分为两种,一种是对外输出的event,由napcat插件自主触发并传递参数,另一种是接收外界输入的event,由外部插件触发并向napcat传递参数
|
||||
|
||||
|
||||
## 例如,
|
||||
|
||||
### 对外输出的event:
|
||||
|
||||
napcat_on_received_text -> (message_seg: Seg) 接受到qq的文字消息,会向handler传递一个Seg
|
||||
napcat_on_received_face -> (message_seg: Seg) 接受到qq的表情消息,会向handler传递一个Seg
|
||||
napcat_on_received_reply -> (message_seg: Seg) 接受到qq的回复消息,会向handler传递一个Seg
|
||||
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
|
||||
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
|
||||
napcat_on_received_record -> (message_seg: Seg) 接受到qq的语音消息,会向handler传递一个Seg
|
||||
napcat_on_received_rps -> (message_seg: Seg) 接受到qq的猜拳魔法表情,会向handler传递一个Seg
|
||||
napcat_on_received_friend_invitation -> (user_id: str) 接受到qq的好友请求,会向handler传递一个user_id
|
||||
...
|
||||
|
||||
此类event不接受外部插件的触发,只能由napcat插件统一触发。
|
||||
|
||||
外部插件需要编写handler并订阅此类事件。
|
||||
```python
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
class MyEventHandler(BaseEventHandler):
|
||||
handler_name = "my_handler"
|
||||
handler_description = "我的自定义事件处理器"
|
||||
weight = 10 # 权重,越大越先执行
|
||||
intercept_message = False # 是否拦截消息
|
||||
init_subscribe = ["napcat_on_received_text"] # 初始订阅的事件
|
||||
|
||||
async def execute(self, params: dict) -> HandlerResult:
|
||||
"""处理事件"""
|
||||
try:
|
||||
message = params.get("message_seg")
|
||||
print(f"收到消息: {message.data}")
|
||||
|
||||
# 业务逻辑处理
|
||||
# ...
|
||||
|
||||
return HandlerResult(
|
||||
success=True,
|
||||
continue_process=True, # 是否继续让其他处理器处理
|
||||
message="处理成功",
|
||||
handler_name=self.handler_name
|
||||
)
|
||||
except Exception as e:
|
||||
return HandlerResult(
|
||||
success=False,
|
||||
continue_process=True,
|
||||
message=f"处理失败: {str(e)}",
|
||||
handler_name=self.handler_name
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### 接收外界输入的event:
|
||||
|
||||
napcat_kick_group <- (user_id, group_id) 踢出某个群组中的某个用户
|
||||
napcat_mute_user <- (user_id, group_id, time) 禁言某个群组中的某个用户
|
||||
napcat_unmute_user <- (user_id, group_id) 取消禁言某个群组中的某个用户
|
||||
napcat_mute_group <- (user_id, group_id) 禁言某个群组
|
||||
napcat_unmute_group <- (user_id, group_id) 取消禁言某个群组
|
||||
napcat_add_friend <- (user_id) 向某个用户发出好友请求
|
||||
napcat_accept_friend <- (user_id) 接收某个用户的好友请求
|
||||
napcat_reject_friend <- (user_id) 拒绝某个用户的好友请求
|
||||
...
|
||||
此类事件只由外部插件触发并传递参数,由napcat完成请求任务。
|
||||
|
||||
外部插件需要触发此类的event并传递正确的参数。
|
||||
|
||||
```python
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 触发事件
|
||||
await event_manager.trigger_event("napcat_accept_friend", user_id = 1234123)
|
||||
```
|
||||
|
||||
@@ -65,3 +65,6 @@ asyncio
|
||||
tavily-python
|
||||
google-generativeai
|
||||
lunar_python
|
||||
|
||||
python-multipart
|
||||
aiofiles
|
||||
@@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
from .hfc_context import HfcContext
|
||||
from .energy_manager import EnergyManager
|
||||
|
||||
@@ -7,33 +7,11 @@ from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.tool_history import ToolHistoryManager
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
# 创建工具历史管理器实例
|
||||
tool_history_manager = ToolHistoryManager()
|
||||
|
||||
def get_tool_history_prompt(message_id: Optional[str] = None) -> str:
|
||||
"""获取工具历史提示词
|
||||
|
||||
Args:
|
||||
message_id: 会话ID, 用于只获取当前会话的历史
|
||||
|
||||
Returns:
|
||||
格式化的工具历史提示词
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
if not global_config.tool.history.enable_prompt_history:
|
||||
return ""
|
||||
|
||||
return tool_history_manager.get_recent_history_prompt(
|
||||
chat_id=message_id
|
||||
)
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
@@ -49,7 +27,7 @@ class PromptContext:
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value)
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
@@ -73,7 +51,7 @@ class PromptContext:
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
@@ -111,6 +89,7 @@ class PromptContext:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
if prompt.name:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
@@ -153,40 +132,15 @@ class PromptManager:
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
prompt = Prompt(fstr, name=name)
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
# 获取当前提示词
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 获取当前会话ID
|
||||
message_id = self._context._current_context
|
||||
|
||||
# 获取工具历史提示词
|
||||
tool_history = ""
|
||||
if name in ['action_prompt', 'replyer_prompt', 'planner_prompt', 'tool_executor_prompt']:
|
||||
tool_history = get_tool_history_prompt(message_id)
|
||||
|
||||
# 获取基本格式化结果
|
||||
result = prompt.format(**kwargs)
|
||||
|
||||
# 如果有工具历史,插入到适当位置
|
||||
if tool_history:
|
||||
# 查找合适的插入点
|
||||
# 在人格信息和身份块之后,但在主要内容之前
|
||||
identity_end = result.find("```\n现在,你说:")
|
||||
if identity_end == -1:
|
||||
# 如果找不到特定标记,尝试在第一个段落后插入
|
||||
first_double_newline = result.find("\n\n")
|
||||
if first_double_newline != -1:
|
||||
# 在第一个双换行后插入
|
||||
result = f"{result[:first_double_newline + 2]}{tool_history}\n{result[first_double_newline + 2:]}"
|
||||
else:
|
||||
# 如果找不到合适的位置,添加到开头
|
||||
result = f"{tool_history}\n\n{result}"
|
||||
else:
|
||||
# 在找到的位置插入
|
||||
result = f"{result[:identity_end]}\n{tool_history}\n{result[identity_end:]}"
|
||||
return result
|
||||
|
||||
|
||||
@@ -195,6 +149,11 @@ global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt(str):
|
||||
template: str
|
||||
name: Optional[str]
|
||||
args: List[str]
|
||||
_args: List[Any]
|
||||
_kwargs: Dict[str, Any]
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
@@ -215,7 +174,7 @@ class Prompt(str):
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
@@ -251,7 +210,7 @@ class Prompt(str):
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
|
||||
):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
@@ -260,7 +219,9 @@ class Prompt(str):
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# 预处理模板中的转义花括号
|
||||
processed_template = cls._process_escaped_braces(template)
|
||||
|
||||
|
||||
1
src/chat/utils/rust-video/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/target
|
||||
610
src/chat/utils/rust-video/Cargo.lock
generated
Normal file
@@ -0,0 +1,610 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-query"
|
||||
version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
|
||||
dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"once_cell_polyfill",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.99"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c5e4fcf9c21d2e544ca1ee9d8552de13019a42aa7dbf32747fa7aaf1df76e57"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fecb53a0e6fcfb055f686001bc2e2592fa527efaf38dbe81a6a9563562e57d41"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.63"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.175"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-video"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"clap",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.143"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"rustversion",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.61.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-interface"
|
||||
version = "0.59.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.60.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.53.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
24
src/chat/utils/rust-video/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "rust-video"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["VideoAnalysis Team"]
|
||||
description = "Ultra-fast video keyframe extraction tool in Rust"
|
||||
license = "GPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
rayon = "1.11"
|
||||
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
strip = true
|
||||
221
src/chat/utils/rust-video/README.md
Normal file
@@ -0,0 +1,221 @@
|
||||
# 🎯 Rust Video Keyframe Extraction API
|
||||
|
||||
高性能视频关键帧提取API服务,基于Rust后端 + Python FastAPI。
|
||||
|
||||
## 📁 项目结构
|
||||
|
||||
```
|
||||
rust-video/
|
||||
├── outputs/ # 关键帧输出目录
|
||||
├── src/ # Rust源码
|
||||
│ └── main.rs
|
||||
├── target/ # Rust编译文件
|
||||
├── api_server.py # 🚀 主API服务器 (整合版)
|
||||
├── start_server.py # 生产启动脚本
|
||||
├── config.py # 配置管理
|
||||
├── config.toml # 配置文件
|
||||
├── Cargo.toml # Rust项目配置
|
||||
├── Cargo.lock # Rust依赖锁定
|
||||
├── .gitignore # Git忽略文件
|
||||
└── README.md # 项目文档
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
```bash
|
||||
pip install fastapi uvicorn python-multipart aiofiles
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
```bash
|
||||
# 开发模式
|
||||
python api_server.py
|
||||
|
||||
# 生产模式
|
||||
python start_server.py --mode prod --port 8050
|
||||
```
|
||||
|
||||
### 3. 访问API
|
||||
- **服务地址**: http://localhost:8050
|
||||
- **API文档**: http://localhost:8050/docs
|
||||
- **健康检查**: http://localhost:8050/health
|
||||
- **性能指标**: http://localhost:8050/metrics
|
||||
|
||||
## API使用方法
|
||||
|
||||
### 主要端点
|
||||
|
||||
#### 1. 提取关键帧 (JSON响应)
|
||||
```http
|
||||
POST /extract-keyframes
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
- video: 视频文件 (.mp4, .avi, .mov, .mkv)
|
||||
- scene_threshold: 场景变化阈值 (0.1-1.0, 默认0.3)
|
||||
- max_frames: 最大关键帧数 (1-200, 默认50)
|
||||
- resize_width: 调整宽度 (可选, 100-1920)
|
||||
- time_interval: 时间间隔秒数 (可选, 0.1-60.0)
|
||||
```
|
||||
|
||||
#### 2. 提取关键帧 (ZIP下载)
|
||||
```http
|
||||
POST /extract-keyframes-zip
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
参数同上,返回包含所有关键帧的ZIP文件
|
||||
```
|
||||
|
||||
#### 3. 健康检查
|
||||
```http
|
||||
GET /health
|
||||
```
|
||||
|
||||
#### 4. 性能指标
|
||||
```http
|
||||
GET /metrics
|
||||
```
|
||||
|
||||
### Python客户端示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 上传视频并提取关键帧
|
||||
files = {'video': open('video.mp4', 'rb')}
|
||||
data = {
|
||||
'scene_threshold': 0.3,
|
||||
'max_frames': 50,
|
||||
'resize_width': 800
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://localhost:8050/extract-keyframes',
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
print(f"提取了 {result['keyframe_count']} 个关键帧")
|
||||
print(f"处理时间: {result['performance']['total_api_time']:.2f}秒")
|
||||
```
|
||||
|
||||
### JavaScript客户端示例
|
||||
|
||||
```javascript
|
||||
const formData = new FormData();
|
||||
formData.append('video', videoFile);
|
||||
formData.append('scene_threshold', '0.3');
|
||||
formData.append('max_frames', '50');
|
||||
|
||||
fetch('http://localhost:8050/extract-keyframes', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(`提取了 ${data.keyframe_count} 个关键帧`);
|
||||
console.log(`处理时间: ${data.performance.total_api_time}秒`);
|
||||
});
|
||||
```
|
||||
|
||||
### cURL示例
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8050/extract-keyframes" \
|
||||
-H "accept: application/json" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "video=@video.mp4" \
|
||||
-F "scene_threshold=0.3" \
|
||||
-F "max_frames=50"
|
||||
```
|
||||
|
||||
## ⚙️ 配置
|
||||
|
||||
编辑 `config.toml` 文件:
|
||||
|
||||
```toml
|
||||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 8050
|
||||
debug = false
|
||||
|
||||
[processing]
|
||||
default_scene_threshold = 0.3
|
||||
default_max_frames = 50
|
||||
timeout_seconds = 300
|
||||
|
||||
[performance]
|
||||
async_workers = 4
|
||||
max_file_size_mb = 500
|
||||
```
|
||||
|
||||
## 性能特性
|
||||
|
||||
- **异步I/O**: 文件上传/下载异步处理
|
||||
- **多线程处理**: 视频处理在独立线程池
|
||||
- **内存优化**: 流式处理,减少内存占用
|
||||
- **智能清理**: 自动临时文件管理
|
||||
- **性能监控**: 实时处理时间和吞吐量统计
|
||||
|
||||
总之就是非常快()
|
||||
|
||||
## 响应格式
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"processing_time": 4.5,
|
||||
"output_directory": "/tmp/output_xxx",
|
||||
"keyframe_count": 15,
|
||||
"keyframes": [
|
||||
"/tmp/output_xxx/frame_001.jpg",
|
||||
"/tmp/output_xxx/frame_002.jpg"
|
||||
],
|
||||
"performance": {
|
||||
"file_size_mb": 209.7,
|
||||
"upload_time": 0.23,
|
||||
"processing_time": 4.5,
|
||||
"total_api_time": 4.73,
|
||||
"upload_speed_mbps": 912.2
|
||||
},
|
||||
"rust_output": "处理完成",
|
||||
"command": "rust-video input.mp4 output/ --scene-threshold 0.3 --max-frames 50"
|
||||
}
|
||||
```
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **Rust binary not found**
|
||||
```bash
|
||||
cargo build # 重新构建Rust项目
|
||||
```
|
||||
|
||||
2. **端口被占用**
|
||||
```bash
|
||||
# 修改config.toml中的端口号
|
||||
port = 8051
|
||||
```
|
||||
|
||||
3. **内存不足**
|
||||
```bash
|
||||
# 减少max_frames或resize_width参数
|
||||
```
|
||||
|
||||
### 日志查看
|
||||
|
||||
服务启动时会显示详细的状态信息,包括:
|
||||
- Rust二进制文件位置
|
||||
- 配置加载状态
|
||||
- 服务监听地址
|
||||
|
||||
## 集成支持
|
||||
|
||||
本API设计为独立服务,可轻松集成到任何项目中:
|
||||
|
||||
- **AI Bot项目**: 通过HTTP API调用
|
||||
- **Web应用**: 直接前端调用或后端代理
|
||||
- **移动应用**: REST API标准接口
|
||||
- **批处理脚本**: Python/Shell脚本调用
|
||||
472
src/chat/utils/rust-video/api_server.py
Normal file
@@ -0,0 +1,472 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Rust Video Keyframe Extraction API Server
|
||||
高性能视频关键帧提取API服务
|
||||
|
||||
功能:
|
||||
- 视频上传和关键帧提取
|
||||
- 异步多线程处理
|
||||
- 性能监控和健康检查
|
||||
- 自动资源清理
|
||||
|
||||
启动: python api_server.py
|
||||
地址: http://localhost:8050
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import zipfile
|
||||
import shutil
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 导入配置管理
|
||||
from config import config
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============================================================================
|
||||
# 内置视频处理器 (整合版)
|
||||
# ============================================================================
|
||||
|
||||
class VideoKeyframeExtractor:
|
||||
"""整合的视频关键帧提取器"""
|
||||
|
||||
def __init__(self, rust_binary_path: Optional[str] = None):
|
||||
self.rust_binary_path = rust_binary_path or self._find_rust_binary()
|
||||
if not self.rust_binary_path or not Path(self.rust_binary_path).exists():
|
||||
raise FileNotFoundError(f"Rust binary not found: {self.rust_binary_path}")
|
||||
|
||||
def _find_rust_binary(self) -> str:
|
||||
"""查找Rust二进制文件"""
|
||||
possible_paths = [
|
||||
"./target/debug/rust-video.exe",
|
||||
"./target/release/rust-video.exe",
|
||||
"./target/debug/rust-video",
|
||||
"./target/release/rust-video"
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if Path(path).exists():
|
||||
return str(Path(path).absolute())
|
||||
|
||||
# 尝试构建
|
||||
try:
|
||||
subprocess.run(["cargo", "build"], check=True, capture_output=True)
|
||||
for path in possible_paths:
|
||||
if Path(path).exists():
|
||||
return str(Path(path).absolute())
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
raise FileNotFoundError("Rust binary not found and build failed")
|
||||
|
||||
def process_video(
|
||||
self,
|
||||
video_path: str,
|
||||
output_dir: str = "outputs",
|
||||
scene_threshold: float = 0.3,
|
||||
max_frames: int = 50,
|
||||
resize_width: Optional[int] = None,
|
||||
time_interval: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""处理视频提取关键帧"""
|
||||
|
||||
video_path = Path(video_path)
|
||||
if not video_path.exists():
|
||||
raise FileNotFoundError(f"Video file not found: {video_path}")
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 构建命令
|
||||
cmd = [self.rust_binary_path, str(video_path), str(output_dir)]
|
||||
cmd.extend(["--scene-threshold", str(scene_threshold)])
|
||||
cmd.extend(["--max-frames", str(max_frames)])
|
||||
|
||||
if resize_width:
|
||||
cmd.extend(["--resize-width", str(resize_width)])
|
||||
if time_interval:
|
||||
cmd.extend(["--time-interval", str(time_interval)])
|
||||
|
||||
# 执行处理
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 解析输出
|
||||
output_files = list(output_dir.glob("*.jpg"))
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"processing_time": processing_time,
|
||||
"output_directory": str(output_dir),
|
||||
"keyframe_count": len(output_files),
|
||||
"keyframes": [str(f) for f in output_files],
|
||||
"rust_output": result.stdout,
|
||||
"command": " ".join(cmd)
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise HTTPException(status_code=408, detail="Video processing timeout")
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Video processing failed: {e.stderr}"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# 异步处理器 (整合版)
|
||||
# ============================================================================
|
||||
|
||||
class AsyncVideoProcessor:
|
||||
"""高性能异步视频处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.extractor = VideoKeyframeExtractor()
|
||||
|
||||
async def process_video_async(
|
||||
self,
|
||||
upload_file: UploadFile,
|
||||
processing_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""异步视频处理主流程"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 异步保存上传文件
|
||||
upload_start = time.time()
|
||||
temp_fd, temp_path_str = tempfile.mkstemp(suffix='.mp4')
|
||||
temp_path = Path(temp_path_str)
|
||||
|
||||
try:
|
||||
os.close(temp_fd)
|
||||
|
||||
# 异步读取并保存文件
|
||||
content = await upload_file.read()
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
upload_time = time.time() - upload_start
|
||||
file_size = len(content)
|
||||
|
||||
# 2. 多线程处理视频
|
||||
process_start = time.time()
|
||||
temp_output_dir = tempfile.mkdtemp()
|
||||
output_path = Path(temp_output_dir)
|
||||
|
||||
try:
|
||||
# 在线程池中异步处理
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self._process_video_sync,
|
||||
str(temp_path),
|
||||
str(output_path),
|
||||
processing_params
|
||||
)
|
||||
|
||||
process_time = time.time() - process_start
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# 添加性能指标
|
||||
result.update({
|
||||
'performance': {
|
||||
'file_size_mb': file_size / (1024 * 1024),
|
||||
'upload_time': upload_time,
|
||||
'processing_time': process_time,
|
||||
'total_api_time': total_time,
|
||||
'upload_speed_mbps': (file_size / (1024 * 1024)) / upload_time if upload_time > 0 else 0
|
||||
}
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# 清理输出目录
|
||||
try:
|
||||
shutil.rmtree(temp_output_dir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup output directory: {e}")
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup temp file: {e}")
|
||||
|
||||
def _process_video_sync(self, video_path: str, output_dir: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""在线程池中同步处理视频"""
|
||||
return self.extractor.process_video(
|
||||
video_path=video_path,
|
||||
output_dir=output_dir,
|
||||
**params
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# FastAPI 应用初始化
|
||||
# ============================================================================
|
||||
|
||||
app = FastAPI(
|
||||
title="Rust Video Keyframe API",
|
||||
description="高性能视频关键帧提取API服务",
|
||||
version="2.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局处理器实例
|
||||
video_processor = AsyncVideoProcessor()
|
||||
|
||||
# 简单的统计
|
||||
stats = {
|
||||
"total_requests": 0,
|
||||
"processing_times": [],
|
||||
"start_time": datetime.now()
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# API 路由
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/", response_class=JSONResponse)
|
||||
async def root():
|
||||
"""API根路径"""
|
||||
return {
|
||||
"message": "Rust Video Keyframe Extraction API",
|
||||
"version": "2.0.0",
|
||||
"status": "ready",
|
||||
"docs": "/docs",
|
||||
"health": "/health",
|
||||
"metrics": "/metrics"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
try:
|
||||
# 检查Rust二进制
|
||||
rust_binary = video_processor.extractor.rust_binary_path
|
||||
rust_status = "ok" if Path(rust_binary).exists() else "missing"
|
||||
|
||||
return {
|
||||
"status": rust_status,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"version": "2.0.0",
|
||||
"rust_binary": rust_binary
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Health check failed: {str(e)}")
|
||||
|
||||
@app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""获取性能指标"""
|
||||
avg_time = sum(stats["processing_times"]) / len(stats["processing_times"]) if stats["processing_times"] else 0
|
||||
uptime = (datetime.now() - stats["start_time"]).total_seconds()
|
||||
|
||||
return {
|
||||
"total_requests": stats["total_requests"],
|
||||
"average_processing_time": avg_time,
|
||||
"last_24h_requests": stats["total_requests"], # 简化版本
|
||||
"system_info": {
|
||||
"uptime_seconds": uptime,
|
||||
"memory_usage": "N/A", # 可以扩展
|
||||
"cpu_usage": "N/A"
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/extract-keyframes")
|
||||
async def extract_keyframes(
|
||||
video: UploadFile = File(..., description="视频文件"),
|
||||
scene_threshold: float = Form(0.3, description="场景变化阈值"),
|
||||
max_frames: int = Form(50, description="最大关键帧数量"),
|
||||
resize_width: Optional[int] = Form(None, description="调整宽度"),
|
||||
time_interval: Optional[float] = Form(None, description="时间间隔")
|
||||
):
|
||||
"""提取视频关键帧 (主要API端点)"""
|
||||
|
||||
# 参数验证
|
||||
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
||||
raise HTTPException(status_code=400, detail="不支持的视频格式")
|
||||
|
||||
# 更新统计
|
||||
stats["total_requests"] += 1
|
||||
|
||||
try:
|
||||
# 构建处理参数
|
||||
params = {
|
||||
"scene_threshold": scene_threshold,
|
||||
"max_frames": max_frames
|
||||
}
|
||||
if resize_width:
|
||||
params["resize_width"] = resize_width
|
||||
if time_interval:
|
||||
params["time_interval"] = time_interval
|
||||
|
||||
# 异步处理
|
||||
start_time = time.time()
|
||||
result = await video_processor.process_video_async(video, params)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
stats["processing_times"].append(processing_time)
|
||||
if len(stats["processing_times"]) > 100: # 保持最近100次记录
|
||||
stats["processing_times"] = stats["processing_times"][-100:]
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Processing failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||||
|
||||
@app.post("/extract-keyframes-zip")
|
||||
async def extract_keyframes_zip(
|
||||
video: UploadFile = File(...),
|
||||
scene_threshold: float = Form(0.3),
|
||||
max_frames: int = Form(50),
|
||||
resize_width: Optional[int] = Form(None),
|
||||
time_interval: Optional[float] = Form(None)
|
||||
):
|
||||
"""提取关键帧并返回ZIP文件"""
|
||||
|
||||
# 验证文件类型
|
||||
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
||||
raise HTTPException(status_code=400, detail="不支持的视频格式")
|
||||
|
||||
# 创建临时目录
|
||||
temp_input_fd, temp_input_path = tempfile.mkstemp(suffix='.mp4')
|
||||
temp_output_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
os.close(temp_input_fd)
|
||||
|
||||
# 保存上传的视频
|
||||
content = await video.read()
|
||||
with open(temp_input_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# 处理参数
|
||||
params = {
|
||||
"scene_threshold": scene_threshold,
|
||||
"max_frames": max_frames
|
||||
}
|
||||
if resize_width:
|
||||
params["resize_width"] = resize_width
|
||||
if time_interval:
|
||||
params["time_interval"] = time_interval
|
||||
|
||||
# 处理视频
|
||||
result = video_processor.extractor.process_video(
|
||||
video_path=temp_input_path,
|
||||
output_dir=temp_output_dir,
|
||||
**params
|
||||
)
|
||||
|
||||
# 创建ZIP文件
|
||||
zip_fd, zip_path = tempfile.mkstemp(suffix='.zip')
|
||||
os.close(zip_fd)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'w') as zip_file:
|
||||
# 添加关键帧图片
|
||||
for keyframe_path in result.get("keyframes", []):
|
||||
if Path(keyframe_path).exists():
|
||||
zip_file.write(keyframe_path, Path(keyframe_path).name)
|
||||
|
||||
# 添加处理信息
|
||||
info_content = json.dumps(result, indent=2, ensure_ascii=False)
|
||||
zip_file.writestr("processing_info.json", info_content)
|
||||
|
||||
# 返回ZIP文件
|
||||
return FileResponse(
|
||||
zip_path,
|
||||
media_type='application/zip',
|
||||
filename=f"keyframes_{video.filename}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
for path in [temp_input_path, temp_output_dir]:
|
||||
try:
|
||||
if Path(path).is_file():
|
||||
Path(path).unlink()
|
||||
elif Path(path).is_dir():
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ============================================================================
|
||||
# 应用启动
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
"""启动API服务器"""
|
||||
|
||||
# 获取配置
|
||||
server_config = config.get('server')
|
||||
host = server_config.get('host', '0.0.0.0')
|
||||
port = server_config.get('port', 8050)
|
||||
|
||||
print(f"""
|
||||
Rust Video Keyframe Extraction API
|
||||
=====================================
|
||||
地址: http://{host}:{port}
|
||||
文档: http://{host}:{port}/docs
|
||||
健康检查: http://{host}:{port}/health
|
||||
性能指标: http://{host}:{port}/metrics
|
||||
=====================================
|
||||
""")
|
||||
|
||||
# 检查Rust二进制
|
||||
try:
|
||||
rust_binary = video_processor.extractor.rust_binary_path
|
||||
print(f"✓ Rust binary: {rust_binary}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Rust binary check failed: {e}")
|
||||
|
||||
# 启动服务器
|
||||
uvicorn.run(
|
||||
"api_server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=False, # 生产环境关闭热重载
|
||||
access_log=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
src/chat/utils/rust-video/config.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
配置管理模块
|
||||
处理 config.toml 文件的读取和管理
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
try:
|
||||
import toml
|
||||
except ImportError:
|
||||
print("⚠️ 需要安装 toml: pip install toml")
|
||||
# 提供基础配置作为后备
|
||||
toml = None
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
def __init__(self, config_file: str = "config.toml"):
|
||||
self.config_file = Path(config_file)
|
||||
self._config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
if toml is None or not self.config_file.exists():
|
||||
return self._get_default_config()
|
||||
|
||||
try:
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
return toml.load(f)
|
||||
except Exception as e:
|
||||
print(f"⚠️ 配置文件读取失败: {e}")
|
||||
return self._get_default_config()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""默认配置"""
|
||||
return {
|
||||
"server": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 8000,
|
||||
"workers": 1,
|
||||
"reload": False,
|
||||
"log_level": "info"
|
||||
},
|
||||
"api": {
|
||||
"title": "Video Keyframe Extraction API",
|
||||
"description": "高性能视频关键帧提取服务",
|
||||
"version": "1.0.0",
|
||||
"max_file_size": "100MB"
|
||||
},
|
||||
"processing": {
|
||||
"default_threshold": 0.3,
|
||||
"default_output_format": "png",
|
||||
"max_frames": 10000,
|
||||
"temp_dir": "temp",
|
||||
"upload_dir": "uploads",
|
||||
"output_dir": "outputs"
|
||||
},
|
||||
"rust": {
|
||||
"executable_name": "video_keyframe_extractor",
|
||||
"executable_path": "target/release"
|
||||
},
|
||||
"ffmpeg": {
|
||||
"auto_detect": True,
|
||||
"custom_path": "",
|
||||
"timeout": 300
|
||||
},
|
||||
"storage": {
|
||||
"cleanup_interval": 3600,
|
||||
"max_storage_size": "10GB",
|
||||
"result_retention_days": 7
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": True,
|
||||
"enable_logging": True,
|
||||
"log_file": "logs/api.log",
|
||||
"max_log_size": "100MB"
|
||||
},
|
||||
"security": {
|
||||
"allowed_origins": ["*"],
|
||||
"max_concurrent_tasks": 10,
|
||||
"rate_limit_per_minute": 60
|
||||
},
|
||||
"development": {
|
||||
"debug": False,
|
||||
"auto_reload": False,
|
||||
"cors_enabled": True
|
||||
}
|
||||
}
|
||||
|
||||
def get(self, section: str, key: str = None, default=None):
|
||||
"""获取配置值"""
|
||||
if key is None:
|
||||
return self._config.get(section, default)
|
||||
return self._config.get(section, {}).get(key, default)
|
||||
|
||||
def get_server_config(self):
|
||||
"""获取服务器配置"""
|
||||
return self.get("server")
|
||||
|
||||
def get_api_config(self):
|
||||
"""获取API配置"""
|
||||
return self.get("api")
|
||||
|
||||
def get_processing_config(self):
|
||||
"""获取处理配置"""
|
||||
return self.get("processing")
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置"""
|
||||
self._config = self._load_config()
|
||||
|
||||
# 全局配置实例
|
||||
config = ConfigManager()
|
||||
70
src/chat/utils/rust-video/config.toml
Normal file
@@ -0,0 +1,70 @@
|
||||
# 🔧 Video Keyframe Extraction API 配置文件
|
||||
|
||||
[server]
|
||||
# 服务器配置
|
||||
host = "0.0.0.0"
|
||||
port = 8050
|
||||
workers = 1
|
||||
reload = false
|
||||
log_level = "info"
|
||||
|
||||
[api]
|
||||
# API 基础配置
|
||||
title = "Video Keyframe Extraction API"
|
||||
description = "视频关键帧提取服务"
|
||||
version = "1.0.0"
|
||||
max_file_size = "100MB" # 最大文件大小
|
||||
|
||||
[processing]
|
||||
# 视频处理配置
|
||||
default_threshold = 0.3
|
||||
default_output_format = "png"
|
||||
max_frames = 10000
|
||||
temp_dir = "temp"
|
||||
upload_dir = "uploads"
|
||||
output_dir = "outputs"
|
||||
|
||||
[rust]
|
||||
# Rust 程序配置
|
||||
executable_name = "video_keyframe_extractor"
|
||||
executable_path = "target/release" # 相对路径,自动检测
|
||||
|
||||
[ffmpeg]
|
||||
# FFmpeg 配置
|
||||
auto_detect = true
|
||||
custom_path = "" # 留空则自动检测
|
||||
timeout = 300 # 秒
|
||||
|
||||
[performance]
|
||||
# 性能优化配置
|
||||
async_workers = 4 # 异步文件处理工作线程数
|
||||
upload_chunk_size = 8192 # 上传块大小 (字节)
|
||||
max_concurrent_uploads = 10 # 最大并发上传数
|
||||
compression_level = 1 # ZIP 压缩级别 (0-9, 1=快速)
|
||||
stream_chunk_size = 8192 # 流式响应块大小
|
||||
enable_performance_metrics = true # 启用性能监控
|
||||
|
||||
[storage]
|
||||
# 存储配置
|
||||
cleanup_interval = 3600 # 清理间隔(秒)
|
||||
max_storage_size = "10GB"
|
||||
result_retention_days = 7
|
||||
|
||||
[monitoring]
|
||||
# 监控配置
|
||||
enable_metrics = true
|
||||
enable_logging = true
|
||||
log_file = "logs/api.log"
|
||||
max_log_size = "100MB"
|
||||
|
||||
[security]
|
||||
# 安全配置
|
||||
allowed_origins = ["*"]
|
||||
max_concurrent_tasks = 10
|
||||
rate_limit_per_minute = 60
|
||||
|
||||
[development]
|
||||
# 开发环境配置
|
||||
debug = false
|
||||
auto_reload = false
|
||||
cors_enabled = true
|
||||
710
src/chat/utils/rust-video/src/main.rs
Normal file
@@ -0,0 +1,710 @@
|
||||
//! # Rust Video Keyframe Extractor
|
||||
//!
|
||||
//! Ultra-fast video keyframe extraction tool with SIMD optimization.
|
||||
//!
|
||||
//! ## Features
|
||||
//! - AVX2/SSE2 SIMD optimization for maximum performance
|
||||
//! - Memory-efficient streaming processing with FFmpeg
|
||||
//! - Multi-threaded parallel processing
|
||||
//! - Release-optimized for production use
|
||||
//!
|
||||
//! ## Performance
|
||||
//! - 150+ FPS processing speed
|
||||
//! - Real-time video analysis capability
|
||||
//! - Minimal memory footprint
|
||||
//!
|
||||
//! ## Usage
|
||||
//! ```bash
|
||||
//! # Single video processing
|
||||
//! rust-video --input video.mp4 --output ./keyframes --threshold 2.0
|
||||
//!
|
||||
//! # Benchmark mode
|
||||
//! rust-video --benchmark --input video.mp4 --output ./results
|
||||
//! ```
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::prelude::*;
|
||||
use clap::Parser;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::io::{BufReader, Read};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
/// Ultra-fast video keyframe extraction tool
|
||||
#[derive(Parser)]
|
||||
#[command(name = "rust-video")]
|
||||
#[command(version = "0.1.0")]
|
||||
#[command(about = "Ultra-fast video keyframe extraction with SIMD optimization")]
|
||||
#[command(long_about = None)]
|
||||
struct Args {
|
||||
/// Input video file path
|
||||
#[arg(short, long, help = "Path to the input video file")]
|
||||
input: Option<PathBuf>,
|
||||
|
||||
/// Output directory for keyframes and results
|
||||
#[arg(short, long, default_value = "./output", help = "Output directory")]
|
||||
output: PathBuf,
|
||||
|
||||
/// Change threshold for keyframe detection (higher = fewer keyframes)
|
||||
#[arg(short, long, default_value = "2.0", help = "Keyframe detection threshold")]
|
||||
threshold: f64,
|
||||
|
||||
/// Number of parallel threads (0 = auto-detect)
|
||||
#[arg(short = 'j', long, default_value = "0", help = "Number of threads")]
|
||||
threads: usize,
|
||||
|
||||
/// Maximum number of keyframes to save (0 = save all)
|
||||
#[arg(short, long, default_value = "50", help = "Maximum keyframes to save")]
|
||||
max_save: usize,
|
||||
|
||||
/// Run performance benchmark suite
|
||||
#[arg(long, help = "Run comprehensive benchmark tests")]
|
||||
benchmark: bool,
|
||||
|
||||
/// Maximum frames to process (0 = process all frames)
|
||||
#[arg(long, default_value = "0", help = "Limit number of frames to process")]
|
||||
max_frames: usize,
|
||||
|
||||
/// FFmpeg executable path
|
||||
#[arg(long, default_value = "ffmpeg", help = "Path to FFmpeg executable")]
|
||||
ffmpeg_path: PathBuf,
|
||||
|
||||
/// Enable SIMD optimizations (AVX2/SSE2)
|
||||
#[arg(long, default_value = "true", help = "Enable SIMD optimizations")]
|
||||
use_simd: bool,
|
||||
|
||||
/// Processing block size for cache optimization
|
||||
#[arg(long, default_value = "8192", help = "Block size for processing")]
|
||||
block_size: usize,
|
||||
|
||||
/// Verbose output
|
||||
#[arg(short, long, help = "Enable verbose output")]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
/// Video frame representation optimized for SIMD processing
|
||||
#[derive(Debug, Clone)]
|
||||
struct VideoFrame {
|
||||
frame_number: usize,
|
||||
width: usize,
|
||||
height: usize,
|
||||
data: Vec<u8>, // Grayscale data, aligned for SIMD
|
||||
}
|
||||
|
||||
impl VideoFrame {
|
||||
/// Create a new video frame with SIMD-aligned data
|
||||
fn new(frame_number: usize, width: usize, height: usize, mut data: Vec<u8>) -> Self {
|
||||
// Ensure data length is multiple of 32 for AVX2 processing
|
||||
let remainder = data.len() % 32;
|
||||
if remainder != 0 {
|
||||
data.resize(data.len() + (32 - remainder), 0);
|
||||
}
|
||||
|
||||
Self {
|
||||
frame_number,
|
||||
width,
|
||||
height,
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate frame difference using parallel SIMD processing
|
||||
fn calculate_difference_parallel_simd(&self, other: &VideoFrame, block_size: usize, use_simd: bool) -> f64 {
|
||||
if self.width != other.width || self.height != other.height {
|
||||
return f64::MAX;
|
||||
}
|
||||
|
||||
let total_pixels = self.width * self.height;
|
||||
let num_blocks = (total_pixels + block_size - 1) / block_size;
|
||||
|
||||
let total_diff: u64 = (0..num_blocks)
|
||||
.into_par_iter()
|
||||
.map(|block_idx| {
|
||||
let start = block_idx * block_size;
|
||||
let end = ((block_idx + 1) * block_size).min(total_pixels);
|
||||
let block_len = end - start;
|
||||
|
||||
if use_simd {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
unsafe {
|
||||
if std::arch::is_x86_feature_detected!("avx2") {
|
||||
return self.calculate_difference_avx2_block(&other.data, start, block_len);
|
||||
} else if std::arch::is_x86_feature_detected!("sse2") {
|
||||
return self.calculate_difference_sse2_block(&other.data, start, block_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback scalar implementation
|
||||
self.data[start..end]
|
||||
.iter()
|
||||
.zip(other.data[start..end].iter())
|
||||
.map(|(a, b)| (*a as i32 - *b as i32).abs() as u64)
|
||||
.sum()
|
||||
})
|
||||
.sum();
|
||||
|
||||
total_diff as f64 / total_pixels as f64
|
||||
}
|
||||
|
||||
/// Standard frame difference calculation (non-SIMD)
|
||||
fn calculate_difference_standard(&self, other: &VideoFrame) -> f64 {
|
||||
if self.width != other.width || self.height != other.height {
|
||||
return f64::MAX;
|
||||
}
|
||||
|
||||
let len = self.width * self.height;
|
||||
let total_diff: u64 = self.data[..len]
|
||||
.iter()
|
||||
.zip(other.data[..len].iter())
|
||||
.map(|(a, b)| (*a as i32 - *b as i32).abs() as u64)
|
||||
.sum();
|
||||
|
||||
total_diff as f64 / len as f64
|
||||
}
|
||||
|
||||
/// AVX2 optimized block processing
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn calculate_difference_avx2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 {
|
||||
let mut total_diff = 0u64;
|
||||
let chunks = len / 32;
|
||||
|
||||
for i in 0..chunks {
|
||||
let offset = start + i * 32;
|
||||
|
||||
let a = _mm256_loadu_si256(self.data.as_ptr().add(offset) as *const __m256i);
|
||||
let b = _mm256_loadu_si256(other_data.as_ptr().add(offset) as *const __m256i);
|
||||
|
||||
let diff = _mm256_sad_epu8(a, b);
|
||||
let result = _mm256_extract_epi64(diff, 0) as u64 +
|
||||
_mm256_extract_epi64(diff, 1) as u64 +
|
||||
_mm256_extract_epi64(diff, 2) as u64 +
|
||||
_mm256_extract_epi64(diff, 3) as u64;
|
||||
|
||||
total_diff += result;
|
||||
}
|
||||
|
||||
// Process remaining bytes
|
||||
for i in (start + chunks * 32)..(start + len) {
|
||||
total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64;
|
||||
}
|
||||
|
||||
total_diff
|
||||
}
|
||||
|
||||
/// SSE2 optimized block processing
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "sse2")]
|
||||
unsafe fn calculate_difference_sse2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 {
|
||||
let mut total_diff = 0u64;
|
||||
let chunks = len / 16;
|
||||
|
||||
for i in 0..chunks {
|
||||
let offset = start + i * 16;
|
||||
|
||||
let a = _mm_loadu_si128(self.data.as_ptr().add(offset) as *const __m128i);
|
||||
let b = _mm_loadu_si128(other_data.as_ptr().add(offset) as *const __m128i);
|
||||
|
||||
let diff = _mm_sad_epu8(a, b);
|
||||
let result = _mm_extract_epi64(diff, 0) as u64 + _mm_extract_epi64(diff, 1) as u64;
|
||||
|
||||
total_diff += result;
|
||||
}
|
||||
|
||||
// Process remaining bytes
|
||||
for i in (start + chunks * 16)..(start + len) {
|
||||
total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64;
|
||||
}
|
||||
|
||||
total_diff
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance measurement results
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PerformanceResult {
|
||||
test_name: String,
|
||||
video_file: String,
|
||||
total_time_ms: f64,
|
||||
frame_extraction_time_ms: f64,
|
||||
keyframe_analysis_time_ms: f64,
|
||||
total_frames: usize,
|
||||
keyframes_extracted: usize,
|
||||
keyframe_ratio: f64,
|
||||
processing_fps: f64,
|
||||
threshold: f64,
|
||||
optimization_type: String,
|
||||
simd_enabled: bool,
|
||||
threads_used: usize,
|
||||
timestamp: String,
|
||||
}
|
||||
|
||||
/// Extract video frames using FFmpeg memory streaming
|
||||
fn extract_frames_memory_stream(
|
||||
video_path: &PathBuf,
|
||||
ffmpeg_path: &PathBuf,
|
||||
max_frames: usize,
|
||||
verbose: bool,
|
||||
) -> Result<(Vec<VideoFrame>, usize, usize)> {
|
||||
if verbose {
|
||||
println!("🎬 Extracting frames using FFmpeg memory streaming...");
|
||||
println!("📁 Video: {}", video_path.display());
|
||||
}
|
||||
|
||||
// Get video information
|
||||
let probe_output = Command::new(ffmpeg_path)
|
||||
.args(["-i", video_path.to_str().unwrap(), "-hide_banner"])
|
||||
.output()
|
||||
.context("Failed to probe video with FFmpeg")?;
|
||||
|
||||
let probe_info = String::from_utf8_lossy(&probe_output.stderr);
|
||||
let (width, height) = parse_video_dimensions(&probe_info)
|
||||
.ok_or_else(|| anyhow::anyhow!("Cannot parse video dimensions"))?;
|
||||
|
||||
if verbose {
|
||||
println!("📐 Video dimensions: {}x{}", width, height);
|
||||
}
|
||||
|
||||
// Build optimized FFmpeg command
|
||||
let mut cmd = Command::new(ffmpeg_path);
|
||||
cmd.args([
|
||||
"-i", video_path.to_str().unwrap(),
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "gray",
|
||||
"-an", // No audio
|
||||
"-threads", "0", // Auto-detect threads
|
||||
"-preset", "ultrafast", // Fastest preset
|
||||
]);
|
||||
|
||||
if max_frames > 0 {
|
||||
cmd.args(["-frames:v", &max_frames.to_string()]);
|
||||
}
|
||||
|
||||
cmd.args(["-"]).stdout(Stdio::piped()).stderr(Stdio::null());
|
||||
|
||||
let start_time = Instant::now();
|
||||
let mut child = cmd.spawn().context("Failed to spawn FFmpeg process")?;
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let mut reader = BufReader::with_capacity(1024 * 1024, stdout); // 1MB buffer
|
||||
|
||||
let frame_size = width * height;
|
||||
let mut frames = Vec::new();
|
||||
let mut frame_count = 0;
|
||||
let mut frame_buffer = vec![0u8; frame_size];
|
||||
|
||||
if verbose {
|
||||
println!("📦 Frame size: {} bytes", frame_size);
|
||||
}
|
||||
|
||||
// Stream frame data directly into memory
|
||||
loop {
|
||||
match reader.read_exact(&mut frame_buffer) {
|
||||
Ok(()) => {
|
||||
frames.push(VideoFrame::new(
|
||||
frame_count,
|
||||
width,
|
||||
height,
|
||||
frame_buffer.clone(),
|
||||
));
|
||||
frame_count += 1;
|
||||
|
||||
if verbose && frame_count % 200 == 0 {
|
||||
print!("\r⚡ Frames processed: {}", frame_count);
|
||||
}
|
||||
|
||||
if max_frames > 0 && frame_count >= max_frames {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break, // End of stream
|
||||
}
|
||||
}
|
||||
|
||||
let _ = child.wait();
|
||||
|
||||
if verbose {
|
||||
println!("\r✅ Frame extraction complete: {} frames in {:.2}s",
|
||||
frame_count, start_time.elapsed().as_secs_f64());
|
||||
}
|
||||
|
||||
Ok((frames, width, height))
|
||||
}
|
||||
|
||||
/// Parse video dimensions from FFmpeg probe output
|
||||
fn parse_video_dimensions(probe_info: &str) -> Option<(usize, usize)> {
|
||||
for line in probe_info.lines() {
|
||||
if line.contains("Video:") && line.contains("x") {
|
||||
for part in line.split_whitespace() {
|
||||
if let Some(x_pos) = part.find('x') {
|
||||
let width_str = &part[..x_pos];
|
||||
let height_part = &part[x_pos + 1..];
|
||||
let height_str = height_part.split(',').next().unwrap_or(height_part);
|
||||
|
||||
if let (Ok(width), Ok(height)) = (width_str.parse::<usize>(), height_str.parse::<usize>()) {
|
||||
return Some((width, height));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract keyframes using optimized algorithms
|
||||
fn extract_keyframes_optimized(
|
||||
frames: &[VideoFrame],
|
||||
threshold: f64,
|
||||
use_simd: bool,
|
||||
block_size: usize,
|
||||
verbose: bool,
|
||||
) -> Result<Vec<usize>> {
|
||||
if frames.len() < 2 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let optimization_name = if use_simd { "SIMD+Parallel" } else { "Standard Parallel" };
|
||||
if verbose {
|
||||
println!("🚀 Keyframe analysis (threshold: {}, optimization: {})", threshold, optimization_name);
|
||||
}
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Parallel computation of frame differences
|
||||
let differences: Vec<f64> = frames
|
||||
.par_windows(2)
|
||||
.map(|pair| {
|
||||
if use_simd {
|
||||
pair[0].calculate_difference_parallel_simd(&pair[1], block_size, true)
|
||||
} else {
|
||||
pair[0].calculate_difference_standard(&pair[1])
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Find keyframes based on threshold
|
||||
let keyframe_indices: Vec<usize> = differences
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, &diff)| {
|
||||
if diff > threshold {
|
||||
Some(i + 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if verbose {
|
||||
println!("⚡ Analysis complete in {:.2}s", start_time.elapsed().as_secs_f64());
|
||||
println!("🎯 Found {} keyframes", keyframe_indices.len());
|
||||
}
|
||||
|
||||
Ok(keyframe_indices)
|
||||
}
|
||||
|
||||
/// Save keyframes as JPEG images using FFmpeg
|
||||
fn save_keyframes_optimized(
|
||||
video_path: &PathBuf,
|
||||
keyframe_indices: &[usize],
|
||||
output_dir: &PathBuf,
|
||||
ffmpeg_path: &PathBuf,
|
||||
max_save: usize,
|
||||
verbose: bool,
|
||||
) -> Result<usize> {
|
||||
if keyframe_indices.is_empty() {
|
||||
if verbose {
|
||||
println!("⚠️ No keyframes to save");
|
||||
}
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("💾 Saving keyframes...");
|
||||
}
|
||||
|
||||
fs::create_dir_all(output_dir).context("Failed to create output directory")?;
|
||||
|
||||
let save_count = keyframe_indices.len().min(max_save);
|
||||
let mut saved = 0;
|
||||
|
||||
for (i, &frame_idx) in keyframe_indices.iter().take(save_count).enumerate() {
|
||||
let output_path = output_dir.join(format!("keyframe_{:03}.jpg", i + 1));
|
||||
let timestamp = frame_idx as f64 / 30.0; // Assume 30 FPS
|
||||
|
||||
let output = Command::new(ffmpeg_path)
|
||||
.args([
|
||||
"-i", video_path.to_str().unwrap(),
|
||||
"-ss", ×tamp.to_string(),
|
||||
"-vframes", "1",
|
||||
"-q:v", "2", // High quality
|
||||
"-y",
|
||||
output_path.to_str().unwrap(),
|
||||
])
|
||||
.output()
|
||||
.context("Failed to extract keyframe with FFmpeg")?;
|
||||
|
||||
if output.status.success() {
|
||||
saved += 1;
|
||||
if verbose && (saved % 10 == 0 || saved == save_count) {
|
||||
print!("\r💾 Saved: {}/{} keyframes", saved, save_count);
|
||||
}
|
||||
} else if verbose {
|
||||
eprintln!("⚠️ Failed to save keyframe {}", frame_idx);
|
||||
}
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("\r✅ Keyframe saving complete: {}/{}", saved, save_count);
|
||||
}
|
||||
|
||||
Ok(saved)
|
||||
}
|
||||
|
||||
/// Run performance test
|
||||
fn run_performance_test(
|
||||
video_path: &PathBuf,
|
||||
threshold: f64,
|
||||
test_name: &str,
|
||||
ffmpeg_path: &PathBuf,
|
||||
max_frames: usize,
|
||||
use_simd: bool,
|
||||
block_size: usize,
|
||||
verbose: bool,
|
||||
) -> Result<PerformanceResult> {
|
||||
if verbose {
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!("⚡ Running test: {}", test_name);
|
||||
println!("{}", "=".repeat(60));
|
||||
}
|
||||
|
||||
let total_start = Instant::now();
|
||||
|
||||
// Frame extraction
|
||||
let extraction_start = Instant::now();
|
||||
let (frames, _width, _height) = extract_frames_memory_stream(video_path, ffmpeg_path, max_frames, verbose)?;
|
||||
let extraction_time = extraction_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
// Keyframe analysis
|
||||
let analysis_start = Instant::now();
|
||||
let keyframe_indices = extract_keyframes_optimized(&frames, threshold, use_simd, block_size, verbose)?;
|
||||
let analysis_time = analysis_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let total_time = total_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let optimization_type = if use_simd {
|
||||
format!("SIMD+Parallel(block:{})", block_size)
|
||||
} else {
|
||||
"Standard Parallel".to_string()
|
||||
};
|
||||
|
||||
let result = PerformanceResult {
|
||||
test_name: test_name.to_string(),
|
||||
video_file: video_path.file_name().unwrap().to_string_lossy().to_string(),
|
||||
total_time_ms: total_time,
|
||||
frame_extraction_time_ms: extraction_time,
|
||||
keyframe_analysis_time_ms: analysis_time,
|
||||
total_frames: frames.len(),
|
||||
keyframes_extracted: keyframe_indices.len(),
|
||||
keyframe_ratio: keyframe_indices.len() as f64 / frames.len() as f64 * 100.0,
|
||||
processing_fps: frames.len() as f64 / (total_time / 1000.0),
|
||||
threshold,
|
||||
optimization_type,
|
||||
simd_enabled: use_simd,
|
||||
threads_used: rayon::current_num_threads(),
|
||||
timestamp: Local::now().format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
};
|
||||
|
||||
if verbose {
|
||||
println!("\n⚡ Test Results:");
|
||||
println!(" 🕐 Total time: {:.2}ms ({:.2}s)", result.total_time_ms, result.total_time_ms / 1000.0);
|
||||
println!(" 📥 Extraction: {:.2}ms ({:.1}%)", result.frame_extraction_time_ms,
|
||||
result.frame_extraction_time_ms / result.total_time_ms * 100.0);
|
||||
println!(" 🧮 Analysis: {:.2}ms ({:.1}%)", result.keyframe_analysis_time_ms,
|
||||
result.keyframe_analysis_time_ms / result.total_time_ms * 100.0);
|
||||
println!(" 📊 Frames: {}", result.total_frames);
|
||||
println!(" 🎯 Keyframes: {}", result.keyframes_extracted);
|
||||
println!(" 🚀 Speed: {:.1} FPS", result.processing_fps);
|
||||
println!(" ⚙️ Optimization: {}", result.optimization_type);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Run comprehensive benchmark suite
|
||||
fn run_benchmark_suite(video_path: &PathBuf, output_dir: &PathBuf, ffmpeg_path: &PathBuf, args: &Args) -> Result<()> {
|
||||
println!("🚀 Rust Video Keyframe Extractor - Benchmark Suite");
|
||||
println!("🕐 Time: {}", Local::now().format("%Y-%m-%d %H:%M:%S"));
|
||||
println!("🎬 Video: {}", video_path.display());
|
||||
println!("🧵 Threads: {}", rayon::current_num_threads());
|
||||
|
||||
// CPU feature detection
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
println!("🔧 CPU Features:");
|
||||
if std::arch::is_x86_feature_detected!("avx2") {
|
||||
println!(" ✅ AVX2 supported");
|
||||
} else if std::arch::is_x86_feature_detected!("sse2") {
|
||||
println!(" ✅ SSE2 supported");
|
||||
} else {
|
||||
println!(" ⚠️ Scalar only");
|
||||
}
|
||||
}
|
||||
|
||||
let test_configs = vec![
|
||||
("Standard Parallel", false, 8192),
|
||||
("SIMD 8K blocks", true, 8192),
|
||||
("SIMD 16K blocks", true, 16384),
|
||||
("SIMD 32K blocks", true, 32768),
|
||||
];
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (test_name, use_simd, block_size) in test_configs {
|
||||
match run_performance_test(
|
||||
video_path,
|
||||
args.threshold,
|
||||
test_name,
|
||||
ffmpeg_path,
|
||||
1000, // Test with 1000 frames
|
||||
use_simd,
|
||||
block_size,
|
||||
args.verbose,
|
||||
) {
|
||||
Ok(result) => results.push(result),
|
||||
Err(e) => println!("❌ Test failed {}: {:?}", test_name, e),
|
||||
}
|
||||
}
|
||||
|
||||
// Performance comparison table
|
||||
println!("\n{}", "=".repeat(120));
|
||||
println!("🏆 Benchmark Results");
|
||||
println!("{}", "=".repeat(120));
|
||||
|
||||
println!("{:<20} {:<15} {:<12} {:<12} {:<12} {:<8} {:<8} {:<12} {:<20}",
|
||||
"Test", "Total(ms)", "Extract(ms)", "Analyze(ms)", "Speed(FPS)", "Frames", "Keyframes", "Threads", "Optimization");
|
||||
println!("{}", "-".repeat(120));
|
||||
|
||||
for result in &results {
|
||||
println!("{:<20} {:<15.1} {:<12.1} {:<12.1} {:<12.1} {:<8} {:<8} {:<12} {:<20}",
|
||||
result.test_name,
|
||||
result.total_time_ms,
|
||||
result.frame_extraction_time_ms,
|
||||
result.keyframe_analysis_time_ms,
|
||||
result.processing_fps,
|
||||
result.total_frames,
|
||||
result.keyframes_extracted,
|
||||
result.threads_used,
|
||||
result.optimization_type);
|
||||
}
|
||||
|
||||
// Find best performance
|
||||
if let Some(best_result) = results.iter().max_by(|a, b| a.processing_fps.partial_cmp(&b.processing_fps).unwrap()) {
|
||||
println!("\n🏆 Best Performance: {}", best_result.test_name);
|
||||
println!(" ⚡ Speed: {:.1} FPS", best_result.processing_fps);
|
||||
println!(" 🕐 Time: {:.2}s", best_result.total_time_ms / 1000.0);
|
||||
println!(" 🧮 Analysis: {:.2}s", best_result.keyframe_analysis_time_ms / 1000.0);
|
||||
println!(" ⚙️ Tech: {}", best_result.optimization_type);
|
||||
}
|
||||
|
||||
// Save detailed results
|
||||
fs::create_dir_all(output_dir).context("Failed to create output directory")?;
|
||||
let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string();
|
||||
let results_file = output_dir.join(format!("benchmark_results_{}.json", timestamp));
|
||||
|
||||
let json_results = serde_json::to_string_pretty(&results)?;
|
||||
fs::write(&results_file, json_results)?;
|
||||
|
||||
println!("\n📄 Detailed results saved to: {}", results_file.display());
|
||||
println!("{}", "=".repeat(120));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Setup thread pool
|
||||
if args.threads > 0 {
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(args.threads)
|
||||
.build_global()
|
||||
.context("Failed to set thread pool")?;
|
||||
}
|
||||
|
||||
println!("🚀 Rust Video Keyframe Extractor v0.1.0");
|
||||
println!("🧵 Threads: {}", rayon::current_num_threads());
|
||||
|
||||
// Verify FFmpeg availability
|
||||
if !args.ffmpeg_path.exists() && args.ffmpeg_path.to_str() == Some("ffmpeg") {
|
||||
// Try to find ffmpeg in PATH
|
||||
if Command::new("ffmpeg").arg("-version").output().is_err() {
|
||||
anyhow::bail!("FFmpeg not found. Please install FFmpeg or specify path with --ffmpeg-path");
|
||||
}
|
||||
} else if !args.ffmpeg_path.exists() {
|
||||
anyhow::bail!("FFmpeg not found at: {}", args.ffmpeg_path.display());
|
||||
}
|
||||
|
||||
if args.benchmark {
|
||||
// Benchmark mode
|
||||
let video_path = args.input.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("Benchmark requires input video file --input <path>"))?;
|
||||
|
||||
if !video_path.exists() {
|
||||
anyhow::bail!("Video file not found: {}", video_path.display());
|
||||
}
|
||||
|
||||
run_benchmark_suite(&video_path, &args.output, &args.ffmpeg_path, &args)?;
|
||||
} else {
|
||||
// Single processing mode
|
||||
let video_path = args.input
|
||||
.ok_or_else(|| anyhow::anyhow!("Please specify input video file --input <path>"))?;
|
||||
|
||||
if !video_path.exists() {
|
||||
anyhow::bail!("Video file not found: {}", video_path.display());
|
||||
}
|
||||
|
||||
// Run single keyframe extraction
|
||||
let result = run_performance_test(
|
||||
&video_path,
|
||||
args.threshold,
|
||||
"Single Processing",
|
||||
&args.ffmpeg_path,
|
||||
args.max_frames,
|
||||
args.use_simd,
|
||||
args.block_size,
|
||||
args.verbose,
|
||||
)?;
|
||||
|
||||
// Extract and save keyframes
|
||||
let (frames, _, _) = extract_frames_memory_stream(&video_path, &args.ffmpeg_path, args.max_frames, args.verbose)?;
|
||||
let keyframe_indices = extract_keyframes_optimized(&frames, args.threshold, args.use_simd, args.block_size, args.verbose)?;
|
||||
let saved_count = save_keyframes_optimized(&video_path, &keyframe_indices, &args.output, &args.ffmpeg_path, args.max_save, args.verbose)?;
|
||||
|
||||
println!("\n✅ Processing Complete!");
|
||||
println!("🎯 Keyframes extracted: {}", result.keyframes_extracted);
|
||||
println!("💾 Keyframes saved: {}", saved_count);
|
||||
println!("⚡ Processing speed: {:.1} FPS", result.processing_fps);
|
||||
println!("📁 Output directory: {}", args.output.display());
|
||||
|
||||
// Save processing report
|
||||
let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string();
|
||||
let report_file = args.output.join(format!("processing_report_{}.json", timestamp));
|
||||
let json_result = serde_json::to_string_pretty(&result)?;
|
||||
fs::write(&report_file, json_result)?;
|
||||
|
||||
if args.verbose {
|
||||
println!("📄 Processing report saved to: {}", report_file.display());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
219
src/chat/utils/rust-video/start_server.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
启动脚本
|
||||
|
||||
支持开发模式和生产模式启动
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from config import config
|
||||
|
||||
|
||||
def check_rust_executable():
|
||||
"""检查 Rust 可执行文件是否存在"""
|
||||
rust_config = config.get("rust")
|
||||
executable_name = rust_config.get("executable_name", "video_keyframe_extractor")
|
||||
executable_path = rust_config.get("executable_path", "target/release")
|
||||
|
||||
possible_paths = [
|
||||
f"./{executable_path}/{executable_name}.exe",
|
||||
f"./{executable_path}/{executable_name}",
|
||||
f"./{executable_name}.exe",
|
||||
f"./{executable_name}"
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if Path(path).exists():
|
||||
print(f"✓ Found Rust executable: {path}")
|
||||
return str(Path(path).absolute())
|
||||
|
||||
print("⚠ Warning: Rust executable not found")
|
||||
print("Please compile first: cargo build --release")
|
||||
return None
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""检查 Python 依赖"""
|
||||
try:
|
||||
import fastapi
|
||||
import uvicorn
|
||||
print("✓ FastAPI dependencies available")
|
||||
return True
|
||||
except ImportError as e:
|
||||
print(f"✗ Missing dependencies: {e}")
|
||||
print("Please install: pip install -r requirements.txt")
|
||||
return False
|
||||
|
||||
|
||||
def install_dependencies():
|
||||
"""安装依赖"""
|
||||
print("Installing dependencies...")
|
||||
try:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
|
||||
check=True)
|
||||
print("✓ Dependencies installed successfully")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"✗ Failed to install dependencies: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def start_development_server(host="127.0.0.1", port=8050, reload=True):
|
||||
"""启动开发服务器"""
|
||||
print(f" Starting development server on http://{host}:{port}")
|
||||
print(f" API docs: http://{host}:{port}/docs")
|
||||
print(f" Health check: http://{host}:{port}/health")
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"api_server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
log_level="info"
|
||||
)
|
||||
except ImportError:
|
||||
print("uvicorn not found, trying with subprocess...")
|
||||
subprocess.run([
|
||||
sys.executable, "-m", "uvicorn",
|
||||
"api_server:app",
|
||||
"--host", host,
|
||||
"--port", str(port),
|
||||
"--reload" if reload else ""
|
||||
])
|
||||
|
||||
|
||||
def start_production_server(host="0.0.0.0", port=8000, workers=4):
|
||||
"""启动生产服务器"""
|
||||
print(f"🚀 Starting production server on http://{host}:{port}")
|
||||
print(f"Workers: {workers}")
|
||||
|
||||
subprocess.run([
|
||||
sys.executable, "-m", "uvicorn",
|
||||
"api_server:app",
|
||||
"--host", host,
|
||||
"--port", str(port),
|
||||
"--workers", str(workers),
|
||||
"--log-level", "warning"
|
||||
])
|
||||
|
||||
|
||||
def create_systemd_service():
|
||||
"""创建 systemd 服务文件"""
|
||||
current_dir = Path.cwd()
|
||||
python_path = sys.executable
|
||||
|
||||
service_content = f"""[Unit]
|
||||
Description=Video Keyframe Extraction API Server
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=exec
|
||||
User=www-data
|
||||
WorkingDirectory={current_dir}
|
||||
Environment=PATH=/usr/bin:/usr/local/bin
|
||||
ExecStart={python_path} -m uvicorn api_server:app --host 0.0.0.0 --port 8000 --workers 4
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
"""
|
||||
|
||||
service_file = Path("/etc/systemd/system/video-keyframe-api.service")
|
||||
|
||||
try:
|
||||
with open(service_file, 'w') as f:
|
||||
f.write(service_content)
|
||||
|
||||
print(f"✓ Systemd service created: {service_file}")
|
||||
print("To enable and start:")
|
||||
print(" sudo systemctl enable video-keyframe-api")
|
||||
print(" sudo systemctl start video-keyframe-api")
|
||||
|
||||
except PermissionError:
|
||||
print("✗ Permission denied. Please run with sudo for systemd service creation")
|
||||
|
||||
# 创建本地副本
|
||||
local_service = Path("./video-keyframe-api.service")
|
||||
with open(local_service, 'w') as f:
|
||||
f.write(service_content)
|
||||
|
||||
print(f"✓ Service file created locally: {local_service}")
|
||||
print(f"To install: sudo cp {local_service} /etc/systemd/system/")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Video Keyframe Extraction API Server")
|
||||
|
||||
# 从配置文件获取默认值
|
||||
server_config = config.get_server_config()
|
||||
|
||||
parser.add_argument("--mode", choices=["dev", "prod", "install"], default="dev",
|
||||
help="运行模式: dev (开发), prod (生产), install (安装依赖)")
|
||||
parser.add_argument("--host", default=server_config.get("host", "127.0.0.1"), help="绑定主机")
|
||||
parser.add_argument("--port", type=int, default=server_config.get("port", 8000), help="端口号")
|
||||
parser.add_argument("--workers", type=int, default=server_config.get("workers", 4), help="生产模式工作进程数")
|
||||
parser.add_argument("--no-reload", action="store_true", help="禁用自动重载")
|
||||
parser.add_argument("--check", action="store_true", help="仅检查环境")
|
||||
parser.add_argument("--create-service", action="store_true", help="创建 systemd 服务")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=== Video Keyframe Extraction API Server ===")
|
||||
|
||||
# 检查环境
|
||||
rust_exe = check_rust_executable()
|
||||
deps_ok = check_dependencies()
|
||||
|
||||
if args.check:
|
||||
print("\n=== Environment Check ===")
|
||||
print(f"Rust executable: {'✓' if rust_exe else '✗'}")
|
||||
print(f"Python dependencies: {'✓' if deps_ok else '✗'}")
|
||||
return
|
||||
|
||||
if args.create_service:
|
||||
create_systemd_service()
|
||||
return
|
||||
|
||||
# 安装模式
|
||||
if args.mode == "install":
|
||||
if not deps_ok:
|
||||
install_dependencies()
|
||||
else:
|
||||
print("✓ Dependencies already installed")
|
||||
return
|
||||
|
||||
# 检查必要条件
|
||||
if not rust_exe:
|
||||
print("✗ Cannot start without Rust executable")
|
||||
print("Please run: cargo build --release")
|
||||
sys.exit(1)
|
||||
|
||||
if not deps_ok:
|
||||
print("Installing missing dependencies...")
|
||||
if not install_dependencies():
|
||||
sys.exit(1)
|
||||
|
||||
# 启动服务器
|
||||
if args.mode == "dev":
|
||||
start_development_server(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=not args.no_reload
|
||||
)
|
||||
elif args.mode == "prod":
|
||||
start_production_server(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
workers=args.workers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||
@@ -21,6 +22,7 @@ class CacheManager:
|
||||
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -32,7 +34,7 @@ class CacheManager:
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self.default_ttl = default_ttl
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -79,7 +81,7 @@ class CacheManager:
|
||||
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
|
||||
return None
|
||||
|
||||
return embedding_array.astype('float32')
|
||||
return embedding_array.astype("float32")
|
||||
else:
|
||||
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
|
||||
return None
|
||||
@@ -104,12 +106,18 @@ class CacheManager:
|
||||
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
|
||||
|
||||
try:
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
except TypeError:
|
||||
sorted_args = repr(sorted(function_args.items()))
|
||||
return f"{tool_name}::{sorted_args}::{file_hash}"
|
||||
|
||||
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
|
||||
async def get(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
semantic_query: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||
"""
|
||||
@@ -136,12 +144,12 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
query_embedding = np.array([validated_embedding], dtype='float32')
|
||||
query_embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||
faiss.normalize_L2(query_embedding)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
hit_index = indices[0][0]
|
||||
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||
@@ -151,10 +159,7 @@ class CacheManager:
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
|
||||
)
|
||||
|
||||
if cache_results_obj:
|
||||
@@ -172,8 +177,8 @@ class CacheManager:
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
|
||||
}
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
|
||||
},
|
||||
)
|
||||
|
||||
# 回填 L1
|
||||
@@ -181,11 +186,7 @@ class CacheManager:
|
||||
return data
|
||||
else:
|
||||
# 删除过期的缓存条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (VectorDB Service)
|
||||
if query_embedding is not None:
|
||||
@@ -193,14 +194,16 @@ class CacheManager:
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=1
|
||||
n_results=1,
|
||||
)
|
||||
if results and results.get("ids") and results["ids"][0]:
|
||||
distance = (
|
||||
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
|
||||
)
|
||||
if results and results.get('ids') and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
if distance != "N/A" and distance < 0.75:
|
||||
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
@@ -208,7 +211,7 @@ class CacheManager:
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
if semantic_cache_results_obj:
|
||||
@@ -235,7 +238,15 @@ class CacheManager:
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
|
||||
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
|
||||
async def set(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
data: Any,
|
||||
ttl: Optional[int] = None,
|
||||
semantic_query: Optional[str] = None,
|
||||
):
|
||||
"""将结果存入所有缓存层。"""
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
@@ -251,20 +262,15 @@ class CacheManager:
|
||||
# 写入 L2 (数据库)
|
||||
cache_data = {
|
||||
"cache_key": key,
|
||||
"cache_value": orjson.dumps(data).decode('utf-8'),
|
||||
"cache_value": orjson.dumps(data).decode("utf-8"),
|
||||
"expires_at": expires_at,
|
||||
"tool_name": tool_name,
|
||||
"created_at": time.time(),
|
||||
"last_accessed": time.time(),
|
||||
"access_count": 1
|
||||
"access_count": 1,
|
||||
}
|
||||
|
||||
await db_save(
|
||||
model_class=CacheEntries,
|
||||
data=cache_data,
|
||||
key_field="cache_key",
|
||||
key_value=key
|
||||
)
|
||||
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model:
|
||||
@@ -274,7 +280,7 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
@@ -286,7 +292,7 @@ class CacheManager:
|
||||
vector_db_service.add(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
embeddings=embedding.tolist(),
|
||||
ids=[key]
|
||||
ids=[key],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
@@ -306,7 +312,7 @@ class CacheManager:
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={} # 删除所有记录
|
||||
filters={}, # 删除所有记录
|
||||
)
|
||||
|
||||
# 清空 VectorDB
|
||||
@@ -338,14 +344,12 @@ class CacheManager:
|
||||
del self.l1_kv_cache[key]
|
||||
|
||||
# 清理L2过期条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"expires_at": {"$lt": current_time}}
|
||||
)
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
|
||||
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
|
||||
@@ -1,405 +0,0 @@
|
||||
"""工具执行历史记录模块"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
|
||||
from .logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_history")
|
||||
|
||||
class ToolHistoryManager:
|
||||
"""工具执行历史记录管理器"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._history: List[Dict[str, Any]] = []
|
||||
self._initialized = True
|
||||
self._data_dir = Path("data/tool_history")
|
||||
self._data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._history_file = self._data_dir / "tool_history.jsonl"
|
||||
self._load_history()
|
||||
|
||||
def _save_history(self):
|
||||
"""保存所有历史记录到文件"""
|
||||
try:
|
||||
with self._history_file.open("w", encoding="utf-8") as f:
|
||||
for record in self._history:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存工具调用记录失败: {e}")
|
||||
|
||||
def _save_record(self, record: Dict[str, Any]):
|
||||
"""保存单条记录到文件"""
|
||||
try:
|
||||
with self._history_file.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存工具调用记录失败: {e}")
|
||||
|
||||
def _clean_expired_records(self):
|
||||
"""清理已过期的记录"""
|
||||
original_count = len(self._history)
|
||||
self._history = [record for record in self._history if record.get("ttl_count", 0) < record.get("ttl", 5)]
|
||||
cleaned_count = original_count - len(self._history)
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"清理了 {cleaned_count} 条过期的工具历史记录,剩余 {len(self._history)} 条")
|
||||
self._save_history()
|
||||
else:
|
||||
logger.debug("没有需要清理的过期工具历史记录")
|
||||
|
||||
def record_tool_call(self,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
result: Any,
|
||||
execution_time: float,
|
||||
status: str,
|
||||
chat_id: Optional[str] = None,
|
||||
ttl: int = 5):
|
||||
"""记录工具调用
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具调用参数
|
||||
result: 工具返回结果
|
||||
execution_time: 执行时间(秒)
|
||||
status: 执行状态("completed"或"error")
|
||||
chat_id: 聊天ID,与ChatManager中的chat_id对应,用于标识群聊或私聊会话
|
||||
ttl: 该记录的生命周期值,插入提示词多少次后删除,默认为5
|
||||
"""
|
||||
# 检查是否启用历史记录且ttl大于0
|
||||
if not global_config.tool.history.enable_history or ttl <= 0:
|
||||
return
|
||||
|
||||
# 先清理过期记录
|
||||
self._clean_expired_records()
|
||||
|
||||
try:
|
||||
# 创建记录
|
||||
record = {
|
||||
"tool_name": tool_name,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"arguments": self._sanitize_args(args),
|
||||
"result": self._sanitize_result(result),
|
||||
"execution_time": execution_time,
|
||||
"status": status,
|
||||
"chat_id": chat_id,
|
||||
"ttl": ttl,
|
||||
"ttl_count": 0
|
||||
}
|
||||
|
||||
# 添加到内存中的历史记录
|
||||
self._history.append(record)
|
||||
|
||||
# 保存到文件
|
||||
self._save_record(record)
|
||||
|
||||
if status == "completed":
|
||||
logger.info(f"工具 {tool_name} 调用完成,耗时:{execution_time:.2f}s")
|
||||
else:
|
||||
logger.error(f"工具 {tool_name} 调用失败:{result}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录工具调用时发生错误: {e}")
|
||||
|
||||
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""清理参数中的敏感信息"""
|
||||
sensitive_keys = ['api_key', 'token', 'password', 'secret']
|
||||
sanitized = args.copy()
|
||||
|
||||
def _sanitize_value(value):
|
||||
if isinstance(value, dict):
|
||||
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
|
||||
for k, v in value.items()}
|
||||
return value
|
||||
|
||||
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
|
||||
for k, v in sanitized.items()}
|
||||
|
||||
def _sanitize_result(self, result: Any) -> Any:
|
||||
"""清理结果中的敏感信息"""
|
||||
if isinstance(result, dict):
|
||||
return self._sanitize_args(result)
|
||||
return result
|
||||
|
||||
def _load_history(self):
|
||||
"""加载历史记录文件"""
|
||||
try:
|
||||
if self._history_file.exists():
|
||||
self._history = []
|
||||
with self._history_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if record.get("ttl_count", 0) < record.get("ttl", 5): # 只加载未过期的记录
|
||||
self._history.append(record)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
logger.info(f"成功加载了 {len(self._history)} 条历史记录")
|
||||
except Exception as e:
|
||||
logger.error(f"加载历史记录失败: {e}")
|
||||
|
||||
def query_history(self,
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""查询工具调用历史
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 聊天ID,与ChatManager中的chat_id对应,用于查询特定群聊或私聊的历史记录
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
符合条件的历史记录列表
|
||||
"""
|
||||
# 先清理过期记录
|
||||
self._clean_expired_records()
|
||||
def _parse_time(time_str: Optional[Union[datetime, str]]) -> Optional[datetime]:
|
||||
if isinstance(time_str, datetime):
|
||||
return time_str
|
||||
elif isinstance(time_str, str):
|
||||
return datetime.fromisoformat(time_str)
|
||||
return None
|
||||
|
||||
filtered_history = self._history
|
||||
|
||||
# 按工具名筛选
|
||||
if tool_names:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record["tool_name"] in tool_names
|
||||
]
|
||||
|
||||
# 按时间范围筛选
|
||||
start_dt = _parse_time(start_time)
|
||||
end_dt = _parse_time(end_time)
|
||||
|
||||
if start_dt:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if datetime.fromisoformat(record["timestamp"]) >= start_dt
|
||||
]
|
||||
|
||||
if end_dt:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if datetime.fromisoformat(record["timestamp"]) <= end_dt
|
||||
]
|
||||
|
||||
# 按聊天ID筛选
|
||||
if chat_id:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record.get("chat_id") == chat_id
|
||||
]
|
||||
|
||||
# 按状态筛选
|
||||
if status:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record["status"] == status
|
||||
]
|
||||
|
||||
# 应用数量限制
|
||||
if limit:
|
||||
filtered_history = filtered_history[-limit:]
|
||||
|
||||
return filtered_history
|
||||
|
||||
def get_recent_history_prompt(self,
|
||||
limit: Optional[int] = None,
|
||||
chat_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取最近工具调用历史的提示词
|
||||
|
||||
Args:
|
||||
limit: 返回的历史记录数量,如果不提供则使用配置中的max_history
|
||||
chat_id: 会话ID,用于只获取当前会话的历史
|
||||
|
||||
Returns:
|
||||
格式化的历史记录提示词
|
||||
"""
|
||||
# 检查是否启用历史记录
|
||||
if not global_config.tool.history.enable_history:
|
||||
return ""
|
||||
|
||||
# 使用配置中的最大历史记录数
|
||||
if limit is None:
|
||||
limit = global_config.tool.history.max_history
|
||||
|
||||
recent_history = self.query_history(
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not recent_history:
|
||||
return ""
|
||||
|
||||
prompt = "\n工具执行历史:\n"
|
||||
needs_save = False
|
||||
updated_history = []
|
||||
|
||||
for record in recent_history:
|
||||
# 增加ttl计数
|
||||
record["ttl_count"] = record.get("ttl_count", 0) + 1
|
||||
needs_save = True
|
||||
|
||||
# 如果未超过ttl,则添加到提示词中
|
||||
if record["ttl_count"] < record.get("ttl", 5):
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容,去除多余空白和换行
|
||||
content = content.strip().replace('\n', ' ')
|
||||
|
||||
# 如果内容太长则截断
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
prompt += f"{name}: \n{content}\n\n"
|
||||
updated_history.append(record)
|
||||
|
||||
# 更新历史记录并保存
|
||||
if needs_save:
|
||||
self._history = updated_history
|
||||
self._save_history()
|
||||
|
||||
return prompt
|
||||
|
||||
def clear_history(self):
|
||||
"""清除历史记录"""
|
||||
self._history.clear()
|
||||
self._save_history()
|
||||
logger.info("工具调用历史记录已清除")
|
||||
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加历史记录和缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
history_manager = ToolHistoryManager()
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
start_time = time.time()
|
||||
|
||||
# 确保我们有 tool_instance
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
# 如果没有 tool_instance,就无法进行缓存检查,直接执行
|
||||
if not tool_instance:
|
||||
result = await original_execute(self, tool_call, None)
|
||||
execution_time = time.time() - start_time
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=5 # Default TTL
|
||||
)
|
||||
return result
|
||||
|
||||
# 新的缓存逻辑
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
|
||||
try:
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# 缓存结果
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
|
||||
# 记录成功的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
# 记录失败的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=str(e),
|
||||
execution_time=execution_time,
|
||||
status="error",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
raise
|
||||
|
||||
# 替换原始方法
|
||||
ToolExecutor.execute_tool_call = wrapped_execute_tool_call
|
||||
@@ -621,6 +621,7 @@ class WakeUpSystemConfig(ValidatedConfigBase):
|
||||
decay_interval: float = Field(default=30.0, ge=1.0, description="唤醒度衰减间隔(秒)")
|
||||
angry_duration: float = Field(default=300.0, ge=10.0, description="愤怒状态持续时间(秒)")
|
||||
angry_prompt: str = Field(default="你被人吵醒了非常生气,说话带着怒气", description="被吵醒后的愤怒提示词")
|
||||
re_sleep_delay_minutes: int = Field(default=5, ge=1, description="被唤醒后,如果多久没有新消息则尝试重新入睡(分钟)")
|
||||
|
||||
# --- 失眠机制相关参数 ---
|
||||
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
|
||||
|
||||
@@ -5,7 +5,7 @@ import random
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
@@ -283,19 +283,26 @@ class LLMRequest:
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""执行单次请求"""
|
||||
# 模型选择和请求准备
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
model_name = model_info.name
|
||||
"""
|
||||
执行单次请求,并在模型失败时按顺序切换到下一个可用模型。
|
||||
"""
|
||||
failed_models = set()
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
model_scheduler = self._model_scheduler(failed_models)
|
||||
|
||||
for model_info, api_provider, client in model_scheduler:
|
||||
start_time = time.time()
|
||||
model_name = model_info.name
|
||||
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
|
||||
|
||||
try:
|
||||
# 检查是否启用反截断
|
||||
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
|
||||
|
||||
processed_prompt = prompt
|
||||
if use_anti_truncation:
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能")
|
||||
logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能")
|
||||
|
||||
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||
|
||||
@@ -304,13 +311,12 @@ class LLMRequest:
|
||||
messages = [message_builder.build()]
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
# 空回复重试逻辑
|
||||
# 针对当前模型的空回复/截断重试逻辑
|
||||
empty_retry_count = 0
|
||||
max_empty_retry = api_provider.max_retry
|
||||
empty_retry_interval = api_provider.retry_interval
|
||||
|
||||
while empty_retry_count <= max_empty_retry:
|
||||
try:
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
@@ -321,93 +327,86 @@ class LLMRequest:
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
content = response.content or ""
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
is_empty_reply = False
|
||||
is_empty_reply = not tool_calls and (not content or content.strip() == "")
|
||||
is_truncated = False
|
||||
# 检测是否为空回复或截断
|
||||
if not tool_calls:
|
||||
is_empty_reply = not content or content.strip() == ""
|
||||
is_truncated = False
|
||||
|
||||
if use_anti_truncation:
|
||||
if content.endswith("[done]"):
|
||||
content = content[:-6].strip()
|
||||
logger.debug("检测到并已移除 [done] 标记")
|
||||
else:
|
||||
is_truncated = True
|
||||
logger.warning("未检测到 [done] 标记,判定为截断")
|
||||
|
||||
if is_empty_reply or is_truncated:
|
||||
if empty_retry_count < max_empty_retry:
|
||||
empty_retry_count += 1
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
|
||||
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
|
||||
model_info, api_provider, client = self._select_model()
|
||||
continue
|
||||
else:
|
||||
# 已达到最大重试次数,但仍然是空回复或截断
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
# 抛出异常,由外层重试逻辑或最终的异常处理器捕获
|
||||
raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复")
|
||||
|
||||
# 记录使用情况
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
time_cost=time.time() - start_time,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
|
||||
# 处理空回复
|
||||
if not content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复")
|
||||
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
|
||||
elif empty_retry_count > 0:
|
||||
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
|
||||
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"请求执行失败: {e}")
|
||||
if raise_when_empty:
|
||||
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
|
||||
if empty_retry_count == 0:
|
||||
raise
|
||||
|
||||
# 如果在重试过程中失败,则继续重试
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue
|
||||
continue # 继续使用当前模型重试
|
||||
else:
|
||||
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
|
||||
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
|
||||
else:
|
||||
# 在并发模式下,单个请求的失败不应中断整个并发流程,
|
||||
# 而是将异常返回给调用者(即 execute_concurrently)进行统一处理
|
||||
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
|
||||
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
|
||||
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
|
||||
|
||||
# 重试失败
|
||||
# 成功获取响应
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
|
||||
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
|
||||
)
|
||||
|
||||
if not content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复")
|
||||
return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None)
|
||||
raise RuntimeError("生成空回复")
|
||||
content = "生成的响应为空"
|
||||
|
||||
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
|
||||
return content, (reasoning_content, model_name, tool_calls)
|
||||
|
||||
except RespNotOkException as e:
|
||||
if e.status_code in [401, 403]:
|
||||
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
else:
|
||||
logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}")
|
||||
if raise_when_empty:
|
||||
raise
|
||||
# 对于其他HTTP错误,直接抛出,不再尝试其他模型
|
||||
return f"请求失败: {e}", ("", model_name, None)
|
||||
|
||||
except RuntimeError as e:
|
||||
# 捕获所有重试失败(包括空回复和网络问题)
|
||||
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
|
||||
# 所有模型都尝试失败
|
||||
logger.error("所有可用模型都已尝试失败。")
|
||||
if raise_when_empty:
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型都请求失败") from last_exception
|
||||
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
|
||||
|
||||
return "所有模型都请求失败", ("", "unknown", None)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
@@ -446,9 +445,24 @@ class LLMRequest:
|
||||
|
||||
return embedding, model_info.name
|
||||
|
||||
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
|
||||
"""
|
||||
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
|
||||
"""
|
||||
for model_name in self.model_for_task.model_list:
|
||||
if model_name in failed_models:
|
||||
continue
|
||||
|
||||
model_info = model_config.get_model_info(model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
yield model_info, api_provider, client
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
根据总tokens和惩罚值选择的模型
|
||||
根据总tokens和惩罚值选择的模型 (负载均衡)
|
||||
"""
|
||||
least_used_model_name = min(
|
||||
self.model_usage,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.tool_history import ToolHistoryManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
@@ -34,109 +32,3 @@ def get_llm_available_tool_definitions():
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
|
||||
def get_tool_history(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用历史记录
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
List[Dict]: 工具调用记录列表,每条记录包含以下字段:
|
||||
- tool_name: 工具名称
|
||||
- timestamp: 调用时间
|
||||
- arguments: 调用参数
|
||||
- result: 调用结果
|
||||
- execution_time: 执行时间
|
||||
- status: 执行状态
|
||||
- chat_id: 会话ID
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
return history_manager.query_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
|
||||
def get_tool_history_text(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
获取工具调用历史记录的文本格式
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
str: 格式化的工具调用历史记录文本
|
||||
"""
|
||||
history = get_tool_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
if not history:
|
||||
return "没有找到工具调用记录"
|
||||
|
||||
text = "工具调用历史记录:\n"
|
||||
for record in history:
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容
|
||||
content = content.strip().replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
# 格式化时间
|
||||
timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
text += f"[{timestamp}] {name}\n"
|
||||
text += f"结果: {content}\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clear_tool_history() -> None:
|
||||
"""
|
||||
清除所有工具调用历史记录
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
history_manager.clear_history()
|
||||
@@ -119,7 +119,7 @@ class BaseEvent:
|
||||
for i, result in enumerate(results):
|
||||
subscriber = sorted_subscribers[i]
|
||||
handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__
|
||||
|
||||
if result:
|
||||
if isinstance(result, Exception):
|
||||
# 处理执行异常
|
||||
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")
|
||||
|
||||
@@ -26,7 +26,6 @@ class BaseEventHandler(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union, TYPE_CHECKING
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
from typing import Tuple, Optional, List
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -167,6 +167,7 @@ class ComponentRegistry:
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
action_class.plugin_name = action_info.plugin_name
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
@@ -184,6 +185,7 @@ class ComponentRegistry:
|
||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
||||
return False
|
||||
|
||||
command_class.plugin_name = command_info.plugin_name
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 如果启用了且有匹配模式
|
||||
@@ -213,6 +215,7 @@ class ComponentRegistry:
|
||||
if not hasattr(self, '_plus_command_registry'):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
self._plus_command_registry[plus_command_name] = plus_command_class
|
||||
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
@@ -222,6 +225,7 @@ class ComponentRegistry:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
tool_class.plugin_name = tool_info.plugin_name
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
@@ -246,6 +250,7 @@ class ComponentRegistry:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
return event_manager.register_event_handler(handler_class)
|
||||
|
||||
@@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import inspect
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -184,21 +186,65 @@ class ToolExecutor:
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
# 如果工具不存在或未启用缓存,则直接执行
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
|
||||
# --- 缓存逻辑开始 ---
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
|
||||
# 缓存未命中,执行原始工具调用
|
||||
result = await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
|
||||
# 将结果存入缓存
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
# --- 缓存逻辑结束 ---
|
||||
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
|
||||
@@ -24,6 +24,7 @@ from .services.qzone_service import QZoneService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.cookie_service import CookieService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.manager import register_service
|
||||
|
||||
logger = get_logger("MaiZone.Plugin")
|
||||
@@ -99,11 +100,13 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
content_service = ContentService(self.get_config)
|
||||
image_service = ImageService(self.get_config)
|
||||
cookie_service = CookieService(self.get_config)
|
||||
reply_tracker_service = ReplyTrackerService()
|
||||
qzone_service = QZoneService(self.get_config, content_service, image_service, cookie_service)
|
||||
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
||||
monitor_service = MonitorService(self.get_config, qzone_service)
|
||||
|
||||
register_service("qzone", qzone_service)
|
||||
register_service("reply_tracker", reply_tracker_service)
|
||||
register_service("get_config", self.get_config)
|
||||
|
||||
# 保存服务引用以便后续启动
|
||||
|
||||
@@ -9,12 +9,9 @@ import datetime
|
||||
import base64
|
||||
import aiohttp
|
||||
from src.common.logger import get_logger
|
||||
import base64
|
||||
import aiohttp
|
||||
import imghdr
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api, person_api
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -27,6 +27,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
from .content_service import ContentService
|
||||
from .image_service import ImageService
|
||||
from .cookie_service import CookieService
|
||||
from .reply_tracker_service import ReplyTrackerService
|
||||
|
||||
logger = get_logger("MaiZone.QZoneService")
|
||||
|
||||
@@ -55,6 +56,7 @@ class QZoneService:
|
||||
self.content_service = content_service
|
||||
self.image_service = image_service
|
||||
self.cookie_service = cookie_service
|
||||
self.reply_tracker = ReplyTrackerService()
|
||||
|
||||
# --- Public Methods (High-Level Business Logic) ---
|
||||
|
||||
@@ -154,7 +156,8 @@ class QZoneService:
|
||||
# --- 第一步: 单独处理自己说说的评论 ---
|
||||
if self.get_config("monitor.enable_auto_reply", False):
|
||||
try:
|
||||
own_feeds = await api_client["list_feeds"](qq_account, 5) # 获取自己最近5条说说
|
||||
# 传入新参数,表明正在检查自己的说说
|
||||
own_feeds = await api_client["list_feeds"](qq_account, 5, is_monitoring_own_feeds=True)
|
||||
if own_feeds:
|
||||
logger.info(f"获取到自己 {len(own_feeds)} 条说说,检查评论...")
|
||||
for feed in own_feeds:
|
||||
@@ -248,42 +251,83 @@ class QZoneService:
|
||||
content = feed.get("content", "")
|
||||
fid = feed.get("tid", "")
|
||||
|
||||
if not comments:
|
||||
if not comments or not fid:
|
||||
return
|
||||
|
||||
# 筛选出未被自己回复过的评论
|
||||
if not comments:
|
||||
# 1. 将评论分为用户评论和自己的回复
|
||||
user_comments = [c for c in comments if str(c.get('qq_account')) != str(qq_account)]
|
||||
my_replies = [c for c in comments if str(c.get('qq_account')) == str(qq_account)]
|
||||
|
||||
if not user_comments:
|
||||
return
|
||||
|
||||
# 找到所有我已经回复过的评论的ID
|
||||
replied_to_tids = {
|
||||
c['parent_tid'] for c in comments
|
||||
if c.get('parent_tid') and str(c.get('qq_account')) == str(qq_account)
|
||||
}
|
||||
# 2. 验证已记录的回复是否仍然存在,清理已删除的回复记录
|
||||
await self._validate_and_cleanup_reply_records(fid, my_replies)
|
||||
|
||||
# 找出所有非我发出且我未回复过的评论
|
||||
comments_to_reply = [
|
||||
c for c in comments
|
||||
if str(c.get('qq_account')) != str(qq_account) and c.get('comment_tid') not in replied_to_tids
|
||||
]
|
||||
# 3. 使用验证后的持久化记录来筛选未回复的评论
|
||||
comments_to_reply = []
|
||||
for comment in user_comments:
|
||||
comment_tid = comment.get('comment_tid')
|
||||
if not comment_tid:
|
||||
continue
|
||||
|
||||
# 检查是否已经在持久化记录中标记为已回复
|
||||
if not self.reply_tracker.has_replied(fid, comment_tid):
|
||||
comments_to_reply.append(comment)
|
||||
|
||||
if not comments_to_reply:
|
||||
logger.debug(f"说说 {fid} 下的所有评论都已回复过")
|
||||
return
|
||||
|
||||
logger.info(f"发现自己说说下的 {len(comments_to_reply)} 条新评论,准备回复...")
|
||||
for comment in comments_to_reply:
|
||||
comment_tid = comment.get("comment_tid")
|
||||
nickname = comment.get("nickname", "")
|
||||
comment_content = comment.get("content", "")
|
||||
|
||||
try:
|
||||
reply_content = await self.content_service.generate_comment_reply(
|
||||
content, comment.get("content", ""), comment.get("nickname", "")
|
||||
content, comment_content, nickname
|
||||
)
|
||||
if reply_content:
|
||||
success = await api_client["reply"](
|
||||
fid, qq_account, comment.get("nickname", ""), reply_content, comment.get("comment_tid")
|
||||
fid, qq_account, nickname, reply_content, comment_tid
|
||||
)
|
||||
if success:
|
||||
logger.info(f"成功回复'{comment.get('nickname', '')}'的评论: '{reply_content}'")
|
||||
# 标记为已回复
|
||||
self.reply_tracker.mark_as_replied(fid, comment_tid)
|
||||
logger.info(f"成功回复'{nickname}'的评论: '{reply_content}'")
|
||||
else:
|
||||
logger.error(f"回复'{comment.get('nickname', '')}'的评论失败")
|
||||
logger.error(f"回复'{nickname}'的评论失败")
|
||||
await asyncio.sleep(random.uniform(10, 20))
|
||||
else:
|
||||
logger.warning(f"生成回复内容失败,跳过回复'{nickname}'的评论")
|
||||
except Exception as e:
|
||||
logger.error(f"回复'{nickname}'的评论时发生异常: {e}", exc_info=True)
|
||||
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
|
||||
"""验证并清理已删除的回复记录"""
|
||||
# 获取当前记录中该说说的所有已回复评论ID
|
||||
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
|
||||
|
||||
if not recorded_replied_comments:
|
||||
return
|
||||
|
||||
# 从API返回的我的回复中提取parent_tid(即被回复的评论ID)
|
||||
current_replied_comments = set()
|
||||
for reply in my_replies:
|
||||
parent_tid = reply.get('parent_tid')
|
||||
if parent_tid:
|
||||
current_replied_comments.add(parent_tid)
|
||||
|
||||
# 找出记录中有但实际已不存在的回复
|
||||
deleted_replies = recorded_replied_comments - current_replied_comments
|
||||
|
||||
if deleted_replies:
|
||||
logger.info(f"检测到 {len(deleted_replies)} 个回复已被删除,清理记录...")
|
||||
for comment_tid in deleted_replies:
|
||||
self.reply_tracker.remove_reply_record(fid, comment_tid)
|
||||
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
|
||||
|
||||
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
|
||||
"""处理单条说说,决定是否评论和点赞"""
|
||||
@@ -641,7 +685,7 @@ class QZoneService:
|
||||
logger.error(f"上传图片 {index+1} 异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
|
||||
async def _list_feeds(t_qq: str, num: int, is_monitoring_own_feeds: bool = False) -> List[Dict]:
|
||||
"""获取指定用户说说列表"""
|
||||
try:
|
||||
params = {
|
||||
@@ -667,10 +711,14 @@ class QZoneService:
|
||||
feeds_list = []
|
||||
my_name = json_data.get("logininfo", {}).get("name", "")
|
||||
for msg in json_data.get("msglist", []):
|
||||
# 只有在处理好友说说时,才检查是否已评论并跳过
|
||||
if not is_monitoring_own_feeds:
|
||||
is_commented = any(
|
||||
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
|
||||
)
|
||||
if not is_commented:
|
||||
if is_commented:
|
||||
continue
|
||||
|
||||
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
|
||||
|
||||
comments = []
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.ReplyTrackerService")
|
||||
|
||||
|
||||
class ReplyTrackerService:
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
使用本地JSON文件持久化存储已回复的评论ID
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 数据存储路径
|
||||
self.data_dir = Path(__file__).resolve().parent.parent / "data"
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.reply_record_file = self.data_dir / "replied_comments.json"
|
||||
|
||||
# 内存中的已回复评论记录
|
||||
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
|
||||
self.replied_comments: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# 数据清理配置
|
||||
self.max_record_days = 30 # 保留30天的记录
|
||||
|
||||
# 加载已有数据
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
"""从文件加载已回复评论数据"""
|
||||
try:
|
||||
if self.reply_record_file.exists():
|
||||
with open(self.reply_record_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self.replied_comments = data
|
||||
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录")
|
||||
else:
|
||||
logger.info("未找到回复记录文件,将创建新的记录")
|
||||
except Exception as e:
|
||||
logger.error(f"加载回复记录失败: {e}")
|
||||
self.replied_comments = {}
|
||||
|
||||
def _save_data(self):
|
||||
"""保存已回复评论数据到文件"""
|
||||
try:
|
||||
# 清理过期数据
|
||||
self._cleanup_old_records()
|
||||
|
||||
with open(self.reply_record_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("回复记录已保存")
|
||||
except Exception as e:
|
||||
logger.error(f"保存回复记录失败: {e}")
|
||||
|
||||
def _cleanup_old_records(self):
|
||||
"""清理超过保留期限的记录"""
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (self.max_record_days * 24 * 60 * 60)
|
||||
|
||||
feeds_to_remove = []
|
||||
total_removed = 0
|
||||
|
||||
for feed_id, comments in self.replied_comments.items():
|
||||
comments_to_remove = []
|
||||
|
||||
for comment_id, timestamp in comments.items():
|
||||
if timestamp < cutoff_time:
|
||||
comments_to_remove.append(comment_id)
|
||||
|
||||
# 移除过期的评论记录
|
||||
for comment_id in comments_to_remove:
|
||||
del comments[comment_id]
|
||||
total_removed += 1
|
||||
|
||||
# 如果该说说下没有任何记录了,标记删除整个说说记录
|
||||
if not comments:
|
||||
feeds_to_remove.append(feed_id)
|
||||
|
||||
# 移除空的说说记录
|
||||
for feed_id in feeds_to_remove:
|
||||
del self.replied_comments[feed_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"清理了 {total_removed} 条过期的回复记录")
|
||||
|
||||
def has_replied(self, feed_id: str, comment_id: str) -> bool:
|
||||
"""
|
||||
检查是否已经回复过指定的评论
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
|
||||
Returns:
|
||||
bool: 如果已回复过返回True,否则返回False
|
||||
"""
|
||||
if not feed_id or not comment_id:
|
||||
return False
|
||||
|
||||
return (feed_id in self.replied_comments and
|
||||
comment_id in self.replied_comments[feed_id])
|
||||
|
||||
def mark_as_replied(self, feed_id: str, comment_id: str):
|
||||
"""
|
||||
标记指定评论为已回复
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
"""
|
||||
if not feed_id or not comment_id:
|
||||
logger.warning("feed_id 或 comment_id 为空,无法标记为已回复")
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
if feed_id not in self.replied_comments:
|
||||
self.replied_comments[feed_id] = {}
|
||||
|
||||
self.replied_comments[feed_id][comment_id] = current_time
|
||||
|
||||
# 保存到文件
|
||||
self._save_data()
|
||||
|
||||
logger.info(f"已标记评论为已回复: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def get_replied_comments(self, feed_id: str) -> Set[str]:
|
||||
"""
|
||||
获取指定说说下所有已回复的评论ID
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
|
||||
Returns:
|
||||
Set[str]: 已回复的评论ID集合
|
||||
"""
|
||||
if feed_id in self.replied_comments:
|
||||
return set(self.replied_comments[feed_id].keys())
|
||||
return set()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取回复记录统计信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含统计信息的字典
|
||||
"""
|
||||
total_feeds = len(self.replied_comments)
|
||||
total_replies = sum(len(comments) for comments in self.replied_comments.values())
|
||||
|
||||
return {
|
||||
"total_feeds_with_replies": total_feeds,
|
||||
"total_replied_comments": total_replies,
|
||||
"data_file": str(self.reply_record_file),
|
||||
"max_record_days": self.max_record_days
|
||||
}
|
||||
|
||||
def remove_reply_record(self, feed_id: str, comment_id: str):
|
||||
"""
|
||||
移除指定评论的回复记录
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
"""
|
||||
if feed_id in self.replied_comments and comment_id in self.replied_comments[feed_id]:
|
||||
del self.replied_comments[feed_id][comment_id]
|
||||
|
||||
# 如果该说说下没有任何回复记录了,删除整个说说记录
|
||||
if not self.replied_comments[feed_id]:
|
||||
del self.replied_comments[feed_id]
|
||||
|
||||
self._save_data()
|
||||
logger.debug(f"已移除回复记录: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def remove_feed_records(self, feed_id: str):
|
||||
"""
|
||||
移除指定说说的所有回复记录
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
"""
|
||||
if feed_id in self.replied_comments:
|
||||
del self.replied_comments[feed_id]
|
||||
self._save_data()
|
||||
logger.info(f"已移除说说 {feed_id} 的所有回复记录")
|
||||
@@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
logger = get_logger("Permission")
|
||||
|
||||
@@ -411,7 +411,6 @@ class ScheduleManager:
|
||||
通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。
|
||||
新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。
|
||||
"""
|
||||
from src.chat.chat_loop.wakeup_manager import WakeUpManager
|
||||
# --- 基础检查 ---
|
||||
if not global_config.schedule.enable_is_sleep:
|
||||
return False
|
||||
|
||||
@@ -21,7 +21,7 @@ max_retry = 2
|
||||
timeout = 30
|
||||
retry_interval = 10
|
||||
|
||||
[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini"
|
||||
[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"aiohttp_gemini"
|
||||
name = "Google"
|
||||
base_url = "https://api.google.com/v1"
|
||||
api_key = "your-google-api-key-1"
|
||||
|
||||