diff --git a/.gitignore b/.gitignore index fe36449a9..e9506861e 100644 --- a/.gitignore +++ b/.gitignore @@ -325,6 +325,8 @@ run_pet.bat !/plugins/set_emoji_like !/plugins/permission_example !/plugins/hello_world_plugin +!/plugins/take_picture_plugin +!/plugins/napcat_adapter_plugin !/plugins/echo_example config.toml diff --git a/bot.py b/bot.py index 6298b09eb..29cae9a90 100644 --- a/bot.py +++ b/bot.py @@ -31,7 +31,6 @@ from src.manager.async_task_manager import async_task_manager # noqa from src.config.config import global_config # noqa from src.common.database.database import initialize_sql_database # noqa from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa -from src.common.tool_history import wrap_tool_executor #noqa logger = get_logger("main") @@ -240,8 +239,7 @@ class MaiBotMain(BaseMain): self.setup_timezone() self.check_and_confirm_eula() self.initialize_database() - # 初始化工具历史记录 - wrap_tool_executor() + return self.create_main_system() diff --git a/changelogs/changelog.md b/changelogs/changelog.md index 9369fbdc2..33e6e5f23 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -1,5 +1,68 @@ # Changelog +## [0.10.0-alpha] - 2025-8-28 + +> **MoFox-Bot 0.10.0-alpha 版本发布!** +> +> 本次更新带来了多项核心功能增强和系统优化。 +> +> 在**新功能**方面,我们引入了**持久化回复跟踪**以避免重复回复,并为LLM请求实现了**模型故障转移机制**,提高了系统的健壮性。插件系统也得到了增强,增加了**事件触发和订阅的白名单机制**,并为事件处理器添加了**异步锁和并行执行**支持。此外,新版本还实现了对**说说中图片的识别与理解**,并引入了**弹性睡眠与睡前通知机制**,使机器人的作息更加智能化。 +> +> 在**修复**方面,我们解决了`enable`配置的bug,修复了event权限问题,并处理了专注模式下艾特不回复、模型信息不存在、cookie获取失败等多个问题,提升了系统的稳定性。 +> +> 在**重构**方面,我们移除了`changelog_config`并更新了模型配置模板,重构了工具缓存机制和LLM请求重试逻辑。同时,我们将项目从`MaiMbot-Pro-Max`正式更名为`MoFox_Bot`,并对内存、权限、配置等多个模块进行了优化。 +> +> 总的来说,0.11.0版本在功能、稳定性和代码质量上都有了显著提升。 + +### 新功能 +- **maizone**: 引入持久化回复跟踪以避免重复回复 +- **llm**: 为LLM请求实现模型故障转移机制 +- **plugin-system**: 添加事件触发和订阅的白名单机制 +- **plugin**: 为事件处理器添加异步锁和并行执行支持 +- **maizone**: 实现对说说中图片的识别与理解 +- **sleep**: 实现睡眠唤醒与重新入睡机制 +- **core**: 实现HFC及睡眠状态的持久化 +- **schedule**: 引入弹性睡眠与睡前通知机制 +- **monthly_plan**: 增加月度计划数量上限并自动清理 +- **command**: 添加PlusCommand增强命令系统 +- **expression**: 重构表达学习配置,引入基于规则的结构化定义 +- **chat**: 实现睡眠压力和失眠系统 +- **schedule**: 重构日程与月度计划管理模块 +- **core**: 集成统一向量数据库服务并重构相关模块 +- **tool_system**: 实现工具的声明式缓存 +- **plugin_system**: 增加工具执行日志记录 +- **maizone**: 新增QQ空间互通组功能,根据聊天上下文生成说说 +- **monthly_plan**: 增强月度计划系统,引入状态管理和智能抽取 + +### 修复 +- 修复`enable`配置 +- 修复event权限,现在每个component都拥有`plugin_name`属性 +- 修复专注模式下艾特不回复的问题 +- 修复模型信息不存在时引发的属性错误 +- 修复`maizone_refactored`获取cookie时响应为空导致的错误 +- 修复关键词非列表形式时导致的解析错误 +- 修复即时记忆的 orjson 编码与解码问题 +- 修复空回复检测,同时修复`tool_call` +- 处理截断消息时`message`为`None`的情况 +- 修复`get_remaining`的起始索引 +- 修复回复自己评论的问题 +- 修复`enable`配置 + +### 重构 +- **config**: 移除`changelog_config`并更新模型配置模板 +- **cache**: 重构工具缓存机制并优化LLM请求重试逻辑 +- **core**: 移除工具历史记录管理器并将缓存集成到工具执行器中 +- 重构权限检查和装饰器用法 +- **core**: 将项目从`MaiMbot-Pro-Max`重命名为`MoFox_Bot` +- **memory**: 重构向量记忆清理逻辑以提高稳定性 +- **llm_models**: 移除官方Gemini客户端并改用`aiohttp`实现 +- **config**: 整合搜索服务配置并移除废弃选项 +- **config**: 将反截断设置移至模型配置 +- **video**: 重构视频分析,增加抽帧模式和间隔配置 + +从这里开始都是第三方改版的更新记录!!!!! +======================================================== + ## [0.10.0] - 2025-7-1 ### 主要功能更改 - 工具系统重构,现在合并到了插件系统中 diff --git a/changelogs/changelog_config.md b/changelogs/changelog_config.md deleted file mode 100644 index 5aa5fb922..000000000 --- a/changelogs/changelog_config.md +++ /dev/null @@ -1,51 +0,0 @@ -# Changelog - -## [1.0.3] - 2025-3-31 -### Added -- 新增了心流相关配置项: - - `heartflow` 配置项,用于控制心流功能 - -### Removed -- 移除了 `response` 配置项中的 `model_r1_probability` 和 `model_v3_probability` 选项 -- 移除了次级推理模型相关配置 - -## [1.0.1] - 2025-3-30 -### Added -- 增加了流式输出控制项 `stream` -- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题 - -## [1.0.0] - 2025-3-30 -### Added -- 修复了错误的版本命名 -- 杀掉了所有无关文件 - -## [0.0.11] - 2025-3-12 -### Added -- 新增了 `schedule` 配置项,用于配置日程表生成功能 -- 新增了 `response_splitter` 配置项,用于控制回复分割 -- 新增了 `experimental` 配置项,用于实验性功能开关 -- 新增了 `llm_observation` 和 `llm_sub_heartflow` 模型配置 -- 新增了 `llm_heartflow` 模型配置 -- 在 `personality` 配置项中新增了 `prompt_schedule_gen` 参数 - -### Changed -- 优化了模型配置的组织结构 -- 调整了部分配置项的默认值 -- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置 -- 在 `message` 配置项中: - - 新增了 `model_max_output_length` 参数 -- 在 `willing` 配置项中新增了 `emoji_response_penalty` 参数 -- 将 `personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen` - -### Removed -- 移除了 `min_text_length` 配置项 -- 移除了 `cq_code` 配置项 -- 移除了 `others` 配置项(其功能已整合到 `experimental` 中) - -## [0.0.5] - 2025-3-11 -### Added -- 新增了 `alias_names` 配置项,用于指定麦麦的别名。 - -## [0.0.4] - 2025-3-9 -### Added -- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。 \ No newline at end of file diff --git a/docs/PERMISSION_SYSTEM.md b/docs/architecture/PERMISSION_SYSTEM.md similarity index 100% rename from docs/PERMISSION_SYSTEM.md rename to docs/architecture/PERMISSION_SYSTEM.md diff --git a/docs/memory_system_design_v3.md b/docs/architecture/memory_system_design_v3.md similarity index 100% rename from docs/memory_system_design_v3.md rename to docs/architecture/memory_system_design_v3.md diff --git a/docs/plugins/image/quick-start/1750326700269.png b/docs/assets/1750326700269.png similarity index 100% rename from docs/plugins/image/quick-start/1750326700269.png rename to docs/assets/1750326700269.png diff --git a/docs/plugins/image/quick-start/1750332508760.png b/docs/assets/1750332508760.png similarity index 100% rename from docs/plugins/image/quick-start/1750332508760.png rename to docs/assets/1750332508760.png diff --git a/docs/image-1.png b/docs/assets/image-1.png similarity index 100% rename from docs/image-1.png rename to docs/assets/image-1.png diff --git a/docs/image.png b/docs/assets/image.png similarity index 100% rename from docs/image.png rename to docs/assets/image.png diff --git a/docs/deployment_guide.md b/docs/deployment_guide.md new file mode 100644 index 000000000..54cc618d3 --- /dev/null +++ b/docs/deployment_guide.md @@ -0,0 +1,124 @@ +# MoFox_Bot 部署指南 + +欢迎使用 MoFox_Bot!本指南将引导您完成在 Windows 环境下部署 MoFox_Bot 的全部过程。 + +## 1. 系统要求 + +- **操作系统**: Windows 10 或 Windows 11 +- **Python**: 版本 >= 3.10 +- **Git**: 用于克隆项目仓库 +- **uv**: 推荐的 Python 包管理器 (版本 >= 0.1.0) + +## 2. 部署步骤 + +### 第一步:获取必要的文件 + +首先,创建一个用于存放 MoFox_Bot 相关文件的文件夹,并通过 `git` 克隆 MoFox_Bot 主程序和 Napcat 适配器。 + +```shell +mkdir MoFox_Bot_Deployment +cd MoFox_Bot_Deployment +git clone hhttps://github.com/MoFox-Studio/MoFox_Bot.git +git clone https://github.com/MoFox-Studio/Napcat-Adapter.git +``` + +### 第二步:环境配置 + +我们推荐使用 `uv` 来管理 Python 环境和依赖,因为它提供了更快的安装速度和更好的依赖管理体验。 + +**安装 uv:** + +```shell +pip install uv +``` + +### 第三步:依赖安装 + +**1. 安装 MoFox_Bot 依赖:** + +进入 `mmc` 文件夹,创建虚拟环境并安装依赖。 + +```shell +cd mmc +uv venv +uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade +``` + +**2. 安装 Napcat-Adapter 依赖:** + +回到上一级目录,进入 `Napcat-Adapter` 文件夹,创建虚拟环境并安装依赖。 + +```shell +cd .. +cd Napcat-Adapter +uv venv +uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple --upgrade +``` + +### 第四步:配置 MoFox_Bot 和 Adapter + +**1. MoFox_Bot 配置:** + +- 在 `mmc` 文件夹中,将 `template/bot_config_template.toml` 复制到 `config/bot_config.toml`。 +- 将 `template/model_config_template.toml` 复制到 `config/model_config.toml`。 +- 根据 [模型配置指南](guides/model_configuration_guide.md) 和 `bot_config.toml` 文件中的注释,填写您的 API Key 和其他相关配置。 + +**2. Napcat-Adapter 配置:** + +- 在 `Napcat-Adapter` 文件夹中,将 `template/template_config.toml` 复制到根目录并改名为 `config.toml`。 +- 打开 `config.toml` 文件,配置 `[Napcat_Server]` 和 `[MaiBot_Server]` 字段。 + - `[Napcat_Server]` 的 `port` 应与 Napcat 设置的反向代理 URL 中的端口相同。 + - `[MaiBot_Server]` 的 `port` 应与 MoFox_Bot 的 `bot_config.toml` 中设置的端口相同。 + +### 第五步:运行 + +**1. 启动 Napcat:** + +请参考 [NapCatQQ 文档](https://napcat-qq.github.io/) 进行部署和启动。 + +**2. 启动 MoFox_Bot:** + +进入 `mmc` 文件夹,使用 `uv` 运行。 + +```shell +cd mmc +uv run python bot.py +``` + +**3. 启动 Napcat-Adapter:** + +打开一个新的终端窗口,进入 `Napcat-Adapter` 文件夹,使用 `uv` 运行。 + +```shell +cd Napcat-Adapter +uv run python main.py +``` + +至此,MoFox_Bot 已成功部署并运行。 + +## 3. 详细配置说明 + +### `bot_config.toml` + +这是 MoFox_Bot 的主配置文件,包含了机器人昵称、主人QQ、命令前缀、数据库设置等。请根据文件内的注释进行详细配置。 + +### `model_config.toml` + +此文件用于配置 AI 模型和 API 服务提供商。详细配置方法请参考 [模型配置指南](guides/model_configuration_guide.md)。 + +### 插件配置 + +每个插件都有独立的配置文件,位于 `mmc/config/plugins/` 目录下。插件的配置由其 `config_schema` 自动生成。详细信息请参考 [插件配置完整指南](plugins/configuration-guide.md)。 + +## 4. 故障排除 + +- **依赖安装失败**: + - 尝试更换 PyPI 镜像源。 + - 检查网络连接。 +- **API 调用失败**: + - 检查 `model_config.toml` 中的 API Key 和 `base_url` 是否正确。 +- **无法连接到 Napcat**: + - 检查 Napcat 是否正常运行。 + - 确认 `Napcat-Adapter` 的 `config.toml` 中 `[Napcat_Server]` 的 `port` 是否与 Napcat 设置的端口一致。 + +如果遇到其他问题,请查看 `logs/` 目录下的日志文件以获取详细的错误信息。 \ No newline at end of file diff --git a/docs/CONTRIBUTE.md b/docs/development/CONTRIBUTE.md similarity index 100% rename from docs/CONTRIBUTE.md rename to docs/development/CONTRIBUTE.md diff --git a/docs/model_configuration_guide.md b/docs/guides/model_configuration_guide.md similarity index 85% rename from docs/model_configuration_guide.md rename to docs/guides/model_configuration_guide.md index 2753e92bd..3ef495eca 100644 --- a/docs/model_configuration_guide.md +++ b/docs/guides/model_configuration_guide.md @@ -43,12 +43,11 @@ retry_interval = 10 # 重试间隔(秒) | `name` | ✅ | 服务商名称,需要在模型配置中引用 | - | | `base_url` | ✅ | API服务的基础URL | - | | `api_key` | ✅ | API密钥,请替换为实际密钥 | - | -| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` | +| `client_type` | ❌ | 客户端类型:`openai`、`gemini` 或 `aiohttp_gemini` | `openai` | | `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 | | `timeout` | ❌ | API请求超时时间(秒) | 30 | | `retry_interval` | ❌ | 重试间隔时间(秒) | 10 | -**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。** ### 2.3 支持的服务商示例 #### DeepSeek @@ -73,9 +72,9 @@ client_type = "openai" ```toml [[api_providers]] name = "Google" -base_url = "https://api.google.com/v1" +base_url = "https://generativelanguage.googleapis.com/v1beta" # 在MoFox-Bot中, 使用aiohttp_gemini客户端的提供商可以自定义base_url api_key = "your-google-api-key" -client_type = "gemini" # 注意:Gemini需要使用特殊客户端 +client_type = "aiohttp_gemini" # 注意:Gemini需要使用特殊客户端 ``` ## 3. 模型配置 @@ -118,11 +117,11 @@ enable_thinking = false # 禁用思考 比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。 -![SiliconFlow文档截图](image-1.png) +![SiliconFlow文档截图](../assets/image-1.png) 以豆包文档为另一个例子 -![豆包文档截图](image.png) +![豆包文档截图](../assets/image.png) 得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为 ```toml @@ -133,7 +132,6 @@ thinking = {type = "disabled"} # 禁用思考 ``` 请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。 -**请注意,对于`client_type`为`gemini`的模型,此字段无效。** ### 3.3 配置参数说明 | 参数 | 必填 | 说明 | @@ -145,6 +143,7 @@ thinking = {type = "disabled"} # 禁用思考 | `price_out` | ❌ | 输出价格(元/M token),用于成本统计 | | `force_stream_mode` | ❌ | 是否强制使用流式输出 | | `extra_params` | ❌ | 额外的模型参数配置 | +| `anti_truncation` | ❌ | 是否启用反截断功能 | ## 4. 模型任务配置 @@ -184,7 +183,7 @@ max_tokens = 800 ``` ### planner - 决策模型 -负责决定MaiBot该做什么: +负责决定MoFox_Bot该做什么: ```toml [model_task_config.planner] model_list = ["siliconflow-deepseek-v3"] @@ -193,7 +192,7 @@ max_tokens = 800 ``` ### emotion - 情绪模型 -负责MaiBot的情绪变化: +负责MoFox_Bot的情绪变化: ```toml [model_task_config.emotion] model_list = ["siliconflow-deepseek-v3"] @@ -262,6 +261,44 @@ temperature = 0.7 max_tokens = 800 ``` +### schedule_generator - 日程生成模型 +```toml +[model_task_config.schedule_generator] +model_list = ["deepseek-v3"] +temperature = 0.5 +max_tokens = 1024 +``` + +### monthly_plan_generator - 月度计划生成模型 +```toml +[model_task_config.monthly_plan_generator] +model_list = ["deepseek-v3"] +temperature = 0.7 +max_tokens = 1024 +``` + +### emoji_vlm - 表情包VLM模型 +```toml +[model_task_config.emoji_vlm] +model_list = ["qwen-vl-max"] +max_tokens = 800 +``` + +### anti_injection - 反注入模型 +```toml +[model_task_config.anti_injection] +model_list = ["deepseek-v3"] +temperature = 0.1 +max_tokens = 512 +``` + +### utils_video - 视频分析模型 +```toml +[model_task_config.utils_video] +model_list = ["qwen-vl-max"] +max_tokens = 800 +``` + ## 5. 配置建议 ### 5.1 Temperature 参数选择 @@ -276,7 +313,7 @@ max_tokens = 800 | 任务类型 | 推荐模型类型 | 示例 | |----------|--------------|------| -| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 | +| 高精度任务 | 大模型 | DeepSeek-V3, GPT-5,Gemini-2.5-Pro | | 高频率任务 | 小模型 | Qwen3-8B | | 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice | | 工具调用 | 支持Function Call的模型 | Qwen3-14B | @@ -285,7 +322,6 @@ max_tokens = 800 1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型 2. **合理配置max_tokens**:根据实际需求设置,避免浪费 -3. **选择免费模型**:对于测试环境,优先使用price为0的模型 ## 6. 配置验证 diff --git a/docs/vector_db_usage_guide.md b/docs/guides/vector_db_usage_guide.md similarity index 100% rename from docs/vector_db_usage_guide.md rename to docs/guides/vector_db_usage_guide.md diff --git a/docs/Bing.md b/docs/integrations/Bing.md similarity index 100% rename from docs/Bing.md rename to docs/integrations/Bing.md diff --git a/docs/PLUS_COMMAND_GUIDE.md b/docs/plugins/PLUS_COMMAND_GUIDE.md similarity index 100% rename from docs/PLUS_COMMAND_GUIDE.md rename to docs/plugins/PLUS_COMMAND_GUIDE.md diff --git a/docs/plugins/command-components.md b/docs/plugins/command-components.md deleted file mode 100644 index 77cc8accf..000000000 --- a/docs/plugins/command-components.md +++ /dev/null @@ -1,89 +0,0 @@ -# 💻 Command组件详解 - -## 📖 什么是Command - -Command是直接响应用户明确指令的组件,与Action不同,Command是**被动触发**的,当用户输入特定格式的命令时立即执行。 - -Command通过正则表达式匹配用户输入,提供确定性的功能服务。 - -### 🎯 Command的特点 - -- 🎯 **确定性执行**:匹配到命令立即执行,无随机性 -- ⚡ **即时响应**:用户主动触发,快速响应 -- 🔍 **正则匹配**:通过正则表达式精确匹配用户输入 -- 🛑 **拦截控制**:可以控制是否阻止消息继续处理 -- 📝 **参数解析**:支持从用户输入中提取参数 - ---- - -## 🛠️ Command组件的基本结构 - -首先,Command组件需要继承自`BaseCommand`类,并实现必要的方法。 - -```python -class ExampleCommand(BaseCommand): - command_name = "example" # 命令名称,作为唯一标识符 - command_description = "这是一个示例命令" # 命令描述 - command_pattern = r"" # 命令匹配的正则表达式 - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """ - 执行Command的主要逻辑 - - Returns: - Tuple[bool, str, bool]: - - 第一个bool表示是否成功执行 - - 第二个str是执行结果消息 - - 第三个bool表示是否需要阻止消息继续处理 - """ - # ---- 执行命令的逻辑 ---- - return True, "执行成功", False -``` -**`command_pattern`**: 该Command匹配的正则表达式,用于精确匹配用户输入。 - -请注意:如果希望能获取到命令中的参数,请在正则表达式中使用有命名的捕获组,例如`(?Ppattern)`。 - -这样在匹配时,内部实现可以使用`re.match.groupdict()`方法获取到所有捕获组的参数,并以字典的形式存储在`self.matched_groups`中。 - -### 匹配样例 -假设我们有一个命令`/example param1=value1 param2=value2`,对应的正则表达式可以是: - -```python -class ExampleCommand(BaseCommand): - command_name = "example" - command_description = "这是一个示例命令" - command_pattern = r"/example (?P\w+) (?P\w+)" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - # 获取匹配的参数 - param1 = self.matched_groups.get("param1") - param2 = self.matched_groups.get("param2") - - # 执行逻辑 - return True, f"参数1: {param1}, 参数2: {param2}", False -``` - ---- - -## Command 内置方法说明 -```python -class BaseCommand: - def get_config(self, key: str, default=None): - """获取插件配置值,使用嵌套键访问""" - - async def send_text(self, content: str, reply_to: str = "") -> bool: - """发送回复消息""" - - async def send_type(self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "") -> bool: - """发送指定类型的回复消息到当前聊天环境""" - - async def send_command(self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True) -> bool: - """发送命令消息""" - - async def send_emoji(self, emoji_base64: str) -> bool: - """发送表情包""" - - async def send_image(self, image_base64: str) -> bool: - """发送图片""" -``` -具体参数与用法参见`BaseCommand`基类的定义。 \ No newline at end of file diff --git a/docs/plugins/index.md b/docs/plugins/index.md index fe999f393..c39efe72e 100644 --- a/docs/plugins/index.md +++ b/docs/plugins/index.md @@ -9,7 +9,7 @@ ## 组件功能详解 - [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件 -- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件 +- [💻 Command组件详解](PLUS_COMMAND_GUIDE.md) - 学习直接响应命令的组件 - [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力 - [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件 - [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构 diff --git a/docs/plugins/quick-start.md b/docs/plugins/quick-start.md index 34431f80b..ff32a43eb 100644 --- a/docs/plugins/quick-start.md +++ b/docs/plugins/quick-start.md @@ -90,7 +90,7 @@ class HelloWorldPlugin(BasePlugin): 在日志中你应该能看到插件被加载的信息。虽然插件还没有任何功能,但它已经成功运行了! -![1750326700269](image/quick-start/1750326700269.png) +![1750326700269](../assets/1750326700269.png) ### 5. 添加第一个功能:问候Action @@ -180,7 +180,7 @@ MoFox_Bot可能会选择使用你的问候Action,发送回复: 嗨!很开心见到你!😊 ``` -![1750332508760](image/quick-start/1750332508760.png) +![1750332508760](../assets/1750332508760.png) > **💡 小提示**:MoFox_Bot会智能地决定什么时候使用它。如果没有立即看到效果,多试几次不同的消息。 diff --git a/docs/plugins/tool_caching_guide.md b/docs/plugins/tool_caching_guide.md deleted file mode 100644 index d670a9f1a..000000000 --- a/docs/plugins/tool_caching_guide.md +++ /dev/null @@ -1,124 +0,0 @@ -# 自动化工具缓存系统使用指南 - -为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。 - -## 核心概念 - -- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。 -- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。 - -## 如何为你的工具启用缓存 - -为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可: - -### 1. `enable_cache: bool` - -这是启用缓存的总开关。 - -- **类型**: `bool` -- **默认值**: `False` -- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。 - -**示例**: -```python -class MyAwesomeTool(BaseTool): - # ... 其他定义 ... - enable_cache: bool = True -``` - -### 2. `cache_ttl: int` - -设置缓存的生存时间(Time-To-Live)。 - -- **类型**: `int` -- **单位**: 秒 -- **默认值**: `3600` (1小时) -- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。 - -**示例**: -```python -class MyLongTermCacheTool(BaseTool): - # ... 其他定义 ... - enable_cache: bool = True - cache_ttl: int = 86400 # 缓存24小时 -``` - -### 3. `semantic_cache_query_key: Optional[str]` - -启用语义缓存的关键。 - -- **类型**: `Optional[str]` -- **默认值**: `None` -- **作用**: - - 将此属性的值设置为你工具的某个**参数的名称**(字符串)。 - - 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。 - - 如果该值为 `None`,则此工具**仅使用精确缓存**。 - -**示例**: -```python -class WebSurfingTool(BaseTool): - name: str = "web_search" - parameters = [ - ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), - # ... 其他参数 ... - ] - - # --- 缓存配置 --- - enable_cache: bool = True - cache_ttl: int = 7200 # 缓存2小时 - semantic_cache_query_key: str = "query" # <-- 关键! -``` -在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。 - -## 完整示例 - -假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。 - -```python -# in your_plugin/tools/stock_checker.py - -from src.plugin_system import BaseTool, ToolParamType - -class StockCheckerTool(BaseTool): - """ - 一个用于查询股票价格的工具。 - """ - name: str = "get_stock_price" - description: str = "获取指定公司或股票代码的最新价格。" - available_for_llm: bool = True - parameters = [ - ("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None), - ] - - # --- 缓存配置 --- - # 1. 开启缓存 - enable_cache: bool = True - # 2. 股价信息缓存10分钟 - cache_ttl: int = 600 - # 3. 使用 "symbol" 参数进行语义搜索 - semantic_cache_query_key: str = "symbol" - # -------------------- - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - symbol = function_args.get("symbol") - - # ... 这里是你调用外部API获取股票价格的逻辑 ... - # price = await some_stock_api.get_price(symbol) - price = 123.45 # 示例价格 - - return { - "type": "stock_price_result", - "content": f"{symbol} 的当前价格是 ${price}" - } - -``` - -通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力: - -- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。 -- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。 -- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。 - ---- - -现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。 \ No newline at end of file diff --git a/docs/plugins/tool-components.md b/docs/plugins/tool_guide.md similarity index 60% rename from docs/plugins/tool-components.md rename to docs/plugins/tool_guide.md index e27658af8..6e150a1cc 100644 --- a/docs/plugins/tool-components.md +++ b/docs/plugins/tool_guide.md @@ -2,7 +2,7 @@ ## 📖 什么是工具 -工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。 +工具是MoFox_Bot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了MoFox_Bot能够获得的信息量。 ### 🎯 工具的特点 @@ -191,8 +191,7 @@ class WeatherTool(BaseTool): name = "weather_query" # 清晰表达功能 name = "knowledge_search" # 描述性强 name = "stock_price_check" # 功能明确 -``` -#### ❌ 避免的命名 +```#### ❌ 避免的命名 ```python name = "tool1" # 无意义 name = "wq" # 过于简短 @@ -244,3 +243,130 @@ def _format_result(self, data): def _format_result(self, data): return str(data) # 直接返回原始数据 ``` + +--- + +# 自动化工具缓存系统使用指南 + +为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。 + +## 核心概念 + +- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。 +- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。 + +## 如何为你的工具启用缓存 + +为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可: + +### 1. `enable_cache: bool` + +这是启用缓存的总开关。 + +- **类型**: `bool` +- **默认值**: `False` +- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。 + +**示例**: +```python +class MyAwesomeTool(BaseTool): + # ... 其他定义 ... + enable_cache: bool = True +``` + +### 2. `cache_ttl: int` + +设置缓存的生存时间(Time-To-Live)。 + +- **类型**: `int` +- **单位**: 秒 +- **默认值**: `3600` (1小时) +- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。 + +**示例**: +```python +class MyLongTermCacheTool(BaseTool): + # ... 其他定义 ... + enable_cache: bool = True + cache_ttl: int = 86400 # 缓存24小时 +``` + +### 3. `semantic_cache_query_key: Optional[str]` + +启用语义缓存的关键。 + +- **类型**: `Optional[str]` +- **默认值**: `None` +- **作用**: + - 将此属性的值设置为你工具的某个**参数的名称**(字符串)。 + - 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。 + - 如果该值为 `None`,则此工具**仅使用精确缓存**。 + +**示例**: +```python +class WebSurfingTool(BaseTool): + name: str = "web_search" + parameters = [ + ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), + # ... 其他参数 ... + ] + + # --- 缓存配置 --- + enable_cache: bool = True + cache_ttl: int = 7200 # 缓存2小时 + semantic_cache_query_key: str = "query" # <-- 关键! +``` +在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。 + +## 完整示例 + +假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。 + +```python +# in your_plugin/tools/stock_checker.py + +from src.plugin_system import BaseTool, ToolParamType + +class StockCheckerTool(BaseTool): + """ + 一个用于查询股票价格的工具。 + """ + name: str = "get_stock_price" + description: str = "获取指定公司或股票代码的最新价格。" + available_for_llm: bool = True + parameters = [ + ("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None), + ] + + # --- 缓存配置 --- + # 1. 开启缓存 + enable_cache: bool = True + # 2. 股价信息缓存10分钟 + cache_ttl: int = 600 + # 3. 使用 "symbol" 参数进行语义搜索 + semantic_cache_query_key: str = "symbol" + # -------------------- + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + symbol = function_args.get("symbol") + + # ... 这里是你调用外部API获取股票价格的逻辑 ... + # price = await some_stock_api.get_price(symbol) + price = 123.45 # 示例价格 + + return { + "type": "stock_price_result", + "content": f"{symbol} 的当前价格是 ${price}" + } + +``` + +通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力: + +- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。 +- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。 +- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。 + +--- + +现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。 \ No newline at end of file diff --git a/docs/schedule_enhancement (1).md b/docs/schedule_enhancement (1).md deleted file mode 100644 index 1dc2a9b8d..000000000 --- a/docs/schedule_enhancement (1).md +++ /dev/null @@ -1,121 +0,0 @@ -# “月层计划”系统架构设计文档 - -## 1. 系统概述与目标 - -本系统旨在为MoFox_Bot引入一个动态的、由大型语言模型(LLM)驱动的“月层计划”机制。其核心目标是取代静态、预设的任务模板,转而利用LLM在程序启动时自动生成符合Bot人设的、具有时效性的月度计划。这些计划将被存储、管理,并在构建每日日程时被动态抽取和使用,从而极大地丰富日程内容的个性和多样性。 - ---- - -## 2. 核心设计原则 - -- **动态性与智能化:** 所有计划内容均由LLM实时生成,确保其独特性和创造性。 -- **人设一致性:** 计划的生成将严格围绕Bot的核心人设进行,强化角色形象。 -- **持久化与可管理:** 生成的计划将被存入专用数据库表,便于管理和追溯。 -- **消耗性与随机性:** 计划在使用后有一定几率被消耗(删除),模拟真实世界中计划的完成与迭代。 - ---- - -## 3. 系统核心流程规划 - -本系统包含两大核心流程:**启动时的计划生成流程**和**日程构建时的计划使用流程**。 - -### 3.1 流程一:启动时计划生成 - -此流程在每次程序启动时触发,负责填充当月的计划池。 - -```mermaid -graph TD - A[程序启动] --> B{检查当月计划池}; - B -- 计划数量低于阈值 --> C[构建LLM Prompt]; - C -- prompt包含Bot人设、月份等信息 --> D[调用LLM服务]; - D -- LLM返回多个计划文本 --> E[解析并格式化计划]; - E -- 逐条处理 --> F[存入`monthly_plans`数据库表]; - F --> G[完成启动任务]; - B -- 计划数量充足 --> G; -``` - -### 3.2 流程二:日程构建时计划使用 - -此流程在构建每日日程的提示词(Prompt)时触发。 - -```mermaid -graph TD - H[构建日程Prompt] --> I{查询数据库}; - I -- 读取当月未使用的计划 --> J[随机抽取N个计划]; - J --> K[将计划文本嵌入日程Prompt]; - K --> L{随机数判断}; - L -- 概率命中 --> M[将已抽取的计划标记为删除]; - M --> N[完成Prompt构建]; - L -- 概率未命中 --> N; -``` - ---- - -## 4. 数据库模型设计 - -为支撑本系统,需要新增一个数据库表。 - -**表名:** `monthly_plans` - -| 字段名 | 类型 | 描述 | -| :--- | :--- | :--- | -| `id` | Integer | 主键,自增。 | -| `plan_text` | Text | 由LLM生成的计划内容原文。 | -| `target_month` | String(7) | 计划所属的月份,格式为 "YYYY-MM"。 | -| `is_deleted` | Boolean | 软删除标记,默认为 `false`。 | -| `created_at` | DateTime | 记录创建时间。 | - ---- - -## 5. 详细模块规划 - -### 5.1 LLM Prompt生成模块 - -- **职责:** 构建高质量的Prompt以引导LLM生成符合要求的计划。 -- **输入:** Bot人设描述、当前月份、期望生成的计划数量。 -- **输出:** 一个结构化的Prompt字符串。 -- **Prompt示例:** - ``` - 你是一个[此处填入Bot人设描述,例如:活泼开朗、偶尔有些小迷糊的虚拟助手]。 - 请为即将到来的[YYYY年MM月]设计[N]个符合你身份的月度计划或目标。 - - 要求: - 1. 每个计划都是独立的、积极向上的。 - 2. 语言风格要自然、口语化,符合你的性格。 - 3. 每个计划用一句话或两句话简短描述。 - 4. 以JSON格式返回,格式为:{"plans": ["计划一", "计划二", ...]} - ``` - -### 5.2 数据库交互模块 - -- **职责:** 提供对 `monthly_plans` 表的增、删、改、查接口。 -- **规划函数列表:** - - `add_new_plans(plans: list[str], month: str)`: 批量添加新生成的计划。 - - `get_active_plans_for_month(month: str) -> list`: 获取指定月份所有未被删除的计划。 - - `soft_delete_plans(plan_ids: list[int])`: 将指定ID的计划标记为软删除。 - -### 5.3 配置项规划 - -需要在主配置文件 `config/bot_config.toml` 中添加以下配置项,以控制系统行为。 - -```toml -# ---------------------------------------------------------------- -# 月层计划系统设置 (Monthly Plan System Settings) -# ---------------------------------------------------------------- -[monthly_plan_system] - -# 是否启用本功能 -enable = true - -# 启动时,如果当月计划少于此数量,则触发LLM生成 -generation_threshold = 10 - -# 每次调用LLM期望生成的计划数量 -plans_per_generation = 5 - -# 计划被使用后,被删除的概率 (0.0 到 1.0) -deletion_probability_on_use = 0.5 -``` - ---- -**文档结束。** 本文档纯粹为架构规划,旨在提供清晰的设计思路和开发指引,不包含任何实现代码。 \ No newline at end of file diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json deleted file mode 100644 index b1a4c4eb8..000000000 --- a/plugins/hello_world_plugin/_manifest.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "manifest_version": 1, - "name": "Hello World 示例插件 (Hello World Plugin)", - "version": "1.0.0", - "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例", - "author": { - "name": "MaiBot开发团队", - "url": "https://github.com/MaiM-with-u" - }, - "license": "GPL-v3.0-or-later", - - "host_application": { - "min_version": "0.8.0" - }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": ["demo", "example", "hello", "greeting", "tutorial"], - "categories": ["Examples", "Tutorial"], - - "default_locale": "zh-CN", - "locales_path": "_locales", - - "plugin_info": { - "is_built_in": false, - "plugin_type": "example", - "components": [ - { - "type": "action", - "name": "hello_greeting", - "description": "向用户发送问候消息" - }, - { - "type": "action", - "name": "bye_greeting", - "description": "向用户发送告别消息", - "activation_modes": ["keyword"], - "keywords": ["再见", "bye", "88", "拜拜"] - }, - { - "type": "command", - "name": "time", - "description": "查询当前时间", - "pattern": "/time" - } - ], - "features": [ - "问候和告别功能", - "时间查询命令", - "配置文件示例", - "新手教程代码" - ] - } -} \ No newline at end of file diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py deleted file mode 100644 index 949f824c0..000000000 --- a/plugins/hello_world_plugin/plugin.py +++ /dev/null @@ -1,282 +0,0 @@ -from typing import List, Tuple, Type, Any, Optional -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseAction, - BaseCommand, - BaseTool, - ComponentInfo, - ActionActivationType, - ConfigField, - ToolParamType -) - -from src.plugin_system.apis import send_api -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ChatType - -logger = get_logger(__name__) - - -class GetGroupListCommand(BaseCommand): - """获取群列表命令""" - - command_name = "get_groups" - command_description = "获取机器人加入的群列表" - command_pattern = r"^/get_groups$" - command_help = "获取机器人加入的群列表" - command_examples = ["/get_groups"] - intercept_message = True - - - - - - async def execute(self) -> Tuple[bool, str, bool]: - try: - # 调用适配器命令API - domain = "user.qzone.qq.com" - response = await send_api.adapter_command_to_stream( - action="get_cookies", - platform="qq", - params={"domain": domain}, - timeout=40.0, - storage_message=False - ) - text = str(response) - await self.send_text(text) - return True, "获取群列表成功", True - - except Exception as e: - await self.send_text(f"获取群列表失败: {str(e)}") - return False, "获取群列表失败", True - -class CompareNumbersTool(BaseTool): - """比较两个数大小的工具""" - - name = "compare_numbers" - description = "使用工具 比较两个数的大小,返回较大的数" - parameters = [ - ("num1", ToolParamType.FLOAT, "第一个数字", True, None), - ("num2", ToolParamType.FLOAT, "第二个数字", True, None), - ] - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行比较两个数的大小 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - num1: int | float = function_args.get("num1") # type: ignore - num2: int | float = function_args.get("num2") # type: ignore - - try: - if num1 > num2: - result = f"{num1} 大于 {num2}" - elif num1 < num2: - result = f"{num1} 小于 {num2}" - else: - result = f"{num1} 等于 {num2}" - - return {"name": self.name, "content": result} - except Exception as e: - return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} - - -# ===== Action组件 ===== -class HelloAction(BaseAction): - """问候Action - 简单的问候动作""" - - # === 基本信息(必须填写)=== - action_name = "hello_greeting" - action_description = "向用户发送问候消息" - activation_type = ActionActivationType.ALWAYS # 始终激活 - - # === 功能描述(必须填写)=== - action_parameters = {"greeting_message": "要发送的问候消息"} - action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - """执行问候动作 - 这是核心功能""" - # 发送问候消息 - greeting_message = self.action_data.get("greeting_message", "") - base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") - message = base_message + greeting_message - await self.send_text(message) - - return True, "发送了问候消息" - - -class ByeAction(BaseAction): - """告别Action - 只在用户说再见时激活""" - - action_name = "bye_greeting" - action_description = "向用户发送告别消息" - - # 使用关键词激活 - activation_type = ActionActivationType.KEYWORD - - # 关键词设置 - activation_keywords = ["再见", "bye", "88", "拜拜"] - keyword_case_sensitive = False - - action_parameters = {"bye_message": "要发送的告别消息"} - action_require = [ - "用户要告别时使用", - "当有人要离开时使用", - "当有人和你说再见时使用", - ] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - bye_message = self.action_data.get("bye_message", "") - - message = f"再见!期待下次聊天!👋{bye_message}" - await self.send_text(message) - return True, "发送了告别消息" - - -class TimeCommand(BaseCommand): - """时间查询Command - 响应/time命令""" - - command_name = "time" - command_description = "获取当前时间" - - # === 命令设置(必须填写)=== - command_pattern = r"^/time$" # 精确匹配 "/time" 命令 - chat_type_allow = ChatType.GROUP # 仅在群聊中可用 - - async def execute(self) -> Tuple[bool, str, bool]: - """执行时间查询""" - import datetime - - # 获取当前时间 - time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore - now = datetime.datetime.now() - time_str = now.strftime(time_format) - - # 发送时间信息 - message = f"⏰ 当前时间:{time_str}" - await self.send_text(message) - - return True, f"显示了当前x时间: {time_str}", True - - -# class PrintMessage(BaseEventHandler): -# """打印消息事件处理器 - 处理打印消息事件""" -# -# event_type = EventType.ON_MESSAGE -# handler_name = "print_message_handler" -# handler_description = "打印接收到的消息" -# -# async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]: -# """执行打印消息事件处理""" -# # 打印接收到的消息 -# -# if self.get_config("print_message.enabled", False): -# print(f"接收到消息: {message.raw_message}") -# return True, True, "消息已打印1" - - -# ===== 插件注册 ===== - - -@register_plugin -class HelloWorldPlugin(BasePlugin): - """Hello World插件 - 你的第一个MaiCore插件""" - - # 插件基本信息 - plugin_name: str = "hello_world_plugin" # 内部标识符 - enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" # 配置文件名 - - # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"} - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"), - "version": ConfigField(type=str, default="1.0.0", description="插件版本"), - "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), - }, - "greeting": { - "message": ConfigField( - type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息" - ), - "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), - }, - "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, - "print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")}, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - return [ - (HelloAction.get_action_info(), HelloAction), - (CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具 - (ByeAction.get_action_info(), ByeAction), # 添加告别Action - (TimeCommand.get_command_info(), TimeCommand), # 现在只能在群聊中使用 - (GetGroupListCommand.get_command_info(), GetGroupListCommand), # 添加获取群列表命令 - (PrivateInfoCommand.get_command_info(), PrivateInfoCommand), # 私聊专用命令 - (GroupOnlyAction.get_action_info(), GroupOnlyAction), # 群聊专用动作 - # (PrintMessage.get_handler_info(), PrintMessage), - ] - - -# @register_plugin -# class HelloWorldEventPlugin(BaseEPlugin): -# """Hello World事件插件 - 处理问候和告别事件""" - -# plugin_name = "hello_world_event_plugin" -# enable_plugin = False -# dependencies = [] -# python_dependencies = [] -# config_file_name = "event_config.toml" - -# config_schema = { -# "plugin": { -# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"), -# "version": ConfigField(type=str, default="1.0.0", description="插件版本"), -# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), -# }, -# } - -# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: -# return [(PrintMessage.get_handler_info(), PrintMessage)] - -# 添加一个新的私聊专用命令 -class PrivateInfoCommand(BaseCommand): - command_name = "private_info" - command_description = "获取私聊信息" - command_pattern = r"^/私聊信息$" - chat_type_allow = ChatType.PRIVATE # 仅在私聊中可用 - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """执行私聊信息命令""" - try: - await self.send_text("这是一个只能在私聊中使用的命令!") - return True, "私聊信息命令执行成功", False - except Exception as e: - logger.error(f"私聊信息命令执行失败: {e}") - return False, f"命令执行失败: {e}", False - -# 添加一个新的仅群聊可用的Action -class GroupOnlyAction(BaseAction): - action_name = "group_only_test" - action_description = "群聊专用测试动作" - chat_type_allow = ChatType.GROUP # 仅在群聊中可用 - - async def execute(self) -> Tuple[bool, str]: - """执行群聊专用测试动作""" - try: - await self.send_text("这是一个只能在群聊中执行的动作!") - return True, "群聊专用动作执行成功" - except Exception as e: - logger.error(f"群聊专用动作执行失败: {e}") - return False, f"动作执行失败: {e}" diff --git a/plugins/napcat_adapter_plugin/.gitignore b/plugins/napcat_adapter_plugin/.gitignore new file mode 100644 index 000000000..0a12c0b74 --- /dev/null +++ b/plugins/napcat_adapter_plugin/.gitignore @@ -0,0 +1,279 @@ + +log/ +logs/ +out/ + +.env +.env.* +.cursor + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +uv.lock +llm_statistics.txt +mongodb +napcat +run_dev.bat +elua.confirmed +# C extensions +*.so +/results +config_backup/ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# PyPI configuration file +.pypirc + +# jieba +jieba.cache + +# .vscode +!.vscode/settings.json + +# direnv +/.direnv + +# JetBrains +.idea +*.iml +*.ipr + +# PyEnv +# If using PyEnv and configured to use a specific Python version locally +# a .local-version file will be created in the root of the project to specify the version. +.python-version + +OtherRes.txt + +/eula.confirmed +/privacy.confirmed + +logs + +.ruff_cache + +.vscode + +/config/* +config/old/bot_config_20250405_212257.toml +temp/ + +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +config.toml +feature.toml +config.toml.back +test +data/NapcatAdapter.db +data/NapcatAdapter.db-shm +data/NapcatAdapter.db-wal \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/CONSTS.py b/plugins/napcat_adapter_plugin/CONSTS.py new file mode 100644 index 000000000..174602208 --- /dev/null +++ b/plugins/napcat_adapter_plugin/CONSTS.py @@ -0,0 +1 @@ +PLUGIN_NAME = "napcat_adapter" \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/_manifest.json b/plugins/napcat_adapter_plugin/_manifest.json new file mode 100644 index 000000000..676aa3121 --- /dev/null +++ b/plugins/napcat_adapter_plugin/_manifest.json @@ -0,0 +1,42 @@ +{ + "manifest_version": 1, + "name": "napcat_plugin", + "version": "1.0.0", + "description": "基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接", + "author": { + "name": "Windpicker_owo", + "url": "https://github.com/Windpicker-owo" + }, + "license": "GPL-v3.0-or-later", + + "host_application": { + "min_version": "0.10.0", + "max_version": "0.10.0" + }, + "homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin", + "repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin", + "keywords": ["qq", "bot", "napcat", "onebot", "api", "websocket"], + "categories": ["protocol"], + "default_locale": "zh-CN", + "locales_path": "_locales", + + "plugin_info": { + "is_built_in": false, + "components": [ + { + "type": "tool", + "name": "napcat_tool", + "description": "NapCat QQ协议综合工具,提供消息发送、群管理、好友管理、文件操作等完整功能" + } + ], + "features": [ + "消息发送与接收", + "群管理功能", + "好友管理功能", + "文件上传下载", + "AI语音功能", + "群签到与戳一戳", + "现有adapter连接" + ] + } +} \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/event_handlers.py b/plugins/napcat_adapter_plugin/event_handlers.py new file mode 100644 index 000000000..cc5c7c2ec --- /dev/null +++ b/plugins/napcat_adapter_plugin/event_handlers.py @@ -0,0 +1,6 @@ +from typing import List, Tuple + +from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType +from src.plugin_system.base.base_event import HandlerResult +from src.plugin_system.core.event_manager import event_manager + diff --git a/plugins/napcat_adapter_plugin/event_types.py b/plugins/napcat_adapter_plugin/event_types.py new file mode 100644 index 000000000..b199cdd37 --- /dev/null +++ b/plugins/napcat_adapter_plugin/event_types.py @@ -0,0 +1,69 @@ +from enum import Enum + +class NapcatEvent(Enum): + # napcat插件事件枚举类 + class ON_RECEIVED(Enum): + """ + 该分类下均为消息接受事件,只能由napcat_plugin触发 + """ + TEXT = "napcat_on_received_text" # 接收到文本消息 + FACE = "napcat_on_received_face" # 接收到表情消息 + REPLY = "napcat_on_received_reply" # 接收到回复消息 + IMAGE = "napcat_on_received_image" # 接收到图像消息 + RECORD = "napcat_on_received_record" # 接收到语音消息 + VIDEO = "napcat_on_received_video" # 接收到视频消息 + AT = "napcat_on_received_at" # 接收到at消息 + DICE = "napcat_on_received_dice" # 接收到骰子消息 + SHAKE = "napcat_on_received_shake" # 接收到屏幕抖动消息 + JSON = "napcat_on_received_json" # 接收到JSON消息 + RPS = "napcat_on_received_rps" # 接收到魔法猜拳消息 + FRIEND_INPUT = "napcat_on_friend_input" # 好友正在输入 + + class ACCOUNT(Enum): + """ + 该分类是对账户相关的操作,只能由外部触发,napcat_plugin负责处理 + """ + SET_PROFILE = "napcat_set_qq_profile" # 设置账号信息 + GET_ONLINE_CLIENTS = "napcat_get_online_clients" # 获取当前账号在线客户端列表 + SET_ONLINE_STATUS = "napcat_set_online_status" # 设置在线状态 + GET_FRIENDS_WITH_CATEGORY = "napcat_get_friends_with_category" # 获取好友分组列表 + SET_AVATAR = "napcat_set_qq_avatar" # 设置头像 + SEND_LIKE = "napcat_send_like" # 点赞 + SET_FRIEND_ADD_REQUEST = "napcat_set_friend_add_request" # 处理好友请求 + SET_SELF_LONGNICK = "napcat_set_self_longnick" # 设置个性签名 + GET_LOGIN_INFO = "napcat_get_login_info" # 获取登录号信息 + GET_RECENT_CONTACT = "napcat_get_recent_contact" # 最近消息列表 + GET_STRANGER_INFO = "napcat_get_stranger_info" # 获取(指定)账号信息 + GET_FRIEND_LIST = "napcat_get_friend_list" # 获取好友列表 + GET_PROFILE_LIKE = "napcat_get_profile_like" # 获取点赞列表 + DELETE_FRIEND = "napcat_delete_friend" # 删除好友 + GET_USER_STATUS = "napcat_get_user_status" # 获取用户状态 + GET_STATUS = "napcat_get_status" # 获取状态 + GET_MINI_APP_ARK = "napcat_get_mini_app_ark" # 获取小程序卡片 + SET_DIY_ONLINE_STATUS = "napcat_set_diy_online_status" # 设置自定义在线状态 + + class MESSAGE(Enum): + """ + 该分类是对信息相关的操作,只能由外部触发,napcat_plugin负责处理 + """ + SEND_GROUP_POKE = "napcat_send_group_poke" # 发送群聊戳一戳 + SEND_PRIVATE_MSG = "napcat_send_private_msg" # 发送私聊消息 + SEND_POKE = "napcat_send_friend_poke" # 发送戳一戳 + DELETE_MSG = "napcat_delete_msg" # 撤回消息 + GET_GROUP_MSG_HISTORY = "napcat_get_group_msg_history" # 获取群历史消息 + GET_MSG = "napcat_get_msg" # 获取消息详情 + GET_FORWARD_MSG = "napcat_get_forward_msg" # 获取合并转发消息 + SET_MSG_EMOJI_LIKE = "napcat_set_msg_emoji_like" # 贴表情 + GET_FRIEND_MSG_HISTORY = "napcat_get_friend_msg_history" # 获取好友历史消息 + FETCH_EMOJI_LIKE = "napcat_fetch_emoji_like" # 获取贴表情详情 + SEND_FORWARF_MSG = "napcat_send_forward_msg" # 发送合并转发消息 + GET_RECOED = "napcat_get_record" # 获取语音消息详情 + SEND_GROUP_AI_RECORD = "napcat_send_group_ai_record" # 发送群AI语音 + + class GROUP(Enum): + """ + 该分类是对群聊相关的操作,只能由外部触发,napcat_plugin负责处理 + """ + + + diff --git a/plugins/napcat_adapter_plugin/plugin.py b/plugins/napcat_adapter_plugin/plugin.py new file mode 100644 index 000000000..d6ec27853 --- /dev/null +++ b/plugins/napcat_adapter_plugin/plugin.py @@ -0,0 +1,131 @@ +import sys +import asyncio +import json +import websockets as Server +from . import event_types,CONSTS + +from typing import List, Tuple + +from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType +from src.plugin_system.base.base_event import HandlerResult +from src.plugin_system.core.event_manager import event_manager + +from pathlib import Path +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + +# 添加当前目录到Python路径,这样可以识别src包 +current_dir = Path(__file__).parent +sys.path.insert(0, str(current_dir)) + +from .src.recv_handler.message_handler import message_handler +from .src.recv_handler.meta_event_handler import meta_event_handler +from .src.recv_handler.notice_handler import notice_handler +from .src.recv_handler.message_sending import message_send_instance +from .src.send_handler import send_handler +from .src.config import global_config +from .src.config.features_config import features_manager +from .src.config.migrate_features import auto_migrate_features +from .src.mmc_com_layer import mmc_start_com, mmc_stop_com, router +from .src.response_pool import put_response, check_timeout_response +from .src.websocket_manager import websocket_manager + +message_queue = asyncio.Queue() + +class LauchNapcatAdapterHandler(BaseEventHandler): + """自动启动Adapter""" + + handler_name: str = "launch_napcat_adapter_handler" + handler_description: str = "自动启动napcat adapter" + weight: int = 100 + intercept_message: bool = False + init_subscribe = [EventType.ON_START] + + async def message_recv(self, server_connection: Server.ServerConnection): + await message_handler.set_server_connection(server_connection) + asyncio.create_task(notice_handler.set_server_connection(server_connection)) + await send_handler.set_server_connection(server_connection) + async for raw_message in server_connection: + logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message) + decoded_raw_message: dict = json.loads(raw_message) + post_type = decoded_raw_message.get("post_type") + if post_type in ["meta_event", "message", "notice"]: + await message_queue.put(decoded_raw_message) + elif post_type is None: + await put_response(decoded_raw_message) + + async def message_process(self): + while True: + message = await message_queue.get() + post_type = message.get("post_type") + if post_type == "message": + await message_handler.handle_raw_message(message) + elif post_type == "meta_event": + await meta_event_handler.handle_meta_event(message) + elif post_type == "notice": + await notice_handler.handle_notice(message) + else: + logger.warning(f"未知的post_type: {post_type}") + message_queue.task_done() + await asyncio.sleep(0.05) + + async def napcat_server(self): + """启动 Napcat WebSocket 连接(支持正向和反向连接)""" + mode = global_config.napcat_server.mode + logger.info(f"正在启动 adapter,连接模式: {mode}") + + try: + await websocket_manager.start_connection(self.message_recv) + except Exception as e: + logger.error(f"启动 WebSocket 连接失败: {e}") + raise + + async def execute(self, kwargs): + # 执行功能配置迁移(如果需要) + logger.info("检查功能配置迁移...") + auto_migrate_features() + + # 初始化功能管理器 + logger.info("正在初始化功能管理器...") + features_manager.load_config() + await features_manager.start_file_watcher(check_interval=2.0) + logger.info("功能管理器初始化完成") + logger.info("开始启动Napcat Adapter") + message_send_instance.maibot_router = router + # 创建单独的异步任务,防止阻塞主线程 + asyncio.create_task(self.napcat_server()) + asyncio.create_task(mmc_start_com()) + asyncio.create_task(self.message_process()) + asyncio.create_task(check_timeout_response()) + +@register_plugin +class NapcatAdapterPlugin(BasePlugin): + plugin_name = CONSTS.PLUGIN_NAME + enable_plugin: bool = True + dependencies: List[str] = [] # 插件依赖列表 + python_dependencies: List[str] = [] # Python包依赖列表 + config_file_name: str = "config.toml" # 配置文件名 + + # 配置节描述 + config_section_descriptions = {"plugin": "插件基本信息"} + + # 配置Schema定义 + config_schema: dict = { + "plugin": { + "name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"), + "version": ConfigField(type=str, default="1.0.0", description="插件版本"), + "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), + } + } + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for e in event_types.NapcatEvent.ON_RECEIVED: + event_manager.register_event(e ,allowed_triggers=[self.plugin_name]) + + def get_plugin_components(self): + components = [] + components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler)) + return components diff --git a/plugins/napcat_adapter_plugin/pyproject.toml b/plugins/napcat_adapter_plugin/pyproject.toml new file mode 100644 index 000000000..13c76e885 --- /dev/null +++ b/plugins/napcat_adapter_plugin/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "MaiBotNapcatAdapter" +version = "0.4.8" +description = "A MaiBot adapter for Napcat" +dependencies = [ + "ruff>=0.12.9", +] + +[tool.ruff] + +include = ["*.py"] + +# 行长度设置 +line-length = 120 + +[tool.ruff.lint] +fixable = ["ALL"] +unfixable = [] + +# 启用的规则 +select = [ + "E", # pycodestyle 错误 + "F", # pyflakes + "B", # flake8-bugbear +] + +ignore = ["E711","E501"] + +[tool.ruff.format] +docstring-code-format = true +indent-style = "space" + + +# 使用双引号表示字符串 +quote-style = "double" + +# 尊重魔法尾随逗号 +# 例如: +# items = [ +# "apple", +# "banana", +# "cherry", +# ] +skip-magic-trailing-comma = false + +# 自动检测合适的换行符 +line-ending = "auto" diff --git a/plugins/napcat_adapter_plugin/src/__init__.py b/plugins/napcat_adapter_plugin/src/__init__.py new file mode 100644 index 000000000..76c84e814 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/__init__.py @@ -0,0 +1,31 @@ +from enum import Enum +import tomlkit +import os +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + DELETE_MSG = "delete_msg" # 撤回消息 + AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音 + SET_EMOJI_LIKE = "set_emoji_like" # 设置表情回应 + SEND_AT_MESSAGE = "send_at_message" # 艾特用户并发送消息 + SEND_LIKE = "send_like" # 点赞 + + def __str__(self) -> str: + return self.value + + +pyproject_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "pyproject.toml" +) +toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read()) +project_data = toml_data.get("project", {}) +version = project_data.get("version", "unknown") +logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n") diff --git a/plugins/napcat_adapter_plugin/src/config/__init__.py b/plugins/napcat_adapter_plugin/src/config/__init__.py new file mode 100644 index 000000000..40ba89aeb --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/__init__.py @@ -0,0 +1,5 @@ +from .config import global_config + +__all__ = [ + "global_config", +] diff --git a/plugins/napcat_adapter_plugin/src/config/config.py b/plugins/napcat_adapter_plugin/src/config/config.py new file mode 100644 index 000000000..c954a7c14 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/config.py @@ -0,0 +1,148 @@ +import os +from dataclasses import dataclass +from datetime import datetime + +import tomlkit +import shutil + +from tomlkit import TOMLDocument +from tomlkit.items import Table +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from rich.traceback import install + +from .config_base import ConfigBase +from .official_configs import ( + DebugConfig, + MaiBotServerConfig, + NapcatServerConfig, + NicknameConfig, + VoiceConfig, +) + +install(extra_lines=3) + +TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template" +CONFIG_DIR = "plugins/napcat_adapter_plugin/config" +OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old" + + +def ensure_config_directories(): + """确保配置目录存在""" + os.makedirs(CONFIG_DIR, exist_ok=True) + os.makedirs(OLD_CONFIG_DIR, exist_ok=True) + + +def update_config(): + """更新配置文件,统一使用 config/old 目录进行备份""" + # 确保目录存在 + ensure_config_directories() + + # 定义文件路径 + template_path = f"{TEMPLATE_DIR}/template_config.toml" + config_path = f"{CONFIG_DIR}/config.toml" + + # 检查配置文件是否存在 + if not os.path.exists(config_path): + logger.info("主配置文件不存在,从模板创建新配置") + shutil.copy2(template_path, config_path) + logger.info(f"已创建新配置文件: {config_path}") + logger.info("程序将退出,请检查配置文件后重启") + + # 读取配置文件和模板文件 + with open(config_path, "r", encoding="utf-8") as f: + old_config = tomlkit.load(f) + with open(template_path, "r", encoding="utf-8") as f: + new_config = tomlkit.load(f) + + # 检查version是否相同 + if old_config and "inner" in old_config and "inner" in new_config: + old_version = old_config["inner"].get("version") + new_version = new_config["inner"].get("version") + if old_version and new_version and old_version == new_version: + logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + return + else: + logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") + else: + logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + + # 创建备份文件 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}") + + # 备份旧配置文件 + shutil.copy2(config_path, backup_path) + logger.info(f"已备份旧配置文件到: {backup_path}") + + # 复制模板文件到配置目录 + shutil.copy2(template_path, config_path) + logger.info(f"已创建新配置文件: {config_path}") + + def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + """将source字典的值更新到target字典中(如果target中存在相同的键)""" + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): + update_dict(target[key], value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + # 将旧配置的值更新到新配置中 + logger.info("开始合并新旧配置...") + update_dict(new_config, old_config) + + # 保存更新后的配置(保留注释和格式) + with open(config_path, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(new_config)) + logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + + +@dataclass +class Config(ConfigBase): + """总配置类""" + + nickname: NicknameConfig + napcat_server: NapcatServerConfig + maibot_server: MaiBotServerConfig + voice: VoiceConfig + debug: DebugConfig + + +def load_config(config_path: str) -> Config: + """ + 加载配置文件 + :param config_path: 配置文件路径 + :return: Config对象 + """ + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建Config对象 + try: + return Config.from_dict(config_data) + except Exception as e: + logger.critical("配置文件解析失败") + raise e + + +# 更新配置 +update_config() + +logger.info("正在品鉴配置文件...") +global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml") +logger.info("非常的新鲜,非常的美味!") diff --git a/plugins/napcat_adapter_plugin/src/config/config_base.py b/plugins/napcat_adapter_plugin/src/config/config_base.py new file mode 100644 index 000000000..87cb079d2 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/config_base.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, fields, MISSING +from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union + +T = TypeVar("T", bound="ConfigBase") + +TOML_DICT_TYPE = { + int, + float, + str, + bool, + list, + dict, +} + + +@dataclass +class ConfigBase: + """配置类的基类""" + + @classmethod + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: + """从字典加载配置字段""" + if not isinstance(data, dict): + raise TypeError(f"Expected a dictionary, got {type(data).__name__}") + + init_args: Dict[str, Any] = {} + + for f in fields(cls): + field_name = f.name + field_type = f.type + if field_name.startswith("_"): + # 跳过以 _ 开头的字段 + continue + + if field_name not in data: + if f.default is not MISSING or f.default_factory is not MISSING: + # 跳过未提供且有默认值/默认构造方法的字段 + continue + else: + raise ValueError(f"Missing required field: '{field_name}'") + + value = data[field_name] + try: + init_args[field_name] = cls._convert_field(value, field_type) + except TypeError as e: + raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e + except Exception as e: + raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e + + return cls(**init_args) + + @classmethod + def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + """ + 转换字段值为指定类型 + + 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法 + 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素 + 3. 对于基础类型(int, str, float, bool),直接转换 + 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 + """ + # 如果是嵌套的 dataclass,递归调用 from_dict 方法 + if isinstance(field_type, type) and issubclass(field_type, ConfigBase): + return field_type.from_dict(value) + + field_origin_type = get_origin(field_type) + field_args_type = get_args(field_type) + + # 处理泛型集合类型(list, set, tuple) + if field_origin_type in {list, set, tuple}: + # 检查提供的value是否为list + if not isinstance(value, list): + raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") + + if field_origin_type is list: + return [cls._convert_field(item, field_args_type[0]) for item in value] + if field_origin_type is set: + return {cls._convert_field(item, field_args_type[0]) for item in value} + if field_origin_type is tuple: + # 检查提供的value长度是否与类型参数一致 + if len(value) != len(field_args_type): + raise TypeError( + f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}" + ) + return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type)) + + if field_origin_type is dict: + # 检查提供的value是否为dict + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + + # 检查字典的键值类型 + if len(field_args_type) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") + key_type, value_type = field_args_type + + return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + + # 处理Optional类型 + if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union + if value is None: + return None + # 如果有数据,检查实际类型 + if type(value) not in field_args_type: + raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}") + return cls._convert_field(value, field_args_type[0]) + + # 处理int, str, float, bool等基础类型 + if field_origin_type is None: + if isinstance(value, field_type): + return field_type(value) + else: + raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}") + + # 处理Literal类型 + if field_origin_type is Literal: + # 获取Literal的允许值 + allowed_values = get_args(field_type) + if value in allowed_values: + return value + else: + raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") + + # 处理其他类型 + if field_type is Any: + return value + + # 其他类型直接转换 + try: + return field_type(value) + except (ValueError, TypeError) as e: + raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e + + def __str__(self): + """返回配置类的字符串表示""" + return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})" diff --git a/plugins/napcat_adapter_plugin/src/config/config_utils.py b/plugins/napcat_adapter_plugin/src/config/config_utils.py new file mode 100644 index 000000000..8aa994b4d --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/config_utils.py @@ -0,0 +1,146 @@ +""" +配置文件工具模块 +提供统一的配置文件生成和管理功能 +""" +import os +import shutil +from pathlib import Path +from datetime import datetime +from typing import Optional + +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + + +def ensure_config_directories(): + """确保配置目录存在""" + os.makedirs("config", exist_ok=True) + os.makedirs("config/old", exist_ok=True) + + +def create_config_from_template( + config_path: str, + template_path: str, + config_name: str = "配置文件", + should_exit: bool = True +) -> bool: + """ + 从模板创建配置文件的统一函数 + + Args: + config_path: 配置文件路径 + template_path: 模板文件路径 + config_name: 配置文件名称(用于日志显示) + should_exit: 创建后是否退出程序 + + Returns: + bool: 是否成功创建配置文件 + """ + try: + # 确保配置目录存在 + ensure_config_directories() + + config_path_obj = Path(config_path) + template_path_obj = Path(template_path) + + # 检查配置文件是否存在 + if config_path_obj.exists(): + return False # 配置文件已存在,无需创建 + + logger.info(f"{config_name}不存在,从模板创建新配置") + + # 检查模板文件是否存在 + if not template_path_obj.exists(): + logger.error(f"模板文件不存在: {template_path}") + if should_exit: + logger.critical("无法创建配置文件,程序退出") + quit(1) + return False + + # 确保配置文件目录存在 + config_path_obj.parent.mkdir(parents=True, exist_ok=True) + + # 复制模板文件到配置目录 + shutil.copy2(template_path_obj, config_path_obj) + logger.info(f"已创建新{config_name}: {config_path}") + + if should_exit: + logger.info("程序将退出,请检查配置文件后重启") + quit(0) + + return True + + except Exception as e: + logger.error(f"创建{config_name}失败: {e}") + if should_exit: + logger.critical("无法创建配置文件,程序退出") + quit(1) + return False + + +def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool: + """ + 创建默认配置文件(使用字典数据) + + Args: + default_values: 默认配置值字典 + config_path: 配置文件路径 + config_name: 配置文件名称(用于日志显示) + + Returns: + bool: 是否成功创建配置文件 + """ + try: + import tomlkit + + config_path_obj = Path(config_path) + + # 确保配置文件目录存在 + config_path_obj.parent.mkdir(parents=True, exist_ok=True) + + # 写入默认配置 + with open(config_path_obj, "w", encoding="utf-8") as f: + tomlkit.dump(default_values, f) + + logger.info(f"已创建默认{config_name}: {config_path}") + return True + + except Exception as e: + logger.error(f"创建默认{config_name}失败: {e}") + return False + + +def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]: + """ + 备份配置文件 + + Args: + config_path: 要备份的配置文件路径 + backup_dir: 备份目录 + + Returns: + Optional[str]: 备份文件路径,失败时返回None + """ + try: + config_path_obj = Path(config_path) + if not config_path_obj.exists(): + return None + + # 确保备份目录存在 + backup_dir_obj = Path(backup_dir) + backup_dir_obj.mkdir(parents=True, exist_ok=True) + + # 创建备份文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}" + backup_path = backup_dir_obj / backup_filename + + # 备份文件 + shutil.copy2(config_path_obj, backup_path) + logger.info(f"已备份配置文件到: {backup_path}") + + return str(backup_path) + + except Exception as e: + logger.error(f"备份配置文件失败: {e}") + return None \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/src/config/features_config.py b/plugins/napcat_adapter_plugin/src/config/features_config.py new file mode 100644 index 000000000..a8b25938d --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/features_config.py @@ -0,0 +1,359 @@ +import asyncio +from dataclasses import dataclass, field +from typing import Literal, Optional +from pathlib import Path +import tomlkit +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from .config_base import ConfigBase +from .config_utils import create_config_from_template, create_default_config_dict + + +@dataclass +class FeaturesConfig(ConfigBase): + """功能配置类""" + + group_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """群聊列表类型 白名单/黑名单""" + + group_list: list[int] = field(default_factory=list) + """群聊列表""" + + private_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """私聊列表类型 白名单/黑名单""" + + private_list: list[int] = field(default_factory=list) + """私聊列表""" + + ban_user_id: list[int] = field(default_factory=list) + """被封禁的用户ID列表,封禁后将无法与其进行交互""" + + ban_qq_bot: bool = False + """是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互""" + + enable_poke: bool = True + """是否启用戳一戳功能""" + + ignore_non_self_poke: bool = False + """是否无视不是针对自己的戳一戳""" + + poke_debounce_seconds: int = 3 + """戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略""" + + enable_reply_at: bool = True + """是否启用引用回复时艾特用户的功能""" + + reply_at_rate: float = 0.5 + """引用回复时艾特用户的几率 (0.0 ~ 1.0)""" + + enable_video_analysis: bool = True + """是否启用视频识别功能""" + + max_video_size_mb: int = 100 + """视频文件最大大小限制(MB)""" + + download_timeout: int = 60 + """视频下载超时时间(秒)""" + + supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]) + """支持的视频格式""" + + # 消息缓冲配置 + enable_message_buffer: bool = True + """是否启用消息缓冲合并功能""" + + message_buffer_enable_group: bool = True + """是否启用群消息缓冲合并""" + + message_buffer_enable_private: bool = True + """是否启用私聊消息缓冲合并""" + + message_buffer_interval: float = 3.0 + """消息合并间隔时间(秒),在此时间内的连续消息将被合并""" + + message_buffer_initial_delay: float = 0.5 + """消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并""" + + message_buffer_max_components: int = 50 + """单个会话最大缓冲消息组件数量,超过此数量将强制合并""" + + message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "!", ".", "。", "#", "%"]) + """消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲""" + + +class FeaturesManager: + """功能管理器,支持热重载""" + + def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"): + self.config_path = Path(config_path) + self.config: Optional[FeaturesConfig] = None + self._file_watcher_task: Optional[asyncio.Task] = None + self._last_modified: Optional[float] = None + self._callbacks: list = [] + + def add_reload_callback(self, callback): + """添加配置重载回调函数""" + self._callbacks.append(callback) + + def remove_reload_callback(self, callback): + """移除配置重载回调函数""" + if callback in self._callbacks: + self._callbacks.remove(callback) + + async def _notify_callbacks(self): + """通知所有回调函数配置已重载""" + for callback in self._callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(self.config) + else: + callback(self.config) + except Exception as e: + logger.error(f"配置重载回调执行失败: {e}") + + def load_config(self) -> FeaturesConfig: + """加载功能配置文件""" + try: + # 检查配置文件是否存在,如果不存在则创建并退出程序 + if not self.config_path.exists(): + logger.info(f"功能配置文件不存在: {self.config_path}") + self._create_default_config() + # 配置文件创建后程序应该退出,让用户检查配置 + logger.info("程序将退出,请检查功能配置文件后重启") + quit(0) + + with open(self.config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + self.config = FeaturesConfig.from_dict(config_data) + self._last_modified = self.config_path.stat().st_mtime + logger.info(f"功能配置加载成功: {self.config_path}") + return self.config + + except Exception as e: + logger.error(f"功能配置加载失败: {e}") + logger.critical("无法加载功能配置文件,程序退出") + quit(1) + + def _create_default_config(self): + """创建默认功能配置文件""" + template_path = "template/features_template.toml" + + # 尝试从模板创建配置文件 + if create_config_from_template( + str(self.config_path), + template_path, + "功能配置文件", + should_exit=False # 不在这里退出,由调用方决定 + ): + return + + # 如果模板文件不存在,创建基本配置 + logger.info("模板文件不存在,创建基本功能配置") + default_config = { + "group_list_type": "whitelist", + "group_list": [], + "private_list_type": "whitelist", + "private_list": [], + "ban_user_id": [], + "ban_qq_bot": False, + "enable_poke": True, + "ignore_non_self_poke": False, + "poke_debounce_seconds": 3, + "enable_reply_at": True, + "reply_at_rate": 0.5, + "enable_video_analysis": True, + "max_video_size_mb": 100, + "download_timeout": 60, + "supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], + # 消息缓冲配置 + "enable_message_buffer": True, + "message_buffer_enable_group": True, + "message_buffer_enable_private": True, + "message_buffer_interval": 3.0, + "message_buffer_initial_delay": 0.5, + "message_buffer_max_components": 50, + "message_buffer_block_prefixes": ["/", "!", "!", ".", "。", "#", "%"] + } + + if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"): + logger.critical("无法创建功能配置文件") + quit(1) + + async def reload_config(self) -> bool: + """重新加载配置文件""" + try: + if not self.config_path.exists(): + logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}") + return False + + current_modified = self.config_path.stat().st_mtime + if self._last_modified and current_modified <= self._last_modified: + return False # 文件未修改 + + old_config = self.config + new_config = self.load_config() + + # 检查配置是否真的发生了变化 + if old_config and self._configs_equal(old_config, new_config): + return False + + logger.info("功能配置已重载") + await self._notify_callbacks() + return True + + except Exception as e: + logger.error(f"功能配置重载失败: {e}") + return False + + def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool: + """比较两个配置是否相等""" + return ( + config1.group_list_type == config2.group_list_type and + set(config1.group_list) == set(config2.group_list) and + config1.private_list_type == config2.private_list_type and + set(config1.private_list) == set(config2.private_list) and + set(config1.ban_user_id) == set(config2.ban_user_id) and + config1.ban_qq_bot == config2.ban_qq_bot and + config1.enable_poke == config2.enable_poke and + config1.ignore_non_self_poke == config2.ignore_non_self_poke and + config1.poke_debounce_seconds == config2.poke_debounce_seconds and + config1.enable_reply_at == config2.enable_reply_at and + config1.reply_at_rate == config2.reply_at_rate and + config1.enable_video_analysis == config2.enable_video_analysis and + config1.max_video_size_mb == config2.max_video_size_mb and + config1.download_timeout == config2.download_timeout and + set(config1.supported_formats) == set(config2.supported_formats) and + # 消息缓冲配置比较 + config1.enable_message_buffer == config2.enable_message_buffer and + config1.message_buffer_enable_group == config2.message_buffer_enable_group and + config1.message_buffer_enable_private == config2.message_buffer_enable_private and + config1.message_buffer_interval == config2.message_buffer_interval and + config1.message_buffer_initial_delay == config2.message_buffer_initial_delay and + config1.message_buffer_max_components == config2.message_buffer_max_components and + set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes) + ) + + async def start_file_watcher(self, check_interval: float = 1.0): + """启动文件监控,定期检查配置文件变化""" + if self._file_watcher_task and not self._file_watcher_task.done(): + logger.warning("文件监控已在运行") + return + + self._file_watcher_task = asyncio.create_task( + self._file_watcher_loop(check_interval) + ) + logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}秒") + + async def stop_file_watcher(self): + """停止文件监控""" + if self._file_watcher_task and not self._file_watcher_task.done(): + self._file_watcher_task.cancel() + try: + await self._file_watcher_task + except asyncio.CancelledError: + pass + logger.info("功能配置文件监控已停止") + + async def _file_watcher_loop(self, check_interval: float): + """文件监控循环""" + while True: + try: + await asyncio.sleep(check_interval) + await self.reload_config() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"文件监控循环出错: {e}") + await asyncio.sleep(check_interval) + + def get_config(self) -> FeaturesConfig: + """获取当前功能配置""" + if self.config is None: + return self.load_config() + return self.config + + def is_group_allowed(self, group_id: int) -> bool: + """检查群聊是否被允许""" + config = self.get_config() + if config.group_list_type == "whitelist": + return group_id in config.group_list + else: # blacklist + return group_id not in config.group_list + + def is_private_allowed(self, user_id: int) -> bool: + """检查私聊是否被允许""" + config = self.get_config() + if config.private_list_type == "whitelist": + return user_id in config.private_list + else: # blacklist + return user_id not in config.private_list + + def is_user_banned(self, user_id: int) -> bool: + """检查用户是否被全局禁止""" + config = self.get_config() + return user_id in config.ban_user_id + + def is_qq_bot_banned(self) -> bool: + """检查是否禁止QQ官方机器人""" + config = self.get_config() + return config.ban_qq_bot + + def is_poke_enabled(self) -> bool: + """检查戳一戳功能是否启用""" + config = self.get_config() + return config.enable_poke + + def is_non_self_poke_ignored(self) -> bool: + """检查是否忽略非自己戳一戳""" + config = self.get_config() + return config.ignore_non_self_poke + + def is_message_buffer_enabled(self) -> bool: + """检查消息缓冲功能是否启用""" + config = self.get_config() + return config.enable_message_buffer + + def is_message_buffer_group_enabled(self) -> bool: + """检查群消息缓冲是否启用""" + config = self.get_config() + return config.message_buffer_enable_group + + def is_message_buffer_private_enabled(self) -> bool: + """检查私聊消息缓冲是否启用""" + config = self.get_config() + return config.message_buffer_enable_private + + def get_message_buffer_interval(self) -> float: + """获取消息缓冲间隔时间""" + config = self.get_config() + return config.message_buffer_interval + + def get_message_buffer_initial_delay(self) -> float: + """获取消息缓冲初始延迟""" + config = self.get_config() + return config.message_buffer_initial_delay + + def get_message_buffer_max_components(self) -> int: + """获取消息缓冲最大组件数量""" + config = self.get_config() + return config.message_buffer_max_components + + def is_message_buffer_group_enabled(self) -> bool: + """检查是否启用群聊消息缓冲""" + config = self.get_config() + return config.message_buffer_enable_group + + def is_message_buffer_private_enabled(self) -> bool: + """检查是否启用私聊消息缓冲""" + config = self.get_config() + return config.message_buffer_enable_private + + def get_message_buffer_block_prefixes(self) -> list[str]: + """获取消息缓冲屏蔽前缀列表""" + config = self.get_config() + return config.message_buffer_block_prefixes + + +# 全局功能管理器实例 +features_manager = FeaturesManager() \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/src/config/migrate_features.py b/plugins/napcat_adapter_plugin/src/config/migrate_features.py new file mode 100644 index 000000000..46926bb7f --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/migrate_features.py @@ -0,0 +1,194 @@ +""" +功能配置迁移脚本 +用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件 +""" + +import os +import shutil +from pathlib import Path +import tomlkit +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + + +def migrate_features_from_config(old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml", + new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml", + template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml"): + """ + 从旧配置文件迁移功能设置到新的功能配置文件 + + Args: + old_config_path: 旧配置文件路径 + new_features_path: 新功能配置文件路径 + template_path: 功能配置模板路径 + """ + try: + # 检查旧配置文件是否存在 + if not os.path.exists(old_config_path): + logger.warning(f"旧配置文件不存在: {old_config_path}") + return False + + # 读取旧配置文件 + with open(old_config_path, "r", encoding="utf-8") as f: + old_config = tomlkit.load(f) + + # 检查是否有chat配置段和video配置段 + chat_config = old_config.get("chat", {}) + video_config = old_config.get("video", {}) + + # 检查是否有权限相关配置 + permission_keys = ["group_list_type", "group_list", "private_list_type", + "private_list", "ban_user_id", "ban_qq_bot", + "enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"] + video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"] + + has_permission_config = any(key in chat_config for key in permission_keys) + has_video_config = any(key in video_config for key in video_keys) + + if not has_permission_config and not has_video_config: + logger.info("旧配置文件中没有找到功能相关配置,无需迁移") + return False + + # 确保新功能配置目录存在 + new_features_dir = Path(new_features_path).parent + new_features_dir.mkdir(parents=True, exist_ok=True) + + # 如果新功能配置文件已存在,先备份 + if os.path.exists(new_features_path): + backup_path = f"{new_features_path}.backup" + shutil.copy2(new_features_path, backup_path) + logger.info(f"已备份现有功能配置文件到: {backup_path}") + + # 创建新的功能配置 + new_features_config = { + "group_list_type": chat_config.get("group_list_type", "whitelist"), + "group_list": chat_config.get("group_list", []), + "private_list_type": chat_config.get("private_list_type", "whitelist"), + "private_list": chat_config.get("private_list", []), + "ban_user_id": chat_config.get("ban_user_id", []), + "ban_qq_bot": chat_config.get("ban_qq_bot", False), + "enable_poke": chat_config.get("enable_poke", True), + "ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False), + "poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3), + "enable_video_analysis": video_config.get("enable_video_analysis", True), + "max_video_size_mb": video_config.get("max_video_size_mb", 100), + "download_timeout": video_config.get("download_timeout", 60), + "supported_formats": video_config.get("supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]) + } + + # 写入新的功能配置文件 + with open(new_features_path, "w", encoding="utf-8") as f: + tomlkit.dump(new_features_config, f) + + logger.info(f"功能配置已成功迁移到: {new_features_path}") + + # 显示迁移的配置内容 + logger.info("迁移的配置内容:") + for key, value in new_features_config.items(): + logger.info(f" {key}: {value}") + + return True + + except Exception as e: + logger.error(f"功能配置迁移失败: {e}") + return False + + +def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"): + """ + 从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录 + + Args: + config_path: 配置文件路径 + """ + try: + if not os.path.exists(config_path): + logger.warning(f"配置文件不存在: {config_path}") + return False + + # 确保 config/old 目录存在 + old_config_dir = "plugins/napcat_adapter_plugin/config/old" + os.makedirs(old_config_dir, exist_ok=True) + + # 备份原配置文件到 config/old 目录 + old_config_path = os.path.join(old_config_dir, "config_with_features.toml") + shutil.copy2(config_path, old_config_path) + logger.info(f"已备份包含功能配置的原文件到: {old_config_path}") + + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config = tomlkit.load(f) + + # 移除chat段中的功能相关配置 + removed_keys = [] + if "chat" in config: + chat_config = config["chat"] + permission_keys = ["group_list_type", "group_list", "private_list_type", + "private_list", "ban_user_id", "ban_qq_bot", + "enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"] + + for key in permission_keys: + if key in chat_config: + del chat_config[key] + removed_keys.append(key) + + if removed_keys: + logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}") + + # 移除video段中的配置 + if "video" in config: + video_config = config["video"] + video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"] + + video_removed_keys = [] + for key in video_keys: + if key in video_config: + del video_config[key] + video_removed_keys.append(key) + + if video_removed_keys: + logger.info(f"已从video配置段中移除配置: {video_removed_keys}") + removed_keys.extend(video_removed_keys) + + # 如果video段为空,则删除整个段 + if not video_config: + del config["video"] + logger.info("已删除空的video配置段") + + if removed_keys: + logger.info(f"总共移除的配置项: {removed_keys}") + + # 写回配置文件 + with open(config_path, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(config)) + + logger.info(f"已更新配置文件: {config_path}") + return True + + except Exception as e: + logger.error(f"移除功能配置失败: {e}") + return False + + +def auto_migrate_features(): + """ + 自动执行功能配置迁移 + """ + logger.info("开始自动功能配置迁移...") + + # 执行迁移 + if migrate_features_from_config(): + logger.info("功能配置迁移成功") + + # 询问是否要从旧配置文件中移除功能配置 + logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置") + # 在实际使用中,这里可以添加用户确认逻辑 + # 为了自动化,这里直接执行移除 + remove_features_from_old_config() + + else: + logger.info("功能配置迁移跳过或失败") + + +if __name__ == "__main__": + auto_migrate_features() \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/src/config/official_configs.py b/plugins/napcat_adapter_plugin/src/config/official_configs.py new file mode 100644 index 000000000..d30c9be10 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/config/official_configs.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass, field +from typing import Literal + +from .config_base import ConfigBase + +""" +须知: +1. 本文件中记录了所有的配置项 +2. 所有新增的class都需要继承自ConfigBase +3. 所有新增的class都应在config.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +""" + +ADAPTER_PLATFORM = "qq" + + +@dataclass +class NicknameConfig(ConfigBase): + nickname: str + """机器人昵称""" + + +@dataclass +class NapcatServerConfig(ConfigBase): + mode: Literal["reverse", "forward"] = "reverse" + """连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)""" + + host: str = "localhost" + """主机地址""" + + port: int = 8095 + """端口号""" + + url: str = "" + """正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws""" + + access_token: str = "" + """WebSocket 连接的访问令牌,用于身份验证""" + + heartbeat_interval: int = 30 + """心跳间隔时间,单位为秒""" + + +@dataclass +class MaiBotServerConfig(ConfigBase): + platform_name: str = field(default=ADAPTER_PLATFORM, init=False) + """平台名称,“qq”""" + + host: str = "localhost" + """MaiMCore的主机地址""" + + port: int = 8000 + """MaiMCore的端口号""" + + + + +@dataclass +class VoiceConfig(ConfigBase): + use_tts: bool = False + """是否启用TTS功能""" + + +@dataclass +class DebugConfig(ConfigBase): + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + """日志级别,默认为INFO""" diff --git a/plugins/napcat_adapter_plugin/src/database.py b/plugins/napcat_adapter_plugin/src/database.py new file mode 100644 index 000000000..ae34f3b7d --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/database.py @@ -0,0 +1,163 @@ +import os +from typing import Optional, List +from dataclasses import dataclass +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + +""" +表记录的方式: +| group_id | user_id | lift_time | +|----------|---------|-----------| + +其中使用 user_id == 0 表示群全体禁言 +""" + + +@dataclass +class BanUser: + """ + 程序处理使用的实例 + """ + + user_id: int + group_id: int + lift_time: Optional[int] = Field(default=-1) + + +class DB_BanUser(SQLModel, table=True): + """ + 表示数据库中的用户禁言记录。 + 使用双重主键 + """ + + user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID + group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID + lift_time: Optional[int] # 禁言解除的时间(时间戳) + + +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: + """ + 检查两个 BanUser 对象是否相同。 + """ + return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id + + +class DatabaseManager: + """ + 数据库管理类,负责与数据库交互。 + """ + + def __init__(self): + os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 + DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") + self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL + self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 + self._ensure_database() # 确保数据库和表已创建 + + def _ensure_database(self) -> None: + """ + 确保数据库和表已创建。 + """ + logger.info("确保数据库文件和表已创建...") + SQLModel.metadata.create_all(self.engine) + logger.info("数据库和表已创建或已存在") + + def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method + """ + 更新禁言列表到数据库。 + 支持在不存在时创建新记录,对于多余的项目自动删除。 + """ + with Session(self.engine) as session: + all_records = session.exec(select(DB_BanUser)).all() + for ban_user in ban_list: + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id + ) + if existing_record := session.exec(statement).first(): + if existing_record.lift_time == ban_user.lift_time: + logger.debug(f"禁言记录未变更: {existing_record}") + continue + # 更新现有记录的 lift_time + existing_record.lift_time = ban_user.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {existing_record}") + else: + # 创建新记录 + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_user}") + # 删除不在 ban_list 中的记录 + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) + if not any(is_identical(record, ban_user) for ban_user in ban_list): + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id + ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") + + + logger.info("禁言记录已更新") + + def get_ban_records(self) -> List[BanUser]: + """ + 读取所有禁言记录。 + """ + with Session(self.engine) as session: + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] + + def create_ban_record(self, ban_record: BanUser) -> None: + """ + 为特定群组中的用户创建禁言记录。 + 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 + 其同时还是简化版的更新方式。 + """ + with Session(self.engine) as session: + # 检查记录是否已存在 + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + ) + existing_record = session.exec(statement).first() + if existing_record: + # 如果记录已存在,更新 lift_time + existing_record.lift_time = ban_record.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {ban_record}") + else: + # 如果记录不存在,创建新记录 + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_record}") + + + def delete_ban_record(self, ban_record: BanUser): + """ + 删除特定用户在特定群组中的禁言记录。 + 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 + """ + user_id = ban_record.user_id + group_id = ban_record.group_id + with Session(self.engine) as session: + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + + +db_manager = DatabaseManager() diff --git a/plugins/napcat_adapter_plugin/src/message_buffer.py b/plugins/napcat_adapter_plugin/src/message_buffer.py new file mode 100644 index 000000000..531f56230 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/message_buffer.py @@ -0,0 +1,320 @@ +import asyncio +import time +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + +from .config.features_config import features_manager +from .recv_handler import RealMessageType + + +@dataclass +class TextMessage: + """文本消息""" + text: str + timestamp: float = field(default_factory=time.time) + + +@dataclass +class BufferedSession: + """缓冲会话数据""" + session_id: str + messages: List[TextMessage] = field(default_factory=list) + timer_task: Optional[asyncio.Task] = None + delay_task: Optional[asyncio.Task] = None + original_event: Any = None + created_at: float = field(default_factory=time.time) + + +class SimpleMessageBuffer: + + def __init__(self, merge_callback=None): + """ + 初始化消息缓冲器 + + Args: + merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数 + """ + self.buffer_pool: Dict[str, BufferedSession] = {} + self.lock = asyncio.Lock() + self.merge_callback = merge_callback + self._shutdown = False + + def get_session_id(self, event_data: Dict[str, Any]) -> str: + """根据事件数据生成会话ID""" + message_type = event_data.get("message_type", "unknown") + user_id = event_data.get("user_id", "unknown") + + if message_type == "private": + return f"private_{user_id}" + elif message_type == "group": + group_id = event_data.get("group_id", "unknown") + return f"group_{group_id}_{user_id}" + else: + return f"{message_type}_{user_id}" + + def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]: + """从OneBot消息中提取纯文本,如果包含非文本内容则返回None""" + text_parts = [] + has_non_text = False + + logger.debug(f"正在提取消息文本,消息段数量: {len(message)}") + + for msg_seg in message: + msg_type = msg_seg.get("type", "") + logger.debug(f"处理消息段类型: {msg_type}") + + if msg_type == RealMessageType.text: + text = msg_seg.get("data", {}).get("text", "").strip() + if text: + text_parts.append(text) + logger.debug(f"提取到文本: {text[:50]}...") + else: + # 发现非文本消息段,标记为包含非文本内容 + has_non_text = True + logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲") + + # 如果包含非文本内容,则不进行缓冲 + if has_non_text: + logger.debug("消息包含非文本内容,不进行缓冲") + return None + + if text_parts: + combined_text = " ".join(text_parts).strip() + logger.debug(f"成功提取纯文本: {combined_text[:50]}...") + return combined_text + + logger.debug("没有找到有效的文本内容") + return None + + def should_skip_message(self, text: str) -> bool: + """判断消息是否应该跳过缓冲""" + if not text or not text.strip(): + return True + + # 检查屏蔽前缀 + config = features_manager.get_config() + block_prefixes = tuple(config.message_buffer_block_prefixes) + + text = text.strip() + if text.startswith(block_prefixes): + logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...") + return True + + return False + + async def add_text_message(self, event_data: Dict[str, Any], message: List[Dict[str, Any]], + original_event: Any = None) -> bool: + """ + 添加文本消息到缓冲区 + + Args: + event_data: 事件数据 + message: OneBot消息数组 + original_event: 原始事件对象 + + Returns: + 是否成功添加到缓冲区 + """ + if self._shutdown: + return False + + config = features_manager.get_config() + if not config.enable_message_buffer: + return False + + # 检查是否启用对应类型的缓冲 + message_type = event_data.get("message_type", "") + if message_type == "group" and not config.message_buffer_enable_group: + return False + elif message_type == "private" and not config.message_buffer_enable_private: + return False + + # 提取文本 + text = self.extract_text_from_message(message) + if not text: + return False + + # 检查是否应该跳过 + if self.should_skip_message(text): + return False + + session_id = self.get_session_id(event_data) + + async with self.lock: + # 获取或创建会话 + if session_id not in self.buffer_pool: + self.buffer_pool[session_id] = BufferedSession( + session_id=session_id, + original_event=original_event + ) + + session = self.buffer_pool[session_id] + + # 检查是否超过最大组件数量 + if len(session.messages) >= config.message_buffer_max_components: + logger.info(f"会话 {session_id} 消息数量达到上限,强制合并") + asyncio.create_task(self._force_merge_session(session_id)) + self.buffer_pool[session_id] = BufferedSession( + session_id=session_id, + original_event=original_event + ) + session = self.buffer_pool[session_id] + + # 添加文本消息 + session.messages.append(TextMessage(text=text)) + session.original_event = original_event # 更新事件 + + # 取消之前的定时器 + await self._cancel_session_timers(session) + + # 设置新的延迟任务 + session.delay_task = asyncio.create_task( + self._wait_and_start_merge(session_id) + ) + + logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...") + return True + + async def _cancel_session_timers(self, session: BufferedSession): + """取消会话的所有定时器""" + for task_name in ['timer_task', 'delay_task']: + task = getattr(session, task_name) + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + setattr(session, task_name, None) + + async def _wait_and_start_merge(self, session_id: str): + """等待初始延迟后开始合并定时器""" + config = features_manager.get_config() + await asyncio.sleep(config.message_buffer_initial_delay) + + async with self.lock: + session = self.buffer_pool.get(session_id) + if session and session.messages: + # 取消旧的定时器 + if session.timer_task and not session.timer_task.done(): + session.timer_task.cancel() + try: + await session.timer_task + except asyncio.CancelledError: + pass + + # 设置合并定时器 + session.timer_task = asyncio.create_task( + self._wait_and_merge(session_id) + ) + + async def _wait_and_merge(self, session_id: str): + """等待合并间隔后执行合并""" + config = features_manager.get_config() + await asyncio.sleep(config.message_buffer_interval) + await self._merge_session(session_id) + + async def _force_merge_session(self, session_id: str): + """强制合并会话(不等待定时器)""" + await self._merge_session(session_id, force=True) + + async def _merge_session(self, session_id: str, force: bool = False): + """合并会话中的消息""" + async with self.lock: + session = self.buffer_pool.get(session_id) + if not session or not session.messages: + self.buffer_pool.pop(session_id, None) + return + + try: + # 合并文本消息 + text_parts = [] + for msg in session.messages: + if msg.text.strip(): + text_parts.append(msg.text.strip()) + + if not text_parts: + self.buffer_pool.pop(session_id, None) + return + + merged_text = ",".join(text_parts) # 使用中文逗号连接 + message_count = len(session.messages) + + logger.info(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...") + + # 调用回调函数 + if self.merge_callback: + try: + if asyncio.iscoroutinefunction(self.merge_callback): + await self.merge_callback(session_id, merged_text, session.original_event) + else: + self.merge_callback(session_id, merged_text, session.original_event) + except Exception as e: + logger.error(f"消息合并回调执行失败: {e}") + + except Exception as e: + logger.error(f"合并会话 {session_id} 时出错: {e}") + finally: + # 清理会话 + await self._cancel_session_timers(session) + self.buffer_pool.pop(session_id, None) + + async def flush_session(self, session_id: str): + """强制刷新指定会话的缓冲区""" + await self._force_merge_session(session_id) + + async def flush_all(self): + """强制刷新所有会话的缓冲区""" + session_ids = list(self.buffer_pool.keys()) + for session_id in session_ids: + await self._force_merge_session(session_id) + + async def get_buffer_stats(self) -> Dict[str, Any]: + """获取缓冲区统计信息""" + async with self.lock: + stats = { + "total_sessions": len(self.buffer_pool), + "sessions": {} + } + + for session_id, session in self.buffer_pool.items(): + stats["sessions"][session_id] = { + "message_count": len(session.messages), + "created_at": session.created_at, + "age": time.time() - session.created_at + } + + return stats + + async def clear_expired_sessions(self, max_age: float = 300.0): + """清理过期的会话""" + current_time = time.time() + expired_sessions = [] + + async with self.lock: + for session_id, session in self.buffer_pool.items(): + if current_time - session.created_at > max_age: + expired_sessions.append(session_id) + + for session_id in expired_sessions: + logger.info(f"清理过期会话: {session_id}") + await self._force_merge_session(session_id) + + async def shutdown(self): + """关闭消息缓冲器""" + self._shutdown = True + logger.info("正在关闭简化消息缓冲器...") + + # 刷新所有缓冲区 + await self.flush_all() + + # 确保所有任务都被取消 + async with self.lock: + for session in list(self.buffer_pool.values()): + await self._cancel_session_timers(session) + self.buffer_pool.clear() + + logger.info("简化消息缓冲器已关闭") diff --git a/plugins/napcat_adapter_plugin/src/mmc_com_layer.py b/plugins/napcat_adapter_plugin/src/mmc_com_layer.py new file mode 100644 index 000000000..14cddf102 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/mmc_com_layer.py @@ -0,0 +1,26 @@ +from maim_message import Router, RouteConfig, TargetConfig +from .config import global_config +from src.common.logger import get_logger +from .send_handler import send_handler + +logger = get_logger("napcat_adapter") + +route_config = RouteConfig( + route_config={ + global_config.maibot_server.platform_name: TargetConfig( + url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws", + token=None, + ) + } +) +router = Router(route_config) + + +async def mmc_start_com(): + logger.info("正在连接MaiBot") + router.register_class_handler(send_handler.handle_message) + await router.run() + + +async def mmc_stop_com(): + await router.stop() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/__init__.py b/plugins/napcat_adapter_plugin/src/recv_handler/__init__.py new file mode 100644 index 000000000..b2fb9bad1 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/__init__.py @@ -0,0 +1,89 @@ +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + group_ban = "group_ban" # 群禁言 + + class Notify: + poke = "poke" # 戳一戳 + input_status = "input_status" # 正在输入 + + class GroupBan: + ban = "ban" # 禁言 + lift_ban = "lift_ban" # 解除禁言 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + json = "json" # json消息 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + DELETE_MSG = "delete_msg" # 撤回消息 + + def __str__(self) -> str: + return self.value + + +ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"] diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py b/plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py new file mode 100644 index 000000000..0e4ba29fe --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -0,0 +1,942 @@ +from ...event_types import NapcatEvent +from src.plugin_system.core.event_manager import event_manager +from src.common.logger import get_logger +from ...CONSTS import PLUGIN_NAME + +logger = get_logger("napcat_adapter") + +from ..config import global_config +from ..config.features_config import features_manager +from ..message_buffer import SimpleMessageBuffer +from ..utils import ( + get_group_info, + get_member_info, + get_image_base64, + get_record_detail, + get_self_info, + get_message_detail, +) +from .qq_emoji_list import qq_face +from .message_sending import message_send_instance +from . import RealMessageType, MessageType, ACCEPT_FORMAT +from ..video_handler import get_video_downloader +from ..websocket_manager import websocket_manager + +import time +import json +import websockets as Server +import base64 +from pathlib import Path +from typing import List, Tuple, Optional, Dict, Any +import uuid + +from maim_message import ( + UserInfo, + GroupInfo, + Seg, + BaseMessageInfo, + MessageBase, + TemplateInfo, + FormatInfo, +) + + +from ..response_pool import get_response + + +class MessageHandler: + def __init__(self): + self.server_connection: Server.ServerConnection = None + self.bot_id_list: Dict[int, bool] = {} + # 初始化简化消息缓冲器,传入回调函数 + self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message) + + async def shutdown(self): + """关闭消息处理器,清理资源""" + if self.message_buffer: + await self.message_buffer.shutdown() + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + def get_server_connection(self) -> Server.ServerConnection: + """获取当前的服务器连接""" + # 优先使用直接设置的连接,否则从 websocket_manager 获取 + if self.server_connection: + return self.server_connection + return websocket_manager.get_connection() + + async def check_allow_to_chat( + self, + user_id: int, + group_id: Optional[int] = None, + ignore_bot: Optional[bool] = False, + ignore_global_list: Optional[bool] = False, + ) -> bool: + # sourcery skip: hoist-statement-from-if, merge-else-if-into-elif + """ + 检查是否允许聊天 + Parameters: + user_id: int: 用户ID + group_id: int: 群ID + ignore_bot: bool: 是否忽略机器人检查 + ignore_global_list: bool: 是否忽略全局黑名单检查 + Returns: + bool: 是否允许聊天 + """ + logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") + logger.debug("开始检查聊天白名单/黑名单") + + # 使用新的权限管理器检查权限 + if group_id: + if not features_manager.is_group_allowed(group_id): + logger.warning("群聊不在聊天权限范围内,消息被丢弃") + return False + else: + if not features_manager.is_private_allowed(user_id): + logger.warning("私聊不在聊天权限范围内,消息被丢弃") + return False + + # 检查全局禁止名单 + if not ignore_global_list and features_manager.is_user_banned(user_id): + logger.warning("用户在全局黑名单中,消息被丢弃") + return False + + # 检查QQ官方机器人 + if features_manager.is_qq_bot_banned() and group_id and not ignore_bot: + logger.debug("开始判断是否为机器人") + member_info = await get_member_info(self.get_server_connection(), group_id, user_id) + if member_info: + is_bot = member_info.get("is_robot") + if is_bot is None: + logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + else: + if is_bot: + logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") + self.bot_id_list[user_id] = True + return False + else: + self.bot_id_list[user_id] = False + + return True + + async def handle_raw_message(self, raw_message: dict) -> None: + # sourcery skip: low-code-quality, remove-unreachable-code + """ + 从Napcat接受的原始消息处理 + + Parameters: + raw_message: dict: 原始消息 + """ + message_type: str = raw_message.get("message_type") + message_id: int = raw_message.get("message_id") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用 + format_info: FormatInfo = FormatInfo( + content_format=["text", "image", "emoji", "voice"], + accept_format=ACCEPT_FORMAT, + ) # 格式化信息 + if message_type == MessageType.private: + sub_type = raw_message.get("sub_type") + if sub_type == MessageType.Private.friend: + sender_info: dict = raw_message.get("sender") + + if not await self.check_allow_to_chat(sender_info.get("user_id"), None): + return None + + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=sender_info.get("nickname"), + user_cardname=sender_info.get("card"), + ) + + # 不存在群信息 + group_info: GroupInfo = None + elif sub_type == MessageType.Private.group: + """ + 本部分暂时不做支持,先放着 + """ + logger.warning("群临时消息类型不支持") + return None + + sender_info: dict = raw_message.get("sender") + + # 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取 + fetched_member_info: dict = await get_member_info( + self.get_server_connection(), + raw_message.get("group_id"), + sender_info.get("user_id"), + ) + nickname = fetched_member_info.get("nickname") if fetched_member_info else None + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=nickname, + user_cardname=None, + ) + + # -------------------这里需要群信息吗?------------------- + + # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 + fetched_group_info: dict = await get_group_info(self.get_server_connection(), raw_message.get("group_id")) + group_name = "" + if fetched_group_info.get("group_name"): + group_name = fetched_group_info.get("group_name") + + group_info: GroupInfo = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=raw_message.get("group_id"), + group_name=group_name, + ) + + else: + logger.warning(f"私聊消息类型 {sub_type} 不支持") + return None + elif message_type == MessageType.group: + sub_type = raw_message.get("sub_type") + if sub_type == MessageType.Group.normal: + sender_info: dict = raw_message.get("sender") + + if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")): + return None + + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=sender_info.get("nickname"), + user_cardname=sender_info.get("card"), + ) + + # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 + fetched_group_info = await get_group_info(self.get_server_connection(), raw_message.get("group_id")) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + + group_info: GroupInfo = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=raw_message.get("group_id"), + group_name=group_name, + ) + + else: + logger.warning(f"群聊消息类型 {sub_type} 不支持") + return None + + additional_config: dict = {} + if global_config.voice.use_tts: + additional_config["allow_tts"] = True + + # 消息信息 + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id=message_id, + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=template_info, + format_info=format_info, + additional_config=additional_config, + ) + + # 处理实际信息 + if not raw_message.get("message"): + logger.warning("原始消息内容为空") + return None + + # 获取Seg列表 + seg_message: List[Seg] = await self.handle_real_message(raw_message) + if not seg_message: + logger.warning("处理后消息内容为空") + return None + + # 检查是否需要使用消息缓冲 + if features_manager.is_message_buffer_enabled(): + # 检查消息类型是否启用缓冲 + message_type = raw_message.get("message_type") + should_use_buffer = False + + if message_type == "group" and features_manager.is_message_buffer_group_enabled(): + should_use_buffer = True + elif message_type == "private" and features_manager.is_message_buffer_private_enabled(): + should_use_buffer = True + + if should_use_buffer: + logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}") + logger.debug(f"原始消息段: {raw_message.get('message', [])}") + + # 尝试添加到缓冲器 + buffered = await self.message_buffer.add_text_message( + event_data={ + "message_type": message_type, + "user_id": user_info.user_id, + "group_id": group_info.group_id if group_info else None, + }, + message=raw_message.get("message", []), + original_event={ + "message_info": message_info, + "raw_message": raw_message + } + ) + + if buffered: + logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}") + return None # 缓冲成功,不立即发送 + # 如果缓冲失败(消息包含非文本元素),走正常处理流程 + logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}") + # 缓冲失败时继续执行后面的正常处理流程,不要直接返回 + + logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}") + for i, seg in enumerate(seg_message): + logger.debug(f"消息段 {i}: type={seg.type}, data={str(seg.data)[:100]}...") + + submit_seg: Seg = Seg( + type="seglist", + data=seg_message, + ) + # MessageBase创建 + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=submit_seg, + raw_message=raw_message.get("raw_message"), + ) + + logger.info("发送到Maibot处理信息") + await message_send_instance.message_send(message_base) + + async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: + # sourcery skip: low-code-quality + """ + 处理实际消息 + Parameters: + real_message: dict: 实际消息 + Returns: + seg_message: list[Seg]: 处理后的消息段列表 + """ + real_message: list = raw_message.get("message") + if not real_message: + return None + seg_message: List[Seg] = [] + for sub_message in real_message: + sub_message: dict + sub_message_type = sub_message.get("type") + match sub_message_type: + case RealMessageType.text: + ret_seg = await self.handle_text_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.TEXT,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("text处理失败") + case RealMessageType.face: + ret_seg = await self.handle_face_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FACE,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("face处理失败或不支持") + case RealMessageType.reply: + if not in_reply: + ret_seg = await self.handle_reply_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.REPLY,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message += ret_seg + else: + logger.warning("reply处理失败") + case RealMessageType.image: + logger.debug(f"开始处理图片消息段") + ret_seg = await self.handle_image_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.IMAGE,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + logger.debug(f"图片处理成功,添加到消息段") + else: + logger.warning("image处理失败") + logger.debug(f"图片消息段处理完成") + case RealMessageType.record: + ret_seg = await self.handle_record_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RECORD,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.clear() + seg_message.append(ret_seg) + break # 使得消息只有record消息 + else: + logger.warning("record处理失败或不支持") + case RealMessageType.video: + ret_seg = await self.handle_video_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.VIDEO,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("video处理失败") + case RealMessageType.at: + ret_seg = await self.handle_at_message( + sub_message, + raw_message.get("self_id"), + raw_message.get("group_id"), + ) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.AT,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("at处理失败") + case RealMessageType.rps: + ret_seg = await self.handle_rps_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.RPS,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("rps处理失败") + case RealMessageType.dice: + ret_seg = await self.handle_dice_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.DICE,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("dice处理失败") + case RealMessageType.shake: + ret_seg = await self.handle_shake_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.SHAKE,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("shake处理失败") + case RealMessageType.share: + print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌SHARE消息,快速速把你刚刚收到的消息截图发到MoFox-Bot群里!!!!\n\n") + logger.warning("暂时不支持链接解析") + case RealMessageType.forward: + messages = await self._get_forward_message(sub_message) + if not messages: + logger.warning("转发消息内容为空或获取失败") + return None + ret_seg = await self.handle_forward_message(messages) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("转发消息处理失败") + case RealMessageType.node: + print("\n\n哦哦哦噢噢噢哦哦你收到了一个超级无敌NODE消息,快速速把你刚刚收到的消息截图发到MoFox-Bot群里!!!!\n\n") + logger.warning("不支持转发消息节点解析") + case RealMessageType.json: + ret_seg = await self.handle_json_message(sub_message) + if ret_seg: + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.JSON,plugin_name=PLUGIN_NAME,message_seg=ret_seg) + seg_message.append(ret_seg) + else: + logger.warning("json处理失败") + case _: + logger.warning(f"未知消息类型: {sub_message_type}") + + logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg") + return seg_message + + async def handle_text_message(self, raw_message: dict) -> Seg: + """ + 处理纯文本信息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + plain_text: str = message_data.get("text") + return Seg(type="text", data=plain_text) + + async def handle_face_message(self, raw_message: dict) -> Seg | None: + """ + 处理表情消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + face_raw_id: str = str(message_data.get("id")) + if face_raw_id in qq_face: + face_content: str = qq_face.get(face_raw_id) + return Seg(type="text", data=face_content) + else: + logger.warning(f"不支持的表情:{face_raw_id}") + return None + + async def handle_image_message(self, raw_message: dict) -> Seg | None: + """ + 处理图片消息与表情包消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + image_sub_type = message_data.get("sub_type") + try: + logger.debug(f"开始下载图片: {message_data.get('url')}") + image_base64 = await get_image_base64(message_data.get("url")) + logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符") + except Exception as e: + logger.error(f"图片消息处理失败: {str(e)}") + return None + if image_sub_type == 0: + """这部分认为是图片""" + return Seg(type="image", data=image_base64) + elif image_sub_type not in [4, 9]: + """这部分认为是表情包""" + return Seg(type="emoji", data=image_base64) + else: + logger.warning(f"不支持的图片子类型:{image_sub_type}") + return None + + async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None: + # sourcery skip: use-named-expression + """ + 处理at消息 + Parameters: + raw_message: dict: 原始消息 + self_id: int: 机器人QQ号 + group_id: int: 群号 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + if message_data: + qq_id = message_data.get("qq") + if str(self_id) == str(qq_id): + logger.debug("机器人被at") + self_info: dict = await get_self_info(self.get_server_connection()) + if self_info: + return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>") + else: + return None + else: + member_info: dict = await get_member_info(self.get_server_connection(), group_id=group_id, user_id=qq_id) + if member_info: + return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>") + else: + return None + + async def handle_record_message(self, raw_message: dict) -> Seg | None: + """ + 处理语音消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + file: str = message_data.get("file") + if not file: + logger.warning("语音消息缺少文件信息") + return None + try: + record_detail = await get_record_detail(self.get_server_connection(), file) + if not record_detail: + logger.warning("获取语音消息详情失败") + return None + audio_base64: str = record_detail.get("base64") + except Exception as e: + logger.error(f"语音消息处理失败: {str(e)}") + return None + if not audio_base64: + logger.error("语音消息处理失败,未获取到音频数据") + return None + return Seg(type="voice", data=audio_base64) + + async def handle_video_message(self, raw_message: dict) -> Seg | None: + """ + 处理视频消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + + # 添加详细的调试信息 + logger.debug(f"视频消息原始数据: {raw_message}") + logger.debug(f"视频消息数据: {message_data}") + + # QQ视频消息可能包含url或filePath字段 + video_url = message_data.get("url") + file_path = message_data.get("filePath") or message_data.get("file_path") + + logger.info(f"视频URL: {video_url}") + logger.info(f"视频文件路径: {file_path}") + + # 优先使用本地文件路径,其次使用URL + video_source = file_path if file_path else video_url + + if not video_source: + logger.warning("视频消息缺少URL或文件路径信息") + logger.warning(f"完整消息数据: {message_data}") + return None + + try: + # 检查是否为本地文件路径 + if file_path and Path(file_path).exists(): + logger.info(f"使用本地视频文件: {file_path}") + # 直接读取本地文件 + with open(file_path, "rb") as f: + video_data = f.read() + + # 将视频数据编码为base64用于传输 + video_base64 = base64.b64encode(video_data).decode('utf-8') + logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB") + + # 返回包含详细信息的字典格式 + return Seg(type="video", data={ + "base64": video_base64, + "filename": Path(file_path).name, + "size_mb": len(video_data) / (1024 * 1024) + }) + + elif video_url: + logger.info(f"使用视频URL下载: {video_url}") + # 使用video_handler下载视频 + video_downloader = get_video_downloader() + download_result = await video_downloader.download_video(video_url) + + if not download_result["success"]: + logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}") + logger.warning(f"失败的URL: {video_url}") + return None + + # 将视频数据编码为base64用于传输 + video_base64 = base64.b64encode(download_result["data"]).decode('utf-8') + logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB") + + # 返回包含详细信息的字典格式 + return Seg(type="video", data={ + "base64": video_base64, + "filename": download_result.get("filename", "video.mp4"), + "size_mb": len(download_result["data"]) / (1024 * 1024), + "url": video_url + }) + + else: + logger.warning("既没有有效的本地文件路径,也没有有效的视频URL") + return None + + except Exception as e: + logger.error(f"视频消息处理失败: {str(e)}") + logger.error(f"视频源: {video_source}") + return None + + async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None: + # sourcery skip: move-assign-in-block, use-named-expression + """ + 处理回复消息 + + """ + raw_message_data: dict = raw_message.get("data") + message_id: int = None + if raw_message_data: + message_id = raw_message_data.get("id") + else: + return None + message_detail: dict = await get_message_detail(self.get_server_connection(), message_id) + if not message_detail: + logger.warning("获取被引用的消息详情失败") + return None + reply_message = await self.handle_real_message(message_detail, in_reply=True) + if reply_message is None: + reply_message = [Seg(type="text", data="(获取发言内容失败)")] + sender_info: dict = message_detail.get("sender") + sender_nickname: str = sender_info.get("nickname") + sender_id: str = sender_info.get("user_id") + seg_message: List[Seg] = [] + if not sender_nickname: + logger.warning("无法获取被引用的人的昵称,返回默认值") + seg_message.append(Seg(type="text", data="[回复 未知用户:")) + else: + seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:")) + seg_message += reply_message + seg_message.append(Seg(type="text", data="],说:")) + return seg_message + + async def handle_forward_message(self, message_list: list) -> Seg | None: + """ + 递归处理转发消息,并按照动态方式确定图片处理方式 + Parameters: + message_list: list: 转发消息列表 + """ + handled_message, image_count = await self._handle_forward_message( + message_list, 0 + ) + handled_message: Seg + image_count: int + if not handled_message: + return None + + processed_message: Seg + if image_count < 5 and image_count > 0: + # 处理图片数量小于5的情况,此时解析图片为base64 + logger.info("图片数量小于5,开始解析图片为base64") + processed_message = await self._recursive_parse_image_seg( + handled_message, True + ) + elif image_count > 0: + logger.info("图片数量大于等于5,开始解析图片为占位符") + # 处理图片数量大于等于5的情况,此时解析图片为占位符 + processed_message = await self._recursive_parse_image_seg( + handled_message, False + ) + else: + # 处理没有图片的情况,此时直接返回 + logger.info("没有图片,直接返回") + processed_message = handled_message + + # 添加转发消息提示 + forward_hint = Seg(type="text", data="这是一条转发消息:\n") + return Seg(type="seglist", data=[forward_hint, processed_message]) + + async def handle_dice_message(self, raw_message: dict) -> Seg: + message_data: dict = raw_message.get("data",{}) + res = message_data.get("result","") + return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]") + + async def handle_shake_message(self, raw_message: dict) -> Seg: + return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]") + + async def handle_json_message(self, raw_message: dict) -> Seg: + message_data: str = raw_message.get("data","").get("data","") + res = json.loads(message_data) + return Seg(type="json", data=res) + + async def handle_rps_message(self, raw_message: dict) -> Seg: + message_data: dict = raw_message.get("data",{}) + res = message_data.get("result","") + if res == "1": + shape = "布" + elif res == "2": + shape = "剪刀" + else: + shape = "石头" + return Seg(type="text", data=f"[发送了一个魔法猜拳表情,结果是:{shape}]") + + async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg: + # sourcery skip: merge-else-if-into-elif + if to_image: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + image_url = seg_data.data + try: + encoded_image = await get_image_base64(image_url) + except Exception as e: + logger.error(f"图片处理失败: {str(e)}") + return Seg(type="text", data="[图片]") + return Seg(type="image", data=encoded_image) + elif seg_data.type == "emoji": + image_url = seg_data.data + try: + encoded_image = await get_image_base64(image_url) + except Exception as e: + logger.error(f"图片处理失败: {str(e)}") + return Seg(type="text", data="[表情包]") + return Seg(type="emoji", data=encoded_image) + else: + logger.info(f"不处理类型: {seg_data.type}") + return seg_data + else: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + return Seg(type="text", data="[图片]") + elif seg_data.type == "emoji": + return Seg(type="text", data="[动画表情]") + else: + logger.info(f"不处理类型: {seg_data.type}") + return seg_data + + async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]: + # sourcery skip: low-code-quality + """ + 递归处理实际转发消息 + Parameters: + message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段 + layer: int: 当前层级 + Returns: + seg_data: Seg: 处理后的消息段 + image_count: int: 图片数量 + """ + seg_list: List[Seg] = [] + image_count = 0 + if message_list is None: + return None, 0 + for sub_message in message_list: + sub_message: dict + sender_info: dict = sub_message.get("sender") + user_nickname: str = sender_info.get("nickname", "QQ用户") + user_nickname_str = f"【{user_nickname}】:" + break_seg = Seg(type="text", data="\n") + message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message") + if not message_of_sub_message_list: + logger.warning("转发消息内容为空") + continue + message_of_sub_message = message_of_sub_message_list[0] + if message_of_sub_message.get("type") == RealMessageType.forward: + if layer >= 3: + full_seg_data = Seg( + type="text", + data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n", + ) + else: + sub_message_data = message_of_sub_message.get("data") + if not sub_message_data: + continue + contents = sub_message_data.get("content") + seg_data, count = await self._handle_forward_message(contents, layer + 1) + image_count += count + head_tip = Seg( + type="text", + data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n", + ) + full_seg_data = Seg(type="seglist", data=[head_tip, seg_data]) + seg_list.append(full_seg_data) + elif message_of_sub_message.get("type") == RealMessageType.text: + sub_message_data = message_of_sub_message.get("data") + if not sub_message_data: + continue + text_message = sub_message_data.get("text") + seg_data = Seg(type="text", data=text_message) + data_list: List[Any] = [] + if layer > 0: + data_list = [ + Seg(type="text", data=("--" * layer) + user_nickname_str), + seg_data, + break_seg, + ] + else: + data_list = [ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ] + seg_list.append(Seg(type="seglist", data=data_list)) + elif message_of_sub_message.get("type") == RealMessageType.image: + image_count += 1 + image_data = message_of_sub_message.get("data") + sub_type = image_data.get("sub_type") + image_url = image_data.get("url") + data_list: List[Any] = [] + if sub_type == 0: + seg_data = Seg(type="image", data=image_url) + else: + seg_data = Seg(type="emoji", data=image_url) + if layer > 0: + data_list = [ + Seg(type="text", data=("--" * layer) + user_nickname_str), + seg_data, + break_seg, + ] + else: + data_list = [ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ] + full_seg_data = Seg(type="seglist", data=data_list) + seg_list.append(full_seg_data) + return Seg(type="seglist", data=seg_list), image_count + + async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: + forward_message_data: Dict = raw_message.get("data") + if not forward_message_data: + logger.warning("转发消息内容为空") + return None + forward_message_id = forward_message_data.get("id") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_forward_msg", + "params": {"message_id": forward_message_id}, + "echo": request_uuid, + } + ) + try: + connection = self.get_server_connection() + if not connection: + logger.error("没有可用的 WebSocket 连接") + return None + await connection.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error("获取转发消息超时") + return None + except Exception as e: + logger.error(f"获取转发消息失败: {str(e)}") + return None + logger.debug( + f"转发消息原始格式:{json.dumps(response)[:80]}..." + if len(json.dumps(response)) > 80 + else json.dumps(response) + ) + response_data: Dict = response.get("data") + if not response_data: + logger.warning("转发消息内容为空或获取失败") + return None + return response_data.get("messages") + + async def _send_buffered_message(self, session_id: str, merged_text: str, original_event: Dict[str, Any]): + """发送缓冲的合并消息""" + try: + # 从原始事件数据中提取信息 + message_info = original_event.get("message_info") + raw_message = original_event.get("raw_message") + + if not message_info or not raw_message: + logger.error("缓冲消息缺少必要信息") + return + + # 创建合并后的消息段 - 将合并的文本转换为Seg格式 + from maim_message import Seg + merged_seg = Seg(type="text", data=merged_text) + submit_seg = Seg(type="seglist", data=[merged_seg]) + + # 创建新的消息ID + import time + new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}" + + # 更新消息信息 + from maim_message import BaseMessageInfo, MessageBase + buffered_message_info = BaseMessageInfo( + platform=message_info.platform, + message_id=new_message_id, + time=time.time(), + user_info=message_info.user_info, + group_info=message_info.group_info, + template_info=message_info.template_info, + format_info=message_info.format_info, + additional_config=message_info.additional_config, + ) + + # 创建MessageBase + message_base = MessageBase( + message_info=buffered_message_info, + message_segment=submit_seg, + raw_message=raw_message.get("raw_message", ""), + ) + + logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}") + await message_send_instance.message_send(message_base) + + except Exception as e: + logger.error(f"发送缓冲消息失败: {e}", exc_info=True) + + +message_handler = MessageHandler() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py b/plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py new file mode 100644 index 000000000..e1cf25001 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py @@ -0,0 +1,32 @@ +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from maim_message import MessageBase, Router + + +class MessageSending: + """ + 负责把消息发送到麦麦 + """ + + maibot_router: Router = None + + def __init__(self): + pass + + async def message_send(self, message_base: MessageBase) -> bool: + """ + 发送消息 + Parameters: + message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息 + """ + try: + send_status = await self.maibot_router.send_message(message_base) + if not send_status: + raise RuntimeError("可能是路由未正确配置或连接异常") + return send_status + except Exception as e: + logger.error(f"发送消息失败: {str(e)}") + logger.error("请检查与MaiBot之间的连接") + + +message_send_instance = MessageSending() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/plugins/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py new file mode 100644 index 000000000..bf6fea541 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -0,0 +1,50 @@ +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from ..config import global_config +import time +import asyncio + +from . import MetaEventType + + +class MetaEventHandler: + """ + 处理Meta事件 + """ + + def __init__(self): + self.interval = global_config.napcat_server.heartbeat_interval + self._interval_checking = False + + async def handle_meta_event(self, message: dict) -> None: + event_type = message.get("meta_event_type") + if event_type == MetaEventType.lifecycle: + sub_type = message.get("sub_type") + if sub_type == MetaEventType.Lifecycle.connect: + self_id = message.get("self_id") + self.last_heart_beat = time.time() + logger.info(f"Bot {self_id} 连接成功") + asyncio.create_task(self.check_heartbeat(self_id)) + elif event_type == MetaEventType.heartbeat: + if message["status"].get("online") and message["status"].get("good"): + if not self._interval_checking: + asyncio.create_task(self.check_heartbeat()) + self.last_heart_beat = time.time() + self.interval = message.get("interval") / 1000 + else: + self_id = message.get("self_id") + logger.warning(f"Bot {self_id} Napcat 端异常!") + + async def check_heartbeat(self, id: int) -> None: + self._interval_checking = True + while True: + now_time = time.time() + if now_time - self.last_heart_beat > self.interval * 2: + logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!") + break + else: + logger.debug("心跳正常") + await asyncio.sleep(self.interval) + + +meta_event_handler = MetaEventHandler() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py new file mode 100644 index 000000000..be6b6a0c4 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -0,0 +1,556 @@ +import time +import json +import asyncio +import websockets as Server +from typing import Tuple, Optional + +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + +from ..config import global_config +from ..config.features_config import features_manager +from ..database import BanUser, db_manager, is_identical +from . import NoticeType, ACCEPT_FORMAT +from .message_sending import message_send_instance +from .message_handler import message_handler +from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase +from ..websocket_manager import websocket_manager + +from ..utils import ( + get_group_info, + get_member_info, + get_self_info, + get_stranger_info, + read_ban_list, +) + +from ...CONSTS import PLUGIN_NAME + +notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100) +unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3) + + +class NoticeHandler: + banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表 + lifted_list: list[BanUser] = [] # 已经自然解除禁言 + + def __init__(self): + self.server_connection: Server.ServerConnection | None = None + self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间 + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + while self.server_connection.state != Server.State.OPEN: + await asyncio.sleep(0.5) + self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) + + asyncio.create_task(self.auto_lift_detect()) + asyncio.create_task(self.send_notice()) + asyncio.create_task(self.handle_natural_lift()) + + def get_server_connection(self) -> Server.ServerConnection: + """获取当前的服务器连接""" + # 优先使用直接设置的连接,否则从 websocket_manager 获取 + if self.server_connection: + return self.server_connection + return websocket_manager.get_connection() + + def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + """ + 将用户禁言记录添加到self.banned_list中 + 如果是全体禁言,则user_id为0 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + lift_time = -1 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) + for record in self.banned_list: + if is_identical(record, ban_record): + self.banned_list.remove(record) + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 作为更新 + return + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 添加到数据库 + + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + """ + 从self.lifted_group_list中移除已经解除全体禁言的群 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) + self.lifted_list.append(ban_record) + db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + + async def handle_notice(self, raw_message: dict) -> None: + notice_type = raw_message.get("notice_type") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + group_id = raw_message.get("group_id") + user_id = raw_message.get("user_id") + target_id = raw_message.get("target_id") + + handled_message: Seg = None + user_info: UserInfo = None + system_notice: bool = False + + match notice_type: + case NoticeType.friend_recall: + logger.info("好友撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.group_recall: + logger.info("群内用户撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.notify: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.Notify.poke: + if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat( + user_id, group_id, False, False + ): + logger.info("处理戳一戳消息") + handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) + else: + logger.warning("戳一戳消息被禁用,取消戳一戳处理") + case NoticeType.Notify.input_status: + from src.plugin_system.core.event_manager import event_manager + from ...event_types import NapcatEvent + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT,plugin_name=PLUGIN_NAME) + case _: + logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") + case NoticeType.group_ban: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.GroupBan.ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群禁言") + handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + system_notice = True + case NoticeType.GroupBan.lift_ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理解除群禁言") + handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + system_notice = True + case _: + logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") + case _: + logger.warning(f"不支持的notice类型: {notice_type}") + return None + if not handled_message or not user_info: + logger.warning("notice处理失败或不支持") + return None + + group_info: GroupInfo = None + if group_id: + fetched_group_info = await get_group_info(self.get_server_connection(), group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=None, + format_info=FormatInfo( + content_format=["text", "notify"], + accept_format=ACCEPT_FORMAT, + ), + additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁 + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=handled_message, + raw_message=json.dumps(raw_message), + ) + + if system_notice: + await self.put_notice(message_base) + else: + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) + + async def handle_poke_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + # sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches + self_info: dict = await get_self_info(self.get_server_connection()) + + if not self_info: + logger.error("自身信息获取失败") + return None, None + + self_id = raw_message.get("self_id") + target_id = raw_message.get("target_id") + + # 防抖检查:如果是针对机器人的戳一戳,检查防抖时间 + if self_id == target_id: + current_time = time.time() + debounce_seconds = features_manager.get_config().poke_debounce_seconds + + if self.last_poke_time > 0: + time_diff = current_time - self.last_poke_time + if time_diff < debounce_seconds: + logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)") + return None, None + + # 记录这次戳一戳的时间 + self.last_poke_time = current_time + + target_name: str = None + raw_info: list = raw_message.get("raw_info") + + if group_id: + user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) + else: + user_qq_info: dict = await get_stranger_info(self.get_server_connection(), user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + user_name = "QQ用户" + user_cardname = "QQ用户" + logger.info("无法获取戳一戳对方的用户昵称") + + # 计算Seg + if self_id == target_id: + display_name = "" + target_name = self_info.get("nickname") + + elif self_id == user_id: + # 让ada不发送麦麦戳别人的消息 + return None, None + + else: + # 如果配置为忽略不是针对自己的戳一戳,则直接返回None + if features_manager.is_non_self_poke_ignored(): + logger.info("忽略不是针对自己的戳一戳消息") + return None, None + + # 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境 + if group_id: + fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, target_id) + if fetched_member_info: + target_name = fetched_member_info.get("nickname") + else: + target_name = "QQ用户" + logger.info("无法获取被戳一戳方的用户昵称") + display_name = user_name + else: + return None, None + + first_txt: str = "戳了戳" + second_txt: str = "" + try: + first_txt = raw_info[2].get("txt", "戳了戳") + second_txt = raw_info[4].get("txt", "") + except Exception as e: + logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") + + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + seg_data: Seg = Seg( + type="text", + data=f"{display_name}{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", + ) + return seg_data, user_info + + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + user_id = raw_message.get("user_id") + banned_user_info: UserInfo = None + user_nickname: str = "QQ用户" + user_cardname: str = None + sub_type: str = None + + duration = raw_message.get("duration") + if duration is None: + logger.error("禁言时长不能为空,无法处理禁言通知") + return None, None + + if user_id == 0: # 为全体禁言 + sub_type: str = "whole_ban" + self._ban_operation(group_id) + else: # 为单人禁言 + # 获取被禁言人的信息 + sub_type: str = "ban" + fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + banned_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._ban_operation(group_id, user_id, int(time.time() + duration)) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "duration": duration, + "banned_user_info": banned_user_info.to_dict() if banned_user_info else None, + }, + ) + + return seg_data, operator_info + + async def handle_lift_ban_notify( + self, raw_message: dict, group_id: int + ) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + sub_type: str = None + user_nickname: str = "QQ用户" + user_cardname: str = None + lifted_user_info: UserInfo = None + + user_id = raw_message.get("user_id") + if user_id == 0: # 全体禁言解除 + sub_type = "whole_lift_ban" + self._lift_operation(group_id) + else: # 单人禁言解除 + sub_type = "lift_ban" + # 获取被解除禁言人的信息 + fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + else: + logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效") + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._lift_operation(group_id, user_id) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None, + }, + ) + return seg_data, operator_info + + async def put_notice(self, message_base: MessageBase) -> None: + """ + 将处理后的通知消息放入通知队列 + """ + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + + async def handle_natural_lift(self) -> None: + while True: + if len(self.lifted_list) != 0: + lift_record = self.lifted_list.pop() + group_id = lift_record.group_id + user_id = lift_record.user_id + + db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + + seg_message: Seg = await self.natural_lift(group_id, user_id) + + fetched_group_info = await get_group_info(self.get_server_connection(), group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=time.time(), + user_info=None, # 自然解除禁言没有操作者 + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=seg_message, + raw_message=json.dumps( + { + "post_type": "notice", + "notice_type": "group_ban", + "sub_type": "lift_ban", + "group_id": group_id, + "user_id": user_id, + "operator_id": None, # 自然解除禁言没有操作者 + } + ), + ) + + await self.put_notice(message_base) + await asyncio.sleep(0.5) # 确保队列处理间隔 + else: + await asyncio.sleep(5) # 每5秒检查一次 + + async def natural_lift(self, group_id: int, user_id: int) -> Seg | None: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None + + if user_id == 0: # 理论上永远不会触发 + return Seg( + type="notify", + data={ + "sub_type": "whole_lift_ban", + "lifted_user_info": None, + }, + ) + + user_nickname: str = "QQ用户" + user_cardname: str = None + fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + + return Seg( + type="notify", + data={ + "sub_type": "lift_ban", + "lifted_user_info": lifted_user_info.to_dict(), + }, + ) + + async def auto_lift_detect(self) -> None: + while True: + if len(self.banned_list) == 0: + await asyncio.sleep(5) + continue + for ban_record in self.banned_list: + if ban_record.user_id == 0 or ban_record.lift_time == -1: + continue + if ban_record.lift_time <= int(time.time()): + # 触发自然解除禁言 + logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") + self.lifted_list.append(ban_record) + self.banned_list.remove(ban_record) + await asyncio.sleep(5) + + async def send_notice(self) -> None: + """ + 发送通知消息到Napcat + """ + while True: + if not unsuccessful_notice_queue.empty(): + to_be_send: MessageBase = await unsuccessful_notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + unsuccessful_notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + continue + to_be_send: MessageBase = await notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + + + + +notice_handler = NoticeHandler() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py b/plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py new file mode 100644 index 000000000..51c32321a --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py @@ -0,0 +1,250 @@ +qq_face: dict = { + "0": "[表情:惊讶]", + "1": "[表情:撇嘴]", + "2": "[表情:色]", + "3": "[表情:发呆]", + "4": "[表情:得意]", + "5": "[表情:流泪]", + "6": "[表情:害羞]", + "7": "[表情:闭嘴]", + "8": "[表情:睡]", + "9": "[表情:大哭]", + "10": "[表情:尴尬]", + "11": "[表情:发怒]", + "12": "[表情:调皮]", + "13": "[表情:呲牙]", + "14": "[表情:微笑]", + "15": "[表情:难过]", + "16": "[表情:酷]", + "18": "[表情:抓狂]", + "19": "[表情:吐]", + "20": "[表情:偷笑]", + "21": "[表情:可爱]", + "22": "[表情:白眼]", + "23": "[表情:傲慢]", + "24": "[表情:饥饿]", + "25": "[表情:困]", + "26": "[表情:惊恐]", + "27": "[表情:流汗]", + "28": "[表情:憨笑]", + "29": "[表情:悠闲]", + "30": "[表情:奋斗]", + "31": "[表情:咒骂]", + "32": "[表情:疑问]", + "33": "[表情: 嘘]", + "34": "[表情:晕]", + "35": "[表情:折磨]", + "36": "[表情:衰]", + "37": "[表情:骷髅]", + "38": "[表情:敲打]", + "39": "[表情:再见]", + "41": "[表情:发抖]", + "42": "[表情:爱情]", + "43": "[表情:跳跳]", + "46": "[表情:猪头]", + "49": "[表情:拥抱]", + "53": "[表情:蛋糕]", + "56": "[表情:刀]", + "59": "[表情:便便]", + "60": "[表情:咖啡]", + "63": "[表情:玫瑰]", + "64": "[表情:凋谢]", + "66": "[表情:爱心]", + "67": "[表情:心碎]", + "74": "[表情:太阳]", + "75": "[表情:月亮]", + "76": "[表情:赞]", + "77": "[表情:踩]", + "78": "[表情:握手]", + "79": "[表情:胜利]", + "85": "[表情:飞吻]", + "86": "[表情:怄火]", + "89": "[表情:西瓜]", + "96": "[表情:冷汗]", + "97": "[表情:擦汗]", + "98": "[表情:抠鼻]", + "99": "[表情:鼓掌]", + "100": "[表情:糗大了]", + "101": "[表情:坏笑]", + "102": "[表情:左哼哼]", + "103": "[表情:右哼哼]", + "104": "[表情:哈欠]", + "105": "[表情:鄙视]", + "106": "[表情:委屈]", + "107": "[表情:快哭了]", + "108": "[表情:阴险]", + "109": "[表情:左亲亲]", + "110": "[表情:吓]", + "111": "[表情:可怜]", + "112": "[表情:菜刀]", + "114": "[表情:篮球]", + "116": "[表情:示爱]", + "118": "[表情:抱拳]", + "119": "[表情:勾引]", + "120": "[表情:拳头]", + "121": "[表情:差劲]", + "123": "[表情:NO]", + "124": "[表情:OK]", + "125": "[表情:转圈]", + "129": "[表情:挥手]", + "137": "[表情:鞭炮]", + "144": "[表情:喝彩]", + "146": "[表情:爆筋]", + "147": "[表情:棒棒糖]", + "169": "[表情:手枪]", + "171": "[表情:茶]", + "172": "[表情:眨眼睛]", + "173": "[表情:泪奔]", + "174": "[表情:无奈]", + "175": "[表情:卖萌]", + "176": "[表情:小纠结]", + "177": "[表情:喷血]", + "178": "[表情:斜眼笑]", + "179": "[表情:doge]", + "181": "[表情:戳一戳]", + "182": "[表情:笑哭]", + "183": "[表情:我最美]", + "185": "[表情:羊驼]", + "187": "[表情:幽灵]", + "201": "[表情:点赞]", + "212": "[表情:托腮]", + "262": "[表情:脑阔疼]", + "263": "[表情:沧桑]", + "264": "[表情:捂脸]", + "265": "[表情:辣眼睛]", + "266": "[表情:哦哟]", + "267": "[表情:头秃]", + "268": "[表情:问号脸]", + "269": "[表情:暗中观察]", + "270": "[表情:emm]", + "271": "[表情:吃 瓜]", + "272": "[表情:呵呵哒]", + "273": "[表情:我酸了]", + "277": "[表情:汪汪]", + "281": "[表情:无眼笑]", + "282": "[表情:敬礼]", + "283": "[表情:狂笑]", + "284": "[表情:面无表情]", + "285": "[表情:摸鱼]", + "286": "[表情:魔鬼笑]", + "287": "[表情:哦]", + "289": "[表情:睁眼]", + "293": "[表情:摸锦鲤]", + "294": "[表情:期待]", + "295": "[表情:拿到红包]", + "297": "[表情:拜谢]", + "298": "[表情:元宝]", + "299": "[表情:牛啊]", + "300": "[表情:胖三斤]", + "302": "[表情:左拜年]", + "303": "[表情:右拜年]", + "305": "[表情:右亲亲]", + "306": "[表情:牛气冲天]", + "307": "[表情:喵喵]", + "311": "[表情:打call]", + "312": "[表情:变形]", + "314": "[表情:仔细分析]", + "317": "[表情:菜汪]", + "318": "[表情:崇拜]", + "319": "[表情: 比心]", + "320": "[表情:庆祝]", + "323": "[表情:嫌弃]", + "324": "[表情:吃糖]", + "325": "[表情:惊吓]", + "326": "[表情:生气]", + "332": "[表情:举牌牌]", + "333": "[表情:烟花]", + "334": "[表情:虎虎生威]", + "336": "[表情:豹富]", + "337": "[表情:花朵脸]", + "338": "[表情:我想开了]", + "339": "[表情:舔屏]", + "341": "[表情:打招呼]", + "342": "[表情:酸Q]", + "343": "[表情:我方了]", + "344": "[表情:大怨种]", + "345": "[表情:红包多多]", + "346": "[表情:你真棒棒]", + "347": "[表情:大展宏兔]", + "349": "[表情:坚强]", + "350": "[表情:贴贴]", + "351": "[表情:敲敲]", + "352": "[表情:咦]", + "353": "[表情:拜托]", + "354": "[表情:尊嘟假嘟]", + "355": "[表情:耶]", + "356": "[表情:666]", + "357": "[表情:裂开]", + "392": "[表情:龙年 快乐]", + "393": "[表情:新年中龙]", + "394": "[表情:新年大龙]", + "395": "[表情:略略略]", + "😊": "[表情:嘿嘿]", + "😌": "[表情:羞涩]", + "😚": "[ 表情:亲亲]", + "😓": "[表情:汗]", + "😰": "[表情:紧张]", + "😝": "[表情:吐舌]", + "😁": "[表情:呲牙]", + "😜": "[表情:淘气]", + "☺": "[表情:可爱]", + "😍": "[表情:花痴]", + "😔": "[表情:失落]", + "😄": "[表情:高兴]", + "😏": "[表情:哼哼]", + "😒": "[表情:不屑]", + "😳": "[表情:瞪眼]", + "😘": "[表情:飞吻]", + "😭": "[表情:大哭]", + "😱": "[表情:害怕]", + "😂": "[表情:激动]", + "💪": "[表情:肌肉]", + "👊": "[表情:拳头]", + "👍": "[表情 :厉害]", + "👏": "[表情:鼓掌]", + "👎": "[表情:鄙视]", + "🙏": "[表情:合十]", + "👌": "[表情:好的]", + "👆": "[表情:向上]", + "👀": "[表情:眼睛]", + "🍜": "[表情:拉面]", + "🍧": "[表情:刨冰]", + "🍞": "[表情:面包]", + "🍺": "[表情:啤酒]", + "🍻": "[表情:干杯]", + "☕": "[表情:咖啡]", + "🍎": "[表情:苹果]", + "🍓": "[表情:草莓]", + "🍉": "[表情:西瓜]", + "🚬": "[表情:吸烟]", + "🌹": "[表情:玫瑰]", + "🎉": "[表情:庆祝]", + "💝": "[表情:礼物]", + "💣": "[表情:炸弹]", + "✨": "[表情:闪光]", + "💨": "[表情:吹气]", + "💦": "[表情:水]", + "🔥": "[表情:火]", + "💤": "[表情:睡觉]", + "💩": "[表情:便便]", + "💉": "[表情:打针]", + "📫": "[表情:邮箱]", + "🐎": "[表情:骑马]", + "👧": "[表情:女孩]", + "👦": "[表情:男孩]", + "🐵": "[表情:猴]", + "🐷": "[表情:猪]", + "🐮": "[表情:牛]", + "🐔": "[表情:公鸡]", + "🐸": "[表情:青蛙]", + "👻": "[表情:幽灵]", + "🐛": "[表情:虫]", + "🐶": "[表情:狗]", + "🐳": "[表情:鲸鱼]", + "👢": "[表情:靴子]", + "☀": "[表情:晴天]", + "❔": "[表情:问号]", + "🔫": "[表情:手枪]", + "💓": "[表情:爱 心]", + "🏪": "[表情:便利店]", +} diff --git a/plugins/napcat_adapter_plugin/src/response_pool.py b/plugins/napcat_adapter_plugin/src/response_pool.py new file mode 100644 index 000000000..ede85d04d --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/response_pool.py @@ -0,0 +1,45 @@ +import asyncio +import time +from typing import Dict +from .config import global_config +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") + +response_dict: Dict = {} +response_time_dict: Dict = {} + + +async def get_response(request_id: str, timeout: int = 10) -> dict: + response = await asyncio.wait_for(_get_response(request_id), timeout) + _ = response_time_dict.pop(request_id) + logger.info(f"响应信息id: {request_id} 已从响应字典中取出") + return response + +async def _get_response(request_id: str) -> dict: + """ + 内部使用的获取响应函数,主要用于在需要时获取响应 + """ + while request_id not in response_dict: + await asyncio.sleep(0.2) + return response_dict.pop(request_id) + +async def put_response(response: dict): + echo_id = response.get("echo") + now_time = time.time() + response_dict[echo_id] = response + response_time_dict[echo_id] = now_time + logger.info(f"响应信息id: {echo_id} 已存入响应字典") + + +async def check_timeout_response() -> None: + while True: + cleaned_message_count: int = 0 + now_time = time.time() + for echo_id, response_time in list(response_time_dict.items()): + if now_time - response_time > global_config.napcat_server.heartbeat_interval: + cleaned_message_count += 1 + response_dict.pop(echo_id) + response_time_dict.pop(echo_id) + logger.warning(f"响应消息 {echo_id} 超时,已删除") + logger.info(f"已删除 {cleaned_message_count} 条超时响应消息") + await asyncio.sleep(global_config.napcat_server.heartbeat_interval) diff --git a/plugins/napcat_adapter_plugin/src/send_handler.py b/plugins/napcat_adapter_plugin/src/send_handler.py new file mode 100644 index 000000000..d772907d0 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/send_handler.py @@ -0,0 +1,711 @@ +import json +import time +import random +import websockets as Server +import uuid +import asyncio +from maim_message import ( + UserInfo, + GroupInfo, + Seg, + BaseMessageInfo, + MessageBase, +) +from typing import Dict, Any, Tuple, Optional + +from . import CommandType +from .config import global_config +from .response_pool import get_response +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from .utils import get_image_format, convert_image_to_gif +from .recv_handler.message_sending import message_send_instance +from .websocket_manager import websocket_manager +from .config.features_config import features_manager + + +class SendHandler: + def __init__(self): + self.server_connection: Optional[Server.ServerConnection] = None + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + def get_server_connection(self) -> Optional[Server.ServerConnection]: + """获取当前的服务器连接""" + # 优先使用直接设置的连接,否则从 websocket_manager 获取 + if self.server_connection: + return self.server_connection + return websocket_manager.get_connection() + + async def handle_message(self, raw_message_base_dict: dict) -> None: + raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) + message_segment: Seg = raw_message_base.message_segment + logger.info("接收到来自MaiBot的消息,处理中") + if message_segment.type == "command": + logger.info("处理命令") + return await self.send_command(raw_message_base) + elif message_segment.type == "adapter_command": + logger.info("处理适配器命令") + return await self.handle_adapter_command(raw_message_base) + else: + logger.info("处理普通消息") + return await self.send_normal_message(raw_message_base) + + async def send_normal_message(self, raw_message_base: MessageBase) -> None: + """ + 处理普通消息发送 + """ + logger.info("处理普通信息中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: Optional[GroupInfo] = message_info.group_info + user_info: Optional[UserInfo] = message_info.user_info + target_id: Optional[int] = None + action: Optional[str] = None + id_name: Optional[str] = None + processed_message: list = [] + try: + if user_info: + processed_message = await self.handle_seg_recursive( + message_segment, user_info + ) + except Exception as e: + logger.error(f"处理消息时发生错误: {e}") + return + + if not processed_message: + logger.critical("现在暂时不支持解析此回复!") + return None + + if group_info and user_info: + logger.debug("发送群聊消息") + target_id = int(group_info.group_id) if group_info.group_id else None + action = "send_group_msg" + id_name = "group_id" + elif user_info: + logger.debug("发送私聊消息") + target_id = int(user_info.user_id) if user_info.user_id else None + action = "send_private_msg" + id_name = "user_id" + else: + logger.error("无法识别的消息类型") + return + logger.info("尝试发送到napcat") + response = await self.send_message_to_napcat( + action, + { + id_name: target_id, + "message": processed_message, + }, + ) + if response.get("status") == "ok": + logger.info("消息发送成功") + qq_message_id = response.get("data", {}).get("message_id") + await self.message_sent_back(raw_message_base, qq_message_id) + else: + logger.warning(f"消息发送失败,napcat返回:{str(response)}") + + async def send_command(self, raw_message_base: MessageBase) -> None: + """ + 处理命令类 + """ + logger.info("处理命令中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: Optional[GroupInfo] = message_info.group_info + seg_data: Dict[str, Any] = ( + message_segment.data + if isinstance(message_segment.data, dict) + else {} + ) + command_name: Optional[str] = seg_data.get("name") + try: + args = seg_data.get("args", {}) + if not isinstance(args, dict): + args = {} + + match command_name: + case CommandType.GROUP_BAN.name: + command, args_dict = self.handle_ban_command(args, group_info) + case CommandType.GROUP_WHOLE_BAN.name: + command, args_dict = self.handle_whole_ban_command( + args, group_info + ) + case CommandType.GROUP_KICK.name: + command, args_dict = self.handle_kick_command(args, group_info) + case CommandType.SEND_POKE.name: + command, args_dict = self.handle_poke_command(args, group_info) + case CommandType.DELETE_MSG.name: + command, args_dict = self.delete_msg_command(args) + case CommandType.AI_VOICE_SEND.name: + command, args_dict = self.handle_ai_voice_send_command( + args, group_info + ) + case CommandType.SET_EMOJI_LIKE.name: + command, args_dict = self.handle_set_emoji_like_command(args) + case CommandType.SEND_AT_MESSAGE.name: + command, args_dict = self.handle_at_message_command( + args, group_info + ) + case CommandType.SEND_LIKE.name: + command, args_dict = self.handle_send_like_command(args) + case _: + logger.error(f"未知命令: {command_name}") + return + except Exception as e: + logger.error(f"处理命令时发生错误: {e}") + return None + + if not command or not args_dict: + logger.error("命令或参数缺失") + return None + + response = await self.send_message_to_napcat(command, args_dict) + if response.get("status") == "ok": + logger.info(f"命令 {command_name} 执行成功") + else: + logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") + + async def handle_adapter_command(self, raw_message_base: MessageBase) -> None: + """ + 处理适配器命令类 - 用于直接向Napcat发送命令并返回结果 + """ + logger.info("处理适配器命令中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + seg_data: Dict[str, Any] = ( + message_segment.data + if isinstance(message_segment.data, dict) + else {} + ) + + try: + action = seg_data.get("action") + params = seg_data.get("params", {}) + request_id = seg_data.get("request_id") + + if not action: + logger.error("适配器命令缺少action参数") + await self.send_adapter_command_response( + raw_message_base, + {"status": "error", "message": "缺少action参数"}, + request_id + ) + return + + logger.info(f"执行适配器命令: {action}") + + # 直接向Napcat发送命令并获取响应 + response_task = asyncio.create_task(self.send_message_to_napcat(action, params)) + response = await response_task + + # 发送响应回MaiBot + await self.send_adapter_command_response(raw_message_base, response, request_id) + + if response.get("status") == "ok": + logger.info(f"适配器命令 {action} 执行成功") + else: + logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}") + + except Exception as e: + logger.error(f"处理适配器命令时发生错误: {e}") + error_response = {"status": "error", "message": str(e)} + await self.send_adapter_command_response( + raw_message_base, + error_response, + seg_data.get("request_id") + ) + + def get_level(self, seg_data: Seg) -> int: + if seg_data.type == "seglist": + return 1 + max(self.get_level(seg) for seg in seg_data.data) + else: + return 1 + + async def handle_seg_recursive(self, seg_data: Seg, user_info: UserInfo) -> list: + payload: list = [] + if seg_data.type == "seglist": + # level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用 + if not seg_data.data: + return [] + for seg in seg_data.data: + payload = await self.process_message_by_type(seg, payload, user_info) + else: + payload = await self.process_message_by_type(seg_data, payload, user_info) + return payload + + async def process_message_by_type( + self, seg: Seg, payload: list, user_info: UserInfo + ) -> list: + # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression + new_payload = payload + if seg.type == "reply": + target_id = seg.data + if target_id == "notice": + return payload + new_payload = self.build_payload( + payload, + await self.handle_reply_message( + target_id if isinstance(target_id, str) else "", user_info + ), + True, + ) + elif seg.type == "text": + text = seg.data + if not text: + return payload + new_payload = self.build_payload( + payload, + self.handle_text_message(text if isinstance(text, str) else ""), + False, + ) + elif seg.type == "face": + logger.warning("MaiBot 发送了qq原生表情,暂时不支持") + elif seg.type == "image": + image = seg.data + new_payload = self.build_payload(payload, self.handle_image_message(image), False) + elif seg.type == "emoji": + emoji = seg.data + new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False) + elif seg.type == "voice": + voice = seg.data + new_payload = self.build_payload(payload, self.handle_voice_message(voice), False) + elif seg.type == "voiceurl": + voice_url = seg.data + new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False) + elif seg.type == "music": + song_id = seg.data + new_payload = self.build_payload(payload, self.handle_music_message(song_id), False) + elif seg.type == "videourl": + video_url = seg.data + new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False) + elif seg.type == "file": + file_path = seg.data + new_payload = self.build_payload(payload, self.handle_file_message(file_path), False) + return new_payload + + def build_payload( + self, payload: list, addon: dict | list, is_reply: bool = False + ) -> list: + # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator + """构建发送的消息体""" + if is_reply: + temp_list = [] + if isinstance(addon, list): + temp_list.extend(addon) + else: + temp_list.append(addon) + for i in payload: + if i.get("type") == "reply": + logger.debug("检测到多个回复,使用最新的回复") + continue + temp_list.append(i) + return temp_list + else: + if isinstance(addon, list): + payload.extend(addon) + else: + payload.append(addon) + return payload + + async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list: + """处理回复消息""" + reply_seg = {"type": "reply", "data": {"id": id}} + + # 获取功能配置 + ft_config = features_manager.get_config() + + # 检查是否启用引用艾特功能 + if not ft_config.enable_reply_at: + return reply_seg + + try: + # 尝试通过 message_id 获取消息详情 + msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)}) + + replied_user_id = None + if msg_info_response and msg_info_response.get("status") == "ok": + sender_info = msg_info_response.get("data", {}).get("sender") + if sender_info: + replied_user_id = sender_info.get("user_id") + + # 如果没有获取到被回复者的ID,则直接返回,不进行@ + if not replied_user_id: + logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @") + return reply_seg + + # 根据概率决定是否艾特用户 + if random.random() < ft_config.reply_at_rate: + at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}} + # 在艾特后面添加一个空格 + text_seg = {"type": "text", "data": {"text": " "}} + return [reply_seg, at_seg, text_seg] + + except Exception as e: + logger.error(f"处理引用回复并尝试@时出错: {e}") + # 出现异常时,只发送普通的回复,避免程序崩溃 + return reply_seg + + return reply_seg + + def handle_text_message(self, message: str) -> dict: + """处理文本消息""" + return {"type": "text", "data": {"text": message}} + + def handle_image_message(self, encoded_image: str) -> dict: + """处理图片消息""" + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 0, + }, + } # base64 编码的图片 + + def handle_emoji_message(self, encoded_emoji: str) -> dict: + """处理表情消息""" + encoded_image = encoded_emoji + image_format = get_image_format(encoded_emoji) + if image_format != "gif": + encoded_image = convert_image_to_gif(encoded_emoji) + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 1, + "summary": "[动画表情]", + }, + } + + def handle_voice_message(self, encoded_voice: str) -> dict: + """处理语音消息""" + if not global_config.voice.use_tts: + logger.warning("未启用语音消息处理") + return {} + if not encoded_voice: + return {} + return { + "type": "record", + "data": {"file": f"base64://{encoded_voice}"}, + } + + def handle_voiceurl_message(self, voice_url: str) -> dict: + """处理语音链接消息""" + return { + "type": "record", + "data": {"file": voice_url}, + } + + def handle_music_message(self, song_id: str) -> dict: + """处理音乐消息""" + return { + "type": "music", + "data": {"type": "163", "id": song_id}, + } + def handle_videourl_message(self, video_url: str) -> dict: + """处理视频链接消息""" + return { + "type": "video", + "data": {"file": video_url}, + } + + def handle_file_message(self, file_path: str) -> dict: + """处理文件消息""" + return { + "type": "file", + "data": {"file": f"file://{file_path}"}, + } + + def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + """处理删除消息命令""" + return "delete_msg", {"message_id": args["message_id"]} + + def handle_ban_command( + self, args: Dict[str, Any], group_info: GroupInfo + ) -> Tuple[str, Dict[str, Any]]: + """处理封禁命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + duration: int = int(args["duration"]) + user_id: int = int(args["qq_id"]) + group_id: int = int(group_info.group_id) + if duration < 0: + raise ValueError("封禁时间必须大于等于0") + if not user_id or not group_id: + raise ValueError("封禁命令缺少必要参数") + if duration > 2592000: + raise ValueError("封禁时间不能超过30天") + return ( + CommandType.GROUP_BAN.value, + { + "group_id": group_id, + "user_id": user_id, + "duration": duration, + }, + ) + + def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理全体禁言命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + enable = args["enable"] + assert isinstance(enable, bool), "enable参数必须是布尔值" + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + return ( + CommandType.GROUP_WHOLE_BAN.value, + { + "group_id": group_id, + "enable": enable, + }, + ) + + def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理群成员踢出命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + user_id: int = int(args["qq_id"]) + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + return ( + CommandType.GROUP_KICK.value, + { + "group_id": group_id, + "user_id": user_id, + "reject_add_request": False, # 不拒绝加群请求 + }, + ) + + def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理戳一戳命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + user_id: int = int(args["qq_id"]) + if group_info is None: + group_id = None + else: + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + return ( + CommandType.SEND_POKE.value, + { + "group_id": group_id, + "user_id": user_id, + }, + ) + + def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + """处理设置表情回应命令 + + Args: + args (Dict[str, Any]): 参数字典 + + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + try: + message_id = int(args["message_id"]) + emoji_id = int(args["emoji_id"]) + set_like = str(args["set"]) + except: + raise ValueError("缺少必需参数: message_id 或 emoji_id") + + return ( + CommandType.SET_EMOJI_LIKE.value, + { + "message_id": message_id, + "emoji_id": emoji_id, + "set": set_like + }, + ) + + def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + """ + 处理发送点赞命令的逻辑。 + + Args: + args (Dict[str, Any]): 参数字典 + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + try: + user_id: int = int(args["qq_id"]) + times: int = int(args["times"]) + except (KeyError, ValueError): + raise ValueError("缺少必需参数: qq_id 或 times") + + return ( + CommandType.SEND_LIKE.value, + { + "user_id": user_id, + "times": times + }, + ) + + def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """ + 处理AI语音发送命令的逻辑。 + 并返回 NapCat 兼容的 (action, params) 元组。 + """ + if not group_info or not group_info.group_id: + raise ValueError("AI语音发送命令必须在群聊上下文中使用") + if not args: + raise ValueError("AI语音发送命令缺少参数") + + group_id: int = int(group_info.group_id) + character_id = args.get("character") + text_content = args.get("text") + + if not character_id or not text_content: + raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'") + + return ( + CommandType.AI_VOICE_SEND.value, + { + "group_id": group_id, + "text": text_content, + "character": character_id, + }, + ) + + async def send_message_to_napcat(self, action: str, params: dict) -> dict: + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) + + # 获取当前连接 + connection = self.get_server_connection() + if not connection: + logger.error("没有可用的 Napcat 连接") + return {"status": "error", "message": "no connection"} + + try: + await connection.send(payload) + response = await get_response(request_uuid) + except TimeoutError: + logger.error("发送消息超时,未收到响应") + return {"status": "error", "message": "timeout"} + except Exception as e: + logger.error(f"发送消息失败: {e}") + return {"status": "error", "message": str(e)} + return response + + async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None: + # 修改 additional_config,添加 echo 字段 + if message_base.message_info.additional_config is None: + message_base.message_info.additional_config = {} + + message_base.message_info.additional_config["echo"] = True + + # 获取原始的 mmc_message_id + mmc_message_id = message_base.message_info.message_id + + # 修改 message_segment 为 notify 类型 + message_base.message_segment = Seg( + type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id} + ) + await message_send_instance.message_send(message_base) + logger.debug("已回送消息ID") + return + + async def send_adapter_command_response( + self, original_message: MessageBase, response_data: dict, request_id: str + ) -> None: + """ + 发送适配器命令响应回MaiBot + + Args: + original_message: 原始消息 + response_data: 响应数据 + request_id: 请求ID + """ + try: + # 修改 additional_config,添加 echo 字段 + if original_message.message_info.additional_config is None: + original_message.message_info.additional_config = {} + + original_message.message_info.additional_config["echo"] = True + + # 修改 message_segment 为 adapter_response 类型 + original_message.message_segment = Seg( + type="adapter_response", + data={ + "request_id": request_id, + "response": response_data, + "timestamp": int(time.time() * 1000) + } + ) + + await message_send_instance.message_send(original_message) + logger.debug(f"已发送适配器命令响应,request_id: {request_id}") + + except Exception as e: + logger.error(f"发送适配器命令响应时出错: {e}") + + def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理艾特并发送消息命令 + + Args: + args (Dict[str, Any]): 参数字典, 包含 qq_id 和 text + group_info (GroupInfo): 群聊信息 + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + at_user_id = args.get("qq_id") + text = args.get("text") + + if not at_user_id or not text: + raise ValueError("艾特消息命令缺少 qq_id 或 text 参数") + + if not group_info: + raise ValueError("艾特消息命令必须在群聊上下文中使用") + + message_payload = [ + {"type": "at", "data": {"qq": str(at_user_id)}}, + {"type": "text", "data": {"text": " " + str(text)}}, + ] + + return ( + "send_group_msg", + { + "group_id": group_info.group_id, + "message": message_payload, + }, + ) + +send_handler = SendHandler() diff --git a/plugins/napcat_adapter_plugin/src/utils.py b/plugins/napcat_adapter_plugin/src/utils.py new file mode 100644 index 000000000..b1d811f15 --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/utils.py @@ -0,0 +1,311 @@ +import websockets as Server +import json +import base64 +import uuid +import urllib3 +import ssl +import io + +from .database import BanUser, db_manager +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from .response_pool import get_response + +from PIL import Image +from typing import Union, List, Tuple, Optional + + +class SSLAdapter(urllib3.PoolManager): + def __init__(self, *args, **kwargs): + context = ssl.create_default_context() + context.set_ciphers("DEFAULT@SECLEVEL=1") + context.minimum_version = ssl.TLSVersion.TLSv1_2 + kwargs["ssl_context"] = context + super().__init__(*args, **kwargs) + + +async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: + """ + 获取群相关信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群聊信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取群信息超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + +async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: + """ + 获取群详细信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群详细信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取群详细信息超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群详细信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + +async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None: + """ + 获取群成员信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群成员信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_group_member_info", + "params": {"group_id": group_id, "user_id": user_id, "no_cache": True}, + "echo": request_uuid, + } + ) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}") + return None + except Exception as e: + logger.error(f"获取成员信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + +async def get_image_base64(url: str) -> str: + # sourcery skip: raise-specific-error + """获取图片/表情包的Base64""" + logger.debug(f"下载图片: {url}") + http = SSLAdapter() + try: + response = http.request("GET", url, timeout=10) + if response.status != 200: + raise Exception(f"HTTP Error: {response.status}") + image_bytes = response.data + return base64.b64encode(image_bytes).decode("utf-8") + except Exception as e: + logger.error(f"图片下载失败: {str(e)}") + raise + + +def convert_image_to_gif(image_base64: str) -> str: + # sourcery skip: extract-method + """ + 将Base64编码的图片转换为GIF格式 + Parameters: + image_base64: str: Base64编码的图片数据 + Returns: + str: Base64编码的GIF图片数据 + """ + logger.debug("转换图片为GIF格式") + try: + image_bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) + output_buffer = io.BytesIO() + image.save(output_buffer, format="GIF") + output_buffer.seek(0) + return base64.b64encode(output_buffer.read()).decode("utf-8") + except Exception as e: + logger.error(f"图片转换为GIF失败: {str(e)}") + return image_base64 + + +async def get_self_info(websocket: Server.ServerConnection) -> dict | None: + """ + 获取自身信息 + Parameters: + websocket: WebSocket连接对象 + Returns: + data: dict: 返回的自身信息 + """ + logger.debug("获取自身信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}) + try: + await websocket.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error("获取自身信息超时") + return None + except Exception as e: + logger.error(f"获取自身信息失败: {e}") + return None + logger.debug(response) + return response.get("data") + + +def get_image_format(raw_data: str) -> str: + """ + 从Base64编码的数据中确定图片的格式。 + Parameters: + raw_data: str: Base64编码的图片数据。 + Returns: + format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。 + """ + image_bytes = base64.b64decode(raw_data) + return Image.open(io.BytesIO(image_bytes)).format.lower() + + +async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None: + """ + 获取陌生人信息 + Parameters: + websocket: WebSocket连接对象 + user_id: 用户ID + Returns: + dict: 返回的陌生人信息 + """ + logger.debug("获取陌生人信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取陌生人信息超时,用户ID: {user_id}") + return None + except Exception as e: + logger.error(f"获取陌生人信息失败: {e}") + return None + logger.debug(response) + return response.get("data") + + +async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None: + """ + 获取消息详情,可能为空 + Parameters: + websocket: WebSocket连接对象 + message_id: 消息ID + Returns: + dict: 返回的消息详情 + """ + logger.debug("获取消息详情中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 + except TimeoutError: + logger.error(f"获取消息详情超时,消息ID: {message_id}") + return None + except Exception as e: + logger.error(f"获取消息详情失败: {e}") + return None + logger.debug(response) + return response.get("data") + + +async def get_record_detail( + websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None +) -> dict | None: + """ + 获取语音消息内容 + Parameters: + websocket: WebSocket连接对象 + file: 文件名 + file_id: 文件ID + Returns: + dict: 返回的语音消息详情 + """ + logger.debug("获取语音消息详情中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_record", + "params": {"file": file, "file_id": file_id, "out_format": "wav"}, + "echo": request_uuid, + } + ) + try: + await websocket.send(payload) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 + except TimeoutError: + logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}") + return None + except Exception as e: + logger.error(f"获取语音消息详情失败: {e}") + return None + logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长 + return response.get("data") + + +async def read_ban_list( + websocket: Server.ServerConnection, +) -> Tuple[List[BanUser], List[BanUser]]: + """ + 从根目录下的data文件夹中的文件读取禁言列表。 + 同时自动更新已经失效禁言 + Returns: + Tuple[ + 一个仍在禁言中的用户的BanUser列表, + 一个已经自然解除禁言的用户的BanUser列表, + 一个仍在全体禁言中的群的BanUser列表, + 一个已经自然解除全体禁言的群的BanUser列表, + ] + """ + try: + ban_list = db_manager.get_ban_records() + lifted_list: List[BanUser] = [] + logger.info("已经读取禁言列表") + for ban_record in ban_list: + if ban_record.user_id == 0: + fetched_group_info = await get_group_info(websocket, ban_record.group_id) + if fetched_group_info is None: + logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除") + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + group_all_shut: int = fetched_group_info.get("group_all_shut") + if group_all_shut == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + else: + fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id) + if fetched_member_info is None: + logger.warning( + f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除" + ) + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + lift_ban_time: int = fetched_member_info.get("shut_up_timestamp") + if lift_ban_time == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + else: + ban_record.lift_time = lift_ban_time + db_manager.update_ban_record(ban_list) + return ban_list, lifted_list + except Exception as e: + logger.error(f"读取禁言列表失败: {e}") + return [], [] + + +def save_ban_record(list: List[BanUser]): + return db_manager.update_ban_record(list) diff --git a/plugins/napcat_adapter_plugin/src/video_handler.py b/plugins/napcat_adapter_plugin/src/video_handler.py new file mode 100644 index 000000000..e6f37602a --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/video_handler.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +视频下载和处理模块 +用于从QQ消息中下载视频并转发给Bot进行分析 +""" + +import aiohttp +import asyncio +from pathlib import Path +from typing import Optional, Dict, Any +from src.common.logger import get_logger +logger = get_logger("video_handler") + + +class VideoDownloader: + def __init__(self, max_size_mb: int = 100, download_timeout: int = 60): + self.max_size_mb = max_size_mb + self.download_timeout = download_timeout + self.supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'} + + def is_video_url(self, url: str) -> bool: + """检查URL是否为视频文件""" + try: + # QQ视频URL可能没有扩展名,所以先检查Content-Type + # 对于QQ视频,我们先假设是视频,稍后通过Content-Type验证 + + # 检查URL中是否包含视频相关的关键字 + video_keywords = ['video', 'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v'] + url_lower = url.lower() + + # 如果URL包含视频关键字,认为是视频 + if any(keyword in url_lower for keyword in video_keywords): + return True + + # 检查文件扩展名(传统方法) + path = Path(url.split('?')[0]) # 移除查询参数 + if path.suffix.lower() in self.supported_formats: + return True + + # 对于QQ等特殊平台,URL可能没有扩展名 + # 我们允许这些URL通过,稍后通过HTTP头Content-Type验证 + qq_domains = ['qpic.cn', 'gtimg.cn', 'qq.com', 'tencent.com'] + if any(domain in url_lower for domain in qq_domains): + return True + + return False + except: + # 如果解析失败,默认允许尝试下载(稍后验证) + return True + + def check_file_size(self, content_length: Optional[str]) -> bool: + """检查文件大小是否在允许范围内""" + if content_length is None: + return True # 无法获取大小时允许下载 + + try: + size_bytes = int(content_length) + size_mb = size_bytes / (1024 * 1024) + return size_mb <= self.max_size_mb + except: + return True + + async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]: + """ + 下载视频文件 + + Args: + url: 视频URL + filename: 可选的文件名 + + Returns: + dict: 下载结果,包含success、data、filename、error等字段 + """ + try: + logger.info(f"开始下载视频: {url}") + + # 检查URL格式 + if not self.is_video_url(url): + logger.warning(f"URL格式检查失败: {url}") + return { + "success": False, + "error": "不支持的视频格式", + "url": url + } + + async with aiohttp.ClientSession() as session: + # 先发送HEAD请求检查文件大小 + try: + async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response: + if response.status != 200: + logger.warning(f"HEAD请求失败,状态码: {response.status}") + else: + content_length = response.headers.get('Content-Length') + if not self.check_file_size(content_length): + return { + "success": False, + "error": f"视频文件过大,超过{self.max_size_mb}MB限制", + "url": url + } + except Exception as e: + logger.warning(f"HEAD请求失败: {e},继续尝试下载") + + # 下载文件 + async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response: + if response.status != 200: + return { + "success": False, + "error": f"下载失败,HTTP状态码: {response.status}", + "url": url + } + + # 检查Content-Type是否为视频 + content_type = response.headers.get('Content-Type', '').lower() + if content_type: + # 检查是否为视频类型 + video_mime_types = [ + 'video/', 'application/octet-stream', + 'application/x-msvideo', 'video/x-msvideo' + ] + is_video_content = any(mime in content_type for mime in video_mime_types) + + if not is_video_content: + logger.warning(f"Content-Type不是视频格式: {content_type}") + # 如果不是明确的视频类型,但可能是QQ的特殊格式,继续尝试 + if 'text/' in content_type or 'application/json' in content_type: + return { + "success": False, + "error": f"URL返回的不是视频内容,Content-Type: {content_type}", + "url": url + } + + # 再次检查Content-Length + content_length = response.headers.get('Content-Length') + if not self.check_file_size(content_length): + return { + "success": False, + "error": f"视频文件过大,超过{self.max_size_mb}MB限制", + "url": url + } + + # 读取文件内容 + video_data = await response.read() + + # 检查实际文件大小 + actual_size_mb = len(video_data) / (1024 * 1024) + if actual_size_mb > self.max_size_mb: + return { + "success": False, + "error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB", + "url": url + } + + # 确定文件名 + if filename is None: + filename = Path(url.split('?')[0]).name + if not filename or '.' not in filename: + filename = "video.mp4" + + logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB") + + return { + "success": True, + "data": video_data, + "filename": filename, + "size_mb": actual_size_mb, + "url": url + } + + except asyncio.TimeoutError: + return { + "success": False, + "error": "下载超时", + "url": url + } + except Exception as e: + logger.error(f"下载视频时出错: {e}") + return { + "success": False, + "error": str(e), + "url": url + } + +# 全局实例 +_video_downloader = None + +def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader: + """获取视频下载器实例""" + global _video_downloader + if _video_downloader is None: + _video_downloader = VideoDownloader(max_size_mb, download_timeout) + return _video_downloader diff --git a/plugins/napcat_adapter_plugin/src/websocket_manager.py b/plugins/napcat_adapter_plugin/src/websocket_manager.py new file mode 100644 index 000000000..f4e62ef0f --- /dev/null +++ b/plugins/napcat_adapter_plugin/src/websocket_manager.py @@ -0,0 +1,158 @@ +import asyncio +import websockets as Server +from typing import Optional, Callable, Any +from src.common.logger import get_logger +logger = get_logger("napcat_adapter") +from .config import global_config + + +class WebSocketManager: + """WebSocket 连接管理器,支持正向和反向连接""" + + def __init__(self): + self.connection: Optional[Server.ServerConnection] = None + self.server: Optional[Server.WebSocketServer] = None + self.is_running = False + self.reconnect_interval = 5 # 重连间隔(秒) + self.max_reconnect_attempts = 10 # 最大重连次数 + + async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None: + """根据配置启动 WebSocket 连接""" + mode = global_config.napcat_server.mode + + if mode == "reverse": + await self._start_reverse_connection(message_handler) + elif mode == "forward": + await self._start_forward_connection(message_handler) + else: + raise ValueError(f"不支持的连接模式: {mode}") + + async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None: + """启动反向连接(作为服务器)""" + host = global_config.napcat_server.host + port = global_config.napcat_server.port + + logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}") + + async def handle_client(websocket, path=None): + self.connection = websocket + logger.info(f"Napcat 客户端已连接: {websocket.remote_address}") + try: + await message_handler(websocket) + except Exception as e: + logger.error(f"处理客户端连接时出错: {e}") + finally: + self.connection = None + logger.info("Napcat 客户端已断开连接") + + self.server = await Server.serve( + handle_client, + host, + port, + max_size=2**26 + ) + self.is_running = True + logger.info(f"反向连接服务器已启动,监听地址: ws://{host}:{port}") + + # 保持服务器运行 + await self.server.serve_forever() + + async def _start_forward_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None: + """启动正向连接(作为客户端)""" + url = self._get_forward_url() + logger.info(f"正在启动正向连接模式,目标地址: {url}") + + reconnect_count = 0 + + while reconnect_count < self.max_reconnect_attempts: + try: + logger.info(f"尝试连接到 Napcat 服务器: {url}") + + # 准备连接参数 + connect_kwargs = {"max_size": 2**26} + + # 如果配置了访问令牌,添加到请求头 + if global_config.napcat_server.access_token: + connect_kwargs["additional_headers"] = { + "Authorization": f"Bearer {global_config.napcat_server.access_token}" + } + logger.info("已添加访问令牌到连接请求头") + + async with Server.connect(url, **connect_kwargs) as websocket: + self.connection = websocket + self.is_running = True + reconnect_count = 0 # 重置重连计数 + + logger.info(f"成功连接到 Napcat 服务器: {url}") + + try: + await message_handler(websocket) + except Server.exceptions.ConnectionClosed: + logger.warning("与 Napcat 服务器的连接已断开") + except Exception as e: + logger.error(f"处理正向连接时出错: {e}") + finally: + self.connection = None + self.is_running = False + + except (Server.exceptions.ConnectionClosed, Server.exceptions.InvalidMessage, OSError, ConnectionRefusedError) as e: + reconnect_count += 1 + logger.warning(f"连接失败 ({reconnect_count}/{self.max_reconnect_attempts}): {e}") + + if reconnect_count < self.max_reconnect_attempts: + logger.info(f"将在 {self.reconnect_interval} 秒后重试连接...") + await asyncio.sleep(self.reconnect_interval) + else: + logger.error("已达到最大重连次数,停止重连") + raise + except Exception as e: + logger.error(f"正向连接时发生未知错误: {e}") + raise + + def _get_forward_url(self) -> str: + """获取正向连接的 URL""" + config = global_config.napcat_server + + # 如果配置了完整的 URL,直接使用 + if config.url: + return config.url + + # 否则根据 host 和 port 构建 URL + host = config.host + port = config.port + return f"ws://{host}:{port}" + + async def stop_connection(self) -> None: + """停止 WebSocket 连接""" + self.is_running = False + + if self.connection: + try: + await self.connection.close() + logger.info("WebSocket 连接已关闭") + except Exception as e: + logger.error(f"关闭 WebSocket 连接时出错: {e}") + finally: + self.connection = None + + if self.server: + try: + self.server.close() + await self.server.wait_closed() + logger.info("WebSocket 服务器已关闭") + except Exception as e: + logger.error(f"关闭 WebSocket 服务器时出错: {e}") + finally: + self.server = None + + def get_connection(self) -> Optional[Server.ServerConnection]: + """获取当前的 WebSocket 连接""" + return self.connection + + def is_connected(self) -> bool: + """检查是否已连接""" + return self.connection is not None and self.is_running + + +# 全局 WebSocket 管理器实例 +websocket_manager = WebSocketManager() \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/template/features_template.toml b/plugins/napcat_adapter_plugin/template/features_template.toml new file mode 100644 index 000000000..195cedacd --- /dev/null +++ b/plugins/napcat_adapter_plugin/template/features_template.toml @@ -0,0 +1,43 @@ +# 权限配置文件 +# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能 +# 支持热重载,修改后会自动生效 + +# 群聊权限设置 +group_list_type = "whitelist" # 群聊列表类型:whitelist(白名单)或 blacklist(黑名单) +group_list = [] # 群聊ID列表 +# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人 +# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人 +# 示例:group_list = [123456789, 987654321] + +# 私聊权限设置 +private_list_type = "whitelist" # 私聊列表类型:whitelist(白名单)或 blacklist(黑名单) +private_list = [] # 用户ID列表 +# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人 +# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人 +# 示例:private_list = [123456789, 987654321] + +# 全局禁止设置 +ban_user_id = [] # 全局禁止用户ID列表,这些用户无法在任何地方使用机器人 +ban_qq_bot = false # 是否屏蔽QQ官方机器人消息 + +# 聊天功能设置 +enable_poke = true # 是否启用戳一戳功能 +ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳 +poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略 +enable_reply_at = true # 是否启用引用回复时艾特用户的功能 +reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0) + +# 视频处理设置 +enable_video_analysis = true # 是否启用视频识别功能 +max_video_size_mb = 100 # 视频文件最大大小限制(MB) +download_timeout = 60 # 视频下载超时时间(秒) +supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式 + +# 消息缓冲设置 +enable_message_buffer = true # 是否启用消息缓冲合并功能 +message_buffer_enable_group = true # 是否启用群聊消息缓冲合并 +message_buffer_enable_private = true # 是否启用私聊消息缓冲合并 +message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并 +message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并 +message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并 +message_buffer_block_prefixes = ["/", "!", "!", ".", "。", "#", "%"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲 \ No newline at end of file diff --git a/plugins/napcat_adapter_plugin/template/template_config.toml b/plugins/napcat_adapter_plugin/template/template_config.toml new file mode 100644 index 000000000..1ddca6cf5 --- /dev/null +++ b/plugins/napcat_adapter_plugin/template/template_config.toml @@ -0,0 +1,25 @@ +[inner] +version = "0.2.0" # 版本号 +# 请勿修改版本号,除非你知道自己在做什么 + +[nickname] # 现在没用 +nickname = "" + +[napcat_server] # Napcat连接的ws服务设置 +mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端) +host = "localhost" # 主机地址 +port = 8095 # 端口号 +url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用) +access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选) +heartbeat_interval = 30 # 心跳间隔时间(按秒计) + +[maibot_server] # 连接麦麦的ws服务设置 +host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 +port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 + +[voice] # 发送语音设置 +use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) + +[debug] +level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL) + diff --git a/plugins/napcat_adapter_plugin/todo.md b/plugins/napcat_adapter_plugin/todo.md new file mode 100644 index 000000000..483426826 --- /dev/null +++ b/plugins/napcat_adapter_plugin/todo.md @@ -0,0 +1,89 @@ +# TODO List: + +[x] logger使用主程序的 +[ ] 使用插件系统的config系统 +[ ] 接收从napcat传递的所有信息 +[ ] 优化架构,各模块解耦,暴露关键方法用于提供接口 +[ ] 单独一个模块负责与主程序通信 +[ ] 使用event系统完善接口api + + +--- +Event分为两种,一种是对外输出的event,由napcat插件自主触发并传递参数,另一种是接收外界输入的event,由外部插件触发并向napcat传递参数 + + +## 例如, + +### 对外输出的event: + +napcat_on_received_text -> (message_seg: Seg) 接受到qq的文字消息,会向handler传递一个Seg +napcat_on_received_face -> (message_seg: Seg) 接受到qq的表情消息,会向handler传递一个Seg +napcat_on_received_reply -> (message_seg: Seg) 接受到qq的回复消息,会向handler传递一个Seg +napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg +napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg +napcat_on_received_record -> (message_seg: Seg) 接受到qq的语音消息,会向handler传递一个Seg +napcat_on_received_rps -> (message_seg: Seg) 接受到qq的猜拳魔法表情,会向handler传递一个Seg +napcat_on_received_friend_invitation -> (user_id: str) 接受到qq的好友请求,会向handler传递一个user_id +... + +此类event不接受外部插件的触发,只能由napcat插件统一触发。 + +外部插件需要编写handler并订阅此类事件。 +```python +from src.plugin_system.core.event_manager import event_manager +from src.plugin_system.base.base_event import HandlerResult + +class MyEventHandler(BaseEventHandler): + handler_name = "my_handler" + handler_description = "我的自定义事件处理器" + weight = 10 # 权重,越大越先执行 + intercept_message = False # 是否拦截消息 + init_subscribe = ["napcat_on_received_text"] # 初始订阅的事件 + + async def execute(self, params: dict) -> HandlerResult: + """处理事件""" + try: + message = params.get("message_seg") + print(f"收到消息: {message.data}") + + # 业务逻辑处理 + # ... + + return HandlerResult( + success=True, + continue_process=True, # 是否继续让其他处理器处理 + message="处理成功", + handler_name=self.handler_name + ) + except Exception as e: + return HandlerResult( + success=False, + continue_process=True, + message=f"处理失败: {str(e)}", + handler_name=self.handler_name + ) + +``` + +### 接收外界输入的event: + +napcat_kick_group <- (user_id, group_id) 踢出某个群组中的某个用户 +napcat_mute_user <- (user_id, group_id, time) 禁言某个群组中的某个用户 +napcat_unmute_user <- (user_id, group_id) 取消禁言某个群组中的某个用户 +napcat_mute_group <- (user_id, group_id) 禁言某个群组 +napcat_unmute_group <- (user_id, group_id) 取消禁言某个群组 +napcat_add_friend <- (user_id) 向某个用户发出好友请求 +napcat_accept_friend <- (user_id) 接收某个用户的好友请求 +napcat_reject_friend <- (user_id) 拒绝某个用户的好友请求 +... +此类事件只由外部插件触发并传递参数,由napcat完成请求任务。 + +外部插件需要触发此类的event并传递正确的参数。 + +```python +from src.plugin_system.core.event_manager import event_manager + +# 触发事件 +await event_manager.trigger_event("napcat_accept_friend", user_id = 1234123) +``` + diff --git a/requirements.txt b/requirements.txt index 7485a43cc..0bfafb8da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,4 +64,7 @@ chromadb asyncio tavily-python google-generativeai -lunar_python \ No newline at end of file +lunar_python + +python-multipart +aiofiles \ No newline at end of file diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 6ca4dc916..ccb90da2d 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager from src.plugin_system.base.component_types import ChatMode from src.schedule.schedule_manager import schedule_manager from src.plugin_system.apis import message_api -from src.mood.mood_manager import mood_manager from .hfc_context import HfcContext from .energy_manager import EnergyManager diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py index 95643c722..1db532b5d 100644 --- a/src/chat/utils/prompt_builder.py +++ b/src/chat/utils/prompt_builder.py @@ -7,33 +7,11 @@ from contextlib import asynccontextmanager from typing import Dict, Any, Optional, List, Union from src.common.logger import get_logger -from src.common.tool_history import ToolHistoryManager install(extra_lines=3) logger = get_logger("prompt_build") -# 创建工具历史管理器实例 -tool_history_manager = ToolHistoryManager() - -def get_tool_history_prompt(message_id: Optional[str] = None) -> str: - """获取工具历史提示词 - - Args: - message_id: 会话ID, 用于只获取当前会话的历史 - - Returns: - 格式化的工具历史提示词 - """ - from src.config.config import global_config - - if not global_config.tool.history.enable_prompt_history: - return "" - - return tool_history_manager.get_recent_history_prompt( - chat_id=message_id - ) - class PromptContext: def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} @@ -49,7 +27,7 @@ class PromptContext: @_current_context.setter def _current_context(self, value: Optional[str]): """设置当前协程的上下文ID""" - self._current_context_var.set(value) + self._current_context_var.set(value) # type: ignore @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): @@ -73,7 +51,7 @@ class PromptContext: # 保存当前协程的上下文值,不影响其他协程 previous_context = self._current_context # 设置当前协程的新上下文 - token = self._current_context_var.set(context_id) if context_id else None + token = self._current_context_var.set(context_id) if context_id else None # type: ignore else: # 如果没有提供新上下文,保持当前上下文不变 previous_context = self._current_context @@ -111,7 +89,8 @@ class PromptContext: """异步注册提示模板到指定作用域""" async with self._context_lock: if target_context := context_id or self._current_context: - self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt + if prompt.name: + self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt class PromptManager: @@ -153,40 +132,15 @@ class PromptManager: def add_prompt(self, name: str, fstr: str) -> "Prompt": prompt = Prompt(fstr, name=name) - self._prompts[prompt.name] = prompt + if prompt.name: + self._prompts[prompt.name] = prompt return prompt async def format_prompt(self, name: str, **kwargs) -> str: # 获取当前提示词 prompt = await self.get_prompt_async(name) - # 获取当前会话ID - message_id = self._context._current_context - - # 获取工具历史提示词 - tool_history = "" - if name in ['action_prompt', 'replyer_prompt', 'planner_prompt', 'tool_executor_prompt']: - tool_history = get_tool_history_prompt(message_id) - # 获取基本格式化结果 result = prompt.format(**kwargs) - - # 如果有工具历史,插入到适当位置 - if tool_history: - # 查找合适的插入点 - # 在人格信息和身份块之后,但在主要内容之前 - identity_end = result.find("```\n现在,你说:") - if identity_end == -1: - # 如果找不到特定标记,尝试在第一个段落后插入 - first_double_newline = result.find("\n\n") - if first_double_newline != -1: - # 在第一个双换行后插入 - result = f"{result[:first_double_newline + 2]}{tool_history}\n{result[first_double_newline + 2:]}" - else: - # 如果找不到合适的位置,添加到开头 - result = f"{tool_history}\n\n{result}" - else: - # 在找到的位置插入 - result = f"{result[:identity_end]}\n{tool_history}\n{result[identity_end:]}" return result @@ -195,6 +149,11 @@ global_prompt_manager = PromptManager() class Prompt(str): + template: str + name: Optional[str] + args: List[str] + _args: List[Any] + _kwargs: Dict[str, Any] # 临时标记,作为类常量 _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" @@ -215,7 +174,7 @@ class Prompt(str): """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs): + def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs): # 如果传入的是元组,转换为列表 if isinstance(args, tuple): args = list(args) @@ -251,7 +210,7 @@ class Prompt(str): @classmethod async def create_async( - cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs + cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs ): """异步创建Prompt实例""" prompt = cls(fstr, name, args, **kwargs) @@ -260,7 +219,9 @@ class Prompt(str): return prompt @classmethod - def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: + def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str: + if kwargs is None: + kwargs = {} # 预处理模板中的转义花括号 processed_template = cls._process_escaped_braces(template) diff --git a/src/chat/utils/rust-video/.gitignore b/src/chat/utils/rust-video/.gitignore new file mode 100644 index 000000000..ea8c4bf7f --- /dev/null +++ b/src/chat/utils/rust-video/.gitignore @@ -0,0 +1 @@ +/target diff --git a/src/chat/utils/rust-video/Cargo.lock b/src/chat/utils/rust-video/Cargo.lock new file mode 100644 index 000000000..8041152b2 --- /dev/null +++ b/src/chat/utils/rust-video/Cargo.lock @@ -0,0 +1,610 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anstream" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "anyhow" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + +[[package]] +name = "cc" +version = "1.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + +[[package]] +name = "chrono" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "clap" +version = "4.5.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c5e4fcf9c21d2e544ca1ee9d8552de13019a42aa7dbf32747fa7aaf1df76e57" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fecb53a0e6fcfb055f686001bc2e2592fa527efaf38dbe81a6a9563562e57d41" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "iana-time-zone" +version = "0.1.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "js-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "log" +version = "0.4.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" + +[[package]] +name = "memchr" +version = "2.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "rust-video" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "clap", + "rayon", + "serde", + "serde_json", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.143" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "wasm-bindgen" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.53.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" diff --git a/src/chat/utils/rust-video/Cargo.toml b/src/chat/utils/rust-video/Cargo.toml new file mode 100644 index 000000000..4120f6f2c --- /dev/null +++ b/src/chat/utils/rust-video/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "rust-video" +version = "0.1.0" +edition = "2021" +authors = ["VideoAnalysis Team"] +description = "Ultra-fast video keyframe extraction tool in Rust" +license = "GPL-3.0" + +[dependencies] +anyhow = "1.0" +clap = { version = "4.0", features = ["derive"] } +rayon = "1.11" + +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +chrono = { version = "0.4", features = ["serde"] } + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +panic = "abort" +strip = true diff --git a/src/chat/utils/rust-video/README.md b/src/chat/utils/rust-video/README.md new file mode 100644 index 000000000..e4c38b0ab --- /dev/null +++ b/src/chat/utils/rust-video/README.md @@ -0,0 +1,221 @@ +# 🎯 Rust Video Keyframe Extraction API + +高性能视频关键帧提取API服务,基于Rust后端 + Python FastAPI。 + +## 📁 项目结构 + +``` +rust-video/ +├── outputs/ # 关键帧输出目录 +├── src/ # Rust源码 +│ └── main.rs +├── target/ # Rust编译文件 +├── api_server.py # 🚀 主API服务器 (整合版) +├── start_server.py # 生产启动脚本 +├── config.py # 配置管理 +├── config.toml # 配置文件 +├── Cargo.toml # Rust项目配置 +├── Cargo.lock # Rust依赖锁定 +├── .gitignore # Git忽略文件 +└── README.md # 项目文档 +``` + +## 快速开始 + +### 1. 安装依赖 +```bash +pip install fastapi uvicorn python-multipart aiofiles +``` + +### 2. 启动服务 +```bash +# 开发模式 +python api_server.py + +# 生产模式 +python start_server.py --mode prod --port 8050 +``` + +### 3. 访问API +- **服务地址**: http://localhost:8050 +- **API文档**: http://localhost:8050/docs +- **健康检查**: http://localhost:8050/health +- **性能指标**: http://localhost:8050/metrics + +## API使用方法 + +### 主要端点 + +#### 1. 提取关键帧 (JSON响应) +```http +POST /extract-keyframes +Content-Type: multipart/form-data + +- video: 视频文件 (.mp4, .avi, .mov, .mkv) +- scene_threshold: 场景变化阈值 (0.1-1.0, 默认0.3) +- max_frames: 最大关键帧数 (1-200, 默认50) +- resize_width: 调整宽度 (可选, 100-1920) +- time_interval: 时间间隔秒数 (可选, 0.1-60.0) +``` + +#### 2. 提取关键帧 (ZIP下载) +```http +POST /extract-keyframes-zip +Content-Type: multipart/form-data + +参数同上,返回包含所有关键帧的ZIP文件 +``` + +#### 3. 健康检查 +```http +GET /health +``` + +#### 4. 性能指标 +```http +GET /metrics +``` + +### Python客户端示例 + +```python +import requests + +# 上传视频并提取关键帧 +files = {'video': open('video.mp4', 'rb')} +data = { + 'scene_threshold': 0.3, + 'max_frames': 50, + 'resize_width': 800 +} + +response = requests.post( + 'http://localhost:8050/extract-keyframes', + files=files, + data=data +) + +result = response.json() +print(f"提取了 {result['keyframe_count']} 个关键帧") +print(f"处理时间: {result['performance']['total_api_time']:.2f}秒") +``` + +### JavaScript客户端示例 + +```javascript +const formData = new FormData(); +formData.append('video', videoFile); +formData.append('scene_threshold', '0.3'); +formData.append('max_frames', '50'); + +fetch('http://localhost:8050/extract-keyframes', { + method: 'POST', + body: formData +}) +.then(response => response.json()) +.then(data => { + console.log(`提取了 ${data.keyframe_count} 个关键帧`); + console.log(`处理时间: ${data.performance.total_api_time}秒`); +}); +``` + +### cURL示例 + +```bash +curl -X POST "http://localhost:8050/extract-keyframes" \ + -H "accept: application/json" \ + -H "Content-Type: multipart/form-data" \ + -F "video=@video.mp4" \ + -F "scene_threshold=0.3" \ + -F "max_frames=50" +``` + +## ⚙️ 配置 + +编辑 `config.toml` 文件: + +```toml +[server] +host = "0.0.0.0" +port = 8050 +debug = false + +[processing] +default_scene_threshold = 0.3 +default_max_frames = 50 +timeout_seconds = 300 + +[performance] +async_workers = 4 +max_file_size_mb = 500 +``` + +## 性能特性 + +- **异步I/O**: 文件上传/下载异步处理 +- **多线程处理**: 视频处理在独立线程池 +- **内存优化**: 流式处理,减少内存占用 +- **智能清理**: 自动临时文件管理 +- **性能监控**: 实时处理时间和吞吐量统计 + +总之就是非常快() + +## 响应格式 + +```json +{ + "status": "success", + "processing_time": 4.5, + "output_directory": "/tmp/output_xxx", + "keyframe_count": 15, + "keyframes": [ + "/tmp/output_xxx/frame_001.jpg", + "/tmp/output_xxx/frame_002.jpg" + ], + "performance": { + "file_size_mb": 209.7, + "upload_time": 0.23, + "processing_time": 4.5, + "total_api_time": 4.73, + "upload_speed_mbps": 912.2 + }, + "rust_output": "处理完成", + "command": "rust-video input.mp4 output/ --scene-threshold 0.3 --max-frames 50" +} +``` + +## 故障排除 + +### 常见问题 + +1. **Rust binary not found** + ```bash + cargo build # 重新构建Rust项目 + ``` + +2. **端口被占用** + ```bash + # 修改config.toml中的端口号 + port = 8051 + ``` + +3. **内存不足** + ```bash + # 减少max_frames或resize_width参数 + ``` + +### 日志查看 + +服务启动时会显示详细的状态信息,包括: +- Rust二进制文件位置 +- 配置加载状态 +- 服务监听地址 + +## 集成支持 + +本API设计为独立服务,可轻松集成到任何项目中: + +- **AI Bot项目**: 通过HTTP API调用 +- **Web应用**: 直接前端调用或后端代理 +- **移动应用**: REST API标准接口 +- **批处理脚本**: Python/Shell脚本调用 diff --git a/src/chat/utils/rust-video/api_server.py b/src/chat/utils/rust-video/api_server.py new file mode 100644 index 000000000..aeb3fa248 --- /dev/null +++ b/src/chat/utils/rust-video/api_server.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +""" +Rust Video Keyframe Extraction API Server +高性能视频关键帧提取API服务 + +功能: +- 视频上传和关键帧提取 +- 异步多线程处理 +- 性能监控和健康检查 +- 自动资源清理 + +启动: python api_server.py +地址: http://localhost:8050 +""" + +import os +import json +import subprocess +import tempfile +import zipfile +import shutil +import asyncio +import time +import logging +from datetime import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any + +import uvicorn +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks +from fastapi.responses import FileResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field + +# 导入配置管理 +from config import config + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# ============================================================================ +# 内置视频处理器 (整合版) +# ============================================================================ + +class VideoKeyframeExtractor: + """整合的视频关键帧提取器""" + + def __init__(self, rust_binary_path: Optional[str] = None): + self.rust_binary_path = rust_binary_path or self._find_rust_binary() + if not self.rust_binary_path or not Path(self.rust_binary_path).exists(): + raise FileNotFoundError(f"Rust binary not found: {self.rust_binary_path}") + + def _find_rust_binary(self) -> str: + """查找Rust二进制文件""" + possible_paths = [ + "./target/debug/rust-video.exe", + "./target/release/rust-video.exe", + "./target/debug/rust-video", + "./target/release/rust-video" + ] + + for path in possible_paths: + if Path(path).exists(): + return str(Path(path).absolute()) + + # 尝试构建 + try: + subprocess.run(["cargo", "build"], check=True, capture_output=True) + for path in possible_paths: + if Path(path).exists(): + return str(Path(path).absolute()) + except subprocess.CalledProcessError: + pass + + raise FileNotFoundError("Rust binary not found and build failed") + + def process_video( + self, + video_path: str, + output_dir: str = "outputs", + scene_threshold: float = 0.3, + max_frames: int = 50, + resize_width: Optional[int] = None, + time_interval: Optional[float] = None + ) -> Dict[str, Any]: + """处理视频提取关键帧""" + + video_path = Path(video_path) + if not video_path.exists(): + raise FileNotFoundError(f"Video file not found: {video_path}") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # 构建命令 + cmd = [self.rust_binary_path, str(video_path), str(output_dir)] + cmd.extend(["--scene-threshold", str(scene_threshold)]) + cmd.extend(["--max-frames", str(max_frames)]) + + if resize_width: + cmd.extend(["--resize-width", str(resize_width)]) + if time_interval: + cmd.extend(["--time-interval", str(time_interval)]) + + # 执行处理 + start_time = time.time() + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + timeout=300 # 5分钟超时 + ) + + processing_time = time.time() - start_time + + # 解析输出 + output_files = list(output_dir.glob("*.jpg")) + + return { + "status": "success", + "processing_time": processing_time, + "output_directory": str(output_dir), + "keyframe_count": len(output_files), + "keyframes": [str(f) for f in output_files], + "rust_output": result.stdout, + "command": " ".join(cmd) + } + + except subprocess.TimeoutExpired: + raise HTTPException(status_code=408, detail="Video processing timeout") + except subprocess.CalledProcessError as e: + raise HTTPException( + status_code=500, + detail=f"Video processing failed: {e.stderr}" + ) + +# ============================================================================ +# 异步处理器 (整合版) +# ============================================================================ + +class AsyncVideoProcessor: + """高性能异步视频处理器""" + + def __init__(self): + self.extractor = VideoKeyframeExtractor() + + async def process_video_async( + self, + upload_file: UploadFile, + processing_params: Dict[str, Any] + ) -> Dict[str, Any]: + """异步视频处理主流程""" + + start_time = time.time() + + # 1. 异步保存上传文件 + upload_start = time.time() + temp_fd, temp_path_str = tempfile.mkstemp(suffix='.mp4') + temp_path = Path(temp_path_str) + + try: + os.close(temp_fd) + + # 异步读取并保存文件 + content = await upload_file.read() + with open(temp_path, 'wb') as f: + f.write(content) + + upload_time = time.time() - upload_start + file_size = len(content) + + # 2. 多线程处理视频 + process_start = time.time() + temp_output_dir = tempfile.mkdtemp() + output_path = Path(temp_output_dir) + + try: + # 在线程池中异步处理 + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + self._process_video_sync, + str(temp_path), + str(output_path), + processing_params + ) + + process_time = time.time() - process_start + total_time = time.time() - start_time + + # 添加性能指标 + result.update({ + 'performance': { + 'file_size_mb': file_size / (1024 * 1024), + 'upload_time': upload_time, + 'processing_time': process_time, + 'total_api_time': total_time, + 'upload_speed_mbps': (file_size / (1024 * 1024)) / upload_time if upload_time > 0 else 0 + } + }) + + return result + + finally: + # 清理输出目录 + try: + shutil.rmtree(temp_output_dir, ignore_errors=True) + except Exception as e: + logger.warning(f"Failed to cleanup output directory: {e}") + + finally: + # 清理临时文件 + try: + if temp_path.exists(): + temp_path.unlink() + except Exception as e: + logger.warning(f"Failed to cleanup temp file: {e}") + + def _process_video_sync(self, video_path: str, output_dir: str, params: Dict[str, Any]) -> Dict[str, Any]: + """在线程池中同步处理视频""" + return self.extractor.process_video( + video_path=video_path, + output_dir=output_dir, + **params + ) + +# ============================================================================ +# FastAPI 应用初始化 +# ============================================================================ + +app = FastAPI( + title="Rust Video Keyframe API", + description="高性能视频关键帧提取API服务", + version="2.0.0", + docs_url="/docs", + redoc_url="/redoc" +) + +# CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 全局处理器实例 +video_processor = AsyncVideoProcessor() + +# 简单的统计 +stats = { + "total_requests": 0, + "processing_times": [], + "start_time": datetime.now() +} + +# ============================================================================ +# API 路由 +# ============================================================================ + +@app.get("/", response_class=JSONResponse) +async def root(): + """API根路径""" + return { + "message": "Rust Video Keyframe Extraction API", + "version": "2.0.0", + "status": "ready", + "docs": "/docs", + "health": "/health", + "metrics": "/metrics" + } + +@app.get("/health") +async def health_check(): + """健康检查端点""" + try: + # 检查Rust二进制 + rust_binary = video_processor.extractor.rust_binary_path + rust_status = "ok" if Path(rust_binary).exists() else "missing" + + return { + "status": rust_status, + "timestamp": datetime.now().isoformat(), + "version": "2.0.0", + "rust_binary": rust_binary + } + except Exception as e: + raise HTTPException(status_code=503, detail=f"Health check failed: {str(e)}") + +@app.get("/metrics") +async def get_metrics(): + """获取性能指标""" + avg_time = sum(stats["processing_times"]) / len(stats["processing_times"]) if stats["processing_times"] else 0 + uptime = (datetime.now() - stats["start_time"]).total_seconds() + + return { + "total_requests": stats["total_requests"], + "average_processing_time": avg_time, + "last_24h_requests": stats["total_requests"], # 简化版本 + "system_info": { + "uptime_seconds": uptime, + "memory_usage": "N/A", # 可以扩展 + "cpu_usage": "N/A" + } + } + +@app.post("/extract-keyframes") +async def extract_keyframes( + video: UploadFile = File(..., description="视频文件"), + scene_threshold: float = Form(0.3, description="场景变化阈值"), + max_frames: int = Form(50, description="最大关键帧数量"), + resize_width: Optional[int] = Form(None, description="调整宽度"), + time_interval: Optional[float] = Form(None, description="时间间隔") +): + """提取视频关键帧 (主要API端点)""" + + # 参数验证 + if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): + raise HTTPException(status_code=400, detail="不支持的视频格式") + + # 更新统计 + stats["total_requests"] += 1 + + try: + # 构建处理参数 + params = { + "scene_threshold": scene_threshold, + "max_frames": max_frames + } + if resize_width: + params["resize_width"] = resize_width + if time_interval: + params["time_interval"] = time_interval + + # 异步处理 + start_time = time.time() + result = await video_processor.process_video_async(video, params) + processing_time = time.time() - start_time + + # 更新统计 + stats["processing_times"].append(processing_time) + if len(stats["processing_times"]) > 100: # 保持最近100次记录 + stats["processing_times"] = stats["processing_times"][-100:] + + return JSONResponse(content=result) + + except Exception as e: + logger.error(f"Processing failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") + +@app.post("/extract-keyframes-zip") +async def extract_keyframes_zip( + video: UploadFile = File(...), + scene_threshold: float = Form(0.3), + max_frames: int = Form(50), + resize_width: Optional[int] = Form(None), + time_interval: Optional[float] = Form(None) +): + """提取关键帧并返回ZIP文件""" + + # 验证文件类型 + if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): + raise HTTPException(status_code=400, detail="不支持的视频格式") + + # 创建临时目录 + temp_input_fd, temp_input_path = tempfile.mkstemp(suffix='.mp4') + temp_output_dir = tempfile.mkdtemp() + + try: + os.close(temp_input_fd) + + # 保存上传的视频 + content = await video.read() + with open(temp_input_path, 'wb') as f: + f.write(content) + + # 处理参数 + params = { + "scene_threshold": scene_threshold, + "max_frames": max_frames + } + if resize_width: + params["resize_width"] = resize_width + if time_interval: + params["time_interval"] = time_interval + + # 处理视频 + result = video_processor.extractor.process_video( + video_path=temp_input_path, + output_dir=temp_output_dir, + **params + ) + + # 创建ZIP文件 + zip_fd, zip_path = tempfile.mkstemp(suffix='.zip') + os.close(zip_fd) + + with zipfile.ZipFile(zip_path, 'w') as zip_file: + # 添加关键帧图片 + for keyframe_path in result.get("keyframes", []): + if Path(keyframe_path).exists(): + zip_file.write(keyframe_path, Path(keyframe_path).name) + + # 添加处理信息 + info_content = json.dumps(result, indent=2, ensure_ascii=False) + zip_file.writestr("processing_info.json", info_content) + + # 返回ZIP文件 + return FileResponse( + zip_path, + media_type='application/zip', + filename=f"keyframes_{video.filename}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") + + finally: + # 清理临时文件 + for path in [temp_input_path, temp_output_dir]: + try: + if Path(path).is_file(): + Path(path).unlink() + elif Path(path).is_dir(): + shutil.rmtree(path, ignore_errors=True) + except Exception: + pass + +# ============================================================================ +# 应用启动 +# ============================================================================ + +def main(): + """启动API服务器""" + + # 获取配置 + server_config = config.get('server') + host = server_config.get('host', '0.0.0.0') + port = server_config.get('port', 8050) + + print(f""" +Rust Video Keyframe Extraction API +===================================== +地址: http://{host}:{port} +文档: http://{host}:{port}/docs +健康检查: http://{host}:{port}/health +性能指标: http://{host}:{port}/metrics +===================================== + """) + + # 检查Rust二进制 + try: + rust_binary = video_processor.extractor.rust_binary_path + print(f"✓ Rust binary: {rust_binary}") + except Exception as e: + print(f"⚠️ Rust binary check failed: {e}") + + # 启动服务器 + uvicorn.run( + "api_server:app", + host=host, + port=port, + reload=False, # 生产环境关闭热重载 + access_log=True + ) + +if __name__ == "__main__": + main() diff --git a/src/chat/utils/rust-video/config.py b/src/chat/utils/rust-video/config.py new file mode 100644 index 000000000..c85b8f9ea --- /dev/null +++ b/src/chat/utils/rust-video/config.py @@ -0,0 +1,115 @@ +""" +配置管理模块 +处理 config.toml 文件的读取和管理 +""" + +import os +from pathlib import Path +from typing import Dict, Any + +try: + import toml +except ImportError: + print("⚠️ 需要安装 toml: pip install toml") + # 提供基础配置作为后备 + toml = None + +class ConfigManager: + """配置管理器""" + + def __init__(self, config_file: str = "config.toml"): + self.config_file = Path(config_file) + self._config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + """加载配置文件""" + if toml is None or not self.config_file.exists(): + return self._get_default_config() + + try: + with open(self.config_file, 'r', encoding='utf-8') as f: + return toml.load(f) + except Exception as e: + print(f"⚠️ 配置文件读取失败: {e}") + return self._get_default_config() + + def _get_default_config(self) -> Dict[str, Any]: + """默认配置""" + return { + "server": { + "host": "0.0.0.0", + "port": 8000, + "workers": 1, + "reload": False, + "log_level": "info" + }, + "api": { + "title": "Video Keyframe Extraction API", + "description": "高性能视频关键帧提取服务", + "version": "1.0.0", + "max_file_size": "100MB" + }, + "processing": { + "default_threshold": 0.3, + "default_output_format": "png", + "max_frames": 10000, + "temp_dir": "temp", + "upload_dir": "uploads", + "output_dir": "outputs" + }, + "rust": { + "executable_name": "video_keyframe_extractor", + "executable_path": "target/release" + }, + "ffmpeg": { + "auto_detect": True, + "custom_path": "", + "timeout": 300 + }, + "storage": { + "cleanup_interval": 3600, + "max_storage_size": "10GB", + "result_retention_days": 7 + }, + "monitoring": { + "enable_metrics": True, + "enable_logging": True, + "log_file": "logs/api.log", + "max_log_size": "100MB" + }, + "security": { + "allowed_origins": ["*"], + "max_concurrent_tasks": 10, + "rate_limit_per_minute": 60 + }, + "development": { + "debug": False, + "auto_reload": False, + "cors_enabled": True + } + } + + def get(self, section: str, key: str = None, default=None): + """获取配置值""" + if key is None: + return self._config.get(section, default) + return self._config.get(section, {}).get(key, default) + + def get_server_config(self): + """获取服务器配置""" + return self.get("server") + + def get_api_config(self): + """获取API配置""" + return self.get("api") + + def get_processing_config(self): + """获取处理配置""" + return self.get("processing") + + def reload(self): + """重新加载配置""" + self._config = self._load_config() + +# 全局配置实例 +config = ConfigManager() diff --git a/src/chat/utils/rust-video/config.toml b/src/chat/utils/rust-video/config.toml new file mode 100644 index 000000000..56e7799cc --- /dev/null +++ b/src/chat/utils/rust-video/config.toml @@ -0,0 +1,70 @@ +# 🔧 Video Keyframe Extraction API 配置文件 + +[server] +# 服务器配置 +host = "0.0.0.0" +port = 8050 +workers = 1 +reload = false +log_level = "info" + +[api] +# API 基础配置 +title = "Video Keyframe Extraction API" +description = "视频关键帧提取服务" +version = "1.0.0" +max_file_size = "100MB" # 最大文件大小 + +[processing] +# 视频处理配置 +default_threshold = 0.3 +default_output_format = "png" +max_frames = 10000 +temp_dir = "temp" +upload_dir = "uploads" +output_dir = "outputs" + +[rust] +# Rust 程序配置 +executable_name = "video_keyframe_extractor" +executable_path = "target/release" # 相对路径,自动检测 + +[ffmpeg] +# FFmpeg 配置 +auto_detect = true +custom_path = "" # 留空则自动检测 +timeout = 300 # 秒 + +[performance] +# 性能优化配置 +async_workers = 4 # 异步文件处理工作线程数 +upload_chunk_size = 8192 # 上传块大小 (字节) +max_concurrent_uploads = 10 # 最大并发上传数 +compression_level = 1 # ZIP 压缩级别 (0-9, 1=快速) +stream_chunk_size = 8192 # 流式响应块大小 +enable_performance_metrics = true # 启用性能监控 + +[storage] +# 存储配置 +cleanup_interval = 3600 # 清理间隔(秒) +max_storage_size = "10GB" +result_retention_days = 7 + +[monitoring] +# 监控配置 +enable_metrics = true +enable_logging = true +log_file = "logs/api.log" +max_log_size = "100MB" + +[security] +# 安全配置 +allowed_origins = ["*"] +max_concurrent_tasks = 10 +rate_limit_per_minute = 60 + +[development] +# 开发环境配置 +debug = false +auto_reload = false +cors_enabled = true diff --git a/src/chat/utils/rust-video/src/main.rs b/src/chat/utils/rust-video/src/main.rs new file mode 100644 index 000000000..13fd98cbb --- /dev/null +++ b/src/chat/utils/rust-video/src/main.rs @@ -0,0 +1,710 @@ +//! # Rust Video Keyframe Extractor +//! +//! Ultra-fast video keyframe extraction tool with SIMD optimization. +//! +//! ## Features +//! - AVX2/SSE2 SIMD optimization for maximum performance +//! - Memory-efficient streaming processing with FFmpeg +//! - Multi-threaded parallel processing +//! - Release-optimized for production use +//! +//! ## Performance +//! - 150+ FPS processing speed +//! - Real-time video analysis capability +//! - Minimal memory footprint +//! +//! ## Usage +//! ```bash +//! # Single video processing +//! rust-video --input video.mp4 --output ./keyframes --threshold 2.0 +//! +//! # Benchmark mode +//! rust-video --benchmark --input video.mp4 --output ./results +//! ``` + +use anyhow::{Context, Result}; +use chrono::prelude::*; +use clap::Parser; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::io::{BufReader, Read}; +use std::path::PathBuf; +use std::process::{Command, Stdio}; +use std::time::Instant; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// Ultra-fast video keyframe extraction tool +#[derive(Parser)] +#[command(name = "rust-video")] +#[command(version = "0.1.0")] +#[command(about = "Ultra-fast video keyframe extraction with SIMD optimization")] +#[command(long_about = None)] +struct Args { + /// Input video file path + #[arg(short, long, help = "Path to the input video file")] + input: Option, + + /// Output directory for keyframes and results + #[arg(short, long, default_value = "./output", help = "Output directory")] + output: PathBuf, + + /// Change threshold for keyframe detection (higher = fewer keyframes) + #[arg(short, long, default_value = "2.0", help = "Keyframe detection threshold")] + threshold: f64, + + /// Number of parallel threads (0 = auto-detect) + #[arg(short = 'j', long, default_value = "0", help = "Number of threads")] + threads: usize, + + /// Maximum number of keyframes to save (0 = save all) + #[arg(short, long, default_value = "50", help = "Maximum keyframes to save")] + max_save: usize, + + /// Run performance benchmark suite + #[arg(long, help = "Run comprehensive benchmark tests")] + benchmark: bool, + + /// Maximum frames to process (0 = process all frames) + #[arg(long, default_value = "0", help = "Limit number of frames to process")] + max_frames: usize, + + /// FFmpeg executable path + #[arg(long, default_value = "ffmpeg", help = "Path to FFmpeg executable")] + ffmpeg_path: PathBuf, + + /// Enable SIMD optimizations (AVX2/SSE2) + #[arg(long, default_value = "true", help = "Enable SIMD optimizations")] + use_simd: bool, + + /// Processing block size for cache optimization + #[arg(long, default_value = "8192", help = "Block size for processing")] + block_size: usize, + + /// Verbose output + #[arg(short, long, help = "Enable verbose output")] + verbose: bool, +} + +/// Video frame representation optimized for SIMD processing +#[derive(Debug, Clone)] +struct VideoFrame { + frame_number: usize, + width: usize, + height: usize, + data: Vec, // Grayscale data, aligned for SIMD +} + +impl VideoFrame { + /// Create a new video frame with SIMD-aligned data + fn new(frame_number: usize, width: usize, height: usize, mut data: Vec) -> Self { + // Ensure data length is multiple of 32 for AVX2 processing + let remainder = data.len() % 32; + if remainder != 0 { + data.resize(data.len() + (32 - remainder), 0); + } + + Self { + frame_number, + width, + height, + data, + } + } + + /// Calculate frame difference using parallel SIMD processing + fn calculate_difference_parallel_simd(&self, other: &VideoFrame, block_size: usize, use_simd: bool) -> f64 { + if self.width != other.width || self.height != other.height { + return f64::MAX; + } + + let total_pixels = self.width * self.height; + let num_blocks = (total_pixels + block_size - 1) / block_size; + + let total_diff: u64 = (0..num_blocks) + .into_par_iter() + .map(|block_idx| { + let start = block_idx * block_size; + let end = ((block_idx + 1) * block_size).min(total_pixels); + let block_len = end - start; + + if use_simd { + #[cfg(target_arch = "x86_64")] + { + unsafe { + if std::arch::is_x86_feature_detected!("avx2") { + return self.calculate_difference_avx2_block(&other.data, start, block_len); + } else if std::arch::is_x86_feature_detected!("sse2") { + return self.calculate_difference_sse2_block(&other.data, start, block_len); + } + } + } + } + + // Fallback scalar implementation + self.data[start..end] + .iter() + .zip(other.data[start..end].iter()) + .map(|(a, b)| (*a as i32 - *b as i32).abs() as u64) + .sum() + }) + .sum(); + + total_diff as f64 / total_pixels as f64 + } + + /// Standard frame difference calculation (non-SIMD) + fn calculate_difference_standard(&self, other: &VideoFrame) -> f64 { + if self.width != other.width || self.height != other.height { + return f64::MAX; + } + + let len = self.width * self.height; + let total_diff: u64 = self.data[..len] + .iter() + .zip(other.data[..len].iter()) + .map(|(a, b)| (*a as i32 - *b as i32).abs() as u64) + .sum(); + + total_diff as f64 / len as f64 + } + + /// AVX2 optimized block processing + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn calculate_difference_avx2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 { + let mut total_diff = 0u64; + let chunks = len / 32; + + for i in 0..chunks { + let offset = start + i * 32; + + let a = _mm256_loadu_si256(self.data.as_ptr().add(offset) as *const __m256i); + let b = _mm256_loadu_si256(other_data.as_ptr().add(offset) as *const __m256i); + + let diff = _mm256_sad_epu8(a, b); + let result = _mm256_extract_epi64(diff, 0) as u64 + + _mm256_extract_epi64(diff, 1) as u64 + + _mm256_extract_epi64(diff, 2) as u64 + + _mm256_extract_epi64(diff, 3) as u64; + + total_diff += result; + } + + // Process remaining bytes + for i in (start + chunks * 32)..(start + len) { + total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64; + } + + total_diff + } + + /// SSE2 optimized block processing + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "sse2")] + unsafe fn calculate_difference_sse2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 { + let mut total_diff = 0u64; + let chunks = len / 16; + + for i in 0..chunks { + let offset = start + i * 16; + + let a = _mm_loadu_si128(self.data.as_ptr().add(offset) as *const __m128i); + let b = _mm_loadu_si128(other_data.as_ptr().add(offset) as *const __m128i); + + let diff = _mm_sad_epu8(a, b); + let result = _mm_extract_epi64(diff, 0) as u64 + _mm_extract_epi64(diff, 1) as u64; + + total_diff += result; + } + + // Process remaining bytes + for i in (start + chunks * 16)..(start + len) { + total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64; + } + + total_diff + } +} + +/// Performance measurement results +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PerformanceResult { + test_name: String, + video_file: String, + total_time_ms: f64, + frame_extraction_time_ms: f64, + keyframe_analysis_time_ms: f64, + total_frames: usize, + keyframes_extracted: usize, + keyframe_ratio: f64, + processing_fps: f64, + threshold: f64, + optimization_type: String, + simd_enabled: bool, + threads_used: usize, + timestamp: String, +} + +/// Extract video frames using FFmpeg memory streaming +fn extract_frames_memory_stream( + video_path: &PathBuf, + ffmpeg_path: &PathBuf, + max_frames: usize, + verbose: bool, +) -> Result<(Vec, usize, usize)> { + if verbose { + println!("🎬 Extracting frames using FFmpeg memory streaming..."); + println!("📁 Video: {}", video_path.display()); + } + + // Get video information + let probe_output = Command::new(ffmpeg_path) + .args(["-i", video_path.to_str().unwrap(), "-hide_banner"]) + .output() + .context("Failed to probe video with FFmpeg")?; + + let probe_info = String::from_utf8_lossy(&probe_output.stderr); + let (width, height) = parse_video_dimensions(&probe_info) + .ok_or_else(|| anyhow::anyhow!("Cannot parse video dimensions"))?; + + if verbose { + println!("📐 Video dimensions: {}x{}", width, height); + } + + // Build optimized FFmpeg command + let mut cmd = Command::new(ffmpeg_path); + cmd.args([ + "-i", video_path.to_str().unwrap(), + "-f", "rawvideo", + "-pix_fmt", "gray", + "-an", // No audio + "-threads", "0", // Auto-detect threads + "-preset", "ultrafast", // Fastest preset + ]); + + if max_frames > 0 { + cmd.args(["-frames:v", &max_frames.to_string()]); + } + + cmd.args(["-"]).stdout(Stdio::piped()).stderr(Stdio::null()); + + let start_time = Instant::now(); + let mut child = cmd.spawn().context("Failed to spawn FFmpeg process")?; + let stdout = child.stdout.take().unwrap(); + let mut reader = BufReader::with_capacity(1024 * 1024, stdout); // 1MB buffer + + let frame_size = width * height; + let mut frames = Vec::new(); + let mut frame_count = 0; + let mut frame_buffer = vec![0u8; frame_size]; + + if verbose { + println!("📦 Frame size: {} bytes", frame_size); + } + + // Stream frame data directly into memory + loop { + match reader.read_exact(&mut frame_buffer) { + Ok(()) => { + frames.push(VideoFrame::new( + frame_count, + width, + height, + frame_buffer.clone(), + )); + frame_count += 1; + + if verbose && frame_count % 200 == 0 { + print!("\r⚡ Frames processed: {}", frame_count); + } + + if max_frames > 0 && frame_count >= max_frames { + break; + } + } + Err(_) => break, // End of stream + } + } + + let _ = child.wait(); + + if verbose { + println!("\r✅ Frame extraction complete: {} frames in {:.2}s", + frame_count, start_time.elapsed().as_secs_f64()); + } + + Ok((frames, width, height)) +} + +/// Parse video dimensions from FFmpeg probe output +fn parse_video_dimensions(probe_info: &str) -> Option<(usize, usize)> { + for line in probe_info.lines() { + if line.contains("Video:") && line.contains("x") { + for part in line.split_whitespace() { + if let Some(x_pos) = part.find('x') { + let width_str = &part[..x_pos]; + let height_part = &part[x_pos + 1..]; + let height_str = height_part.split(',').next().unwrap_or(height_part); + + if let (Ok(width), Ok(height)) = (width_str.parse::(), height_str.parse::()) { + return Some((width, height)); + } + } + } + } + } + None +} + +/// Extract keyframes using optimized algorithms +fn extract_keyframes_optimized( + frames: &[VideoFrame], + threshold: f64, + use_simd: bool, + block_size: usize, + verbose: bool, +) -> Result> { + if frames.len() < 2 { + return Ok(Vec::new()); + } + + let optimization_name = if use_simd { "SIMD+Parallel" } else { "Standard Parallel" }; + if verbose { + println!("🚀 Keyframe analysis (threshold: {}, optimization: {})", threshold, optimization_name); + } + + let start_time = Instant::now(); + + // Parallel computation of frame differences + let differences: Vec = frames + .par_windows(2) + .map(|pair| { + if use_simd { + pair[0].calculate_difference_parallel_simd(&pair[1], block_size, true) + } else { + pair[0].calculate_difference_standard(&pair[1]) + } + }) + .collect(); + + // Find keyframes based on threshold + let keyframe_indices: Vec = differences + .par_iter() + .enumerate() + .filter_map(|(i, &diff)| { + if diff > threshold { + Some(i + 1) + } else { + None + } + }) + .collect(); + + if verbose { + println!("⚡ Analysis complete in {:.2}s", start_time.elapsed().as_secs_f64()); + println!("🎯 Found {} keyframes", keyframe_indices.len()); + } + + Ok(keyframe_indices) +} + +/// Save keyframes as JPEG images using FFmpeg +fn save_keyframes_optimized( + video_path: &PathBuf, + keyframe_indices: &[usize], + output_dir: &PathBuf, + ffmpeg_path: &PathBuf, + max_save: usize, + verbose: bool, +) -> Result { + if keyframe_indices.is_empty() { + if verbose { + println!("⚠️ No keyframes to save"); + } + return Ok(0); + } + + if verbose { + println!("💾 Saving keyframes..."); + } + + fs::create_dir_all(output_dir).context("Failed to create output directory")?; + + let save_count = keyframe_indices.len().min(max_save); + let mut saved = 0; + + for (i, &frame_idx) in keyframe_indices.iter().take(save_count).enumerate() { + let output_path = output_dir.join(format!("keyframe_{:03}.jpg", i + 1)); + let timestamp = frame_idx as f64 / 30.0; // Assume 30 FPS + + let output = Command::new(ffmpeg_path) + .args([ + "-i", video_path.to_str().unwrap(), + "-ss", ×tamp.to_string(), + "-vframes", "1", + "-q:v", "2", // High quality + "-y", + output_path.to_str().unwrap(), + ]) + .output() + .context("Failed to extract keyframe with FFmpeg")?; + + if output.status.success() { + saved += 1; + if verbose && (saved % 10 == 0 || saved == save_count) { + print!("\r💾 Saved: {}/{} keyframes", saved, save_count); + } + } else if verbose { + eprintln!("⚠️ Failed to save keyframe {}", frame_idx); + } + } + + if verbose { + println!("\r✅ Keyframe saving complete: {}/{}", saved, save_count); + } + + Ok(saved) +} + +/// Run performance test +fn run_performance_test( + video_path: &PathBuf, + threshold: f64, + test_name: &str, + ffmpeg_path: &PathBuf, + max_frames: usize, + use_simd: bool, + block_size: usize, + verbose: bool, +) -> Result { + if verbose { + println!("\n{}", "=".repeat(60)); + println!("⚡ Running test: {}", test_name); + println!("{}", "=".repeat(60)); + } + + let total_start = Instant::now(); + + // Frame extraction + let extraction_start = Instant::now(); + let (frames, _width, _height) = extract_frames_memory_stream(video_path, ffmpeg_path, max_frames, verbose)?; + let extraction_time = extraction_start.elapsed().as_secs_f64() * 1000.0; + + // Keyframe analysis + let analysis_start = Instant::now(); + let keyframe_indices = extract_keyframes_optimized(&frames, threshold, use_simd, block_size, verbose)?; + let analysis_time = analysis_start.elapsed().as_secs_f64() * 1000.0; + + let total_time = total_start.elapsed().as_secs_f64() * 1000.0; + + let optimization_type = if use_simd { + format!("SIMD+Parallel(block:{})", block_size) + } else { + "Standard Parallel".to_string() + }; + + let result = PerformanceResult { + test_name: test_name.to_string(), + video_file: video_path.file_name().unwrap().to_string_lossy().to_string(), + total_time_ms: total_time, + frame_extraction_time_ms: extraction_time, + keyframe_analysis_time_ms: analysis_time, + total_frames: frames.len(), + keyframes_extracted: keyframe_indices.len(), + keyframe_ratio: keyframe_indices.len() as f64 / frames.len() as f64 * 100.0, + processing_fps: frames.len() as f64 / (total_time / 1000.0), + threshold, + optimization_type, + simd_enabled: use_simd, + threads_used: rayon::current_num_threads(), + timestamp: Local::now().format("%Y-%m-%d %H:%M:%S").to_string(), + }; + + if verbose { + println!("\n⚡ Test Results:"); + println!(" 🕐 Total time: {:.2}ms ({:.2}s)", result.total_time_ms, result.total_time_ms / 1000.0); + println!(" 📥 Extraction: {:.2}ms ({:.1}%)", result.frame_extraction_time_ms, + result.frame_extraction_time_ms / result.total_time_ms * 100.0); + println!(" 🧮 Analysis: {:.2}ms ({:.1}%)", result.keyframe_analysis_time_ms, + result.keyframe_analysis_time_ms / result.total_time_ms * 100.0); + println!(" 📊 Frames: {}", result.total_frames); + println!(" 🎯 Keyframes: {}", result.keyframes_extracted); + println!(" 🚀 Speed: {:.1} FPS", result.processing_fps); + println!(" ⚙️ Optimization: {}", result.optimization_type); + } + + Ok(result) +} + +/// Run comprehensive benchmark suite +fn run_benchmark_suite(video_path: &PathBuf, output_dir: &PathBuf, ffmpeg_path: &PathBuf, args: &Args) -> Result<()> { + println!("🚀 Rust Video Keyframe Extractor - Benchmark Suite"); + println!("🕐 Time: {}", Local::now().format("%Y-%m-%d %H:%M:%S")); + println!("🎬 Video: {}", video_path.display()); + println!("🧵 Threads: {}", rayon::current_num_threads()); + + // CPU feature detection + #[cfg(target_arch = "x86_64")] + { + println!("🔧 CPU Features:"); + if std::arch::is_x86_feature_detected!("avx2") { + println!(" ✅ AVX2 supported"); + } else if std::arch::is_x86_feature_detected!("sse2") { + println!(" ✅ SSE2 supported"); + } else { + println!(" ⚠️ Scalar only"); + } + } + + let test_configs = vec![ + ("Standard Parallel", false, 8192), + ("SIMD 8K blocks", true, 8192), + ("SIMD 16K blocks", true, 16384), + ("SIMD 32K blocks", true, 32768), + ]; + + let mut results = Vec::new(); + + for (test_name, use_simd, block_size) in test_configs { + match run_performance_test( + video_path, + args.threshold, + test_name, + ffmpeg_path, + 1000, // Test with 1000 frames + use_simd, + block_size, + args.verbose, + ) { + Ok(result) => results.push(result), + Err(e) => println!("❌ Test failed {}: {:?}", test_name, e), + } + } + + // Performance comparison table + println!("\n{}", "=".repeat(120)); + println!("🏆 Benchmark Results"); + println!("{}", "=".repeat(120)); + + println!("{:<20} {:<15} {:<12} {:<12} {:<12} {:<8} {:<8} {:<12} {:<20}", + "Test", "Total(ms)", "Extract(ms)", "Analyze(ms)", "Speed(FPS)", "Frames", "Keyframes", "Threads", "Optimization"); + println!("{}", "-".repeat(120)); + + for result in &results { + println!("{:<20} {:<15.1} {:<12.1} {:<12.1} {:<12.1} {:<8} {:<8} {:<12} {:<20}", + result.test_name, + result.total_time_ms, + result.frame_extraction_time_ms, + result.keyframe_analysis_time_ms, + result.processing_fps, + result.total_frames, + result.keyframes_extracted, + result.threads_used, + result.optimization_type); + } + + // Find best performance + if let Some(best_result) = results.iter().max_by(|a, b| a.processing_fps.partial_cmp(&b.processing_fps).unwrap()) { + println!("\n🏆 Best Performance: {}", best_result.test_name); + println!(" ⚡ Speed: {:.1} FPS", best_result.processing_fps); + println!(" 🕐 Time: {:.2}s", best_result.total_time_ms / 1000.0); + println!(" 🧮 Analysis: {:.2}s", best_result.keyframe_analysis_time_ms / 1000.0); + println!(" ⚙️ Tech: {}", best_result.optimization_type); + } + + // Save detailed results + fs::create_dir_all(output_dir).context("Failed to create output directory")?; + let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string(); + let results_file = output_dir.join(format!("benchmark_results_{}.json", timestamp)); + + let json_results = serde_json::to_string_pretty(&results)?; + fs::write(&results_file, json_results)?; + + println!("\n📄 Detailed results saved to: {}", results_file.display()); + println!("{}", "=".repeat(120)); + + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + + // Setup thread pool + if args.threads > 0 { + rayon::ThreadPoolBuilder::new() + .num_threads(args.threads) + .build_global() + .context("Failed to set thread pool")?; + } + + println!("🚀 Rust Video Keyframe Extractor v0.1.0"); + println!("🧵 Threads: {}", rayon::current_num_threads()); + + // Verify FFmpeg availability + if !args.ffmpeg_path.exists() && args.ffmpeg_path.to_str() == Some("ffmpeg") { + // Try to find ffmpeg in PATH + if Command::new("ffmpeg").arg("-version").output().is_err() { + anyhow::bail!("FFmpeg not found. Please install FFmpeg or specify path with --ffmpeg-path"); + } + } else if !args.ffmpeg_path.exists() { + anyhow::bail!("FFmpeg not found at: {}", args.ffmpeg_path.display()); + } + + if args.benchmark { + // Benchmark mode + let video_path = args.input.clone() + .ok_or_else(|| anyhow::anyhow!("Benchmark requires input video file --input "))?; + + if !video_path.exists() { + anyhow::bail!("Video file not found: {}", video_path.display()); + } + + run_benchmark_suite(&video_path, &args.output, &args.ffmpeg_path, &args)?; + } else { + // Single processing mode + let video_path = args.input + .ok_or_else(|| anyhow::anyhow!("Please specify input video file --input "))?; + + if !video_path.exists() { + anyhow::bail!("Video file not found: {}", video_path.display()); + } + + // Run single keyframe extraction + let result = run_performance_test( + &video_path, + args.threshold, + "Single Processing", + &args.ffmpeg_path, + args.max_frames, + args.use_simd, + args.block_size, + args.verbose, + )?; + + // Extract and save keyframes + let (frames, _, _) = extract_frames_memory_stream(&video_path, &args.ffmpeg_path, args.max_frames, args.verbose)?; + let keyframe_indices = extract_keyframes_optimized(&frames, args.threshold, args.use_simd, args.block_size, args.verbose)?; + let saved_count = save_keyframes_optimized(&video_path, &keyframe_indices, &args.output, &args.ffmpeg_path, args.max_save, args.verbose)?; + + println!("\n✅ Processing Complete!"); + println!("🎯 Keyframes extracted: {}", result.keyframes_extracted); + println!("💾 Keyframes saved: {}", saved_count); + println!("⚡ Processing speed: {:.1} FPS", result.processing_fps); + println!("📁 Output directory: {}", args.output.display()); + + // Save processing report + let timestamp = Local::now().format("%Y%m%d_%H%M%S").to_string(); + let report_file = args.output.join(format!("processing_report_{}.json", timestamp)); + let json_result = serde_json::to_string_pretty(&result)?; + fs::write(&report_file, json_result)?; + + if args.verbose { + println!("📄 Processing report saved to: {}", report_file.display()); + } + } + + Ok(()) +} diff --git a/src/chat/utils/rust-video/start_server.py b/src/chat/utils/rust-video/start_server.py new file mode 100644 index 000000000..b1547d441 --- /dev/null +++ b/src/chat/utils/rust-video/start_server.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +启动脚本 + +支持开发模式和生产模式启动 +""" + +import os +import sys +import subprocess +import argparse +from pathlib import Path +from config import config + + +def check_rust_executable(): + """检查 Rust 可执行文件是否存在""" + rust_config = config.get("rust") + executable_name = rust_config.get("executable_name", "video_keyframe_extractor") + executable_path = rust_config.get("executable_path", "target/release") + + possible_paths = [ + f"./{executable_path}/{executable_name}.exe", + f"./{executable_path}/{executable_name}", + f"./{executable_name}.exe", + f"./{executable_name}" + ] + + for path in possible_paths: + if Path(path).exists(): + print(f"✓ Found Rust executable: {path}") + return str(Path(path).absolute()) + + print("⚠ Warning: Rust executable not found") + print("Please compile first: cargo build --release") + return None + + +def check_dependencies(): + """检查 Python 依赖""" + try: + import fastapi + import uvicorn + print("✓ FastAPI dependencies available") + return True + except ImportError as e: + print(f"✗ Missing dependencies: {e}") + print("Please install: pip install -r requirements.txt") + return False + + +def install_dependencies(): + """安装依赖""" + print("Installing dependencies...") + try: + subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], + check=True) + print("✓ Dependencies installed successfully") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Failed to install dependencies: {e}") + return False + + +def start_development_server(host="127.0.0.1", port=8050, reload=True): + """启动开发服务器""" + print(f" Starting development server on http://{host}:{port}") + print(f" API docs: http://{host}:{port}/docs") + print(f" Health check: http://{host}:{port}/health") + + try: + import uvicorn + uvicorn.run( + "api_server:app", + host=host, + port=port, + reload=reload, + log_level="info" + ) + except ImportError: + print("uvicorn not found, trying with subprocess...") + subprocess.run([ + sys.executable, "-m", "uvicorn", + "api_server:app", + "--host", host, + "--port", str(port), + "--reload" if reload else "" + ]) + + +def start_production_server(host="0.0.0.0", port=8000, workers=4): + """启动生产服务器""" + print(f"🚀 Starting production server on http://{host}:{port}") + print(f"Workers: {workers}") + + subprocess.run([ + sys.executable, "-m", "uvicorn", + "api_server:app", + "--host", host, + "--port", str(port), + "--workers", str(workers), + "--log-level", "warning" + ]) + + +def create_systemd_service(): + """创建 systemd 服务文件""" + current_dir = Path.cwd() + python_path = sys.executable + + service_content = f"""[Unit] +Description=Video Keyframe Extraction API Server +After=network.target + +[Service] +Type=exec +User=www-data +WorkingDirectory={current_dir} +Environment=PATH=/usr/bin:/usr/local/bin +ExecStart={python_path} -m uvicorn api_server:app --host 0.0.0.0 --port 8000 --workers 4 +Restart=always +RestartSec=10 + +[Install] +WantedBy=multi-user.target +""" + + service_file = Path("/etc/systemd/system/video-keyframe-api.service") + + try: + with open(service_file, 'w') as f: + f.write(service_content) + + print(f"✓ Systemd service created: {service_file}") + print("To enable and start:") + print(" sudo systemctl enable video-keyframe-api") + print(" sudo systemctl start video-keyframe-api") + + except PermissionError: + print("✗ Permission denied. Please run with sudo for systemd service creation") + + # 创建本地副本 + local_service = Path("./video-keyframe-api.service") + with open(local_service, 'w') as f: + f.write(service_content) + + print(f"✓ Service file created locally: {local_service}") + print(f"To install: sudo cp {local_service} /etc/systemd/system/") + + +def main(): + parser = argparse.ArgumentParser(description="Video Keyframe Extraction API Server") + + # 从配置文件获取默认值 + server_config = config.get_server_config() + + parser.add_argument("--mode", choices=["dev", "prod", "install"], default="dev", + help="运行模式: dev (开发), prod (生产), install (安装依赖)") + parser.add_argument("--host", default=server_config.get("host", "127.0.0.1"), help="绑定主机") + parser.add_argument("--port", type=int, default=server_config.get("port", 8000), help="端口号") + parser.add_argument("--workers", type=int, default=server_config.get("workers", 4), help="生产模式工作进程数") + parser.add_argument("--no-reload", action="store_true", help="禁用自动重载") + parser.add_argument("--check", action="store_true", help="仅检查环境") + parser.add_argument("--create-service", action="store_true", help="创建 systemd 服务") + + args = parser.parse_args() + + print("=== Video Keyframe Extraction API Server ===") + + # 检查环境 + rust_exe = check_rust_executable() + deps_ok = check_dependencies() + + if args.check: + print("\n=== Environment Check ===") + print(f"Rust executable: {'✓' if rust_exe else '✗'}") + print(f"Python dependencies: {'✓' if deps_ok else '✗'}") + return + + if args.create_service: + create_systemd_service() + return + + # 安装模式 + if args.mode == "install": + if not deps_ok: + install_dependencies() + else: + print("✓ Dependencies already installed") + return + + # 检查必要条件 + if not rust_exe: + print("✗ Cannot start without Rust executable") + print("Please run: cargo build --release") + sys.exit(1) + + if not deps_ok: + print("Installing missing dependencies...") + if not install_dependencies(): + sys.exit(1) + + # 启动服务器 + if args.mode == "dev": + start_development_server( + host=args.host, + port=args.port, + reload=not args.no_reload + ) + elif args.mode == "prod": + start_production_server( + host=args.host, + port=args.port, + workers=args.workers + ) + + +if __name__ == "__main__": + main() diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 7b0a8ec92..a11ccaa7e 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service logger = get_logger("cache_manager") + class CacheManager: """ 一个支持分层和语义缓存的通用工具缓存管理器。 @@ -21,6 +22,7 @@ class CacheManager: L1缓存: 内存字典 (KV) + FAISS (Vector)。 L2缓存: 数据库 (KV) + ChromaDB (Vector)。 """ + _instance = None def __new__(cls, *args, **kwargs): @@ -32,7 +34,7 @@ class CacheManager: """ 初始化缓存管理器。 """ - if not hasattr(self, '_initialized'): + if not hasattr(self, "_initialized"): self.default_ttl = default_ttl self.semantic_cache_collection_name = "semantic_cache" @@ -41,7 +43,7 @@ class CacheManager: embedding_dim = global_config.lpmm_knowledge.embedding_dimension self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) self.l1_vector_id_to_key: Dict[int, str] = {} - + # L2 向量缓存 (使用新的服务) vector_db_service.get_or_create_collection(self.semantic_cache_collection_name) @@ -58,32 +60,32 @@ class CacheManager: try: if embedding_result is None: return None - + # 确保embedding_result是一维数组或列表 if isinstance(embedding_result, (list, tuple, np.ndarray)): # 转换为numpy数组进行处理 embedding_array = np.array(embedding_result) - + # 如果是多维数组,展平它 if embedding_array.ndim > 1: embedding_array = embedding_array.flatten() - + # 检查维度是否符合预期 expected_dim = global_config.lpmm_knowledge.embedding_dimension if embedding_array.shape[0] != expected_dim: logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") return None - + # 检查是否包含有效的数值 if np.isnan(embedding_array).any() or np.isinf(embedding_array).any(): logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)") return None - - return embedding_array.astype('float32') + + return embedding_array.astype("float32") else: logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}") return None - + except Exception as e: logger.error(f"验证嵌入向量时发生错误: {e}") return None @@ -102,14 +104,20 @@ class CacheManager: except (OSError, TypeError) as e: file_hash = "unknown" logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}") - + try: - sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8') + sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8") except TypeError: sorted_args = repr(sorted(function_args.items())) return f"{tool_name}::{sorted_args}::{file_hash}" - async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]: + async def get( + self, + tool_name: str, + function_args: Dict[str, Any], + tool_file_path: Union[str, Path], + semantic_query: Optional[str] = None, + ) -> Optional[Any]: """ 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ @@ -136,13 +144,13 @@ class CacheManager: embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result validated_embedding = self._validate_embedding(embedding_vector) if validated_embedding is not None: - query_embedding = np.array([validated_embedding], dtype='float32') + query_embedding = np.array([validated_embedding], dtype="float32") # 步骤 2a: L1 语义缓存 (FAISS) if query_embedding is not None and self.l1_vector_index.ntotal > 0: faiss.normalize_L2(query_embedding) - distances, indices = self.l1_vector_index.search(query_embedding, 1) - if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似 + distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore + if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似 hit_index = indices[0][0] l1_hit_key = self.l1_vector_id_to_key.get(hit_index) if l1_hit_key and l1_hit_key in self.l1_kv_cache: @@ -151,12 +159,9 @@ class CacheManager: # 步骤 2b: L2 精确缓存 (数据库) cache_results_obj = await db_query( - model_class=CacheEntries, - query_type="get", - filters={"cache_key": key}, - single_result=True + model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True ) - + if cache_results_obj: # 使用 getattr 安全访问属性,避免 Pylance 类型检查错误 expires_at = getattr(cache_results_obj, "expires_at", 0) @@ -164,7 +169,7 @@ class CacheManager: logger.info(f"命中L2键值缓存: {key}") cache_value = getattr(cache_results_obj, "cache_value", "{}") data = orjson.loads(cache_value) - + # 更新访问统计 await db_query( model_class=CacheEntries, @@ -172,20 +177,16 @@ class CacheManager: filters={"cache_key": key}, data={ "last_accessed": time.time(), - "access_count": getattr(cache_results_obj, "access_count", 0) + 1 - } + "access_count": getattr(cache_results_obj, "access_count", 0) + 1, + }, ) - + # 回填 L1 self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} return data else: # 删除过期的缓存条目 - await db_query( - model_class=CacheEntries, - query_type="delete", - filters={"cache_key": key} - ) + await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key}) # 步骤 2c: L2 语义缓存 (VectorDB Service) if query_embedding is not None: @@ -193,31 +194,33 @@ class CacheManager: results = vector_db_service.query( collection_name=self.semantic_cache_collection_name, query_embeddings=query_embedding.tolist(), - n_results=1 + n_results=1, ) - if results and results.get('ids') and results['ids'][0]: - distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A' + if results and results.get("ids") and results["ids"][0]: + distance = ( + results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A" + ) logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") - - if distance != 'N/A' and distance < 0.75: - l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0] + + if distance != "N/A" and distance < 0.75: + l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0] logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") - + # 从数据库获取缓存数据 semantic_cache_results_obj = await db_query( model_class=CacheEntries, query_type="get", filters={"cache_key": l2_hit_key}, - single_result=True + single_result=True, ) - + if semantic_cache_results_obj: expires_at = getattr(semantic_cache_results_obj, "expires_at", 0) if time.time() < expires_at: cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}") data = orjson.loads(cache_value) logger.debug(f"L2语义缓存返回的数据: {data}") - + # 回填 L1 self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} if query_embedding is not None: @@ -235,7 +238,15 @@ class CacheManager: logger.debug(f"缓存未命中: {key}") return None - async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None): + async def set( + self, + tool_name: str, + function_args: Dict[str, Any], + tool_file_path: Union[str, Path], + data: Any, + ttl: Optional[int] = None, + semantic_query: Optional[str] = None, + ): """将结果存入所有缓存层。""" if ttl is None: ttl = self.default_ttl @@ -244,27 +255,22 @@ class CacheManager: key = self._generate_key(tool_name, function_args, tool_file_path) expires_at = time.time() + ttl - + # 写入 L1 self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} # 写入 L2 (数据库) cache_data = { "cache_key": key, - "cache_value": orjson.dumps(data).decode('utf-8'), + "cache_value": orjson.dumps(data).decode("utf-8"), "expires_at": expires_at, "tool_name": tool_name, "created_at": time.time(), "last_accessed": time.time(), - "access_count": 1 + "access_count": 1, } - - await db_save( - model_class=CacheEntries, - data=cache_data, - key_field="cache_key", - key_value=key - ) + + await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key) # 写入语义缓存 if semantic_query and self.embedding_model: @@ -274,19 +280,19 @@ class CacheManager: embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result validated_embedding = self._validate_embedding(embedding_vector) if validated_embedding is not None: - embedding = np.array([validated_embedding], dtype='float32') - + embedding = np.array([validated_embedding], dtype="float32") + # 写入 L1 Vector new_id = self.l1_vector_index.ntotal faiss.normalize_L2(embedding) self.l1_vector_index.add(x=embedding) # type: ignore self.l1_vector_id_to_key[new_id] = key - + # 写入 L2 Vector (使用新的服务) vector_db_service.add( collection_name=self.semantic_cache_collection_name, embeddings=embedding.tolist(), - ids=[key] + ids=[key], ) except Exception as e: logger.warning(f"语义缓存写入失败: {e}") @@ -306,16 +312,16 @@ class CacheManager: await db_query( model_class=CacheEntries, query_type="delete", - filters={} # 删除所有记录 + filters={}, # 删除所有记录 ) - + # 清空 VectorDB try: vector_db_service.delete_collection(name=self.semantic_cache_collection_name) vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name) except Exception as e: logger.warning(f"清空 VectorDB 集合失败: {e}") - + logger.info("L2 (数据库 & VectorDB) 缓存已清空。") async def clear_all(self): @@ -327,25 +333,23 @@ class CacheManager: async def clean_expired(self): """清理过期的缓存条目""" current_time = time.time() - + # 清理L1过期条目 expired_keys = [] for key, entry in self.l1_kv_cache.items(): if current_time >= entry["expires_at"]: expired_keys.append(key) - + for key in expired_keys: del self.l1_kv_cache[key] - + # 清理L2过期条目 - await db_query( - model_class=CacheEntries, - query_type="delete", - filters={"expires_at": {"$lt": current_time}} - ) - + await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}}) + if expired_keys: logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目") + # 全局实例 -tool_cache = CacheManager() \ No newline at end of file +tool_cache = CacheManager() + diff --git a/src/common/tool_history.py b/src/common/tool_history.py deleted file mode 100644 index b3edb12ce..000000000 --- a/src/common/tool_history.py +++ /dev/null @@ -1,405 +0,0 @@ -"""工具执行历史记录模块""" -import time -from datetime import datetime -from typing import Any, Dict, List, Optional, Union -import json -from pathlib import Path -import inspect - -from .logger import get_logger -from src.config.config import global_config -from src.common.cache_manager import tool_cache - -logger = get_logger("tool_history") - -class ToolHistoryManager: - """工具执行历史记录管理器""" - - _instance = None - _initialized = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - if not self._initialized: - self._history: List[Dict[str, Any]] = [] - self._initialized = True - self._data_dir = Path("data/tool_history") - self._data_dir.mkdir(parents=True, exist_ok=True) - self._history_file = self._data_dir / "tool_history.jsonl" - self._load_history() - - def _save_history(self): - """保存所有历史记录到文件""" - try: - with self._history_file.open("w", encoding="utf-8") as f: - for record in self._history: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - except Exception as e: - logger.error(f"保存工具调用记录失败: {e}") - - def _save_record(self, record: Dict[str, Any]): - """保存单条记录到文件""" - try: - with self._history_file.open("a", encoding="utf-8") as f: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - except Exception as e: - logger.error(f"保存工具调用记录失败: {e}") - - def _clean_expired_records(self): - """清理已过期的记录""" - original_count = len(self._history) - self._history = [record for record in self._history if record.get("ttl_count", 0) < record.get("ttl", 5)] - cleaned_count = original_count - len(self._history) - - if cleaned_count > 0: - logger.info(f"清理了 {cleaned_count} 条过期的工具历史记录,剩余 {len(self._history)} 条") - self._save_history() - else: - logger.debug("没有需要清理的过期工具历史记录") - - def record_tool_call(self, - tool_name: str, - args: Dict[str, Any], - result: Any, - execution_time: float, - status: str, - chat_id: Optional[str] = None, - ttl: int = 5): - """记录工具调用 - - Args: - tool_name: 工具名称 - args: 工具调用参数 - result: 工具返回结果 - execution_time: 执行时间(秒) - status: 执行状态("completed"或"error") - chat_id: 聊天ID,与ChatManager中的chat_id对应,用于标识群聊或私聊会话 - ttl: 该记录的生命周期值,插入提示词多少次后删除,默认为5 - """ - # 检查是否启用历史记录且ttl大于0 - if not global_config.tool.history.enable_history or ttl <= 0: - return - - # 先清理过期记录 - self._clean_expired_records() - - try: - # 创建记录 - record = { - "tool_name": tool_name, - "timestamp": datetime.now().isoformat(), - "arguments": self._sanitize_args(args), - "result": self._sanitize_result(result), - "execution_time": execution_time, - "status": status, - "chat_id": chat_id, - "ttl": ttl, - "ttl_count": 0 - } - - # 添加到内存中的历史记录 - self._history.append(record) - - # 保存到文件 - self._save_record(record) - - if status == "completed": - logger.info(f"工具 {tool_name} 调用完成,耗时:{execution_time:.2f}s") - else: - logger.error(f"工具 {tool_name} 调用失败:{result}") - - except Exception as e: - logger.error(f"记录工具调用时发生错误: {e}") - - def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - """清理参数中的敏感信息""" - sensitive_keys = ['api_key', 'token', 'password', 'secret'] - sanitized = args.copy() - - def _sanitize_value(value): - if isinstance(value, dict): - return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v) - for k, v in value.items()} - return value - - return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v) - for k, v in sanitized.items()} - - def _sanitize_result(self, result: Any) -> Any: - """清理结果中的敏感信息""" - if isinstance(result, dict): - return self._sanitize_args(result) - return result - - def _load_history(self): - """加载历史记录文件""" - try: - if self._history_file.exists(): - self._history = [] - with self._history_file.open("r", encoding="utf-8") as f: - for line in f: - try: - record = json.loads(line) - if record.get("ttl_count", 0) < record.get("ttl", 5): # 只加载未过期的记录 - self._history.append(record) - except json.JSONDecodeError: - continue - logger.info(f"成功加载了 {len(self._history)} 条历史记录") - except Exception as e: - logger.error(f"加载历史记录失败: {e}") - - def query_history(self, - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None) -> List[Dict[str, Any]]: - """查询工具调用历史 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 聊天ID,与ChatManager中的chat_id对应,用于查询特定群聊或私聊的历史记录 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - 符合条件的历史记录列表 - """ - # 先清理过期记录 - self._clean_expired_records() - def _parse_time(time_str: Optional[Union[datetime, str]]) -> Optional[datetime]: - if isinstance(time_str, datetime): - return time_str - elif isinstance(time_str, str): - return datetime.fromisoformat(time_str) - return None - - filtered_history = self._history - - # 按工具名筛选 - if tool_names: - filtered_history = [ - record for record in filtered_history - if record["tool_name"] in tool_names - ] - - # 按时间范围筛选 - start_dt = _parse_time(start_time) - end_dt = _parse_time(end_time) - - if start_dt: - filtered_history = [ - record for record in filtered_history - if datetime.fromisoformat(record["timestamp"]) >= start_dt - ] - - if end_dt: - filtered_history = [ - record for record in filtered_history - if datetime.fromisoformat(record["timestamp"]) <= end_dt - ] - - # 按聊天ID筛选 - if chat_id: - filtered_history = [ - record for record in filtered_history - if record.get("chat_id") == chat_id - ] - - # 按状态筛选 - if status: - filtered_history = [ - record for record in filtered_history - if record["status"] == status - ] - - # 应用数量限制 - if limit: - filtered_history = filtered_history[-limit:] - - return filtered_history - - def get_recent_history_prompt(self, - limit: Optional[int] = None, - chat_id: Optional[str] = None) -> str: - """ - 获取最近工具调用历史的提示词 - - Args: - limit: 返回的历史记录数量,如果不提供则使用配置中的max_history - chat_id: 会话ID,用于只获取当前会话的历史 - - Returns: - 格式化的历史记录提示词 - """ - # 检查是否启用历史记录 - if not global_config.tool.history.enable_history: - return "" - - # 使用配置中的最大历史记录数 - if limit is None: - limit = global_config.tool.history.max_history - - recent_history = self.query_history( - chat_id=chat_id, - limit=limit - ) - - if not recent_history: - return "" - - prompt = "\n工具执行历史:\n" - needs_save = False - updated_history = [] - - for record in recent_history: - # 增加ttl计数 - record["ttl_count"] = record.get("ttl_count", 0) + 1 - needs_save = True - - # 如果未超过ttl,则添加到提示词中 - if record["ttl_count"] < record.get("ttl", 5): - # 提取结果中的name和content - result = record['result'] - if isinstance(result, dict): - name = result.get('name', record['tool_name']) - content = result.get('content', str(result)) - else: - name = record['tool_name'] - content = str(result) - - # 格式化内容,去除多余空白和换行 - content = content.strip().replace('\n', ' ') - - # 如果内容太长则截断 - if len(content) > 200: - content = content[:200] + "..." - - prompt += f"{name}: \n{content}\n\n" - updated_history.append(record) - - # 更新历史记录并保存 - if needs_save: - self._history = updated_history - self._save_history() - - return prompt - - def clear_history(self): - """清除历史记录""" - self._history.clear() - self._save_history() - logger.info("工具调用历史记录已清除") - - -def wrap_tool_executor(): - """ - 包装工具执行器以添加历史记录和缓存功能 - 这个函数应该在系统启动时被调用一次 - """ - from src.plugin_system.core.tool_use import ToolExecutor - from src.plugin_system.apis.tool_api import get_tool_instance - original_execute = ToolExecutor.execute_tool_call - history_manager = ToolHistoryManager() - - async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): - start_time = time.time() - - # 确保我们有 tool_instance - if not tool_instance: - tool_instance = get_tool_instance(tool_call.func_name) - - # 如果没有 tool_instance,就无法进行缓存检查,直接执行 - if not tool_instance: - result = await original_execute(self, tool_call, None) - execution_time = time.time() - start_time - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=result, - execution_time=execution_time, - status="completed", - chat_id=getattr(self, 'chat_id', None), - ttl=5 # Default TTL - ) - return result - - # 新的缓存逻辑 - if tool_instance.enable_cache: - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) - - cached_result = await tool_cache.get( - tool_name=tool_call.func_name, - function_args=tool_call.args, - tool_file_path=tool_file_path, - semantic_query=semantic_query - ) - if cached_result: - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") - return cached_result - except Exception as e: - logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") - - try: - result = await original_execute(self, tool_call, tool_instance) - execution_time = time.time() - start_time - - # 缓存结果 - if tool_instance.enable_cache: - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) - - await tool_cache.set( - tool_name=tool_call.func_name, - function_args=tool_call.args, - tool_file_path=tool_file_path, - data=result, - ttl=tool_instance.cache_ttl, - semantic_query=semantic_query - ) - except Exception as e: - logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") - - # 记录成功的调用 - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=result, - execution_time=execution_time, - status="completed", - chat_id=getattr(self, 'chat_id', None), - ttl=tool_instance.history_ttl - ) - - return result - - except Exception as e: - execution_time = time.time() - start_time - # 记录失败的调用 - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=str(e), - execution_time=execution_time, - status="error", - chat_id=getattr(self, 'chat_id', None), - ttl=tool_instance.history_ttl - ) - raise - - # 替换原始方法 - ToolExecutor.execute_tool_call = wrapped_execute_tool_call \ No newline at end of file diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 6a8d47187..28e617474 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -621,6 +621,7 @@ class WakeUpSystemConfig(ValidatedConfigBase): decay_interval: float = Field(default=30.0, ge=1.0, description="唤醒度衰减间隔(秒)") angry_duration: float = Field(default=300.0, ge=10.0, description="愤怒状态持续时间(秒)") angry_prompt: str = Field(default="你被人吵醒了非常生气,说话带着怒气", description="被吵醒后的愤怒提示词") + re_sleep_delay_minutes: int = Field(default=5, ge=1, description="被唤醒后,如果多久没有新消息则尝试重新入睡(分钟)") # --- 失眠机制相关参数 --- enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 7c9f19869..8341653ff 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -5,7 +5,7 @@ import random from enum import Enum from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine +from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator from src.common.logger import get_logger from src.config.config import model_config @@ -283,131 +283,130 @@ class LLMRequest: tools: Optional[List[Dict[str, Any]]] = None, raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """执行单次请求""" - # 模型选择和请求准备 - start_time = time.time() - model_info, api_provider, client = self._select_model() - model_name = model_info.name - - # 检查是否启用反截断 - use_anti_truncation = getattr(api_provider, "anti_truncation", False) - - processed_prompt = prompt - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # 空回复重试逻辑 - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: + """ + 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 + """ + failed_models = set() + last_exception: Optional[Exception] = None + + model_scheduler = self._model_scheduler(failed_models) + + for model_info, api_provider, client in model_scheduler: + start_time = time.time() + model_name = model_info.name + logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 + try: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = False - is_truncated = False - # 检测是否为空回复或截断 - if not tool_calls: - is_empty_reply = not content or content.strip() == "" - is_truncated = False - + # 检查是否启用反截断 + use_anti_truncation = getattr(api_provider, "anti_truncation", False) + processed_prompt = prompt if use_anti_truncation: - if content.endswith("[done]"): - content = content[:-6].strip() - logger.debug("检测到并已移除 [done] 标记") - else: - is_truncated = True - logger.warning("未检测到 [done] 标记,判定为截断") + processed_prompt += self.anti_truncation_instruction + logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能") - if is_empty_reply or is_truncated: - if empty_retry_count < max_empty_retry: - empty_retry_count += 1 - reason = "空回复" if is_empty_reply else "截断" - logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成") + processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) + message_builder = MessageBuilder() + message_builder.add_text_content(processed_prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) - model_info, api_provider, client = self._select_model() - continue - else: - # 已达到最大重试次数,但仍然是空回复或截断 - reason = "空回复" if is_empty_reply else "截断" - # 抛出异常,由外层重试逻辑或最终的异常处理器捕获 - raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复") + # 针对当前模型的空回复/截断重试逻辑 + empty_retry_count = 0 + max_empty_retry = api_provider.max_retry + empty_retry_interval = api_provider.retry_interval - # 记录使用情况 - if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + while empty_retry_count <= max_empty_retry: + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, model_info=model_info, - model_usage=usage, - time_cost=time.time() - start_time, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", + message_list=messages, + tool_options=tool_built, + temperature=temperature, + max_tokens=max_tokens, ) - # 处理空回复 - if not content and not tool_calls: - if raise_when_empty: - raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复") - content = "生成的响应为空,请检查模型配置或输入内容是否正确" - elif empty_retry_count > 0: - logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复") + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls - return content, (reasoning_content, model_info.name, tool_calls) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False + if use_anti_truncation: + if content.endswith("[done]"): + content = content[:-6].strip() + else: + is_truncated = True + + if is_empty_reply or is_truncated: + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...") + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue # 继续使用当前模型重试 + else: + # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") + raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") + + # 成功获取响应 + if usage := response.usage: + llm_usage_recorder.record_usage_to_database( + model_info=model_info, model_usage=usage, time_cost=time.time() - start_time, + user_id="system", request_type=self.request_type, endpoint="/chat/completions", + ) + + if not content and not tool_calls: + if raise_when_empty: + raise RuntimeError("生成空回复") + content = "生成的响应为空" + + logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 + return content, (reasoning_content, model_name, tool_calls) + + except RespNotOkException as e: + if e.status_code in [401, 403]: + logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 + else: + logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") + if raise_when_empty: + raise + # 对于其他HTTP错误,直接抛出,不再尝试其他模型 + return f"请求失败: {e}", ("", model_name, None) + + except RuntimeError as e: + # 捕获所有重试失败(包括空回复和网络问题) + logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 except Exception as e: - logger.error(f"请求执行失败: {e}") - if raise_when_empty: - # 在非并发模式下,如果第一次尝试就失败,则直接抛出异常 - if empty_retry_count == 0: - raise + logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 - # 如果在重试过程中失败,则继续重试 - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...") - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue - else: - logger.error(f"经过 {max_empty_retry} 次重试后仍然失败") - raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e - else: - # 在并发模式下,单个请求的失败不应中断整个并发流程, - # 而是将异常返回给调用者(即 execute_concurrently)进行统一处理 - raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获 - - # 重试失败 + # 所有模型都尝试失败 + logger.error("所有可用模型都已尝试失败。") if raise_when_empty: - raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") - return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None) + if last_exception: + raise RuntimeError("所有模型都请求失败") from last_exception + raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") + + return "所有模型都请求失败", ("", "unknown", None) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: """获取嵌入向量 @@ -446,9 +445,24 @@ class LLMRequest: return embedding, model_info.name + def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + """ + 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 + """ + for model_name in self.model_for_task.model_list: + if model_name in failed_models: + continue + + model_info = model_config.get_model_info(model_name) + api_provider = model_config.get_provider(model_info.api_provider) + force_new_client = (self.request_type == "embedding") + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + yield model_info, api_provider, client + def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: """ - 根据总tokens和惩罚值选择的模型 + 根据总tokens和惩罚值选择的模型 (负载均衡) """ least_used_model_name = min( self.model_usage, diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index ec8ddec39..60b9f17de 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,9 +1,7 @@ -from typing import Any, Dict, List, Optional, Type, Union -from datetime import datetime +from typing import Optional, Type from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType -from src.common.tool_history import ToolHistoryManager from src.common.logger import get_logger logger = get_logger("tool_api") @@ -33,110 +31,4 @@ def get_llm_available_tool_definitions(): from src.plugin_system.core import component_registry llm_available_tools = component_registry.get_llm_available_tools() - return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] - -def get_tool_history( - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None -) -> List[Dict[str, Any]]: - """ - 获取工具调用历史记录 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 会话ID,用于筛选特定会话的调用 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - List[Dict]: 工具调用记录列表,每条记录包含以下字段: - - tool_name: 工具名称 - - timestamp: 调用时间 - - arguments: 调用参数 - - result: 调用结果 - - execution_time: 执行时间 - - status: 执行状态 - - chat_id: 会话ID - """ - history_manager = ToolHistoryManager() - return history_manager.query_history( - tool_names=tool_names, - start_time=start_time, - end_time=end_time, - chat_id=chat_id, - limit=limit, - status=status - ) - - -def get_tool_history_text( - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None -) -> str: - """ - 获取工具调用历史记录的文本格式 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 会话ID,用于筛选特定会话的调用 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - str: 格式化的工具调用历史记录文本 - """ - history = get_tool_history( - tool_names=tool_names, - start_time=start_time, - end_time=end_time, - chat_id=chat_id, - limit=limit, - status=status - ) - - if not history: - return "没有找到工具调用记录" - - text = "工具调用历史记录:\n" - for record in history: - # 提取结果中的name和content - result = record['result'] - if isinstance(result, dict): - name = result.get('name', record['tool_name']) - content = result.get('content', str(result)) - else: - name = record['tool_name'] - content = str(result) - - # 格式化内容 - content = content.strip().replace('\n', ' ') - if len(content) > 200: - content = content[:200] + "..." - - # 格式化时间 - timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S") - - text += f"[{timestamp}] {name}\n" - text += f"结果: {content}\n\n" - - return text - - -def clear_tool_history() -> None: - """ - 清除所有工具调用历史记录 - """ - history_manager = ToolHistoryManager() - history_manager.clear_history() \ No newline at end of file + return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] \ No newline at end of file diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index c527752d5..a9ab38911 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -119,17 +119,17 @@ class BaseEvent: for i, result in enumerate(results): subscriber = sorted_subscribers[i] handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__ - - if isinstance(result, Exception): - # 处理执行异常 - logger.error(f"事件处理器 {handler_name} 执行失败: {result}") - processed_results.append(HandlerResult(False, True, str(result), handler_name)) - else: - # 正常执行结果 - if not result.handler_name: - # 补充handler_name - result.handler_name = handler_name - processed_results.append(result) + if result: + if isinstance(result, Exception): + # 处理执行异常 + logger.error(f"事件处理器 {handler_name} 执行失败: {result}") + processed_results.append(HandlerResult(False, True, str(result), handler_name)) + else: + # 正常执行结果 + if not result.handler_name: + # 补充handler_name + result.handler_name = handler_name + processed_results.append(result) return HandlerResultsCollection(processed_results) diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index bfc0e5636..1d023ae02 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -26,7 +26,6 @@ class BaseEventHandler(ABC): def __init__(self): self.log_prefix = "[EventHandler]" - self.plugin_name = "" """对应插件名""" self.plugin_config: Optional[Dict] = None """插件配置字典""" diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 57f131ba1..8916fadfd 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Type, Tuple, Union, TYPE_CHECKING +from typing import List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index 16af685a1..1e68a2276 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -4,7 +4,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional, List +from typing import Tuple, Optional, List import re from src.common.logger import get_logger diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 7e925e3f0..eea3a247e 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -166,7 +166,8 @@ class ComponentRegistry: if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): logger.error(f"注册失败: {action_name} 不是有效的Action") return False - + + action_class.plugin_name = action_info.plugin_name self._action_registry[action_name] = action_class # 如果启用,添加到默认动作集 @@ -184,6 +185,7 @@ class ComponentRegistry: logger.error(f"注册失败: {command_name} 不是有效的Command") return False + command_class.plugin_name = command_info.plugin_name self._command_registry[command_name] = command_class # 如果启用了且有匹配模式 @@ -213,6 +215,7 @@ class ComponentRegistry: if not hasattr(self, '_plus_command_registry'): self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} + plus_command_class.plugin_name = plus_command_info.plugin_name self._plus_command_registry[plus_command_name] = plus_command_class logger.debug(f"已注册PlusCommand组件: {plus_command_name}") @@ -222,6 +225,7 @@ class ComponentRegistry: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name + tool_class.plugin_name = tool_info.plugin_name self._tool_registry[tool_name] = tool_class # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 @@ -246,6 +250,7 @@ class ComponentRegistry: logger.warning(f"EventHandler组件 {handler_name} 未启用") return True # 未启用,但是也是注册成功 + handler_class.plugin_name = handler_info.plugin_name # 使用EventManager进行事件处理器注册 from src.plugin_system.core.event_manager import event_manager return event_manager.register_event_handler(handler_class) diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index dee611c8c..180085f6d 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest from src.llm_models.payload_content import ToolCall from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +import inspect from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger +from src.common.cache_manager import tool_cache logger = get_logger("tool_use") @@ -184,21 +186,65 @@ class ToolExecutor: return tool_results, used_tools async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: - # sourcery skip: use-assigned-variable - """执行单个工具调用 + """执行单个工具调用,并处理缓存""" + + function_args = tool_call.args or {} + tool_instance = tool_instance or get_tool_instance(tool_call.func_name) - Args: - tool_call: 工具调用对象 + # 如果工具不存在或未启用缓存,则直接执行 + if not tool_instance or not tool_instance.enable_cache: + return await self._original_execute_tool_call(tool_call, tool_instance) - Returns: - Optional[Dict]: 工具调用结果,如果失败则返回None - """ + # --- 缓存逻辑开始 --- + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = function_args.get(tool_instance.semantic_cache_query_key) + + cached_result = await tool_cache.get( + tool_name=tool_call.func_name, + function_args=function_args, + tool_file_path=tool_file_path, + semantic_query=semantic_query + ) + if cached_result: + logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") + return cached_result + except Exception as e: + logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") + + # 缓存未命中,执行原始工具调用 + result = await self._original_execute_tool_call(tool_call, tool_instance) + + # 将结果存入缓存 + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = function_args.get(tool_instance.semantic_cache_query_key) + + await tool_cache.set( + tool_name=tool_call.func_name, + function_args=function_args, + tool_file_path=tool_file_path, + data=result, + ttl=tool_instance.cache_ttl, + semantic_query=semantic_query + ) + except Exception as e: + logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") + # --- 缓存逻辑结束 --- + + return result + + async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: + """执行单个工具调用的原始逻辑""" try: function_name = tool_call.func_name function_args = tool_call.args or {} - logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}") + logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}") function_args["llm_called"] = True # 标记为LLM调用 - # 获取对应工具实例 tool_instance = tool_instance or get_tool_instance(function_name) if not tool_instance: diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index ca9a8c72c..6507c1c92 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -24,6 +24,7 @@ from .services.qzone_service import QZoneService from .services.scheduler_service import SchedulerService from .services.monitor_service import MonitorService from .services.cookie_service import CookieService +from .services.reply_tracker_service import ReplyTrackerService from .services.manager import register_service logger = get_logger("MaiZone.Plugin") @@ -99,11 +100,13 @@ class MaiZoneRefactoredPlugin(BasePlugin): content_service = ContentService(self.get_config) image_service = ImageService(self.get_config) cookie_service = CookieService(self.get_config) + reply_tracker_service = ReplyTrackerService() qzone_service = QZoneService(self.get_config, content_service, image_service, cookie_service) scheduler_service = SchedulerService(self.get_config, qzone_service) monitor_service = MonitorService(self.get_config, qzone_service) register_service("qzone", qzone_service) + register_service("reply_tracker", reply_tracker_service) register_service("get_config", self.get_config) # 保存服务引用以便后续启动 diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 7a98a7cdc..cda1fa714 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -9,12 +9,9 @@ import datetime import base64 import aiohttp from src.common.logger import get_logger -import base64 -import aiohttp import imghdr import asyncio -from src.common.logger import get_logger -from src.plugin_system.apis import llm_api, config_api, generator_api, person_api +from src.plugin_system.apis import llm_api, config_api, generator_api from src.chat.message_receive.chat_stream import get_chat_manager from maim_message import UserInfo from src.llm_models.utils_model import LLMRequest diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index f97ff0991..ea422b7e5 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -27,6 +27,7 @@ from src.chat.utils.chat_message_builder import ( from .content_service import ContentService from .image_service import ImageService from .cookie_service import CookieService +from .reply_tracker_service import ReplyTrackerService logger = get_logger("MaiZone.QZoneService") @@ -55,6 +56,7 @@ class QZoneService: self.content_service = content_service self.image_service = image_service self.cookie_service = cookie_service + self.reply_tracker = ReplyTrackerService() # --- Public Methods (High-Level Business Logic) --- @@ -154,7 +156,8 @@ class QZoneService: # --- 第一步: 单独处理自己说说的评论 --- if self.get_config("monitor.enable_auto_reply", False): try: - own_feeds = await api_client["list_feeds"](qq_account, 5) # 获取自己最近5条说说 + # 传入新参数,表明正在检查自己的说说 + own_feeds = await api_client["list_feeds"](qq_account, 5, is_monitoring_own_feeds=True) if own_feeds: logger.info(f"获取到自己 {len(own_feeds)} 条说说,检查评论...") for feed in own_feeds: @@ -248,42 +251,83 @@ class QZoneService: content = feed.get("content", "") fid = feed.get("tid", "") - if not comments: + if not comments or not fid: return - # 筛选出未被自己回复过的评论 - if not comments: + # 1. 将评论分为用户评论和自己的回复 + user_comments = [c for c in comments if str(c.get('qq_account')) != str(qq_account)] + my_replies = [c for c in comments if str(c.get('qq_account')) == str(qq_account)] + + if not user_comments: return - # 找到所有我已经回复过的评论的ID - replied_to_tids = { - c['parent_tid'] for c in comments - if c.get('parent_tid') and str(c.get('qq_account')) == str(qq_account) - } + # 2. 验证已记录的回复是否仍然存在,清理已删除的回复记录 + await self._validate_and_cleanup_reply_records(fid, my_replies) - # 找出所有非我发出且我未回复过的评论 - comments_to_reply = [ - c for c in comments - if str(c.get('qq_account')) != str(qq_account) and c.get('comment_tid') not in replied_to_tids - ] + # 3. 使用验证后的持久化记录来筛选未回复的评论 + comments_to_reply = [] + for comment in user_comments: + comment_tid = comment.get('comment_tid') + if not comment_tid: + continue + + # 检查是否已经在持久化记录中标记为已回复 + if not self.reply_tracker.has_replied(fid, comment_tid): + comments_to_reply.append(comment) if not comments_to_reply: + logger.debug(f"说说 {fid} 下的所有评论都已回复过") return logger.info(f"发现自己说说下的 {len(comments_to_reply)} 条新评论,准备回复...") for comment in comments_to_reply: - reply_content = await self.content_service.generate_comment_reply( - content, comment.get("content", ""), comment.get("nickname", "") - ) - if reply_content: - success = await api_client["reply"]( - fid, qq_account, comment.get("nickname", ""), reply_content, comment.get("comment_tid") + comment_tid = comment.get("comment_tid") + nickname = comment.get("nickname", "") + comment_content = comment.get("content", "") + + try: + reply_content = await self.content_service.generate_comment_reply( + content, comment_content, nickname ) - if success: - logger.info(f"成功回复'{comment.get('nickname', '')}'的评论: '{reply_content}'") + if reply_content: + success = await api_client["reply"]( + fid, qq_account, nickname, reply_content, comment_tid + ) + if success: + # 标记为已回复 + self.reply_tracker.mark_as_replied(fid, comment_tid) + logger.info(f"成功回复'{nickname}'的评论: '{reply_content}'") + else: + logger.error(f"回复'{nickname}'的评论失败") + await asyncio.sleep(random.uniform(10, 20)) else: - logger.error(f"回复'{comment.get('nickname', '')}'的评论失败") - await asyncio.sleep(random.uniform(10, 20)) + logger.warning(f"生成回复内容失败,跳过回复'{nickname}'的评论") + except Exception as e: + logger.error(f"回复'{nickname}'的评论时发生异常: {e}", exc_info=True) + + async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]): + """验证并清理已删除的回复记录""" + # 获取当前记录中该说说的所有已回复评论ID + recorded_replied_comments = self.reply_tracker.get_replied_comments(fid) + + if not recorded_replied_comments: + return + + # 从API返回的我的回复中提取parent_tid(即被回复的评论ID) + current_replied_comments = set() + for reply in my_replies: + parent_tid = reply.get('parent_tid') + if parent_tid: + current_replied_comments.add(parent_tid) + + # 找出记录中有但实际已不存在的回复 + deleted_replies = recorded_replied_comments - current_replied_comments + + if deleted_replies: + logger.info(f"检测到 {len(deleted_replies)} 个回复已被删除,清理记录...") + for comment_tid in deleted_replies: + self.reply_tracker.remove_reply_record(fid, comment_tid) + logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}") async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str): """处理单条说说,决定是否评论和点赞""" @@ -641,7 +685,7 @@ class QZoneService: logger.error(f"上传图片 {index+1} 异常: {e}", exc_info=True) return None - async def _list_feeds(t_qq: str, num: int) -> List[Dict]: + async def _list_feeds(t_qq: str, num: int, is_monitoring_own_feeds: bool = False) -> List[Dict]: """获取指定用户说说列表""" try: params = { @@ -667,37 +711,41 @@ class QZoneService: feeds_list = [] my_name = json_data.get("logininfo", {}).get("name", "") for msg in json_data.get("msglist", []): - is_commented = any( - c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict) - ) - if not is_commented: - images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic] - - comments = [] - if 'commentlist' in msg: - for c in msg['commentlist']: - comments.append({ - 'qq_account': c.get('uin'), - 'nickname': c.get('name'), - 'content': c.get('content'), - 'comment_tid': c.get('tid'), - 'parent_tid': c.get('parent_tid') # API直接返回了父ID - }) - - feeds_list.append( - { - "tid": msg.get("tid", ""), - "content": msg.get("content", ""), - "created_time": time.strftime( - "%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0)) - ), - "rt_con": msg.get("rt_con", {}).get("content", "") - if isinstance(msg.get("rt_con"), dict) - else "", - "images": images, - "comments": comments - } + # 只有在处理好友说说时,才检查是否已评论并跳过 + if not is_monitoring_own_feeds: + is_commented = any( + c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict) ) + if is_commented: + continue + + images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic] + + comments = [] + if 'commentlist' in msg: + for c in msg['commentlist']: + comments.append({ + 'qq_account': c.get('uin'), + 'nickname': c.get('name'), + 'content': c.get('content'), + 'comment_tid': c.get('tid'), + 'parent_tid': c.get('parent_tid') # API直接返回了父ID + }) + + feeds_list.append( + { + "tid": msg.get("tid", ""), + "content": msg.get("content", ""), + "created_time": time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0)) + ), + "rt_con": msg.get("rt_con", {}).get("content", "") + if isinstance(msg.get("rt_con"), dict) + else "", + "images": images, + "comments": comments + } + ) return feeds_list except Exception as e: logger.error(f"获取说说列表失败: {e}", exc_info=True) diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py new file mode 100644 index 000000000..a90c88d9f --- /dev/null +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +""" +评论回复跟踪服务 +负责记录和管理已回复过的评论ID,避免重复回复 +""" + +import json +import time +from pathlib import Path +from typing import Set, Dict, Any +from src.common.logger import get_logger + +logger = get_logger("MaiZone.ReplyTrackerService") + + +class ReplyTrackerService: + """ + 评论回复跟踪服务 + 使用本地JSON文件持久化存储已回复的评论ID + """ + + def __init__(self): + # 数据存储路径 + self.data_dir = Path(__file__).resolve().parent.parent / "data" + self.data_dir.mkdir(exist_ok=True) + self.reply_record_file = self.data_dir / "replied_comments.json" + + # 内存中的已回复评论记录 + # 格式: {feed_id: {comment_id: timestamp, ...}, ...} + self.replied_comments: Dict[str, Dict[str, float]] = {} + + # 数据清理配置 + self.max_record_days = 30 # 保留30天的记录 + + # 加载已有数据 + self._load_data() + + def _load_data(self): + """从文件加载已回复评论数据""" + try: + if self.reply_record_file.exists(): + with open(self.reply_record_file, 'r', encoding='utf-8') as f: + data = json.load(f) + self.replied_comments = data + logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录") + else: + logger.info("未找到回复记录文件,将创建新的记录") + except Exception as e: + logger.error(f"加载回复记录失败: {e}") + self.replied_comments = {} + + def _save_data(self): + """保存已回复评论数据到文件""" + try: + # 清理过期数据 + self._cleanup_old_records() + + with open(self.reply_record_file, 'w', encoding='utf-8') as f: + json.dump(self.replied_comments, f, ensure_ascii=False, indent=2) + logger.debug("回复记录已保存") + except Exception as e: + logger.error(f"保存回复记录失败: {e}") + + def _cleanup_old_records(self): + """清理超过保留期限的记录""" + current_time = time.time() + cutoff_time = current_time - (self.max_record_days * 24 * 60 * 60) + + feeds_to_remove = [] + total_removed = 0 + + for feed_id, comments in self.replied_comments.items(): + comments_to_remove = [] + + for comment_id, timestamp in comments.items(): + if timestamp < cutoff_time: + comments_to_remove.append(comment_id) + + # 移除过期的评论记录 + for comment_id in comments_to_remove: + del comments[comment_id] + total_removed += 1 + + # 如果该说说下没有任何记录了,标记删除整个说说记录 + if not comments: + feeds_to_remove.append(feed_id) + + # 移除空的说说记录 + for feed_id in feeds_to_remove: + del self.replied_comments[feed_id] + + if total_removed > 0: + logger.info(f"清理了 {total_removed} 条过期的回复记录") + + def has_replied(self, feed_id: str, comment_id: str) -> bool: + """ + 检查是否已经回复过指定的评论 + + Args: + feed_id: 说说ID + comment_id: 评论ID + + Returns: + bool: 如果已回复过返回True,否则返回False + """ + if not feed_id or not comment_id: + return False + + return (feed_id in self.replied_comments and + comment_id in self.replied_comments[feed_id]) + + def mark_as_replied(self, feed_id: str, comment_id: str): + """ + 标记指定评论为已回复 + + Args: + feed_id: 说说ID + comment_id: 评论ID + """ + if not feed_id or not comment_id: + logger.warning("feed_id 或 comment_id 为空,无法标记为已回复") + return + + current_time = time.time() + + if feed_id not in self.replied_comments: + self.replied_comments[feed_id] = {} + + self.replied_comments[feed_id][comment_id] = current_time + + # 保存到文件 + self._save_data() + + logger.info(f"已标记评论为已回复: feed_id={feed_id}, comment_id={comment_id}") + + def get_replied_comments(self, feed_id: str) -> Set[str]: + """ + 获取指定说说下所有已回复的评论ID + + Args: + feed_id: 说说ID + + Returns: + Set[str]: 已回复的评论ID集合 + """ + if feed_id in self.replied_comments: + return set(self.replied_comments[feed_id].keys()) + return set() + + def get_stats(self) -> Dict[str, Any]: + """ + 获取回复记录统计信息 + + Returns: + Dict: 包含统计信息的字典 + """ + total_feeds = len(self.replied_comments) + total_replies = sum(len(comments) for comments in self.replied_comments.values()) + + return { + "total_feeds_with_replies": total_feeds, + "total_replied_comments": total_replies, + "data_file": str(self.reply_record_file), + "max_record_days": self.max_record_days + } + + def remove_reply_record(self, feed_id: str, comment_id: str): + """ + 移除指定评论的回复记录 + + Args: + feed_id: 说说ID + comment_id: 评论ID + """ + if feed_id in self.replied_comments and comment_id in self.replied_comments[feed_id]: + del self.replied_comments[feed_id][comment_id] + + # 如果该说说下没有任何回复记录了,删除整个说说记录 + if not self.replied_comments[feed_id]: + del self.replied_comments[feed_id] + + self._save_data() + logger.debug(f"已移除回复记录: feed_id={feed_id}, comment_id={comment_id}") + + def remove_feed_records(self, feed_id: str): + """ + 移除指定说说的所有回复记录 + + Args: + feed_id: 说说ID + """ + if feed_id in self.replied_comments: + del self.replied_comments[feed_id] + self._save_data() + logger.info(f"已移除说说 {feed_id} 的所有回复记录") \ No newline at end of file diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index d8a39107a..bad227787 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.logging_api import get_logger from src.plugin_system.base.component_types import PlusCommandInfo, ChatType from src.plugin_system.base.config_types import ConfigField -from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker +from src.plugin_system.utils.permission_decorators import require_permission logger = get_logger("Permission") diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 82578046d..84b87c657 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -411,7 +411,6 @@ class ScheduleManager: 通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。 新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。 """ - from src.chat.chat_loop.wakeup_manager import WakeUpManager # --- 基础检查 --- if not global_config.schedule.enable_is_sleep: return False diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 730f43e21..5d4b1c08f 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -21,7 +21,7 @@ max_retry = 2 timeout = 30 retry_interval = 10 -[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini" +[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"aiohttp_gemini" name = "Google" base_url = "https://api.google.com/v1" api_key = "your-google-api-key-1"