This commit is contained in:
雅诺狐
2025-08-29 13:48:01 +08:00
86 changed files with 9073 additions and 1578 deletions

2
.gitignore vendored
View File

@@ -325,6 +325,8 @@ run_pet.bat
!/plugins/set_emoji_like
!/plugins/permission_example
!/plugins/hello_world_plugin
!/plugins/take_picture_plugin
!/plugins/napcat_adapter_plugin
!/plugins/echo_example
config.toml

4
bot.py
View File

@@ -31,7 +31,6 @@ from src.manager.async_task_manager import async_task_manager # noqa
from src.config.config import global_config # noqa
from src.common.database.database import initialize_sql_database # noqa
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
from src.common.tool_history import wrap_tool_executor #noqa
logger = get_logger("main")
@@ -240,8 +239,7 @@ class MaiBotMain(BaseMain):
self.setup_timezone()
self.check_and_confirm_eula()
self.initialize_database()
# 初始化工具历史记录
wrap_tool_executor()
return self.create_main_system()

View File

@@ -1,5 +1,68 @@
# Changelog
## [0.10.0-alpha] - 2025-8-28
> **MoFox-Bot 0.10.0-alpha 版本发布!**
>
> 本次更新带来了多项核心功能增强和系统优化。
>
> 在**新功能**方面,我们引入了**持久化回复跟踪**以避免重复回复并为LLM请求实现了**模型故障转移机制**,提高了系统的健壮性。插件系统也得到了增强,增加了**事件触发和订阅的白名单机制**,并为事件处理器添加了**异步锁和并行执行**支持。此外,新版本还实现了对**说说中图片的识别与理解**,并引入了**弹性睡眠与睡前通知机制**,使机器人的作息更加智能化。
>
> 在**修复**方面,我们解决了`enable`配置的bug修复了event权限问题并处理了专注模式下艾特不回复、模型信息不存在、cookie获取失败等多个问题提升了系统的稳定性。
>
> 在**重构**方面,我们移除了`changelog_config`并更新了模型配置模板重构了工具缓存机制和LLM请求重试逻辑。同时我们将项目从`MaiMbot-Pro-Max`正式更名为`MoFox_Bot`,并对内存、权限、配置等多个模块进行了优化。
>
> 总的来说0.11.0版本在功能、稳定性和代码质量上都有了显著提升。
### 新功能
- **maizone**: 引入持久化回复跟踪以避免重复回复
- **llm**: 为LLM请求实现模型故障转移机制
- **plugin-system**: 添加事件触发和订阅的白名单机制
- **plugin**: 为事件处理器添加异步锁和并行执行支持
- **maizone**: 实现对说说中图片的识别与理解
- **sleep**: 实现睡眠唤醒与重新入睡机制
- **core**: 实现HFC及睡眠状态的持久化
- **schedule**: 引入弹性睡眠与睡前通知机制
- **monthly_plan**: 增加月度计划数量上限并自动清理
- **command**: 添加PlusCommand增强命令系统
- **expression**: 重构表达学习配置,引入基于规则的结构化定义
- **chat**: 实现睡眠压力和失眠系统
- **schedule**: 重构日程与月度计划管理模块
- **core**: 集成统一向量数据库服务并重构相关模块
- **tool_system**: 实现工具的声明式缓存
- **plugin_system**: 增加工具执行日志记录
- **maizone**: 新增QQ空间互通组功能根据聊天上下文生成说说
- **monthly_plan**: 增强月度计划系统,引入状态管理和智能抽取
### 修复
- 修复`enable`配置
- 修复event权限现在每个component都拥有`plugin_name`属性
- 修复专注模式下艾特不回复的问题
- 修复模型信息不存在时引发的属性错误
- 修复`maizone_refactored`获取cookie时响应为空导致的错误
- 修复关键词非列表形式时导致的解析错误
- 修复即时记忆的 orjson 编码与解码问题
- 修复空回复检测,同时修复`tool_call`
- 处理截断消息时`message``None`的情况
- 修复`get_remaining`的起始索引
- 修复回复自己评论的问题
- 修复`enable`配置
### 重构
- **config**: 移除`changelog_config`并更新模型配置模板
- **cache**: 重构工具缓存机制并优化LLM请求重试逻辑
- **core**: 移除工具历史记录管理器并将缓存集成到工具执行器中
- 重构权限检查和装饰器用法
- **core**: 将项目从`MaiMbot-Pro-Max`重命名为`MoFox_Bot`
- **memory**: 重构向量记忆清理逻辑以提高稳定性
- **llm_models**: 移除官方Gemini客户端并改用`aiohttp`实现
- **config**: 整合搜索服务配置并移除废弃选项
- **config**: 将反截断设置移至模型配置
- **video**: 重构视频分析,增加抽帧模式和间隔配置
从这里开始都是第三方改版的更新记录!!!!!
========================================================
## [0.10.0] - 2025-7-1
### 主要功能更改
- 工具系统重构,现在合并到了插件系统中

View File

@@ -1,51 +0,0 @@
# Changelog
## [1.0.3] - 2025-3-31
### Added
- 新增了心流相关配置项:
- `heartflow` 配置项,用于控制心流功能
### Removed
- 移除了 `response` 配置项中的 `model_r1_probability``model_v3_probability` 选项
- 移除了次级推理模型相关配置
## [1.0.1] - 2025-3-30
### Added
- 增加了流式输出控制项 `stream`
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
## [1.0.0] - 2025-3-30
### Added
- 修复了错误的版本命名
- 杀掉了所有无关文件
## [0.0.11] - 2025-3-12
### Added
- 新增了 `schedule` 配置项,用于配置日程表生成功能
- 新增了 `response_splitter` 配置项,用于控制回复分割
- 新增了 `experimental` 配置项,用于实验性功能开关
- 新增了 `llm_observation``llm_sub_heartflow` 模型配置
- 新增了 `llm_heartflow` 模型配置
-`personality` 配置项中新增了 `prompt_schedule_gen` 参数
### Changed
- 优化了模型配置的组织结构
- 调整了部分配置项的默认值
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
-`message` 配置项中:
- 新增了 `model_max_output_length` 参数
-`willing` 配置项中新增了 `emoji_response_penalty` 参数
-`personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
### Removed
- 移除了 `min_text_length` 配置项
- 移除了 `cq_code` 配置项
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
## [0.0.5] - 2025-3-11
### Added
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
## [0.0.4] - 2025-3-9
### Added
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

View File

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

View File

Before

Width:  |  Height:  |  Size: 4.9 KiB

After

Width:  |  Height:  |  Size: 4.9 KiB

124
docs/deployment_guide.md Normal file
View File

@@ -0,0 +1,124 @@
# MoFox_Bot 部署指南
欢迎使用 MoFox_Bot本指南将引导您完成在 Windows 环境下部署 MoFox_Bot 的全部过程。
## 1. 系统要求
- **操作系统**: Windows 10 或 Windows 11
- **Python**: 版本 >= 3.10
- **Git**: 用于克隆项目仓库
- **uv**: 推荐的 Python 包管理器 (版本 >= 0.1.0)
## 2. 部署步骤
### 第一步:获取必要的文件
首先,创建一个用于存放 MoFox_Bot 相关文件的文件夹,并通过 `git` 克隆 MoFox_Bot 主程序和 Napcat 适配器。
```shell
mkdir MoFox_Bot_Deployment
cd MoFox_Bot_Deployment
git clone hhttps://github.com/MoFox-Studio/MoFox_Bot.git
git clone https://github.com/MoFox-Studio/Napcat-Adapter.git
```
### 第二步:环境配置
我们推荐使用 `uv` 来管理 Python 环境和依赖,因为它提供了更快的安装速度和更好的依赖管理体验。
**安装 uv:**
```shell
pip install uv
```
### 第三步:依赖安装
**1. 安装 MoFox_Bot 依赖:**
进入 `mmc` 文件夹,创建虚拟环境并安装依赖。
```shell
cd mmc
uv venv
uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade
```
**2. 安装 Napcat-Adapter 依赖:**
回到上一级目录,进入 `Napcat-Adapter` 文件夹,创建虚拟环境并安装依赖。
```shell
cd ..
cd Napcat-Adapter
uv venv
uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade
```
### 第四步:配置 MoFox_Bot 和 Adapter
**1. MoFox_Bot 配置:**
-`mmc` 文件夹中,将 `template/bot_config_template.toml` 复制到 `config/bot_config.toml`
-`template/model_config_template.toml` 复制到 `config/model_config.toml`
- 根据 [模型配置指南](guides/model_configuration_guide.md) 和 `bot_config.toml` 文件中的注释,填写您的 API Key 和其他相关配置。
**2. Napcat-Adapter 配置:**
-`Napcat-Adapter` 文件夹中,将 `template/template_config.toml` 复制到根目录并改名为 `config.toml`
- 打开 `config.toml` 文件,配置 `[Napcat_Server]``[MaiBot_Server]` 字段。
- `[Napcat_Server]``port` 应与 Napcat 设置的反向代理 URL 中的端口相同。
- `[MaiBot_Server]``port` 应与 MoFox_Bot 的 `bot_config.toml` 中设置的端口相同。
### 第五步:运行
**1. 启动 Napcat:**
请参考 [NapCatQQ 文档](https://napcat-qq.github.io/) 进行部署和启动。
**2. 启动 MoFox_Bot:**
进入 `mmc` 文件夹,使用 `uv` 运行。
```shell
cd mmc
uv run python bot.py
```
**3. 启动 Napcat-Adapter:**
打开一个新的终端窗口,进入 `Napcat-Adapter` 文件夹,使用 `uv` 运行。
```shell
cd Napcat-Adapter
uv run python main.py
```
至此MoFox_Bot 已成功部署并运行。
## 3. 详细配置说明
### `bot_config.toml`
这是 MoFox_Bot 的主配置文件包含了机器人昵称、主人QQ、命令前缀、数据库设置等。请根据文件内的注释进行详细配置。
### `model_config.toml`
此文件用于配置 AI 模型和 API 服务提供商。详细配置方法请参考 [模型配置指南](guides/model_configuration_guide.md)。
### 插件配置
每个插件都有独立的配置文件,位于 `mmc/config/plugins/` 目录下。插件的配置由其 `config_schema` 自动生成。详细信息请参考 [插件配置完整指南](plugins/configuration-guide.md)。
## 4. 故障排除
- **依赖安装失败**:
- 尝试更换 PyPI 镜像源。
- 检查网络连接。
- **API 调用失败**:
- 检查 `model_config.toml` 中的 API Key 和 `base_url` 是否正确。
- **无法连接到 Napcat**:
- 检查 Napcat 是否正常运行。
- 确认 `Napcat-Adapter``config.toml``[Napcat_Server]``port` 是否与 Napcat 设置的端口一致。
如果遇到其他问题,请查看 `logs/` 目录下的日志文件以获取详细的错误信息。

View File

@@ -43,12 +43,11 @@ retry_interval = 10 # 重试间隔(秒)
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式现在支持不良好 | `openai` |
| `client_type` | ❌ | 客户端类型:`openai``gemini``aiohttp_gemini` | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type``gemini`的模型,`base_url`字段无效。**
### 2.3 支持的服务商示例
#### DeepSeek
@@ -73,9 +72,9 @@ client_type = "openai"
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
base_url = "https://generativelanguage.googleapis.com/v1beta" # 在MoFox-Bot中, 使用aiohttp_gemini客户端的提供商可以自定义base_url
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
client_type = "aiohttp_gemini" # 注意Gemini需要使用特殊客户端
```
## 3. 模型配置
@@ -118,11 +117,11 @@ enable_thinking = false # 禁用思考
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
![SiliconFlow文档截图](image-1.png)
![SiliconFlow文档截图](../assets/image-1.png)
以豆包文档为另一个例子
![豆包文档截图](image.png)
![豆包文档截图](../assets/image.png)
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
```toml
@@ -133,7 +132,6 @@ thinking = {type = "disabled"} # 禁用思考
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
**请注意,对于`client_type``gemini`的模型,此字段无效。**
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |
@@ -145,6 +143,7 @@ thinking = {type = "disabled"} # 禁用思考
| `price_out` | ❌ | 输出价格(元/M token用于成本统计 |
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
| `extra_params` | ❌ | 额外的模型参数配置 |
| `anti_truncation` | ❌ | 是否启用反截断功能 |
## 4. 模型任务配置
@@ -184,7 +183,7 @@ max_tokens = 800
```
### planner - 决策模型
负责决定MaiBot该做什么
负责决定MoFox_Bot该做什么
```toml
[model_task_config.planner]
model_list = ["siliconflow-deepseek-v3"]
@@ -193,7 +192,7 @@ max_tokens = 800
```
### emotion - 情绪模型
负责MaiBot的情绪变化
负责MoFox_Bot的情绪变化
```toml
[model_task_config.emotion]
model_list = ["siliconflow-deepseek-v3"]
@@ -262,6 +261,44 @@ temperature = 0.7
max_tokens = 800
```
### schedule_generator - 日程生成模型
```toml
[model_task_config.schedule_generator]
model_list = ["deepseek-v3"]
temperature = 0.5
max_tokens = 1024
```
### monthly_plan_generator - 月度计划生成模型
```toml
[model_task_config.monthly_plan_generator]
model_list = ["deepseek-v3"]
temperature = 0.7
max_tokens = 1024
```
### emoji_vlm - 表情包VLM模型
```toml
[model_task_config.emoji_vlm]
model_list = ["qwen-vl-max"]
max_tokens = 800
```
### anti_injection - 反注入模型
```toml
[model_task_config.anti_injection]
model_list = ["deepseek-v3"]
temperature = 0.1
max_tokens = 512
```
### utils_video - 视频分析模型
```toml
[model_task_config.utils_video]
model_list = ["qwen-vl-max"]
max_tokens = 800
```
## 5. 配置建议
### 5.1 Temperature 参数选择
@@ -276,7 +313,7 @@ max_tokens = 800
| 任务类型 | 推荐模型类型 | 示例 |
|----------|--------------|------|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-5,Gemini-2.5-Pro |
| 高频率任务 | 小模型 | Qwen3-8B |
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
@@ -285,7 +322,6 @@ max_tokens = 800
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
3. **选择免费模型**对于测试环境优先使用price为0的模型
## 6. 配置验证

View File

@@ -1,89 +0,0 @@
# 💻 Command组件详解
## 📖 什么是Command
Command是直接响应用户明确指令的组件与Action不同Command是**被动触发**的,当用户输入特定格式的命令时立即执行。
Command通过正则表达式匹配用户输入提供确定性的功能服务。
### 🎯 Command的特点
- 🎯 **确定性执行**:匹配到命令立即执行,无随机性
-**即时响应**:用户主动触发,快速响应
- 🔍 **正则匹配**:通过正则表达式精确匹配用户输入
- 🛑 **拦截控制**:可以控制是否阻止消息继续处理
- 📝 **参数解析**:支持从用户输入中提取参数
---
## 🛠️ Command组件的基本结构
首先Command组件需要继承自`BaseCommand`类,并实现必要的方法。
```python
class ExampleCommand(BaseCommand):
command_name = "example" # 命令名称,作为唯一标识符
command_description = "这是一个示例命令" # 命令描述
command_pattern = r"" # 命令匹配的正则表达式
async def execute(self) -> Tuple[bool, Optional[str], bool]:
"""
执行Command的主要逻辑
Returns:
Tuple[bool, str, bool]:
- 第一个bool表示是否成功执行
- 第二个str是执行结果消息
- 第三个bool表示是否需要阻止消息继续处理
"""
# ---- 执行命令的逻辑 ----
return True, "执行成功", False
```
**`command_pattern`**: 该Command匹配的正则表达式用于精确匹配用户输入。
请注意:如果希望能获取到命令中的参数,请在正则表达式中使用有命名的捕获组,例如`(?P<param_name>pattern)`
这样在匹配时,内部实现可以使用`re.match.groupdict()`方法获取到所有捕获组的参数,并以字典的形式存储在`self.matched_groups`中。
### 匹配样例
假设我们有一个命令`/example param1=value1 param2=value2`,对应的正则表达式可以是:
```python
class ExampleCommand(BaseCommand):
command_name = "example"
command_description = "这是一个示例命令"
command_pattern = r"/example (?P<param1>\w+) (?P<param2>\w+)"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
# 获取匹配的参数
param1 = self.matched_groups.get("param1")
param2 = self.matched_groups.get("param2")
# 执行逻辑
return True, f"参数1: {param1}, 参数2: {param2}", False
```
---
## Command 内置方法说明
```python
class BaseCommand:
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问"""
async def send_text(self, content: str, reply_to: str = "") -> bool:
"""发送回复消息"""
async def send_type(self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "") -> bool:
"""发送指定类型的回复消息到当前聊天环境"""
async def send_command(self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True) -> bool:
"""发送命令消息"""
async def send_emoji(self, emoji_base64: str) -> bool:
"""发送表情包"""
async def send_image(self, image_base64: str) -> bool:
"""发送图片"""
```
具体参数与用法参见`BaseCommand`基类的定义。

View File

@@ -9,7 +9,7 @@
## 组件功能详解
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
- [💻 Command组件详解](PLUS_COMMAND_GUIDE.md) - 学习直接响应命令的组件
- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构

View File

@@ -90,7 +90,7 @@ class HelloWorldPlugin(BasePlugin):
在日志中你应该能看到插件被加载的信息。虽然插件还没有任何功能,但它已经成功运行了!
![1750326700269](image/quick-start/1750326700269.png)
![1750326700269](../assets/1750326700269.png)
### 5. 添加第一个功能问候Action
@@ -180,7 +180,7 @@ MoFox_Bot可能会选择使用你的问候Action发送回复
嗨!很开心见到你!😊
```
![1750332508760](image/quick-start/1750332508760.png)
![1750332508760](../assets/1750332508760.png)
> **💡 小提示**MoFox_Bot会智能地决定什么时候使用它。如果没有立即看到效果多试几次不同的消息。

View File

@@ -1,124 +0,0 @@
# 自动化工具缓存系统使用指南
为了提升性能并减少不必要的重复计算或API调用MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
## 核心概念
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"``"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
## 如何为你的工具启用缓存
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
### 1. `enable_cache: bool`
这是启用缓存的总开关。
- **类型**: `bool`
- **默认值**: `False`
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
**示例**:
```python
class MyAwesomeTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
```
### 2. `cache_ttl: int`
设置缓存的生存时间Time-To-Live
- **类型**: `int`
- **单位**: 秒
- **默认值**: `3600` (1小时)
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
**示例**:
```python
class MyLongTermCacheTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
```
### 3. `semantic_cache_query_key: Optional[str]`
启用语义缓存的关键。
- **类型**: `Optional[str]`
- **默认值**: `None`
- **作用**:
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
**示例**:
```python
class WebSurfingTool(BaseTool):
name: str = "web_search"
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
# ... 其他参数 ...
]
# --- 缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query" # <-- 关键!
```
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
## 完整示例
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定且查询意图可能相似如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
```python
# in your_plugin/tools/stock_checker.py
from src.plugin_system import BaseTool, ToolParamType
class StockCheckerTool(BaseTool):
"""
一个用于查询股票价格的工具。
"""
name: str = "get_stock_price"
description: str = "获取指定公司或股票代码的最新价格。"
available_for_llm: bool = True
parameters = [
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
]
# --- 缓存配置 ---
# 1. 开启缓存
enable_cache: bool = True
# 2. 股价信息缓存10分钟
cache_ttl: int = 600
# 3. 使用 "symbol" 参数进行语义搜索
semantic_cache_query_key: str = "symbol"
# --------------------
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
symbol = function_args.get("symbol")
# ... 这里是你调用外部API获取股票价格的逻辑 ...
# price = await some_stock_api.get_price(symbol)
price = 123.45 # 示例价格
return {
"type": "stock_price_result",
"content": f"{symbol} 的当前价格是 ${price}"
}
```
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
- 在接下来的10分钟内如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"``"苹果"` 在语义上高度相关大概率也会直接返回缓存的结果而无需再次调用API。
---
现在你可以专注于实现工具的核心逻辑把缓存的复杂性交给MMC的自动化系统来处理。

View File

@@ -2,7 +2,7 @@
## 📖 什么是工具
工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了MoFox_Bot能够获得的信息量。
### 🎯 工具的特点
@@ -191,8 +191,7 @@ class WeatherTool(BaseTool):
name = "weather_query" # 清晰表达功能
name = "knowledge_search" # 描述性强
name = "stock_price_check" # 功能明确
```
#### ❌ 避免的命名
```#### ❌ 避免的命名
```python
name = "tool1" # 无意义
name = "wq" # 过于简短
@@ -244,3 +243,130 @@ def _format_result(self, data):
def _format_result(self, data):
return str(data) # 直接返回原始数据
```
---
# 自动化工具缓存系统使用指南
为了提升性能并减少不必要的重复计算或API调用MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
## 核心概念
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"``"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
## 如何为你的工具启用缓存
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
### 1. `enable_cache: bool`
这是启用缓存的总开关。
- **类型**: `bool`
- **默认值**: `False`
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
**示例**:
```python
class MyAwesomeTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
```
### 2. `cache_ttl: int`
设置缓存的生存时间Time-To-Live
- **类型**: `int`
- **单位**: 秒
- **默认值**: `3600` (1小时)
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
**示例**:
```python
class MyLongTermCacheTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
```
### 3. `semantic_cache_query_key: Optional[str]`
启用语义缓存的关键。
- **类型**: `Optional[str]`
- **默认值**: `None`
- **作用**:
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
**示例**:
```python
class WebSurfingTool(BaseTool):
name: str = "web_search"
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
# ... 其他参数 ...
]
# --- 缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query" # <-- 关键
```
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
## 完整示例
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定且查询意图可能相似如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
```python
# in your_plugin/tools/stock_checker.py
from src.plugin_system import BaseTool, ToolParamType
class StockCheckerTool(BaseTool):
"""
一个用于查询股票价格的工具。
"""
name: str = "get_stock_price"
description: str = "获取指定公司或股票代码的最新价格。"
available_for_llm: bool = True
parameters = [
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
]
# --- 缓存配置 ---
# 1. 开启缓存
enable_cache: bool = True
# 2. 股价信息缓存10分钟
cache_ttl: int = 600
# 3. 使用 "symbol" 参数进行语义搜索
semantic_cache_query_key: str = "symbol"
# --------------------
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
symbol = function_args.get("symbol")
# ... 这里是你调用外部API获取股票价格的逻辑 ...
# price = await some_stock_api.get_price(symbol)
price = 123.45 # 示例价格
return {
"type": "stock_price_result",
"content": f"{symbol} 的当前价格是 ${price}"
}
```
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
- 在接下来的10分钟内如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"``"苹果"` 在语义上高度相关大概率也会直接返回缓存的结果而无需再次调用API。
---
现在你可以专注于实现工具的核心逻辑把缓存的复杂性交给MMC的自动化系统来处理。

View File

@@ -1,121 +0,0 @@
# “月层计划”系统架构设计文档
## 1. 系统概述与目标
本系统旨在为MoFox_Bot引入一个动态的、由大型语言模型LLM驱动的“月层计划”机制。其核心目标是取代静态、预设的任务模板转而利用LLM在程序启动时自动生成符合Bot人设的、具有时效性的月度计划。这些计划将被存储、管理并在构建每日日程时被动态抽取和使用从而极大地丰富日程内容的个性和多样性。
---
## 2. 核心设计原则
- **动态性与智能化:** 所有计划内容均由LLM实时生成确保其独特性和创造性。
- **人设一致性:** 计划的生成将严格围绕Bot的核心人设进行强化角色形象。
- **持久化与可管理:** 生成的计划将被存入专用数据库表,便于管理和追溯。
- **消耗性与随机性:** 计划在使用后有一定几率被消耗(删除),模拟真实世界中计划的完成与迭代。
---
## 3. 系统核心流程规划
本系统包含两大核心流程:**启动时的计划生成流程**和**日程构建时的计划使用流程**。
### 3.1 流程一:启动时计划生成
此流程在每次程序启动时触发,负责填充当月的计划池。
```mermaid
graph TD
A[程序启动] --> B{检查当月计划池};
B -- 计划数量低于阈值 --> C[构建LLM Prompt];
C -- prompt包含Bot人设、月份等信息 --> D[调用LLM服务];
D -- LLM返回多个计划文本 --> E[解析并格式化计划];
E -- 逐条处理 --> F[存入`monthly_plans`数据库表];
F --> G[完成启动任务];
B -- 计划数量充足 --> G;
```
### 3.2 流程二:日程构建时计划使用
此流程在构建每日日程的提示词Prompt时触发。
```mermaid
graph TD
H[构建日程Prompt] --> I{查询数据库};
I -- 读取当月未使用的计划 --> J[随机抽取N个计划];
J --> K[将计划文本嵌入日程Prompt];
K --> L{随机数判断};
L -- 概率命中 --> M[将已抽取的计划标记为删除];
M --> N[完成Prompt构建];
L -- 概率未命中 --> N;
```
---
## 4. 数据库模型设计
为支撑本系统,需要新增一个数据库表。
**表名:** `monthly_plans`
| 字段名 | 类型 | 描述 |
| :--- | :--- | :--- |
| `id` | Integer | 主键,自增。 |
| `plan_text` | Text | 由LLM生成的计划内容原文。 |
| `target_month` | String(7) | 计划所属的月份,格式为 "YYYY-MM"。 |
| `is_deleted` | Boolean | 软删除标记,默认为 `false`。 |
| `created_at` | DateTime | 记录创建时间。 |
---
## 5. 详细模块规划
### 5.1 LLM Prompt生成模块
- **职责:** 构建高质量的Prompt以引导LLM生成符合要求的计划。
- **输入:** Bot人设描述、当前月份、期望生成的计划数量。
- **输出:** 一个结构化的Prompt字符串。
- **Prompt示例:**
```
你是一个[此处填入Bot人设描述例如活泼开朗、偶尔有些小迷糊的虚拟助手]。
请为即将到来的[YYYY年MM月]设计[N]个符合你身份的月度计划或目标。
要求:
1. 每个计划都是独立的、积极向上的。
2. 语言风格要自然、口语化,符合你的性格。
3. 每个计划用一句话或两句话简短描述。
4. 以JSON格式返回格式为{"plans": ["计划一", "计划二", ...]}
```
### 5.2 数据库交互模块
- **职责:** 提供对 `monthly_plans` 表的增、删、改、查接口。
- **规划函数列表:**
- `add_new_plans(plans: list[str], month: str)`: 批量添加新生成的计划。
- `get_active_plans_for_month(month: str) -> list`: 获取指定月份所有未被删除的计划。
- `soft_delete_plans(plan_ids: list[int])`: 将指定ID的计划标记为软删除。
### 5.3 配置项规划
需要在主配置文件 `config/bot_config.toml` 中添加以下配置项,以控制系统行为。
```toml
# ----------------------------------------------------------------
# 月层计划系统设置 (Monthly Plan System Settings)
# ----------------------------------------------------------------
[monthly_plan_system]
# 是否启用本功能
enable = true
# 启动时如果当月计划少于此数量则触发LLM生成
generation_threshold = 10
# 每次调用LLM期望生成的计划数量
plans_per_generation = 5
# 计划被使用后,被删除的概率 (0.0 到 1.0)
deletion_probability_on_use = 0.5
```
---
**文档结束。** 本文档纯粹为架构规划,旨在提供清晰的设计思路和开发指引,不包含任何实现代码。

View File

@@ -1,53 +0,0 @@
{
"manifest_version": 1,
"name": "Hello World 示例插件 (Hello World Plugin)",
"version": "1.0.0",
"description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.8.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["demo", "example", "hello", "greeting", "tutorial"],
"categories": ["Examples", "Tutorial"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "example",
"components": [
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
}
}

View File

@@ -1,282 +0,0 @@
from typing import List, Tuple, Type, Any, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
BaseCommand,
BaseTool,
ComponentInfo,
ActionActivationType,
ConfigField,
ToolParamType
)
from src.plugin_system.apis import send_api
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatType
logger = get_logger(__name__)
class GetGroupListCommand(BaseCommand):
"""获取群列表命令"""
command_name = "get_groups"
command_description = "获取机器人加入的群列表"
command_pattern = r"^/get_groups$"
command_help = "获取机器人加入的群列表"
command_examples = ["/get_groups"]
intercept_message = True
async def execute(self) -> Tuple[bool, str, bool]:
try:
# 调用适配器命令API
domain = "user.qzone.qq.com"
response = await send_api.adapter_command_to_stream(
action="get_cookies",
platform="qq",
params={"domain": domain},
timeout=40.0,
storage_message=False
)
text = str(response)
await self.send_text(text)
return True, "获取群列表成功", True
except Exception as e:
await self.send_text(f"获取群列表失败: {str(e)}")
return False, "获取群列表失败", True
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"name": self.name, "content": result}
except Exception as e:
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
# ===== Action组件 =====
class HelloAction(BaseAction):
"""问候Action - 简单的问候动作"""
# === 基本信息(必须填写)===
action_name = "hello_greeting"
action_description = "向用户发送问候消息"
activation_type = ActionActivationType.ALWAYS # 始终激活
# === 功能描述(必须填写)===
action_parameters = {"greeting_message": "要发送的问候消息"}
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行问候动作 - 这是核心功能"""
# 发送问候消息
greeting_message = self.action_data.get("greeting_message", "")
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
message = base_message + greeting_message
await self.send_text(message)
return True, "发送了问候消息"
class ByeAction(BaseAction):
"""告别Action - 只在用户说再见时激活"""
action_name = "bye_greeting"
action_description = "向用户发送告别消息"
# 使用关键词激活
activation_type = ActionActivationType.KEYWORD
# 关键词设置
activation_keywords = ["再见", "bye", "88", "拜拜"]
keyword_case_sensitive = False
action_parameters = {"bye_message": "要发送的告别消息"}
action_require = [
"用户要告别时使用",
"当有人要离开时使用",
"当有人和你说再见时使用",
]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
bye_message = self.action_data.get("bye_message", "")
message = f"再见!期待下次聊天!👋{bye_message}"
await self.send_text(message)
return True, "发送了告别消息"
class TimeCommand(BaseCommand):
"""时间查询Command - 响应/time命令"""
command_name = "time"
command_description = "获取当前时间"
# === 命令设置(必须填写)===
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
async def execute(self) -> Tuple[bool, str, bool]:
"""执行时间查询"""
import datetime
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 发送时间信息
message = f"⏰ 当前时间:{time_str}"
await self.send_text(message)
return True, f"显示了当前x时间: {time_str}", True
# class PrintMessage(BaseEventHandler):
# """打印消息事件处理器 - 处理打印消息事件"""
#
# event_type = EventType.ON_MESSAGE
# handler_name = "print_message_handler"
# handler_description = "打印接收到的消息"
#
# async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
# """执行打印消息事件处理"""
# # 打印接收到的消息
#
# if self.get_config("print_message.enabled", False):
# print(f"接收到消息: {message.raw_message}")
# return True, True, "消息已打印1"
# ===== 插件注册 =====
@register_plugin
class HelloWorldPlugin(BasePlugin):
"""Hello World插件 - 你的第一个MaiCore插件"""
# 插件基本信息
plugin_name: str = "hello_world_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
"message": ConfigField(
type=list, default=["嗨!很开心见到你!😊", "Ciallo(∠・ω< )⌒★"], description="默认问候消息"
),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
},
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(HelloAction.get_action_info(), HelloAction),
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand), # 现在只能在群聊中使用
(GetGroupListCommand.get_command_info(), GetGroupListCommand), # 添加获取群列表命令
(PrivateInfoCommand.get_command_info(), PrivateInfoCommand), # 私聊专用命令
(GroupOnlyAction.get_action_info(), GroupOnlyAction), # 群聊专用动作
# (PrintMessage.get_handler_info(), PrintMessage),
]
# @register_plugin
# class HelloWorldEventPlugin(BaseEPlugin):
# """Hello World事件插件 - 处理问候和告别事件"""
# plugin_name = "hello_world_event_plugin"
# enable_plugin = False
# dependencies = []
# python_dependencies = []
# config_file_name = "event_config.toml"
# config_schema = {
# "plugin": {
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
# },
# }
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
# return [(PrintMessage.get_handler_info(), PrintMessage)]
# 添加一个新的私聊专用命令
class PrivateInfoCommand(BaseCommand):
command_name = "private_info"
command_description = "获取私聊信息"
command_pattern = r"^/私聊信息$"
chat_type_allow = ChatType.PRIVATE # 仅在私聊中可用
async def execute(self) -> Tuple[bool, Optional[str], bool]:
"""执行私聊信息命令"""
try:
await self.send_text("这是一个只能在私聊中使用的命令!")
return True, "私聊信息命令执行成功", False
except Exception as e:
logger.error(f"私聊信息命令执行失败: {e}")
return False, f"命令执行失败: {e}", False
# 添加一个新的仅群聊可用的Action
class GroupOnlyAction(BaseAction):
action_name = "group_only_test"
action_description = "群聊专用测试动作"
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
async def execute(self) -> Tuple[bool, str]:
"""执行群聊专用测试动作"""
try:
await self.send_text("这是一个只能在群聊中执行的动作!")
return True, "群聊专用动作执行成功"
except Exception as e:
logger.error(f"群聊专用动作执行失败: {e}")
return False, f"动作执行失败: {e}"

279
plugins/napcat_adapter_plugin/.gitignore vendored Normal file
View File

@@ -0,0 +1,279 @@
log/
logs/
out/
.env
.env.*
.cursor
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
uv.lock
llm_statistics.txt
mongodb
napcat
run_dev.bat
elua.confirmed
# C extensions
*.so
/results
config_backup/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
# jieba
jieba.cache
# .vscode
!.vscode/settings.json
# direnv
/.direnv
# JetBrains
.idea
*.iml
*.ipr
# PyEnv
# If using PyEnv and configured to use a specific Python version locally
# a .local-version file will be created in the root of the project to specify the version.
.python-version
OtherRes.txt
/eula.confirmed
/privacy.confirmed
logs
.ruff_cache
.vscode
/config/*
config/old/bot_config_20250405_212257.toml
temp/
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
config.toml
feature.toml
config.toml.back
test
data/NapcatAdapter.db
data/NapcatAdapter.db-shm
data/NapcatAdapter.db-wal

View File

@@ -0,0 +1 @@
PLUGIN_NAME = "napcat_adapter"

View File

@@ -0,0 +1,42 @@
{
"manifest_version": 1,
"name": "napcat_plugin",
"version": "1.0.0",
"description": "基于OneBot 11协议的NapCat QQ协议插件提供完整的QQ机器人API接口使用现有adapter连接",
"author": {
"name": "Windpicker_owo",
"url": "https://github.com/Windpicker-owo"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.0",
"max_version": "0.10.0"
},
"homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
"repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
"keywords": ["qq", "bot", "napcat", "onebot", "api", "websocket"],
"categories": ["protocol"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"components": [
{
"type": "tool",
"name": "napcat_tool",
"description": "NapCat QQ协议综合工具提供消息发送、群管理、好友管理、文件操作等完整功能"
}
],
"features": [
"消息发送与接收",
"群管理功能",
"好友管理功能",
"文件上传下载",
"AI语音功能",
"群签到与戳一戳",
"现有adapter连接"
]
}
}

View File

@@ -0,0 +1,6 @@
from typing import List, Tuple
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType
from src.plugin_system.base.base_event import HandlerResult
from src.plugin_system.core.event_manager import event_manager

View File

@@ -0,0 +1,69 @@
from enum import Enum
class NapcatEvent(Enum):
# napcat插件事件枚举类
class ON_RECEIVED(Enum):
"""
该分类下均为消息接受事件只能由napcat_plugin触发
"""
TEXT = "napcat_on_received_text" # 接收到文本消息
FACE = "napcat_on_received_face" # 接收到表情消息
REPLY = "napcat_on_received_reply" # 接收到回复消息
IMAGE = "napcat_on_received_image" # 接收到图像消息
RECORD = "napcat_on_received_record" # 接收到语音消息
VIDEO = "napcat_on_received_video" # 接收到视频消息
AT = "napcat_on_received_at" # 接收到at消息
DICE = "napcat_on_received_dice" # 接收到骰子消息
SHAKE = "napcat_on_received_shake" # 接收到屏幕抖动消息
JSON = "napcat_on_received_json" # 接收到JSON消息
RPS = "napcat_on_received_rps" # 接收到魔法猜拳消息
FRIEND_INPUT = "napcat_on_friend_input" # 好友正在输入
class ACCOUNT(Enum):
"""
该分类是对账户相关的操作只能由外部触发napcat_plugin负责处理
"""
SET_PROFILE = "napcat_set_qq_profile" # 设置账号信息
GET_ONLINE_CLIENTS = "napcat_get_online_clients" # 获取当前账号在线客户端列表
SET_ONLINE_STATUS = "napcat_set_online_status" # 设置在线状态
GET_FRIENDS_WITH_CATEGORY = "napcat_get_friends_with_category" # 获取好友分组列表
SET_AVATAR = "napcat_set_qq_avatar" # 设置头像
SEND_LIKE = "napcat_send_like" # 点赞
SET_FRIEND_ADD_REQUEST = "napcat_set_friend_add_request" # 处理好友请求
SET_SELF_LONGNICK = "napcat_set_self_longnick" # 设置个性签名
GET_LOGIN_INFO = "napcat_get_login_info" # 获取登录号信息
GET_RECENT_CONTACT = "napcat_get_recent_contact" # 最近消息列表
GET_STRANGER_INFO = "napcat_get_stranger_info" # 获取(指定)账号信息
GET_FRIEND_LIST = "napcat_get_friend_list" # 获取好友列表
GET_PROFILE_LIKE = "napcat_get_profile_like" # 获取点赞列表
DELETE_FRIEND = "napcat_delete_friend" # 删除好友
GET_USER_STATUS = "napcat_get_user_status" # 获取用户状态
GET_STATUS = "napcat_get_status" # 获取状态
GET_MINI_APP_ARK = "napcat_get_mini_app_ark" # 获取小程序卡片
SET_DIY_ONLINE_STATUS = "napcat_set_diy_online_status" # 设置自定义在线状态
class MESSAGE(Enum):
"""
该分类是对信息相关的操作只能由外部触发napcat_plugin负责处理
"""
SEND_GROUP_POKE = "napcat_send_group_poke" # 发送群聊戳一戳
SEND_PRIVATE_MSG = "napcat_send_private_msg" # 发送私聊消息
SEND_POKE = "napcat_send_friend_poke" # 发送戳一戳
DELETE_MSG = "napcat_delete_msg" # 撤回消息
GET_GROUP_MSG_HISTORY = "napcat_get_group_msg_history" # 获取群历史消息
GET_MSG = "napcat_get_msg" # 获取消息详情
GET_FORWARD_MSG = "napcat_get_forward_msg" # 获取合并转发消息
SET_MSG_EMOJI_LIKE = "napcat_set_msg_emoji_like" # 贴表情
GET_FRIEND_MSG_HISTORY = "napcat_get_friend_msg_history" # 获取好友历史消息
FETCH_EMOJI_LIKE = "napcat_fetch_emoji_like" # 获取贴表情详情
SEND_FORWARF_MSG = "napcat_send_forward_msg" # 发送合并转发消息
GET_RECOED = "napcat_get_record" # 获取语音消息详情
SEND_GROUP_AI_RECORD = "napcat_send_group_ai_record" # 发送群AI语音
class GROUP(Enum):
"""
该分类是对群聊相关的操作只能由外部触发napcat_plugin负责处理
"""

View File

@@ -0,0 +1,131 @@
import sys
import asyncio
import json
import websockets as Server
from . import event_types,CONSTS
from typing import List, Tuple
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType
from src.plugin_system.base.base_event import HandlerResult
from src.plugin_system.core.event_manager import event_manager
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
# 添加当前目录到Python路径这样可以识别src包
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
from .src.recv_handler.message_handler import message_handler
from .src.recv_handler.meta_event_handler import meta_event_handler
from .src.recv_handler.notice_handler import notice_handler
from .src.recv_handler.message_sending import message_send_instance
from .src.send_handler import send_handler
from .src.config import global_config
from .src.config.features_config import features_manager
from .src.config.migrate_features import auto_migrate_features
from .src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
from .src.response_pool import put_response, check_timeout_response
from .src.websocket_manager import websocket_manager
message_queue = asyncio.Queue()
class LauchNapcatAdapterHandler(BaseEventHandler):
"""自动启动Adapter"""
handler_name: str = "launch_napcat_adapter_handler"
handler_description: str = "自动启动napcat adapter"
weight: int = 100
intercept_message: bool = False
init_subscribe = [EventType.ON_START]
async def message_recv(self, server_connection: Server.ServerConnection):
await message_handler.set_server_connection(server_connection)
asyncio.create_task(notice_handler.set_server_connection(server_connection))
await send_handler.set_server_connection(server_connection)
async for raw_message in server_connection:
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
post_type = decoded_raw_message.get("post_type")
if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message)
elif post_type is None:
await put_response(decoded_raw_message)
async def message_process(self):
while True:
message = await message_queue.get()
post_type = message.get("post_type")
if post_type == "message":
await message_handler.handle_raw_message(message)
elif post_type == "meta_event":
await meta_event_handler.handle_meta_event(message)
elif post_type == "notice":
await notice_handler.handle_notice(message)
else:
logger.warning(f"未知的post_type: {post_type}")
message_queue.task_done()
await asyncio.sleep(0.05)
async def napcat_server(self):
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
mode = global_config.napcat_server.mode
logger.info(f"正在启动 adapter连接模式: {mode}")
try:
await websocket_manager.start_connection(self.message_recv)
except Exception as e:
logger.error(f"启动 WebSocket 连接失败: {e}")
raise
async def execute(self, kwargs):
# 执行功能配置迁移(如果需要)
logger.info("检查功能配置迁移...")
auto_migrate_features()
# 初始化功能管理器
logger.info("正在初始化功能管理器...")
features_manager.load_config()
await features_manager.start_file_watcher(check_interval=2.0)
logger.info("功能管理器初始化完成")
logger.info("开始启动Napcat Adapter")
message_send_instance.maibot_router = router
# 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(self.napcat_server())
asyncio.create_task(mmc_start_com())
asyncio.create_task(self.message_process())
asyncio.create_task(check_timeout_response())
@register_plugin
class NapcatAdapterPlugin(BasePlugin):
plugin_name = CONSTS.PLUGIN_NAME
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
}
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for e in event_types.NapcatEvent.ON_RECEIVED:
event_manager.register_event(e ,allowed_triggers=[self.plugin_name])
def get_plugin_components(self):
components = []
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
return components

View File

@@ -0,0 +1,47 @@
[project]
name = "MaiBotNapcatAdapter"
version = "0.4.8"
description = "A MaiBot adapter for Napcat"
dependencies = [
"ruff>=0.12.9",
]
[tool.ruff]
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = ["E711","E501"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

View File

@@ -0,0 +1,31 @@
from enum import Enum
import tomlkit
import os
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
SET_EMOJI_LIKE = "set_emoji_like" # 设置表情回应
SEND_AT_MESSAGE = "send_at_message" # 艾特用户并发送消息
SEND_LIKE = "send_like" # 点赞
def __str__(self) -> str:
return self.value
pyproject_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "pyproject.toml"
)
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
project_data = toml_data.get("project", {})
version = project_data.get("version", "unknown")
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")

View File

@@ -0,0 +1,5 @@
from .config import global_config
__all__ = [
"global_config",
]

View File

@@ -0,0 +1,148 @@
import os
from dataclasses import dataclass
from datetime import datetime
import tomlkit
import shutil
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from rich.traceback import install
from .config_base import ConfigBase
from .official_configs import (
DebugConfig,
MaiBotServerConfig,
NapcatServerConfig,
NicknameConfig,
VoiceConfig,
)
install(extra_lines=3)
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs(CONFIG_DIR, exist_ok=True)
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
def update_config():
"""更新配置文件,统一使用 config/old 目录进行备份"""
# 确保目录存在
ensure_config_directories()
# 定义文件路径
template_path = f"{TEMPLATE_DIR}/template_config.toml"
config_path = f"{CONFIG_DIR}/config.toml"
# 检查配置文件是否存在
if not os.path.exists(config_path):
logger.info("主配置文件不存在,从模板创建新配置")
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
logger.info("程序将退出,请检查配置文件后重启")
# 读取配置文件和模板文件
with open(config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建备份文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
# 备份旧配置文件
shutil.copy2(config_path, backup_path)
logger.info(f"已备份旧配置文件到: {backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""将source字典的值更新到target字典中如果target中存在相同的键"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
@dataclass
class Config(ConfigBase):
"""总配置类"""
nickname: NicknameConfig
napcat_server: NapcatServerConfig
maibot_server: MaiBotServerConfig
voice: VoiceConfig
debug: DebugConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 更新配置
update_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
logger.info("非常的新鲜,非常的美味!")

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: Dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
field_type = f.type
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
except Exception as e:
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
return field_type.from_dict(value)
field_origin_type = get_origin(field_type)
field_args_type = get_args(field_type)
# 处理泛型集合类型list, set, tuple
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_args_type[0]) for item in value]
if field_origin_type is set:
return {cls._convert_field(item, field_args_type[0]) for item in value}
if field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_args_type):
raise TypeError(
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_args_type) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_args_type
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理Optional类型
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
if value is None:
return None
# 如果有数据,检查实际类型
if type(value) not in field_args_type:
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
return cls._convert_field(value, field_args_type[0])
# 处理int, str, float, bool等基础类型
if field_origin_type is None:
if isinstance(value, field_type):
return field_type(value)
else:
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
# 处理Literal类型
if field_origin_type is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
# 处理其他类型
if field_type is Any:
return value
# 其他类型直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,146 @@
"""
配置文件工具模块
提供统一的配置文件生成和管理功能
"""
import os
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs("config", exist_ok=True)
os.makedirs("config/old", exist_ok=True)
def create_config_from_template(
config_path: str,
template_path: str,
config_name: str = "配置文件",
should_exit: bool = True
) -> bool:
"""
从模板创建配置文件的统一函数
Args:
config_path: 配置文件路径
template_path: 模板文件路径
config_name: 配置文件名称(用于日志显示)
should_exit: 创建后是否退出程序
Returns:
bool: 是否成功创建配置文件
"""
try:
# 确保配置目录存在
ensure_config_directories()
config_path_obj = Path(config_path)
template_path_obj = Path(template_path)
# 检查配置文件是否存在
if config_path_obj.exists():
return False # 配置文件已存在,无需创建
logger.info(f"{config_name}不存在,从模板创建新配置")
# 检查模板文件是否存在
if not template_path_obj.exists():
logger.error(f"模板文件不存在: {template_path}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 复制模板文件到配置目录
shutil.copy2(template_path_obj, config_path_obj)
logger.info(f"已创建新{config_name}: {config_path}")
if should_exit:
logger.info("程序将退出,请检查配置文件后重启")
quit(0)
return True
except Exception as e:
logger.error(f"创建{config_name}失败: {e}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
"""
创建默认配置文件(使用字典数据)
Args:
default_values: 默认配置值字典
config_path: 配置文件路径
config_name: 配置文件名称(用于日志显示)
Returns:
bool: 是否成功创建配置文件
"""
try:
import tomlkit
config_path_obj = Path(config_path)
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 写入默认配置
with open(config_path_obj, "w", encoding="utf-8") as f:
tomlkit.dump(default_values, f)
logger.info(f"已创建默认{config_name}: {config_path}")
return True
except Exception as e:
logger.error(f"创建默认{config_name}失败: {e}")
return False
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
"""
备份配置文件
Args:
config_path: 要备份的配置文件路径
backup_dir: 备份目录
Returns:
Optional[str]: 备份文件路径失败时返回None
"""
try:
config_path_obj = Path(config_path)
if not config_path_obj.exists():
return None
# 确保备份目录存在
backup_dir_obj = Path(backup_dir)
backup_dir_obj.mkdir(parents=True, exist_ok=True)
# 创建备份文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
backup_path = backup_dir_obj / backup_filename
# 备份文件
shutil.copy2(config_path_obj, backup_path)
logger.info(f"已备份配置文件到: {backup_path}")
return str(backup_path)
except Exception as e:
logger.error(f"备份配置文件失败: {e}")
return None

View File

@@ -0,0 +1,359 @@
import asyncio
from dataclasses import dataclass, field
from typing import Literal, Optional
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config_base import ConfigBase
from .config_utils import create_config_from_template, create_default_config_dict
@dataclass
class FeaturesConfig(ConfigBase):
"""功能配置类"""
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""群聊列表类型 白名单/黑名单"""
group_list: list[int] = field(default_factory=list)
"""群聊列表"""
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""私聊列表类型 白名单/黑名单"""
private_list: list[int] = field(default_factory=list)
"""私聊列表"""
ban_user_id: list[int] = field(default_factory=list)
"""被封禁的用户ID列表封禁后将无法与其进行交互"""
ban_qq_bot: bool = False
"""是否屏蔽QQ官方机器人若为True则所有QQ官方机器人将无法与MaiMCore进行交互"""
enable_poke: bool = True
"""是否启用戳一戳功能"""
ignore_non_self_poke: bool = False
"""是否无视不是针对自己的戳一戳"""
poke_debounce_seconds: int = 3
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
enable_reply_at: bool = True
"""是否启用引用回复时艾特用户的功能"""
reply_at_rate: float = 0.5
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
enable_video_analysis: bool = True
"""是否启用视频识别功能"""
max_video_size_mb: int = 100
"""视频文件最大大小限制MB"""
download_timeout: int = 60
"""视频下载超时时间(秒)"""
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
"""支持的视频格式"""
# 消息缓冲配置
enable_message_buffer: bool = True
"""是否启用消息缓冲合并功能"""
message_buffer_enable_group: bool = True
"""是否启用群消息缓冲合并"""
message_buffer_enable_private: bool = True
"""是否启用私聊消息缓冲合并"""
message_buffer_interval: float = 3.0
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
message_buffer_initial_delay: float = 0.5
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
message_buffer_max_components: int = 50
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "", ".", "", "#", "%"])
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
class FeaturesManager:
"""功能管理器,支持热重载"""
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
self.config_path = Path(config_path)
self.config: Optional[FeaturesConfig] = None
self._file_watcher_task: Optional[asyncio.Task] = None
self._last_modified: Optional[float] = None
self._callbacks: list = []
def add_reload_callback(self, callback):
"""添加配置重载回调函数"""
self._callbacks.append(callback)
def remove_reload_callback(self, callback):
"""移除配置重载回调函数"""
if callback in self._callbacks:
self._callbacks.remove(callback)
async def _notify_callbacks(self):
"""通知所有回调函数配置已重载"""
for callback in self._callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(self.config)
else:
callback(self.config)
except Exception as e:
logger.error(f"配置重载回调执行失败: {e}")
def load_config(self) -> FeaturesConfig:
"""加载功能配置文件"""
try:
# 检查配置文件是否存在,如果不存在则创建并退出程序
if not self.config_path.exists():
logger.info(f"功能配置文件不存在: {self.config_path}")
self._create_default_config()
# 配置文件创建后程序应该退出,让用户检查配置
logger.info("程序将退出,请检查功能配置文件后重启")
quit(0)
with open(self.config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
self.config = FeaturesConfig.from_dict(config_data)
self._last_modified = self.config_path.stat().st_mtime
logger.info(f"功能配置加载成功: {self.config_path}")
return self.config
except Exception as e:
logger.error(f"功能配置加载失败: {e}")
logger.critical("无法加载功能配置文件,程序退出")
quit(1)
def _create_default_config(self):
"""创建默认功能配置文件"""
template_path = "template/features_template.toml"
# 尝试从模板创建配置文件
if create_config_from_template(
str(self.config_path),
template_path,
"功能配置文件",
should_exit=False # 不在这里退出,由调用方决定
):
return
# 如果模板文件不存在,创建基本配置
logger.info("模板文件不存在,创建基本功能配置")
default_config = {
"group_list_type": "whitelist",
"group_list": [],
"private_list_type": "whitelist",
"private_list": [],
"ban_user_id": [],
"ban_qq_bot": False,
"enable_poke": True,
"ignore_non_self_poke": False,
"poke_debounce_seconds": 3,
"enable_reply_at": True,
"reply_at_rate": 0.5,
"enable_video_analysis": True,
"max_video_size_mb": 100,
"download_timeout": 60,
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
# 消息缓冲配置
"enable_message_buffer": True,
"message_buffer_enable_group": True,
"message_buffer_enable_private": True,
"message_buffer_interval": 3.0,
"message_buffer_initial_delay": 0.5,
"message_buffer_max_components": 50,
"message_buffer_block_prefixes": ["/", "!", "", ".", "", "#", "%"]
}
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
logger.critical("无法创建功能配置文件")
quit(1)
async def reload_config(self) -> bool:
"""重新加载配置文件"""
try:
if not self.config_path.exists():
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
return False
current_modified = self.config_path.stat().st_mtime
if self._last_modified and current_modified <= self._last_modified:
return False # 文件未修改
old_config = self.config
new_config = self.load_config()
# 检查配置是否真的发生了变化
if old_config and self._configs_equal(old_config, new_config):
return False
logger.info("功能配置已重载")
await self._notify_callbacks()
return True
except Exception as e:
logger.error(f"功能配置重载失败: {e}")
return False
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
"""比较两个配置是否相等"""
return (
config1.group_list_type == config2.group_list_type and
set(config1.group_list) == set(config2.group_list) and
config1.private_list_type == config2.private_list_type and
set(config1.private_list) == set(config2.private_list) and
set(config1.ban_user_id) == set(config2.ban_user_id) and
config1.ban_qq_bot == config2.ban_qq_bot and
config1.enable_poke == config2.enable_poke and
config1.ignore_non_self_poke == config2.ignore_non_self_poke and
config1.poke_debounce_seconds == config2.poke_debounce_seconds and
config1.enable_reply_at == config2.enable_reply_at and
config1.reply_at_rate == config2.reply_at_rate and
config1.enable_video_analysis == config2.enable_video_analysis and
config1.max_video_size_mb == config2.max_video_size_mb and
config1.download_timeout == config2.download_timeout and
set(config1.supported_formats) == set(config2.supported_formats) and
# 消息缓冲配置比较
config1.enable_message_buffer == config2.enable_message_buffer and
config1.message_buffer_enable_group == config2.message_buffer_enable_group and
config1.message_buffer_enable_private == config2.message_buffer_enable_private and
config1.message_buffer_interval == config2.message_buffer_interval and
config1.message_buffer_initial_delay == config2.message_buffer_initial_delay and
config1.message_buffer_max_components == config2.message_buffer_max_components and
set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
)
async def start_file_watcher(self, check_interval: float = 1.0):
"""启动文件监控,定期检查配置文件变化"""
if self._file_watcher_task and not self._file_watcher_task.done():
logger.warning("文件监控已在运行")
return
self._file_watcher_task = asyncio.create_task(
self._file_watcher_loop(check_interval)
)
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}")
async def stop_file_watcher(self):
"""停止文件监控"""
if self._file_watcher_task and not self._file_watcher_task.done():
self._file_watcher_task.cancel()
try:
await self._file_watcher_task
except asyncio.CancelledError:
pass
logger.info("功能配置文件监控已停止")
async def _file_watcher_loop(self, check_interval: float):
"""文件监控循环"""
while True:
try:
await asyncio.sleep(check_interval)
await self.reload_config()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"文件监控循环出错: {e}")
await asyncio.sleep(check_interval)
def get_config(self) -> FeaturesConfig:
"""获取当前功能配置"""
if self.config is None:
return self.load_config()
return self.config
def is_group_allowed(self, group_id: int) -> bool:
"""检查群聊是否被允许"""
config = self.get_config()
if config.group_list_type == "whitelist":
return group_id in config.group_list
else: # blacklist
return group_id not in config.group_list
def is_private_allowed(self, user_id: int) -> bool:
"""检查私聊是否被允许"""
config = self.get_config()
if config.private_list_type == "whitelist":
return user_id in config.private_list
else: # blacklist
return user_id not in config.private_list
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被全局禁止"""
config = self.get_config()
return user_id in config.ban_user_id
def is_qq_bot_banned(self) -> bool:
"""检查是否禁止QQ官方机器人"""
config = self.get_config()
return config.ban_qq_bot
def is_poke_enabled(self) -> bool:
"""检查戳一戳功能是否启用"""
config = self.get_config()
return config.enable_poke
def is_non_self_poke_ignored(self) -> bool:
"""检查是否忽略非自己戳一戳"""
config = self.get_config()
return config.ignore_non_self_poke
def is_message_buffer_enabled(self) -> bool:
"""检查消息缓冲功能是否启用"""
config = self.get_config()
return config.enable_message_buffer
def is_message_buffer_group_enabled(self) -> bool:
"""检查群消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查私聊消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_interval(self) -> float:
"""获取消息缓冲间隔时间"""
config = self.get_config()
return config.message_buffer_interval
def get_message_buffer_initial_delay(self) -> float:
"""获取消息缓冲初始延迟"""
config = self.get_config()
return config.message_buffer_initial_delay
def get_message_buffer_max_components(self) -> int:
"""获取消息缓冲最大组件数量"""
config = self.get_config()
return config.message_buffer_max_components
def is_message_buffer_group_enabled(self) -> bool:
"""检查是否启用群聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查是否启用私聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_block_prefixes(self) -> list[str]:
"""获取消息缓冲屏蔽前缀列表"""
config = self.get_config()
return config.message_buffer_block_prefixes
# 全局功能管理器实例
features_manager = FeaturesManager()

View File

@@ -0,0 +1,194 @@
"""
功能配置迁移脚本
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
"""
import os
import shutil
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def migrate_features_from_config(old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml"):
"""
从旧配置文件迁移功能设置到新的功能配置文件
Args:
old_config_path: 旧配置文件路径
new_features_path: 新功能配置文件路径
template_path: 功能配置模板路径
"""
try:
# 检查旧配置文件是否存在
if not os.path.exists(old_config_path):
logger.warning(f"旧配置文件不存在: {old_config_path}")
return False
# 读取旧配置文件
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 检查是否有chat配置段和video配置段
chat_config = old_config.get("chat", {})
video_config = old_config.get("video", {})
# 检查是否有权限相关配置
permission_keys = ["group_list_type", "group_list", "private_list_type",
"private_list", "ban_user_id", "ban_qq_bot",
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
has_permission_config = any(key in chat_config for key in permission_keys)
has_video_config = any(key in video_config for key in video_keys)
if not has_permission_config and not has_video_config:
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
return False
# 确保新功能配置目录存在
new_features_dir = Path(new_features_path).parent
new_features_dir.mkdir(parents=True, exist_ok=True)
# 如果新功能配置文件已存在,先备份
if os.path.exists(new_features_path):
backup_path = f"{new_features_path}.backup"
shutil.copy2(new_features_path, backup_path)
logger.info(f"已备份现有功能配置文件到: {backup_path}")
# 创建新的功能配置
new_features_config = {
"group_list_type": chat_config.get("group_list_type", "whitelist"),
"group_list": chat_config.get("group_list", []),
"private_list_type": chat_config.get("private_list_type", "whitelist"),
"private_list": chat_config.get("private_list", []),
"ban_user_id": chat_config.get("ban_user_id", []),
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
"enable_poke": chat_config.get("enable_poke", True),
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
"enable_video_analysis": video_config.get("enable_video_analysis", True),
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
"download_timeout": video_config.get("download_timeout", 60),
"supported_formats": video_config.get("supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
}
# 写入新的功能配置文件
with open(new_features_path, "w", encoding="utf-8") as f:
tomlkit.dump(new_features_config, f)
logger.info(f"功能配置已成功迁移到: {new_features_path}")
# 显示迁移的配置内容
logger.info("迁移的配置内容:")
for key, value in new_features_config.items():
logger.info(f" {key}: {value}")
return True
except Exception as e:
logger.error(f"功能配置迁移失败: {e}")
return False
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
"""
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
Args:
config_path: 配置文件路径
"""
try:
if not os.path.exists(config_path):
logger.warning(f"配置文件不存在: {config_path}")
return False
# 确保 config/old 目录存在
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
os.makedirs(old_config_dir, exist_ok=True)
# 备份原配置文件到 config/old 目录
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
shutil.copy2(config_path, old_config_path)
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 移除chat段中的功能相关配置
removed_keys = []
if "chat" in config:
chat_config = config["chat"]
permission_keys = ["group_list_type", "group_list", "private_list_type",
"private_list", "ban_user_id", "ban_qq_bot",
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
for key in permission_keys:
if key in chat_config:
del chat_config[key]
removed_keys.append(key)
if removed_keys:
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
# 移除video段中的配置
if "video" in config:
video_config = config["video"]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
video_removed_keys = []
for key in video_keys:
if key in video_config:
del video_config[key]
video_removed_keys.append(key)
if video_removed_keys:
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
removed_keys.extend(video_removed_keys)
# 如果video段为空则删除整个段
if not video_config:
del config["video"]
logger.info("已删除空的video配置段")
if removed_keys:
logger.info(f"总共移除的配置项: {removed_keys}")
# 写回配置文件
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(config))
logger.info(f"已更新配置文件: {config_path}")
return True
except Exception as e:
logger.error(f"移除功能配置失败: {e}")
return False
def auto_migrate_features():
"""
自动执行功能配置迁移
"""
logger.info("开始自动功能配置迁移...")
# 执行迁移
if migrate_features_from_config():
logger.info("功能配置迁移成功")
# 询问是否要从旧配置文件中移除功能配置
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
# 在实际使用中,这里可以添加用户确认逻辑
# 为了自动化,这里直接执行移除
remove_features_from_old_config()
else:
logger.info("功能配置迁移跳过或失败")
if __name__ == "__main__":
auto_migrate_features()

View File

@@ -0,0 +1,67 @@
from dataclasses import dataclass, field
from typing import Literal
from .config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
ADAPTER_PLATFORM = "qq"
@dataclass
class NicknameConfig(ConfigBase):
nickname: str
"""机器人昵称"""
@dataclass
class NapcatServerConfig(ConfigBase):
mode: Literal["reverse", "forward"] = "reverse"
"""连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
host: str = "localhost"
"""主机地址"""
port: int = 8095
"""端口号"""
url: str = ""
"""正向连接时的完整WebSocket URL如 ws://localhost:8080/ws"""
access_token: str = ""
"""WebSocket 连接的访问令牌,用于身份验证"""
heartbeat_interval: int = 30
"""心跳间隔时间,单位为秒"""
@dataclass
class MaiBotServerConfig(ConfigBase):
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
"""平台名称“qq”"""
host: str = "localhost"
"""MaiMCore的主机地址"""
port: int = 8000
"""MaiMCore的端口号"""
@dataclass
class VoiceConfig(ConfigBase):
use_tts: bool = False
"""是否启用TTS功能"""
@dataclass
class DebugConfig(ConfigBase):
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
"""日志级别默认为INFO"""

View File

@@ -0,0 +1,163 @@
import os
from typing import Optional, List
from dataclasses import dataclass
from sqlmodel import Field, Session, SQLModel, create_engine, select
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
其中使用 user_id == 0 表示群全体禁言
"""
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = Field(default=-1)
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录。
使用双重主键
"""
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
lift_time: Optional[int] # 禁言解除的时间(时间戳)
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同。
"""
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.info("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else:
# 创建新记录
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: {ban_record}")
logger.info("禁言记录已更新")
def get_ban_records(self) -> List[BanUser]:
"""
读取所有禁言记录。
"""
with Session(self.engine) as session:
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
)
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else:
# 如果记录不存在,创建新记录
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
def delete_ban_record(self, ban_record: BanUser):
"""
删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
db_manager = DatabaseManager()

View File

@@ -0,0 +1,320 @@
import asyncio
import time
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config.features_config import features_manager
from .recv_handler import RealMessageType
@dataclass
class TextMessage:
"""文本消息"""
text: str
timestamp: float = field(default_factory=time.time)
@dataclass
class BufferedSession:
"""缓冲会话数据"""
session_id: str
messages: List[TextMessage] = field(default_factory=list)
timer_task: Optional[asyncio.Task] = None
delay_task: Optional[asyncio.Task] = None
original_event: Any = None
created_at: float = field(default_factory=time.time)
class SimpleMessageBuffer:
def __init__(self, merge_callback=None):
"""
初始化消息缓冲器
Args:
merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数
"""
self.buffer_pool: Dict[str, BufferedSession] = {}
self.lock = asyncio.Lock()
self.merge_callback = merge_callback
self._shutdown = False
def get_session_id(self, event_data: Dict[str, Any]) -> str:
"""根据事件数据生成会话ID"""
message_type = event_data.get("message_type", "unknown")
user_id = event_data.get("user_id", "unknown")
if message_type == "private":
return f"private_{user_id}"
elif message_type == "group":
group_id = event_data.get("group_id", "unknown")
return f"group_{group_id}_{user_id}"
else:
return f"{message_type}_{user_id}"
def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]:
"""从OneBot消息中提取纯文本如果包含非文本内容则返回None"""
text_parts = []
has_non_text = False
logger.debug(f"正在提取消息文本,消息段数量: {len(message)}")
for msg_seg in message:
msg_type = msg_seg.get("type", "")
logger.debug(f"处理消息段类型: {msg_type}")
if msg_type == RealMessageType.text:
text = msg_seg.get("data", {}).get("text", "").strip()
if text:
text_parts.append(text)
logger.debug(f"提取到文本: {text[:50]}...")
else:
# 发现非文本消息段,标记为包含非文本内容
has_non_text = True
logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲")
# 如果包含非文本内容,则不进行缓冲
if has_non_text:
logger.debug("消息包含非文本内容,不进行缓冲")
return None
if text_parts:
combined_text = " ".join(text_parts).strip()
logger.debug(f"成功提取纯文本: {combined_text[:50]}...")
return combined_text
logger.debug("没有找到有效的文本内容")
return None
def should_skip_message(self, text: str) -> bool:
"""判断消息是否应该跳过缓冲"""
if not text or not text.strip():
return True
# 检查屏蔽前缀
config = features_manager.get_config()
block_prefixes = tuple(config.message_buffer_block_prefixes)
text = text.strip()
if text.startswith(block_prefixes):
logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...")
return True
return False
async def add_text_message(self, event_data: Dict[str, Any], message: List[Dict[str, Any]],
original_event: Any = None) -> bool:
"""
添加文本消息到缓冲区
Args:
event_data: 事件数据
message: OneBot消息数组
original_event: 原始事件对象
Returns:
是否成功添加到缓冲区
"""
if self._shutdown:
return False
config = features_manager.get_config()
if not config.enable_message_buffer:
return False
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config.message_buffer_enable_group:
return False
elif message_type == "private" and not config.message_buffer_enable_private:
return False
# 提取文本
text = self.extract_text_from_message(message)
if not text:
return False
# 检查是否应该跳过
if self.should_skip_message(text):
return False
session_id = self.get_session_id(event_data)
async with self.lock:
# 获取或创建会话
if session_id not in self.buffer_pool:
self.buffer_pool[session_id] = BufferedSession(
session_id=session_id,
original_event=original_event
)
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config.message_buffer_max_components:
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(
session_id=session_id,
original_event=original_event
)
session = self.buffer_pool[session_id]
# 添加文本消息
session.messages.append(TextMessage(text=text))
session.original_event = original_event # 更新事件
# 取消之前的定时器
await self._cancel_session_timers(session)
# 设置新的延迟任务
session.delay_task = asyncio.create_task(
self._wait_and_start_merge(session_id)
)
logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...")
return True
async def _cancel_session_timers(self, session: BufferedSession):
"""取消会话的所有定时器"""
for task_name in ['timer_task', 'delay_task']:
task = getattr(session, task_name)
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
setattr(session, task_name, None)
async def _wait_and_start_merge(self, session_id: str):
"""等待初始延迟后开始合并定时器"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_initial_delay)
async with self.lock:
session = self.buffer_pool.get(session_id)
if session and session.messages:
# 取消旧的定时器
if session.timer_task and not session.timer_task.done():
session.timer_task.cancel()
try:
await session.timer_task
except asyncio.CancelledError:
pass
# 设置合并定时器
session.timer_task = asyncio.create_task(
self._wait_and_merge(session_id)
)
async def _wait_and_merge(self, session_id: str):
"""等待合并间隔后执行合并"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_interval)
await self._merge_session(session_id)
async def _force_merge_session(self, session_id: str):
"""强制合并会话(不等待定时器)"""
await self._merge_session(session_id, force=True)
async def _merge_session(self, session_id: str, force: bool = False):
"""合并会话中的消息"""
async with self.lock:
session = self.buffer_pool.get(session_id)
if not session or not session.messages:
self.buffer_pool.pop(session_id, None)
return
try:
# 合并文本消息
text_parts = []
for msg in session.messages:
if msg.text.strip():
text_parts.append(msg.text.strip())
if not text_parts:
self.buffer_pool.pop(session_id, None)
return
merged_text = "".join(text_parts) # 使用中文逗号连接
message_count = len(session.messages)
logger.info(f"合并会话 {session_id}{message_count} 条文本消息: {merged_text[:100]}...")
# 调用回调函数
if self.merge_callback:
try:
if asyncio.iscoroutinefunction(self.merge_callback):
await self.merge_callback(session_id, merged_text, session.original_event)
else:
self.merge_callback(session_id, merged_text, session.original_event)
except Exception as e:
logger.error(f"消息合并回调执行失败: {e}")
except Exception as e:
logger.error(f"合并会话 {session_id} 时出错: {e}")
finally:
# 清理会话
await self._cancel_session_timers(session)
self.buffer_pool.pop(session_id, None)
async def flush_session(self, session_id: str):
"""强制刷新指定会话的缓冲区"""
await self._force_merge_session(session_id)
async def flush_all(self):
"""强制刷新所有会话的缓冲区"""
session_ids = list(self.buffer_pool.keys())
for session_id in session_ids:
await self._force_merge_session(session_id)
async def get_buffer_stats(self) -> Dict[str, Any]:
"""获取缓冲区统计信息"""
async with self.lock:
stats = {
"total_sessions": len(self.buffer_pool),
"sessions": {}
}
for session_id, session in self.buffer_pool.items():
stats["sessions"][session_id] = {
"message_count": len(session.messages),
"created_at": session.created_at,
"age": time.time() - session.created_at
}
return stats
async def clear_expired_sessions(self, max_age: float = 300.0):
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
async with self.lock:
for session_id, session in self.buffer_pool.items():
if current_time - session.created_at > max_age:
expired_sessions.append(session_id)
for session_id in expired_sessions:
logger.info(f"清理过期会话: {session_id}")
await self._force_merge_session(session_id)
async def shutdown(self):
"""关闭消息缓冲器"""
self._shutdown = True
logger.info("正在关闭简化消息缓冲器...")
# 刷新所有缓冲区
await self.flush_all()
# 确保所有任务都被取消
async with self.lock:
for session in list(self.buffer_pool.values()):
await self._cancel_session_timers(session)
self.buffer_pool.clear()
logger.info("简化消息缓冲器已关闭")

View File

@@ -0,0 +1,26 @@
from maim_message import Router, RouteConfig, TargetConfig
from .config import global_config
from src.common.logger import get_logger
from .send_handler import send_handler
logger = get_logger("napcat_adapter")
route_config = RouteConfig(
route_config={
global_config.maibot_server.platform_name: TargetConfig(
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
token=None,
)
}
)
router = Router(route_config)
async def mmc_start_com():
logger.info("正在连接MaiBot")
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
await router.stop()

View File

@@ -0,0 +1,89 @@
from enum import Enum
class MetaEventType:
lifecycle = "lifecycle" # 生命周期
class Lifecycle:
connect = "connect" # 生命周期 - WebSocket 连接成功
heartbeat = "heartbeat" # 心跳
class MessageType: # 接受消息大类
private = "private" # 私聊消息
class Private:
friend = "friend" # 私聊消息 - 好友
group = "group" # 私聊消息 - 群临时
group_self = "group_self" # 私聊消息 - 群中自身发送
other = "other" # 私聊消息 - 其他
group = "group" # 群聊消息
class Group:
normal = "normal" # 群聊消息 - 普通
anonymous = "anonymous" # 群聊消息 - 匿名消息
notice = "notice" # 群聊消息 - 系统提示
class NoticeType: # 通知事件
friend_recall = "friend_recall" # 私聊消息撤回
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
group_ban = "group_ban" # 群禁言
class Notify:
poke = "poke" # 戳一戳
input_status = "input_status" # 正在输入
class GroupBan:
ban = "ban" # 禁言
lift_ban = "lift_ban" # 解除禁言
class RealMessageType: # 实际消息分类
text = "text" # 纯文本
face = "face" # qq表情
image = "image" # 图片
record = "record" # 语音
video = "video" # 视频
at = "at" # @某人
rps = "rps" # 猜拳魔法表情
dice = "dice" # 骰子
shake = "shake" # 私聊窗口抖动(只收)
poke = "poke" # 群聊戳一戳
share = "share" # 链接分享json形式
reply = "reply" # 回复消息
forward = "forward" # 转发消息
node = "node" # 转发消息节点
json = "json" # json消息
class MessageSentType:
private = "private"
class Private:
friend = "friend"
group = "group"
group = "group"
class Group:
normal = "normal"
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
def __str__(self) -> str:
return self.value
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]

View File

@@ -0,0 +1,942 @@
from ...event_types import NapcatEvent
from src.plugin_system.core.event_manager import event_manager
from src.common.logger import get_logger
from ...CONSTS import PLUGIN_NAME
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from ..message_buffer import SimpleMessageBuffer
from ..utils import (
get_group_info,
get_member_info,
get_image_base64,
get_record_detail,
get_self_info,
get_message_detail,
)
from .qq_emoji_list import qq_face
from .message_sending import message_send_instance
from . import RealMessageType, MessageType, ACCEPT_FORMAT
from ..video_handler import get_video_downloader
from ..websocket_manager import websocket_manager
import time
import json
import websockets as Server
import base64
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import uuid
from maim_message import (
UserInfo,
GroupInfo,
Seg,
BaseMessageInfo,
MessageBase,
TemplateInfo,
FormatInfo,
)
from ..response_pool import get_response
class MessageHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
self.bot_id_list: Dict[int, bool] = {}
# 初始化简化消息缓冲器,传入回调函数
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
async def shutdown(self):
"""关闭消息处理器,清理资源"""
if self.message_buffer:
await self.message_buffer.shutdown()
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
def get_server_connection(self) -> Server.ServerConnection:
"""获取当前的服务器连接"""
# 优先使用直接设置的连接,否则从 websocket_manager 获取
if self.server_connection:
return self.server_connection
return websocket_manager.get_connection()
async def check_allow_to_chat(
self,
user_id: int,
group_id: Optional[int] = None,
ignore_bot: Optional[bool] = False,
ignore_global_list: Optional[bool] = False,
) -> bool:
# sourcery skip: hoist-statement-from-if, merge-else-if-into-elif
"""
检查是否允许聊天
Parameters:
user_id: int: 用户ID
group_id: int: 群ID
ignore_bot: bool: 是否忽略机器人检查
ignore_global_list: bool: 是否忽略全局黑名单检查
Returns:
bool: 是否允许聊天
"""
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
logger.debug("开始检查聊天白名单/黑名单")
# 使用新的权限管理器检查权限
if group_id:
if not features_manager.is_group_allowed(group_id):
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
return False
else:
if not features_manager.is_private_allowed(user_id):
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
return False
# 检查全局禁止名单
if not ignore_global_list and features_manager.is_user_banned(user_id):
logger.warning("用户在全局黑名单中,消息被丢弃")
return False
# 检查QQ官方机器人
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
logger.debug("开始判断是否为机器人")
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
if member_info:
is_bot = member_info.get("is_robot")
if is_bot is None:
logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新")
else:
if is_bot:
logger.warning("QQ官方机器人消息拦截已启用消息被丢弃新机器人加入拦截名单")
self.bot_id_list[user_id] = True
return False
else:
self.bot_id_list[user_id] = False
return True
async def handle_raw_message(self, raw_message: dict) -> None:
# sourcery skip: low-code-quality, remove-unreachable-code
"""
从Napcat接受的原始消息处理
Parameters:
raw_message: dict: 原始消息
"""
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
format_info: FormatInfo = FormatInfo(
content_format=["text", "image", "emoji", "voice"],
accept_format=ACCEPT_FORMAT,
) # 格式化信息
if message_type == MessageType.private:
sub_type = raw_message.get("sub_type")
if sub_type == MessageType.Private.friend:
sender_info: dict = raw_message.get("sender")
if not await self.check_allow_to_chat(sender_info.get("user_id"), None):
return None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
)
# 不存在群信息
group_info: GroupInfo = None
elif sub_type == MessageType.Private.group:
"""
本部分暂时不做支持,先放着
"""
logger.warning("群临时消息类型不支持")
return None
sender_info: dict = raw_message.get("sender")
# 由于临时会话中Napcat默认不发送成员昵称所以需要单独获取
fetched_member_info: dict = await get_member_info(
self.get_server_connection(),
raw_message.get("group_id"),
sender_info.get("user_id"),
)
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=nickname,
user_cardname=None,
)
# -------------------这里需要群信息吗?-------------------
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info: dict = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
group_name = ""
if fetched_group_info.get("group_name"):
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
else:
logger.warning(f"私聊消息类型 {sub_type} 不支持")
return None
elif message_type == MessageType.group:
sub_type = raw_message.get("sub_type")
if sub_type == MessageType.Group.normal:
sender_info: dict = raw_message.get("sender")
if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")):
return None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
)
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
else:
logger.warning(f"群聊消息类型 {sub_type} 不支持")
return None
additional_config: dict = {}
if global_config.voice.use_tts:
additional_config["allow_tts"] = True
# 消息信息
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id=message_id,
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=template_info,
format_info=format_info,
additional_config=additional_config,
)
# 处理实际信息
if not raw_message.get("message"):
logger.warning("原始消息内容为空")
return None
# 获取Seg列表
seg_message: List[Seg] = await self.handle_real_message(raw_message)
if not seg_message:
logger.warning("处理后消息内容为空")
return None
# 检查是否需要使用消息缓冲
if features_manager.is_message_buffer_enabled():
# 检查消息类型是否启用缓冲
message_type = raw_message.get("message_type")
should_use_buffer = False
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
should_use_buffer = True
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
should_use_buffer = True
if should_use_buffer:
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
logger.debug(f"原始消息段: {raw_message.get('message', [])}")
# 尝试添加到缓冲器
buffered = await self.message_buffer.add_text_message(
event_data={
"message_type": message_type,
"user_id": user_info.user_id,
"group_id": group_info.group_id if group_info else None,
},
message=raw_message.get("message", []),
original_event={
"message_info": message_info,
"raw_message": raw_message
}
)
if buffered:
logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
return None # 缓冲成功,不立即发送
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
logger.debug(f"准备发送消息到MaiBot消息段数量: {len(seg_message)}")
for i, seg in enumerate(seg_message):
logger.debug(f"消息段 {i}: type={seg.type}, data={str(seg.data)[:100]}...")
submit_seg: Seg = Seg(
type="seglist",
data=seg_message,
)
# MessageBase创建
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=submit_seg,
raw_message=raw_message.get("raw_message"),
)
logger.info("发送到Maibot处理信息")
await message_send_instance.message_send(message_base)
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
# sourcery skip: low-code-quality
"""
处理实际消息
Parameters:
real_message: dict: 实际消息
Returns:
seg_message: list[Seg]: 处理后的消息段列表
"""
real_message: list = raw_message.get("message")
if not real_message:
return None
seg_message: List[Seg] = []
for sub_message in real_message:
sub_message: dict
sub_message_type = sub_message.get("type")
match sub_message_type:
case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.TEXT,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("text处理失败")
case RealMessageType.face:
ret_seg = await self.handle_face_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FACE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("face处理失败或不支持")
case RealMessageType.reply:
if not in_reply:
ret_seg = await self.handle_reply_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.REPLY,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message += ret_seg
else:
logger.warning("reply处理失败")
case RealMessageType.image:
logger.debug(f"开始处理图片消息段")
ret_seg = await self.handle_image_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.IMAGE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
logger.debug(f"图片处理成功,添加到消息段")
else:
logger.warning("image处理失败")
logger.debug(f"图片消息段处理完成")
case RealMessageType.record:
ret_seg = await self.handle_record_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RECORD,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.clear()
seg_message.append(ret_seg)
break # 使得消息只有record消息
else:
logger.warning("record处理失败或不支持")
case RealMessageType.video:
ret_seg = await self.handle_video_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.VIDEO,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("video处理失败")
case RealMessageType.at:
ret_seg = await self.handle_at_message(
sub_message,
raw_message.get("self_id"),
raw_message.get("group_id"),
)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.AT,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("at处理失败")
case RealMessageType.rps:
ret_seg = await self.handle_rps_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RPS,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("rps处理失败")
case RealMessageType.dice:
ret_seg = await self.handle_dice_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.DICE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("dice处理失败")
case RealMessageType.shake:
ret_seg = await self.handle_shake_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.SHAKE,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("shake处理失败")
case RealMessageType.share:
print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌SHARE消息快速速把你刚刚收到的消息截图发到MoFox-Bot群里\n\n")
logger.warning("暂时不支持链接解析")
case RealMessageType.forward:
messages = await self._get_forward_message(sub_message)
if not messages:
logger.warning("转发消息内容为空或获取失败")
return None
ret_seg = await self.handle_forward_message(messages)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("转发消息处理失败")
case RealMessageType.node:
print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌NODE消息快速速把你刚刚收到的消息截图发到MoFox-Bot群里\n\n")
logger.warning("不支持转发消息节点解析")
case RealMessageType.json:
ret_seg = await self.handle_json_message(sub_message)
if ret_seg:
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.JSON,plugin_name=PLUGIN_NAME,message_seg=ret_seg)
seg_message.append(ret_seg)
else:
logger.warning("json处理失败")
case _:
logger.warning(f"未知消息类型: {sub_message_type}")
logger.debug(f"handle_real_message完成处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg")
return seg_message
async def handle_text_message(self, raw_message: dict) -> Seg:
"""
处理纯文本信息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
plain_text: str = message_data.get("text")
return Seg(type="text", data=plain_text)
async def handle_face_message(self, raw_message: dict) -> Seg | None:
"""
处理表情消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
face_raw_id: str = str(message_data.get("id"))
if face_raw_id in qq_face:
face_content: str = qq_face.get(face_raw_id)
return Seg(type="text", data=face_content)
else:
logger.warning(f"不支持的表情:{face_raw_id}")
return None
async def handle_image_message(self, raw_message: dict) -> Seg | None:
"""
处理图片消息与表情包消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
image_sub_type = message_data.get("sub_type")
try:
logger.debug(f"开始下载图片: {message_data.get('url')}")
image_base64 = await get_image_base64(message_data.get("url"))
logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符")
except Exception as e:
logger.error(f"图片消息处理失败: {str(e)}")
return None
if image_sub_type == 0:
"""这部分认为是图片"""
return Seg(type="image", data=image_base64)
elif image_sub_type not in [4, 9]:
"""这部分认为是表情包"""
return Seg(type="emoji", data=image_base64)
else:
logger.warning(f"不支持的图片子类型:{image_sub_type}")
return None
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None:
# sourcery skip: use-named-expression
"""
处理at消息
Parameters:
raw_message: dict: 原始消息
self_id: int: 机器人QQ号
group_id: int: 群号
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
if message_data:
qq_id = message_data.get("qq")
if str(self_id) == str(qq_id):
logger.debug("机器人被at")
self_info: dict = await get_self_info(self.get_server_connection())
if self_info:
return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>")
else:
return None
else:
member_info: dict = await get_member_info(self.get_server_connection(), group_id=group_id, user_id=qq_id)
if member_info:
return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>")
else:
return None
async def handle_record_message(self, raw_message: dict) -> Seg | None:
"""
处理语音消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
file: str = message_data.get("file")
if not file:
logger.warning("语音消息缺少文件信息")
return None
try:
record_detail = await get_record_detail(self.get_server_connection(), file)
if not record_detail:
logger.warning("获取语音消息详情失败")
return None
audio_base64: str = record_detail.get("base64")
except Exception as e:
logger.error(f"语音消息处理失败: {str(e)}")
return None
if not audio_base64:
logger.error("语音消息处理失败,未获取到音频数据")
return None
return Seg(type="voice", data=audio_base64)
async def handle_video_message(self, raw_message: dict) -> Seg | None:
"""
处理视频消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
# 添加详细的调试信息
logger.debug(f"视频消息原始数据: {raw_message}")
logger.debug(f"视频消息数据: {message_data}")
# QQ视频消息可能包含url或filePath字段
video_url = message_data.get("url")
file_path = message_data.get("filePath") or message_data.get("file_path")
logger.info(f"视频URL: {video_url}")
logger.info(f"视频文件路径: {file_path}")
# 优先使用本地文件路径其次使用URL
video_source = file_path if file_path else video_url
if not video_source:
logger.warning("视频消息缺少URL或文件路径信息")
logger.warning(f"完整消息数据: {message_data}")
return None
try:
# 检查是否为本地文件路径
if file_path and Path(file_path).exists():
logger.info(f"使用本地视频文件: {file_path}")
# 直接读取本地文件
with open(file_path, "rb") as f:
video_data = f.read()
# 将视频数据编码为base64用于传输
video_base64 = base64.b64encode(video_data).decode('utf-8')
logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
# 返回包含详细信息的字典格式
return Seg(type="video", data={
"base64": video_base64,
"filename": Path(file_path).name,
"size_mb": len(video_data) / (1024 * 1024)
})
elif video_url:
logger.info(f"使用视频URL下载: {video_url}")
# 使用video_handler下载视频
video_downloader = get_video_downloader()
download_result = await video_downloader.download_video(video_url)
if not download_result["success"]:
logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}")
logger.warning(f"失败的URL: {video_url}")
return None
# 将视频数据编码为base64用于传输
video_base64 = base64.b64encode(download_result["data"]).decode('utf-8')
logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
# 返回包含详细信息的字典格式
return Seg(type="video", data={
"base64": video_base64,
"filename": download_result.get("filename", "video.mp4"),
"size_mb": len(download_result["data"]) / (1024 * 1024),
"url": video_url
})
else:
logger.warning("既没有有效的本地文件路径也没有有效的视频URL")
return None
except Exception as e:
logger.error(f"视频消息处理失败: {str(e)}")
logger.error(f"视频源: {video_source}")
return None
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
# sourcery skip: move-assign-in-block, use-named-expression
"""
处理回复消息
"""
raw_message_data: dict = raw_message.get("data")
message_id: int = None
if raw_message_data:
message_id = raw_message_data.get("id")
else:
return None
message_detail: dict = await get_message_detail(self.get_server_connection(), message_id)
if not message_detail:
logger.warning("获取被引用的消息详情失败")
return None
reply_message = await self.handle_real_message(message_detail, in_reply=True)
if reply_message is None:
reply_message = [Seg(type="text", data="(获取发言内容失败)")]
sender_info: dict = message_detail.get("sender")
sender_nickname: str = sender_info.get("nickname")
sender_id: str = sender_info.get("user_id")
seg_message: List[Seg] = []
if not sender_nickname:
logger.warning("无法获取被引用的人的昵称,返回默认值")
seg_message.append(Seg(type="text", data="[回复 未知用户:"))
else:
seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>"))
seg_message += reply_message
seg_message.append(Seg(type="text", data="],说:"))
return seg_message
async def handle_forward_message(self, message_list: list) -> Seg | None:
"""
递归处理转发消息,并按照动态方式确定图片处理方式
Parameters:
message_list: list: 转发消息列表
"""
handled_message, image_count = await self._handle_forward_message(
message_list, 0
)
handled_message: Seg
image_count: int
if not handled_message:
return None
processed_message: Seg
if image_count < 5 and image_count > 0:
# 处理图片数量小于5的情况此时解析图片为base64
logger.info("图片数量小于5开始解析图片为base64")
processed_message = await self._recursive_parse_image_seg(
handled_message, True
)
elif image_count > 0:
logger.info("图片数量大于等于5开始解析图片为占位符")
# 处理图片数量大于等于5的情况此时解析图片为占位符
processed_message = await self._recursive_parse_image_seg(
handled_message, False
)
else:
# 处理没有图片的情况,此时直接返回
logger.info("没有图片,直接返回")
processed_message = handled_message
# 添加转发消息提示
forward_hint = Seg(type="text", data="这是一条转发消息:\n")
return Seg(type="seglist", data=[forward_hint, processed_message])
async def handle_dice_message(self, raw_message: dict) -> Seg:
message_data: dict = raw_message.get("data",{})
res = message_data.get("result","")
return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]")
async def handle_shake_message(self, raw_message: dict) -> Seg:
return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]")
async def handle_json_message(self, raw_message: dict) -> Seg:
message_data: str = raw_message.get("data","").get("data","")
res = json.loads(message_data)
return Seg(type="json", data=res)
async def handle_rps_message(self, raw_message: dict) -> Seg:
message_data: dict = raw_message.get("data",{})
res = message_data.get("result","")
if res == "1":
shape = ""
elif res == "2":
shape = "剪刀"
else:
shape = "石头"
return Seg(type="text", data=f"[发送了一个魔法猜拳表情,结果是:{shape}]")
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
# sourcery skip: merge-else-if-into-elif
if to_image:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
image_url = seg_data.data
try:
encoded_image = await get_image_base64(image_url)
except Exception as e:
logger.error(f"图片处理失败: {str(e)}")
return Seg(type="text", data="[图片]")
return Seg(type="image", data=encoded_image)
elif seg_data.type == "emoji":
image_url = seg_data.data
try:
encoded_image = await get_image_base64(image_url)
except Exception as e:
logger.error(f"图片处理失败: {str(e)}")
return Seg(type="text", data="[表情包]")
return Seg(type="emoji", data=encoded_image)
else:
logger.info(f"不处理类型: {seg_data.type}")
return seg_data
else:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
return Seg(type="text", data="[图片]")
elif seg_data.type == "emoji":
return Seg(type="text", data="[动画表情]")
else:
logger.info(f"不处理类型: {seg_data.type}")
return seg_data
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
# sourcery skip: low-code-quality
"""
递归处理实际转发消息
Parameters:
message_list: list: 转发消息列表首层对应messages字段后面对应content字段
layer: int: 当前层级
Returns:
seg_data: Seg: 处理后的消息段
image_count: int: 图片数量
"""
seg_list: List[Seg] = []
image_count = 0
if message_list is None:
return None, 0
for sub_message in message_list:
sub_message: dict
sender_info: dict = sub_message.get("sender")
user_nickname: str = sender_info.get("nickname", "QQ用户")
user_nickname_str = f"{user_nickname}】:"
break_seg = Seg(type="text", data="\n")
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
if not message_of_sub_message_list:
logger.warning("转发消息内容为空")
continue
message_of_sub_message = message_of_sub_message_list[0]
if message_of_sub_message.get("type") == RealMessageType.forward:
if layer >= 3:
full_seg_data = Seg(
type="text",
data=("--" * layer) + f"{user_nickname}】:【转发消息】\n",
)
else:
sub_message_data = message_of_sub_message.get("data")
if not sub_message_data:
continue
contents = sub_message_data.get("content")
seg_data, count = await self._handle_forward_message(contents, layer + 1)
image_count += count
head_tip = Seg(
type="text",
data=("--" * layer) + f"{user_nickname}】: 合并转发消息内容:\n",
)
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
seg_list.append(full_seg_data)
elif message_of_sub_message.get("type") == RealMessageType.text:
sub_message_data = message_of_sub_message.get("data")
if not sub_message_data:
continue
text_message = sub_message_data.get("text")
seg_data = Seg(type="text", data=text_message)
data_list: List[Any] = []
if layer > 0:
data_list = [
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
]
else:
data_list = [
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
]
seg_list.append(Seg(type="seglist", data=data_list))
elif message_of_sub_message.get("type") == RealMessageType.image:
image_count += 1
image_data = message_of_sub_message.get("data")
sub_type = image_data.get("sub_type")
image_url = image_data.get("url")
data_list: List[Any] = []
if sub_type == 0:
seg_data = Seg(type="image", data=image_url)
else:
seg_data = Seg(type="emoji", data=image_url)
if layer > 0:
data_list = [
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
]
else:
data_list = [
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
]
full_seg_data = Seg(type="seglist", data=data_list)
seg_list.append(full_seg_data)
return Seg(type="seglist", data=seg_list), image_count
async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None:
forward_message_data: Dict = raw_message.get("data")
if not forward_message_data:
logger.warning("转发消息内容为空")
return None
forward_message_id = forward_message_data.get("id")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_forward_msg",
"params": {"message_id": forward_message_id},
"echo": request_uuid,
}
)
try:
connection = self.get_server_connection()
if not connection:
logger.error("没有可用的 WebSocket 连接")
return None
await connection.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error("获取转发消息超时")
return None
except Exception as e:
logger.error(f"获取转发消息失败: {str(e)}")
return None
logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..."
if len(json.dumps(response)) > 80
else json.dumps(response)
)
response_data: Dict = response.get("data")
if not response_data:
logger.warning("转发消息内容为空或获取失败")
return None
return response_data.get("messages")
async def _send_buffered_message(self, session_id: str, merged_text: str, original_event: Dict[str, Any]):
"""发送缓冲的合并消息"""
try:
# 从原始事件数据中提取信息
message_info = original_event.get("message_info")
raw_message = original_event.get("raw_message")
if not message_info or not raw_message:
logger.error("缓冲消息缺少必要信息")
return
# 创建合并后的消息段 - 将合并的文本转换为Seg格式
from maim_message import Seg
merged_seg = Seg(type="text", data=merged_text)
submit_seg = Seg(type="seglist", data=[merged_seg])
# 创建新的消息ID
import time
new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}"
# 更新消息信息
from maim_message import BaseMessageInfo, MessageBase
buffered_message_info = BaseMessageInfo(
platform=message_info.platform,
message_id=new_message_id,
time=time.time(),
user_info=message_info.user_info,
group_info=message_info.group_info,
template_info=message_info.template_info,
format_info=message_info.format_info,
additional_config=message_info.additional_config,
)
# 创建MessageBase
message_base = MessageBase(
message_info=buffered_message_info,
message_segment=submit_seg,
raw_message=raw_message.get("raw_message", ""),
)
logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}")
await message_send_instance.message_send(message_base)
except Exception as e:
logger.error(f"发送缓冲消息失败: {e}", exc_info=True)
message_handler = MessageHandler()

View File

@@ -0,0 +1,32 @@
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from maim_message import MessageBase, Router
class MessageSending:
"""
负责把消息发送到麦麦
"""
maibot_router: Router = None
def __init__(self):
pass
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息
Parameters:
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
"""
try:
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("可能是路由未正确配置或连接异常")
return send_status
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
message_send_instance = MessageSending()

View File

@@ -0,0 +1,50 @@
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
import time
import asyncio
from . import MetaEventType
class MetaEventHandler:
"""
处理Meta事件
"""
def __init__(self):
self.interval = global_config.napcat_server.heartbeat_interval
self._interval_checking = False
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")
if event_type == MetaEventType.lifecycle:
sub_type = message.get("sub_type")
if sub_type == MetaEventType.Lifecycle.connect:
self_id = message.get("self_id")
self.last_heart_beat = time.time()
logger.info(f"Bot {self_id} 连接成功")
asyncio.create_task(self.check_heartbeat(self_id))
elif event_type == MetaEventType.heartbeat:
if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking:
asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
self_id = message.get("self_id")
logger.warning(f"Bot {self_id} Napcat 端异常!")
async def check_heartbeat(self, id: int) -> None:
self._interval_checking = True
while True:
now_time = time.time()
if now_time - self.last_heart_beat > self.interval * 2:
logger.error(f"Bot {id} 可能发生了连接断开被下线或者Napcat卡死")
break
else:
logger.debug("心跳正常")
await asyncio.sleep(self.interval)
meta_event_handler = MetaEventHandler()

View File

@@ -0,0 +1,556 @@
import time
import json
import asyncio
import websockets as Server
from typing import Tuple, Optional
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from ..database import BanUser, db_manager, is_identical
from . import NoticeType, ACCEPT_FORMAT
from .message_sending import message_send_instance
from .message_handler import message_handler
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
from ..websocket_manager import websocket_manager
from ..utils import (
get_group_info,
get_member_info,
get_self_info,
get_stranger_info,
read_ban_list,
)
from ...CONSTS import PLUGIN_NAME
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
class NoticeHandler:
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
lifted_list: list[BanUser] = [] # 已经自然解除禁言
def __init__(self):
self.server_connection: Server.ServerConnection | None = None
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
while self.server_connection.state != Server.State.OPEN:
await asyncio.sleep(0.5)
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
asyncio.create_task(self.auto_lift_detect())
asyncio.create_task(self.send_notice())
asyncio.create_task(self.handle_natural_lift())
def get_server_connection(self) -> Server.ServerConnection:
"""获取当前的服务器连接"""
# 优先使用直接设置的连接,否则从 websocket_manager 获取
if self.server_connection:
return self.server_connection
return websocket_manager.get_connection()
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
"""
将用户禁言记录添加到self.banned_list中
如果是全体禁言则user_id为0
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
lift_time = -1
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
for record in self.banned_list:
if is_identical(record, ban_record):
self.banned_list.remove(record)
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 作为更新
return
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 添加到数据库
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
"""
从self.lifted_group_list中移除已经解除全体禁言的群
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
self.lifted_list.append(ban_record)
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
async def handle_notice(self, raw_message: dict) -> None:
notice_type = raw_message.get("notice_type")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
group_id = raw_message.get("group_id")
user_id = raw_message.get("user_id")
target_id = raw_message.get("target_id")
handled_message: Seg = None
user_info: UserInfo = None
system_notice: bool = False
match notice_type:
case NoticeType.friend_recall:
logger.info("好友撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.group_recall:
logger.info("群内用户撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.notify:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
logger.info("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
case NoticeType.Notify.input_status:
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT,plugin_name=PLUGIN_NAME)
case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_ban:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.GroupBan.ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理群禁言")
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
system_notice = True
case NoticeType.GroupBan.lift_ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理解除群禁言")
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
system_notice = True
case _:
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
case _:
logger.warning(f"不支持的notice类型: {notice_type}")
return None
if not handled_message or not user_info:
logger.warning("notice处理失败或不支持")
return None
group_info: GroupInfo = None
if group_id:
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id="notice",
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=None,
format_info=FormatInfo(
content_format=["text", "notify"],
accept_format=ACCEPT_FORMAT,
),
additional_config={"target_id": target_id}, # 在这里塞了一个target_id方便mmc那边知道被戳的人是谁
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=handled_message,
raw_message=json.dumps(raw_message),
)
if system_notice:
await self.put_notice(message_base)
else:
logger.info("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
async def handle_poke_notify(
self, raw_message: dict, group_id: int, user_id: int
) -> Tuple[Seg | None, UserInfo | None]:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
self_info: dict = await get_self_info(self.get_server_connection())
if not self_info:
logger.error("自身信息获取失败")
return None, None
self_id = raw_message.get("self_id")
target_id = raw_message.get("target_id")
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
if self_id == target_id:
current_time = time.time()
debounce_seconds = features_manager.get_config().poke_debounce_seconds
if self.last_poke_time > 0:
time_diff = current_time - self.last_poke_time
if time_diff < debounce_seconds:
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
return None, None
# 记录这次戳一戳的时间
self.last_poke_time = current_time
target_name: str = None
raw_info: list = raw_message.get("raw_info")
if group_id:
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
else:
user_qq_info: dict = await get_stranger_info(self.get_server_connection(), user_id)
if user_qq_info:
user_name = user_qq_info.get("nickname")
user_cardname = user_qq_info.get("card")
else:
user_name = "QQ用户"
user_cardname = "QQ用户"
logger.info("无法获取戳一戳对方的用户昵称")
# 计算Seg
if self_id == target_id:
display_name = ""
target_name = self_info.get("nickname")
elif self_id == user_id:
# 让ada不发送麦麦戳别人的消息
return None, None
else:
# 如果配置为忽略不是针对自己的戳一戳则直接返回None
if features_manager.is_non_self_poke_ignored():
logger.info("忽略不是针对自己的戳一戳消息")
return None, None
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
if group_id:
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, target_id)
if fetched_member_info:
target_name = fetched_member_info.get("nickname")
else:
target_name = "QQ用户"
logger.info("无法获取被戳一戳方的用户昵称")
display_name = user_name
else:
return None, None
first_txt: str = "戳了戳"
second_txt: str = ""
try:
first_txt = raw_info[2].get("txt", "戳了戳")
second_txt = raw_info[4].get("txt", "")
except Exception as e:
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
)
seg_data: Seg = Seg(
type="text",
data=f"{display_name}{first_txt}{target_name}{second_txt}这是QQ的一个功能用于提及某人但没那么明显",
)
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
user_id = raw_message.get("user_id")
banned_user_info: UserInfo = None
user_nickname: str = "QQ用户"
user_cardname: str = None
sub_type: str = None
duration = raw_message.get("duration")
if duration is None:
logger.error("禁言时长不能为空,无法处理禁言通知")
return None, None
if user_id == 0: # 为全体禁言
sub_type: str = "whole_ban"
self._ban_operation(group_id)
else: # 为单人禁言
# 获取被禁言人的信息
sub_type: str = "ban"
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
banned_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._ban_operation(group_id, user_id, int(time.time() + duration))
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"duration": duration,
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
},
)
return seg_data, operator_info
async def handle_lift_ban_notify(
self, raw_message: dict, group_id: int
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
sub_type: str = None
user_nickname: str = "QQ用户"
user_cardname: str = None
lifted_user_info: UserInfo = None
user_id = raw_message.get("user_id")
if user_id == 0: # 全体禁言解除
sub_type = "whole_lift_ban"
self._lift_operation(group_id)
else: # 单人禁言解除
sub_type = "lift_ban"
# 获取被解除禁言人的信息
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
else:
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._lift_operation(group_id, user_id)
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
},
)
return seg_data, operator_info
async def put_notice(self, message_base: MessageBase) -> None:
"""
将处理后的通知消息放入通知队列
"""
if notice_queue.full() or unsuccessful_notice_queue.full():
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
else:
await notice_queue.put(message_base)
async def handle_natural_lift(self) -> None:
while True:
if len(self.lifted_list) != 0:
lift_record = self.lifted_list.pop()
group_id = lift_record.group_id
user_id = lift_record.user_id
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
seg_message: Seg = await self.natural_lift(group_id, user_id)
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id="notice",
time=time.time(),
user_info=None, # 自然解除禁言没有操作者
group_info=group_info,
template_info=None,
format_info=None,
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=seg_message,
raw_message=json.dumps(
{
"post_type": "notice",
"notice_type": "group_ban",
"sub_type": "lift_ban",
"group_id": group_id,
"user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者
}
),
)
await self.put_notice(message_base)
await asyncio.sleep(0.5) # 确保队列处理间隔
else:
await asyncio.sleep(5) # 每5秒检查一次
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None
if user_id == 0: # 理论上永远不会触发
return Seg(
type="notify",
data={
"sub_type": "whole_lift_ban",
"lifted_user_info": None,
},
)
user_nickname: str = "QQ用户"
user_cardname: str = None
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
return Seg(
type="notify",
data={
"sub_type": "lift_ban",
"lifted_user_info": lifted_user_info.to_dict(),
},
)
async def auto_lift_detect(self) -> None:
while True:
if len(self.banned_list) == 0:
await asyncio.sleep(5)
continue
for ban_record in self.banned_list:
if ban_record.user_id == 0 or ban_record.lift_time == -1:
continue
if ban_record.lift_time <= int(time.time()):
# 触发自然解除禁言
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
self.lifted_list.append(ban_record)
self.banned_list.remove(ban_record)
await asyncio.sleep(5)
async def send_notice(self) -> None:
"""
发送通知消息到Napcat
"""
while True:
if not unsuccessful_notice_queue.empty():
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
unsuccessful_notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
continue
to_be_send: MessageBase = await notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
notice_handler = NoticeHandler()

View File

@@ -0,0 +1,250 @@
qq_face: dict = {
"0": "[表情:惊讶]",
"1": "[表情:撇嘴]",
"2": "[表情:色]",
"3": "[表情:发呆]",
"4": "[表情:得意]",
"5": "[表情:流泪]",
"6": "[表情:害羞]",
"7": "[表情:闭嘴]",
"8": "[表情:睡]",
"9": "[表情:大哭]",
"10": "[表情:尴尬]",
"11": "[表情:发怒]",
"12": "[表情:调皮]",
"13": "[表情:呲牙]",
"14": "[表情:微笑]",
"15": "[表情:难过]",
"16": "[表情:酷]",
"18": "[表情:抓狂]",
"19": "[表情:吐]",
"20": "[表情:偷笑]",
"21": "[表情:可爱]",
"22": "[表情:白眼]",
"23": "[表情:傲慢]",
"24": "[表情:饥饿]",
"25": "[表情:困]",
"26": "[表情:惊恐]",
"27": "[表情:流汗]",
"28": "[表情:憨笑]",
"29": "[表情:悠闲]",
"30": "[表情:奋斗]",
"31": "[表情:咒骂]",
"32": "[表情:疑问]",
"33": "[表情: 嘘]",
"34": "[表情:晕]",
"35": "[表情:折磨]",
"36": "[表情:衰]",
"37": "[表情:骷髅]",
"38": "[表情:敲打]",
"39": "[表情:再见]",
"41": "[表情:发抖]",
"42": "[表情:爱情]",
"43": "[表情:跳跳]",
"46": "[表情:猪头]",
"49": "[表情:拥抱]",
"53": "[表情:蛋糕]",
"56": "[表情:刀]",
"59": "[表情:便便]",
"60": "[表情:咖啡]",
"63": "[表情:玫瑰]",
"64": "[表情:凋谢]",
"66": "[表情:爱心]",
"67": "[表情:心碎]",
"74": "[表情:太阳]",
"75": "[表情:月亮]",
"76": "[表情:赞]",
"77": "[表情:踩]",
"78": "[表情:握手]",
"79": "[表情:胜利]",
"85": "[表情:飞吻]",
"86": "[表情:怄火]",
"89": "[表情:西瓜]",
"96": "[表情:冷汗]",
"97": "[表情:擦汗]",
"98": "[表情:抠鼻]",
"99": "[表情:鼓掌]",
"100": "[表情:糗大了]",
"101": "[表情:坏笑]",
"102": "[表情:左哼哼]",
"103": "[表情:右哼哼]",
"104": "[表情:哈欠]",
"105": "[表情:鄙视]",
"106": "[表情:委屈]",
"107": "[表情:快哭了]",
"108": "[表情:阴险]",
"109": "[表情:左亲亲]",
"110": "[表情:吓]",
"111": "[表情:可怜]",
"112": "[表情:菜刀]",
"114": "[表情:篮球]",
"116": "[表情:示爱]",
"118": "[表情:抱拳]",
"119": "[表情:勾引]",
"120": "[表情:拳头]",
"121": "[表情:差劲]",
"123": "[表情NO]",
"124": "[表情OK]",
"125": "[表情:转圈]",
"129": "[表情:挥手]",
"137": "[表情:鞭炮]",
"144": "[表情:喝彩]",
"146": "[表情:爆筋]",
"147": "[表情:棒棒糖]",
"169": "[表情:手枪]",
"171": "[表情:茶]",
"172": "[表情:眨眼睛]",
"173": "[表情:泪奔]",
"174": "[表情:无奈]",
"175": "[表情:卖萌]",
"176": "[表情:小纠结]",
"177": "[表情:喷血]",
"178": "[表情:斜眼笑]",
"179": "[表情doge]",
"181": "[表情:戳一戳]",
"182": "[表情:笑哭]",
"183": "[表情:我最美]",
"185": "[表情:羊驼]",
"187": "[表情:幽灵]",
"201": "[表情:点赞]",
"212": "[表情:托腮]",
"262": "[表情:脑阔疼]",
"263": "[表情:沧桑]",
"264": "[表情:捂脸]",
"265": "[表情:辣眼睛]",
"266": "[表情:哦哟]",
"267": "[表情:头秃]",
"268": "[表情:问号脸]",
"269": "[表情:暗中观察]",
"270": "[表情emm]",
"271": "[表情:吃 瓜]",
"272": "[表情:呵呵哒]",
"273": "[表情:我酸了]",
"277": "[表情:汪汪]",
"281": "[表情:无眼笑]",
"282": "[表情:敬礼]",
"283": "[表情:狂笑]",
"284": "[表情:面无表情]",
"285": "[表情:摸鱼]",
"286": "[表情:魔鬼笑]",
"287": "[表情:哦]",
"289": "[表情:睁眼]",
"293": "[表情:摸锦鲤]",
"294": "[表情:期待]",
"295": "[表情:拿到红包]",
"297": "[表情:拜谢]",
"298": "[表情:元宝]",
"299": "[表情:牛啊]",
"300": "[表情:胖三斤]",
"302": "[表情:左拜年]",
"303": "[表情:右拜年]",
"305": "[表情:右亲亲]",
"306": "[表情:牛气冲天]",
"307": "[表情:喵喵]",
"311": "[表情打call]",
"312": "[表情:变形]",
"314": "[表情:仔细分析]",
"317": "[表情:菜汪]",
"318": "[表情:崇拜]",
"319": "[表情: 比心]",
"320": "[表情:庆祝]",
"323": "[表情:嫌弃]",
"324": "[表情:吃糖]",
"325": "[表情:惊吓]",
"326": "[表情:生气]",
"332": "[表情:举牌牌]",
"333": "[表情:烟花]",
"334": "[表情:虎虎生威]",
"336": "[表情:豹富]",
"337": "[表情:花朵脸]",
"338": "[表情:我想开了]",
"339": "[表情:舔屏]",
"341": "[表情:打招呼]",
"342": "[表情酸Q]",
"343": "[表情:我方了]",
"344": "[表情:大怨种]",
"345": "[表情:红包多多]",
"346": "[表情:你真棒棒]",
"347": "[表情:大展宏兔]",
"349": "[表情:坚强]",
"350": "[表情:贴贴]",
"351": "[表情:敲敲]",
"352": "[表情:咦]",
"353": "[表情:拜托]",
"354": "[表情:尊嘟假嘟]",
"355": "[表情:耶]",
"356": "[表情666]",
"357": "[表情:裂开]",
"392": "[表情:龙年 快乐]",
"393": "[表情:新年中龙]",
"394": "[表情:新年大龙]",
"395": "[表情:略略略]",
"😊": "[表情:嘿嘿]",
"😌": "[表情:羞涩]",
"😚": "[ 表情:亲亲]",
"😓": "[表情:汗]",
"😰": "[表情:紧张]",
"😝": "[表情:吐舌]",
"😁": "[表情:呲牙]",
"😜": "[表情:淘气]",
"": "[表情:可爱]",
"😍": "[表情:花痴]",
"😔": "[表情:失落]",
"😄": "[表情:高兴]",
"😏": "[表情:哼哼]",
"😒": "[表情:不屑]",
"😳": "[表情:瞪眼]",
"😘": "[表情:飞吻]",
"😭": "[表情:大哭]",
"😱": "[表情:害怕]",
"😂": "[表情:激动]",
"💪": "[表情:肌肉]",
"👊": "[表情:拳头]",
"👍": "[表情 :厉害]",
"👏": "[表情:鼓掌]",
"👎": "[表情:鄙视]",
"🙏": "[表情:合十]",
"👌": "[表情:好的]",
"👆": "[表情:向上]",
"👀": "[表情:眼睛]",
"🍜": "[表情:拉面]",
"🍧": "[表情:刨冰]",
"🍞": "[表情:面包]",
"🍺": "[表情:啤酒]",
"🍻": "[表情:干杯]",
"": "[表情:咖啡]",
"🍎": "[表情:苹果]",
"🍓": "[表情:草莓]",
"🍉": "[表情:西瓜]",
"🚬": "[表情:吸烟]",
"🌹": "[表情:玫瑰]",
"🎉": "[表情:庆祝]",
"💝": "[表情:礼物]",
"💣": "[表情:炸弹]",
"": "[表情:闪光]",
"💨": "[表情:吹气]",
"💦": "[表情:水]",
"🔥": "[表情:火]",
"💤": "[表情:睡觉]",
"💩": "[表情:便便]",
"💉": "[表情:打针]",
"📫": "[表情:邮箱]",
"🐎": "[表情:骑马]",
"👧": "[表情:女孩]",
"👦": "[表情:男孩]",
"🐵": "[表情:猴]",
"🐷": "[表情:猪]",
"🐮": "[表情:牛]",
"🐔": "[表情:公鸡]",
"🐸": "[表情:青蛙]",
"👻": "[表情:幽灵]",
"🐛": "[表情:虫]",
"🐶": "[表情:狗]",
"🐳": "[表情:鲸鱼]",
"👢": "[表情:靴子]",
"": "[表情:晴天]",
"": "[表情:问号]",
"🔫": "[表情:手枪]",
"💓": "[表情:爱 心]",
"🏪": "[表情:便利店]",
}

View File

@@ -0,0 +1,45 @@
import asyncio
import time
from typing import Dict
from .config import global_config
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
response_dict: Dict = {}
response_time_dict: Dict = {}
async def get_response(request_id: str, timeout: int = 10) -> dict:
response = await asyncio.wait_for(_get_response(request_id), timeout)
_ = response_time_dict.pop(request_id)
logger.info(f"响应信息id: {request_id} 已从响应字典中取出")
return response
async def _get_response(request_id: str) -> dict:
"""
内部使用的获取响应函数,主要用于在需要时获取响应
"""
while request_id not in response_dict:
await asyncio.sleep(0.2)
return response_dict.pop(request_id)
async def put_response(response: dict):
echo_id = response.get("echo")
now_time = time.time()
response_dict[echo_id] = response
response_time_dict[echo_id] = now_time
logger.info(f"响应信息id: {echo_id} 已存入响应字典")
async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)

View File

@@ -0,0 +1,711 @@
import json
import time
import random
import websockets as Server
import uuid
import asyncio
from maim_message import (
UserInfo,
GroupInfo,
Seg,
BaseMessageInfo,
MessageBase,
)
from typing import Dict, Any, Tuple, Optional
from . import CommandType
from .config import global_config
from .response_pool import get_response
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .utils import get_image_format, convert_image_to_gif
from .recv_handler.message_sending import message_send_instance
from .websocket_manager import websocket_manager
from .config.features_config import features_manager
class SendHandler:
def __init__(self):
self.server_connection: Optional[Server.ServerConnection] = None
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
def get_server_connection(self) -> Optional[Server.ServerConnection]:
"""获取当前的服务器连接"""
# 优先使用直接设置的连接,否则从 websocket_manager 获取
if self.server_connection:
return self.server_connection
return websocket_manager.get_connection()
async def handle_message(self, raw_message_base_dict: dict) -> None:
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
message_segment: Seg = raw_message_base.message_segment
logger.info("接收到来自MaiBot的消息处理中")
if message_segment.type == "command":
logger.info("处理命令")
return await self.send_command(raw_message_base)
elif message_segment.type == "adapter_command":
logger.info("处理适配器命令")
return await self.handle_adapter_command(raw_message_base)
else:
logger.info("处理普通消息")
return await self.send_normal_message(raw_message_base)
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
"""
处理普通消息发送
"""
logger.info("处理普通信息中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: Optional[GroupInfo] = message_info.group_info
user_info: Optional[UserInfo] = message_info.user_info
target_id: Optional[int] = None
action: Optional[str] = None
id_name: Optional[str] = None
processed_message: list = []
try:
if user_info:
processed_message = await self.handle_seg_recursive(
message_segment, user_info
)
except Exception as e:
logger.error(f"处理消息时发生错误: {e}")
return
if not processed_message:
logger.critical("现在暂时不支持解析此回复!")
return None
if group_info and user_info:
logger.debug("发送群聊消息")
target_id = int(group_info.group_id) if group_info.group_id else None
action = "send_group_msg"
id_name = "group_id"
elif user_info:
logger.debug("发送私聊消息")
target_id = int(user_info.user_id) if user_info.user_id else None
action = "send_private_msg"
id_name = "user_id"
else:
logger.error("无法识别的消息类型")
return
logger.info("尝试发送到napcat")
response = await self.send_message_to_napcat(
action,
{
id_name: target_id,
"message": processed_message,
},
)
if response.get("status") == "ok":
logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id)
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
async def send_command(self, raw_message_base: MessageBase) -> None:
"""
处理命令类
"""
logger.info("处理命令中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: Optional[GroupInfo] = message_info.group_info
seg_data: Dict[str, Any] = (
message_segment.data
if isinstance(message_segment.data, dict)
else {}
)
command_name: Optional[str] = seg_data.get("name")
try:
args = seg_data.get("args", {})
if not isinstance(args, dict):
args = {}
match command_name:
case CommandType.GROUP_BAN.name:
command, args_dict = self.handle_ban_command(args, group_info)
case CommandType.GROUP_WHOLE_BAN.name:
command, args_dict = self.handle_whole_ban_command(
args, group_info
)
case CommandType.GROUP_KICK.name:
command, args_dict = self.handle_kick_command(args, group_info)
case CommandType.SEND_POKE.name:
command, args_dict = self.handle_poke_command(args, group_info)
case CommandType.DELETE_MSG.name:
command, args_dict = self.delete_msg_command(args)
case CommandType.AI_VOICE_SEND.name:
command, args_dict = self.handle_ai_voice_send_command(
args, group_info
)
case CommandType.SET_EMOJI_LIKE.name:
command, args_dict = self.handle_set_emoji_like_command(args)
case CommandType.SEND_AT_MESSAGE.name:
command, args_dict = self.handle_at_message_command(
args, group_info
)
case CommandType.SEND_LIKE.name:
command, args_dict = self.handle_send_like_command(args)
case _:
logger.error(f"未知命令: {command_name}")
return
except Exception as e:
logger.error(f"处理命令时发生错误: {e}")
return None
if not command or not args_dict:
logger.error("命令或参数缺失")
return None
response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功")
else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
"""
处理适配器命令类 - 用于直接向Napcat发送命令并返回结果
"""
logger.info("处理适配器命令中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
seg_data: Dict[str, Any] = (
message_segment.data
if isinstance(message_segment.data, dict)
else {}
)
try:
action = seg_data.get("action")
params = seg_data.get("params", {})
request_id = seg_data.get("request_id")
if not action:
logger.error("适配器命令缺少action参数")
await self.send_adapter_command_response(
raw_message_base,
{"status": "error", "message": "缺少action参数"},
request_id
)
return
logger.info(f"执行适配器命令: {action}")
# 直接向Napcat发送命令并获取响应
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
response = await response_task
# 发送响应回MaiBot
await self.send_adapter_command_response(raw_message_base, response, request_id)
if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功")
else:
logger.warning(f"适配器命令 {action} 执行失败napcat返回{str(response)}")
except Exception as e:
logger.error(f"处理适配器命令时发生错误: {e}")
error_response = {"status": "error", "message": str(e)}
await self.send_adapter_command_response(
raw_message_base,
error_response,
seg_data.get("request_id")
)
def get_level(self, seg_data: Seg) -> int:
if seg_data.type == "seglist":
return 1 + max(self.get_level(seg) for seg in seg_data.data)
else:
return 1
async def handle_seg_recursive(self, seg_data: Seg, user_info: UserInfo) -> list:
payload: list = []
if seg_data.type == "seglist":
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
if not seg_data.data:
return []
for seg in seg_data.data:
payload = await self.process_message_by_type(seg, payload, user_info)
else:
payload = await self.process_message_by_type(seg_data, payload, user_info)
return payload
async def process_message_by_type(
self, seg: Seg, payload: list, user_info: UserInfo
) -> list:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
new_payload = payload
if seg.type == "reply":
target_id = seg.data
if target_id == "notice":
return payload
new_payload = self.build_payload(
payload,
await self.handle_reply_message(
target_id if isinstance(target_id, str) else "", user_info
),
True,
)
elif seg.type == "text":
text = seg.data
if not text:
return payload
new_payload = self.build_payload(
payload,
self.handle_text_message(text if isinstance(text, str) else ""),
False,
)
elif seg.type == "face":
logger.warning("MaiBot 发送了qq原生表情暂时不支持")
elif seg.type == "image":
image = seg.data
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
elif seg.type == "emoji":
emoji = seg.data
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
elif seg.type == "voice":
voice = seg.data
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
elif seg.type == "voiceurl":
voice_url = seg.data
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
elif seg.type == "music":
song_id = seg.data
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
elif seg.type == "videourl":
video_url = seg.data
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
elif seg.type == "file":
file_path = seg.data
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload
def build_payload(
self, payload: list, addon: dict | list, is_reply: bool = False
) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体"""
if is_reply:
temp_list = []
if isinstance(addon, list):
temp_list.extend(addon)
else:
temp_list.append(addon)
for i in payload:
if i.get("type") == "reply":
logger.debug("检测到多个回复,使用最新的回复")
continue
temp_list.append(i)
return temp_list
else:
if isinstance(addon, list):
payload.extend(addon)
else:
payload.append(addon)
return payload
async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list:
"""处理回复消息"""
reply_seg = {"type": "reply", "data": {"id": id}}
# 获取功能配置
ft_config = features_manager.get_config()
# 检查是否启用引用艾特功能
if not ft_config.enable_reply_at:
return reply_seg
try:
# 尝试通过 message_id 获取消息详情
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
replied_user_id = None
if msg_info_response and msg_info_response.get("status") == "ok":
sender_info = msg_info_response.get("data", {}).get("sender")
if sender_info:
replied_user_id = sender_info.get("user_id")
# 如果没有获取到被回复者的ID则直接返回不进行@
if not replied_user_id:
logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @")
return reply_seg
# 根据概率决定是否艾特用户
if random.random() < ft_config.reply_at_rate:
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
# 在艾特后面添加一个空格
text_seg = {"type": "text", "data": {"text": " "}}
return [reply_seg, at_seg, text_seg]
except Exception as e:
logger.error(f"处理引用回复并尝试@时出错: {e}")
# 出现异常时,只发送普通的回复,避免程序崩溃
return reply_seg
return reply_seg
def handle_text_message(self, message: str) -> dict:
"""处理文本消息"""
return {"type": "text", "data": {"text": message}}
def handle_image_message(self, encoded_image: str) -> dict:
"""处理图片消息"""
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 0,
},
} # base64 编码的图片
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
if image_format != "gif":
encoded_image = convert_image_to_gif(encoded_emoji)
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 1,
"summary": "[动画表情]",
},
}
def handle_voice_message(self, encoded_voice: str) -> dict:
"""处理语音消息"""
if not global_config.voice.use_tts:
logger.warning("未启用语音消息处理")
return {}
if not encoded_voice:
return {}
return {
"type": "record",
"data": {"file": f"base64://{encoded_voice}"},
}
def handle_voiceurl_message(self, voice_url: str) -> dict:
"""处理语音链接消息"""
return {
"type": "record",
"data": {"file": voice_url},
}
def handle_music_message(self, song_id: str) -> dict:
"""处理音乐消息"""
return {
"type": "music",
"data": {"type": "163", "id": song_id},
}
def handle_videourl_message(self, video_url: str) -> dict:
"""处理视频链接消息"""
return {
"type": "video",
"data": {"file": video_url},
}
def handle_file_message(self, file_path: str) -> dict:
"""处理文件消息"""
return {
"type": "file",
"data": {"file": f"file://{file_path}"},
}
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理删除消息命令"""
return "delete_msg", {"message_id": args["message_id"]}
def handle_ban_command(
self, args: Dict[str, Any], group_info: GroupInfo
) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
duration: int = int(args["duration"])
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if duration < 0:
raise ValueError("封禁时间必须大于等于0")
if not user_id or not group_id:
raise ValueError("封禁命令缺少必要参数")
if duration > 2592000:
raise ValueError("封禁时间不能超过30天")
return (
CommandType.GROUP_BAN.value,
{
"group_id": group_id,
"user_id": user_id,
"duration": duration,
},
)
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
enable = args["enable"]
assert isinstance(enable, bool), "enable参数必须是布尔值"
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
return (
CommandType.GROUP_WHOLE_BAN.value,
{
"group_id": group_id,
"enable": enable,
},
)
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.GROUP_KICK.value,
{
"group_id": group_id,
"user_id": user_id,
"reject_add_request": False, # 不拒绝加群请求
},
)
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
if group_info is None:
group_id = None
else:
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.SEND_POKE.value,
{
"group_id": group_id,
"user_id": user_id,
},
)
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理设置表情回应命令
Args:
args (Dict[str, Any]): 参数字典
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
try:
message_id = int(args["message_id"])
emoji_id = int(args["emoji_id"])
set_like = str(args["set"])
except:
raise ValueError("缺少必需参数: message_id 或 emoji_id")
return (
CommandType.SET_EMOJI_LIKE.value,
{
"message_id": message_id,
"emoji_id": emoji_id,
"set": set_like
},
)
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""
处理发送点赞命令的逻辑。
Args:
args (Dict[str, Any]): 参数字典
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
try:
user_id: int = int(args["qq_id"])
times: int = int(args["times"])
except (KeyError, ValueError):
raise ValueError("缺少必需参数: qq_id 或 times")
return (
CommandType.SEND_LIKE.value,
{
"user_id": user_id,
"times": times
},
)
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""
处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。
"""
if not group_info or not group_info.group_id:
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
if not args:
raise ValueError("AI语音发送命令缺少参数")
group_id: int = int(group_info.group_id)
character_id = args.get("character")
text_content = args.get("text")
if not character_id or not text_content:
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
return (
CommandType.AI_VOICE_SEND.value,
{
"group_id": group_id,
"text": text_content,
"character": character_id,
},
)
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
# 获取当前连接
connection = self.get_server_connection()
if not connection:
logger.error("没有可用的 Napcat 连接")
return {"status": "error", "message": "no connection"}
try:
await connection.send(payload)
response = await get_response(request_uuid)
except TimeoutError:
logger.error("发送消息超时,未收到响应")
return {"status": "error", "message": "timeout"}
except Exception as e:
logger.error(f"发送消息失败: {e}")
return {"status": "error", "message": str(e)}
return response
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {}
message_base.message_info.additional_config["echo"] = True
# 获取原始的 mmc_message_id
mmc_message_id = message_base.message_info.message_id
# 修改 message_segment 为 notify 类型
message_base.message_segment = Seg(
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
)
await message_send_instance.message_send(message_base)
logger.debug("已回送消息ID")
return
async def send_adapter_command_response(
self, original_message: MessageBase, response_data: dict, request_id: str
) -> None:
"""
发送适配器命令响应回MaiBot
Args:
original_message: 原始消息
response_data: 响应数据
request_id: 请求ID
"""
try:
# 修改 additional_config添加 echo 字段
if original_message.message_info.additional_config is None:
original_message.message_info.additional_config = {}
original_message.message_info.additional_config["echo"] = True
# 修改 message_segment 为 adapter_response 类型
original_message.message_segment = Seg(
type="adapter_response",
data={
"request_id": request_id,
"response": response_data,
"timestamp": int(time.time() * 1000)
}
)
await message_send_instance.message_send(original_message)
logger.debug(f"已发送适配器命令响应request_id: {request_id}")
except Exception as e:
logger.error(f"发送适配器命令响应时出错: {e}")
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理艾特并发送消息命令
Args:
args (Dict[str, Any]): 参数字典, 包含 qq_id 和 text
group_info (GroupInfo): 群聊信息
Returns:
Tuple[str, Dict[str, Any]]: (action, params)
"""
at_user_id = args.get("qq_id")
text = args.get("text")
if not at_user_id or not text:
raise ValueError("艾特消息命令缺少 qq_id 或 text 参数")
if not group_info:
raise ValueError("艾特消息命令必须在群聊上下文中使用")
message_payload = [
{"type": "at", "data": {"qq": str(at_user_id)}},
{"type": "text", "data": {"text": " " + str(text)}},
]
return (
"send_group_msg",
{
"group_id": group_info.group_id,
"message": message_payload,
},
)
send_handler = SendHandler()

View File

@@ -0,0 +1,311 @@
import websockets as Server
import json
import base64
import uuid
import urllib3
import ssl
import io
from .database import BanUser, db_manager
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .response_pool import get_response
from PIL import Image
from typing import Union, List, Tuple, Optional
class SSLAdapter(urllib3.PoolManager):
def __init__(self, *args, **kwargs):
context = ssl.create_default_context()
context.set_ciphers("DEFAULT@SECLEVEL=1")
context.minimum_version = ssl.TLSVersion.TLSv1_2
kwargs["ssl_context"] = context
super().__init__(*args, **kwargs)
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群相关信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群详细信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群详细信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群详细信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
"""
获取群成员信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取成员信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_image_base64(url: str) -> str:
# sourcery skip: raise-specific-error
"""获取图片/表情包的Base64"""
logger.debug(f"下载图片: {url}")
http = SSLAdapter()
try:
response = http.request("GET", url, timeout=10)
if response.status != 200:
raise Exception(f"HTTP Error: {response.status}")
image_bytes = response.data
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"图片下载失败: {str(e)}")
raise
def convert_image_to_gif(image_base64: str) -> str:
# sourcery skip: extract-method
"""
将Base64编码的图片转换为GIF格式
Parameters:
image_base64: str: Base64编码的图片数据
Returns:
str: Base64编码的GIF图片数据
"""
logger.debug("转换图片为GIF格式")
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
output_buffer = io.BytesIO()
image.save(output_buffer, format="GIF")
output_buffer.seek(0)
return base64.b64encode(output_buffer.read()).decode("utf-8")
except Exception as e:
logger.error(f"图片转换为GIF失败: {str(e)}")
return image_base64
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
"""
获取自身信息
Parameters:
websocket: WebSocket连接对象
Returns:
data: dict: 返回的自身信息
"""
logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error("获取自身信息超时")
return None
except Exception as e:
logger.error(f"获取自身信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
def get_image_format(raw_data: str) -> str:
"""
从Base64编码的数据中确定图片的格式。
Parameters:
raw_data: str: Base64编码的图片数据。
Returns:
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
"""
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
"""
获取陌生人信息
Parameters:
websocket: WebSocket连接对象
user_id: 用户ID
Returns:
dict: 返回的陌生人信息
"""
logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取陌生人信息超时用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
"""
获取消息详情,可能为空
Parameters:
websocket: WebSocket连接对象
message_id: 消息ID
Returns:
dict: 返回的消息详情
"""
logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取消息详情超时消息ID: {message_id}")
return None
except Exception as e:
logger.error(f"获取消息详情失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_record_detail(
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
) -> dict | None:
"""
获取语音消息内容
Parameters:
websocket: WebSocket连接对象
file: 文件名
file_id: 文件ID
Returns:
dict: 返回的语音消息详情
"""
logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
return None
except Exception as e:
logger.error(f"获取语音消息详情失败: {e}")
return None
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
return response.get("data")
async def read_ban_list(
websocket: Server.ServerConnection,
) -> Tuple[List[BanUser], List[BanUser]]:
"""
从根目录下的data文件夹中的文件读取禁言列表。
同时自动更新已经失效禁言
Returns:
Tuple[
一个仍在禁言中的用户的BanUser列表,
一个已经自然解除禁言的用户的BanUser列表,
一个仍在全体禁言中的群的BanUser列表,
一个已经自然解除全体禁言的群的BanUser列表,
]
"""
try:
ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表")
for ban_record in ban_list:
if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None:
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
group_all_shut: int = fetched_group_info.get("group_all_shut")
if group_all_shut == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
else:
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
if fetched_member_info is None:
logger.warning(
f"无法获取群成员信息用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
)
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
if lift_ban_time == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
else:
ban_record.lift_time = lift_ban_time
db_manager.update_ban_record(ban_list)
return ban_list, lifted_list
except Exception as e:
logger.error(f"读取禁言列表失败: {e}")
return [], []
def save_ban_record(list: List[BanUser]):
return db_manager.update_ban_record(list)

View File

@@ -0,0 +1,192 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频下载和处理模块
用于从QQ消息中下载视频并转发给Bot进行分析
"""
import aiohttp
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any
from src.common.logger import get_logger
logger = get_logger("video_handler")
class VideoDownloader:
def __init__(self, max_size_mb: int = 100, download_timeout: int = 60):
self.max_size_mb = max_size_mb
self.download_timeout = download_timeout
self.supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
def is_video_url(self, url: str) -> bool:
"""检查URL是否为视频文件"""
try:
# QQ视频URL可能没有扩展名所以先检查Content-Type
# 对于QQ视频我们先假设是视频稍后通过Content-Type验证
# 检查URL中是否包含视频相关的关键字
video_keywords = ['video', 'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v']
url_lower = url.lower()
# 如果URL包含视频关键字认为是视频
if any(keyword in url_lower for keyword in video_keywords):
return True
# 检查文件扩展名(传统方法)
path = Path(url.split('?')[0]) # 移除查询参数
if path.suffix.lower() in self.supported_formats:
return True
# 对于QQ等特殊平台URL可能没有扩展名
# 我们允许这些URL通过稍后通过HTTP头Content-Type验证
qq_domains = ['qpic.cn', 'gtimg.cn', 'qq.com', 'tencent.com']
if any(domain in url_lower for domain in qq_domains):
return True
return False
except:
# 如果解析失败,默认允许尝试下载(稍后验证)
return True
def check_file_size(self, content_length: Optional[str]) -> bool:
"""检查文件大小是否在允许范围内"""
if content_length is None:
return True # 无法获取大小时允许下载
try:
size_bytes = int(content_length)
size_mb = size_bytes / (1024 * 1024)
return size_mb <= self.max_size_mb
except:
return True
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""
下载视频文件
Args:
url: 视频URL
filename: 可选的文件名
Returns:
dict: 下载结果包含success、data、filename、error等字段
"""
try:
logger.info(f"开始下载视频: {url}")
# 检查URL格式
if not self.is_video_url(url):
logger.warning(f"URL格式检查失败: {url}")
return {
"success": False,
"error": "不支持的视频格式",
"url": url
}
async with aiohttp.ClientSession() as session:
# 先发送HEAD请求检查文件大小
try:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
if response.status != 200:
logger.warning(f"HEAD请求失败状态码: {response.status}")
else:
content_length = response.headers.get('Content-Length')
if not self.check_file_size(content_length):
return {
"success": False,
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
"url": url
}
except Exception as e:
logger.warning(f"HEAD请求失败: {e},继续尝试下载")
# 下载文件
async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response:
if response.status != 200:
return {
"success": False,
"error": f"下载失败HTTP状态码: {response.status}",
"url": url
}
# 检查Content-Type是否为视频
content_type = response.headers.get('Content-Type', '').lower()
if content_type:
# 检查是否为视频类型
video_mime_types = [
'video/', 'application/octet-stream',
'application/x-msvideo', 'video/x-msvideo'
]
is_video_content = any(mime in content_type for mime in video_mime_types)
if not is_video_content:
logger.warning(f"Content-Type不是视频格式: {content_type}")
# 如果不是明确的视频类型但可能是QQ的特殊格式继续尝试
if 'text/' in content_type or 'application/json' in content_type:
return {
"success": False,
"error": f"URL返回的不是视频内容Content-Type: {content_type}",
"url": url
}
# 再次检查Content-Length
content_length = response.headers.get('Content-Length')
if not self.check_file_size(content_length):
return {
"success": False,
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
"url": url
}
# 读取文件内容
video_data = await response.read()
# 检查实际文件大小
actual_size_mb = len(video_data) / (1024 * 1024)
if actual_size_mb > self.max_size_mb:
return {
"success": False,
"error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB",
"url": url
}
# 确定文件名
if filename is None:
filename = Path(url.split('?')[0]).name
if not filename or '.' not in filename:
filename = "video.mp4"
logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB")
return {
"success": True,
"data": video_data,
"filename": filename,
"size_mb": actual_size_mb,
"url": url
}
except asyncio.TimeoutError:
return {
"success": False,
"error": "下载超时",
"url": url
}
except Exception as e:
logger.error(f"下载视频时出错: {e}")
return {
"success": False,
"error": str(e),
"url": url
}
# 全局实例
_video_downloader = None
def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader:
"""获取视频下载器实例"""
global _video_downloader
if _video_downloader is None:
_video_downloader = VideoDownloader(max_size_mb, download_timeout)
return _video_downloader

View File

@@ -0,0 +1,158 @@
import asyncio
import websockets as Server
from typing import Optional, Callable, Any
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config import global_config
class WebSocketManager:
"""WebSocket 连接管理器,支持正向和反向连接"""
def __init__(self):
self.connection: Optional[Server.ServerConnection] = None
self.server: Optional[Server.WebSocketServer] = None
self.is_running = False
self.reconnect_interval = 5 # 重连间隔(秒)
self.max_reconnect_attempts = 10 # 最大重连次数
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""根据配置启动 WebSocket 连接"""
mode = global_config.napcat_server.mode
if mode == "reverse":
await self._start_reverse_connection(message_handler)
elif mode == "forward":
await self._start_forward_connection(message_handler)
else:
raise ValueError(f"不支持的连接模式: {mode}")
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动反向连接(作为服务器)"""
host = global_config.napcat_server.host
port = global_config.napcat_server.port
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
async def handle_client(websocket, path=None):
self.connection = websocket
logger.info(f"Napcat 客户端已连接: {websocket.remote_address}")
try:
await message_handler(websocket)
except Exception as e:
logger.error(f"处理客户端连接时出错: {e}")
finally:
self.connection = None
logger.info("Napcat 客户端已断开连接")
self.server = await Server.serve(
handle_client,
host,
port,
max_size=2**26
)
self.is_running = True
logger.info(f"反向连接服务器已启动,监听地址: ws://{host}:{port}")
# 保持服务器运行
await self.server.serve_forever()
async def _start_forward_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动正向连接(作为客户端)"""
url = self._get_forward_url()
logger.info(f"正在启动正向连接模式,目标地址: {url}")
reconnect_count = 0
while reconnect_count < self.max_reconnect_attempts:
try:
logger.info(f"尝试连接到 Napcat 服务器: {url}")
# 准备连接参数
connect_kwargs = {"max_size": 2**26}
# 如果配置了访问令牌,添加到请求头
if global_config.napcat_server.access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {global_config.napcat_server.access_token}"
}
logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket:
self.connection = websocket
self.is_running = True
reconnect_count = 0 # 重置重连计数
logger.info(f"成功连接到 Napcat 服务器: {url}")
try:
await message_handler(websocket)
except Server.exceptions.ConnectionClosed:
logger.warning("与 Napcat 服务器的连接已断开")
except Exception as e:
logger.error(f"处理正向连接时出错: {e}")
finally:
self.connection = None
self.is_running = False
except (Server.exceptions.ConnectionClosed, Server.exceptions.InvalidMessage, OSError, ConnectionRefusedError) as e:
reconnect_count += 1
logger.warning(f"连接失败 ({reconnect_count}/{self.max_reconnect_attempts}): {e}")
if reconnect_count < self.max_reconnect_attempts:
logger.info(f"将在 {self.reconnect_interval} 秒后重试连接...")
await asyncio.sleep(self.reconnect_interval)
else:
logger.error("已达到最大重连次数,停止重连")
raise
except Exception as e:
logger.error(f"正向连接时发生未知错误: {e}")
raise
def _get_forward_url(self) -> str:
"""获取正向连接的 URL"""
config = global_config.napcat_server
# 如果配置了完整的 URL直接使用
if config.url:
return config.url
# 否则根据 host 和 port 构建 URL
host = config.host
port = config.port
return f"ws://{host}:{port}"
async def stop_connection(self) -> None:
"""停止 WebSocket 连接"""
self.is_running = False
if self.connection:
try:
await self.connection.close()
logger.info("WebSocket 连接已关闭")
except Exception as e:
logger.error(f"关闭 WebSocket 连接时出错: {e}")
finally:
self.connection = None
if self.server:
try:
self.server.close()
await self.server.wait_closed()
logger.info("WebSocket 服务器已关闭")
except Exception as e:
logger.error(f"关闭 WebSocket 服务器时出错: {e}")
finally:
self.server = None
def get_connection(self) -> Optional[Server.ServerConnection]:
"""获取当前的 WebSocket 连接"""
return self.connection
def is_connected(self) -> bool:
"""检查是否已连接"""
return self.connection is not None and self.is_running
# 全局 WebSocket 管理器实例
websocket_manager = WebSocketManager()

View File

@@ -0,0 +1,43 @@
# 权限配置文件
# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能
# 支持热重载,修改后会自动生效
# 群聊权限设置
group_list_type = "whitelist" # 群聊列表类型whitelist白名单或 blacklist黑名单
group_list = [] # 群聊ID列表
# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人
# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人
# 示例group_list = [123456789, 987654321]
# 私聊权限设置
private_list_type = "whitelist" # 私聊列表类型whitelist白名单或 blacklist黑名单
private_list = [] # 用户ID列表
# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人
# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人
# 示例private_list = [123456789, 987654321]
# 全局禁止设置
ban_user_id = [] # 全局禁止用户ID列表这些用户无法在任何地方使用机器人
ban_qq_bot = false # 是否屏蔽QQ官方机器人消息
# 聊天功能设置
enable_poke = true # 是否启用戳一戳功能
ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳
poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略
enable_reply_at = true # 是否启用引用回复时艾特用户的功能
reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0)
# 视频处理设置
enable_video_analysis = true # 是否启用视频识别功能
max_video_size_mb = 100 # 视频文件最大大小限制MB
download_timeout = 60 # 视频下载超时时间(秒)
supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式
# 消息缓冲设置
enable_message_buffer = true # 是否启用消息缓冲合并功能
message_buffer_enable_group = true # 是否启用群聊消息缓冲合并
message_buffer_enable_private = true # 是否启用私聊消息缓冲合并
message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并
message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并
message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并
message_buffer_block_prefixes = ["/", "!", "", ".", "。", "#", "%"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲

View File

@@ -0,0 +1,25 @@
[inner]
version = "0.2.0" # 版本号
# 请勿修改版本号,除非你知道自己在做什么
[nickname] # 现在没用
nickname = ""
[napcat_server] # Napcat连接的ws服务设置
mode = "reverse" # 连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)
host = "localhost" # 主机地址
port = 8095 # 端口号
url = "" # 正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)
access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选)
heartbeat_interval = 30 # 心跳间隔时间(按秒计)
[maibot_server] # 连接麦麦的ws服务设置
host = "localhost" # 麦麦在.env文件中设置的主机地址即HOST字段
port = 8000 # 麦麦在.env文件中设置的端口即PORT字段
[voice] # 发送语音设置
use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter
[debug]
level = "INFO" # 日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL

View File

@@ -0,0 +1,89 @@
# TODO List:
[x] logger使用主程序的
[ ] 使用插件系统的config系统
[ ] 接收从napcat传递的所有信息
[ ] 优化架构,各模块解耦,暴露关键方法用于提供接口
[ ] 单独一个模块负责与主程序通信
[ ] 使用event系统完善接口api
---
Event分为两种一种是对外输出的event由napcat插件自主触发并传递参数另一种是接收外界输入的event由外部插件触发并向napcat传递参数
## 例如,
### 对外输出的event
napcat_on_received_text -> (message_seg: Seg) 接受到qq的文字消息,会向handler传递一个Seg
napcat_on_received_face -> (message_seg: Seg) 接受到qq的表情消息,会向handler传递一个Seg
napcat_on_received_reply -> (message_seg: Seg) 接受到qq的回复消息,会向handler传递一个Seg
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
napcat_on_received_record -> (message_seg: Seg) 接受到qq的语音消息,会向handler传递一个Seg
napcat_on_received_rps -> (message_seg: Seg) 接受到qq的猜拳魔法表情,会向handler传递一个Seg
napcat_on_received_friend_invitation -> (user_id: str) 接受到qq的好友请求,会向handler传递一个user_id
...
此类event不接受外部插件的触发只能由napcat插件统一触发。
外部插件需要编写handler并订阅此类事件。
```python
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.base_event import HandlerResult
class MyEventHandler(BaseEventHandler):
handler_name = "my_handler"
handler_description = "我的自定义事件处理器"
weight = 10 # 权重,越大越先执行
intercept_message = False # 是否拦截消息
init_subscribe = ["napcat_on_received_text"] # 初始订阅的事件
async def execute(self, params: dict) -> HandlerResult:
"""处理事件"""
try:
message = params.get("message_seg")
print(f"收到消息: {message.data}")
# 业务逻辑处理
# ...
return HandlerResult(
success=True,
continue_process=True, # 是否继续让其他处理器处理
message="处理成功",
handler_name=self.handler_name
)
except Exception as e:
return HandlerResult(
success=False,
continue_process=True,
message=f"处理失败: {str(e)}",
handler_name=self.handler_name
)
```
### 接收外界输入的event
napcat_kick_group <- (user_id, group_id) 踢出某个群组中的某个用户
napcat_mute_user <- (user_id, group_id, time) 禁言某个群组中的某个用户
napcat_unmute_user <- (user_id, group_id) 取消禁言某个群组中的某个用户
napcat_mute_group <- (user_id, group_id) 禁言某个群组
napcat_unmute_group <- (user_id, group_id) 取消禁言某个群组
napcat_add_friend <- (user_id) 向某个用户发出好友请求
napcat_accept_friend <- (user_id) 接收某个用户的好友请求
napcat_reject_friend <- (user_id) 拒绝某个用户的好友请求
...
此类事件只由外部插件触发并传递参数由napcat完成请求任务
外部插件需要触发此类的event并传递正确的参数
```python
from src.plugin_system.core.event_manager import event_manager
# 触发事件
await event_manager.trigger_event("napcat_accept_friend", user_id = 1234123)
```

View File

@@ -65,3 +65,6 @@ asyncio
tavily-python
google-generativeai
lunar_python
python-multipart
aiofiles

View File

@@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager
from src.plugin_system.base.component_types import ChatMode
from src.schedule.schedule_manager import schedule_manager
from src.plugin_system.apis import message_api
from src.mood.mood_manager import mood_manager
from .hfc_context import HfcContext
from .energy_manager import EnergyManager

View File

@@ -7,33 +7,11 @@ from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
from src.common.tool_history import ToolHistoryManager
install(extra_lines=3)
logger = get_logger("prompt_build")
# 创建工具历史管理器实例
tool_history_manager = ToolHistoryManager()
def get_tool_history_prompt(message_id: Optional[str] = None) -> str:
"""获取工具历史提示词
Args:
message_id: 会话ID, 用于只获取当前会话的历史
Returns:
格式化的工具历史提示词
"""
from src.config.config import global_config
if not global_config.tool.history.enable_prompt_history:
return ""
return tool_history_manager.get_recent_history_prompt(
chat_id=message_id
)
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
@@ -49,7 +27,7 @@ class PromptContext:
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value)
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
@@ -73,7 +51,7 @@ class PromptContext:
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id) if context_id else None
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
@@ -111,6 +89,7 @@ class PromptContext:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
@@ -153,40 +132,15 @@ class PromptManager:
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
# 获取当前提示词
prompt = await self.get_prompt_async(name)
# 获取当前会话ID
message_id = self._context._current_context
# 获取工具历史提示词
tool_history = ""
if name in ['action_prompt', 'replyer_prompt', 'planner_prompt', 'tool_executor_prompt']:
tool_history = get_tool_history_prompt(message_id)
# 获取基本格式化结果
result = prompt.format(**kwargs)
# 如果有工具历史,插入到适当位置
if tool_history:
# 查找合适的插入点
# 在人格信息和身份块之后,但在主要内容之前
identity_end = result.find("```\n现在,你说:")
if identity_end == -1:
# 如果找不到特定标记,尝试在第一个段落后插入
first_double_newline = result.find("\n\n")
if first_double_newline != -1:
# 在第一个双换行后插入
result = f"{result[:first_double_newline + 2]}{tool_history}\n{result[first_double_newline + 2:]}"
else:
# 如果找不到合适的位置,添加到开头
result = f"{tool_history}\n\n{result}"
else:
# 在找到的位置插入
result = f"{result[:identity_end]}\n{tool_history}\n{result[identity_end:]}"
return result
@@ -195,6 +149,11 @@ global_prompt_manager = PromptManager()
class Prompt(str):
template: str
name: Optional[str]
args: List[str]
_args: List[Any]
_kwargs: Dict[str, Any]
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
@@ -215,7 +174,7 @@ class Prompt(str):
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
@@ -251,7 +210,7 @@ class Prompt(str):
@classmethod
async def create_async(
cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
@@ -260,7 +219,9 @@ class Prompt(str):
return prompt
@classmethod
def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
if kwargs is None:
kwargs = {}
# 预处理模板中的转义花括号
processed_template = cls._process_escaped_braces(template)

1
src/chat/utils/rust-video/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/target

610
src/chat/utils/rust-video/Cargo.lock generated Normal file
View File

@@ -0,0 +1,610 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
[[package]]
name = "anstyle-parse"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
dependencies = [
"windows-sys",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys",
]
[[package]]
name = "anyhow"
version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
[[package]]
name = "autocfg"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bumpalo"
version = "3.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]]
name = "cc"
version = "1.2.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc"
dependencies = [
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "chrono"
version = "0.4.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
[[package]]
name = "clap"
version = "4.5.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c5e4fcf9c21d2e544ca1ee9d8552de13019a42aa7dbf32747fa7aaf1df76e57"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fecb53a0e6fcfb055f686001bc2e2592fa527efaf38dbe81a6a9563562e57d41"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
[[package]]
name = "colorchoice"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "iana-time-zone"
version = "0.1.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"log",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "js-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
"once_cell",
"wasm-bindgen",
]
[[package]]
name = "libc"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "log"
version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "memchr"
version = "2.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "once_cell_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
[[package]]
name = "proc-macro2"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rayon"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "rust-video"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"clap",
"rayon",
"serde",
"serde_json",
]
[[package]]
name = "rustversion"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "serde"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.143"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "syn"
version = "2.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "wasm-bindgen"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
"rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
dependencies = [
"unicode-ident",
]
[[package]]
name = "windows-core"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-implement"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "windows-interface"
version = "0.59.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "windows-link"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
[[package]]
name = "windows-result"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.60.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.53.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91"
dependencies = [
"windows-link",
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
[[package]]
name = "windows_aarch64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
[[package]]
name = "windows_i686_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
[[package]]
name = "windows_i686_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
[[package]]
name = "windows_i686_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
[[package]]
name = "windows_x86_64_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
[[package]]
name = "windows_x86_64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"

View File

@@ -0,0 +1,24 @@
[package]
name = "rust-video"
version = "0.1.0"
edition = "2021"
authors = ["VideoAnalysis Team"]
description = "Ultra-fast video keyframe extraction tool in Rust"
license = "GPL-3.0"
[dependencies]
anyhow = "1.0"
clap = { version = "4.0", features = ["derive"] }
rayon = "1.11"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
[profile.release]
opt-level = 3
lto = true
codegen-units = 1
panic = "abort"
strip = true

View File

@@ -0,0 +1,221 @@
# 🎯 Rust Video Keyframe Extraction API
高性能视频关键帧提取API服务基于Rust后端 + Python FastAPI。
## 📁 项目结构
```
rust-video/
├── outputs/ # 关键帧输出目录
├── src/ # Rust源码
│ └── main.rs
├── target/ # Rust编译文件
├── api_server.py # 🚀 主API服务器 (整合版)
├── start_server.py # 生产启动脚本
├── config.py # 配置管理
├── config.toml # 配置文件
├── Cargo.toml # Rust项目配置
├── Cargo.lock # Rust依赖锁定
├── .gitignore # Git忽略文件
└── README.md # 项目文档
```
## 快速开始
### 1. 安装依赖
```bash
pip install fastapi uvicorn python-multipart aiofiles
```
### 2. 启动服务
```bash
# 开发模式
python api_server.py
# 生产模式
python start_server.py --mode prod --port 8050
```
### 3. 访问API
- **服务地址**: http://localhost:8050
- **API文档**: http://localhost:8050/docs
- **健康检查**: http://localhost:8050/health
- **性能指标**: http://localhost:8050/metrics
## API使用方法
### 主要端点
#### 1. 提取关键帧 (JSON响应)
```http
POST /extract-keyframes
Content-Type: multipart/form-data
- video: 视频文件 (.mp4, .avi, .mov, .mkv)
- scene_threshold: 场景变化阈值 (0.1-1.0, 默认0.3)
- max_frames: 最大关键帧数 (1-200, 默认50)
- resize_width: 调整宽度 (可选, 100-1920)
- time_interval: 时间间隔秒数 (可选, 0.1-60.0)
```
#### 2. 提取关键帧 (ZIP下载)
```http
POST /extract-keyframes-zip
Content-Type: multipart/form-data
参数同上返回包含所有关键帧的ZIP文件
```
#### 3. 健康检查
```http
GET /health
```
#### 4. 性能指标
```http
GET /metrics
```
### Python客户端示例
```python
import requests
# 上传视频并提取关键帧
files = {'video': open('video.mp4', 'rb')}
data = {
'scene_threshold': 0.3,
'max_frames': 50,
'resize_width': 800
}
response = requests.post(
'http://localhost:8050/extract-keyframes',
files=files,
data=data
)
result = response.json()
print(f"提取了 {result['keyframe_count']} 个关键帧")
print(f"处理时间: {result['performance']['total_api_time']:.2f}秒")
```
### JavaScript客户端示例
```javascript
const formData = new FormData();
formData.append('video', videoFile);
formData.append('scene_threshold', '0.3');
formData.append('max_frames', '50');
fetch('http://localhost:8050/extract-keyframes', {
method: 'POST',
body: formData
})
.then(response => response.json())
.then(data => {
console.log(`提取了 ${data.keyframe_count} 个关键帧`);
console.log(`处理时间: ${data.performance.total_api_time}秒`);
});
```
### cURL示例
```bash
curl -X POST "http://localhost:8050/extract-keyframes" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "video=@video.mp4" \
-F "scene_threshold=0.3" \
-F "max_frames=50"
```
## ⚙️ 配置
编辑 `config.toml` 文件:
```toml
[server]
host = "0.0.0.0"
port = 8050
debug = false
[processing]
default_scene_threshold = 0.3
default_max_frames = 50
timeout_seconds = 300
[performance]
async_workers = 4
max_file_size_mb = 500
```
## 性能特性
- **异步I/O**: 文件上传/下载异步处理
- **多线程处理**: 视频处理在独立线程池
- **内存优化**: 流式处理,减少内存占用
- **智能清理**: 自动临时文件管理
- **性能监控**: 实时处理时间和吞吐量统计
总之就是非常快()
## 响应格式
```json
{
"status": "success",
"processing_time": 4.5,
"output_directory": "/tmp/output_xxx",
"keyframe_count": 15,
"keyframes": [
"/tmp/output_xxx/frame_001.jpg",
"/tmp/output_xxx/frame_002.jpg"
],
"performance": {
"file_size_mb": 209.7,
"upload_time": 0.23,
"processing_time": 4.5,
"total_api_time": 4.73,
"upload_speed_mbps": 912.2
},
"rust_output": "处理完成",
"command": "rust-video input.mp4 output/ --scene-threshold 0.3 --max-frames 50"
}
```
## 故障排除
### 常见问题
1. **Rust binary not found**
```bash
cargo build # 重新构建Rust项目
```
2. **端口被占用**
```bash
# 修改config.toml中的端口号
port = 8051
```
3. **内存不足**
```bash
# 减少max_frames或resize_width参数
```
### 日志查看
服务启动时会显示详细的状态信息,包括:
- Rust二进制文件位置
- 配置加载状态
- 服务监听地址
## 集成支持
本API设计为独立服务可轻松集成到任何项目中
- **AI Bot项目**: 通过HTTP API调用
- **Web应用**: 直接前端调用或后端代理
- **移动应用**: REST API标准接口
- **批处理脚本**: Python/Shell脚本调用

View File

@@ -0,0 +1,472 @@
#!/usr/bin/env python3
"""
Rust Video Keyframe Extraction API Server
高性能视频关键帧提取API服务
功能:
- 视频上传和关键帧提取
- 异步多线程处理
- 性能监控和健康检查
- 自动资源清理
启动: python api_server.py
地址: http://localhost:8050
"""
import os
import json
import subprocess
import tempfile
import zipfile
import shutil
import asyncio
import time
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# 导入配置管理
from config import config
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# 内置视频处理器 (整合版)
# ============================================================================
class VideoKeyframeExtractor:
"""整合的视频关键帧提取器"""
def __init__(self, rust_binary_path: Optional[str] = None):
self.rust_binary_path = rust_binary_path or self._find_rust_binary()
if not self.rust_binary_path or not Path(self.rust_binary_path).exists():
raise FileNotFoundError(f"Rust binary not found: {self.rust_binary_path}")
def _find_rust_binary(self) -> str:
"""查找Rust二进制文件"""
possible_paths = [
"./target/debug/rust-video.exe",
"./target/release/rust-video.exe",
"./target/debug/rust-video",
"./target/release/rust-video"
]
for path in possible_paths:
if Path(path).exists():
return str(Path(path).absolute())
# 尝试构建
try:
subprocess.run(["cargo", "build"], check=True, capture_output=True)
for path in possible_paths:
if Path(path).exists():
return str(Path(path).absolute())
except subprocess.CalledProcessError:
pass
raise FileNotFoundError("Rust binary not found and build failed")
def process_video(
self,
video_path: str,
output_dir: str = "outputs",
scene_threshold: float = 0.3,
max_frames: int = 50,
resize_width: Optional[int] = None,
time_interval: Optional[float] = None
) -> Dict[str, Any]:
"""处理视频提取关键帧"""
video_path = Path(video_path)
if not video_path.exists():
raise FileNotFoundError(f"Video file not found: {video_path}")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 构建命令
cmd = [self.rust_binary_path, str(video_path), str(output_dir)]
cmd.extend(["--scene-threshold", str(scene_threshold)])
cmd.extend(["--max-frames", str(max_frames)])
if resize_width:
cmd.extend(["--resize-width", str(resize_width)])
if time_interval:
cmd.extend(["--time-interval", str(time_interval)])
# 执行处理
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True,
timeout=300 # 5分钟超时
)
processing_time = time.time() - start_time
# 解析输出
output_files = list(output_dir.glob("*.jpg"))
return {
"status": "success",
"processing_time": processing_time,
"output_directory": str(output_dir),
"keyframe_count": len(output_files),
"keyframes": [str(f) for f in output_files],
"rust_output": result.stdout,
"command": " ".join(cmd)
}
except subprocess.TimeoutExpired:
raise HTTPException(status_code=408, detail="Video processing timeout")
except subprocess.CalledProcessError as e:
raise HTTPException(
status_code=500,
detail=f"Video processing failed: {e.stderr}"
)
# ============================================================================
# 异步处理器 (整合版)
# ============================================================================
class AsyncVideoProcessor:
"""高性能异步视频处理器"""
def __init__(self):
self.extractor = VideoKeyframeExtractor()
async def process_video_async(
self,
upload_file: UploadFile,
processing_params: Dict[str, Any]
) -> Dict[str, Any]:
"""异步视频处理主流程"""
start_time = time.time()
# 1. 异步保存上传文件
upload_start = time.time()
temp_fd, temp_path_str = tempfile.mkstemp(suffix='.mp4')
temp_path = Path(temp_path_str)
try:
os.close(temp_fd)
# 异步读取并保存文件
content = await upload_file.read()
with open(temp_path, 'wb') as f:
f.write(content)
upload_time = time.time() - upload_start
file_size = len(content)
# 2. 多线程处理视频
process_start = time.time()
temp_output_dir = tempfile.mkdtemp()
output_path = Path(temp_output_dir)
try:
# 在线程池中异步处理
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
self._process_video_sync,
str(temp_path),
str(output_path),
processing_params
)
process_time = time.time() - process_start
total_time = time.time() - start_time
# 添加性能指标
result.update({
'performance': {
'file_size_mb': file_size / (1024 * 1024),
'upload_time': upload_time,
'processing_time': process_time,
'total_api_time': total_time,
'upload_speed_mbps': (file_size / (1024 * 1024)) / upload_time if upload_time > 0 else 0
}
})
return result
finally:
# 清理输出目录
try:
shutil.rmtree(temp_output_dir, ignore_errors=True)
except Exception as e:
logger.warning(f"Failed to cleanup output directory: {e}")
finally:
# 清理临时文件
try:
if temp_path.exists():
temp_path.unlink()
except Exception as e:
logger.warning(f"Failed to cleanup temp file: {e}")
def _process_video_sync(self, video_path: str, output_dir: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""在线程池中同步处理视频"""
return self.extractor.process_video(
video_path=video_path,
output_dir=output_dir,
**params
)
# ============================================================================
# FastAPI 应用初始化
# ============================================================================
app = FastAPI(
title="Rust Video Keyframe API",
description="高性能视频关键帧提取API服务",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局处理器实例
video_processor = AsyncVideoProcessor()
# 简单的统计
stats = {
"total_requests": 0,
"processing_times": [],
"start_time": datetime.now()
}
# ============================================================================
# API 路由
# ============================================================================
@app.get("/", response_class=JSONResponse)
async def root():
"""API根路径"""
return {
"message": "Rust Video Keyframe Extraction API",
"version": "2.0.0",
"status": "ready",
"docs": "/docs",
"health": "/health",
"metrics": "/metrics"
}
@app.get("/health")
async def health_check():
"""健康检查端点"""
try:
# 检查Rust二进制
rust_binary = video_processor.extractor.rust_binary_path
rust_status = "ok" if Path(rust_binary).exists() else "missing"
return {
"status": rust_status,
"timestamp": datetime.now().isoformat(),
"version": "2.0.0",
"rust_binary": rust_binary
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Health check failed: {str(e)}")
@app.get("/metrics")
async def get_metrics():
"""获取性能指标"""
avg_time = sum(stats["processing_times"]) / len(stats["processing_times"]) if stats["processing_times"] else 0
uptime = (datetime.now() - stats["start_time"]).total_seconds()
return {
"total_requests": stats["total_requests"],
"average_processing_time": avg_time,
"last_24h_requests": stats["total_requests"], # 简化版本
"system_info": {
"uptime_seconds": uptime,
"memory_usage": "N/A", # 可以扩展
"cpu_usage": "N/A"
}
}
@app.post("/extract-keyframes")
async def extract_keyframes(
video: UploadFile = File(..., description="视频文件"),
scene_threshold: float = Form(0.3, description="场景变化阈值"),
max_frames: int = Form(50, description="最大关键帧数量"),
resize_width: Optional[int] = Form(None, description="调整宽度"),
time_interval: Optional[float] = Form(None, description="时间间隔")
):
"""提取视频关键帧 (主要API端点)"""
# 参数验证
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
raise HTTPException(status_code=400, detail="不支持的视频格式")
# 更新统计
stats["total_requests"] += 1
try:
# 构建处理参数
params = {
"scene_threshold": scene_threshold,
"max_frames": max_frames
}
if resize_width:
params["resize_width"] = resize_width
if time_interval:
params["time_interval"] = time_interval
# 异步处理
start_time = time.time()
result = await video_processor.process_video_async(video, params)
processing_time = time.time() - start_time
# 更新统计
stats["processing_times"].append(processing_time)
if len(stats["processing_times"]) > 100: # 保持最近100次记录
stats["processing_times"] = stats["processing_times"][-100:]
return JSONResponse(content=result)
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.post("/extract-keyframes-zip")
async def extract_keyframes_zip(
video: UploadFile = File(...),
scene_threshold: float = Form(0.3),
max_frames: int = Form(50),
resize_width: Optional[int] = Form(None),
time_interval: Optional[float] = Form(None)
):
"""提取关键帧并返回ZIP文件"""
# 验证文件类型
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
raise HTTPException(status_code=400, detail="不支持的视频格式")
# 创建临时目录
temp_input_fd, temp_input_path = tempfile.mkstemp(suffix='.mp4')
temp_output_dir = tempfile.mkdtemp()
try:
os.close(temp_input_fd)
# 保存上传的视频
content = await video.read()
with open(temp_input_path, 'wb') as f:
f.write(content)
# 处理参数
params = {
"scene_threshold": scene_threshold,
"max_frames": max_frames
}
if resize_width:
params["resize_width"] = resize_width
if time_interval:
params["time_interval"] = time_interval
# 处理视频
result = video_processor.extractor.process_video(
video_path=temp_input_path,
output_dir=temp_output_dir,
**params
)
# 创建ZIP文件
zip_fd, zip_path = tempfile.mkstemp(suffix='.zip')
os.close(zip_fd)
with zipfile.ZipFile(zip_path, 'w') as zip_file:
# 添加关键帧图片
for keyframe_path in result.get("keyframes", []):
if Path(keyframe_path).exists():
zip_file.write(keyframe_path, Path(keyframe_path).name)
# 添加处理信息
info_content = json.dumps(result, indent=2, ensure_ascii=False)
zip_file.writestr("processing_info.json", info_content)
# 返回ZIP文件
return FileResponse(
zip_path,
media_type='application/zip',
filename=f"keyframes_{video.filename}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
finally:
# 清理临时文件
for path in [temp_input_path, temp_output_dir]:
try:
if Path(path).is_file():
Path(path).unlink()
elif Path(path).is_dir():
shutil.rmtree(path, ignore_errors=True)
except Exception:
pass
# ============================================================================
# 应用启动
# ============================================================================
def main():
"""启动API服务器"""
# 获取配置
server_config = config.get('server')
host = server_config.get('host', '0.0.0.0')
port = server_config.get('port', 8050)
print(f"""
Rust Video Keyframe Extraction API
=====================================
地址: http://{host}:{port}
文档: http://{host}:{port}/docs
健康检查: http://{host}:{port}/health
性能指标: http://{host}:{port}/metrics
=====================================
""")
# 检查Rust二进制
try:
rust_binary = video_processor.extractor.rust_binary_path
print(f"✓ Rust binary: {rust_binary}")
except Exception as e:
print(f"⚠️ Rust binary check failed: {e}")
# 启动服务器
uvicorn.run(
"api_server:app",
host=host,
port=port,
reload=False, # 生产环境关闭热重载
access_log=True
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
"""
配置管理模块
处理 config.toml 文件的读取和管理
"""
import os
from pathlib import Path
from typing import Dict, Any
try:
import toml
except ImportError:
print("⚠️ 需要安装 toml: pip install toml")
# 提供基础配置作为后备
toml = None
class ConfigManager:
"""配置管理器"""
def __init__(self, config_file: str = "config.toml"):
self.config_file = Path(config_file)
self._config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
if toml is None or not self.config_file.exists():
return self._get_default_config()
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
return toml.load(f)
except Exception as e:
print(f"⚠️ 配置文件读取失败: {e}")
return self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""默认配置"""
return {
"server": {
"host": "0.0.0.0",
"port": 8000,
"workers": 1,
"reload": False,
"log_level": "info"
},
"api": {
"title": "Video Keyframe Extraction API",
"description": "高性能视频关键帧提取服务",
"version": "1.0.0",
"max_file_size": "100MB"
},
"processing": {
"default_threshold": 0.3,
"default_output_format": "png",
"max_frames": 10000,
"temp_dir": "temp",
"upload_dir": "uploads",
"output_dir": "outputs"
},
"rust": {
"executable_name": "video_keyframe_extractor",
"executable_path": "target/release"
},
"ffmpeg": {
"auto_detect": True,
"custom_path": "",
"timeout": 300
},
"storage": {
"cleanup_interval": 3600,
"max_storage_size": "10GB",
"result_retention_days": 7
},
"monitoring": {
"enable_metrics": True,
"enable_logging": True,
"log_file": "logs/api.log",
"max_log_size": "100MB"
},
"security": {
"allowed_origins": ["*"],
"max_concurrent_tasks": 10,
"rate_limit_per_minute": 60
},
"development": {
"debug": False,
"auto_reload": False,
"cors_enabled": True
}
}
def get(self, section: str, key: str = None, default=None):
"""获取配置值"""
if key is None:
return self._config.get(section, default)
return self._config.get(section, {}).get(key, default)
def get_server_config(self):
"""获取服务器配置"""
return self.get("server")
def get_api_config(self):
"""获取API配置"""
return self.get("api")
def get_processing_config(self):
"""获取处理配置"""
return self.get("processing")
def reload(self):
"""重新加载配置"""
self._config = self._load_config()
# 全局配置实例
config = ConfigManager()

View File

@@ -0,0 +1,70 @@
# 🔧 Video Keyframe Extraction API 配置文件
[server]
# 服务器配置
host = "0.0.0.0"
port = 8050
workers = 1
reload = false
log_level = "info"
[api]
# API 基础配置
title = "Video Keyframe Extraction API"
description = "视频关键帧提取服务"
version = "1.0.0"
max_file_size = "100MB" # 最大文件大小
[processing]
# 视频处理配置
default_threshold = 0.3
default_output_format = "png"
max_frames = 10000
temp_dir = "temp"
upload_dir = "uploads"
output_dir = "outputs"
[rust]
# Rust 程序配置
executable_name = "video_keyframe_extractor"
executable_path = "target/release" # 相对路径,自动检测
[ffmpeg]
# FFmpeg 配置
auto_detect = true
custom_path = "" # 留空则自动检测
timeout = 300 # 秒
[performance]
# 性能优化配置
async_workers = 4 # 异步文件处理工作线程数
upload_chunk_size = 8192 # 上传块大小 (字节)
max_concurrent_uploads = 10 # 最大并发上传数
compression_level = 1 # ZIP 压缩级别 (0-9, 1=快速)
stream_chunk_size = 8192 # 流式响应块大小
enable_performance_metrics = true # 启用性能监控
[storage]
# 存储配置
cleanup_interval = 3600 # 清理间隔(秒)
max_storage_size = "10GB"
result_retention_days = 7
[monitoring]
# 监控配置
enable_metrics = true
enable_logging = true
log_file = "logs/api.log"
max_log_size = "100MB"
[security]
# 安全配置
allowed_origins = ["*"]
max_concurrent_tasks = 10
rate_limit_per_minute = 60
[development]
# 开发环境配置
debug = false
auto_reload = false
cors_enabled = true

View File

@@ -0,0 +1,710 @@
//! # Rust Video Keyframe Extractor
//!
//! Ultra-fast video keyframe extraction tool with SIMD optimization.
//!
//! ## Features
//! - AVX2/SSE2 SIMD optimization for maximum performance
//! - Memory-efficient streaming processing with FFmpeg
//! - Multi-threaded parallel processing
//! - Release-optimized for production use
//!
//! ## Performance
//! - 150+ FPS processing speed
//! - Real-time video analysis capability
//! - Minimal memory footprint
//!
//! ## Usage
//! ```bash
//! # Single video processing
//! rust-video --input video.mp4 --output ./keyframes --threshold 2.0
//!
//! # Benchmark mode
//! rust-video --benchmark --input video.mp4 --output ./results
//! ```
use anyhow::{Context, Result};
use chrono::prelude::*;
use clap::Parser;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::fs;
use std::io::{BufReader, Read};
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::time::Instant;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
/// Ultra-fast video keyframe extraction tool
#[derive(Parser)]
#[command(name = "rust-video")]
#[command(version = "0.1.0")]
#[command(about = "Ultra-fast video keyframe extraction with SIMD optimization")]
#[command(long_about = None)]
struct Args {
/// Input video file path
#[arg(short, long, help = "Path to the input video file")]
input: Option<PathBuf>,
/// Output directory for keyframes and results
#[arg(short, long, default_value = "./output", help = "Output directory")]
output: PathBuf,
/// Change threshold for keyframe detection (higher = fewer keyframes)
#[arg(short, long, default_value = "2.0", help = "Keyframe detection threshold")]
threshold: f64,
/// Number of parallel threads (0 = auto-detect)
#[arg(short = 'j', long, default_value = "0", help = "Number of threads")]
threads: usize,
/// Maximum number of keyframes to save (0 = save all)
#[arg(short, long, default_value = "50", help = "Maximum keyframes to save")]
max_save: usize,
/// Run performance benchmark suite
#[arg(long, help = "Run comprehensive benchmark tests")]
benchmark: bool,
/// Maximum frames to process (0 = process all frames)
#[arg(long, default_value = "0", help = "Limit number of frames to process")]
max_frames: usize,
/// FFmpeg executable path
#[arg(long, default_value = "ffmpeg", help = "Path to FFmpeg executable")]
ffmpeg_path: PathBuf,
/// Enable SIMD optimizations (AVX2/SSE2)
#[arg(long, default_value = "true", help = "Enable SIMD optimizations")]
use_simd: bool,
/// Processing block size for cache optimization
#[arg(long, default_value = "8192", help = "Block size for processing")]
block_size: usize,
/// Verbose output
#[arg(short, long, help = "Enable verbose output")]
verbose: bool,
}
/// Video frame representation optimized for SIMD processing
#[derive(Debug, Clone)]
struct VideoFrame {
frame_number: usize,
width: usize,
height: usize,
data: Vec<u8>, // Grayscale data, aligned for SIMD
}
impl VideoFrame {
/// Create a new video frame with SIMD-aligned data
fn new(frame_number: usize, width: usize, height: usize, mut data: Vec<u8>) -> Self {
// Ensure data length is multiple of 32 for AVX2 processing
let remainder = data.len() % 32;
if remainder != 0 {
data.resize(data.len() + (32 - remainder), 0);
}
Self {
frame_number,
width,
height,
data,
}
}
/// Calculate frame difference using parallel SIMD processing
fn calculate_difference_parallel_simd(&self, other: &VideoFrame, block_size: usize, use_simd: bool) -> f64 {
if self.width != other.width || self.height != other.height {
return f64::MAX;
}
let total_pixels = self.width * self.height;
let num_blocks = (total_pixels + block_size - 1) / block_size;
let total_diff: u64 = (0..num_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = ((block_idx + 1) * block_size).min(total_pixels);
let block_len = end - start;
if use_simd {
#[cfg(target_arch = "x86_64")]
{
unsafe {
if std::arch::is_x86_feature_detected!("avx2") {
return self.calculate_difference_avx2_block(&other.data, start, block_len);
} else if std::arch::is_x86_feature_detected!("sse2") {
return self.calculate_difference_sse2_block(&other.data, start, block_len);
}
}
}
}
// Fallback scalar implementation
self.data[start..end]
.iter()
.zip(other.data[start..end].iter())
.map(|(a, b)| (*a as i32 - *b as i32).abs() as u64)
.sum()
})
.sum();
total_diff as f64 / total_pixels as f64
}
/// Standard frame difference calculation (non-SIMD)
fn calculate_difference_standard(&self, other: &VideoFrame) -> f64 {
if self.width != other.width || self.height != other.height {
return f64::MAX;
}
let len = self.width * self.height;
let total_diff: u64 = self.data[..len]
.iter()
.zip(other.data[..len].iter())
.map(|(a, b)| (*a as i32 - *b as i32).abs() as u64)
.sum();
total_diff as f64 / len as f64
}
/// AVX2 optimized block processing
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn calculate_difference_avx2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 {
let mut total_diff = 0u64;
let chunks = len / 32;
for i in 0..chunks {
let offset = start + i * 32;
let a = _mm256_loadu_si256(self.data.as_ptr().add(offset) as *const __m256i);
let b = _mm256_loadu_si256(other_data.as_ptr().add(offset) as *const __m256i);
let diff = _mm256_sad_epu8(a, b);
let result = _mm256_extract_epi64(diff, 0) as u64 +
_mm256_extract_epi64(diff, 1) as u64 +
_mm256_extract_epi64(diff, 2) as u64 +
_mm256_extract_epi64(diff, 3) as u64;
total_diff += result;
}
// Process remaining bytes
for i in (start + chunks * 32)..(start + len) {
total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64;
}
total_diff
}
/// SSE2 optimized block processing
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn calculate_difference_sse2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 {
let mut total_diff = 0u64;
let chunks = len / 16;
for i in 0..chunks {
let offset = start + i * 16;
let a = _mm_loadu_si128(self.data.as_ptr().add(offset) as *const __m128i);
let b = _mm_loadu_si128(other_data.as_ptr().add(offset) as *const __m128i);
let diff = _mm_sad_epu8(a, b);
let result = _mm_extract_epi64(diff, 0) as u64 + _mm_extract_epi64(diff, 1) as u64;
total_diff += result;
}
// Process remaining bytes
for i in (start + chunks * 16)..(start + len) {
total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64;
}
total_diff
}
}
/// Performance measurement results
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PerformanceResult {
test_name: String,
video_file: String,
total_time_ms: f64,
frame_extraction_time_ms: f64,
keyframe_analysis_time_ms: f64,
total_frames: usize,
keyframes_extracted: usize,
keyframe_ratio: f64,
processing_fps: f64,
threshold: f64,
optimization_type: String,
simd_enabled: bool,
threads_used: usize,
timestamp: String,
}
/// Extract video frames using FFmpeg memory streaming
fn extract_frames_memory_stream(
video_path: &PathBuf,
ffmpeg_path: &PathBuf,
max_frames: usize,
verbose: bool,
) -> Result<(Vec<VideoFrame>, usize, usize)> {
if verbose {
println!("🎬 Extracting frames using FFmpeg memory streaming...");
println!("📁 Video: {}", video_path.display());
}
// Get video information
let probe_output = Command::new(ffmpeg_path)
.args(["-i", video_path.to_str().unwrap(), "-hide_banner"])
.output()
.context("Failed to probe video with FFmpeg")?;
let probe_info = String::from_utf8_lossy(&probe_output.stderr);
let (width, height) = parse_video_dimensions(&probe_info)
.ok_or_else(|| anyhow::anyhow!("Cannot parse video dimensions"))?;
if verbose {
println!("📐 Video dimensions: {}x{}", width, height);
}
// Build optimized FFmpeg command
let mut cmd = Command::new(ffmpeg_path);
cmd.args([
"-i", video_path.to_str().unwrap(),
"-f", "rawvideo",
"-pix_fmt", "gray",
"-an", // No audio
"-threads", "0", // Auto-detect threads
"-preset", "ultrafast", // Fastest preset
]);
if max_frames > 0 {
cmd.args(["-frames:v", &max_frames.to_string()]);
}
cmd.args(["-"]).stdout(Stdio::piped()).stderr(Stdio::null());
let start_time = Instant::now();
let mut child = cmd.spawn().context("Failed to spawn FFmpeg process")?;
let stdout = child.stdout.take().unwrap();
let mut reader = BufReader::with_capacity(1024 * 1024, stdout); // 1MB buffer
let frame_size = width * height;
let mut frames = Vec::new();
let mut frame_count = 0;
let mut frame_buffer = vec![0u8; frame_size];
if verbose {
println!("📦 Frame size: {} bytes", frame_size);
}
// Stream frame data directly into memory
loop {
match reader.read_exact(&mut frame_buffer) {
Ok(()) => {
frames.push(VideoFrame::new(
frame_count,
width,
height,
frame_buffer.clone(),
));
frame_count += 1;
if verbose && frame_count % 200 == 0 {
print!("\r⚡ Frames processed: {}", frame_count);
}
if max_frames > 0 && frame_count >= max_frames {
break;
}
}
Err(_) => break, // End of stream
}
}
let _ = child.wait();
if verbose {
println!("\r✅ Frame extraction complete: {} frames in {:.2}s",
frame_count, start_time.elapsed().as_secs_f64());
}
Ok((frames, width, height))
}
/// Parse video dimensions from FFmpeg probe output
fn parse_video_dimensions(probe_info: &str) -> Option<(usize, usize)> {
for line in probe_info.lines() {
if line.contains("Video:") && line.contains("x") {
for part in line.split_whitespace() {
if let Some(x_pos) = part.find('x') {
let width_str = &part[..x_pos];
let height_part = &part[x_pos + 1..];
let height_str = height_part.split(',').next().unwrap_or(height_part);
if let (Ok(width), Ok(height)) = (width_str.parse::<usize>(), height_str.parse::<usize>()) {
return Some((width, height));
}
}
}
}
}
None
}
/// Extract keyframes using optimized algorithms
fn extract_keyframes_optimized(
frames: &[VideoFrame],
threshold: f64,
use_simd: bool,
block_size: usize,
verbose: bool,
) -> Result<Vec<usize>> {
if frames.len() < 2 {
return Ok(Vec::new());
}
let optimization_name = if use_simd { "SIMD+Parallel" } else { "Standard Parallel" };
if verbose {
println!("🚀 Keyframe analysis (threshold: {}, optimization: {})", threshold, optimization_name);
}
let start_time = Instant::now();
// Parallel computation of frame differences
let differences: Vec<f64> = frames
.par_windows(2)
.map(|pair| {
if use_simd {
pair[0].calculate_difference_parallel_simd(&pair[1], block_size, true)
} else {
pair[0].calculate_difference_standard(&pair[1])
}
})
.collect();
// Find keyframes based on threshold
let keyframe_indices: Vec<usize> = differences
.par_iter()
.enumerate()
.filter_map(|(i, &diff)| {
if diff > threshold {
Some(i + 1)
} else {
None
}
})
.collect();
if verbose {
println!("⚡ Analysis complete in {:.2}s", start_time.elapsed().as_secs_f64());
println!("🎯 Found {} keyframes", keyframe_indices.len());
}
Ok(keyframe_indices)
}
/// Save keyframes as JPEG images using FFmpeg
fn save_keyframes_optimized(
video_path: &PathBuf,
keyframe_indices: &[usize],
output_dir: &PathBuf,
ffmpeg_path: &PathBuf,
max_save: usize,
verbose: bool,
) -> Result<usize> {
if keyframe_indices.is_empty() {
if verbose {
println!("⚠️ No keyframes to save");
}
return Ok(0);
}
if verbose {
println!("💾 Saving keyframes...");
}
fs::create_dir_all(output_dir).context("Failed to create output directory")?;
let save_count = keyframe_indices.len().min(max_save);
let mut saved = 0;
for (i, &frame_idx) in keyframe_indices.iter().take(save_count).enumerate() {
let output_path = output_dir.join(format!("keyframe_{:03}.jpg", i + 1));
let timestamp = frame_idx as f64 / 30.0; // Assume 30 FPS
let output = Command::new(ffmpeg_path)
.args([
"-i", video_path.to_str().unwrap(),
"-ss", &timestamp.to_string(),
"-vframes", "1",
"-q:v", "2", // High quality
"-y",
output_path.to_str().unwrap(),
])
.output()
.context("Failed to extract keyframe with FFmpeg")?;
if output.status.success() {
saved += 1;
if verbose && (saved % 10 == 0 || saved == save_count) {
print!("\r💾 Saved: {}/{} keyframes", saved, save_count);
}
} else if verbose {
eprintln!("⚠️ Failed to save keyframe {}", frame_idx);
}
}
if verbose {
println!("\r✅ Keyframe saving complete: {}/{}", saved, save_count);
}
Ok(saved)
}
/// Run performance test
fn run_performance_test(
video_path: &PathBuf,
threshold: f64,
test_name: &str,
ffmpeg_path: &PathBuf,
max_frames: usize,
use_simd: bool,
block_size: usize,
verbose: bool,
) -> Result<PerformanceResult> {
if verbose {
println!("\n{}", "=".repeat(60));
println!("⚡ Running test: {}", test_name);
println!("{}", "=".repeat(60));
}
let total_start = Instant::now();
// Frame extraction
let extraction_start = Instant::now();
let (frames, _width, _height) = extract_frames_memory_stream(video_path, ffmpeg_path, max_frames, verbose)?;
let extraction_time = extraction_start.elapsed().as_secs_f64() * 1000.0;
// Keyframe analysis
let analysis_start = Instant::now();
let keyframe_indices = extract_keyframes_optimized(&frames, threshold, use_simd, block_size, verbose)?;
let analysis_time = analysis_start.elapsed().as_secs_f64() * 1000.0;
let total_time = total_start.elapsed().as_secs_f64() * 1000.0;
let optimization_type = if use_simd {
format!("SIMD+Parallel(block:{})", block_size)
} else {
"Standard Parallel".to_string()
};
let result = PerformanceResult {
test_name: test_name.to_string(),
video_file: video_path.file_name().unwrap().to_string_lossy().to_string(),
total_time_ms: total_time,
frame_extraction_time_ms: extraction_time,
keyframe_analysis_time_ms: analysis_time,
total_frames: frames.len(),
keyframes_extracted: keyframe_indices.len(),
keyframe_ratio: keyframe_indices.len() as f64 / frames.len() as f64 * 100.0,
processing_fps: frames.len() as f64 / (total_time / 1000.0),
threshold,
optimization_type,
simd_enabled: use_simd,
threads_used: rayon::current_num_threads(),
timestamp: Local::now().format("%Y-%m-%d %H:%M:%S").to_string(),
};
if verbose {
println!("\n⚡ Test Results:");
println!(" 🕐 Total time: {:.2}ms ({:.2}s)", result.total_time_ms, result.total_time_ms / 1000.0);
println!(" 📥 Extraction: {:.2}ms ({:.1}%)", result.frame_extraction_time_ms,
result.frame_extraction_time_ms / result.total_time_ms * 100.0);
println!(" 🧮 Analysis: {:.2}ms ({:.1}%)", result.keyframe_analysis_time_ms,
result.keyframe_analysis_time_ms / result.total_time_ms * 100.0);
println!(" 📊 Frames: {}", result.total_frames);
println!(" 🎯 Keyframes: {}", result.keyframes_extracted);
println!(" 🚀 Speed: {:.1} FPS", result.processing_fps);
println!(" ⚙️ Optimization: {}", result.optimization_type);
}
Ok(result)
}
/// Run comprehensive benchmark suite
fn run_benchmark_suite(video_path: &PathBuf, output_dir: &PathBuf, ffmpeg_path: &PathBuf, args: &Args) -> Result<()> {
println!("🚀 Rust Video Keyframe Extractor - Benchmark Suite");
println!("🕐 Time: {}", Local::now().format("%Y-%m-%d %H:%M:%S"));
println!("🎬 Video: {}", video_path.display());
println!("🧵 Threads: {}", rayon::current_num_threads());
// CPU feature detection
#[cfg(target_arch = "x86_64")]
{
println!("🔧 CPU Features:");
if std::arch::is_x86_feature_detected!("avx2") {
println!(" ✅ AVX2 supported");
} else if std::arch::is_x86_feature_detected!("sse2") {
println!(" ✅ SSE2 supported");
} else {
println!(" ⚠️ Scalar only");
}
}
let test_configs = vec![
("Standard Parallel", false, 8192),
("SIMD 8K blocks", true, 8192),
("SIMD 16K blocks", true, 16384),
("SIMD 32K blocks", true, 32768),
];
let mut results = Vec::new();
for (test_name, use_simd, block_size) in test_configs {
match run_performance_test(
video_path,
args.threshold,
test_name,
ffmpeg_path,
1000, // Test with 1000 frames
use_simd,
block_size,
args.verbose,
) {
Ok(result) => results.push(result),
Err(e) => println!("❌ Test failed {}: {:?}", test_name, e),
}
}
// Performance comparison table
println!("\n{}", "=".repeat(120));
println!("🏆 Benchmark Results");
println!("{}", "=".repeat(120));
println!("{:<20} {:<15} {:<12} {:<12} {:<12} {:<8} {:<8} {:<12} {:<20}",
"Test", "Total(ms)", "Extract(ms)", "Analyze(ms)", "Speed(FPS)", "Frames", "Keyframes", "Threads", "Optimization");
println!("{}", "-".repeat(120));
for result in &results {
println!("{:<20} {:<15.1} {:<12.1} {:<12.1} {:<12.1} {:<8} {:<8} {:<12} {:<20}",
result.test_name,
result.total_time_ms,
result.frame_extraction_time_ms,
result.keyframe_analysis_time_ms,
result.processing_fps,
result.total_frames,
result.keyframes_extracted,
result.threads_used,
result.optimization_type);
}
// Find best performance
if let Some(best_result) = results.iter().max_by(|a, b| a.processing_fps.partial_cmp(&b.processing_fps).unwrap()) {
println!("\n🏆 Best Performance: {}", best_result.test_name);
println!(" ⚡ Speed: {:.1} FPS", best_result.processing_fps);
println!(" 🕐 Time: {:.2}s", best_result.total_time_ms / 1000.0);
println!(" 🧮 Analysis: {:.2}s", best_result.keyframe_analysis_time_ms / 1000.0);
println!(" ⚙️ Tech: {}", best_result.optimization_type);
}
// Save detailed results
fs::create_dir_all(output_dir).context("Failed to create output directory")?;
let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string();
let results_file = output_dir.join(format!("benchmark_results_{}.json", timestamp));
let json_results = serde_json::to_string_pretty(&results)?;
fs::write(&results_file, json_results)?;
println!("\n📄 Detailed results saved to: {}", results_file.display());
println!("{}", "=".repeat(120));
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
// Setup thread pool
if args.threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.context("Failed to set thread pool")?;
}
println!("🚀 Rust Video Keyframe Extractor v0.1.0");
println!("🧵 Threads: {}", rayon::current_num_threads());
// Verify FFmpeg availability
if !args.ffmpeg_path.exists() && args.ffmpeg_path.to_str() == Some("ffmpeg") {
// Try to find ffmpeg in PATH
if Command::new("ffmpeg").arg("-version").output().is_err() {
anyhow::bail!("FFmpeg not found. Please install FFmpeg or specify path with --ffmpeg-path");
}
} else if !args.ffmpeg_path.exists() {
anyhow::bail!("FFmpeg not found at: {}", args.ffmpeg_path.display());
}
if args.benchmark {
// Benchmark mode
let video_path = args.input.clone()
.ok_or_else(|| anyhow::anyhow!("Benchmark requires input video file --input <path>"))?;
if !video_path.exists() {
anyhow::bail!("Video file not found: {}", video_path.display());
}
run_benchmark_suite(&video_path, &args.output, &args.ffmpeg_path, &args)?;
} else {
// Single processing mode
let video_path = args.input
.ok_or_else(|| anyhow::anyhow!("Please specify input video file --input <path>"))?;
if !video_path.exists() {
anyhow::bail!("Video file not found: {}", video_path.display());
}
// Run single keyframe extraction
let result = run_performance_test(
&video_path,
args.threshold,
"Single Processing",
&args.ffmpeg_path,
args.max_frames,
args.use_simd,
args.block_size,
args.verbose,
)?;
// Extract and save keyframes
let (frames, _, _) = extract_frames_memory_stream(&video_path, &args.ffmpeg_path, args.max_frames, args.verbose)?;
let keyframe_indices = extract_keyframes_optimized(&frames, args.threshold, args.use_simd, args.block_size, args.verbose)?;
let saved_count = save_keyframes_optimized(&video_path, &keyframe_indices, &args.output, &args.ffmpeg_path, args.max_save, args.verbose)?;
println!("\n✅ Processing Complete!");
println!("🎯 Keyframes extracted: {}", result.keyframes_extracted);
println!("💾 Keyframes saved: {}", saved_count);
println!("⚡ Processing speed: {:.1} FPS", result.processing_fps);
println!("📁 Output directory: {}", args.output.display());
// Save processing report
let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string();
let report_file = args.output.join(format!("processing_report_{}.json", timestamp));
let json_result = serde_json::to_string_pretty(&result)?;
fs::write(&report_file, json_result)?;
if args.verbose {
println!("📄 Processing report saved to: {}", report_file.display());
}
}
Ok(())
}

View File

@@ -0,0 +1,219 @@
#!/usr/bin/env python3
"""
启动脚本
支持开发模式和生产模式启动
"""
import os
import sys
import subprocess
import argparse
from pathlib import Path
from config import config
def check_rust_executable():
"""检查 Rust 可执行文件是否存在"""
rust_config = config.get("rust")
executable_name = rust_config.get("executable_name", "video_keyframe_extractor")
executable_path = rust_config.get("executable_path", "target/release")
possible_paths = [
f"./{executable_path}/{executable_name}.exe",
f"./{executable_path}/{executable_name}",
f"./{executable_name}.exe",
f"./{executable_name}"
]
for path in possible_paths:
if Path(path).exists():
print(f"✓ Found Rust executable: {path}")
return str(Path(path).absolute())
print("⚠ Warning: Rust executable not found")
print("Please compile first: cargo build --release")
return None
def check_dependencies():
"""检查 Python 依赖"""
try:
import fastapi
import uvicorn
print("✓ FastAPI dependencies available")
return True
except ImportError as e:
print(f"✗ Missing dependencies: {e}")
print("Please install: pip install -r requirements.txt")
return False
def install_dependencies():
"""安装依赖"""
print("Installing dependencies...")
try:
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
check=True)
print("✓ Dependencies installed successfully")
return True
except subprocess.CalledProcessError as e:
print(f"✗ Failed to install dependencies: {e}")
return False
def start_development_server(host="127.0.0.1", port=8050, reload=True):
"""启动开发服务器"""
print(f" Starting development server on http://{host}:{port}")
print(f" API docs: http://{host}:{port}/docs")
print(f" Health check: http://{host}:{port}/health")
try:
import uvicorn
uvicorn.run(
"api_server:app",
host=host,
port=port,
reload=reload,
log_level="info"
)
except ImportError:
print("uvicorn not found, trying with subprocess...")
subprocess.run([
sys.executable, "-m", "uvicorn",
"api_server:app",
"--host", host,
"--port", str(port),
"--reload" if reload else ""
])
def start_production_server(host="0.0.0.0", port=8000, workers=4):
"""启动生产服务器"""
print(f"🚀 Starting production server on http://{host}:{port}")
print(f"Workers: {workers}")
subprocess.run([
sys.executable, "-m", "uvicorn",
"api_server:app",
"--host", host,
"--port", str(port),
"--workers", str(workers),
"--log-level", "warning"
])
def create_systemd_service():
"""创建 systemd 服务文件"""
current_dir = Path.cwd()
python_path = sys.executable
service_content = f"""[Unit]
Description=Video Keyframe Extraction API Server
After=network.target
[Service]
Type=exec
User=www-data
WorkingDirectory={current_dir}
Environment=PATH=/usr/bin:/usr/local/bin
ExecStart={python_path} -m uvicorn api_server:app --host 0.0.0.0 --port 8000 --workers 4
Restart=always
RestartSec=10
[Install]
WantedBy=multi-user.target
"""
service_file = Path("/etc/systemd/system/video-keyframe-api.service")
try:
with open(service_file, 'w') as f:
f.write(service_content)
print(f"✓ Systemd service created: {service_file}")
print("To enable and start:")
print(" sudo systemctl enable video-keyframe-api")
print(" sudo systemctl start video-keyframe-api")
except PermissionError:
print("✗ Permission denied. Please run with sudo for systemd service creation")
# 创建本地副本
local_service = Path("./video-keyframe-api.service")
with open(local_service, 'w') as f:
f.write(service_content)
print(f"✓ Service file created locally: {local_service}")
print(f"To install: sudo cp {local_service} /etc/systemd/system/")
def main():
parser = argparse.ArgumentParser(description="Video Keyframe Extraction API Server")
# 从配置文件获取默认值
server_config = config.get_server_config()
parser.add_argument("--mode", choices=["dev", "prod", "install"], default="dev",
help="运行模式: dev (开发), prod (生产), install (安装依赖)")
parser.add_argument("--host", default=server_config.get("host", "127.0.0.1"), help="绑定主机")
parser.add_argument("--port", type=int, default=server_config.get("port", 8000), help="端口号")
parser.add_argument("--workers", type=int, default=server_config.get("workers", 4), help="生产模式工作进程数")
parser.add_argument("--no-reload", action="store_true", help="禁用自动重载")
parser.add_argument("--check", action="store_true", help="仅检查环境")
parser.add_argument("--create-service", action="store_true", help="创建 systemd 服务")
args = parser.parse_args()
print("=== Video Keyframe Extraction API Server ===")
# 检查环境
rust_exe = check_rust_executable()
deps_ok = check_dependencies()
if args.check:
print("\n=== Environment Check ===")
print(f"Rust executable: {'' if rust_exe else ''}")
print(f"Python dependencies: {'' if deps_ok else ''}")
return
if args.create_service:
create_systemd_service()
return
# 安装模式
if args.mode == "install":
if not deps_ok:
install_dependencies()
else:
print("✓ Dependencies already installed")
return
# 检查必要条件
if not rust_exe:
print("✗ Cannot start without Rust executable")
print("Please run: cargo build --release")
sys.exit(1)
if not deps_ok:
print("Installing missing dependencies...")
if not install_dependencies():
sys.exit(1)
# 启动服务器
if args.mode == "dev":
start_development_server(
host=args.host,
port=args.port,
reload=not args.no_reload
)
elif args.mode == "prod":
start_production_server(
host=args.host,
port=args.port,
workers=args.workers
)
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
@@ -21,6 +22,7 @@ class CacheManager:
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
"""
_instance = None
def __new__(cls, *args, **kwargs):
@@ -32,7 +34,7 @@ class CacheManager:
"""
初始化缓存管理器。
"""
if not hasattr(self, '_initialized'):
if not hasattr(self, "_initialized"):
self.default_ttl = default_ttl
self.semantic_cache_collection_name = "semantic_cache"
@@ -79,7 +81,7 @@ class CacheManager:
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
return None
return embedding_array.astype('float32')
return embedding_array.astype("float32")
else:
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
return None
@@ -104,12 +106,18 @@ class CacheManager:
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
try:
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{file_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
async def get(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
semantic_query: Optional[str] = None,
) -> Optional[Any]:
"""
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
@@ -136,12 +144,12 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
query_embedding = np.array([validated_embedding], dtype='float32')
query_embedding = np.array([validated_embedding], dtype="float32")
# 步骤 2a: L1 语义缓存 (FAISS)
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1)
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
@@ -151,10 +159,7 @@ class CacheManager:
# 步骤 2b: L2 精确缓存 (数据库)
cache_results_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": key},
single_result=True
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
)
if cache_results_obj:
@@ -172,8 +177,8 @@ class CacheManager:
filters={"cache_key": key},
data={
"last_accessed": time.time(),
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
}
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
},
)
# 回填 L1
@@ -181,11 +186,7 @@ class CacheManager:
return data
else:
# 删除过期的缓存条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"cache_key": key}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
# 步骤 2c: L2 语义缓存 (VectorDB Service)
if query_embedding is not None:
@@ -193,14 +194,16 @@ class CacheManager:
results = vector_db_service.query(
collection_name=self.semantic_cache_collection_name,
query_embeddings=query_embedding.tolist(),
n_results=1
n_results=1,
)
if results and results.get("ids") and results["ids"][0]:
distance = (
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
)
if results and results.get('ids') and results['ids'][0]:
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
if distance != "N/A" and distance < 0.75:
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
# 从数据库获取缓存数据
@@ -208,7 +211,7 @@ class CacheManager:
model_class=CacheEntries,
query_type="get",
filters={"cache_key": l2_hit_key},
single_result=True
single_result=True,
)
if semantic_cache_results_obj:
@@ -235,7 +238,15 @@ class CacheManager:
logger.debug(f"缓存未命中: {key}")
return None
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
async def set(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
data: Any,
ttl: Optional[int] = None,
semantic_query: Optional[str] = None,
):
"""将结果存入所有缓存层。"""
if ttl is None:
ttl = self.default_ttl
@@ -251,20 +262,15 @@ class CacheManager:
# 写入 L2 (数据库)
cache_data = {
"cache_key": key,
"cache_value": orjson.dumps(data).decode('utf-8'),
"cache_value": orjson.dumps(data).decode("utf-8"),
"expires_at": expires_at,
"tool_name": tool_name,
"created_at": time.time(),
"last_accessed": time.time(),
"access_count": 1
"access_count": 1,
}
await db_save(
model_class=CacheEntries,
data=cache_data,
key_field="cache_key",
key_value=key
)
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
# 写入语义缓存
if semantic_query and self.embedding_model:
@@ -274,7 +280,7 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
embedding = np.array([validated_embedding], dtype='float32')
embedding = np.array([validated_embedding], dtype="float32")
# 写入 L1 Vector
new_id = self.l1_vector_index.ntotal
@@ -286,7 +292,7 @@ class CacheManager:
vector_db_service.add(
collection_name=self.semantic_cache_collection_name,
embeddings=embedding.tolist(),
ids=[key]
ids=[key],
)
except Exception as e:
logger.warning(f"语义缓存写入失败: {e}")
@@ -306,7 +312,7 @@ class CacheManager:
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={} # 删除所有记录
filters={}, # 删除所有记录
)
# 清空 VectorDB
@@ -338,14 +344,12 @@ class CacheManager:
del self.l1_kv_cache[key]
# 清理L2过期条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"expires_at": {"$lt": current_time}}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
# 全局实例
tool_cache = CacheManager()

View File

@@ -1,405 +0,0 @@
"""工具执行历史记录模块"""
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
import json
from pathlib import Path
import inspect
from .logger import get_logger
from src.config.config import global_config
from src.common.cache_manager import tool_cache
logger = get_logger("tool_history")
class ToolHistoryManager:
"""工具执行历史记录管理器"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._history: List[Dict[str, Any]] = []
self._initialized = True
self._data_dir = Path("data/tool_history")
self._data_dir.mkdir(parents=True, exist_ok=True)
self._history_file = self._data_dir / "tool_history.jsonl"
self._load_history()
def _save_history(self):
"""保存所有历史记录到文件"""
try:
with self._history_file.open("w", encoding="utf-8") as f:
for record in self._history:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e:
logger.error(f"保存工具调用记录失败: {e}")
def _save_record(self, record: Dict[str, Any]):
"""保存单条记录到文件"""
try:
with self._history_file.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e:
logger.error(f"保存工具调用记录失败: {e}")
def _clean_expired_records(self):
"""清理已过期的记录"""
original_count = len(self._history)
self._history = [record for record in self._history if record.get("ttl_count", 0) < record.get("ttl", 5)]
cleaned_count = original_count - len(self._history)
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 条过期的工具历史记录,剩余 {len(self._history)}")
self._save_history()
else:
logger.debug("没有需要清理的过期工具历史记录")
def record_tool_call(self,
tool_name: str,
args: Dict[str, Any],
result: Any,
execution_time: float,
status: str,
chat_id: Optional[str] = None,
ttl: int = 5):
"""记录工具调用
Args:
tool_name: 工具名称
args: 工具调用参数
result: 工具返回结果
execution_time: 执行时间(秒)
status: 执行状态("completed""error")
chat_id: 聊天ID与ChatManager中的chat_id对应用于标识群聊或私聊会话
ttl: 该记录的生命周期值插入提示词多少次后删除默认为5
"""
# 检查是否启用历史记录且ttl大于0
if not global_config.tool.history.enable_history or ttl <= 0:
return
# 先清理过期记录
self._clean_expired_records()
try:
# 创建记录
record = {
"tool_name": tool_name,
"timestamp": datetime.now().isoformat(),
"arguments": self._sanitize_args(args),
"result": self._sanitize_result(result),
"execution_time": execution_time,
"status": status,
"chat_id": chat_id,
"ttl": ttl,
"ttl_count": 0
}
# 添加到内存中的历史记录
self._history.append(record)
# 保存到文件
self._save_record(record)
if status == "completed":
logger.info(f"工具 {tool_name} 调用完成,耗时:{execution_time:.2f}s")
else:
logger.error(f"工具 {tool_name} 调用失败:{result}")
except Exception as e:
logger.error(f"记录工具调用时发生错误: {e}")
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""清理参数中的敏感信息"""
sensitive_keys = ['api_key', 'token', 'password', 'secret']
sanitized = args.copy()
def _sanitize_value(value):
if isinstance(value, dict):
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
for k, v in value.items()}
return value
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
for k, v in sanitized.items()}
def _sanitize_result(self, result: Any) -> Any:
"""清理结果中的敏感信息"""
if isinstance(result, dict):
return self._sanitize_args(result)
return result
def _load_history(self):
"""加载历史记录文件"""
try:
if self._history_file.exists():
self._history = []
with self._history_file.open("r", encoding="utf-8") as f:
for line in f:
try:
record = json.loads(line)
if record.get("ttl_count", 0) < record.get("ttl", 5): # 只加载未过期的记录
self._history.append(record)
except json.JSONDecodeError:
continue
logger.info(f"成功加载了 {len(self._history)} 条历史记录")
except Exception as e:
logger.error(f"加载历史记录失败: {e}")
def query_history(self,
tool_names: Optional[List[str]] = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
chat_id: Optional[str] = None,
limit: Optional[int] = None,
status: Optional[str] = None) -> List[Dict[str, Any]]:
"""查询工具调用历史
Args:
tool_names: 工具名称列表,为空则查询所有工具
start_time: 开始时间可以是datetime对象或ISO格式字符串
end_time: 结束时间可以是datetime对象或ISO格式字符串
chat_id: 聊天ID与ChatManager中的chat_id对应用于查询特定群聊或私聊的历史记录
limit: 返回记录数量限制
status: 执行状态筛选("completed""error")
Returns:
符合条件的历史记录列表
"""
# 先清理过期记录
self._clean_expired_records()
def _parse_time(time_str: Optional[Union[datetime, str]]) -> Optional[datetime]:
if isinstance(time_str, datetime):
return time_str
elif isinstance(time_str, str):
return datetime.fromisoformat(time_str)
return None
filtered_history = self._history
# 按工具名筛选
if tool_names:
filtered_history = [
record for record in filtered_history
if record["tool_name"] in tool_names
]
# 按时间范围筛选
start_dt = _parse_time(start_time)
end_dt = _parse_time(end_time)
if start_dt:
filtered_history = [
record for record in filtered_history
if datetime.fromisoformat(record["timestamp"]) >= start_dt
]
if end_dt:
filtered_history = [
record for record in filtered_history
if datetime.fromisoformat(record["timestamp"]) <= end_dt
]
# 按聊天ID筛选
if chat_id:
filtered_history = [
record for record in filtered_history
if record.get("chat_id") == chat_id
]
# 按状态筛选
if status:
filtered_history = [
record for record in filtered_history
if record["status"] == status
]
# 应用数量限制
if limit:
filtered_history = filtered_history[-limit:]
return filtered_history
def get_recent_history_prompt(self,
limit: Optional[int] = None,
chat_id: Optional[str] = None) -> str:
"""
获取最近工具调用历史的提示词
Args:
limit: 返回的历史记录数量,如果不提供则使用配置中的max_history
chat_id: 会话ID用于只获取当前会话的历史
Returns:
格式化的历史记录提示词
"""
# 检查是否启用历史记录
if not global_config.tool.history.enable_history:
return ""
# 使用配置中的最大历史记录数
if limit is None:
limit = global_config.tool.history.max_history
recent_history = self.query_history(
chat_id=chat_id,
limit=limit
)
if not recent_history:
return ""
prompt = "\n工具执行历史:\n"
needs_save = False
updated_history = []
for record in recent_history:
# 增加ttl计数
record["ttl_count"] = record.get("ttl_count", 0) + 1
needs_save = True
# 如果未超过ttl则添加到提示词中
if record["ttl_count"] < record.get("ttl", 5):
# 提取结果中的name和content
result = record['result']
if isinstance(result, dict):
name = result.get('name', record['tool_name'])
content = result.get('content', str(result))
else:
name = record['tool_name']
content = str(result)
# 格式化内容,去除多余空白和换行
content = content.strip().replace('\n', ' ')
# 如果内容太长则截断
if len(content) > 200:
content = content[:200] + "..."
prompt += f"{name}: \n{content}\n\n"
updated_history.append(record)
# 更新历史记录并保存
if needs_save:
self._history = updated_history
self._save_history()
return prompt
def clear_history(self):
"""清除历史记录"""
self._history.clear()
self._save_history()
logger.info("工具调用历史记录已清除")
def wrap_tool_executor():
"""
包装工具执行器以添加历史记录和缓存功能
这个函数应该在系统启动时被调用一次
"""
from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call
history_manager = ToolHistoryManager()
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
start_time = time.time()
# 确保我们有 tool_instance
if not tool_instance:
tool_instance = get_tool_instance(tool_call.func_name)
# 如果没有 tool_instance就无法进行缓存检查直接执行
if not tool_instance:
result = await original_execute(self, tool_call, None)
execution_time = time.time() - start_time
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=result,
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=5 # Default TTL
)
return result
# 新的缓存逻辑
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
try:
result = await original_execute(self, tool_call, tool_instance)
execution_time = time.time() - start_time
# 缓存结果
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# 记录成功的调用
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=result,
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=tool_instance.history_ttl
)
return result
except Exception as e:
execution_time = time.time() - start_time
# 记录失败的调用
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=str(e),
execution_time=execution_time,
status="error",
chat_id=getattr(self, 'chat_id', None),
ttl=tool_instance.history_ttl
)
raise
# 替换原始方法
ToolExecutor.execute_tool_call = wrapped_execute_tool_call

View File

@@ -621,6 +621,7 @@ class WakeUpSystemConfig(ValidatedConfigBase):
decay_interval: float = Field(default=30.0, ge=1.0, description="唤醒度衰减间隔(秒)")
angry_duration: float = Field(default=300.0, ge=10.0, description="愤怒状态持续时间(秒)")
angry_prompt: str = Field(default="你被人吵醒了非常生气,说话带着怒气", description="被吵醒后的愤怒提示词")
re_sleep_delay_minutes: int = Field(default=5, ge=1, description="被唤醒后,如果多久没有新消息则尝试重新入睡(分钟)")
# --- 失眠机制相关参数 ---
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")

View File

@@ -5,7 +5,7 @@ import random
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
from src.common.logger import get_logger
from src.config.config import model_config
@@ -283,19 +283,26 @@ class LLMRequest:
tools: Optional[List[Dict[str, Any]]] = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""执行单次请求"""
# 模型选择和请求准备
start_time = time.time()
model_info, api_provider, client = self._select_model()
model_name = model_info.name
"""
执行单次请求,并在模型失败时按顺序切换到下一个可用模型。
"""
failed_models = set()
last_exception: Optional[Exception] = None
model_scheduler = self._model_scheduler(failed_models)
for model_info, api_provider, client in model_scheduler:
start_time = time.time()
model_name = model_info.name
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
try:
# 检查是否启用反截断
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
processed_prompt += self.anti_truncation_instruction
logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能")
logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
@@ -304,13 +311,12 @@ class LLMRequest:
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 空回复重试逻辑
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
while empty_retry_count <= max_empty_retry:
try:
response = await self._execute_request(
api_provider=api_provider,
client=client,
@@ -321,93 +327,86 @@ class LLMRequest:
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = False
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
# 检测是否为空回复或截断
if not tool_calls:
is_empty_reply = not content or content.strip() == ""
is_truncated = False
if use_anti_truncation:
if content.endswith("[done]"):
content = content[:-6].strip()
logger.debug("检测到并已移除 [done] 标记")
else:
is_truncated = True
logger.warning("未检测到 [done] 标记,判定为截断")
if is_empty_reply or is_truncated:
if empty_retry_count < max_empty_retry:
empty_retry_count += 1
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
model_info, api_provider, client = self._select_model()
continue
else:
# 已达到最大重试次数,但仍然是空回复或截断
reason = "空回复" if is_empty_reply else "截断"
# 抛出异常,由外层重试逻辑或最终的异常处理器捕获
raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复")
# 记录使用情况
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
time_cost=time.time() - start_time,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
# 处理空回复
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复")
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
elif empty_retry_count > 0:
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
return content, (reasoning_content, model_info.name, tool_calls)
except Exception as e:
logger.error(f"请求执行失败: {e}")
if raise_when_empty:
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
if empty_retry_count == 0:
raise
# 如果在重试过程中失败,则继续重试
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue
continue # 继续使用当前模型重试
else:
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
else:
# 在并发模式下,单个请求的失败不应中断整个并发流程,
# 而是将异常返回给调用者(即 execute_concurrently进行统一处理
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
# 重试失败
# 成功获取响应
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
)
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复")
return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None)
raise RuntimeError("生成空回复")
content = "生成的响应为空"
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
return content, (reasoning_content, model_name, tool_calls)
except RespNotOkException as e:
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
else:
logger.error(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code}")
if raise_when_empty:
raise
# 对于其他HTTP错误直接抛出不再尝试其他模型
return f"请求失败: {e}", ("", model_name, None)
except RuntimeError as e:
# 捕获所有重试失败(包括空回复和网络问题)
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
except Exception as e:
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_when_empty:
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
return "所有模型都请求失败", ("", "unknown", None)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
@@ -446,9 +445,24 @@ class LLMRequest:
return embedding, model_info.name
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
"""
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
"""
for model_name in self.model_for_task.model_list:
if model_name in failed_models:
continue
model_info = model_config.get_model_info(model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = (self.request_type == "embedding")
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
yield model_info, api_provider, client
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
根据总tokens和惩罚值选择的模型 (负载均衡)
"""
least_used_model_name = min(
self.model_usage,

View File

@@ -1,9 +1,7 @@
from typing import Any, Dict, List, Optional, Type, Union
from datetime import datetime
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.tool_history import ToolHistoryManager
from src.common.logger import get_logger
logger = get_logger("tool_api")
@@ -34,109 +32,3 @@ def get_llm_available_tool_definitions():
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
def get_tool_history(
tool_names: Optional[List[str]] = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
chat_id: Optional[str] = None,
limit: Optional[int] = None,
status: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
获取工具调用历史记录
Args:
tool_names: 工具名称列表,为空则查询所有工具
start_time: 开始时间可以是datetime对象或ISO格式字符串
end_time: 结束时间可以是datetime对象或ISO格式字符串
chat_id: 会话ID用于筛选特定会话的调用
limit: 返回记录数量限制
status: 执行状态筛选("completed""error")
Returns:
List[Dict]: 工具调用记录列表,每条记录包含以下字段:
- tool_name: 工具名称
- timestamp: 调用时间
- arguments: 调用参数
- result: 调用结果
- execution_time: 执行时间
- status: 执行状态
- chat_id: 会话ID
"""
history_manager = ToolHistoryManager()
return history_manager.query_history(
tool_names=tool_names,
start_time=start_time,
end_time=end_time,
chat_id=chat_id,
limit=limit,
status=status
)
def get_tool_history_text(
tool_names: Optional[List[str]] = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
chat_id: Optional[str] = None,
limit: Optional[int] = None,
status: Optional[str] = None
) -> str:
"""
获取工具调用历史记录的文本格式
Args:
tool_names: 工具名称列表,为空则查询所有工具
start_time: 开始时间可以是datetime对象或ISO格式字符串
end_time: 结束时间可以是datetime对象或ISO格式字符串
chat_id: 会话ID用于筛选特定会话的调用
limit: 返回记录数量限制
status: 执行状态筛选("completed""error")
Returns:
str: 格式化的工具调用历史记录文本
"""
history = get_tool_history(
tool_names=tool_names,
start_time=start_time,
end_time=end_time,
chat_id=chat_id,
limit=limit,
status=status
)
if not history:
return "没有找到工具调用记录"
text = "工具调用历史记录:\n"
for record in history:
# 提取结果中的name和content
result = record['result']
if isinstance(result, dict):
name = result.get('name', record['tool_name'])
content = result.get('content', str(result))
else:
name = record['tool_name']
content = str(result)
# 格式化内容
content = content.strip().replace('\n', ' ')
if len(content) > 200:
content = content[:200] + "..."
# 格式化时间
timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
text += f"[{timestamp}] {name}\n"
text += f"结果: {content}\n\n"
return text
def clear_tool_history() -> None:
"""
清除所有工具调用历史记录
"""
history_manager = ToolHistoryManager()
history_manager.clear_history()

View File

@@ -119,7 +119,7 @@ class BaseEvent:
for i, result in enumerate(results):
subscriber = sorted_subscribers[i]
handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__
if result:
if isinstance(result, Exception):
# 处理执行异常
logger.error(f"事件处理器 {handler_name} 执行失败: {result}")

View File

@@ -26,7 +26,6 @@ class BaseEventHandler(ABC):
def __init__(self):
self.log_prefix = "[EventHandler]"
self.plugin_name = ""
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Type, Tuple, Union, TYPE_CHECKING
from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
"""
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, List
from typing import Tuple, Optional, List
import re
from src.common.logger import get_logger

View File

@@ -167,6 +167,7 @@ class ComponentRegistry:
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
action_class.plugin_name = action_info.plugin_name
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
@@ -184,6 +185,7 @@ class ComponentRegistry:
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
command_class.plugin_name = command_info.plugin_name
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
@@ -213,6 +215,7 @@ class ComponentRegistry:
if not hasattr(self, '_plus_command_registry'):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
self._plus_command_registry[plus_command_name] = plus_command_class
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
@@ -222,6 +225,7 @@ class ComponentRegistry:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
tool_class.plugin_name = tool_info.plugin_name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
@@ -246,6 +250,7 @@ class ComponentRegistry:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
handler_class.plugin_name = handler_info.plugin_name
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class)

View File

@@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("tool_use")
@@ -184,21 +186,65 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
"""执行单个工具调用,并处理缓存"""
Args:
tool_call: 工具调用对象
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
return result
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:

View File

@@ -24,6 +24,7 @@ from .services.qzone_service import QZoneService
from .services.scheduler_service import SchedulerService
from .services.monitor_service import MonitorService
from .services.cookie_service import CookieService
from .services.reply_tracker_service import ReplyTrackerService
from .services.manager import register_service
logger = get_logger("MaiZone.Plugin")
@@ -99,11 +100,13 @@ class MaiZoneRefactoredPlugin(BasePlugin):
content_service = ContentService(self.get_config)
image_service = ImageService(self.get_config)
cookie_service = CookieService(self.get_config)
reply_tracker_service = ReplyTrackerService()
qzone_service = QZoneService(self.get_config, content_service, image_service, cookie_service)
scheduler_service = SchedulerService(self.get_config, qzone_service)
monitor_service = MonitorService(self.get_config, qzone_service)
register_service("qzone", qzone_service)
register_service("reply_tracker", reply_tracker_service)
register_service("get_config", self.get_config)
# 保存服务引用以便后续启动

View File

@@ -9,12 +9,9 @@ import datetime
import base64
import aiohttp
from src.common.logger import get_logger
import base64
import aiohttp
import imghdr
import asyncio
from src.common.logger import get_logger
from src.plugin_system.apis import llm_api, config_api, generator_api, person_api
from src.plugin_system.apis import llm_api, config_api, generator_api
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.llm_models.utils_model import LLMRequest

View File

@@ -27,6 +27,7 @@ from src.chat.utils.chat_message_builder import (
from .content_service import ContentService
from .image_service import ImageService
from .cookie_service import CookieService
from .reply_tracker_service import ReplyTrackerService
logger = get_logger("MaiZone.QZoneService")
@@ -55,6 +56,7 @@ class QZoneService:
self.content_service = content_service
self.image_service = image_service
self.cookie_service = cookie_service
self.reply_tracker = ReplyTrackerService()
# --- Public Methods (High-Level Business Logic) ---
@@ -154,7 +156,8 @@ class QZoneService:
# --- 第一步: 单独处理自己说说的评论 ---
if self.get_config("monitor.enable_auto_reply", False):
try:
own_feeds = await api_client["list_feeds"](qq_account, 5) # 获取自己最近5条说说
# 传入新参数,表明正在检查自己的说说
own_feeds = await api_client["list_feeds"](qq_account, 5, is_monitoring_own_feeds=True)
if own_feeds:
logger.info(f"获取到自己 {len(own_feeds)} 条说说,检查评论...")
for feed in own_feeds:
@@ -248,42 +251,83 @@ class QZoneService:
content = feed.get("content", "")
fid = feed.get("tid", "")
if not comments:
if not comments or not fid:
return
# 筛选出未被自己回复过的评论
if not comments:
# 1. 将评论分为用户评论和自己回复
user_comments = [c for c in comments if str(c.get('qq_account')) != str(qq_account)]
my_replies = [c for c in comments if str(c.get('qq_account')) == str(qq_account)]
if not user_comments:
return
# 找到所有我已经回复过的评论的ID
replied_to_tids = {
c['parent_tid'] for c in comments
if c.get('parent_tid') and str(c.get('qq_account')) == str(qq_account)
}
# 2. 验证已记录的回复是否仍然存在,清理已删除的回复记录
await self._validate_and_cleanup_reply_records(fid, my_replies)
# 找出所有非我发出且我未回复的评论
comments_to_reply = [
c for c in comments
if str(c.get('qq_account')) != str(qq_account) and c.get('comment_tid') not in replied_to_tids
]
# 3. 使用验证后的持久化记录来筛选未回复的评论
comments_to_reply = []
for comment in user_comments:
comment_tid = comment.get('comment_tid')
if not comment_tid:
continue
# 检查是否已经在持久化记录中标记为已回复
if not self.reply_tracker.has_replied(fid, comment_tid):
comments_to_reply.append(comment)
if not comments_to_reply:
logger.debug(f"说说 {fid} 下的所有评论都已回复过")
return
logger.info(f"发现自己说说下的 {len(comments_to_reply)} 条新评论,准备回复...")
for comment in comments_to_reply:
comment_tid = comment.get("comment_tid")
nickname = comment.get("nickname", "")
comment_content = comment.get("content", "")
try:
reply_content = await self.content_service.generate_comment_reply(
content, comment.get("content", ""), comment.get("nickname", "")
content, comment_content, nickname
)
if reply_content:
success = await api_client["reply"](
fid, qq_account, comment.get("nickname", ""), reply_content, comment.get("comment_tid")
fid, qq_account, nickname, reply_content, comment_tid
)
if success:
logger.info(f"成功回复'{comment.get('nickname', '')}'的评论: '{reply_content}'")
# 标记为已回复
self.reply_tracker.mark_as_replied(fid, comment_tid)
logger.info(f"成功回复'{nickname}'的评论: '{reply_content}'")
else:
logger.error(f"回复'{comment.get('nickname', '')}'的评论失败")
logger.error(f"回复'{nickname}'的评论失败")
await asyncio.sleep(random.uniform(10, 20))
else:
logger.warning(f"生成回复内容失败,跳过回复'{nickname}'的评论")
except Exception as e:
logger.error(f"回复'{nickname}'的评论时发生异常: {e}", exc_info=True)
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
"""验证并清理已删除的回复记录"""
# 获取当前记录中该说说的所有已回复评论ID
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
if not recorded_replied_comments:
return
# 从API返回的我的回复中提取parent_tid即被回复的评论ID
current_replied_comments = set()
for reply in my_replies:
parent_tid = reply.get('parent_tid')
if parent_tid:
current_replied_comments.add(parent_tid)
# 找出记录中有但实际已不存在的回复
deleted_replies = recorded_replied_comments - current_replied_comments
if deleted_replies:
logger.info(f"检测到 {len(deleted_replies)} 个回复已被删除,清理记录...")
for comment_tid in deleted_replies:
self.reply_tracker.remove_reply_record(fid, comment_tid)
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
"""处理单条说说,决定是否评论和点赞"""
@@ -641,7 +685,7 @@ class QZoneService:
logger.error(f"上传图片 {index+1} 异常: {e}", exc_info=True)
return None
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
async def _list_feeds(t_qq: str, num: int, is_monitoring_own_feeds: bool = False) -> List[Dict]:
"""获取指定用户说说列表"""
try:
params = {
@@ -667,10 +711,14 @@ class QZoneService:
feeds_list = []
my_name = json_data.get("logininfo", {}).get("name", "")
for msg in json_data.get("msglist", []):
# 只有在处理好友说说时,才检查是否已评论并跳过
if not is_monitoring_own_feeds:
is_commented = any(
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
)
if not is_commented:
if is_commented:
continue
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
comments = []

View File

@@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
"""
评论回复跟踪服务
负责记录和管理已回复过的评论ID避免重复回复
"""
import json
import time
from pathlib import Path
from typing import Set, Dict, Any
from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService")
class ReplyTrackerService:
"""
评论回复跟踪服务
使用本地JSON文件持久化存储已回复的评论ID
"""
def __init__(self):
# 数据存储路径
self.data_dir = Path(__file__).resolve().parent.parent / "data"
self.data_dir.mkdir(exist_ok=True)
self.reply_record_file = self.data_dir / "replied_comments.json"
# 内存中的已回复评论记录
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
self.replied_comments: Dict[str, Dict[str, float]] = {}
# 数据清理配置
self.max_record_days = 30 # 保留30天的记录
# 加载已有数据
self._load_data()
def _load_data(self):
"""从文件加载已回复评论数据"""
try:
if self.reply_record_file.exists():
with open(self.reply_record_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.replied_comments = data
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录")
else:
logger.info("未找到回复记录文件,将创建新的记录")
except Exception as e:
logger.error(f"加载回复记录失败: {e}")
self.replied_comments = {}
def _save_data(self):
"""保存已回复评论数据到文件"""
try:
# 清理过期数据
self._cleanup_old_records()
with open(self.reply_record_file, 'w', encoding='utf-8') as f:
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
logger.debug("回复记录已保存")
except Exception as e:
logger.error(f"保存回复记录失败: {e}")
def _cleanup_old_records(self):
"""清理超过保留期限的记录"""
current_time = time.time()
cutoff_time = current_time - (self.max_record_days * 24 * 60 * 60)
feeds_to_remove = []
total_removed = 0
for feed_id, comments in self.replied_comments.items():
comments_to_remove = []
for comment_id, timestamp in comments.items():
if timestamp < cutoff_time:
comments_to_remove.append(comment_id)
# 移除过期的评论记录
for comment_id in comments_to_remove:
del comments[comment_id]
total_removed += 1
# 如果该说说下没有任何记录了,标记删除整个说说记录
if not comments:
feeds_to_remove.append(feed_id)
# 移除空的说说记录
for feed_id in feeds_to_remove:
del self.replied_comments[feed_id]
if total_removed > 0:
logger.info(f"清理了 {total_removed} 条过期的回复记录")
def has_replied(self, feed_id: str, comment_id: str) -> bool:
"""
检查是否已经回复过指定的评论
Args:
feed_id: 说说ID
comment_id: 评论ID
Returns:
bool: 如果已回复过返回True否则返回False
"""
if not feed_id or not comment_id:
return False
return (feed_id in self.replied_comments and
comment_id in self.replied_comments[feed_id])
def mark_as_replied(self, feed_id: str, comment_id: str):
"""
标记指定评论为已回复
Args:
feed_id: 说说ID
comment_id: 评论ID
"""
if not feed_id or not comment_id:
logger.warning("feed_id 或 comment_id 为空,无法标记为已回复")
return
current_time = time.time()
if feed_id not in self.replied_comments:
self.replied_comments[feed_id] = {}
self.replied_comments[feed_id][comment_id] = current_time
# 保存到文件
self._save_data()
logger.info(f"已标记评论为已回复: feed_id={feed_id}, comment_id={comment_id}")
def get_replied_comments(self, feed_id: str) -> Set[str]:
"""
获取指定说说下所有已回复的评论ID
Args:
feed_id: 说说ID
Returns:
Set[str]: 已回复的评论ID集合
"""
if feed_id in self.replied_comments:
return set(self.replied_comments[feed_id].keys())
return set()
def get_stats(self) -> Dict[str, Any]:
"""
获取回复记录统计信息
Returns:
Dict: 包含统计信息的字典
"""
total_feeds = len(self.replied_comments)
total_replies = sum(len(comments) for comments in self.replied_comments.values())
return {
"total_feeds_with_replies": total_feeds,
"total_replied_comments": total_replies,
"data_file": str(self.reply_record_file),
"max_record_days": self.max_record_days
}
def remove_reply_record(self, feed_id: str, comment_id: str):
"""
移除指定评论的回复记录
Args:
feed_id: 说说ID
comment_id: 评论ID
"""
if feed_id in self.replied_comments and comment_id in self.replied_comments[feed_id]:
del self.replied_comments[feed_id][comment_id]
# 如果该说说下没有任何回复记录了,删除整个说说记录
if not self.replied_comments[feed_id]:
del self.replied_comments[feed_id]
self._save_data()
logger.debug(f"已移除回复记录: feed_id={feed_id}, comment_id={comment_id}")
def remove_feed_records(self, feed_id: str):
"""
移除指定说说的所有回复记录
Args:
feed_id: 说说ID
"""
if feed_id in self.replied_comments:
del self.replied_comments[feed_id]
self._save_data()
logger.info(f"已移除说说 {feed_id} 的所有回复记录")

View File

@@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker
from src.plugin_system.utils.permission_decorators import require_permission
logger = get_logger("Permission")

View File

@@ -411,7 +411,6 @@ class ScheduleManager:
通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。
新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。
"""
from src.chat.chat_loop.wakeup_manager import WakeUpManager
# --- 基础检查 ---
if not global_config.schedule.enable_is_sleep:
return False

View File

@@ -21,7 +21,7 @@ max_retry = 2
timeout = 30
retry_interval = 10
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"aiohttp_gemini"
name = "Google"
base_url = "https://api.google.com/v1"
api_key = "your-google-api-key-1"