diff --git a/MoFox 重构指导总览.md b/MoFox 重构指导总览.md index 10349f5db..d46b40242 100644 --- a/MoFox 重构指导总览.md +++ b/MoFox 重构指导总览.md @@ -95,10 +95,39 @@ components:基本插件组件管理 mcp_tool_manager.py:MCP工具管理器 permission_manager.py:权限管理器 plugin_manager.py:插件管理器 + prompt_component_manager.py:Prompt组件管理器 tool_manager:工具相关管理 + __init__.py:导出 tool_histoty.py:工具调用历史记录 tool_use.py:实际工具调用器 types.py:组件类型 registry.py:组件注册管理 state_manager.py:组件状态管理 - +prompt:提示词管理系统 + __init__.py:导出 + prompt.py:Prompt基类 + manager.py:全局prompt管理器 + params.py:Prompt参数系统 +perception:感知学习系统 + __init__.py:导出 + memory:常规记忆 + ... + knowledge:知识库 + ... + meme:黑话库 + ... + express:表达学习 + ... +transport:通讯传输系统 + __init__.py:导出 + message_receive:消息接收 + ... + message_send:消息发送 + ... + router:api路由 + ... + sink:针对适配器的core sink和ws接收器 + ... +models:基本模型 + __init__.py:导出 + \ No newline at end of file diff --git a/docs/development/emoji_prompt_limit.md b/docs/development/emoji_prompt_limit.md new file mode 100644 index 000000000..7127dff1f --- /dev/null +++ b/docs/development/emoji_prompt_limit.md @@ -0,0 +1,22 @@ +# 表情替换候选数量说明 + +## 背景 +`MAX_EMOJI_FOR_PROMPT` 用于 `replace_a_emoji` 等场景,限制送入 LLM 的候选表情数量,避免上下文过长导致响应变慢或 token 开销过大。 + +## 为什么是 20 +- 平衡:超过十几项后决策收益递减,但 token/时间成本线性增加。 +- 性能:在常用模型和硬件下,20 个描述可在可接受延迟内返回决策。 +- 兼容:历史实现也使用 20,保持行为稳定。 + +## 何时调整 +- 设备/模型更强且希望更广覆盖:可提升到 30-40,但注意延迟和费用。 +- 低算力或对延迟敏感:可下调到 10-15 以加快决策。 +- 特殊场景(主题集中、库很小):下调有助于避免无意义的冗余候选。 + +## 如何修改 +- 常量位置:`src/chat/emoji_system/emoji_constants.py` 中的 `MAX_EMOJI_FOR_PROMPT`。 +- 如需动态配置,可将其迁移到 `global_config.emoji` 下的配置项并在 `emoji_manager` 读取。 + +## 建议 +- 调整后观察:替换决策耗时、模型费用、误删率(删除的表情是否被实际需要)。 +- 如继续扩展表情库规模,建议为候选列表增加基于使用频次或时间的预筛选策略。 diff --git a/docs/development/emoji_system_refactor.md b/docs/development/emoji_system_refactor.md new file mode 100644 index 000000000..5d5538f51 --- /dev/null +++ b/docs/development/emoji_system_refactor.md @@ -0,0 +1,33 @@ +# 表情系统重构说明 + +日期:2025-12-15 + +## 目标 +- 拆分单体的 `emoji_manager.py`,将实体、常量、文件工具解耦。 +- 减少扫描/注册期间的事件循环阻塞。 +- 保留现有行为(LLM/VLM 流程、容量替换、缓存查找),同时提升可维护性。 + +## 新结构 +- `src/chat/emoji_system/emoji_constants.py`:共享路径与提示/数量上限。 +- `src/chat/emoji_system/emoji_entities.py`:`MaiEmoji`(哈希、格式检测、入库/删除、缓存失效)。 +- `src/chat/emoji_system/emoji_utils.py`:目录保证、临时清理、增量文件扫描、DB 行到实体转换。 +- `src/chat/emoji_system/emoji_manager.py`:负责完整性检查、扫描、注册、VLM/LLM 描述、替换与缓存,现委托给上述模块。 +- `src/chat/emoji_system/README.md`:快速使用/生命周期指引。 + +## 行为变化 +- 完整性检查改为游标+批量增量扫描,每处理 50 个让出一次事件循环。 +- 循环内的重文件操作(exists、listdir、remove、makedirs)通过 `asyncio.to_thread` 释放主循环。 +- 目录扫描使用 `os.scandir`(经 `list_image_files`),减少重复 stat,并返回文件列表与是否为空。 +- 快速查找:加载时重建 `_emoji_index`,增删时保持同步;`get_emoji_from_manager` 优先走索引。 +- 注册与替换流程在更新索引的同时,异步清理失败/重复文件。 + +## 迁移提示 +- 现有调用继续使用 `get_emoji_manager()` 与 `EmojiManager` API,外部接口未改动。 +- 如曾直接从 `emoji_manager` 引入常量或工具,请改为从 `emoji_constants`、`emoji_entities`、`emoji_utils` 引入。 +- 依赖同步文件时序的测试/脚本可能观察到不同的耗时,但逻辑等价。 + +## 后续建议 +1. 为 `list_image_files`、`clean_unused_emojis`、完整性扫描游标行为补充单测。 +2. 将 VLM/LLM 提示词模板外置为配置,便于迭代。 +3. 暴露扫描耗时、清理数量、注册延迟等指标,便于观测。 +4. 为 `replace_a_emoji` 的 LLM 调用添加重试上限,并记录 prompt/决策日志以便审计。 diff --git a/docs/express_similarity.md b/docs/express_similarity.md new file mode 100644 index 000000000..04055e29c --- /dev/null +++ b/docs/express_similarity.md @@ -0,0 +1,36 @@ +# 表达相似度计算策略 + +本文档说明 `calculate_similarity` 的实现与配置,帮助在质量与性能间做权衡。 + +## 总览 +- 支持两种路径: + 1) **向量化路径(默认优先)**:TF-IDF + 余弦相似度(依赖 `scikit-learn`) + 2) **回退路径**:`difflib.SequenceMatcher` +- 参数 `prefer_vector` 控制是否优先尝试向量化,默认 `True`。 +- 依赖缺失或文本过短时,自动回退,无需额外配置。 + +## 调用方式 +```python +from src.chat.express.express_utils import calculate_similarity + +sim = calculate_similarity(text1, text2) # 默认优先向量化 +sim_fast = calculate_similarity(text1, text2, prefer_vector=False) # 强制使用 SequenceMatcher +``` + +## 依赖与回退 +- 可选依赖:`scikit-learn` + - 缺失时自动回退到 `SequenceMatcher`,不会抛异常。 +- 文本过短(长度 < 2)时直接回退,避免稀疏向量噪声。 + +## 适用建议 +- 文本较长、对鲁棒性/语义相似度有更高要求:保持默认(向量化优先)。 +- 环境无 `scikit-learn` 或追求极简依赖:调用时设置 `prefer_vector=False`。 +- 高并发性能敏感:可在调用点酌情关闭向量化或加缓存。 + +## 返回范围 +- 相似度范围始终在 `[0, 1]`。 +- 空字符串 → `0.0`;完全相同 → `1.0`。 + +## 额外建议 +- 若需更强语义能力,可替换为向量数据库或句向量模型(需新增依赖与配置)。 +- 对热路径可增加缓存(按文本哈希),或限制输入长度以控制向量维度与内存。 diff --git a/docs/napcat_video_configuration_guide.md b/docs/napcat_video_configuration_guide.md new file mode 100644 index 000000000..c4fa894a2 --- /dev/null +++ b/docs/napcat_video_configuration_guide.md @@ -0,0 +1,283 @@ +# Napcat 视频处理配置指南 + +## 概述 + +本指南说明如何在 MoFox-Bot 中配置和控制 Napcat 适配器的视频消息处理功能。 + +**相关 Issue**: [#10 - 强烈请求有个开关选择是否下载视频](https://github.com/MoFox-Studio/MoFox-Core/issues/10) + +--- + +## 快速开始 + +### 关闭视频下载(推荐用于低配机器或有限带宽) + +编辑 `config/bot_config.toml`,找到 `[napcat_adapter.features]` 段落,修改: + +```toml +[napcat_adapter.features] +enable_video_processing = false # 改为 false 关闭视频处理 +``` + +**效果**:视频消息会显示为 `[视频消息]`,不会进行下载。 + +--- + +## 配置选项详解 + +### 主开关:`enable_video_processing` + +| 属性 | 值 | +|------|-----| +| **类型** | 布尔值 (`true` / `false`) | +| **默认值** | `true` | +| **说明** | 是否启用视频消息的下载和处理 | + +**启用 (`true`)**: +- ✅ 自动下载视频 +- ✅ 将视频转换为 base64 并发送给 AI +- ⚠️ 消耗网络带宽和 CPU 资源 + +**禁用 (`false`)**: +- ✅ 跳过视频下载 +- ✅ 显示 `[视频消息]` 占位符 +- ✅ 显著降低带宽和 CPU 占用 + +### 高级选项 + +#### `video_max_size_mb` + +| 属性 | 值 | +|------|-----| +| **类型** | 整数 | +| **默认值** | `100` (MB) | +| **建议范围** | 10 - 500 MB | +| **说明** | 允许下载的最大视频文件大小 | + +**用途**:防止下载过大的视频文件。 + +**建议**: +- **低配机器** (2GB RAM): 设置为 10-20 MB +- **中等配置** (8GB RAM): 设置为 50-100 MB +- **高配机器** (16GB+ RAM): 设置为 100-500 MB + +```toml +# 只允许下载 50MB 以下的视频 +video_max_size_mb = 50 +``` + +#### `video_download_timeout` + +| 属性 | 值 | +|------|-----| +| **类型** | 整数 | +| **默认值** | `60` (秒) | +| **建议范围** | 30 - 180 秒 | +| **说明** | 视频下载超时时间 | + +**用途**:防止卡住等待无法下载的视频。 + +**建议**: +- **网络较差** (2-5 Mbps): 设置为 120-180 秒 +- **网络一般** (5-20 Mbps): 设置为 60-120 秒 +- **网络较好** (20+ Mbps): 设置为 30-60 秒 + +```toml +# 下载超时时间改为 120 秒 +video_download_timeout = 120 +``` + +--- + +## 常见配置场景 + +### 场景 1:服务器带宽有限 + +**症状**:群聊消息中经常出现大量视频,导致网络流量爆满。 + +**解决方案**: +```toml +[napcat_adapter.features] +enable_video_processing = false # 完全关闭 +``` + +### 场景 2:机器性能较低 + +**症状**:处理视频消息时 CPU 占用率高,其他功能响应变慢。 + +**解决方案**: +```toml +[napcat_adapter.features] +enable_video_processing = true +video_max_size_mb = 20 # 限制小视频 +video_download_timeout = 30 # 快速超时 +``` + +### 场景 3:特定时间段关闭视频处理 + +如果需要在特定时间段内关闭视频处理,可以: + +1. 修改配置文件 +2. 调用 API 重新加载配置(如果支持) + +例如:在工作时间关闭,下班后打开。 + +### 场景 4:保留所有视频处理(默认行为) + +```toml +[napcat_adapter.features] +enable_video_processing = true +video_max_size_mb = 100 +video_download_timeout = 60 +``` + +--- + +## 工作原理 + +### 启用视频处理的流程 + +``` +消息到达 + ↓ +检查 enable_video_processing + ├─ false → 返回 [视频消息] 占位符 ✓ + └─ true ↓ + 检查文件大小 + ├─ > video_max_size_mb → 返回错误信息 ✓ + └─ ≤ video_max_size_mb ↓ + 开始下载(最多等待 video_download_timeout 秒) + ├─ 成功 → 返回视频数据 ✓ + ├─ 超时 → 返回超时错误 ✓ + └─ 失败 → 返回错误信息 ✓ +``` + +### 禁用视频处理的流程 + +``` +消息到达 + ↓ +检查 enable_video_processing + └─ false → 立即返回 [视频消息] 占位符 ✓ + (节省带宽和 CPU) +``` + +--- + +## 错误处理 + +当视频处理出现问题时,用户会看到以下占位符消息: + +| 消息 | 含义 | +|------|------| +| `[视频消息]` | 视频处理已禁用或信息不完整 | +| `[视频消息] (文件过大)` | 视频大小超过限制 | +| `[视频消息] (下载失败)` | 网络错误或服务不可用 | +| `[视频消息处理出错]` | 其他异常错误 | + +这些占位符确保消息不会因为视频处理失败而导致程序崩溃。 + +--- + +## 性能对比 + +| 配置 | 带宽消耗 | CPU 占用 | 内存占用 | 响应速度 | +|------|----------|---------|---------|----------| +| **禁用** (`false`) | 🟢 极低 | 🟢 极低 | 🟢 极低 | 🟢 极快 | +| **启用,小视频** (≤20MB) | 🟡 中等 | 🟡 中等 | 🟡 中等 | 🟡 一般 | +| **启用,大视频** (≤100MB) | 🔴 较高 | 🔴 较高 | 🔴 较高 | 🔴 较慢 | + +--- + +## 监控和调试 + +### 检查配置是否生效 + +启动 bot 后,查看日志中是否有类似信息: + +``` +[napcat_adapter] 视频下载器已初始化: max_size=100MB, timeout=60s +``` + +如果看到这条信息,说明配置已成功加载。 + +### 监控视频处理 + +当处理视频消息时,日志中会记录: + +``` +[video_handler] 开始下载视频: https://... +[video_handler] 视频下载成功,大小: 25.50 MB +``` + +或者: + +``` +[napcat_adapter] 视频消息处理已禁用,跳过 +``` + +--- + +## 常见问题 + +### Q1: 关闭视频处理会影响 AI 的回复吗? + +**A**: 不会。AI 仍然能看到 `[视频消息]` 占位符,可以根据上下文判断是否涉及视频内容。 + +### Q2: 可以为不同群组设置不同的视频处理策略吗? + +**A**: 当前版本不支持。所有群组使用相同的配置。如需支持,请在 Issue 或讨论中提出。 + +### Q3: 视频下载会影响消息处理延迟吗? + +**A**: 会。下载大视频可能需要几秒钟。建议: +- 设置合理的 `video_download_timeout` +- 或禁用视频处理以获得最快响应 + +### Q4: 修改配置后需要重启吗? + +**A**: 是的。需要重启 bot 才能应用新配置。 + +### Q5: 如何快速诊断视频下载问题? + +**A**: +1. 检查日志中的错误信息 +2. 验证网络连接 +3. 检查 `video_max_size_mb` 是否设置过小 +4. 尝试增加 `video_download_timeout` + +--- + +## 最佳实践 + +1. **新用户建议**:先启用视频处理,如果出现性能问题再调整参数或关闭。 + +2. **生产环境建议**: + - 定期监控日志中的视频处理错误 + - 根据实际网络和 CPU 情况调整参数 + - 在高峰期可考虑关闭视频处理 + +3. **开发调试**: + - 启用日志中的 DEBUG 级别输出 + - 测试各个 `video_max_size_mb` 值的实际表现 + - 检查超时时间是否符合网络条件 + +--- + +## 相关链接 + +- **GitHub Issue #10**: [强烈请求有个开关选择是否下载视频](https://github.com/MoFox-Studio/MoFox-Core/issues/10) +- **配置文件**: `config/bot_config.toml` +- **实现代码**: + - `src/plugins/built_in/napcat_adapter/plugin.py` + - `src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py` + - `src/plugins/built_in/napcat_adapter/src/handlers/video_handler.py` + +--- + +## 反馈和建议 + +如有其他问题或建议,欢迎在 GitHub Issue 中提出。 + +**版本**: v2.1.0 +**最后更新**: 2025-12-16 diff --git a/docs/changelogs/short_term_pressure_patch.md b/docs/short_term_pressure_patch.md similarity index 95% rename from docs/changelogs/short_term_pressure_patch.md rename to docs/short_term_pressure_patch.md index 65dd7ea76..0e124932c 100644 --- a/docs/changelogs/short_term_pressure_patch.md +++ b/docs/short_term_pressure_patch.md @@ -30,7 +30,7 @@ ## 影响范围 -- 默认行为保持与补丁前一致(开关默认 `on`)。 +- 默认行为保持与补丁前一致(开关默认 `off`)。 - 如果关闭开关,短期层将不再做强制删除,只依赖自动转移机制。 ## 回滚 diff --git a/docs/style_learner_resource_limit.md b/docs/style_learner_resource_limit.md new file mode 100644 index 000000000..da2550742 --- /dev/null +++ b/docs/style_learner_resource_limit.md @@ -0,0 +1,60 @@ +# StyleLearner 资源上限开关(默认开启) + +## 概览 +StyleLearner 支持资源上限控制,用于约束风格容量与清理行为。开关默认 **开启**,以防止模型无限膨胀;可在运行时动态关闭。 + +## 开关位置与用法(务必看这里) + +开关在 **代码层**,默认开启,不依赖配置文件。 + +1) **全局运行时切换(推荐)** + 路径:`src/chat/express/style_learner.py` 暴露的单例 `style_learner_manager` + ```python + from src.chat.express.style_learner import style_learner_manager + + # 关闭资源上限(放开容量,谨慎使用) + style_learner_manager.set_resource_limit(False) + + # 再次开启资源上限 + style_learner_manager.set_resource_limit(True) + ``` + - 影响范围:实时作用于已创建的全部 learner(逐个同步 `resource_limit_enabled`)。 + - 生效时机:调用后立即生效,无需重启。 + +2) **构造时指定(不常用)** + - `StyleLearner(resource_limit_enabled: True|False, ...)` + - `StyleLearnerManager(resource_limit_enabled: True|False, ...)` + 用于自定义实例化逻辑(通常保持默认即可)。 + +3) **默认行为** + - 开关默认 **开启**,即启用容量管理与清理。 + - 没有配置文件项;若需持久化开关状态,可自行在启动代码中显式调用 `set_resource_limit`。 + +## 资源上限行为(开启时) +- 容量参数(每个 chat): + - `max_styles = 2000` + - `cleanup_threshold = 0.9`(≥90% 容量触发清理) + - `cleanup_ratio = 0.2`(清理低价值风格约 20%) +- 价值评分:结合使用频率(log 平滑)与最近使用时间(指数衰减),得分低者优先清理。 +- 仅对单个 learner 的容量管理生效;LRU 淘汰逻辑保持不变。 + +> ⚙️ 开关作用面: +> - **开启**:在 add_style 时会检查容量并触发 `_cleanup_styles`;预测/学习逻辑不变。 +> - **关闭**:不再触发容量清理,但 LRU 管理器仍可能在进程层面淘汰不活跃 learner。 + +## I/O 与健壮性 +- 模型与元数据保存采用原子写(`.tmp` + `os.replace`),避免部分写入。 +- `pickle` 使用 `HIGHEST_PROTOCOL`,并执行 `fsync` 确保落盘。 + +## 兼容性 +- 默认开启,无需修改配置文件;关闭后行为与旧版本类似。 +- 已有模型文件可直接加载,开关仅影响运行时清理策略。 + +## 何时建议开启/关闭 +- 开启(默认):内存/磁盘受限,或聊天风格高频增长,需防止模型膨胀。 +- 关闭:需要完整保留所有历史风格且资源充足,或进行一次性数据收集实验。 + +## 监控与调优建议 +- 监控:每 chat 风格数量、清理触发次数、删除数量、预测延迟 p95。 +- 如清理过于激进:提高 `cleanup_threshold` 或降低 `cleanup_ratio`。 +- 如内存/磁盘依旧偏高:降低 `max_styles`,或增加定期持久化与压缩策略。 diff --git a/docs/video_download_configuration_changelog.md b/docs/video_download_configuration_changelog.md new file mode 100644 index 000000000..7ce8a06c8 --- /dev/null +++ b/docs/video_download_configuration_changelog.md @@ -0,0 +1,134 @@ +# Napcat 适配器视频处理配置完成总结 + +## 修改内容 + +### 1. **增强配置定义** (`plugin.py`) + - 添加 `video_max_size_mb`: 视频最大大小限制(默认 100MB) + - 添加 `video_download_timeout`: 下载超时时间(默认 60秒) + - 改进 `enable_video_processing` 的描述文字 + - **位置**: `src/plugins/built_in/napcat_adapter/plugin.py` L417-430 + +### 2. **改进消息处理器** (`message_handler.py`) + - 添加 `_video_downloader` 成员变量存储下载器实例 + - 改进 `set_plugin_config()` 方法,根据配置初始化视频下载器 + - 改进视频下载调用,使用初始化时的配置 + - **位置**: `src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py` L32-54, L327-334 + +### 3. **添加配置示例** (`bot_config.toml`) + - 添加 `[napcat_adapter]` 配置段 + - 添加完整的 Napcat 服务器配置示例 + - 添加详细的特性配置(消息过滤、视频处理等) + - 包含详尽的中文注释和使用建议 + - **位置**: `config/bot_config.toml` L680-724 + +### 4. **编写使用文档** (新文件) + - 创建 `docs/napcat_video_configuration_guide.md` + - 详细说明所有配置选项的含义和用法 + - 提供常见场景的配置模板 + - 包含故障排查和性能对比 + +--- + +## 功能清单 + +### 核心功能 +- ✅ 全局开关控制视频处理 (`enable_video_processing`) +- ✅ 视频大小限制 (`video_max_size_mb`) +- ✅ 下载超时控制 (`video_download_timeout`) +- ✅ 根据配置初始化下载器 +- ✅ 友好的错误提示信息 + +### 用户体验 +- ✅ 详细的配置说明文档 +- ✅ 代码中的中文注释 +- ✅ 启动日志反馈 +- ✅ 配置示例可直接使用 + +--- + +## 如何使用 + +### 快速关闭视频下载(解决 Issue #10) + +编辑 `config/bot_config.toml`: + +```toml +[napcat_adapter.features] +enable_video_processing = false # 改为 false +``` + +重启 bot 后生效。 + +### 调整视频大小限制 + +```toml +[napcat_adapter.features] +video_max_size_mb = 50 # 只允许下载 50MB 以下的视频 +``` + +### 调整下载超时 + +```toml +[napcat_adapter.features] +video_download_timeout = 120 # 增加到 120 秒 +``` + +--- + +## 向下兼容性 + +- ✅ 旧配置文件无需修改(使用默认值) +- ✅ 现有视频处理流程完全兼容 +- ✅ 所有功能都带有合理的默认值 + +--- + +## 测试场景 + +已验证的工作场景: + +| 场景 | 行为 | 状态 | +|------|------|------| +| 视频处理启用 | 正常下载视频 | ✅ | +| 视频处理禁用 | 返回占位符 | ✅ | +| 视频超过大小限制 | 返回错误信息 | ✅ | +| 下载超时 | 返回超时错误 | ✅ | +| 网络错误 | 返回友好错误 | ✅ | +| 启动时初始化 | 日志输出配置 | ✅ | + +--- + +## 文件修改清单 + +``` +修改文件: + - src/plugins/built_in/napcat_adapter/plugin.py + - src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py + - config/bot_config.toml + +新增文件: + - docs/napcat_video_configuration_guide.md +``` + +--- + +## 关联信息 + +- **GitHub Issue**: #10 - 强烈请求有个开关选择是否下载视频 +- **修复时间**: 2025-12-16 +- **相关文档**: [Napcat 视频处理配置指南](./napcat_video_configuration_guide.md) + +--- + +## 后续改进建议 + +1. **分组配置** - 为不同群组设置不同的视频处理策略 +2. **动态开关** - 提供运行时 API 动态开启/关闭视频处理 +3. **性能监控** - 添加视频处理的性能统计指标 +4. **队列管理** - 实现视频下载队列,限制并发下载数 +5. **缓存机制** - 缓存已下载的视频避免重复下载 + +--- + +**版本**: v2.1.0 +**状态**: ✅ 完成 diff --git a/src/chat/emoji_system/README.md b/src/chat/emoji_system/README.md new file mode 100644 index 000000000..ab9f50820 --- /dev/null +++ b/src/chat/emoji_system/README.md @@ -0,0 +1,37 @@ +# 新表情系统概览 + +本目录存放表情包的采集、注册与选择逻辑。 + +## 模块 +- `emoji_constants.py`:共享路径与数量上限。 +- `emoji_entities.py`:`MaiEmoji` 实体,负责哈希/格式检测、数据库注册与删除。 +- `emoji_utils.py`:文件系统工具(目录保证、临时清理、DB 行转换、文件列表扫描)。 +- `emoji_manager.py`:核心管理器,定期扫描、完整性检查、VLM/LLM 标注、容量替换、缓存查找。 +- `emoji_history.py`:按会话保存的内存历史。 + +## 生命周期 +1. 通过 `EmojiManager.start()` 启动后台任务(或在已有事件循环中直接 await `start_periodic_check_register()`)。 +2. 循环会加载数据库状态、做完整性清理、清理临时缓存,并扫描 `data/emoji` 中的新文件。 +3. 新图片会生成哈希,调用 VLM/LLM 生成描述后注册入库,并移动到 `data/emoji_registed`。 +4. 达到容量上限时,`replace_a_emoji()` 可能在 LLM 协助下删除低使用量表情再注册新表情。 + +## 关键行为 +- 完整性检查增量扫描,批量让出事件循环避免长阻塞。 +- 循环内的文件操作使用 `asyncio.to_thread` 以保持事件循环可响应。 +- 哈希索引 `_emoji_index` 加速内存查找;数据库为事实来源,内存为镜像。 +- 描述与标签使用缓存(见管理器上的 `@cached`)。 + +## 常用操作 +- `get_emoji_for_text(text_emotion)`:按目标情绪选取表情路径与描述。 +- `record_usage(emoji_hash)`:累加使用次数。 +- `delete_emoji(emoji_hash)`:删除文件与数据库记录并清缓存。 + +## 目录 +- 待注册:`data/emoji` +- 已注册:`data/emoji_registed` +- 临时图片:`data/image`, `data/images` + +## 说明 +- 通过 `config/bot_config.toml`、`config/model_config.toml` 配置上限与模型。 +- GIF 支持保留,注册前会提取关键帧再送 VLM。 +- 避免直接使用 `Session`,请使用本模块提供的 API。 diff --git a/src/chat/emoji_system/emoji_constants.py b/src/chat/emoji_system/emoji_constants.py new file mode 100644 index 000000000..3c4d70a6d --- /dev/null +++ b/src/chat/emoji_system/emoji_constants.py @@ -0,0 +1,6 @@ +import os + +BASE_DIR = os.path.join("data") +EMOJI_DIR = os.path.join(BASE_DIR, "emoji") +EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") +MAX_EMOJI_FOR_PROMPT = 20 diff --git a/src/chat/emoji_system/emoji_entities.py b/src/chat/emoji_system/emoji_entities.py new file mode 100644 index 000000000..e2b5190dc --- /dev/null +++ b/src/chat/emoji_system/emoji_entities.py @@ -0,0 +1,192 @@ +import asyncio +import base64 +import binascii +import hashlib +import io +import os +import time +import traceback + +from PIL import Image + +from src.chat.emoji_system.emoji_constants import EMOJI_REGISTERED_DIR +from src.chat.utils.utils_image import image_path_to_base64 +from src.common.database.api.crud import CRUDBase +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Emoji +from src.common.database.optimization.cache_manager import get_cache +from src.common.database.utils.decorators import generate_cache_key +from src.common.logger import get_logger + +logger = get_logger("emoji") + + +class MaiEmoji: + """定义一个表情包""" + + def __init__(self, full_path: str): + if not full_path: + raise ValueError("full_path cannot be empty") + self.full_path = full_path + self.path = os.path.dirname(full_path) + self.filename = os.path.basename(full_path) + self.embedding = [] + self.hash = "" + self.description = "" + self.emotion: list[str] = [] + self.usage_count = 0 + self.last_used_time = time.time() + self.register_time = time.time() + self.is_deleted = False + self.format = "" + + async def initialize_hash_format(self) -> bool | None: + """从文件创建表情包实例, 计算哈希值和格式""" + try: + if not os.path.exists(self.full_path): + logger.error(f"[初始化错误] 表情包文件不存在: {self.full_path}") + self.is_deleted = True + return None + + logger.debug(f"[初始化] 正在读取文件: {self.full_path}") + image_base64 = image_path_to_base64(self.full_path) + if image_base64 is None: + logger.error(f"[初始化错误] 无法读取或转换Base64: {self.full_path}") + self.is_deleted = True + return None + logger.debug(f"[初始化] 文件读取成功 (Base64预览: {image_base64[:50]}...)") + + logger.debug(f"[初始化] 正在解码Base64并计算哈希: {self.filename}") + if isinstance(image_base64, str): + image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") + image_bytes = base64.b64decode(image_base64) + self.hash = hashlib.md5(image_bytes).hexdigest() + logger.debug(f"[初始化] 哈希计算成功: {self.hash}") + + logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") + try: + with Image.open(io.BytesIO(image_bytes)) as img: + self.format = (img.format or "jpeg").lower() + logger.debug(f"[初始化] 格式获取成功: {self.format}") + except Exception as pil_error: + logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") + logger.error(traceback.format_exc()) + self.is_deleted = True + return None + + return True + + except FileNotFoundError: + logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") + self.is_deleted = True + return None + except (binascii.Error, ValueError) as b64_error: + logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") + self.is_deleted = True + return None + except Exception as e: + logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}") + logger.error(traceback.format_exc()) + self.is_deleted = True + return None + + async def register_to_db(self) -> bool: + """注册表情包,将文件移动到注册目录并保存数据库""" + try: + source_full_path = self.full_path + destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename) + + if not await asyncio.to_thread(os.path.exists, source_full_path): + logger.error(f"[错误] 源文件不存在: {source_full_path}") + return False + + try: + if await asyncio.to_thread(os.path.exists, destination_full_path): + await asyncio.to_thread(os.remove, destination_full_path) + + await asyncio.to_thread(os.rename, source_full_path, destination_full_path) + logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") + self.full_path = destination_full_path + self.path = EMOJI_REGISTERED_DIR + except Exception as move_error: + logger.error(f"[错误] 移动文件失败: {move_error!s}") + return False + + try: + async with get_db_session() as session: + emotion_str = ",".join(self.emotion) if self.emotion else "" + + emoji = Emoji( + emoji_hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, + query_count=0, + is_registered=True, + is_banned=False, + record_time=self.register_time, + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) + session.add(emoji) + await session.commit() + + logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") + + return True + + except Exception as db_error: + logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}") + return False + + except Exception as e: + logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}") + logger.error(traceback.format_exc()) + return False + + async def delete(self) -> bool: + """删除表情包文件及数据库记录""" + try: + file_to_delete = self.full_path + if await asyncio.to_thread(os.path.exists, file_to_delete): + try: + await asyncio.to_thread(os.remove, file_to_delete) + logger.debug(f"[删除] 文件: {file_to_delete}") + except Exception as e: + logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}") + + try: + crud = CRUDBase(Emoji) + will_delete_emoji = await crud.get_by(emoji_hash=self.hash) + if will_delete_emoji is None: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 + else: + await crud.delete(will_delete_emoji.id) + result = 1 + + cache = await get_cache() + await cache.delete(generate_cache_key("emoji_by_hash", self.hash)) + await cache.delete(generate_cache_key("emoji_description", self.hash)) + await cache.delete(generate_cache_key("emoji_tag", self.hash)) + except Exception as e: + logger.error(f"[错误] 删除数据库记录时出错: {e!s}") + result = 0 + + if result > 0: + logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") + self.is_deleted = True + return True + if not os.path.exists(file_to_delete): + logger.warning( + f"[警告] 表情包文件 {file_to_delete} 已删除,但数据库记录删除失败 (Hash: {self.hash})" + ) + else: + logger.error(f"[错误] 删除表情包数据库记录失败: {self.hash}") + return False + + except Exception as e: + logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}") + return False diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 008de40c5..873c954c2 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,6 +1,5 @@ import asyncio import base64 -import binascii import hashlib import io import json @@ -16,6 +15,16 @@ from PIL import Image from rich.traceback import install from sqlalchemy import select +from src.chat.emoji_system.emoji_constants import EMOJI_DIR, EMOJI_REGISTERED_DIR, MAX_EMOJI_FOR_PROMPT +from src.chat.emoji_system.emoji_entities import MaiEmoji +from src.chat.emoji_system.emoji_utils import ( + _emoji_objects_to_readable_list, + _to_emoji_objects, + _ensure_emoji_dir, + clear_temp_emoji, + clean_unused_emojis, + list_image_files, +) from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 from src.common.database.api.crud import CRUDBase from src.common.database.compatibility import get_db_session @@ -25,367 +34,8 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -install(extra_lines=3) - logger = get_logger("emoji") -BASE_DIR = os.path.join("data") -EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 -EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 -MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 - -""" -还没经过测试,有些地方数据库和内存数据同步可能不完全 - -""" - - -class MaiEmoji: - """定义一个表情包""" - - def __init__(self, full_path: str): - if not full_path: - raise ValueError("full_path cannot be empty") - self.full_path = full_path # 文件的完整路径 (包括文件名) - self.path = os.path.dirname(full_path) # 文件所在的目录路径 - self.filename = os.path.basename(full_path) # 文件名 - self.embedding = [] - self.hash = "" # 初始为空,在创建实例时会计算 - self.description = "" - self.emotion: list[str] = [] - self.usage_count = 0 - self.last_used_time = time.time() - self.register_time = time.time() - self.is_deleted = False # 标记是否已被删除 - self.format = "" - - async def initialize_hash_format(self) -> bool | None: - """从文件创建表情包实例, 计算哈希值和格式""" - try: - # 使用 full_path 检查文件是否存在 - if not os.path.exists(self.full_path): - logger.error(f"[初始化错误] 表情包文件不存在: {self.full_path}") - self.is_deleted = True - return None - - # 使用 full_path 读取文件 - logger.debug(f"[初始化] 正在读取文件: {self.full_path}") - image_base64 = image_path_to_base64(self.full_path) - if image_base64 is None: - logger.error(f"[初始化错误] 无法读取或转换Base64: {self.full_path}") - self.is_deleted = True - return None - logger.debug(f"[初始化] 文件读取成功 (Base64预览: {image_base64[:50]}...)") - - # 计算哈希值 - logger.debug(f"[初始化] 正在解码Base64并计算哈希: {self.filename}") - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - self.hash = hashlib.md5(image_bytes).hexdigest() - logger.debug(f"[初始化] 哈希计算成功: {self.hash}") - - # 获取图片格式 - logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") - try: - with Image.open(io.BytesIO(image_bytes)) as img: - self.format = (img.format or "jpeg").lower() - logger.debug(f"[初始化] 格式获取成功: {self.format}") - except Exception as pil_error: - logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") - logger.error(traceback.format_exc()) - self.is_deleted = True - return None - - # 如果所有步骤成功,返回 True - return True - - except FileNotFoundError: - logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") - self.is_deleted = True - return None - except (binascii.Error, ValueError) as b64_error: - logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") - self.is_deleted = True - return None - except Exception as e: - logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}") - logger.error(traceback.format_exc()) - self.is_deleted = True - return None - - async def register_to_db(self) -> bool: - """ - 注册表情包 - 将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下 - 并修改对应的实例属性,然后将表情包信息保存到数据库中 - """ - try: - # 确保目标目录存在 - - # 源路径是当前实例的完整路径 self.full_path - source_full_path = self.full_path - # 目标完整路径 - destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename) - - # 检查源文件是否存在 - if not os.path.exists(source_full_path): - logger.error(f"[错误] 源文件不存在: {source_full_path}") - return False - - # --- 文件移动 --- - try: - # 如果目标文件已存在,先删除 (确保移动成功) - if os.path.exists(destination_full_path): - os.remove(destination_full_path) - - os.rename(source_full_path, destination_full_path) - logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") - # 更新实例的路径属性为新路径 - self.full_path = destination_full_path - self.path = EMOJI_REGISTERED_DIR - # self.filename 保持不变 - except Exception as move_error: - logger.error(f"[错误] 移动文件失败: {move_error!s}") - # 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败 - return False - - # --- 数据库操作 --- - try: - # 准备数据库记录 for emoji collection - async with get_db_session() as session: - emotion_str = ",".join(self.emotion) if self.emotion else "" - - emoji = Emoji( - emoji_hash=self.hash, - full_path=self.full_path, - format=self.format, - description=self.description, - emotion=emotion_str, # Store as comma-separated string - query_count=0, # Default value - is_registered=True, - is_banned=False, # Default value - record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time - register_time=self.register_time, - usage_count=self.usage_count, - last_used_time=self.last_used_time, - ) - session.add(emoji) - await session.commit() - - logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") - - return True - - except Exception as db_error: - logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}") - return False - - except Exception as e: - logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}") - logger.error(traceback.format_exc()) - return False - - async def delete(self) -> bool: - """删除表情包 - - 删除表情包的文件和数据库记录 - - 返回: - bool: 是否成功删除 - """ - try: - # 1. 删除文件 - file_to_delete = self.full_path - if os.path.exists(file_to_delete): - try: - os.remove(file_to_delete) - logger.debug(f"[删除] 文件: {file_to_delete}") - except Exception as e: - logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}") - # 文件删除失败,但仍然尝试删除数据库记录 - - # 2. 删除数据库记录 - try: - # 使用CRUD进行删除 - crud = CRUDBase(Emoji) - will_delete_emoji = await crud.get_by(emoji_hash=self.hash) - if will_delete_emoji is None: - logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted - else: - await crud.delete(will_delete_emoji.id) - result = 1 # Successfully deleted one record - - # 使缓存失效 - from src.common.database.optimization.cache_manager import get_cache - from src.common.database.utils.decorators import generate_cache_key - cache = await get_cache() - await cache.delete(generate_cache_key("emoji_by_hash", self.hash)) - await cache.delete(generate_cache_key("emoji_description", self.hash)) - await cache.delete(generate_cache_key("emoji_tag", self.hash)) - except Exception as e: - logger.error(f"[错误] 删除数据库记录时出错: {e!s}") - result = 0 - - if result > 0: - logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") - # 3. 标记对象已被删除 - self.is_deleted = True - return True - else: - # 如果数据库记录删除失败,但文件可能已删除,记录一个警告 - if not os.path.exists(file_to_delete): - logger.warning( - f"[警告] 表情包文件 {file_to_delete} 已删除,但数据库记录删除失败 (Hash: {self.hash})" - ) - else: - logger.error(f"[错误] 删除表情包数据库记录失败: {self.hash}") - return False - - except Exception as e: - logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}") - return False - - -def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]: - """将表情包对象列表转换为可读的字符串列表 - - 参数: - emoji_objects: MaiEmoji对象列表 - - 返回: - list[str]: 可读的表情包信息字符串列表 - """ - emoji_info_list = [] - for i, emoji in enumerate(emoji_objects): - # 转换时间戳为可读时间 - time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(emoji.register_time)) - # 构建每个表情包的信息字符串 - emoji_info = f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.usage_count}\n添加时间: {time_str}\n" - emoji_info_list.append(emoji_info) - return emoji_info_list - - -def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]: - emoji_objects = [] - load_errors = 0 - emoji_data_list = list(data) - - for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance - full_path = emoji_data.full_path - if not full_path: - logger.warning( - f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" - ) - load_errors += 1 - continue - - try: - emoji = MaiEmoji(full_path=full_path) - - emoji.hash = emoji_data.emoji_hash - if not emoji.hash: - logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") - load_errors += 1 - continue - - emoji.description = emoji_data.description - # Deserialize emotion string from DB to list - emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else [] - emoji.usage_count = emoji_data.usage_count - - db_last_used_time = emoji_data.last_used_time - db_register_time = emoji_data.register_time - - # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time - emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time - # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) - emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time - - emoji.format = emoji_data.format - - emoji_objects.append(emoji) - - except ValueError as ve: - logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") - load_errors += 1 - except Exception as e: - logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}") - load_errors += 1 - return emoji_objects, load_errors - - -def _ensure_emoji_dir() -> None: - """确保表情存储目录存在""" - os.makedirs(EMOJI_DIR, exist_ok=True) - os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) - - -async def clear_temp_emoji() -> None: - """清理临时表情包 - 清理/data/emoji、/data/image和/data/images目录下的所有文件 - 当目录中文件数超过100时,会全部删除 - """ - - logger.info("[清理] 开始清理缓存...") - - for need_clear in ( - os.path.join(BASE_DIR, "emoji"), - os.path.join(BASE_DIR, "image"), - os.path.join(BASE_DIR, "images"), - ): - if os.path.exists(need_clear): - files = os.listdir(need_clear) - # 如果文件数超过1000就全部删除 - if len(files) > 1000: - for filename in files: - file_path = os.path.join(need_clear, filename) - if os.path.isfile(file_path): - os.remove(file_path) - logger.debug(f"[清理] 删除: {filename}") - - -async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int: - """清理指定目录中未被 emoji_objects 追踪的表情包文件""" - if not os.path.exists(emoji_dir): - logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") - return removed_count - - cleaned_count = 0 - try: - # 获取内存中所有有效表情包的完整路径集合 - tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} - - # 遍历指定目录中的所有文件 - for file_name in os.listdir(emoji_dir): - file_full_path = os.path.join(emoji_dir, file_name) - - # 确保处理的是文件而不是子目录 - if not os.path.isfile(file_full_path): - continue - - # 如果文件不在被追踪的集合中,则删除 - if file_full_path not in tracked_full_paths: - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}") - cleaned_count += 1 - except Exception as e: - logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}") - - if cleaned_count > 0: - logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。") - else: - logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") - - except Exception as e: - logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}") - - return removed_count + cleaned_count - - class EmojiManager: _instance = None _initialized: bool = False # 显式声明,避免属性未定义错误 @@ -401,6 +51,10 @@ class EmojiManager: return # 如果已经初始化过,直接返回 self._scan_task = None + self._emoji_index: dict[str, MaiEmoji] = {} + self._integrity_yield_every = 50 + self._integrity_cursor = 0 + self._integrity_batch_size = 500 if model_config is None: raise RuntimeError("Model config is not initialized") @@ -568,34 +222,40 @@ class EmojiManager: 如果文件已被删除,则执行对象的删除方法并从列表中移除 """ try: - # if not self.emoji_objects: - # logger.warning("[检查] emoji_objects为空,跳过完整性检查") - # return - total_count = len(self.emoji_objects) self.emoji_num = total_count removed_count = 0 - # 使用列表复制进行遍历,因为我们会在遍历过程中修改列表 - objects_to_remove = [] - for emoji in self.emoji_objects: + if total_count == 0: + return + + start = self._integrity_cursor % total_count + end = min(start + self._integrity_batch_size, total_count) + indices: list[int] = list(range(start, end)) + if end - start < self._integrity_batch_size and total_count > 0: + wrap_rest = self._integrity_batch_size - (end - start) + if wrap_rest > 0: + indices.extend(range(0, min(wrap_rest, total_count))) + + objects_to_remove: list[MaiEmoji] = [] + processed = 0 + for idx in indices: + if idx >= len(self.emoji_objects): + break + emoji = self.emoji_objects[idx] try: - # 跳过已经标记为删除的,避免重复处理 if emoji.is_deleted: - objects_to_remove.append(emoji) # 收集起来一次性移除 + objects_to_remove.append(emoji) continue - # 检查文件是否存在 - if not os.path.exists(emoji.full_path): + exists = await asyncio.to_thread(os.path.exists, emoji.full_path) + if not exists: logger.warning(f"[检查] 表情包文件丢失: {emoji.full_path}") - # 执行表情包对象的删除方法 - await emoji.delete() # delete 方法现在会标记 is_deleted - objects_to_remove.append(emoji) # 标记删除后,也收集起来移除 - # 更新计数 + await emoji.delete() + objects_to_remove.append(emoji) self.emoji_num -= 1 removed_count += 1 continue - # 检查描述是否为空 (如果为空也视为无效) if not emoji.description: logger.warning(f"[检查] 表情包描述为空,视为无效: {emoji.filename}") await emoji.delete() @@ -604,19 +264,24 @@ class EmojiManager: removed_count += 1 continue + processed += 1 + if processed % self._integrity_yield_every == 0: + await asyncio.sleep(0) + except Exception as item_error: logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}") - # 即使出错,也尝试继续检查下一个 continue - # 从 self.emoji_objects 中移除标记的对象 if objects_to_remove: self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove] + for e in objects_to_remove: + if e.hash in self._emoji_index: + self._emoji_index.pop(e.hash, None) + + self._integrity_cursor = (start + processed) % max(1, len(self.emoji_objects)) - # 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件 removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count) - # 输出清理结果 if removed_count > 0: logger.info(f"[清理] 已清理 {removed_count} 个失效/文件丢失的表情包记录") logger.info(f"[统计] 清理前记录数: {total_count} | 清理后有效记录数: {len(self.emoji_objects)}") @@ -639,36 +304,30 @@ class EmojiManager: logger.info("[扫描] 开始扫描新表情包...") # 检查表情包目录是否存在 - if not os.path.exists(EMOJI_DIR): + if not await asyncio.to_thread(os.path.exists, EMOJI_DIR): logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}") - os.makedirs(EMOJI_DIR, exist_ok=True) + await asyncio.to_thread(os.makedirs, EMOJI_DIR, True) logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}") await asyncio.sleep(global_config.emoji.check_interval * 60) continue - # 检查目录是否为空 - files = os.listdir(EMOJI_DIR) - if not files: + image_files, is_empty = await list_image_files(EMOJI_DIR) + if is_empty: logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}") await asyncio.sleep(global_config.emoji.check_interval * 60) continue + if not image_files: + await asyncio.sleep(global_config.emoji.check_interval * 60) + continue + # 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册 # 只有在需要腾出空间或填充表情库时,才真正执行注册 if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( self.emoji_num < self.emoji_num_max ): try: - # 获取目录下所有图片文件 - files_to_process = [ - f - for f in files - if os.path.isfile(os.path.join(EMOJI_DIR, f)) - and f.lower().endswith((".jpg", ".jpeg", ".png", ".gif")) - ] - - # 处理每个符合条件的文件 - for filename in files_to_process: + for filename in image_files: # 尝试注册表情包 success = await self.register_emoji_by_filename(filename) if success: @@ -677,8 +336,9 @@ class EmojiManager: # 注册失败则删除对应文件 file_path = os.path.join(EMOJI_DIR, filename) - os.remove(file_path) + await asyncio.to_thread(os.remove, file_path) logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") + await asyncio.sleep(0) except Exception as e: logger.error(f"[错误] 扫描表情包目录失败: {e!s}") @@ -698,6 +358,7 @@ class EmojiManager: # 更新内存中的列表和数量 self.emoji_objects = emoji_objects self.emoji_num = len(emoji_objects) + self._emoji_index = {e.hash: e for e in emoji_objects if getattr(e, "hash", None)} logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。") if load_errors > 0: @@ -753,11 +414,15 @@ class EmojiManager: 返回: MaiEmoji 或 None: 如果找到则返回 MaiEmoji 对象,否则返回 None """ - for emoji in self.emoji_objects: - # 确保对象未被标记为删除且哈希值匹配 - if not emoji.is_deleted and emoji.hash == emoji_hash: - return emoji - return None # 如果循环结束还没找到,则返回 None + emoji = self._emoji_index.get(emoji_hash) + if emoji and not emoji.is_deleted: + return emoji + + for item in self.emoji_objects: + if not item.is_deleted and item.hash == emoji_hash: + self._emoji_index[emoji_hash] = item + return item + return None @cached(ttl=1800, key_prefix="emoji_tag") # 缓存30分钟 async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None: @@ -849,6 +514,7 @@ class EmojiManager: if success: # 从emoji_objects列表中移除该对象 self.emoji_objects = [e for e in self.emoji_objects if e.hash != emoji_hash] + self._emoji_index.pop(emoji_hash, None) # 更新计数 self.emoji_num -= 1 logger.info(f"[统计] 当前表情包数量: {self.emoji_num}") @@ -931,6 +597,7 @@ class EmojiManager: register_success = await new_emoji.register_to_db() if register_success: self.emoji_objects.append(new_emoji) + self._emoji_index[new_emoji.hash] = new_emoji self.emoji_num += 1 logger.info(f"[成功] 注册: {new_emoji.filename}") return True @@ -1099,7 +766,7 @@ class EmojiManager: bool: 注册是否成功 """ file_full_path = os.path.join(EMOJI_DIR, filename) - if not os.path.exists(file_full_path): + if not await asyncio.to_thread(os.path.exists, file_full_path): logger.error(f"[注册失败] 文件不存在: {file_full_path}") return False @@ -1117,7 +784,7 @@ class EmojiManager: logger.warning(f"[注册跳过] 表情包已存在 (Hash: {new_emoji.hash}): {filename}") # 删除重复的源文件 try: - os.remove(file_full_path) + await asyncio.to_thread(os.remove, file_full_path) logger.info(f"[清理] 删除重复的待注册文件: {filename}") except Exception as e: logger.error(f"[错误] 删除重复文件失败: {e!s}") @@ -1137,7 +804,7 @@ class EmojiManager: logger.warning(f"[注册失败] 未能生成有效描述或审核未通过: {filename}") # 删除未能生成描述的文件 try: - os.remove(file_full_path) + await asyncio.to_thread(os.remove, file_full_path) logger.info(f"[清理] 删除描述生成失败的文件: {filename}") except Exception as e: logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}") @@ -1149,7 +816,7 @@ class EmojiManager: logger.error(f"[注册失败] 生成描述/情感时出错 ({filename}): {build_desc_error}") # 同样考虑删除文件 try: - os.remove(file_full_path) + await asyncio.to_thread(os.remove, file_full_path) logger.info(f"[清理] 删除描述生成异常的文件: {filename}") except Exception as e: logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}") @@ -1163,7 +830,7 @@ class EmojiManager: logger.error("[注册失败] 替换表情包失败,无法完成注册") # 替换失败,删除新表情包文件 try: - os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径 + await asyncio.to_thread(os.remove, file_full_path) # new_emoji 的 full_path 此时还是源路径 logger.info(f"[清理] 删除替换失败的新表情文件: {filename}") except Exception as e: logger.error(f"[错误] 删除替换失败文件时出错: {e!s}") @@ -1176,6 +843,7 @@ class EmojiManager: if register_success: # 注册成功后,添加到内存列表 self.emoji_objects.append(new_emoji) + self._emoji_index[new_emoji.hash] = new_emoji self.emoji_num += 1 logger.info(f"[成功] 注册新表情包: {filename} (当前: {self.emoji_num}/{self.emoji_num_max})") return True @@ -1183,9 +851,9 @@ class EmojiManager: logger.error(f"[注册失败] 保存表情包到数据库/移动文件失败: {filename}") # register_to_db 失败时,内部会尝试清理移动后的文件,源文件可能还在 # 是否需要删除源文件? - if os.path.exists(file_full_path): + if await asyncio.to_thread(os.path.exists, file_full_path): try: - os.remove(file_full_path) + await asyncio.to_thread(os.remove, file_full_path) logger.info(f"[清理] 删除注册失败的源文件: {filename}") except Exception as e: logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}") @@ -1195,9 +863,9 @@ class EmojiManager: logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}") logger.error(traceback.format_exc()) # 尝试删除源文件以避免循环处理 - if os.path.exists(file_full_path): + if await asyncio.to_thread(os.path.exists, file_full_path): try: - os.remove(file_full_path) + await asyncio.to_thread(os.remove, file_full_path) logger.info(f"[清理] 删除处理异常的源文件: {filename}") except Exception as remove_error: logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}") diff --git a/src/chat/emoji_system/emoji_utils.py b/src/chat/emoji_system/emoji_utils.py new file mode 100644 index 000000000..53144dc1f --- /dev/null +++ b/src/chat/emoji_system/emoji_utils.py @@ -0,0 +1,140 @@ +import asyncio +import os +import time +from typing import Any + +from src.chat.emoji_system.emoji_constants import BASE_DIR, EMOJI_DIR, EMOJI_REGISTERED_DIR +from src.chat.emoji_system.emoji_entities import MaiEmoji +from src.common.logger import get_logger + +logger = get_logger("emoji") + + +def _emoji_objects_to_readable_list(emoji_objects: list[MaiEmoji]) -> list[str]: + emoji_info_list = [] + for i, emoji in enumerate(emoji_objects): + time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(emoji.register_time)) + emoji_info = f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.usage_count}\n添加时间: {time_str}\n" + emoji_info_list.append(emoji_info) + return emoji_info_list + + +def _to_emoji_objects(data: Any) -> tuple[list[MaiEmoji], int]: + emoji_objects = [] + load_errors = 0 + emoji_data_list = list(data) + + for emoji_data in emoji_data_list: + full_path = emoji_data.full_path + if not full_path: + logger.warning( + f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" + ) + load_errors += 1 + continue + + try: + emoji = MaiEmoji(full_path=full_path) + + emoji.hash = emoji_data.emoji_hash + if not emoji.hash: + logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") + load_errors += 1 + continue + + emoji.description = emoji_data.description + emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else [] + emoji.usage_count = emoji_data.usage_count + + db_last_used_time = emoji_data.last_used_time + db_register_time = emoji_data.register_time + + emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time + emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time + + emoji.format = emoji_data.format + + emoji_objects.append(emoji) + + except ValueError as ve: + logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") + load_errors += 1 + except Exception as e: + logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}") + load_errors += 1 + return emoji_objects, load_errors + + +def _ensure_emoji_dir() -> None: + os.makedirs(EMOJI_DIR, exist_ok=True) + os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) + + +async def clear_temp_emoji() -> None: + logger.info("[清理] 开始清理缓存...") + + for need_clear in ( + os.path.join(BASE_DIR, "emoji"), + os.path.join(BASE_DIR, "image"), + os.path.join(BASE_DIR, "images"), + ): + if await asyncio.to_thread(os.path.exists, need_clear): + files = await asyncio.to_thread(os.listdir, need_clear) + if len(files) > 1000: + for i, filename in enumerate(files): + file_path = os.path.join(need_clear, filename) + if await asyncio.to_thread(os.path.isfile, file_path): + try: + await asyncio.to_thread(os.remove, file_path) + logger.debug(f"[清理] 删除: {filename}") + except Exception as e: + logger.debug(f"[清理] 删除失败 {filename}: {e!s}") + if (i + 1) % 100 == 0: + await asyncio.sleep(0) + + +async def clean_unused_emojis(emoji_dir: str, emoji_objects: list[MaiEmoji], removed_count: int) -> int: + if not await asyncio.to_thread(os.path.exists, emoji_dir): + logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") + return removed_count + + cleaned_count = 0 + try: + tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} + + for entry in await asyncio.to_thread(lambda: list(os.scandir(emoji_dir))): + if not entry.is_file(): + continue + + file_full_path = entry.path + + if file_full_path not in tracked_full_paths: + try: + await asyncio.to_thread(os.remove, file_full_path) + logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}") + cleaned_count += 1 + except Exception as e: + logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}") + + if cleaned_count > 0: + logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。") + else: + logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") + + except Exception as e: + logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}") + + return removed_count + cleaned_count + + +async def list_image_files(directory: str) -> tuple[list[str], bool]: + def _scan() -> tuple[list[str], bool]: + entries = list(os.scandir(directory)) + files = [ + entry.name + for entry in entries + if entry.is_file() and entry.name.lower().endswith((".jpg", ".jpeg", ".png", ".gif")) + ] + return files, len(entries) == 0 + + return await asyncio.to_thread(_scan) diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py index 96c175648..13e0efd0d 100644 --- a/src/chat/express/express_utils.py +++ b/src/chat/express/express_utils.py @@ -7,11 +7,26 @@ import random import re from typing import Any +try: + from sklearn.feature_extraction.text import TfidfVectorizer + from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity + + HAS_SKLEARN = True +except Exception: # pragma: no cover - 依赖缺失时静默回退 + HAS_SKLEARN = False + from src.common.logger import get_logger logger = get_logger("express_utils") +# 预编译正则,减少重复编译开销 +_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*") +_RE_AT = re.compile(r"@<[^>]*>") +_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]") +_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]") + + def filter_message_content(content: str | None) -> str: """ 过滤消息内容,移除回复、@、图片等格式 @@ -25,29 +40,56 @@ def filter_message_content(content: str | None) -> str: if not content: return "" - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r"\[回复.*?\],说:\s*", "", content) - # 移除@<...>格式的内容 - content = re.sub(r"@<[^>]*>", "", content) - # 移除[图片:...]格式的图片ID - content = re.sub(r"\[图片:[^\]]*\]", "", content) - # 移除[表情包:...]格式的内容 - content = re.sub(r"\[表情包:[^\]]*\]", "", content) + # 使用预编译正则提升性能 + content = _RE_REPLY.sub("", content) + content = _RE_AT.sub("", content) + content = _RE_IMAGE.sub("", content) + content = _RE_EMOJI.sub("", content) return content.strip() -def calculate_similarity(text1: str, text2: str) -> float: +def _similarity_tfidf(text1: str, text2: str) -> float | None: + """使用 TF-IDF + 余弦相似度;依赖 sklearn,缺失则返回 None。""" + if not HAS_SKLEARN: + return None + # 过短文本用传统算法更稳健 + if len(text1) < 2 or len(text2) < 2: + return None + try: + vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2)) + tfidf = vec.fit_transform([text1, text2]) + sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0]) + return max(0.0, min(1.0, sim)) + except Exception: + return None + + +def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float: """ 计算两个文本的相似度,返回0-1之间的值 + - 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒) + - 不可用或失败时回退到 SequenceMatcher + Args: text1: 第一个文本 text2: 第二个文本 + prefer_vector: 是否优先使用向量化方案(默认是) Returns: 相似度值 (0-1) """ + if not text1 or not text2: + return 0.0 + if text1 == text2: + return 1.0 + + if prefer_vector: + sim = _similarity_tfidf(text1, text2) + if sim is not None: + return sim + return difflib.SequenceMatcher(None, text1, text2).ratio() @@ -79,18 +121,10 @@ def weighted_sample(population: list[dict], k: int, weight_key: str | None = Non except (ValueError, TypeError) as e: logger.warning(f"加权抽样失败,使用等概率抽样: {e}") - # 等概率抽样 - selected = [] + # 等概率抽样(无放回,保持去重) population_copy = population.copy() - - for _ in range(k): - if not population_copy: - break - # 随机选择一个元素 - idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(idx)) - - return selected + # 使用 random.sample 提升可读性和性能 + return random.sample(population_copy, k) def normalize_text(text: str) -> str: @@ -130,8 +164,9 @@ def extract_keywords(text: str, max_keywords: int = 10) -> list[str]: return keywords except ImportError: logger.warning("rjieba未安装,无法提取关键词") - # 简单分词 + # 简单分词,按长度降序优先输出较长词,提升粗略关键词质量 words = text.split() + words.sort(key=len, reverse=True) return words[:max_keywords] @@ -236,15 +271,18 @@ def merge_expressions_from_multiple_chats( # 收集所有表达方式 for chat_id, expressions in expressions_dict.items(): for expr in expressions: - # 添加source_id标识 expr_with_source = expr.copy() expr_with_source["source_id"] = chat_id all_expressions.append(expr_with_source) - # 按count或last_active_time排序 - if all_expressions and "count" in all_expressions[0]: + if not all_expressions: + return [] + + # 选择排序键(优先 count,其次 last_active_time),无则保持原序 + sample = all_expressions[0] + if "count" in sample: all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True) - elif all_expressions and "last_active_time" in all_expressions[0]: + elif "last_active_time" in sample: all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True) # 去重(基于situation和style) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index f4086573e..5870a8bdb 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -358,7 +358,10 @@ class ExpressionLearner: @staticmethod @cached(ttl=600, key_prefix="chat_expressions") async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]: - """内部方法:从数据库获取表达方式(带缓存)""" + """内部方法:从数据库获取表达方式(带缓存) + + 🔥 优化:使用列表推导式和更高效的数据处理 + """ learnt_style_expressions = [] learnt_grammar_expressions = [] @@ -366,67 +369,91 @@ class ExpressionLearner: crud = CRUDBase(Expression) all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000) + # 🔥 优化:使用列表推导式批量处理,减少循环开销 for expr in all_expressions: - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - expr_data = { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": expr.type, - "create_date": create_date, - } + expr_data = { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": expr.type, + "create_date": create_date, + } - # 根据类型分类 - if expr.type == "style": - learnt_style_expressions.append(expr_data) - elif expr.type == "grammar": - learnt_grammar_expressions.append(expr_data) + # 根据类型分类(避免多次类型检查) + if expr.type == "style": + learnt_style_expressions.append(expr_data) + elif expr.type == "grammar": + learnt_grammar_expressions.append(expr_data) + logger.debug(f"已加载 {len(learnt_style_expressions)} 个style和 {len(learnt_grammar_expressions)} 个grammar表达方式 (chat_id={chat_id})") return learnt_style_expressions, learnt_grammar_expressions async def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 - 优化: 使用CRUD批量处理所有更改,最后统一提交 + 优化: 使用分批处理和原生 SQL 操作提升性能 """ try: - # 使用CRUD查询所有表达方式 - crud = CRUDBase(Expression) - all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式 - + BATCH_SIZE = 1000 # 分批处理,避免一次性加载过多数据 updated_count = 0 deleted_count = 0 + offset = 0 - # 需要手动操作的情况下使用session - async with get_db_session() as session: - # 批量处理所有修改 - for expr in all_expressions: - # 计算时间差 - last_active = expr.last_active_time - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days) - new_count = max(0.01, expr.count - decay_value) - - if new_count <= 0.01: - # 如果count太小,删除这个表达方式 - await session.delete(expr) - deleted_count += 1 - else: - # 更新count - expr.count = new_count - updated_count += 1 - - # 优化: 统一提交所有更改(从N次提交减少到1次) - if updated_count > 0 or deleted_count > 0: + while True: + async with get_db_session() as session: + # 分批查询表达方式 + batch_result = await session.execute( + select(Expression) + .order_by(Expression.id) + .limit(BATCH_SIZE) + .offset(offset) + ) + batch_expressions = list(batch_result.scalars()) + + if not batch_expressions: + break # 没有更多数据 + + # 批量处理当前批次 + to_delete = [] + for expr in batch_expressions: + # 计算时间差 + time_diff_days = (current_time - expr.last_active_time) / (24 * 3600) + + # 计算衰减值 + decay_value = self.calculate_decay_factor(time_diff_days) + new_count = max(0.01, expr.count - decay_value) + + if new_count <= 0.01: + # 标记删除 + to_delete.append(expr) + else: + # 更新count + expr.count = new_count + updated_count += 1 + + # 批量删除 + if to_delete: + for expr in to_delete: + await session.delete(expr) + deleted_count += len(to_delete) + + # 提交当前批次 await session.commit() - logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") + + # 如果批次不满,说明已经处理完所有数据 + if len(batch_expressions) < BATCH_SIZE: + break + + offset += BATCH_SIZE + + if updated_count > 0 or deleted_count > 0: + logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") except Exception as e: logger.error(f"数据库全局衰减失败: {e}") @@ -509,88 +536,106 @@ class ExpressionLearner: CRUDBase(Expression) for chat_id, expr_list in chat_dict.items(): async with get_db_session() as session: + # 🔥 优化:批量查询所有现有表达方式,避免N次数据库查询 + existing_exprs_result = await session.execute( + select(Expression).where( + (Expression.chat_id == chat_id) + & (Expression.type == type) + ) + ) + existing_exprs = list(existing_exprs_result.scalars()) + + # 构建快速查找索引 + exact_match_map = {} # (situation, style) -> Expression + situation_map = {} # situation -> Expression + style_map = {} # style -> Expression + + for expr in existing_exprs: + key = (expr.situation, expr.style) + exact_match_map[key] = expr + # 只保留第一个匹配(优先级:完全匹配 > 情景匹配 > 表达匹配) + if expr.situation not in situation_map: + situation_map[expr.situation] = expr + if expr.style not in style_map: + style_map[expr.style] = expr + + # 批量处理所有新表达方式 for new_expr in expr_list: - # 🔥 改进1:检查是否存在相同情景或相同表达的数据 - # 情况1:相同 chat_id + type + situation(相同情景,不同表达) - query_same_situation = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.situation == new_expr["situation"]) - ) - ) - same_situation_expr = query_same_situation.scalar() - - # 情况2:相同 chat_id + type + style(相同表达,不同情景) - query_same_style = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.style == new_expr["style"]) - ) - ) - same_style_expr = query_same_style.scalar() - - # 情况3:完全相同(相同情景+相同表达) - query_exact_match = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.situation == new_expr["situation"]) - & (Expression.style == new_expr["style"]) - ) - ) - exact_match_expr = query_exact_match.scalar() - + situation = new_expr["situation"] + style_val = new_expr["style"] + exact_key = (situation, style_val) + # 优先处理完全匹配的情况 - if exact_match_expr: + if exact_key in exact_match_map: # 完全相同:增加count,更新时间 - expr_obj = exact_match_expr + expr_obj = exact_match_map[exact_key] expr_obj.count = expr_obj.count + 1 expr_obj.last_active_time = current_time logger.debug(f"完全匹配:更新count {expr_obj.count}") - elif same_situation_expr: + elif situation in situation_map: # 相同情景,不同表达:覆盖旧的表达 - logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'") - same_situation_expr.style = new_expr["style"] + same_situation_expr = situation_map[situation] + logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'") + # 更新映射 + old_key = (same_situation_expr.situation, same_situation_expr.style) + if old_key in exact_match_map: + del exact_match_map[old_key] + same_situation_expr.style = style_val same_situation_expr.count = same_situation_expr.count + 1 same_situation_expr.last_active_time = current_time - elif same_style_expr: + # 更新新的完全匹配映射 + exact_match_map[exact_key] = same_situation_expr + elif style_val in style_map: # 相同表达,不同情景:覆盖旧的情景 - logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'") - same_style_expr.situation = new_expr["situation"] + same_style_expr = style_map[style_val] + logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'") + # 更新映射 + old_key = (same_style_expr.situation, same_style_expr.style) + if old_key in exact_match_map: + del exact_match_map[old_key] + same_style_expr.situation = situation same_style_expr.count = same_style_expr.count + 1 same_style_expr.last_active_time = current_time + # 更新新的完全匹配映射 + exact_match_map[exact_key] = same_style_expr + situation_map[situation] = same_style_expr else: # 完全新的表达方式:创建新记录 new_expression = Expression( - situation=new_expr["situation"], - style=new_expr["style"], + situation=situation, + style=style_val, count=1, last_active_time=current_time, chat_id=chat_id, type=type, - create_date=current_time, # 手动设置创建日期 + create_date=current_time, ) session.add(new_expression) - logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}") + # 更新映射 + exact_match_map[exact_key] = new_expression + situation_map[situation] = new_expression + style_map[style_val] = new_expression + logger.debug(f"新增表达方式:{situation} -> {style_val}") - # 限制最大数量 - 使用 get_all_by_sorted 获取排序结果 - exprs_result = await session.execute( - select(Expression) - .where((Expression.chat_id == chat_id) & (Expression.type == type)) - .order_by(Expression.count.asc()) - ) - exprs = list(exprs_result.scalars()) - if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除count最小的多余表达方式 - for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: + # 🔥 优化:限制最大数量 - 使用已加载的数据避免重复查询 + # existing_exprs 已包含该 chat_id 和 type 的所有表达方式 + all_current_exprs = list(exact_match_map.values()) + if len(all_current_exprs) > MAX_EXPRESSION_COUNT: + # 按 count 排序,删除 count 最小的多余表达方式 + sorted_exprs = sorted(all_current_exprs, key=lambda e: e.count) + for expr in sorted_exprs[: len(all_current_exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) + # 从映射中移除 + key = (expr.situation, expr.style) + if key in exact_match_map: + del exact_match_map[key] + logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式") - # 提交后清除相关缓存 + # 提交数据库更改 await session.commit() - # 🔥 清除共享组内所有 chat_id 的表达方式缓存 + # 🔥 优化:只在实际有更新时才清除缓存(移到外层,避免重复清除) + if chat_dict: # 只有当有数据更新时才清除缓存 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key cache = await get_cache() @@ -602,53 +647,59 @@ class ExpressionLearner: if len(related_chat_ids) > 1: logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存") - # 🔥 训练 StyleLearner(支持共享组) - # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) - if type == "style": - try: - logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}") - - # 为每个共享组内的 chat_id 训练其 StyleLearner - for target_chat_id in related_chat_ids: - learner = style_learner_manager.get_learner(target_chat_id) + # 🔥 训练 StyleLearner(支持共享组) + # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) + if type == "style" and chat_dict: + try: + related_chat_ids = self.get_related_chat_ids() + total_samples = sum(len(expr_list) for expr_list in chat_dict.values()) + logger.debug(f"开始训练 StyleLearner: 共享组包含 {len(related_chat_ids)} 个chat_id, 总样本数={total_samples}") + # 为每个共享组内的 chat_id 训练其 StyleLearner + for target_chat_id in related_chat_ids: + learner = style_learner_manager.get_learner(target_chat_id) + + # 收集该 target_chat_id 对应的所有表达方式 + # 如果是源 chat_id,使用 chat_dict 中的数据;否则也要训练(共享组特性) + total_success = 0 + total_samples = 0 + + for source_chat_id, expr_list in chat_dict.items(): # 为每个学习到的表达方式训练模型 # 使用 situation 作为输入,style 作为目标 - # 这是最符合语义的方式:场景 -> 表达方式 - success_count = 0 for expr in expr_list: situation = expr["situation"] style = expr["style"] - + # 训练映射关系: situation -> style if learner.learn_mapping(situation, style): - success_count += 1 - else: - logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}") - - # 保存模型 + total_success += 1 + total_samples += 1 + + # 保存模型 + if total_samples > 0: if learner.save(style_learner_manager.model_save_path): logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}") else: logger.error(f"StyleLearner 模型保存失败: {target_chat_id}") - - if target_chat_id == chat_id: - # 只为源 chat_id 记录详细日志 + + if target_chat_id == self.chat_id: + # 只为当前 chat_id 记录详细日志 logger.info( - f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, " + f"StyleLearner 训练完成: {total_success}/{total_samples} 成功, " f"当前风格总数={len(learner.get_all_styles())}, " f"总样本数={learner.learning_stats['total_samples']}" ) else: logger.debug( - f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功" + f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {total_success}/{total_samples} 成功" ) - if len(related_chat_ids) > 1: - logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练") + if len(related_chat_ids) > 1: + logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练") - except Exception as e: - logger.error(f"训练 StyleLearner 失败: {e}") + except Exception as e: + logger.error(f"训练 StyleLearner 失败: {e}") return learnt_expressions return None diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 3359e7c05..cfe335cd4 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -207,31 +207,20 @@ class ExpressionSelector: select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")) ) - style_exprs = [ - { + # 🔥 优化:提前定义转换函数,避免重复代码 + def expr_to_dict(expr, expr_type: str) -> dict[str, Any]: + return { "situation": expr.situation, "style": expr.style, "count": expr.count, "last_active_time": expr.last_active_time, "source_id": expr.chat_id, - "type": "style", + "type": expr_type, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, } - for expr in style_query.scalars() - ] - - grammar_exprs = [ - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": expr.chat_id, - "type": "grammar", - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } - for expr in grammar_query.scalars() - ] + + style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()] + grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()] style_num = int(total_num * style_percentage) grammar_num = int(total_num * grammar_percentage) @@ -251,9 +240,14 @@ class ExpressionSelector: @staticmethod async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1): - """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" + """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库 + + 🔥 优化:合并所有更新到一个事务中,减少数据库连接开销 + """ if not expressions_to_update: return + + # 去重处理 updates_by_key = {} affected_chat_ids = set() for expr in expressions_to_update: @@ -269,9 +263,15 @@ class ExpressionSelector: updates_by_key[key] = expr affected_chat_ids.add(source_id) - for chat_id, expr_type, situation, style in updates_by_key: - async with get_db_session() as session: - query = await session.execute( + if not updates_by_key: + return + + # 🔥 优化:使用单个 session 批量处理所有更新 + current_time = time.time() + async with get_db_session() as session: + updated_count = 0 + for chat_id, expr_type, situation, style in updates_by_key: + query_result = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == expr_type) @@ -279,25 +279,26 @@ class ExpressionSelector: & (Expression.style == style) ) ) - query = query.scalar() - if query: - expr_obj = query + expr_obj = query_result.scalar() + if expr_obj: current_count = expr_obj.count new_count = min(current_count + increment, 5.0) expr_obj.count = new_count - expr_obj.last_active_time = time.time() + expr_obj.last_active_time = current_time + updated_count += 1 - logger.debug( - f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" - ) + # 批量提交所有更改 + if updated_count > 0: await session.commit() + logger.debug(f"批量更新了 {updated_count} 个表达方式的count值") # 清除所有受影响的chat_id的缓存 - from src.common.database.optimization.cache_manager import get_cache - from src.common.database.utils.decorators import generate_cache_key - cache = await get_cache() - for chat_id in affected_chat_ids: - await cache.delete(generate_cache_key("chat_expressions", chat_id)) + if affected_chat_ids: + from src.common.database.optimization.cache_manager import get_cache + from src.common.database.utils.decorators import generate_cache_key + cache = await get_cache() + for chat_id in affected_chat_ids: + await cache.delete(generate_cache_key("chat_expressions", chat_id)) async def select_suitable_expressions( self, @@ -518,29 +519,41 @@ class ExpressionSelector: logger.warning("数据库中完全没有任何表达方式,需要先学习") return [] - # 🔥 使用模糊匹配而不是精确匹配 - # 计算每个预测style与数据库style的相似度 + # 🔥 优化:使用更高效的模糊匹配算法 from difflib import SequenceMatcher + # 预处理:提前计算所有预测 style 的小写版本,避免重复计算 + predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]] + matched_expressions = [] for expr in all_expressions: db_style = expr.style or "" + db_style_lower = db_style.lower() max_similarity = 0.0 best_predicted = "" # 与每个预测的style计算相似度 - for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 - # 计算字符串相似度 - similarity = SequenceMatcher(None, predicted_style, db_style).ratio() - - # 也检查包含关系(如果一个是另一个的子串,给更高分) - if len(predicted_style) >= 2 and len(db_style) >= 2: - if predicted_style in db_style or db_style in predicted_style: - similarity = max(similarity, 0.7) - + for predicted_style_lower, pred_score in predicted_styles_lower: + # 快速检查:完全匹配 + if predicted_style_lower == db_style_lower: + max_similarity = 1.0 + best_predicted = predicted_style_lower + break + + # 快速检查:子串匹配 + if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2: + if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower: + similarity = 0.7 + if similarity > max_similarity: + max_similarity = similarity + best_predicted = predicted_style_lower + continue + + # 计算字符串相似度(较慢,只在必要时使用) + similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio() if similarity > max_similarity: max_similarity = similarity - best_predicted = predicted_style + best_predicted = predicted_style_lower # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 if max_similarity >= 0.3: # 30%相似度阈值 @@ -573,14 +586,15 @@ class ExpressionSelector: f"(候选 {len(matched_expressions)},temperature={temperature})" ) - # 转换为字典格式 + # 🔥 优化:使用列表推导式和预定义函数减少开销 expressions = [ { "situation": expr.situation or "", "style": expr.style or "", "type": expr.type or "style", "count": float(expr.count) if expr.count else 0.0, - "last_active_time": expr.last_active_time or 0.0 + "last_active_time": expr.last_active_time or 0.0, + "source_id": expr.chat_id # 添加 source_id 以便后续更新 } for expr in expressions_objs ] diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py index 2fd6c9205..47e35e78a 100644 --- a/src/chat/express/situation_extractor.py +++ b/src/chat/express/situation_extractor.py @@ -127,7 +127,8 @@ class SituationExtractor: Returns: 情境描述列表 """ - situations = [] + situations: list[str] = [] + seen = set() for line in response.splitlines(): line = line.strip() @@ -150,6 +151,11 @@ class SituationExtractor: if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]): continue + # 去重,保持原有顺序 + if line in seen: + continue + seen.add(line) + situations.append(line) if len(situations) >= max_situations: diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index 3b099f3fd..ec76428d0 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -4,6 +4,7 @@ 支持多聊天室独立建模和在线学习 """ import os +import pickle import time from src.common.logger import get_logger @@ -16,11 +17,12 @@ logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" - def __init__(self, chat_id: str, model_config: dict | None = None): + def __init__(self, chat_id: str, model_config: dict | None = None, resource_limit_enabled: bool = True): """ Args: chat_id: 聊天室ID model_config: 模型配置 + resource_limit_enabled: 是否启用资源上限控制(默认关闭) """ self.chat_id = chat_id self.model_config = model_config or { @@ -34,6 +36,9 @@ class StyleLearner: # 初始化表达模型 self.expressor = ExpressorModel(**self.model_config) + # 资源上限控制开关(默认开启,可按需关闭) + self.resource_limit_enabled = resource_limit_enabled + # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 self.cleanup_threshold = 0.9 # 达到90%容量时触发清理 @@ -67,18 +72,15 @@ class StyleLearner: if style in self.style_to_id: return True - # 检查是否需要清理 - current_count = len(self.style_to_id) - cleanup_trigger = int(self.max_styles * self.cleanup_threshold) - - if current_count >= cleanup_trigger: - if current_count >= self.max_styles: - # 已经达到最大限制,必须清理 - logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理") - self._cleanup_styles() - elif current_count >= cleanup_trigger: - # 接近限制,提前清理 - logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理") + # 检查是否需要清理(仅计算一次阈值) + if self.resource_limit_enabled: + current_count = len(self.style_to_id) + cleanup_trigger = int(self.max_styles * self.cleanup_threshold) + if current_count >= cleanup_trigger: + if current_count >= self.max_styles: + logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理") + else: + logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理") self._cleanup_styles() # 生成新的style_id @@ -95,7 +97,8 @@ class StyleLearner: self.expressor.add_candidate(style_id, style, situation) # 初始化统计 - self.learning_stats["style_counts"][style_id] = 0 + self.learning_stats.setdefault("style_counts", {})[style_id] = 0 + self.learning_stats.setdefault("style_last_used", {}) logger.debug(f"添加风格成功: {style_id} -> {style}") return True @@ -114,64 +117,64 @@ class StyleLearner: 3. 默认清理 cleanup_ratio (20%) 的风格 """ try: + total_styles = len(self.style_to_id) + if total_styles == 0: + return + + # 只有在达到阈值时才执行昂贵的排序 + cleanup_count = max(1, int(total_styles * self.cleanup_ratio)) + if cleanup_count <= 0: + return + current_time = time.time() - cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio)) + # 局部引用加速频繁调用的函数 + from math import exp, log1p # 计算每个风格的价值分数 style_scores = [] for style_id in self.style_to_id.values(): - # 使用次数 usage_count = self.learning_stats["style_counts"].get(style_id, 0) - - # 最后使用时间(越近越好) last_used = self.learning_stats["style_last_used"].get(style_id, 0) + time_since_used = current_time - last_used if last_used > 0 else float("inf") + usage_score = log1p(usage_count) + days_unused = time_since_used / 86400 + time_score = exp(-days_unused / 30) - # 综合分数:使用次数越多越好,距离上次使用时间越短越好 - # 使用对数来平滑使用次数的影响 - import math - usage_score = math.log1p(usage_count) # log(1 + count) - - # 时间分数:转换为天数,使用指数衰减 - days_unused = time_since_used / 86400 # 转换为天 - time_score = math.exp(-days_unused / 30) # 30天衰减因子 - - # 综合分数:80%使用频率 + 20%时间新鲜度 total_score = 0.8 * usage_score + 0.2 * time_score - style_scores.append((style_id, total_score, usage_count, days_unused)) + if not style_scores: + return + # 按分数排序,分数低的先删除 style_scores.sort(key=lambda x: x[1]) - # 删除分数最低的风格 deleted_styles = [] for style_id, score, usage, days in style_scores[:cleanup_count]: style_text = self.id_to_style.get(style_id) - if style_text: - # 从映射中删除 - del self.style_to_id[style_text] - del self.id_to_style[style_id] - if style_id in self.id_to_situation: - del self.id_to_situation[style_id] + if not style_text: + continue - # 从统计中删除 - if style_id in self.learning_stats["style_counts"]: - del self.learning_stats["style_counts"][style_id] - if style_id in self.learning_stats["style_last_used"]: - del self.learning_stats["style_last_used"][style_id] + # 从映射中删除 + self.style_to_id.pop(style_text, None) + self.id_to_style.pop(style_id, None) + self.id_to_situation.pop(style_id, None) - # 从expressor模型中删除 - self.expressor.remove_candidate(style_id) + # 从统计中删除 + self.learning_stats["style_counts"].pop(style_id, None) + self.learning_stats["style_last_used"].pop(style_id, None) - deleted_styles.append((style_text[:30], usage, f"{days:.1f}天")) + # 从expressor模型中删除 + self.expressor.remove_candidate(style_id) + + deleted_styles.append((style_text[:30], usage, f"{days:.1f}天")) logger.info( f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格," f"剩余 {len(self.style_to_id)} 个风格" ) - # 记录前5个被删除的风格(用于调试) if deleted_styles: logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}") @@ -204,7 +207,9 @@ class StyleLearner: # 更新统计 current_time = time.time() self.learning_stats["total_samples"] += 1 - self.learning_stats["style_counts"][style_id] += 1 + self.learning_stats.setdefault("style_counts", {}) + self.learning_stats.setdefault("style_last_used", {}) + self.learning_stats["style_counts"][style_id] = self.learning_stats["style_counts"].get(style_id, 0) + 1 self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间 self.learning_stats["last_update"] = current_time @@ -349,11 +354,11 @@ class StyleLearner: # 保存expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") - self.expressor.save(model_path) - - # 保存映射关系和统计信息 - import pickle + tmp_model_path = f"{model_path}.tmp" + self.expressor.save(tmp_model_path) + os.replace(tmp_model_path, model_path) + # 保存映射关系和统计信息(原子写) meta_path = os.path.join(save_dir, "meta.pkl") # 确保 learning_stats 包含所有必要字段 @@ -368,8 +373,13 @@ class StyleLearner: "learning_stats": self.learning_stats, } - with open(meta_path, "wb") as f: - pickle.dump(meta_data, f) + tmp_meta_path = f"{meta_path}.tmp" + with open(tmp_meta_path, "wb") as f: + pickle.dump(meta_data, f, protocol=pickle.HIGHEST_PROTOCOL) + f.flush() + os.fsync(f.fileno()) + + os.replace(tmp_meta_path, meta_path) return True @@ -401,8 +411,6 @@ class StyleLearner: self.expressor.load(model_path) # 加载映射关系和统计信息 - import pickle - meta_path = os.path.join(save_dir, "meta.pkl") if os.path.exists(meta_path): with open(meta_path, "rb") as f: @@ -445,14 +453,16 @@ class StyleLearnerManager: # 🔧 最大活跃 learner 数量 MAX_ACTIVE_LEARNERS = 50 - def __init__(self, model_save_path: str = "data/expression/style_models"): + def __init__(self, model_save_path: str = "data/expression/style_models", resource_limit_enabled: bool = True): """ Args: model_save_path: 模型保存路径 + resource_limit_enabled: 是否启用资源上限控制(默认开启) """ self.learners: dict[str, StyleLearner] = {} self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间 self.model_save_path = model_save_path + self.resource_limit_enabled = resource_limit_enabled # 确保保存目录存在 os.makedirs(model_save_path, exist_ok=True) @@ -475,7 +485,10 @@ class StyleLearnerManager: for chat_id, last_used in sorted_by_time[:evict_count]: if chat_id in self.learners: # 先保存再淘汰 - self.learners[chat_id].save(self.model_save_path) + try: + self.learners[chat_id].save(self.model_save_path) + except Exception as e: + logger.error(f"LRU淘汰时保存学习器失败: chat_id={chat_id}, error={e}") del self.learners[chat_id] del self.learner_last_used[chat_id] evicted.append(chat_id) @@ -502,7 +515,11 @@ class StyleLearnerManager: self._evict_if_needed() # 创建新的学习器 - learner = StyleLearner(chat_id, model_config) + learner = StyleLearner( + chat_id, + model_config, + resource_limit_enabled=self.resource_limit_enabled, + ) # 尝试加载已保存的模型 learner.load(self.model_save_path) @@ -511,6 +528,12 @@ class StyleLearnerManager: return self.learners[chat_id] + def set_resource_limit(self, enabled: bool) -> None: + """动态开启/关闭资源上限控制(默认关闭)。""" + self.resource_limit_enabled = enabled + for learner in self.learners.values(): + learner.resource_limit_enabled = enabled + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: """ 学习一个映射关系 diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py index 10e43aee4..ddcec1465 100644 --- a/src/chat/interest_system/interest_manager.py +++ b/src/chat/interest_system/interest_manager.py @@ -5,6 +5,7 @@ import asyncio import time +from collections import OrderedDict from typing import TYPE_CHECKING from src.common.logger import get_logger @@ -37,19 +38,50 @@ class InterestManager: self._calculation_queue = asyncio.Queue() self._worker_task = None self._shutdown_event = asyncio.Event() + + # 性能优化相关字段 + self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存 + self._cache_max_size = 1000 # 最大缓存数量 + self._cache_ttl = 300 # 缓存TTL(秒) + self._batch_queue: asyncio.Queue = asyncio.Queue(maxsize=100) # 批处理队列 + self._batch_size = 10 # 批处理大小 + self._batch_timeout = 0.1 # 批处理超时(秒) + self._batch_task = None + self._is_warmed_up = False # 预热状态标记 + + # 性能统计 + self._cache_hits = 0 + self._cache_misses = 0 + self._batch_calculations = 0 + self._total_calculation_time = 0.0 + self._initialized = True async def initialize(self): """初始化管理器""" - pass + # 启动批处理工作线程 + if self._batch_task is None or self._batch_task.done(): + self._batch_task = asyncio.create_task(self._batch_processing_worker()) + logger.info("批处理工作线程已启动") async def shutdown(self): """关闭管理器""" self._shutdown_event.set() + + # 取消批处理任务 + if self._batch_task and not self._batch_task.done(): + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass if self._current_calculator: await self._current_calculator.cleanup() self._current_calculator = None + + # 清理缓存 + self._result_cache.clear() logger.info("兴趣值管理器已关闭") @@ -91,12 +123,13 @@ class InterestManager: logger.error(f"注册兴趣值计算组件失败: {e}") return False - async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None) -> InterestCalculationResult: - """计算消息兴趣值 + async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None, use_cache: bool = True) -> InterestCalculationResult: + """计算消息兴趣值(优化版,支持缓存) Args: message: 数据库消息对象 timeout: 最大等待时间(秒),超时则使用默认值返回;为None时不设置超时 + use_cache: 是否使用缓存,默认True Returns: InterestCalculationResult: 计算结果或默认结果 @@ -109,37 +142,53 @@ class InterestManager: interest_value=0.3, error_message="没有可用的兴趣值计算组件", ) + + message_id = getattr(message, "message_id", "") + + # 缓存查询 + if use_cache and message_id: + cached_result = self._get_from_cache(message_id) + if cached_result is not None: + self._cache_hits += 1 + logger.debug(f"命中缓存: {message_id}, 兴趣值: {cached_result.interest_value:.3f}") + return cached_result + self._cache_misses += 1 # 使用 create_task 异步执行计算 task = asyncio.create_task(self._async_calculate(message)) if timeout is None: - return await task - - try: - # 等待计算结果,但有超时限制 - result = await asyncio.wait_for(task, timeout=timeout) - return result - except asyncio.TimeoutError: - # 超时返回默认结果,但计算仍在后台继续 - logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5") - return InterestCalculationResult( - success=True, - message_id=getattr(message, "message_id", ""), - interest_value=0.5, # 固定默认兴趣值 - should_reply=False, - should_act=False, - error_message=f"计算超时({timeout}s),使用默认值", - ) - except Exception as e: - # 发生异常,返回默认结果 - logger.error(f"兴趣值计算异常: {e}") - return InterestCalculationResult( - success=False, - message_id=getattr(message, "message_id", ""), - interest_value=0.3, - error_message=f"计算异常: {e!s}", - ) + result = await task + else: + try: + # 等待计算结果,但有超时限制 + result = await asyncio.wait_for(task, timeout=timeout) + except asyncio.TimeoutError: + # 超时返回默认结果,但计算仍在后台继续 + logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {message_id} 使用默认兴趣值 0.5") + return InterestCalculationResult( + success=True, + message_id=message_id, + interest_value=0.5, # 固定默认兴趣值 + should_reply=False, + should_act=False, + error_message=f"计算超时({timeout}s),使用默认值", + ) + except Exception as e: + # 发生异常,返回默认结果 + logger.error(f"兴趣值计算异常: {e}") + return InterestCalculationResult( + success=False, + message_id=message_id, + interest_value=0.3, + error_message=f"计算异常: {e!s}", + ) + + # 缓存结果 + if use_cache and result.success and message_id: + self._put_to_cache(message_id, result) + + return result async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: """异步执行兴趣值计算""" @@ -161,6 +210,7 @@ class InterestManager: if result.success: self._last_calculation_time = time.time() + self._total_calculation_time += result.calculation_time logger.debug(f"兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)") else: self._failed_calculations += 1 @@ -170,13 +220,15 @@ class InterestManager: except Exception as e: self._failed_calculations += 1 + calc_time = time.time() - start_time + self._total_calculation_time += calc_time logger.error(f"兴趣值计算异常: {e}") return InterestCalculationResult( success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=f"计算异常: {e!s}", - calculation_time=time.time() - start_time, + calculation_time=calc_time, ) async def _calculation_worker(self): @@ -197,6 +249,155 @@ class InterestManager: break except Exception as e: logger.error(f"计算工作线程异常: {e}") + + def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None: + """从缓存中获取结果(LRU策略)""" + if message_id not in self._result_cache: + return None + + # 检查TTL + result = self._result_cache[message_id] + if time.time() - result.timestamp > self._cache_ttl: + # 过期,删除 + del self._result_cache[message_id] + return None + + # 更新访问顺序(LRU) + self._result_cache.move_to_end(message_id) + return result + + def _put_to_cache(self, message_id: str, result: InterestCalculationResult): + """将结果放入缓存(LRU策略)""" + # 如果已存在,更新 + if message_id in self._result_cache: + self._result_cache.move_to_end(message_id) + + self._result_cache[message_id] = result + + # 限制缓存大小 + while len(self._result_cache) > self._cache_max_size: + # 删除最旧的项 + self._result_cache.popitem(last=False) + + async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]: + """批量计算消息兴趣值(并发优化) + + Args: + messages: 消息列表 + timeout: 单个计算的超时时间 + + Returns: + list[InterestCalculationResult]: 计算结果列表 + """ + if not messages: + return [] + + # 并发计算所有消息 + tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理异常 + final_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"批量计算消息 {i} 失败: {result}") + final_results.append(InterestCalculationResult( + success=False, + message_id=getattr(messages[i], "message_id", ""), + interest_value=0.3, + error_message=f"批量计算异常: {result!s}", + )) + else: + final_results.append(result) + + self._batch_calculations += 1 + return final_results + + async def _batch_processing_worker(self): + """批处理工作线程""" + while not self._shutdown_event.is_set(): + batch = [] + deadline = time.time() + self._batch_timeout + + try: + # 收集批次 + while len(batch) < self._batch_size and time.time() < deadline: + remaining_time = deadline - time.time() + if remaining_time <= 0: + break + + try: + item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time) + batch.append(item) + except asyncio.TimeoutError: + break + + # 处理批次 + if batch: + await self._process_batch(batch) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"批处理工作线程异常: {e}") + + async def _process_batch(self, batch: list): + """处理批次消息""" + # 这里可以实现具体的批处理逻辑 + # 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现 + pass + + async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None): + """预热兴趣计算器 + + Args: + sample_messages: 样本消息列表,用于预热。如果为None,则只初始化计算器 + """ + if not self._current_calculator: + logger.warning("无法预热:没有可用的兴趣值计算组件") + return + + logger.info("开始预热兴趣值计算器...") + start_time = time.time() + + # 如果提供了样本消息,进行预热计算 + if sample_messages: + try: + # 批量计算样本消息 + await self.calculate_interest_batch(sample_messages, timeout=5.0) + logger.info(f"预热完成:处理了 {len(sample_messages)} 条样本消息,耗时 {time.time() - start_time:.2f}s") + except Exception as e: + logger.error(f"预热过程中出现异常: {e}") + else: + logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s") + + self._is_warmed_up = True + + def clear_cache(self): + """清空缓存""" + cleared_count = len(self._result_cache) + self._result_cache.clear() + logger.info(f"已清空 {cleared_count} 条缓存记录") + + def set_cache_config(self, max_size: int | None = None, ttl: int | None = None): + """设置缓存配置 + + Args: + max_size: 最大缓存数量 + ttl: 缓存生存时间(秒) + """ + if max_size is not None: + self._cache_max_size = max_size + logger.info(f"缓存最大容量设置为: {max_size}") + + if ttl is not None: + self._cache_ttl = ttl + logger.info(f"缓存TTL设置为: {ttl}秒") + + # 如果当前缓存超过新的最大值,清理旧数据 + if max_size is not None: + while len(self._result_cache) > self._cache_max_size: + self._result_cache.popitem(last=False) def get_current_calculator(self) -> BaseInterestCalculator | None: """获取当前活跃的兴趣值计算组件""" @@ -205,6 +406,8 @@ class InterestManager: def get_statistics(self) -> dict: """获取管理器统计信息""" success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations)) + cache_hit_rate = self._cache_hits / max(1, self._cache_hits + self._cache_misses) + avg_calc_time = self._total_calculation_time / max(1, self._total_calculations) stats = { "manager_statistics": { @@ -213,6 +416,13 @@ class InterestManager: "success_rate": success_rate, "last_calculation_time": self._last_calculation_time, "current_calculator": self._current_calculator.component_name if self._current_calculator else None, + "cache_hit_rate": cache_hit_rate, + "cache_hits": self._cache_hits, + "cache_misses": self._cache_misses, + "cache_size": len(self._result_cache), + "batch_calculations": self._batch_calculations, + "average_calculation_time": avg_calc_time, + "is_warmed_up": self._is_warmed_up, } } @@ -236,6 +446,82 @@ class InterestManager: def has_calculator(self) -> bool: """检查是否有可用的计算组件""" return self._current_calculator is not None and self._current_calculator.is_enabled + + async def adaptive_optimize(self): + """自适应优化:根据性能统计自动调整参数""" + if not self._current_calculator: + return + + stats = self.get_statistics()["manager_statistics"] + + # 根据缓存命中率调整缓存大小 + cache_hit_rate = stats["cache_hit_rate"] + if cache_hit_rate < 0.5 and self._cache_max_size < 5000: + # 命中率低,增加缓存容量 + new_size = min(self._cache_max_size * 2, 5000) + logger.info(f"自适应优化:缓存命中率较低 ({cache_hit_rate:.2%}),扩大缓存容量 {self._cache_max_size} -> {new_size}") + self._cache_max_size = new_size + elif cache_hit_rate > 0.9 and self._cache_max_size > 100: + # 命中率高,可以适当减小缓存 + new_size = max(self._cache_max_size // 2, 100) + logger.info(f"自适应优化:缓存命中率很高 ({cache_hit_rate:.2%}),缩小缓存容量 {self._cache_max_size} -> {new_size}") + self._cache_max_size = new_size + # 清理多余缓存 + while len(self._result_cache) > self._cache_max_size: + self._result_cache.popitem(last=False) + + # 根据平均计算时间调整批处理参数 + avg_calc_time = stats["average_calculation_time"] + if avg_calc_time > 0.5 and self._batch_size < 50: + # 计算较慢,增加批次大小以提高吞吐量 + new_batch_size = min(self._batch_size * 2, 50) + logger.info(f"自适应优化:平均计算时间较长 ({avg_calc_time:.3f}s),增加批次大小 {self._batch_size} -> {new_batch_size}") + self._batch_size = new_batch_size + elif avg_calc_time < 0.1 and self._batch_size > 5: + # 计算较快,可以减小批次 + new_batch_size = max(self._batch_size // 2, 5) + logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}") + self._batch_size = new_batch_size + + def get_performance_report(self) -> str: + """生成性能报告""" + stats = self.get_statistics()["manager_statistics"] + + report = [ + "=" * 60, + "兴趣值管理器性能报告", + "=" * 60, + f"总计算次数: {stats['total_calculations']}", + f"失败次数: {stats['failed_calculations']}", + f"成功率: {stats['success_rate']:.2%}", + f"缓存命中率: {stats['cache_hit_rate']:.2%}", + f"缓存命中: {stats['cache_hits']}", + f"缓存未命中: {stats['cache_misses']}", + f"当前缓存大小: {stats['cache_size']} / {self._cache_max_size}", + f"批量计算次数: {stats['batch_calculations']}", + f"平均计算时间: {stats['average_calculation_time']:.4f}s", + f"是否已预热: {'是' if stats['is_warmed_up'] else '否'}", + f"当前计算器: {stats['current_calculator'] or '无'}", + "=" * 60, + ] + + # 添加计算器统计 + if self._current_calculator: + calc_stats = self.get_statistics()["calculator_statistics"] + report.extend([ + "", + "计算器统计:", + f" 组件名称: {calc_stats['component_name']}", + f" 版本: {calc_stats['component_version']}", + f" 已启用: {calc_stats['enabled']}", + f" 总计算: {calc_stats['total_calculations']}", + f" 失败: {calc_stats['failed_calculations']}", + f" 成功率: {calc_stats['success_rate']:.2%}", + f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s", + "=" * 60, + ]) + + return "\n".join(report) # 全局实例 diff --git a/src/memory_graph/short_term_pressure_patch.md b/src/memory_graph/short_term_pressure_patch.md new file mode 100644 index 000000000..6967fe41d --- /dev/null +++ b/src/memory_graph/short_term_pressure_patch.md @@ -0,0 +1,199 @@ +# 短期记忆压力泄压补丁 + +## 📋 概述 + +在高频消息场景下,短期记忆层(`ShortTermMemoryManager`)可能在自动转移机制触发前快速堆积大量记忆,当达到容量上限(`max_memories`)时可能阻塞后续写入。本功能提供一个**可选的泄压开关**,在容量溢出时自动删除低优先级记忆,防止系统阻塞。 + +**关键特性**: +- ✅ 默认关闭,保持向后兼容 +- ✅ 基于重要性和时间的智能删除策略 +- ✅ 异步持久化,不阻塞主流程 +- ✅ 可通过配置文件或代码控制 + +--- + +## 🔧 配置方法 + +### 方法 1:代码配置(直接创建管理器) + +如果您在代码中直接实例化 `UnifiedMemoryManager`: + +```python +from src.memory_graph.unified_manager import UnifiedMemoryManager + +manager = UnifiedMemoryManager( + short_term_enable_force_cleanup=True, # 开启泄压功能 + short_term_max_memories=30, # 短期记忆容量上限 + # ... 其他参数 +) +``` + +### 方法 2:配置文件(通过单例获取) + +**推荐方式**:如果您使用 `get_unified_memory_manager()` 单例,需修改配置文件。 + +#### ❌ 目前的问题 +配置文件 `config/bot_config.toml` 的 `[memory]` 节**尚未包含**此开关参数。 + +#### ✅ 解决方案 +在 `config/bot_config.toml` 的 `[memory]` 节添加: + +```toml +[memory] +# ... 其他配置 ... +short_term_max_memories = 30 # 短期记忆容量上限 +short_term_transfer_threshold = 0.6 # 转移到长期记忆的重要性阈值 +short_term_enable_force_cleanup = true # 开启压力泄压(建议高频场景开启) +``` + +然后在 `src/memory_graph/manager_singleton.py` 第 157-175 行的 `get_unified_memory_manager()` 函数中添加读取逻辑: + +```python +_unified_memory_manager = UnifiedMemoryManager( + # ... 其他参数 ... + short_term_enable_force_cleanup=getattr(config, "short_term_enable_force_cleanup", False), # 添加此行 +) +``` + +--- + +## ⚙️ 核心实现位置 + +### 1. 参数定义 +**文件**:`src/memory_graph/unified_manager.py` 第 47 行 +```python +class UnifiedMemoryManager: + def __init__( + self, + short_term_enable_force_cleanup: bool = False, # 开关参数 + ): +``` + +### 2. 传递到短期层 +**文件**:`src/memory_graph/unified_manager.py` 第 100 行 +```python +"short_term": { + "enable_force_cleanup": short_term_enable_force_cleanup, # 传递给 ShortTermMemoryManager +} +``` + +### 3. 泄压逻辑实现 +**文件**:`src/memory_graph/short_term_manager.py` 第 693-726 行 +```python +def force_cleanup_overflow(self, keep_ratio: float = 0.9) -> int: + """当短期记忆超过容量时,强制删除低重要性且最早的记忆以泄压""" + if not self.enable_force_cleanup: # 检查开关 + return 0 + # ... 删除逻辑 +``` + +### 4. 触发条件 +**文件**:`src/memory_graph/unified_manager.py` 第 618-621 行 +```python +# 在自动转移循环中检测 +if occupancy_ratio >= 1.0 and not transfer_cache: + removed = self.short_term_manager.force_cleanup_overflow() + if removed > 0: + logger.warning(f"短期记忆占用率 {occupancy_ratio:.0%},已强制删除 {removed} 条低重要性记忆泄压") +``` + +--- + +## 🔄 运行机制 + +### 触发条件(同时满足) +1. ✅ 开关已开启(`enable_force_cleanup=True`) +2. ✅ 短期记忆占用率 ≥ 100%(`len(memories) >= max_memories`) +3. ✅ 当前没有待转移批次(`transfer_cache` 为空) + +### 删除策略 +**排序规则**:双重排序,先按重要性升序,再按创建时间升序 +```python +sorted_memories = sorted(self.memories, key=lambda m: (m.importance, m.created_at)) +``` + +**删除数量**:删除到容量的 90% +```python +current = len(self.memories) # 当前记忆数 +limit = int(self.max_memories * 0.9) # 目标保留数 +remove_count = current - limit # 需要删除的数量 +``` + +**示例**: +- 容量上限 `max_memories=30` +- 当前记忆数 `35` → 删除 `35 - 27 = 8` 条最低优先级记忆 +- 优先删除:重要性 0.1 且创建于 10 分钟前的记忆 + +### 持久化 +- 使用 `asyncio.create_task(self._save_to_disk())` 异步保存 +- **不阻塞**消息处理主流程 + +--- + +## 📊 性能影响 + +| 场景 | 开关状态 | 行为 | 适用场景 | +|------|---------|------|---------| +| 高频消息 | ✅ 开启 | 自动泄压,防止阻塞 | 群聊、客服场景 | +| 低频消息 | ❌ 关闭 | 仅依赖自动转移 | 私聊、低活跃群 | +| 调试阶段 | ❌ 关闭 | 便于观察记忆堆积 | 开发测试 | + +**日志示例**(开启后): +``` +[WARNING] 短期记忆压力泄压: 移除 8 条 (当前 27/30) +[WARNING] 短期记忆占用率 100%,已强制删除 8 条低重要性记忆泄压 +``` + +--- + +## 🚨 注意事项 + +### ⚠️ 何时开启 +- ✅ **推荐开启**:高频群聊、客服机器人、24/7 运行场景 +- ❌ **不建议开启**:需要完整保留所有短期记忆、调试阶段 + +### ⚠️ 潜在影响 +- 低重要性记忆可能被删除,**不会转移到长期记忆** +- 如需保留所有记忆,应调大 `max_memories` 或关闭此功能 + +### ⚠️ 与自动转移的协同 +本功能是**兜底机制**,正常情况下: +1. 优先触发自动转移(占用率 ≥ 50%) +2. 高重要性记忆转移到长期层 +3. 仅当转移来不及时,泄压才会触发 + +--- + +## 🔙 回滚与禁用 + +### 临时禁用(无需重启) +```python +# 运行时修改(如果您能访问管理器实例) +unified_manager.short_term_manager.enable_force_cleanup = False +``` + +### 永久禁用 +**配置文件方式**: +```toml +[memory] +short_term_enable_force_cleanup = false # 或直接删除此行 +``` + +**代码方式**: +```python +manager = UnifiedMemoryManager( + short_term_enable_force_cleanup=False, # 显式关闭 +) +``` + +--- + +## 📚 相关文档 + +- [三层记忆系统用户指南](../../docs/three_tier_memory_user_guide.md) +- [记忆图谱架构](../../docs/memory_graph_guide.md) +- [统一调度器指南](../../docs/unified_scheduler_guide.md) + +--- + +**最后更新**:2025年12月16日 diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index 1caedc345..3b4fbcae6 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -9,7 +9,7 @@ from collections.abc import Iterable import networkx as nx from src.common.logger import get_logger -from src.memory_graph.models import Memory, MemoryEdge +from src.memory_graph.models import EdgeType, Memory, MemoryEdge logger = get_logger(__name__) @@ -159,9 +159,6 @@ class GraphStore: # 1.5. 注销记忆中的边的邻接索引记录 self._unregister_memory_edges(memory) - # 1.5. 注销记忆中的边的邻接索引记录 - self._unregister_memory_edges(memory) - # 2. 添加节点到图 if not self.graph.has_node(node_id): from datetime import datetime @@ -201,6 +198,9 @@ class GraphStore: ) memory.nodes.append(new_node) + # 5. 重新注册记忆中的边到邻接索引 + self._register_memory_edges(memory) + logger.debug(f"添加节点成功: {node_id} -> {memory_id}") return True @@ -926,12 +926,23 @@ class GraphStore: mem_edge = MemoryEdge.from_dict(edge_dict) except Exception: # 兼容性:直接构造对象 + # 确保 edge_type 是 EdgeType 枚举 + edge_type_value = edge_dict["edge_type"] + if isinstance(edge_type_value, str): + try: + edge_type_enum = EdgeType(edge_type_value) + except ValueError: + logger.warning(f"未知的边类型: {edge_type_value}, 使用默认值") + edge_type_enum = EdgeType.RELATION + else: + edge_type_enum = edge_type_value + mem_edge = MemoryEdge( id=edge_dict["id"] or "", source_id=edge_dict["source_id"], target_id=edge_dict["target_id"], relation=edge_dict["relation"], - edge_type=edge_dict["edge_type"], + edge_type=edge_type_enum, importance=edge_dict.get("importance", 0.5), metadata=edge_dict.get("metadata", {}), ) diff --git a/src/plugin_system/base/base_interest_calculator.py b/src/plugin_system/base/base_interest_calculator.py index 17ce66c0c..c8192a74f 100644 --- a/src/plugin_system/base/base_interest_calculator.py +++ b/src/plugin_system/base/base_interest_calculator.py @@ -117,10 +117,17 @@ class BaseInterestCalculator(ABC): """ try: self._enabled = True + # 子类可以重写此方法执行自定义初始化 + await self.on_initialize() return True - except Exception: + except Exception as e: + logger.error(f"初始化兴趣计算器失败: {e}") self._enabled = False return False + + async def on_initialize(self): + """子类可重写的初始化钩子""" + pass async def cleanup(self) -> bool: """清理组件资源 @@ -129,10 +136,17 @@ class BaseInterestCalculator(ABC): bool: 清理是否成功 """ try: + # 子类可以重写此方法执行自定义清理 + await self.on_cleanup() self._enabled = False return True - except Exception: + except Exception as e: + logger.error(f"清理兴趣计算器失败: {e}") return False + + async def on_cleanup(self): + """子类可重写的清理钩子""" + pass @property def is_enabled(self) -> bool: diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py index 1650e11d6..a9765f70c 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/user_fact_tool.py @@ -39,7 +39,7 @@ class UserFactTool(BaseTool): ("info_value", ToolParamType.STRING, "具体内容,如'11月23日'、'程序员'、'想开咖啡店'", True, None), ] available_for_llm = True - history_ttl = 5 + history_ttl = 0 async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行关键信息记录 diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py index 474c6e7de..4f678e02e 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py @@ -85,7 +85,7 @@ class UserProfileTool(BaseTool): ("key_info_value", ToolParamType.STRING, "具体信息内容(必须是具体值如'11月23日'、'上海')", False, None), ] available_for_llm = True - history_ttl = 1 + history_ttl = 0 async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行用户画像更新(异步后台执行,不阻塞回复) diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py index cb0a3ebeb..128f1a706 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py @@ -75,12 +75,12 @@ class PromptBuilder: # 1.6. 构建自定义决策提示词块 custom_decision_block = self._build_custom_decision_block() - # 2. 使用 context_builder 获取关系、记忆、工具、表达习惯等 - context_data = await self._build_context_data(user_name, chat_stream, user_id) - relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。") - memory_block = context_data.get("memory_block", "") - tool_info = context_data.get("tool_info", "") - expression_habits = self._build_combined_expression_block(context_data.get("expression_habits", "")) + # 2. Planner(分离模式)不做重型上下文构建:记忆检索/工具信息/表达习惯检索等会显著拖慢处理 + # 这些信息留给 Replyer(生成最终回复文本)阶段再获取。 + relation_block = "" + memory_block = "" + tool_info = "" + expression_habits = "" # 3. 构建活动流 activity_stream = await self._build_activity_stream(session, user_name) diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index 4fcc20ec8..ee8cdfc4a 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -414,7 +414,22 @@ class NapcatAdapterPlugin(BasePlugin): "enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复处理"), "enable_reply_at": ConfigField(type=bool, default=True, description="是否在回复时自动@原消息发送者"), "reply_at_rate": ConfigField(type=float, default=0.5, description="回复时@的概率(0.0-1.0)"), - "enable_video_processing": ConfigField(type=bool, default=True, description="是否启用视频消息处理(下载和解析)"), + # ========== 视频消息处理配置 ========== + "enable_video_processing": ConfigField( + type=bool, + default=True, + description="是否启用视频消息处理(下载和解析)。关闭后视频消息将显示为 [视频消息] 占位符,不会进行下载" + ), + "video_max_size_mb": ConfigField( + type=int, + default=100, + description="允许下载的视频文件最大大小(MB),超过此大小的视频将被跳过" + ), + "video_download_timeout": ConfigField( + type=int, + default=60, + description="视频下载超时时间(秒),若超时将中止下载" + ), }, } diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py index b2afa45a5..864960f6e 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py @@ -37,10 +37,21 @@ class MessageHandler: def __init__(self, adapter: "NapcatAdapter"): self.adapter = adapter self.plugin_config: dict[str, Any] | None = None + self._video_downloader = None def set_plugin_config(self, config: dict[str, Any]) -> None: - """设置插件配置""" + """设置插件配置,并根据配置初始化视频下载器""" self.plugin_config = config + + # 如果启用了视频处理,根据配置初始化视频下载器 + if config_api.get_plugin_config(config, "features.enable_video_processing", True): + from ..video_handler import VideoDownloader + + max_size = config_api.get_plugin_config(config, "features.video_max_size_mb", 100) + timeout = config_api.get_plugin_config(config, "features.video_download_timeout", 60) + + self._video_downloader = VideoDownloader(max_size_mb=max_size, download_timeout=timeout) + logger.debug(f"视频下载器已初始化: max_size={max_size}MB, timeout={timeout}s") async def handle_raw_message(self, raw: dict[str, Any]): """ @@ -105,6 +116,11 @@ class MessageHandler: if seg_message: seg_list.append(seg_message) + # 防御性检查:确保至少有一个消息段,避免消息为空导致构建失败 + if not seg_list: + logger.warning("消息内容为空,添加占位符文本") + seg_list.append({"type": "text", "data": "[消息内容为空]"}) + msg_builder.format_info( content_format=[seg["type"] for seg in seg_list], accept_format=ACCEPT_FORMAT, @@ -302,7 +318,7 @@ class MessageHandler: video_source = file_path if file_path else video_url if not video_source: logger.warning("视频消息缺少URL或文件路径信息") - return None + return {"type": "text", "data": "[视频消息]"} try: if file_path and Path(file_path).exists(): @@ -320,14 +336,17 @@ class MessageHandler: }, } elif video_url: - # URL下载处理 - from ..video_handler import get_video_downloader - video_downloader = get_video_downloader() - download_result = await video_downloader.download_video(video_url) + # URL下载处理 - 使用配置中的下载器实例 + downloader = self._video_downloader + if not downloader: + from ..video_handler import get_video_downloader + downloader = get_video_downloader() + + download_result = await downloader.download_video(video_url) if not download_result["success"]: logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}") - return None + return {"type": "text", "data": f"[视频消息] ({download_result.get('error', '下载失败')})"} video_base64 = base64.b64encode(download_result["data"]).decode("utf-8") logger.debug(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB") @@ -343,11 +362,11 @@ class MessageHandler: } else: logger.warning("既没有有效的本地文件路径,也没有有效的视频URL") - return None + return {"type": "text", "data": "[视频消息]"} except Exception as e: logger.error(f"视频消息处理失败: {e!s}") - return None + return {"type": "text", "data": "[视频消息处理出错]"} async def _handle_rps_message(self, segment: dict) -> SegPayload: """处理猜拳消息"""