25
README.md
25
README.md
@@ -25,9 +25,11 @@
|
|||||||
|
|
||||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||||
|
|
||||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互。
|
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。
|
||||||
|
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。
|
||||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||||
- 💝 **情感表达系统**:丰富的表情包和情绪表达。
|
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||||
|
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||||
- 🧠 **持久记忆系统**:基于图的长期记忆存储。
|
- 🧠 **持久记忆系统**:基于图的长期记忆存储。
|
||||||
- 🔄 **动态人格系统**:自适应的性格特征和表达方式。
|
- 🔄 **动态人格系统**:自适应的性格特征和表达方式。
|
||||||
|
|
||||||
@@ -44,11 +46,10 @@
|
|||||||
|
|
||||||
## 🔥 更新和安装
|
## 🔥 更新和安装
|
||||||
|
|
||||||
|
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md))
|
||||||
**最新版本: v0.8.1** ([更新日志](changelogs/changelog.md))
|
|
||||||
|
|
||||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/tag/v0.1.0)下载最新启动器
|
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||||
**GitHub 分支说明:**
|
**GitHub 分支说明:**
|
||||||
- `main`: 稳定发布版本(推荐)
|
- `main`: 稳定发布版本(推荐)
|
||||||
- `dev`: 开发测试版本(不稳定)
|
- `dev`: 开发测试版本(不稳定)
|
||||||
@@ -68,11 +69,17 @@
|
|||||||
|
|
||||||
## 💬 讨论
|
## 💬 讨论
|
||||||
|
|
||||||
- [四群](https://qm.qq.com/q/wGePTl1UyY) |
|
**技术交流群:**
|
||||||
[一群](https://qm.qq.com/q/VQ3XZrWgMs) |
|
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||||
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
|
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||||
[五群](https://qm.qq.com/q/JxvHZnxyec) |
|
[三群](https://qm.qq.com/q/wlH5eT8OmQ) |
|
||||||
[三群](https://qm.qq.com/q/wlH5eT8OmQ)
|
[四群](https://qm.qq.com/q/wGePTl1UyY)
|
||||||
|
|
||||||
|
**聊天吹水群:**
|
||||||
|
- [五群](https://qm.qq.com/q/JxvHZnxyec)
|
||||||
|
|
||||||
|
**插件开发测试版群:**
|
||||||
|
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
|
|
||||||
|
|||||||
@@ -1,29 +1,80 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## [0.8.2] - 2025-7-5
|
## [0.9.1] - 2025-7-25
|
||||||
|
|
||||||
功能更新:
|
- 修复表达方式迁移空目录问题
|
||||||
|
- 修复reply_to空字段问题
|
||||||
|
- 将metioned bot 和 at应用到focus prompt中
|
||||||
|
|
||||||
- 新的情绪系统,麦麦现在拥有持续的情绪
|
|
||||||
-
|
|
||||||
|
|
||||||
优化和修复:
|
|
||||||
|
|
||||||
-
|
## [0.9.0] - 2025-7-25
|
||||||
- 优化no_reply逻辑
|
|
||||||
- 优化Log显示
|
### 摘要
|
||||||
- 优化关系配置
|
MaiBot 0.9.0 重磅升级!本版本带来两大核心突破:**全面重构的插件系统**提供更强大的扩展能力和管理功能;**normal和focus模式统一化处理**大幅简化架构并提升性能。同时新增s4u prompt模式优化、语音消息支持、全新情绪系统和mais4u直播互动功能,为MaiBot带来更自然、更智能的交互体验!
|
||||||
- 简化配置文件
|
|
||||||
- 修复在auto模式下,私聊会转为normal的bug
|
### 🌟 主要功能概览
|
||||||
- 修复一般过滤次序问题
|
|
||||||
- 优化normal_chat代码,采用和focus一致的关系构建
|
#### 🔌 插件系统全面重构 - 重点升级
|
||||||
- 优化计时信息和Log
|
- **完整管理API**: 全新的插件管理API,支持插件的启用、禁用、重载和卸载操作
|
||||||
- 添加回复超时检查
|
- **权限控制系统**: 为插件管理增加完善的权限控制,确保系统安全性
|
||||||
- normal的插件允许llm激活
|
- **智能依赖管理**: 优化插件依赖管理和自动注册机制,减少配置复杂度
|
||||||
- 合并action激活器
|
|
||||||
- emoji统一可选随机激活或llm激活
|
#### ⚡ Normal和Focus模式统一化处理 - 重点升级
|
||||||
- 移除observation和processor,简化focus的代码逻辑
|
- **架构统一**: 彻底统一normal和focus聊天模式,消除模式间的差异和复杂性
|
||||||
|
- **智能模式切换**: 优化频率控制和模式切换逻辑,normal可以无缝切换到focus
|
||||||
|
- **统一LLM激活**: normal模式现在支持LLM激活插件,与focus模式功能对等
|
||||||
|
- **一致的关系构建**: normal采用与focus一致的关系构建机制,提升交互质量
|
||||||
|
- **统一退出机制**: 为focus提供更合理的退出方法,简化状态管理
|
||||||
|
|
||||||
|
#### 🎯 s4u prompt模式
|
||||||
|
- **s4u prompt模式**: 新增专门的s4u prompt构建方式,提供更好的交互效果
|
||||||
|
- **配置化启用**: 可在配置文件中选择启用s4u prompt模式,灵活控制
|
||||||
|
- **兼容性保持**: 与现有系统完全兼容,可随时切换启用或禁用
|
||||||
|
|
||||||
|
#### 🎤 语音消息支持
|
||||||
|
- **Voice消息处理**: 新增对voice类型消息的支持,麦麦现在可以识别和处理语音消息(需要模型配置)
|
||||||
|
|
||||||
|
#### 全新情绪系统
|
||||||
|
- **持续情绪**: 麦麦现在拥有持续的情绪状态,情绪会影响回复风格和行为
|
||||||
|
|
||||||
|
|
||||||
|
### 💻 更新预览
|
||||||
|
|
||||||
|
#### 关系系统优化
|
||||||
|
- **prompt优化**: 优化关系prompt和person_info信息展示
|
||||||
|
- **构建间隔**: 让关系构建间隔可配置,提升灵活性
|
||||||
|
- **关系配置**: 优化关系配置,采用和focus一致的关系构建
|
||||||
|
|
||||||
|
#### 表情包系统升级
|
||||||
|
- **识别增强**: 加强emoji的识别能力,优化emoji显示
|
||||||
|
- **匹配精准**: 更精准的表情包匹配算法
|
||||||
|
|
||||||
|
#### 完善mais4u系统(需要amaidesu支持)
|
||||||
|
- **直播互动**: 新增mais4u直播功能,支持实时互动和思考状态展示
|
||||||
|
- **动作控制**: 支持眨眼、微动作、注视等多种动作适配
|
||||||
|
|
||||||
|
#### 日志系统优化
|
||||||
|
- **显示优化**: 优化Logger前缀映射、颜色格式和计时信息显示
|
||||||
|
- **级别优化**: 优化日志级别和信息过滤,提升调试体验
|
||||||
|
- **日志查看器**: 升级logger_viewer,移除无用脚本
|
||||||
|
|
||||||
|
#### 配置系统改进
|
||||||
|
- **配置简化**: 简化配置文件,让配置更加精简易懂
|
||||||
|
- **prompt显示**: 可选打开prompt显示功能
|
||||||
|
- **配置更新**: 更好的配置文件更新机制和更新内容显示
|
||||||
|
|
||||||
|
#### 问题修复与优化
|
||||||
|
|
||||||
|
- 修复normal planner没有超时退出问题,添加回复超时检查
|
||||||
|
- 重构no_reply逻辑,不再使用小模型,采用激活度决定
|
||||||
- 修复图片与文字混合兴趣值为0的情况
|
- 修复图片与文字混合兴趣值为0的情况
|
||||||
|
- 适配无兴趣度消息处理
|
||||||
|
- 优化Docker镜像构建流程,合并AMD64和ARM64构建步骤
|
||||||
|
- 移除vtb插件和take_picture_plugin,功能已由其他系统接管,移除pfc遗留代码和其他过时功能
|
||||||
|
- 移除observation和processor等冗余组件,大幅简化focus代码逻辑
|
||||||
|
- 修复了LPMM的学习问题
|
||||||
|
|
||||||
|
|
||||||
## [0.8.1] - 2025-7-5
|
## [0.8.1] - 2025-7-5
|
||||||
|
|
||||||
|
|||||||
19
changes.md
19
changes.md
@@ -20,6 +20,9 @@
|
|||||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||||
5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。
|
5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。
|
||||||
|
6. 增加了插件和组件管理的API。
|
||||||
|
7. `BaseCommand`的`execute`方法现在返回一个三元组,包含是否执行成功、可选的回复消息和是否拦截消息。
|
||||||
|
- 这意味着你终于可以动态控制是否继续后续消息的处理了。
|
||||||
|
|
||||||
# 插件系统修改
|
# 插件系统修改
|
||||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||||
@@ -45,6 +48,17 @@
|
|||||||
10. 修正了`main.py`中的错误输出。
|
10. 修正了`main.py`中的错误输出。
|
||||||
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
||||||
12. `events_manager`有了task相关逻辑了。
|
12. `events_manager`有了task相关逻辑了。
|
||||||
|
13. 现在有了插件卸载和重载功能了,也就是热插拔。
|
||||||
|
14. 实现了组件的全局启用和禁用功能。
|
||||||
|
- 通过`enable_component`和`disable_component`方法来启用或禁用组件。
|
||||||
|
- 不过这个操作不会保存到配置文件~
|
||||||
|
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
|
||||||
|
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
|
||||||
|
- 同样不保存到配置文件~
|
||||||
|
|
||||||
|
# 官方插件修改
|
||||||
|
1. `HelloWorld`插件现在有一个样例的`EventHandler`。
|
||||||
|
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
把这个看起来就很别扭的config获取方式改一下
|
把这个看起来就很别扭的config获取方式改一下
|
||||||
@@ -64,4 +78,7 @@ else:
|
|||||||
plugin_path = Path(plugin_file)
|
plugin_path = Path(plugin_file)
|
||||||
module_name = ".".join(plugin_path.parent.parts)
|
module_name = ".".join(plugin_path.parent.parts)
|
||||||
```
|
```
|
||||||
这两个区别很大的。
|
这两个区别很大的。
|
||||||
|
|
||||||
|
### 执笔BGM
|
||||||
|
塞壬唱片!
|
||||||
@@ -4,42 +4,183 @@
|
|||||||
|
|
||||||
Action是给麦麦在回复之外提供额外功能的智能组件,**由麦麦的决策系统自主选择是否使用**,具有随机性和拟人化的调用特点。Action不是直接响应用户命令,而是让麦麦根据聊天情境智能地选择合适的动作,使其行为更加自然和真实。
|
Action是给麦麦在回复之外提供额外功能的智能组件,**由麦麦的决策系统自主选择是否使用**,具有随机性和拟人化的调用特点。Action不是直接响应用户命令,而是让麦麦根据聊天情境智能地选择合适的动作,使其行为更加自然和真实。
|
||||||
|
|
||||||
### 🎯 Action的特点
|
### Action的特点
|
||||||
|
|
||||||
- 🧠 **智能激活**:麦麦根据多种条件智能判断是否使用
|
- 🧠 **智能激活**:麦麦根据多种条件智能判断是否使用
|
||||||
- 🎲 **随机性**:增加行为的不可预测性,更接近真人交流
|
- 🎲 **可随机性**:可以使用随机数激活,增加行为的不可预测性,更接近真人交流
|
||||||
- 🤖 **拟人化**:让麦麦的回应更自然、更有个性
|
- 🤖 **拟人化**:让麦麦的回应更自然、更有个性
|
||||||
- 🔄 **情境感知**:基于聊天上下文做出合适的反应
|
- 🔄 **情境感知**:基于聊天上下文做出合适的反应
|
||||||
|
|
||||||
## 🎯 两层决策机制
|
---
|
||||||
|
|
||||||
|
## 🎯 Action组件的基本结构
|
||||||
|
首先,所有的Action都应该继承`BaseAction`类。
|
||||||
|
|
||||||
|
其次,每个Action组件都应该实现以下基本信息:
|
||||||
|
```python
|
||||||
|
class ExampleAction(BaseAction):
|
||||||
|
action_name = "example_action" # 动作的唯一标识符
|
||||||
|
action_description = "这是一个示例动作" # 动作描述
|
||||||
|
activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例
|
||||||
|
mode_enable = ChatMode.ALL # 这里以 ALL 为例
|
||||||
|
associated_types = ["text", "emoji", ...] # 关联类型
|
||||||
|
parallel_action = False # 是否允许与其他Action并行执行
|
||||||
|
action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...}
|
||||||
|
# Action使用场景描述 - 帮助LLM判断何时"选择"使用
|
||||||
|
action_require = ["使用场景描述1", "使用场景描述2", ...]
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
执行Action的主要逻辑
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (是否成功, 执行结果描述)
|
||||||
|
"""
|
||||||
|
# ---- 执行动作的逻辑 ----
|
||||||
|
return True, "执行成功"
|
||||||
|
```
|
||||||
|
#### associated_types: 该Action会发送的消息类型,例如文本、表情等。
|
||||||
|
|
||||||
|
这部分由Adapter传递给处理器。
|
||||||
|
|
||||||
|
以 MaiBot-Napcat-Adapter 为例,可选项目如下:
|
||||||
|
| 类型 | 说明 | 格式 |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| text | 文本消息 | str |
|
||||||
|
| emoji | 表情消息 | str: 表情包的无头base64|
|
||||||
|
| image | 图片消息 | str: 图片的无头base64 |
|
||||||
|
| reply | 回复消息 | str: 回复的消息ID |
|
||||||
|
| voice | 语音消息 | str: wav格式语音的无头base64 |
|
||||||
|
| command | 命令消息 | 参见Adapter文档 |
|
||||||
|
| voiceurl | 语音URL消息 | str: wav格式语音的URL |
|
||||||
|
| music | 音乐消息 | str: 这首歌在网易云音乐的音乐id |
|
||||||
|
| videourl | 视频URL消息 | str: 视频的URL |
|
||||||
|
| file | 文件消息 | str: 文件的路径 |
|
||||||
|
|
||||||
|
**请知悉,对于不同的处理器,其支持的消息类型可能会有所不同。在开发时请注意。**
|
||||||
|
|
||||||
|
#### action_parameters: 该Action的参数说明。
|
||||||
|
这是一个字典,键为参数名,值为参数说明。这个字段可以帮助LLM理解如何使用这个Action,并由LLM返回对应的参数,最后传递到 Action 的 action_data 属性中。其格式与你定义的格式完全相同 **(除非LLM哈气了,返回了错误的内容)**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎯 Action 调用的决策机制
|
||||||
|
|
||||||
Action采用**两层决策机制**来优化性能和决策质量:
|
Action采用**两层决策机制**来优化性能和决策质量:
|
||||||
|
|
||||||
### 第一层:激活控制(Activation Control)
|
> 设计目的:在加载许多插件的时候降低LLM决策压力,避免让麦麦在过多的选项中纠结。
|
||||||
|
|
||||||
**激活决定麦麦是否"知道"这个Action的存在**,即这个Action是否进入决策候选池。**不被激活的Action麦麦永远不会选择**。
|
**第一层:激活控制(Activation Control)**
|
||||||
|
|
||||||
> 🎯 **设计目的**:在加载许多插件的时候降低LLM决策压力,避免让麦麦在过多的选项中纠结。
|
激活决定麦麦是否 **“知道”** 这个Action的存在,即这个Action是否进入决策候选池。不被激活的Action麦麦永远不会选择。
|
||||||
|
|
||||||
#### 激活类型说明
|
**第二层:使用决策(Usage Decision)**
|
||||||
|
|
||||||
| 激活类型 | 说明 | 使用场景 |
|
在Action被激活后,使用条件决定麦麦什么时候会 **“选择”** 使用这个Action。
|
||||||
| ------------- | ------------------------------------------- | ------------------------ |
|
|
||||||
| `NEVER` | 从不激活,Action对麦麦不可见 | 临时禁用某个Action |
|
|
||||||
| `ALWAYS` | 永远激活,Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 |
|
|
||||||
| `LLM_JUDGE` | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 |
|
|
||||||
| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 |
|
|
||||||
| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 |
|
|
||||||
|
|
||||||
#### 聊天模式控制
|
### 决策参数详解 🔧
|
||||||
|
|
||||||
| 模式 | 说明 |
|
#### 第一层:ActivationType 激活类型说明
|
||||||
| ------------------- | ------------------------ |
|
|
||||||
| `ChatMode.FOCUS` | 仅在专注聊天模式下可激活 |
|
|
||||||
| `ChatMode.NORMAL` | 仅在普通聊天模式下可激活 |
|
|
||||||
| `ChatMode.ALL` | 所有模式下都可激活 |
|
|
||||||
|
|
||||||
### 第二层:使用决策(Usage Decision)
|
| 激活类型 | 说明 | 使用场景 |
|
||||||
|
| ----------- | ---------------------------------------- | ---------------------- |
|
||||||
|
| [`NEVER`](#never-激活) | 从不激活,Action对麦麦不可见 | 临时禁用某个Action |
|
||||||
|
| [`ALWAYS`](#always-激活) | 永远激活,Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 |
|
||||||
|
| [`LLM_JUDGE`](#llm_judge-激活) | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 |
|
||||||
|
| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 |
|
||||||
|
| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 |
|
||||||
|
|
||||||
|
#### `NEVER` 激活
|
||||||
|
|
||||||
|
`ActionActivationType.NEVER` 会使得 Action 永远不会被激活
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DisabledAction(BaseAction):
|
||||||
|
activation_type = ActionActivationType.NEVER # 永远不激活
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
# 这个Action永远不会被执行
|
||||||
|
return False, "这个Action被禁用"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `ALWAYS` 激活
|
||||||
|
|
||||||
|
`ActionActivationType.ALWAYS` 会使得 Action 永远会被激活,即一直在 Action 候选池中
|
||||||
|
|
||||||
|
这种激活方式常用于核心功能,如回复或不回复。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AlwaysActivatedAction(BaseAction):
|
||||||
|
activation_type = ActionActivationType.ALWAYS # 永远激活
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
# 执行核心功能
|
||||||
|
return True, "执行了核心功能"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `LLM_JUDGE` 激活
|
||||||
|
|
||||||
|
`ActionActivationType.LLM_JUDGE`会使得这个 Action 根据 LLM 的判断来决定是否加入候选池。
|
||||||
|
|
||||||
|
而 LLM 的判断是基于代码中预设的`llm_judge_prompt`和自动提供的聊天上下文进行的。
|
||||||
|
|
||||||
|
因此使用此种方法需要实现`llm_judge_prompt`属性。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class LLMJudgedAction(BaseAction):
|
||||||
|
activation_type = ActionActivationType.LLM_JUDGE # 通过LLM判断激活
|
||||||
|
# LLM判断提示词
|
||||||
|
llm_judge_prompt = (
|
||||||
|
"判定是否需要使用这个动作的条件:\n"
|
||||||
|
"1. 用户希望调用XXX这个动作\n"
|
||||||
|
"...\n"
|
||||||
|
"请回答\"是\"或\"否\"。\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
# 根据LLM判断是否执行
|
||||||
|
return True, "执行了LLM判断功能"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `RANDOM` 激活
|
||||||
|
|
||||||
|
`ActionActivationType.RANDOM`会使得这个 Action 根据随机概率决定是否加入候选池。
|
||||||
|
|
||||||
|
概率则由代码中的`random_activation_probability`控制。在内部实现中我们使用了`random.random()`来生成一个0到1之间的随机数,并与这个概率进行比较。
|
||||||
|
|
||||||
|
因此使用这个方法需要实现`random_activation_probability`属性。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SurpriseAction(BaseAction):
|
||||||
|
activation_type = ActionActivationType.RANDOM # 基于随机概率激活
|
||||||
|
# 随机激活概率
|
||||||
|
random_activation_probability = 0.1 # 10%概率激活
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
# 执行惊喜动作
|
||||||
|
return True, "发送了惊喜内容"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `KEYWORD` 激活
|
||||||
|
|
||||||
|
`ActionActivationType.KEYWORD`会使得这个 Action 在检测到特定关键词时激活。
|
||||||
|
|
||||||
|
关键词由代码中的`activation_keywords`定义,而`keyword_case_sensitive`则控制关键词匹配时是否区分大小写。在内部实现中,我们使用了`in`操作符来检查消息内容是否包含这些关键词。
|
||||||
|
|
||||||
|
因此,使用此种方法需要实现`activation_keywords`和`keyword_case_sensitive`属性。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GreetingAction(BaseAction):
|
||||||
|
activation_type = ActionActivationType.KEYWORD # 关键词激活
|
||||||
|
activation_keywords = ["你好", "hello", "hi", "嗨"] # 关键词配置
|
||||||
|
keyword_case_sensitive = False # 不区分大小写
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
# 执行问候逻辑
|
||||||
|
return True, "发送了问候"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 第二层:使用决策
|
||||||
|
|
||||||
**在Action被激活后,使用条件决定麦麦什么时候会"选择"使用这个Action**。
|
**在Action被激活后,使用条件决定麦麦什么时候会"选择"使用这个Action**。
|
||||||
|
|
||||||
@@ -49,17 +190,16 @@ Action采用**两层决策机制**来优化性能和决策质量:
|
|||||||
- `action_parameters`:所需参数,影响Action的可执行性
|
- `action_parameters`:所需参数,影响Action的可执行性
|
||||||
- 当前聊天上下文和麦麦的决策逻辑
|
- 当前聊天上下文和麦麦的决策逻辑
|
||||||
|
|
||||||
### 🎬 决策流程示例
|
---
|
||||||
|
|
||||||
假设有一个"发送表情"Action:
|
### 决策流程示例
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class EmojiAction(BaseAction):
|
class EmojiAction(BaseAction):
|
||||||
# 第一层:激活控制
|
# 第一层:激活控制
|
||||||
focus_activation_type = ActionActivationType.RANDOM # 专注模式下随机激活
|
activation_type = ActionActivationType.RANDOM # 随机激活
|
||||||
normal_activation_type = ActionActivationType.KEYWORD # 普通模式下关键词激活
|
random_activation_probability = 0.1 # 10%概率激活
|
||||||
activation_keywords = ["表情", "emoji", "😊"]
|
|
||||||
|
|
||||||
# 第二层:使用决策
|
# 第二层:使用决策
|
||||||
action_require = [
|
action_require = [
|
||||||
"表达情绪时可以选择使用",
|
"表达情绪时可以选择使用",
|
||||||
@@ -72,311 +212,85 @@ class EmojiAction(BaseAction):
|
|||||||
|
|
||||||
1. **第一层激活判断**:
|
1. **第一层激活判断**:
|
||||||
|
|
||||||
- 普通模式:只有当用户消息包含"表情"、"emoji"或"😊"时,麦麦才"知道"可以使用这个Action
|
- 使用随机数进行决策,当`random.random() < self.random_activation_probability`时,麦麦才"知道"可以使用这个Action
|
||||||
- 专注模式:随机激活,有概率让麦麦"看到"这个Action
|
|
||||||
2. **第二层使用决策**:
|
2. **第二层使用决策**:
|
||||||
|
|
||||||
- 即使Action被激活,麦麦还会根据 `action_require`中的条件判断是否真正选择使用
|
- 即使Action被激活,麦麦还会根据 `action_require` 中的条件判断是否真正选择使用
|
||||||
- 例如:如果刚刚已经发过表情,根据"不要连续发送多个表情"的要求,麦麦可能不会选择这个Action
|
- 例如:如果刚刚已经发过表情,根据"不要连续发送多个表情"的要求,麦麦可能不会选择这个Action
|
||||||
|
|
||||||
## 📋 Action必须项清单
|
---
|
||||||
|
|
||||||
每个Action类都**必须**包含以下属性:
|
|
||||||
|
|
||||||
### 1. 激活控制必须项
|
|
||||||
|
|
||||||
|
## Action 内置属性说明
|
||||||
```python
|
```python
|
||||||
# 专注模式下的激活类型
|
class BaseAction:
|
||||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
|
||||||
|
|
||||||
# 普通模式下的激活类型
|
|
||||||
normal_activation_type = ActionActivationType.KEYWORD
|
|
||||||
|
|
||||||
# 启用的聊天模式
|
|
||||||
mode_enable = ChatMode.ALL
|
|
||||||
|
|
||||||
# 是否允许与其他Action并行执行
|
|
||||||
parallel_action = False
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 基本信息必须项
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Action的唯一标识名称
|
|
||||||
action_name = "my_action"
|
|
||||||
|
|
||||||
# Action的功能描述
|
|
||||||
action_description = "描述这个Action的具体功能和用途"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 功能定义必须项
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Action参数定义 - 告诉LLM执行时需要什么参数
|
|
||||||
action_parameters = {
|
|
||||||
"param1": "参数1的说明",
|
|
||||||
"param2": "参数2的说明"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Action使用场景描述 - 帮助LLM判断何时"选择"使用
|
|
||||||
action_require = [
|
|
||||||
"使用场景描述1",
|
|
||||||
"使用场景描述2"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 关联的消息类型 - 说明Action能处理什么类型的内容
|
|
||||||
associated_types = ["text", "emoji", "image"]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 执行方法必须项
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""
|
|
||||||
执行Action的主要逻辑
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否成功, 执行结果描述)
|
|
||||||
"""
|
|
||||||
# 执行动作的代码
|
|
||||||
success = True
|
|
||||||
message = "动作执行成功"
|
|
||||||
|
|
||||||
return success, message
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔧 激活类型详解
|
|
||||||
|
|
||||||
### KEYWORD激活
|
|
||||||
|
|
||||||
当检测到特定关键词时激活Action:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class GreetingAction(BaseAction):
|
|
||||||
focus_activation_type = ActionActivationType.KEYWORD
|
|
||||||
normal_activation_type = ActionActivationType.KEYWORD
|
|
||||||
|
|
||||||
# 关键词配置
|
|
||||||
activation_keywords = ["你好", "hello", "hi", "嗨"]
|
|
||||||
keyword_case_sensitive = False # 不区分大小写
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# 执行问候逻辑
|
|
||||||
return True, "发送了问候"
|
|
||||||
```
|
|
||||||
|
|
||||||
### LLM_JUDGE激活
|
|
||||||
|
|
||||||
通过LLM智能判断是否激活:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class HelpAction(BaseAction):
|
|
||||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
|
||||||
normal_activation_type = ActionActivationType.LLM_JUDGE
|
|
||||||
|
|
||||||
# LLM判断提示词
|
|
||||||
llm_judge_prompt = """
|
|
||||||
判定是否需要使用帮助动作的条件:
|
|
||||||
1. 用户表达了困惑或需要帮助
|
|
||||||
2. 用户提出了问题但没有得到满意答案
|
|
||||||
3. 对话中出现了技术术语或复杂概念
|
|
||||||
|
|
||||||
请回答"是"或"否"。
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# 执行帮助逻辑
|
|
||||||
return True, "提供了帮助"
|
|
||||||
```
|
|
||||||
|
|
||||||
### RANDOM激活
|
|
||||||
|
|
||||||
基于随机概率激活:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SurpriseAction(BaseAction):
|
|
||||||
focus_activation_type = ActionActivationType.RANDOM
|
|
||||||
normal_activation_type = ActionActivationType.RANDOM
|
|
||||||
|
|
||||||
# 随机激活概率
|
|
||||||
random_activation_probability = 0.1 # 10%概率激活
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# 执行惊喜动作
|
|
||||||
return True, "发送了惊喜内容"
|
|
||||||
```
|
|
||||||
|
|
||||||
### ALWAYS激活
|
|
||||||
|
|
||||||
永远激活,常用于核心功能:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class CoreAction(BaseAction):
|
|
||||||
focus_activation_type = ActionActivationType.ALWAYS
|
|
||||||
normal_activation_type = ActionActivationType.ALWAYS
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# 执行核心功能
|
|
||||||
return True, "执行了核心功能"
|
|
||||||
```
|
|
||||||
|
|
||||||
### NEVER激活
|
|
||||||
|
|
||||||
从不激活,用于临时禁用:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class DisabledAction(BaseAction):
|
|
||||||
focus_activation_type = ActionActivationType.NEVER
|
|
||||||
normal_activation_type = ActionActivationType.NEVER
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# 这个方法不会被调用
|
|
||||||
return False, "已禁用"
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📚 BaseAction内置属性和方法
|
|
||||||
|
|
||||||
### 内置属性
|
|
||||||
|
|
||||||
```python
|
|
||||||
class MyAction(BaseAction):
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 消息相关属性
|
# 消息相关属性
|
||||||
self.message # 当前消息对象
|
self.log_prefix: str # 日志前缀
|
||||||
self.chat_stream # 聊天流对象
|
self.group_id: str # 群组ID
|
||||||
self.user_id # 用户ID
|
self.group_name: str # 群组名称
|
||||||
self.user_nickname # 用户昵称
|
self.user_id: str # 用户ID
|
||||||
self.platform # 平台类型 (qq, telegram等)
|
self.user_nickname: str # 用户昵称
|
||||||
self.chat_id # 聊天ID
|
self.platform: str # 平台类型 (qq, telegram等)
|
||||||
self.is_group # 是否群聊
|
self.chat_id: str # 聊天ID
|
||||||
|
self.chat_stream: ChatStream # 聊天流对象
|
||||||
# Action相关属性
|
self.is_group: bool # 是否群聊
|
||||||
self.action_data # Action执行时的数据
|
|
||||||
self.thinking_id # 思考ID
|
|
||||||
self.matched_groups # 匹配到的组(如果有正则匹配)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 内置方法
|
# 消息体
|
||||||
|
self.action_message: dict # 消息数据
|
||||||
|
|
||||||
|
# Action相关属性
|
||||||
|
self.action_data: dict # Action执行时的数据
|
||||||
|
self.thinking_id: str # 思考ID
|
||||||
|
```
|
||||||
|
action_message为一个字典,包含的键值对如下(省略了不必要的键值对)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class MyAction(BaseAction):
|
{
|
||||||
|
"message_id": "1234567890", # 消息id,str
|
||||||
|
"time": 1627545600.0, # 时间戳,float
|
||||||
|
"chat_id": "abcdef123456", # 聊天ID,str
|
||||||
|
"reply_to": None, # 回复消息id,str或None
|
||||||
|
"interest_value": 0.85, # 兴趣值,float
|
||||||
|
"is_mentioned": True, # 是否被提及,bool
|
||||||
|
"chat_info_last_active_time": 1627548600.0, # 最后活跃时间,float
|
||||||
|
"processed_plain_text": None, # 处理后的文本,str或None
|
||||||
|
"additional_config": None, # Adapter传来的additional_config,dict或None
|
||||||
|
"is_emoji": False, # 是否为表情,bool
|
||||||
|
"is_picid": False, # 是否为图片ID,bool
|
||||||
|
"is_command": False # 是否为命令,bool
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
部分值的格式请自行查询数据库。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Action 内置方法说明
|
||||||
|
```python
|
||||||
|
class BaseAction:
|
||||||
# 配置相关
|
# 配置相关
|
||||||
def get_config(self, key: str, default=None):
|
def get_config(self, key: str, default=None):
|
||||||
"""获取配置值"""
|
"""获取插件配置值,使用嵌套键访问"""
|
||||||
pass
|
|
||||||
|
|
||||||
# 消息发送相关
|
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||||
async def send_text(self, text: str):
|
"""等待新消息或超时"""
|
||||||
|
|
||||||
|
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
|
||||||
"""发送文本消息"""
|
"""发送文本消息"""
|
||||||
pass
|
|
||||||
|
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||||
async def send_emoji(self, emoji_base64: str):
|
|
||||||
"""发送表情包"""
|
"""发送表情包"""
|
||||||
pass
|
|
||||||
|
async def send_image(self, image_base64: str) -> bool:
|
||||||
async def send_image(self, image_base64: str):
|
|
||||||
"""发送图片"""
|
"""发送图片"""
|
||||||
pass
|
|
||||||
|
async def send_custom(self, message_type: str, content: str, typing: bool = False, reply_to: str = "") -> bool:
|
||||||
# 动作记录相关
|
"""发送自定义类型消息"""
|
||||||
async def store_action_info(self, **kwargs):
|
|
||||||
"""记录动作信息"""
|
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||||
pass
|
"""存储动作信息到数据库"""
|
||||||
|
|
||||||
|
async def send_command(self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True) -> bool:
|
||||||
|
"""发送命令消息"""
|
||||||
```
|
```
|
||||||
|
具体参数与用法参见`BaseAction`基类的定义。
|
||||||
## 🎯 完整Action示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
class ExampleAction(BaseAction):
|
|
||||||
"""示例Action - 展示完整的Action结构"""
|
|
||||||
|
|
||||||
# === 激活控制 ===
|
|
||||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
|
||||||
normal_activation_type = ActionActivationType.KEYWORD
|
|
||||||
mode_enable = ChatMode.ALL
|
|
||||||
parallel_action = False
|
|
||||||
|
|
||||||
# 关键词激活配置
|
|
||||||
activation_keywords = ["示例", "测试", "example"]
|
|
||||||
keyword_case_sensitive = False
|
|
||||||
|
|
||||||
# LLM判断提示词
|
|
||||||
llm_judge_prompt = "当用户需要示例或测试功能时激活"
|
|
||||||
|
|
||||||
# 随机激活概率(如果使用RANDOM类型)
|
|
||||||
random_activation_probability = 0.2
|
|
||||||
|
|
||||||
# === 基本信息 ===
|
|
||||||
action_name = "example_action"
|
|
||||||
action_description = "这是一个示例Action,用于演示Action的完整结构"
|
|
||||||
|
|
||||||
# === 功能定义 ===
|
|
||||||
action_parameters = {
|
|
||||||
"content": "要处理的内容",
|
|
||||||
"type": "处理类型",
|
|
||||||
"options": "可选配置"
|
|
||||||
}
|
|
||||||
|
|
||||||
action_require = [
|
|
||||||
"用户需要示例功能时使用",
|
|
||||||
"适合用于测试和演示",
|
|
||||||
"不要在正式对话中频繁使用"
|
|
||||||
]
|
|
||||||
|
|
||||||
associated_types = ["text", "emoji"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""执行示例Action"""
|
|
||||||
try:
|
|
||||||
# 获取Action参数
|
|
||||||
content = self.action_data.get("content", "默认内容")
|
|
||||||
action_type = self.action_data.get("type", "default")
|
|
||||||
|
|
||||||
# 获取配置
|
|
||||||
enable_feature = self.get_config("example.enable_advanced", False)
|
|
||||||
max_length = self.get_config("example.max_length", 100)
|
|
||||||
|
|
||||||
# 执行具体逻辑
|
|
||||||
if action_type == "greeting":
|
|
||||||
await self.send_text(f"你好!这是示例内容:{content}")
|
|
||||||
elif action_type == "info":
|
|
||||||
await self.send_text(f"信息:{content[:max_length]}")
|
|
||||||
else:
|
|
||||||
await self.send_text("执行了示例Action")
|
|
||||||
|
|
||||||
# 记录动作信息
|
|
||||||
await self.store_action_info(
|
|
||||||
action_build_into_prompt=True,
|
|
||||||
action_prompt_display=f"执行了示例动作:{action_type}",
|
|
||||||
action_done=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return True, f"示例Action执行成功,类型:{action_type}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"执行失败:{str(e)}"
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 最佳实践
|
|
||||||
|
|
||||||
### 1. Action设计原则
|
|
||||||
|
|
||||||
- **单一职责**:每个Action只负责一个明确的功能
|
|
||||||
- **智能激活**:合理选择激活类型,避免过度激活
|
|
||||||
- **清晰描述**:提供准确的`action_require`帮助LLM决策
|
|
||||||
- **错误处理**:妥善处理执行过程中的异常情况
|
|
||||||
|
|
||||||
### 2. 性能优化
|
|
||||||
|
|
||||||
- **激活控制**:使用合适的激活类型减少不必要的LLM调用
|
|
||||||
- **并行执行**:谨慎设置`parallel_action`,避免冲突
|
|
||||||
- **资源管理**:及时释放占用的资源
|
|
||||||
|
|
||||||
### 3. 调试技巧
|
|
||||||
|
|
||||||
- **日志记录**:在关键位置添加日志
|
|
||||||
- **参数验证**:检查`action_data`的有效性
|
|
||||||
- **配置测试**:测试不同配置下的行为
|
|
||||||
@@ -8,6 +8,25 @@
|
|||||||
from src.plugin_system.apis import emoji_api
|
from src.plugin_system.apis import emoji_api
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🆕 **二步走识别优化**
|
||||||
|
|
||||||
|
从最新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
|
||||||
|
|
||||||
|
### **收到表情包时的识别流程**
|
||||||
|
1. **第一步**:VLM视觉分析 - 生成详细描述
|
||||||
|
2. **第二步**:LLM情感分析 - 基于详细描述提取核心情感标签
|
||||||
|
3. **缓存机制**:将情感标签缓存到数据库,详细描述保存到Images表
|
||||||
|
|
||||||
|
### **注册表情包时的优化**
|
||||||
|
- **智能复用**:优先从Images表获取已有的详细描述
|
||||||
|
- **避免重复**:如果表情包之前被收到过,跳过VLM调用
|
||||||
|
- **性能提升**:减少不必要的AI调用,降低延时和成本
|
||||||
|
|
||||||
|
### **缓存策略**
|
||||||
|
- **ImageDescriptions表**:缓存最终的情感标签(用于快速显示)
|
||||||
|
- **Images表**:保存详细描述(用于注册时复用)
|
||||||
|
- **双重检查**:防止并发情况下的重复生成
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 表情包获取
|
### 1. 表情包获取
|
||||||
|
|||||||
@@ -77,9 +77,8 @@ class TimeCommand(BaseCommand):
|
|||||||
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
||||||
command_help = "查询当前时间"
|
command_help = "查询当前时间"
|
||||||
command_examples = ["/time"]
|
command_examples = ["/time"]
|
||||||
intercept_message = True # 拦截消息,不让其他组件处理
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> Tuple[bool, str, bool]:
|
||||||
"""执行时间查询"""
|
"""执行时间查询"""
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
@@ -92,7 +91,7 @@ class TimeCommand(BaseCommand):
|
|||||||
message = f"⏰ 当前时间:{time_str}"
|
message = f"⏰ 当前时间:{time_str}"
|
||||||
await self.send_text(message)
|
await self.send_text(message)
|
||||||
|
|
||||||
return True, f"显示了当前时间: {time_str}"
|
return True, f"显示了当前时间: {time_str}", True
|
||||||
|
|
||||||
|
|
||||||
class PrintMessage(BaseEventHandler):
|
class PrintMessage(BaseEventHandler):
|
||||||
@@ -118,17 +117,17 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
"""Hello World插件 - 你的第一个MaiCore插件"""
|
"""Hello World插件 - 你的第一个MaiCore插件"""
|
||||||
|
|
||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "hello_world_plugin" # 内部标识符
|
plugin_name: str = "hello_world_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin: bool = True
|
||||||
dependencies = [] # 插件依赖列表
|
dependencies: List[str] = [] # 插件依赖列表
|
||||||
python_dependencies = [] # Python包依赖列表
|
python_dependencies: List[str] = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml" # 配置文件名
|
config_file_name: str = "config.toml" # 配置文件名
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
|
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema = {
|
config_schema: dict = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
||||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
{
|
|
||||||
"manifest_version": 1,
|
|
||||||
"name": "AI拍照插件 (Take Picture Plugin)",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "基于AI图像生成的拍照插件,可以生成逼真的自拍照片,支持照片存储和展示功能。",
|
|
||||||
"author": {
|
|
||||||
"name": "SengokuCola",
|
|
||||||
"url": "https://github.com/SengokuCola"
|
|
||||||
},
|
|
||||||
"license": "GPL-v3.0-or-later",
|
|
||||||
|
|
||||||
"host_application": {
|
|
||||||
"min_version": "0.9.0"
|
|
||||||
},
|
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
|
||||||
"keywords": ["camera", "photo", "selfie", "ai", "image", "generation"],
|
|
||||||
"categories": ["AI Tools", "Image Processing", "Entertainment"],
|
|
||||||
|
|
||||||
"default_locale": "zh-CN",
|
|
||||||
"locales_path": "_locales",
|
|
||||||
|
|
||||||
"plugin_info": {
|
|
||||||
"is_built_in": false,
|
|
||||||
"plugin_type": "image_generator",
|
|
||||||
"api_dependencies": ["volcengine"],
|
|
||||||
"components": [
|
|
||||||
{
|
|
||||||
"type": "action",
|
|
||||||
"name": "take_picture",
|
|
||||||
"description": "生成一张用手机拍摄的照片,比如自拍或者近照",
|
|
||||||
"activation_modes": ["keyword"],
|
|
||||||
"keywords": ["拍张照", "自拍", "发张照片", "看看你", "你的照片"]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "show_recent_pictures",
|
|
||||||
"description": "展示最近生成的5张照片",
|
|
||||||
"pattern": "/show_pics"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"features": [
|
|
||||||
"AI驱动的自拍照生成",
|
|
||||||
"个性化照片风格",
|
|
||||||
"照片历史记录",
|
|
||||||
"缓存机制优化",
|
|
||||||
"火山引擎API集成"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,517 +0,0 @@
|
|||||||
"""
|
|
||||||
拍照插件
|
|
||||||
|
|
||||||
功能特性:
|
|
||||||
- Action: 生成一张自拍照,prompt由人设和模板生成
|
|
||||||
- Command: 展示最近生成的照片
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
#此插件并不完善
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
包含组件:
|
|
||||||
- 拍照Action - 生成自拍照
|
|
||||||
- 展示照片Command - 展示最近生成的照片
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Tuple, Type, Optional
|
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
import base64
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
|
||||||
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
|
||||||
from src.plugin_system import register_plugin
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("take_picture_plugin")
|
|
||||||
|
|
||||||
# 定义数据目录常量
|
|
||||||
DATA_DIR = os.path.join("data", "take_picture_data")
|
|
||||||
# 确保数据目录存在
|
|
||||||
os.makedirs(DATA_DIR, exist_ok=True)
|
|
||||||
# 创建全局锁
|
|
||||||
file_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
class TakePictureAction(BaseAction):
|
|
||||||
"""生成一张自拍照"""
|
|
||||||
|
|
||||||
focus_activation_type = ActionActivationType.KEYWORD
|
|
||||||
normal_activation_type = ActionActivationType.KEYWORD
|
|
||||||
mode_enable = ChatMode.ALL
|
|
||||||
parallel_action = False
|
|
||||||
|
|
||||||
action_name = "take_picture"
|
|
||||||
action_description = "生成一张用手机拍摄,比如自拍或者近照"
|
|
||||||
activation_keywords = ["拍张照", "自拍", "发张照片", "看看你", "你的照片"]
|
|
||||||
keyword_case_sensitive = False
|
|
||||||
|
|
||||||
action_parameters = {}
|
|
||||||
|
|
||||||
action_require = ["当用户想看你的照片时使用", "当用户让你发自拍时使用当想随手拍眼前的场景时使用"]
|
|
||||||
|
|
||||||
associated_types = ["text", "image"]
|
|
||||||
|
|
||||||
# 内置的Prompt模板,如果配置文件中没有定义,将使用这些模板
|
|
||||||
DEFAULT_PROMPT_TEMPLATES = [
|
|
||||||
"极其频繁无奇的iPhone自拍照,没有明确的主体或构图感,就是随手一拍的快照照片略带运动模糊,阳光或室内打光不均匀导致的轻微曝光过度,整体呈现出一种刻意的平庸感,就像是从口袋里拿手机时不小心拍到的一张自拍。主角是{name},{personality}"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 简单的请求缓存,避免短时间内重复请求
|
|
||||||
_request_cache = {}
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
|
||||||
logger.info(f"{self.log_prefix} 执行拍照动作")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 配置验证
|
|
||||||
http_base_url = self.api.get_config("api.base_url")
|
|
||||||
http_api_key = self.api.get_config("api.volcano_generate_api_key")
|
|
||||||
|
|
||||||
if not (http_base_url and http_api_key):
|
|
||||||
error_msg = "抱歉,照片生成功能所需的API配置(如API地址或密钥)不完整,无法提供服务。"
|
|
||||||
await self.send_text(error_msg)
|
|
||||||
logger.error(f"{self.log_prefix} HTTP调用配置缺失: base_url 或 volcano_generate_api_key.")
|
|
||||||
return False, "API配置不完整"
|
|
||||||
|
|
||||||
# API密钥验证
|
|
||||||
if http_api_key == "YOUR_DOUBAO_API_KEY_HERE":
|
|
||||||
error_msg = "照片生成功能尚未配置,请设置正确的API密钥。"
|
|
||||||
await self.send_text(error_msg)
|
|
||||||
logger.error(f"{self.log_prefix} API密钥未配置")
|
|
||||||
return False, "API密钥未配置"
|
|
||||||
|
|
||||||
# 获取全局配置信息
|
|
||||||
bot_nickname = self.api.get_global_config("bot.nickname", "麦麦")
|
|
||||||
bot_personality = self.api.get_global_config("personality.personality_core", "")
|
|
||||||
|
|
||||||
personality_side = self.api.get_global_config("personality.personality_side", [])
|
|
||||||
if personality_side:
|
|
||||||
bot_personality += random.choice(personality_side)
|
|
||||||
|
|
||||||
# 准备模板变量
|
|
||||||
template_vars = {"name": bot_nickname, "personality": bot_personality}
|
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 使用的全局配置: name={bot_nickname}, personality={bot_personality}")
|
|
||||||
|
|
||||||
# 尝试从配置文件获取模板,如果没有则使用默认模板
|
|
||||||
templates = self.api.get_config("picture.prompt_templates", self.DEFAULT_PROMPT_TEMPLATES)
|
|
||||||
if not templates:
|
|
||||||
logger.warning(f"{self.log_prefix} 未找到有效的提示词模板,使用默认模板")
|
|
||||||
templates = self.DEFAULT_PROMPT_TEMPLATES
|
|
||||||
|
|
||||||
prompt_template = random.choice(templates)
|
|
||||||
|
|
||||||
# 填充模板
|
|
||||||
final_prompt = prompt_template.format(**template_vars)
|
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 生成的最终Prompt: {final_prompt}")
|
|
||||||
|
|
||||||
# 从配置获取参数
|
|
||||||
model = self.api.get_config("picture.default_model", "doubao-seedream-3-0-t2i-250415")
|
|
||||||
size = self.api.get_config("picture.default_size", "1024x1024")
|
|
||||||
watermark = self.api.get_config("picture.default_watermark", True)
|
|
||||||
guidance_scale = self.api.get_config("picture.default_guidance_scale", 2.5)
|
|
||||||
seed = self.api.get_config("picture.default_seed", 42)
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
enable_cache = self.api.get_config("storage.enable_cache", True)
|
|
||||||
if enable_cache:
|
|
||||||
cache_key = self._get_cache_key(final_prompt, model, size)
|
|
||||||
if cache_key in self._request_cache:
|
|
||||||
cached_result = self._request_cache[cache_key]
|
|
||||||
logger.info(f"{self.log_prefix} 使用缓存的图片结果")
|
|
||||||
await self.send_text("我之前拍过类似的照片,用之前的结果~")
|
|
||||||
|
|
||||||
# 直接发送缓存的结果
|
|
||||||
send_success = await self._send_image(cached_result)
|
|
||||||
if send_success:
|
|
||||||
await self.send_text("这是我的照片,好看吗?")
|
|
||||||
return True, "照片已发送(缓存)"
|
|
||||||
else:
|
|
||||||
# 缓存失败,清除这个缓存项并继续正常流程
|
|
||||||
del self._request_cache[cache_key]
|
|
||||||
|
|
||||||
await self.send_text("正在为你拍照,请稍候...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
seed = random.randint(1, 1000000)
|
|
||||||
success, result = await asyncio.to_thread(
|
|
||||||
self._make_http_image_request,
|
|
||||||
prompt=final_prompt,
|
|
||||||
model=model,
|
|
||||||
size=size,
|
|
||||||
seed=seed,
|
|
||||||
guidance_scale=guidance_scale,
|
|
||||||
watermark=watermark,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
success = False
|
|
||||||
result = f"照片生成服务遇到意外问题: {str(e)[:100]}"
|
|
||||||
|
|
||||||
if success:
|
|
||||||
image_url = result
|
|
||||||
logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
encode_success = False
|
|
||||||
encode_result = f"图片下载或编码时发生内部错误: {str(e)[:100]}"
|
|
||||||
|
|
||||||
if encode_success:
|
|
||||||
base64_image_string = encode_result
|
|
||||||
# 更新缓存
|
|
||||||
if enable_cache:
|
|
||||||
self._update_cache(final_prompt, model, size, base64_image_string)
|
|
||||||
|
|
||||||
# 发送图片
|
|
||||||
send_success = await self._send_image(base64_image_string)
|
|
||||||
if send_success:
|
|
||||||
# 存储到文件
|
|
||||||
await self._store_picture_info(final_prompt, image_url)
|
|
||||||
logger.info(f"{self.log_prefix} 成功生成并存储照片: {image_url}")
|
|
||||||
await self.send_text("当当当当~这是我刚拍的照片,好看吗?")
|
|
||||||
return True, f"成功生成照片: {image_url}"
|
|
||||||
else:
|
|
||||||
await self.send_text("照片生成了,但发送失败了,可能是格式问题...")
|
|
||||||
return False, "照片发送失败"
|
|
||||||
else:
|
|
||||||
await self.send_text(f"照片下载失败: {encode_result}")
|
|
||||||
return False, encode_result
|
|
||||||
else:
|
|
||||||
await self.send_text(f"哎呀,拍照失败了: {result}")
|
|
||||||
return False, result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 执行拍照动作失败: {e}", exc_info=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
await self.send_text("呜呜,拍照的时候出了一点小问题...")
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
async def _store_picture_info(self, prompt: str, image_url: str):
|
|
||||||
"""将照片信息存入日志文件"""
|
|
||||||
log_file = self.api.get_config("storage.log_file", "picture_log.json")
|
|
||||||
log_path = os.path.join(DATA_DIR, log_file)
|
|
||||||
max_photos = self.api.get_config("storage.max_photos", 50)
|
|
||||||
|
|
||||||
async with file_lock:
|
|
||||||
try:
|
|
||||||
if os.path.exists(log_path):
|
|
||||||
with open(log_path, "r", encoding="utf-8") as f:
|
|
||||||
log_data = json.load(f)
|
|
||||||
else:
|
|
||||||
log_data = []
|
|
||||||
except (json.JSONDecodeError, FileNotFoundError):
|
|
||||||
log_data = []
|
|
||||||
|
|
||||||
# 添加新照片
|
|
||||||
log_data.append(
|
|
||||||
{"prompt": prompt, "image_url": image_url, "timestamp": datetime.datetime.now().isoformat()}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果超过最大数量,删除最旧的
|
|
||||||
if len(log_data) > max_photos:
|
|
||||||
log_data = sorted(log_data, key=lambda x: x.get("timestamp", ""), reverse=True)[:max_photos]
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(log_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(log_data, f, ensure_ascii=False, indent=4)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 写入照片日志文件失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _make_http_image_request(
|
|
||||||
self, prompt: str, model: str, size: str, seed: int, guidance_scale: float, watermark: bool
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
"""发送HTTP请求到火山引擎豆包API生成图片"""
|
|
||||||
try:
|
|
||||||
base_url = self.api.get_config("api.base_url")
|
|
||||||
api_key = self.api.get_config("api.volcano_generate_api_key")
|
|
||||||
|
|
||||||
# 构建请求URL和头部
|
|
||||||
endpoint = f"{base_url.rstrip('/')}/images/generations"
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
request_body = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"response_format": "url",
|
|
||||||
"size": size,
|
|
||||||
"seed": seed,
|
|
||||||
"guidance_scale": guidance_scale,
|
|
||||||
"watermark": watermark,
|
|
||||||
"api-key": api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建请求对象
|
|
||||||
req = urllib.request.Request(
|
|
||||||
endpoint,
|
|
||||||
data=json.dumps(request_body).encode("utf-8"),
|
|
||||||
headers=headers,
|
|
||||||
method="POST",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发送请求并获取响应
|
|
||||||
with urllib.request.urlopen(req, timeout=60) as response:
|
|
||||||
response_data = json.loads(response.read().decode("utf-8"))
|
|
||||||
|
|
||||||
# 解析响应
|
|
||||||
image_url = None
|
|
||||||
if (
|
|
||||||
isinstance(response_data.get("data"), list)
|
|
||||||
and response_data["data"]
|
|
||||||
and isinstance(response_data["data"][0], dict)
|
|
||||||
):
|
|
||||||
image_url = response_data["data"][0].get("url")
|
|
||||||
elif response_data.get("url"):
|
|
||||||
image_url = response_data.get("url")
|
|
||||||
|
|
||||||
if image_url:
|
|
||||||
return True, image_url
|
|
||||||
else:
|
|
||||||
error_msg = response_data.get("error", {}).get("message", "未知错误")
|
|
||||||
logger.error(f"API返回错误: {error_msg}")
|
|
||||||
return False, f"API错误: {error_msg}"
|
|
||||||
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
error_body = e.read().decode("utf-8")
|
|
||||||
logger.error(f"HTTP错误 {e.code}: {error_body}")
|
|
||||||
return False, f"HTTP错误 {e.code}: {error_body[:100]}..."
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"请求异常: {e}", exc_info=True)
|
|
||||||
return False, f"请求异常: {str(e)}"
|
|
||||||
|
|
||||||
def _download_and_encode_base64(self, image_url: str) -> Tuple[bool, str]:
|
|
||||||
"""下载图片并转换为Base64编码"""
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(image_url) as response:
|
|
||||||
image_data = response.read()
|
|
||||||
|
|
||||||
base64_encoded = base64.b64encode(image_data).decode("utf-8")
|
|
||||||
return True, base64_encoded
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片下载编码失败: {e}", exc_info=True)
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
async def _send_image(self, base64_image: str) -> bool:
|
|
||||||
"""发送图片"""
|
|
||||||
try:
|
|
||||||
# 使用聊天流信息确定发送目标
|
|
||||||
chat_stream = self.api.get_service("chat_stream")
|
|
||||||
if not chat_stream:
|
|
||||||
logger.error(f"{self.log_prefix} 没有可用的聊天流发送图片")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if chat_stream.group_info:
|
|
||||||
# 群聊
|
|
||||||
return await self.api.send_message_to_target(
|
|
||||||
message_type="image",
|
|
||||||
content=base64_image,
|
|
||||||
platform=chat_stream.platform,
|
|
||||||
target_id=str(chat_stream.group_info.group_id),
|
|
||||||
is_group=True,
|
|
||||||
display_message="发送生成的照片",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 私聊
|
|
||||||
return await self.api.send_message_to_target(
|
|
||||||
message_type="image",
|
|
||||||
content=base64_image,
|
|
||||||
platform=chat_stream.platform,
|
|
||||||
target_id=str(chat_stream.user_info.user_id),
|
|
||||||
is_group=False,
|
|
||||||
display_message="发送生成的照片",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 发送图片时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_cache_key(cls, description: str, model: str, size: str) -> str:
|
|
||||||
"""生成缓存键"""
|
|
||||||
return f"{description}|{model}|{size}"
|
|
||||||
|
|
||||||
def _update_cache(self, description: str, model: str, size: str, base64_image: str):
|
|
||||||
"""更新缓存"""
|
|
||||||
max_cache_size = self.api.get_config("storage.max_cache_size", 10)
|
|
||||||
cache_key = self._get_cache_key(description, model, size)
|
|
||||||
|
|
||||||
# 添加到缓存
|
|
||||||
self._request_cache[cache_key] = base64_image
|
|
||||||
|
|
||||||
# 如果缓存超过最大大小,删除最旧的项
|
|
||||||
if len(self._request_cache) > max_cache_size:
|
|
||||||
oldest_key = next(iter(self._request_cache))
|
|
||||||
del self._request_cache[oldest_key]
|
|
||||||
|
|
||||||
|
|
||||||
class ShowRecentPicturesCommand(BaseCommand):
|
|
||||||
"""展示最近生成的照片"""
|
|
||||||
|
|
||||||
command_name = "show_recent_pictures"
|
|
||||||
command_description = "展示最近生成的5张照片"
|
|
||||||
command_pattern = r"^/show_pics$"
|
|
||||||
command_help = "用法: /show_pics"
|
|
||||||
command_examples = ["/show_pics"]
|
|
||||||
intercept_message = True
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
|
||||||
logger.info(f"{self.log_prefix} 执行展示最近照片命令")
|
|
||||||
log_file = self.api.get_config("storage.log_file", "picture_log.json")
|
|
||||||
log_path = os.path.join(DATA_DIR, log_file)
|
|
||||||
|
|
||||||
async with file_lock:
|
|
||||||
try:
|
|
||||||
if not os.path.exists(log_path):
|
|
||||||
await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!")
|
|
||||||
return True, "没有照片日志文件"
|
|
||||||
|
|
||||||
with open(log_path, "r", encoding="utf-8") as f:
|
|
||||||
log_data = json.load(f)
|
|
||||||
|
|
||||||
if not log_data:
|
|
||||||
await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!")
|
|
||||||
return True, "没有照片"
|
|
||||||
|
|
||||||
# 获取最新的5张照片
|
|
||||||
recent_pics = sorted(log_data, key=lambda x: x["timestamp"], reverse=True)[:5]
|
|
||||||
|
|
||||||
# 先发送文本消息
|
|
||||||
await self.send_text("这是我最近拍的几张照片~")
|
|
||||||
|
|
||||||
# 逐个发送图片
|
|
||||||
for pic in recent_pics:
|
|
||||||
# 尝试获取图片URL
|
|
||||||
image_url = pic.get("image_url")
|
|
||||||
if image_url:
|
|
||||||
try:
|
|
||||||
# 下载图片并转换为Base64
|
|
||||||
with urllib.request.urlopen(image_url) as response:
|
|
||||||
image_data = response.read()
|
|
||||||
base64_encoded = base64.b64encode(image_data).decode("utf-8")
|
|
||||||
|
|
||||||
# 发送图片
|
|
||||||
await self.send_type(
|
|
||||||
message_type="image", content=base64_encoded, display_message="发送最近的照片"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 下载或发送照片失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
return True, "成功展示最近的照片"
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
await self.send_text("照片记录文件好像损坏了...")
|
|
||||||
return False, "JSON解码错误"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 展示照片失败: {e}", exc_info=True)
|
|
||||||
await self.send_text("哎呀,查找照片的时候出错了。")
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class TakePicturePlugin(BasePlugin):
|
|
||||||
"""拍照插件"""
|
|
||||||
|
|
||||||
plugin_name = "take_picture_plugin" # 内部标识符
|
|
||||||
enable_plugin = False
|
|
||||||
dependencies = [] # 插件依赖列表
|
|
||||||
python_dependencies = [] # Python包依赖列表
|
|
||||||
config_file_name = "config.toml"
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {
|
|
||||||
"plugin": "插件基本信息配置",
|
|
||||||
"api": "API相关配置,包含火山引擎API的访问信息",
|
|
||||||
"components": "组件启用控制",
|
|
||||||
"picture": "拍照功能核心配置",
|
|
||||||
"storage": "照片存储相关配置",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema = {
|
|
||||||
"plugin": {
|
|
||||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
|
||||||
},
|
|
||||||
"api": {
|
|
||||||
"base_url": ConfigField(
|
|
||||||
type=str,
|
|
||||||
default="https://ark.cn-beijing.volces.com/api/v3",
|
|
||||||
description="API基础URL",
|
|
||||||
example="https://api.example.com/v1",
|
|
||||||
),
|
|
||||||
"volcano_generate_api_key": ConfigField(
|
|
||||||
type=str, default="YOUR_DOUBAO_API_KEY_HERE", description="火山引擎豆包API密钥", required=True
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"components": {
|
|
||||||
"enable_take_picture_action": ConfigField(type=bool, default=True, description="是否启用拍照Action"),
|
|
||||||
"enable_show_pics_command": ConfigField(type=bool, default=True, description="是否启用展示照片Command"),
|
|
||||||
},
|
|
||||||
"picture": {
|
|
||||||
"default_model": ConfigField(
|
|
||||||
type=str,
|
|
||||||
default="doubao-seedream-3-0-t2i-250415",
|
|
||||||
description="默认使用的文生图模型",
|
|
||||||
choices=["doubao-seedream-3-0-t2i-250415", "doubao-seedream-2-0-t2i"],
|
|
||||||
),
|
|
||||||
"default_size": ConfigField(
|
|
||||||
type=str,
|
|
||||||
default="1024x1024",
|
|
||||||
description="默认图片尺寸",
|
|
||||||
example="1024x1024",
|
|
||||||
choices=["1024x1024", "1024x1280", "1280x1024", "1024x1536", "1536x1024"],
|
|
||||||
),
|
|
||||||
"default_watermark": ConfigField(type=bool, default=True, description="是否默认添加水印"),
|
|
||||||
"default_guidance_scale": ConfigField(
|
|
||||||
type=float, default=2.5, description="模型指导强度,影响图片与提示的关联性", example="2.0"
|
|
||||||
),
|
|
||||||
"default_seed": ConfigField(type=int, default=42, description="随机种子,用于复现图片"),
|
|
||||||
"prompt_templates": ConfigField(
|
|
||||||
type=list, default=TakePictureAction.DEFAULT_PROMPT_TEMPLATES, description="用于生成自拍照的prompt模板"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"storage": {
|
|
||||||
"max_photos": ConfigField(type=int, default=50, description="最大保存的照片数量"),
|
|
||||||
"log_file": ConfigField(type=str, default="picture_log.json", description="照片日志文件名"),
|
|
||||||
"enable_cache": ConfigField(type=bool, default=True, description="是否启用请求缓存"),
|
|
||||||
"max_cache_size": ConfigField(type=int, default=10, description="最大缓存数量"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
"""返回插件包含的组件列表"""
|
|
||||||
components = []
|
|
||||||
if self.get_config("components.enable_take_picture_action", True):
|
|
||||||
components.append((TakePictureAction.get_action_info(), TakePictureAction))
|
|
||||||
if self.get_config("components.enable_show_pics_command", True):
|
|
||||||
components.append((ShowRecentPicturesCommand.get_command_info(), ShowRecentPicturesCommand))
|
|
||||||
return components
|
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
import glob
|
|
||||||
import sqlite3
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
|
||||||
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
|
||||||
if not cleaned:
|
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
|
||||||
"""从数据库中获取群组名称"""
|
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT group_name, user_nickname, platform
|
|
||||||
FROM chat_streams
|
|
||||||
WHERE stream_id = ?
|
|
||||||
""",
|
|
||||||
(stream_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if result:
|
|
||||||
group_name, user_nickname, platform = result
|
|
||||||
if group_name:
|
|
||||||
return clean_group_name(group_name)
|
|
||||||
if user_nickname:
|
|
||||||
return clean_group_name(user_nickname)
|
|
||||||
if platform:
|
|
||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
|
||||||
return stream_id
|
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(timestamp: float) -> str:
|
|
||||||
"""将时间戳转换为可读的时间格式"""
|
|
||||||
if not timestamp:
|
|
||||||
return "未知"
|
|
||||||
try:
|
|
||||||
dt = datetime.fromtimestamp(timestamp)
|
|
||||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"时间戳格式化错误: {e}")
|
|
||||||
return "未知"
|
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
|
||||||
"""加载指定群聊的表达方式"""
|
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
|
||||||
|
|
||||||
style_exprs = []
|
|
||||||
|
|
||||||
if os.path.exists(style_file):
|
|
||||||
with open(style_file, "r", encoding="utf-8") as f:
|
|
||||||
style_exprs = json.load(f)
|
|
||||||
|
|
||||||
return style_exprs
|
|
||||||
|
|
||||||
|
|
||||||
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
|
|
||||||
"""找出每个表达方式最相似的top_k个表达方式"""
|
|
||||||
if not expressions:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# 分别准备情景和表达方式的文本数据
|
|
||||||
situations = [expr["situation"] for expr in expressions]
|
|
||||||
styles = [expr["style"] for expr in expressions]
|
|
||||||
|
|
||||||
# 使用TF-IDF向量化
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
situation_matrix = vectorizer.fit_transform(situations)
|
|
||||||
style_matrix = vectorizer.fit_transform(styles)
|
|
||||||
|
|
||||||
# 计算余弦相似度
|
|
||||||
situation_similarity = cosine_similarity(situation_matrix)
|
|
||||||
style_similarity = cosine_similarity(style_matrix)
|
|
||||||
|
|
||||||
# 对每个表达方式找出最相似的top_k个
|
|
||||||
similar_expressions = {}
|
|
||||||
for i, _ in enumerate(expressions):
|
|
||||||
# 获取相似度分数
|
|
||||||
situation_scores = situation_similarity[i]
|
|
||||||
style_scores = style_similarity[i]
|
|
||||||
|
|
||||||
# 获取top_k的索引(排除自己)
|
|
||||||
situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1]
|
|
||||||
style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1]
|
|
||||||
|
|
||||||
similar_situations = []
|
|
||||||
similar_styles = []
|
|
||||||
|
|
||||||
# 处理相似情景
|
|
||||||
for idx in situation_indices:
|
|
||||||
if situation_scores[idx] > 0: # 只保留有相似度的
|
|
||||||
similar_situations.append(
|
|
||||||
(
|
|
||||||
expressions[idx]["situation"],
|
|
||||||
expressions[idx]["style"], # 添加对应的原始表达
|
|
||||||
situation_scores[idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理相似表达
|
|
||||||
for idx in style_indices:
|
|
||||||
if style_scores[idx] > 0: # 只保留有相似度的
|
|
||||||
similar_styles.append(
|
|
||||||
(
|
|
||||||
expressions[idx]["style"],
|
|
||||||
expressions[idx]["situation"], # 添加对应的原始情景
|
|
||||||
style_scores[idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if similar_situations or similar_styles:
|
|
||||||
similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles}
|
|
||||||
|
|
||||||
return similar_expressions
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 获取所有群聊ID
|
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
|
||||||
chat_ids = [os.path.basename(d) for d in style_dirs]
|
|
||||||
|
|
||||||
if not chat_ids:
|
|
||||||
print("没有找到任何群聊的表达方式数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("可用的群聊:")
|
|
||||||
for i, chat_id in enumerate(chat_ids, 1):
|
|
||||||
group_name = get_group_name(chat_id)
|
|
||||||
print(f"{i}. {group_name}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
|
||||||
if choice == 0:
|
|
||||||
break
|
|
||||||
if 1 <= choice <= len(chat_ids):
|
|
||||||
chat_id = chat_ids[choice - 1]
|
|
||||||
break
|
|
||||||
print("无效的选择,请重试")
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效的数字")
|
|
||||||
|
|
||||||
if choice == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 加载表达方式
|
|
||||||
style_exprs = load_expressions(chat_id)
|
|
||||||
|
|
||||||
group_name = get_group_name(chat_id)
|
|
||||||
print(f"\n分析群聊 {group_name} 的表达方式:")
|
|
||||||
|
|
||||||
similar_styles = find_similar_expressions(style_exprs)
|
|
||||||
for i, expr in enumerate(style_exprs):
|
|
||||||
if i in similar_styles:
|
|
||||||
print("\n" + "-" * 20)
|
|
||||||
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
|
|
||||||
|
|
||||||
if similar_styles[i]["styles"]:
|
|
||||||
print("\n\033[33m相似表达:\033[0m")
|
|
||||||
for similar_style, original_situation, score in similar_styles[i]["styles"]:
|
|
||||||
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
|
|
||||||
|
|
||||||
if similar_styles[i]["situations"]:
|
|
||||||
print("\n\033[32m相似情景:\033[0m")
|
|
||||||
for similar_situation, original_style, score in similar_styles[i]["situations"]:
|
|
||||||
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}"
|
|
||||||
)
|
|
||||||
print("-" * 20)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,215 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Any
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
|
||||||
# 提取中文和英文字符
|
|
||||||
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
|
||||||
# 如果清理后为空,使用当前日期
|
|
||||||
if not cleaned:
|
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
|
||||||
"""从数据库中获取群组名称"""
|
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT group_name, user_nickname, platform
|
|
||||||
FROM chat_streams
|
|
||||||
WHERE stream_id = ?
|
|
||||||
""",
|
|
||||||
(stream_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if result:
|
|
||||||
group_name, user_nickname, platform = result
|
|
||||||
if group_name:
|
|
||||||
return clean_group_name(group_name)
|
|
||||||
if user_nickname:
|
|
||||||
return clean_group_name(user_nickname)
|
|
||||||
if platform:
|
|
||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
|
||||||
return stream_id
|
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
||||||
"""加载指定群组的表达方式"""
|
|
||||||
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
|
||||||
learnt_grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
|
||||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
|
||||||
|
|
||||||
style_expressions = []
|
|
||||||
grammar_expressions = []
|
|
||||||
personality_expressions = []
|
|
||||||
|
|
||||||
if os.path.exists(learnt_style_file):
|
|
||||||
with open(learnt_style_file, "r", encoding="utf-8") as f:
|
|
||||||
style_expressions = json.load(f)
|
|
||||||
|
|
||||||
if os.path.exists(learnt_grammar_file):
|
|
||||||
with open(learnt_grammar_file, "r", encoding="utf-8") as f:
|
|
||||||
grammar_expressions = json.load(f)
|
|
||||||
|
|
||||||
if os.path.exists(personality_file):
|
|
||||||
with open(personality_file, "r", encoding="utf-8") as f:
|
|
||||||
personality_expressions = json.load(f)
|
|
||||||
|
|
||||||
return style_expressions, grammar_expressions, personality_expressions
|
|
||||||
|
|
||||||
|
|
||||||
def format_time(timestamp: float) -> str:
|
|
||||||
"""格式化时间戳为可读字符串"""
|
|
||||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
|
|
||||||
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
|
||||||
"""写入表达方式列表"""
|
|
||||||
if not expressions:
|
|
||||||
f.write(f"{title}:暂无数据\n")
|
|
||||||
f.write("-" * 40 + "\n")
|
|
||||||
return
|
|
||||||
|
|
||||||
f.write(f"{title}:\n")
|
|
||||||
for expr in expressions:
|
|
||||||
count = expr.get("count", 0)
|
|
||||||
last_active = expr.get("last_active_time", time.time())
|
|
||||||
f.write(f"场景: {expr['situation']}\n")
|
|
||||||
f.write(f"表达: {expr['style']}\n")
|
|
||||||
f.write(f"计数: {count:.4f}\n")
|
|
||||||
f.write(f"最后活跃: {format_time(last_active)}\n")
|
|
||||||
f.write("-" * 40 + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def write_group_report(
|
|
||||||
group_file: str,
|
|
||||||
group_name: str,
|
|
||||||
chat_id: str,
|
|
||||||
style_exprs: List[Dict[str, Any]],
|
|
||||||
grammar_exprs: List[Dict[str, Any]],
|
|
||||||
):
|
|
||||||
"""写入群组详细报告"""
|
|
||||||
with open(group_file, "w", encoding="utf-8") as gf:
|
|
||||||
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
|
|
||||||
gf.write("=" * 80 + "\n\n")
|
|
||||||
|
|
||||||
# 写入语言风格
|
|
||||||
gf.write("【语言风格】\n")
|
|
||||||
gf.write("=" * 40 + "\n")
|
|
||||||
write_expressions(gf, style_exprs, "语言风格")
|
|
||||||
gf.write("\n")
|
|
||||||
|
|
||||||
# 写入句法特点
|
|
||||||
gf.write("【句法特点】\n")
|
|
||||||
gf.write("=" * 40 + "\n")
|
|
||||||
write_expressions(gf, grammar_exprs, "句法特点")
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_expressions():
|
|
||||||
"""分析所有群组的表达方式"""
|
|
||||||
# 获取所有群组ID
|
|
||||||
style_dir = os.path.join("data", "expression", "learnt_style")
|
|
||||||
chat_ids = [d for d in os.listdir(style_dir) if os.path.isdir(os.path.join(style_dir, d))]
|
|
||||||
|
|
||||||
# 创建输出目录
|
|
||||||
output_dir = "data/expression_analysis"
|
|
||||||
personality_dir = os.path.join(output_dir, "personality")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
os.makedirs(personality_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 生成时间戳
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
# 创建总报告
|
|
||||||
summary_file = os.path.join(output_dir, f"summary_{timestamp}.txt")
|
|
||||||
with open(summary_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write(f"表达方式分析报告 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
||||||
f.write("=" * 80 + "\n\n")
|
|
||||||
|
|
||||||
# 先处理人格表达
|
|
||||||
personality_exprs = []
|
|
||||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
|
||||||
if os.path.exists(personality_file):
|
|
||||||
with open(personality_file, "r", encoding="utf-8") as pf:
|
|
||||||
personality_exprs = json.load(pf)
|
|
||||||
|
|
||||||
# 保存人格表达总数
|
|
||||||
total_personality = len(personality_exprs)
|
|
||||||
|
|
||||||
# 排序并取前20条
|
|
||||||
personality_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
|
||||||
personality_exprs = personality_exprs[:20]
|
|
||||||
|
|
||||||
# 写入人格表达报告
|
|
||||||
personality_report = os.path.join(personality_dir, f"expressions_{timestamp}.txt")
|
|
||||||
with open(personality_report, "w", encoding="utf-8") as pf:
|
|
||||||
pf.write("【人格表达方式】\n")
|
|
||||||
pf.write("=" * 40 + "\n")
|
|
||||||
write_expressions(pf, personality_exprs, "人格表达")
|
|
||||||
|
|
||||||
# 写入总报告摘要中的人格表达部分
|
|
||||||
f.write("【人格表达方式】\n")
|
|
||||||
f.write("=" * 40 + "\n")
|
|
||||||
f.write(f"人格表达总数: {total_personality} (显示前20条)\n")
|
|
||||||
f.write(f"详细报告: {personality_report}\n")
|
|
||||||
f.write("-" * 40 + "\n\n")
|
|
||||||
|
|
||||||
# 处理各个群组的表达方式
|
|
||||||
f.write("【群组表达方式】\n")
|
|
||||||
f.write("=" * 40 + "\n\n")
|
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
|
||||||
style_exprs, grammar_exprs, _ = load_expressions(chat_id)
|
|
||||||
|
|
||||||
# 保存总数
|
|
||||||
total_style = len(style_exprs)
|
|
||||||
total_grammar = len(grammar_exprs)
|
|
||||||
|
|
||||||
# 分别排序
|
|
||||||
style_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
|
||||||
grammar_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
|
||||||
|
|
||||||
# 只取前20条
|
|
||||||
style_exprs = style_exprs[:20]
|
|
||||||
grammar_exprs = grammar_exprs[:20]
|
|
||||||
|
|
||||||
# 获取群组名称
|
|
||||||
group_name = get_group_name(chat_id)
|
|
||||||
|
|
||||||
# 创建群组子目录(使用清理后的名称)
|
|
||||||
safe_group_name = clean_group_name(group_name)
|
|
||||||
group_dir = os.path.join(output_dir, f"{safe_group_name}_{chat_id}")
|
|
||||||
os.makedirs(group_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 写入群组详细报告
|
|
||||||
group_file = os.path.join(group_dir, f"expressions_{timestamp}.txt")
|
|
||||||
write_group_report(group_file, group_name, chat_id, style_exprs, grammar_exprs)
|
|
||||||
|
|
||||||
# 写入总报告摘要
|
|
||||||
f.write(f"群组: {group_name} (ID: {chat_id})\n")
|
|
||||||
f.write("-" * 40 + "\n")
|
|
||||||
f.write(f"语言风格总数: {total_style} (显示前20条)\n")
|
|
||||||
f.write(f"句法特点总数: {total_grammar} (显示前20条)\n")
|
|
||||||
f.write(f"详细报告: {group_file}\n")
|
|
||||||
f.write("-" * 40 + "\n\n")
|
|
||||||
|
|
||||||
print("分析报告已生成:")
|
|
||||||
print(f"总报告: {summary_file}")
|
|
||||||
print(f"人格表达报告: {personality_report}")
|
|
||||||
print(f"各群组详细报告位于: {output_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
analyze_expressions()
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
# 设置中文字体
|
|
||||||
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 使用微软雅黑
|
|
||||||
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
|
|
||||||
plt.rcParams["font.family"] = "sans-serif"
|
|
||||||
|
|
||||||
# 获取脚本所在目录
|
|
||||||
SCRIPT_DIR = Path(__file__).parent
|
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id):
|
|
||||||
"""从数据库中获取群组名称"""
|
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT group_name, user_nickname, platform
|
|
||||||
FROM chat_streams
|
|
||||||
WHERE stream_id = ?
|
|
||||||
""",
|
|
||||||
(stream_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if result:
|
|
||||||
group_name, user_nickname, platform = result
|
|
||||||
if group_name:
|
|
||||||
return group_name
|
|
||||||
if user_nickname:
|
|
||||||
return user_nickname
|
|
||||||
if platform:
|
|
||||||
return f"{platform}-{stream_id[:8]}"
|
|
||||||
return stream_id
|
|
||||||
|
|
||||||
|
|
||||||
def load_group_data(group_dir):
|
|
||||||
"""加载单个群组的数据"""
|
|
||||||
json_path = Path(group_dir) / "expressions.json"
|
|
||||||
if not json_path.exists():
|
|
||||||
return [], [], [], 0
|
|
||||||
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
situations = []
|
|
||||||
styles = []
|
|
||||||
combined = []
|
|
||||||
total_count = sum(item["count"] for item in data)
|
|
||||||
|
|
||||||
for item in data:
|
|
||||||
count = item["count"]
|
|
||||||
situations.extend([item["situation"]] * int(count))
|
|
||||||
styles.extend([item["style"]] * int(count))
|
|
||||||
combined.extend([f"{item['situation']} {item['style']}"] * int(count))
|
|
||||||
|
|
||||||
return situations, styles, combined, total_count
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_group_similarity():
|
|
||||||
# 获取所有群组目录
|
|
||||||
base_dir = Path("data/expression/learnt_style")
|
|
||||||
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
|
||||||
|
|
||||||
# 加载所有群组的数据并过滤
|
|
||||||
valid_groups = []
|
|
||||||
valid_names = []
|
|
||||||
valid_situations = []
|
|
||||||
valid_styles = []
|
|
||||||
valid_combined = []
|
|
||||||
|
|
||||||
for d in group_dirs:
|
|
||||||
situations, styles, combined, total_count = load_group_data(d)
|
|
||||||
if total_count >= 50: # 只保留数据量大于等于50的群组
|
|
||||||
valid_groups.append(d)
|
|
||||||
valid_names.append(get_group_name(d.name))
|
|
||||||
valid_situations.append(" ".join(situations))
|
|
||||||
valid_styles.append(" ".join(styles))
|
|
||||||
valid_combined.append(" ".join(combined))
|
|
||||||
|
|
||||||
if not valid_groups:
|
|
||||||
print("没有找到数据量大于等于50的群组")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 创建TF-IDF向量化器
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
|
|
||||||
# 计算三种相似度矩阵
|
|
||||||
situation_matrix = cosine_similarity(vectorizer.fit_transform(valid_situations))
|
|
||||||
style_matrix = cosine_similarity(vectorizer.fit_transform(valid_styles))
|
|
||||||
combined_matrix = cosine_similarity(vectorizer.fit_transform(valid_combined))
|
|
||||||
|
|
||||||
# 对相似度矩阵进行对数变换
|
|
||||||
log_situation_matrix = np.log10(situation_matrix * 100 + 1) * 10 / np.log10(4)
|
|
||||||
log_style_matrix = np.log10(style_matrix * 100 + 1) * 10 / np.log10(4)
|
|
||||||
log_combined_matrix = np.log10(combined_matrix * 100 + 1) * 10 / np.log10(4)
|
|
||||||
|
|
||||||
# 创建一个大图,包含三个子图
|
|
||||||
plt.figure(figsize=(45, 12))
|
|
||||||
|
|
||||||
# 场景相似度热力图
|
|
||||||
plt.subplot(1, 3, 1)
|
|
||||||
sns.heatmap(
|
|
||||||
log_situation_matrix,
|
|
||||||
xticklabels=valid_names,
|
|
||||||
yticklabels=valid_names,
|
|
||||||
cmap="YlOrRd",
|
|
||||||
annot=True,
|
|
||||||
fmt=".1f",
|
|
||||||
vmin=0,
|
|
||||||
vmax=30,
|
|
||||||
)
|
|
||||||
plt.title("群组场景相似度热力图 (对数百分比)")
|
|
||||||
plt.xticks(rotation=45, ha="right")
|
|
||||||
|
|
||||||
# 表达方式相似度热力图
|
|
||||||
plt.subplot(1, 3, 2)
|
|
||||||
sns.heatmap(
|
|
||||||
log_style_matrix,
|
|
||||||
xticklabels=valid_names,
|
|
||||||
yticklabels=valid_names,
|
|
||||||
cmap="YlOrRd",
|
|
||||||
annot=True,
|
|
||||||
fmt=".1f",
|
|
||||||
vmin=0,
|
|
||||||
vmax=30,
|
|
||||||
)
|
|
||||||
plt.title("群组表达方式相似度热力图 (对数百分比)")
|
|
||||||
plt.xticks(rotation=45, ha="right")
|
|
||||||
|
|
||||||
# 组合相似度热力图
|
|
||||||
plt.subplot(1, 3, 3)
|
|
||||||
sns.heatmap(
|
|
||||||
log_combined_matrix,
|
|
||||||
xticklabels=valid_names,
|
|
||||||
yticklabels=valid_names,
|
|
||||||
cmap="YlOrRd",
|
|
||||||
annot=True,
|
|
||||||
fmt=".1f",
|
|
||||||
vmin=0,
|
|
||||||
vmax=30,
|
|
||||||
)
|
|
||||||
plt.title("群组场景+表达方式相似度热力图 (对数百分比)")
|
|
||||||
plt.xticks(rotation=45, ha="right")
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(SCRIPT_DIR / "group_similarity_heatmaps.png", dpi=300, bbox_inches="tight")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# 保存匹配详情到文本文件
|
|
||||||
with open(SCRIPT_DIR / "group_similarity_details.txt", "w", encoding="utf-8") as f:
|
|
||||||
f.write("群组相似度详情\n")
|
|
||||||
f.write("=" * 50 + "\n\n")
|
|
||||||
|
|
||||||
for i in range(len(valid_names)):
|
|
||||||
for j in range(i + 1, len(valid_names)):
|
|
||||||
if log_combined_matrix[i][j] > 50:
|
|
||||||
f.write(f"群组1: {valid_names[i]}\n")
|
|
||||||
f.write(f"群组2: {valid_names[j]}\n")
|
|
||||||
f.write(f"场景相似度: {situation_matrix[i][j]:.4f}\n")
|
|
||||||
f.write(f"表达方式相似度: {style_matrix[i][j]:.4f}\n")
|
|
||||||
f.write(f"组合相似度: {combined_matrix[i][j]:.4f}\n")
|
|
||||||
|
|
||||||
# 获取两个群组的数据
|
|
||||||
situations1, styles1, _ = load_group_data(valid_groups[i])
|
|
||||||
situations2, styles2, _ = load_group_data(valid_groups[j])
|
|
||||||
|
|
||||||
# 找出共同的场景
|
|
||||||
common_situations = set(situations1) & set(situations2)
|
|
||||||
if common_situations:
|
|
||||||
f.write("\n共同场景:\n")
|
|
||||||
for situation in common_situations:
|
|
||||||
f.write(f"- {situation}\n")
|
|
||||||
|
|
||||||
# 找出共同的表达方式
|
|
||||||
common_styles = set(styles1) & set(styles2)
|
|
||||||
if common_styles:
|
|
||||||
f.write("\n共同表达方式:\n")
|
|
||||||
for style in common_styles:
|
|
||||||
f.write(f"- {style}\n")
|
|
||||||
|
|
||||||
f.write("\n" + "-" * 50 + "\n\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
analyze_group_similarity()
|
|
||||||
208
scripts/expression_stats.py
Normal file
208
scripts/expression_stats.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
# Add project root to Python path
|
||||||
|
from src.common.database.database_model import Expression, ChatStreams
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_name(chat_id: str) -> str:
|
||||||
|
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||||
|
try:
|
||||||
|
# 直接从数据库查询ChatStreams表
|
||||||
|
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||||
|
if chat_stream is None:
|
||||||
|
return f"未知聊天 ({chat_id})"
|
||||||
|
|
||||||
|
# 如果有群组信息,显示群组名称
|
||||||
|
if chat_stream.group_name:
|
||||||
|
return f"{chat_stream.group_name} ({chat_id})"
|
||||||
|
# 如果是私聊,显示用户昵称
|
||||||
|
elif chat_stream.user_nickname:
|
||||||
|
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||||
|
else:
|
||||||
|
return f"未知聊天 ({chat_id})"
|
||||||
|
except Exception:
|
||||||
|
return f"查询失败 ({chat_id})"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||||
|
"""Calculate distribution of last active time in days"""
|
||||||
|
now = time.time()
|
||||||
|
distribution = {
|
||||||
|
'0-1天': 0,
|
||||||
|
'1-3天': 0,
|
||||||
|
'3-7天': 0,
|
||||||
|
'7-14天': 0,
|
||||||
|
'14-30天': 0,
|
||||||
|
'30-60天': 0,
|
||||||
|
'60-90天': 0,
|
||||||
|
'90+天': 0
|
||||||
|
}
|
||||||
|
for expr in expressions:
|
||||||
|
diff_days = (now - expr.last_active_time) / (24*3600)
|
||||||
|
if diff_days < 1:
|
||||||
|
distribution['0-1天'] += 1
|
||||||
|
elif diff_days < 3:
|
||||||
|
distribution['1-3天'] += 1
|
||||||
|
elif diff_days < 7:
|
||||||
|
distribution['3-7天'] += 1
|
||||||
|
elif diff_days < 14:
|
||||||
|
distribution['7-14天'] += 1
|
||||||
|
elif diff_days < 30:
|
||||||
|
distribution['14-30天'] += 1
|
||||||
|
elif diff_days < 60:
|
||||||
|
distribution['30-60天'] += 1
|
||||||
|
elif diff_days < 90:
|
||||||
|
distribution['60-90天'] += 1
|
||||||
|
else:
|
||||||
|
distribution['90+天'] += 1
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||||
|
"""Calculate distribution of count values"""
|
||||||
|
distribution = {
|
||||||
|
'0-1': 0,
|
||||||
|
'1-2': 0,
|
||||||
|
'2-3': 0,
|
||||||
|
'3-4': 0,
|
||||||
|
'4-5': 0,
|
||||||
|
'5-10': 0,
|
||||||
|
'10+': 0
|
||||||
|
}
|
||||||
|
for expr in expressions:
|
||||||
|
cnt = expr.count
|
||||||
|
if cnt < 1:
|
||||||
|
distribution['0-1'] += 1
|
||||||
|
elif cnt < 2:
|
||||||
|
distribution['1-2'] += 1
|
||||||
|
elif cnt < 3:
|
||||||
|
distribution['2-3'] += 1
|
||||||
|
elif cnt < 4:
|
||||||
|
distribution['3-4'] += 1
|
||||||
|
elif cnt < 5:
|
||||||
|
distribution['4-5'] += 1
|
||||||
|
elif cnt < 10:
|
||||||
|
distribution['5-10'] += 1
|
||||||
|
else:
|
||||||
|
distribution['10+'] += 1
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||||
|
"""Get top N most used expressions for a specific chat_id"""
|
||||||
|
return (Expression.select()
|
||||||
|
.where(Expression.chat_id == chat_id)
|
||||||
|
.order_by(Expression.count.desc())
|
||||||
|
.limit(top_n))
|
||||||
|
|
||||||
|
|
||||||
|
def show_overall_statistics(expressions, total: int) -> None:
|
||||||
|
"""Show overall statistics"""
|
||||||
|
time_dist = calculate_time_distribution(expressions)
|
||||||
|
count_dist = calculate_count_distribution(expressions)
|
||||||
|
|
||||||
|
print("\n=== 总体统计 ===")
|
||||||
|
print(f"总表达式数量: {total}")
|
||||||
|
|
||||||
|
print("\n上次激活时间分布:")
|
||||||
|
for period, count in time_dist.items():
|
||||||
|
print(f"{period}: {count} ({count/total*100:.2f}%)")
|
||||||
|
|
||||||
|
print("\ncount分布:")
|
||||||
|
for range_, count in count_dist.items():
|
||||||
|
print(f"{range_}: {count} ({count/total*100:.2f}%)")
|
||||||
|
|
||||||
|
|
||||||
|
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||||
|
"""Show statistics for a specific chat"""
|
||||||
|
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||||
|
chat_total = len(chat_exprs)
|
||||||
|
|
||||||
|
print(f"\n=== {chat_name} ===")
|
||||||
|
print(f"表达式数量: {chat_total}")
|
||||||
|
|
||||||
|
if chat_total == 0:
|
||||||
|
print("该聊天没有表达式数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Time distribution for this chat
|
||||||
|
time_dist = calculate_time_distribution(chat_exprs)
|
||||||
|
print("\n上次激活时间分布:")
|
||||||
|
for period, count in time_dist.items():
|
||||||
|
if count > 0:
|
||||||
|
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
|
||||||
|
|
||||||
|
# Count distribution for this chat
|
||||||
|
count_dist = calculate_count_distribution(chat_exprs)
|
||||||
|
print("\ncount分布:")
|
||||||
|
for range_, count in count_dist.items():
|
||||||
|
if count > 0:
|
||||||
|
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
|
||||||
|
|
||||||
|
# Top expressions
|
||||||
|
print("\nTop 10使用最多的表达式:")
|
||||||
|
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||||
|
for i, expr in enumerate(top_exprs, 1):
|
||||||
|
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||||
|
print(f" Situation: {expr.situation}")
|
||||||
|
print(f" Style: {expr.style}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_menu() -> None:
|
||||||
|
"""Interactive menu for expression statistics"""
|
||||||
|
# Get all expressions
|
||||||
|
expressions = list(Expression.select())
|
||||||
|
if not expressions:
|
||||||
|
print("数据库中没有找到表达式")
|
||||||
|
return
|
||||||
|
|
||||||
|
total = len(expressions)
|
||||||
|
|
||||||
|
# Get unique chat_ids and their names
|
||||||
|
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||||
|
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||||
|
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("表达式统计分析")
|
||||||
|
print("="*50)
|
||||||
|
print("0. 显示总体统计")
|
||||||
|
|
||||||
|
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||||
|
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||||
|
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||||
|
|
||||||
|
print("q. 退出")
|
||||||
|
|
||||||
|
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||||
|
|
||||||
|
if choice.lower() == 'q':
|
||||||
|
print("再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
choice_num = int(choice)
|
||||||
|
if choice_num == 0:
|
||||||
|
show_overall_statistics(expressions, total)
|
||||||
|
elif 1 <= choice_num <= len(chat_info):
|
||||||
|
chat_id, chat_name = chat_info[choice_num - 1]
|
||||||
|
show_chat_statistics(chat_id, chat_name)
|
||||||
|
else:
|
||||||
|
print("无效的选择,请重新输入")
|
||||||
|
except ValueError:
|
||||||
|
print("请输入有效的数字")
|
||||||
|
|
||||||
|
input("\n按回车键继续...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
interactive_menu()
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
import glob
|
|
||||||
import sqlite3
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
import random
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
|
||||||
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
|
||||||
if not cleaned:
|
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
|
||||||
"""从数据库中获取群组名称"""
|
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT group_name, user_nickname, platform
|
|
||||||
FROM chat_streams
|
|
||||||
WHERE stream_id = ?
|
|
||||||
""",
|
|
||||||
(stream_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if result:
|
|
||||||
group_name, user_nickname, platform = result
|
|
||||||
if group_name:
|
|
||||||
return clean_group_name(group_name)
|
|
||||||
if user_nickname:
|
|
||||||
return clean_group_name(user_nickname)
|
|
||||||
if platform:
|
|
||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
|
||||||
return stream_id
|
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
|
||||||
"""加载指定群聊的表达方式"""
|
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
|
||||||
|
|
||||||
style_exprs = []
|
|
||||||
|
|
||||||
if os.path.exists(style_file):
|
|
||||||
with open(style_file, "r", encoding="utf-8") as f:
|
|
||||||
style_exprs = json.load(f)
|
|
||||||
|
|
||||||
# 如果表达方式超过10个,随机选择10个
|
|
||||||
if len(style_exprs) > 50:
|
|
||||||
style_exprs = random.sample(style_exprs, 50)
|
|
||||||
print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配")
|
|
||||||
|
|
||||||
return style_exprs
|
|
||||||
|
|
||||||
|
|
||||||
def find_similar_expressions_tfidf(
|
|
||||||
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10
|
|
||||||
) -> List[Tuple[str, str, float]]:
|
|
||||||
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
|
|
||||||
if not expressions:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 准备文本数据
|
|
||||||
if mode == "style":
|
|
||||||
texts = [expr["style"] for expr in expressions]
|
|
||||||
elif mode == "situation":
|
|
||||||
texts = [expr["situation"] for expr in expressions]
|
|
||||||
else: # both
|
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
|
||||||
|
|
||||||
texts.append(input_text) # 添加输入文本
|
|
||||||
|
|
||||||
# 使用TF-IDF向量化
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
|
||||||
|
|
||||||
# 计算余弦相似度
|
|
||||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
|
||||||
|
|
||||||
# 获取输入文本的相似度分数(最后一行)
|
|
||||||
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
|
||||||
|
|
||||||
# 获取top_k的索引
|
|
||||||
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
||||||
|
|
||||||
# 获取相似表达
|
|
||||||
similar_exprs = []
|
|
||||||
for idx in top_indices:
|
|
||||||
if scores[idx] > 0: # 只保留有相似度的
|
|
||||||
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx]))
|
|
||||||
|
|
||||||
return similar_exprs
|
|
||||||
|
|
||||||
|
|
||||||
async def find_similar_expressions_embedding(
|
|
||||||
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5
|
|
||||||
) -> List[Tuple[str, str, float]]:
|
|
||||||
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
|
|
||||||
if not expressions:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 准备文本数据
|
|
||||||
if mode == "style":
|
|
||||||
texts = [expr["style"] for expr in expressions]
|
|
||||||
elif mode == "situation":
|
|
||||||
texts = [expr["situation"] for expr in expressions]
|
|
||||||
else: # both
|
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
|
||||||
|
|
||||||
# 获取嵌入向量
|
|
||||||
llm_request = LLMRequest(global_config.model.embedding)
|
|
||||||
text_embeddings = []
|
|
||||||
for text in texts:
|
|
||||||
embedding = await llm_request.get_embedding(text)
|
|
||||||
if embedding:
|
|
||||||
text_embeddings.append(embedding)
|
|
||||||
|
|
||||||
input_embedding = await llm_request.get_embedding(input_text)
|
|
||||||
if not input_embedding or not text_embeddings:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 计算余弦相似度
|
|
||||||
text_embeddings = np.array(text_embeddings)
|
|
||||||
similarities = np.dot(text_embeddings, input_embedding) / (
|
|
||||||
np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取top_k的索引
|
|
||||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
|
||||||
|
|
||||||
# 获取相似表达
|
|
||||||
similar_exprs = []
|
|
||||||
for idx in top_indices:
|
|
||||||
if similarities[idx] > 0: # 只保留有相似度的
|
|
||||||
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx]))
|
|
||||||
|
|
||||||
return similar_exprs
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# 获取所有群聊ID
|
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
|
||||||
chat_ids = [os.path.basename(d) for d in style_dirs]
|
|
||||||
|
|
||||||
if not chat_ids:
|
|
||||||
print("没有找到任何群聊的表达方式数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("可用的群聊:")
|
|
||||||
for i, chat_id in enumerate(chat_ids, 1):
|
|
||||||
group_name = get_group_name(chat_id)
|
|
||||||
print(f"{i}. {group_name}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
|
||||||
if choice == 0:
|
|
||||||
break
|
|
||||||
if 1 <= choice <= len(chat_ids):
|
|
||||||
chat_id = chat_ids[choice - 1]
|
|
||||||
break
|
|
||||||
print("无效的选择,请重试")
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效的数字")
|
|
||||||
|
|
||||||
if choice == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 加载表达方式
|
|
||||||
style_exprs = load_expressions(chat_id)
|
|
||||||
|
|
||||||
group_name = get_group_name(chat_id)
|
|
||||||
print(f"\n已选择群聊:{group_name}")
|
|
||||||
|
|
||||||
# 选择匹配模式
|
|
||||||
print("\n请选择匹配模式:")
|
|
||||||
print("1. 匹配表达方式")
|
|
||||||
print("2. 匹配情景")
|
|
||||||
print("3. 两者都考虑")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
mode_choice = int(input("\n请选择匹配模式 (1-3): "))
|
|
||||||
if 1 <= mode_choice <= 3:
|
|
||||||
break
|
|
||||||
print("无效的选择,请重试")
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效的数字")
|
|
||||||
|
|
||||||
mode_map = {1: "style", 2: "situation", 3: "both"}
|
|
||||||
mode = mode_map[mode_choice]
|
|
||||||
|
|
||||||
# 选择匹配方法
|
|
||||||
print("\n请选择匹配方法:")
|
|
||||||
print("1. TF-IDF方法")
|
|
||||||
print("2. 嵌入模型方法")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
method_choice = int(input("\n请选择匹配方法 (1-2): "))
|
|
||||||
if 1 <= method_choice <= 2:
|
|
||||||
break
|
|
||||||
print("无效的选择,请重试")
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效的数字")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
input_text = input("\n请输入要匹配的文本(输入q退出): ")
|
|
||||||
if input_text.lower() == "q":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not input_text.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if method_choice == 1:
|
|
||||||
similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode)
|
|
||||||
else:
|
|
||||||
similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode)
|
|
||||||
|
|
||||||
if similar_exprs:
|
|
||||||
print("\n找到以下相似表达:")
|
|
||||||
for style, situation, score in similar_exprs:
|
|
||||||
print(f"\n\033[33m表达方式:{style}\033[0m")
|
|
||||||
print(f"\033[32m对应情景:{situation}\033[0m")
|
|
||||||
print(f"相似度:{score:.3f}")
|
|
||||||
print("-" * 20)
|
|
||||||
else:
|
|
||||||
print("\n没有找到相似的表达方式")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
287
scripts/interest_value_analysis.py
Normal file
287
scripts/interest_value_analysis.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from src.common.database.database_model import Messages, ChatStreams
|
||||||
|
# Add project root to Python path
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_name(chat_id: str) -> str:
|
||||||
|
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||||
|
try:
|
||||||
|
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||||
|
if chat_stream is None:
|
||||||
|
return f"未知聊天 ({chat_id})"
|
||||||
|
|
||||||
|
if chat_stream.group_name:
|
||||||
|
return f"{chat_stream.group_name} ({chat_id})"
|
||||||
|
elif chat_stream.user_nickname:
|
||||||
|
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||||
|
else:
|
||||||
|
return f"未知聊天 ({chat_id})"
|
||||||
|
except Exception:
|
||||||
|
return f"查询失败 ({chat_id})"
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(timestamp: float) -> str:
|
||||||
|
"""Format timestamp to readable date string"""
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return "未知时间"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||||
|
"""Calculate distribution of interest_value"""
|
||||||
|
distribution = {
|
||||||
|
'0.000-0.010': 0,
|
||||||
|
'0.010-0.050': 0,
|
||||||
|
'0.050-0.100': 0,
|
||||||
|
'0.100-0.500': 0,
|
||||||
|
'0.500-1.000': 0,
|
||||||
|
'1.000-2.000': 0,
|
||||||
|
'2.000-5.000': 0,
|
||||||
|
'5.000-10.000': 0,
|
||||||
|
'10.000+': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.interest_value is None or msg.interest_value == 0.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = float(msg.interest_value)
|
||||||
|
if value < 0.010:
|
||||||
|
distribution['0.000-0.010'] += 1
|
||||||
|
elif value < 0.050:
|
||||||
|
distribution['0.010-0.050'] += 1
|
||||||
|
elif value < 0.100:
|
||||||
|
distribution['0.050-0.100'] += 1
|
||||||
|
elif value < 0.500:
|
||||||
|
distribution['0.100-0.500'] += 1
|
||||||
|
elif value < 1.000:
|
||||||
|
distribution['0.500-1.000'] += 1
|
||||||
|
elif value < 2.000:
|
||||||
|
distribution['1.000-2.000'] += 1
|
||||||
|
elif value < 5.000:
|
||||||
|
distribution['2.000-5.000'] += 1
|
||||||
|
elif value < 10.000:
|
||||||
|
distribution['5.000-10.000'] += 1
|
||||||
|
else:
|
||||||
|
distribution['10.000+'] += 1
|
||||||
|
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||||
|
"""Calculate basic statistics for interest_value"""
|
||||||
|
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return {
|
||||||
|
'count': 0,
|
||||||
|
'min': 0,
|
||||||
|
'max': 0,
|
||||||
|
'avg': 0,
|
||||||
|
'median': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
values.sort()
|
||||||
|
count = len(values)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'count': count,
|
||||||
|
'min': min(values),
|
||||||
|
'max': max(values),
|
||||||
|
'avg': sum(values) / count,
|
||||||
|
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||||
|
"""Get all available chats with message counts"""
|
||||||
|
try:
|
||||||
|
# 获取所有有消息的chat_id
|
||||||
|
chat_counts = {}
|
||||||
|
for msg in Messages.select(Messages.chat_id).distinct():
|
||||||
|
chat_id = msg.chat_id
|
||||||
|
count = Messages.select().where(
|
||||||
|
(Messages.chat_id == chat_id) &
|
||||||
|
(Messages.interest_value.is_null(False)) &
|
||||||
|
(Messages.interest_value != 0.0)
|
||||||
|
).count()
|
||||||
|
if count > 0:
|
||||||
|
chat_counts[chat_id] = count
|
||||||
|
|
||||||
|
# 获取聊天名称
|
||||||
|
result = []
|
||||||
|
for chat_id, count in chat_counts.items():
|
||||||
|
chat_name = get_chat_name(chat_id)
|
||||||
|
result.append((chat_id, chat_name, count))
|
||||||
|
|
||||||
|
# 按消息数量排序
|
||||||
|
result.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取聊天列表失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||||
|
"""Get time range input from user"""
|
||||||
|
print("\n时间范围选择:")
|
||||||
|
print("1. 最近1天")
|
||||||
|
print("2. 最近3天")
|
||||||
|
print("3. 最近7天")
|
||||||
|
print("4. 最近30天")
|
||||||
|
print("5. 自定义时间范围")
|
||||||
|
print("6. 不限制时间")
|
||||||
|
|
||||||
|
choice = input("请选择时间范围 (1-6): ").strip()
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
return now - 24*3600, now
|
||||||
|
elif choice == "2":
|
||||||
|
return now - 3*24*3600, now
|
||||||
|
elif choice == "3":
|
||||||
|
return now - 7*24*3600, now
|
||||||
|
elif choice == "4":
|
||||||
|
return now - 30*24*3600, now
|
||||||
|
elif choice == "5":
|
||||||
|
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||||
|
start_str = input().strip()
|
||||||
|
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||||
|
end_str = input().strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||||
|
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||||
|
return start_time, end_time
|
||||||
|
except ValueError:
|
||||||
|
print("时间格式错误,将不限制时间范围")
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||||
|
"""Analyze interest values with optional filters"""
|
||||||
|
|
||||||
|
# 构建查询条件
|
||||||
|
query = Messages.select().where(
|
||||||
|
(Messages.interest_value.is_null(False)) &
|
||||||
|
(Messages.interest_value != 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if chat_id:
|
||||||
|
query = query.where(Messages.chat_id == chat_id)
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
query = query.where(Messages.time >= start_time)
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
query = query.where(Messages.time <= end_time)
|
||||||
|
|
||||||
|
messages = list(query)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
print("没有找到符合条件的消息")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算统计信息
|
||||||
|
distribution = calculate_interest_value_distribution(messages)
|
||||||
|
stats = get_interest_value_stats(messages)
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
print("\n=== Interest Value 分析结果 ===")
|
||||||
|
if chat_id:
|
||||||
|
print(f"聊天: {get_chat_name(chat_id)}")
|
||||||
|
else:
|
||||||
|
print("聊天: 全部聊天")
|
||||||
|
|
||||||
|
if start_time and end_time:
|
||||||
|
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||||
|
elif start_time:
|
||||||
|
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||||
|
elif end_time:
|
||||||
|
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||||
|
else:
|
||||||
|
print("时间范围: 不限制")
|
||||||
|
|
||||||
|
print("\n基本统计:")
|
||||||
|
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
||||||
|
print(f"最小值: {stats['min']:.3f}")
|
||||||
|
print(f"最大值: {stats['max']:.3f}")
|
||||||
|
print(f"平均值: {stats['avg']:.3f}")
|
||||||
|
print(f"中位数: {stats['median']:.3f}")
|
||||||
|
|
||||||
|
print("\nInterest Value 分布:")
|
||||||
|
total = stats['count']
|
||||||
|
for range_name, count in distribution.items():
|
||||||
|
if count > 0:
|
||||||
|
percentage = count / total * 100
|
||||||
|
print(f"{range_name}: {count} ({percentage:.2f}%)")
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_menu() -> None:
|
||||||
|
"""Interactive menu for interest value analysis"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("Interest Value 分析工具")
|
||||||
|
print("="*50)
|
||||||
|
print("1. 分析全部聊天")
|
||||||
|
print("2. 选择特定聊天分析")
|
||||||
|
print("q. 退出")
|
||||||
|
|
||||||
|
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||||
|
|
||||||
|
if choice.lower() == 'q':
|
||||||
|
print("再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
chat_id = None
|
||||||
|
|
||||||
|
if choice == "2":
|
||||||
|
# 显示可用的聊天列表
|
||||||
|
chats = get_available_chats()
|
||||||
|
if not chats:
|
||||||
|
print("没有找到有interest_value数据的聊天")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||||
|
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||||
|
print(f"{i}. {name} ({count}条有效消息)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||||
|
if 1 <= chat_choice <= len(chats):
|
||||||
|
chat_id = chats[chat_choice - 1][0]
|
||||||
|
else:
|
||||||
|
print("无效选择")
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
print("请输入有效数字")
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif choice != "1":
|
||||||
|
print("无效选择")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取时间范围
|
||||||
|
start_time, end_time = get_time_range_input()
|
||||||
|
|
||||||
|
# 执行分析
|
||||||
|
analyze_interest_values(chat_id, start_time, end_time)
|
||||||
|
|
||||||
|
input("\n按回车键继续...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
interactive_menu()
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import ttk, messagebox, filedialog
|
from tkinter import ttk, messagebox, filedialog, colorchooser
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
@@ -206,6 +206,23 @@ class LogFormatter:
|
|||||||
parts.append(str(event))
|
parts.append(str(event))
|
||||||
tags.append("message")
|
tags.append("message")
|
||||||
|
|
||||||
|
# 处理其他字段
|
||||||
|
extras = []
|
||||||
|
for key, value in log_entry.items():
|
||||||
|
if key not in ("timestamp", "level", "logger_name", "event"):
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
try:
|
||||||
|
value_str = json.dumps(value, ensure_ascii=False, indent=None)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value_str = str(value)
|
||||||
|
else:
|
||||||
|
value_str = str(value)
|
||||||
|
extras.append(f"{key}={value_str}")
|
||||||
|
|
||||||
|
if extras:
|
||||||
|
parts.append(" ".join(extras))
|
||||||
|
tags.append("extras")
|
||||||
|
|
||||||
return parts, tags
|
return parts, tags
|
||||||
|
|
||||||
def format_timestamp(self, timestamp):
|
def format_timestamp(self, timestamp):
|
||||||
@@ -287,6 +304,7 @@ class VirtualLogDisplay:
|
|||||||
self.text_widget.tag_configure("level", foreground="#808080")
|
self.text_widget.tag_configure("level", foreground="#808080")
|
||||||
self.text_widget.tag_configure("module", foreground="#808080")
|
self.text_widget.tag_configure("module", foreground="#808080")
|
||||||
self.text_widget.tag_configure("message", foreground="#ffffff")
|
self.text_widget.tag_configure("message", foreground="#ffffff")
|
||||||
|
self.text_widget.tag_configure("extras", foreground="#808080")
|
||||||
|
|
||||||
# 日志级别颜色标签
|
# 日志级别颜色标签
|
||||||
for level, color in self.formatter.level_colors.items():
|
for level, color in self.formatter.level_colors.items():
|
||||||
@@ -449,7 +467,7 @@ class LogViewer:
|
|||||||
self.load_config()
|
self.load_config()
|
||||||
|
|
||||||
# 初始化日志格式化器
|
# 初始化日志格式化器
|
||||||
self.formatter = LogFormatter(self.log_config, {}, {})
|
self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors)
|
||||||
|
|
||||||
# 初始化日志文件路径
|
# 初始化日志文件路径
|
||||||
self.current_log_file = Path("logs/app.log.jsonl")
|
self.current_log_file = Path("logs/app.log.jsonl")
|
||||||
@@ -467,6 +485,9 @@ class LogViewer:
|
|||||||
self.main_frame = ttk.Frame(root)
|
self.main_frame = ttk.Frame(root)
|
||||||
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
|
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
|
||||||
|
|
||||||
|
# 创建菜单栏
|
||||||
|
self.create_menu()
|
||||||
|
|
||||||
# 创建控制面板
|
# 创建控制面板
|
||||||
self.create_control_panel()
|
self.create_control_panel()
|
||||||
|
|
||||||
@@ -477,12 +498,30 @@ class LogViewer:
|
|||||||
# 模块名映射
|
# 模块名映射
|
||||||
self.module_name_mapping = {
|
self.module_name_mapping = {
|
||||||
"api": "API接口",
|
"api": "API接口",
|
||||||
|
"async_task_manager": "异步任务管理器",
|
||||||
|
"background_tasks": "后台任务",
|
||||||
|
"base_tool": "基础工具",
|
||||||
|
"chat_stream": "聊天流",
|
||||||
|
"component_registry": "组件注册器",
|
||||||
"config": "配置",
|
"config": "配置",
|
||||||
"chat": "聊天",
|
"database_model": "数据库模型",
|
||||||
"plugin": "插件",
|
"emoji": "表情",
|
||||||
|
"heartflow": "心流",
|
||||||
|
"local_storage": "本地存储",
|
||||||
|
"lpmm": "LPMM",
|
||||||
|
"maibot_statistic": "MaiBot统计",
|
||||||
|
"main_message": "主消息",
|
||||||
"main": "主程序",
|
"main": "主程序",
|
||||||
|
"memory": "内存",
|
||||||
|
"mood": "情绪",
|
||||||
|
"plugin_manager": "插件管理器",
|
||||||
|
"remote": "远程",
|
||||||
|
"willing": "意愿",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 加载自定义映射
|
||||||
|
self.load_module_mapping()
|
||||||
|
|
||||||
# 选中的模块集合
|
# 选中的模块集合
|
||||||
self.selected_modules = set()
|
self.selected_modules = set()
|
||||||
self.modules = set()
|
self.modules = set()
|
||||||
@@ -491,19 +530,35 @@ class LogViewer:
|
|||||||
self.level_combo.bind("<<ComboboxSelected>>", self.filter_logs)
|
self.level_combo.bind("<<ComboboxSelected>>", self.filter_logs)
|
||||||
self.search_var.trace("w", self.filter_logs)
|
self.search_var.trace("w", self.filter_logs)
|
||||||
|
|
||||||
|
# 绑定快捷键
|
||||||
|
self.root.bind("<Control-o>", lambda e: self.select_log_file())
|
||||||
|
self.root.bind("<F5>", lambda e: self.refresh_log_file())
|
||||||
|
self.root.bind("<Control-s>", lambda e: self.export_logs())
|
||||||
|
|
||||||
# 初始加载文件
|
# 初始加载文件
|
||||||
if self.current_log_file.exists():
|
if self.current_log_file.exists():
|
||||||
self.load_log_file_async()
|
self.load_log_file_async()
|
||||||
|
|
||||||
def load_config(self):
|
def load_config(self):
|
||||||
"""加载配置文件"""
|
"""加载配置文件"""
|
||||||
|
# 默认配置
|
||||||
self.default_config = {
|
self.default_config = {
|
||||||
"log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full"},
|
"log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full", "log_level": "INFO"},
|
||||||
|
"viewer": {
|
||||||
|
"theme": "dark",
|
||||||
|
"font_size": 10,
|
||||||
|
"max_lines": 1000,
|
||||||
|
"auto_scroll": True,
|
||||||
|
"show_milliseconds": False,
|
||||||
|
"window": {"width": 1200, "height": 800, "remember_position": True},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.log_config = self.default_config["log"].copy()
|
# 从bot_config.toml加载日志配置
|
||||||
|
|
||||||
config_path = Path("config/bot_config.toml")
|
config_path = Path("config/bot_config.toml")
|
||||||
|
self.log_config = self.default_config["log"].copy()
|
||||||
|
self.viewer_config = self.default_config["viewer"].copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
@@ -511,7 +566,377 @@ class LogViewer:
|
|||||||
if "log" in bot_config:
|
if "log" in bot_config:
|
||||||
self.log_config.update(bot_config["log"])
|
self.log_config.update(bot_config["log"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"加载配置失败: {e}")
|
print(f"加载bot配置失败: {e}")
|
||||||
|
|
||||||
|
# 从viewer配置文件加载查看器配置
|
||||||
|
viewer_config_path = Path("config/log_viewer_config.toml")
|
||||||
|
self.custom_module_colors = {}
|
||||||
|
self.custom_level_colors = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if viewer_config_path.exists():
|
||||||
|
with open(viewer_config_path, "r", encoding="utf-8") as f:
|
||||||
|
viewer_config = toml.load(f)
|
||||||
|
if "viewer" in viewer_config:
|
||||||
|
self.viewer_config.update(viewer_config["viewer"])
|
||||||
|
|
||||||
|
# 加载自定义模块颜色
|
||||||
|
if "module_colors" in viewer_config["viewer"]:
|
||||||
|
self.custom_module_colors = viewer_config["viewer"]["module_colors"]
|
||||||
|
|
||||||
|
# 加载自定义级别颜色
|
||||||
|
if "level_colors" in viewer_config["viewer"]:
|
||||||
|
self.custom_level_colors = viewer_config["viewer"]["level_colors"]
|
||||||
|
|
||||||
|
if "log" in viewer_config:
|
||||||
|
self.log_config.update(viewer_config["log"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载查看器配置失败: {e}")
|
||||||
|
|
||||||
|
# 应用窗口配置
|
||||||
|
window_config = self.viewer_config.get("window", {})
|
||||||
|
window_width = window_config.get("width", 1200)
|
||||||
|
window_height = window_config.get("height", 800)
|
||||||
|
self.root.geometry(f"{window_width}x{window_height}")
|
||||||
|
|
||||||
|
def save_viewer_config(self):
|
||||||
|
"""保存查看器配置"""
|
||||||
|
# 准备完整的配置数据
|
||||||
|
viewer_config_copy = self.viewer_config.copy()
|
||||||
|
|
||||||
|
# 保存自定义颜色(只保存与默认值不同的颜色)
|
||||||
|
if self.custom_module_colors:
|
||||||
|
viewer_config_copy["module_colors"] = self.custom_module_colors
|
||||||
|
if self.custom_level_colors:
|
||||||
|
viewer_config_copy["level_colors"] = self.custom_level_colors
|
||||||
|
|
||||||
|
config_data = {"log": self.log_config, "viewer": viewer_config_copy}
|
||||||
|
|
||||||
|
config_path = Path("config/log_viewer_config.toml")
|
||||||
|
config_path.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
toml.dump(config_data, f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存查看器配置失败: {e}")
|
||||||
|
|
||||||
|
def create_menu(self):
|
||||||
|
"""创建菜单栏"""
|
||||||
|
menubar = tk.Menu(self.root)
|
||||||
|
self.root.config(menu=menubar)
|
||||||
|
|
||||||
|
# 配置菜单
|
||||||
|
config_menu = tk.Menu(menubar, tearoff=0)
|
||||||
|
menubar.add_cascade(label="配置", menu=config_menu)
|
||||||
|
config_menu.add_command(label="日志格式设置", command=self.show_format_settings)
|
||||||
|
config_menu.add_command(label="颜色设置", command=self.show_color_settings)
|
||||||
|
config_menu.add_command(label="查看器设置", command=self.show_viewer_settings)
|
||||||
|
config_menu.add_separator()
|
||||||
|
config_menu.add_command(label="重新加载配置", command=self.reload_config)
|
||||||
|
|
||||||
|
# 文件菜单
|
||||||
|
file_menu = tk.Menu(menubar, tearoff=0)
|
||||||
|
menubar.add_cascade(label="文件", menu=file_menu)
|
||||||
|
file_menu.add_command(label="选择日志文件", command=self.select_log_file, accelerator="Ctrl+O")
|
||||||
|
file_menu.add_command(label="刷新当前文件", command=self.refresh_log_file, accelerator="F5")
|
||||||
|
file_menu.add_separator()
|
||||||
|
file_menu.add_command(label="导出当前日志", command=self.export_logs, accelerator="Ctrl+S")
|
||||||
|
|
||||||
|
# 工具菜单
|
||||||
|
tools_menu = tk.Menu(menubar, tearoff=0)
|
||||||
|
menubar.add_cascade(label="工具", menu=tools_menu)
|
||||||
|
tools_menu.add_command(label="清空日志显示", command=self.clear_log_display)
|
||||||
|
|
||||||
|
def show_format_settings(self):
|
||||||
|
"""显示格式设置窗口"""
|
||||||
|
format_window = tk.Toplevel(self.root)
|
||||||
|
format_window.title("日志格式设置")
|
||||||
|
format_window.geometry("400x300")
|
||||||
|
|
||||||
|
frame = ttk.Frame(format_window)
|
||||||
|
frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||||
|
|
||||||
|
# 日期格式
|
||||||
|
ttk.Label(frame, text="日期格式:").pack(anchor="w", pady=2)
|
||||||
|
date_style_var = tk.StringVar(value=self.log_config.get("date_style", "m-d H:i:s"))
|
||||||
|
date_entry = ttk.Entry(frame, textvariable=date_style_var, width=30)
|
||||||
|
date_entry.pack(anchor="w", pady=2)
|
||||||
|
ttk.Label(frame, text="格式说明: Y=年份, m=月份, d=日期, H=小时, i=分钟, s=秒", font=("", 8)).pack(
|
||||||
|
anchor="w", pady=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 日志级别样式
|
||||||
|
ttk.Label(frame, text="日志级别样式:").pack(anchor="w", pady=(10, 2))
|
||||||
|
level_style_var = tk.StringVar(value=self.log_config.get("log_level_style", "lite"))
|
||||||
|
level_frame = ttk.Frame(frame)
|
||||||
|
level_frame.pack(anchor="w", pady=2)
|
||||||
|
|
||||||
|
ttk.Radiobutton(level_frame, text="简洁(lite)", variable=level_style_var, value="lite").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
ttk.Radiobutton(level_frame, text="紧凑(compact)", variable=level_style_var, value="compact").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
ttk.Radiobutton(level_frame, text="完整(full)", variable=level_style_var, value="full").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 颜色文本设置
|
||||||
|
ttk.Label(frame, text="文本颜色设置:").pack(anchor="w", pady=(10, 2))
|
||||||
|
color_text_var = tk.StringVar(value=self.log_config.get("color_text", "full"))
|
||||||
|
color_frame = ttk.Frame(frame)
|
||||||
|
color_frame.pack(anchor="w", pady=2)
|
||||||
|
|
||||||
|
ttk.Radiobutton(color_frame, text="无颜色(none)", variable=color_text_var, value="none").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
ttk.Radiobutton(color_frame, text="仅标题(title)", variable=color_text_var, value="title").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
ttk.Radiobutton(color_frame, text="全部(full)", variable=color_text_var, value="full").pack(
|
||||||
|
side="left", padx=(0, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按钮
|
||||||
|
button_frame = ttk.Frame(frame)
|
||||||
|
button_frame.pack(fill="x", pady=(20, 0))
|
||||||
|
|
||||||
|
def apply_format():
|
||||||
|
self.log_config["date_style"] = date_style_var.get()
|
||||||
|
self.log_config["log_level_style"] = level_style_var.get()
|
||||||
|
self.log_config["color_text"] = color_text_var.get()
|
||||||
|
|
||||||
|
# 重新初始化格式化器
|
||||||
|
self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors)
|
||||||
|
self.log_display.formatter = self.formatter
|
||||||
|
self.log_display.configure_text_tags()
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
self.save_viewer_config()
|
||||||
|
|
||||||
|
# 重新过滤日志以应用新格式
|
||||||
|
self.filter_logs()
|
||||||
|
|
||||||
|
format_window.destroy()
|
||||||
|
|
||||||
|
ttk.Button(button_frame, text="应用", command=apply_format).pack(side="right", padx=(5, 0))
|
||||||
|
ttk.Button(button_frame, text="取消", command=format_window.destroy).pack(side="right")
|
||||||
|
|
||||||
|
def show_viewer_settings(self):
|
||||||
|
"""显示查看器设置窗口"""
|
||||||
|
viewer_window = tk.Toplevel(self.root)
|
||||||
|
viewer_window.title("查看器设置")
|
||||||
|
viewer_window.geometry("350x250")
|
||||||
|
|
||||||
|
frame = ttk.Frame(viewer_window)
|
||||||
|
frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||||
|
|
||||||
|
# 主题设置
|
||||||
|
ttk.Label(frame, text="主题:").pack(anchor="w", pady=2)
|
||||||
|
theme_var = tk.StringVar(value=self.viewer_config.get("theme", "dark"))
|
||||||
|
theme_frame = ttk.Frame(frame)
|
||||||
|
theme_frame.pack(anchor="w", pady=2)
|
||||||
|
ttk.Radiobutton(theme_frame, text="深色", variable=theme_var, value="dark").pack(side="left", padx=(0, 10))
|
||||||
|
ttk.Radiobutton(theme_frame, text="浅色", variable=theme_var, value="light").pack(side="left")
|
||||||
|
|
||||||
|
# 字体大小
|
||||||
|
ttk.Label(frame, text="字体大小:").pack(anchor="w", pady=(10, 2))
|
||||||
|
font_size_var = tk.IntVar(value=self.viewer_config.get("font_size", 10))
|
||||||
|
font_size_spin = ttk.Spinbox(frame, from_=8, to=20, textvariable=font_size_var, width=10)
|
||||||
|
font_size_spin.pack(anchor="w", pady=2)
|
||||||
|
|
||||||
|
# 最大行数
|
||||||
|
ttk.Label(frame, text="最大显示行数:").pack(anchor="w", pady=(10, 2))
|
||||||
|
max_lines_var = tk.IntVar(value=self.viewer_config.get("max_lines", 1000))
|
||||||
|
max_lines_spin = ttk.Spinbox(frame, from_=100, to=10000, increment=100, textvariable=max_lines_var, width=10)
|
||||||
|
max_lines_spin.pack(anchor="w", pady=2)
|
||||||
|
|
||||||
|
# 自动滚动
|
||||||
|
auto_scroll_var = tk.BooleanVar(value=self.viewer_config.get("auto_scroll", True))
|
||||||
|
ttk.Checkbutton(frame, text="自动滚动到底部", variable=auto_scroll_var).pack(anchor="w", pady=(10, 2))
|
||||||
|
|
||||||
|
# 按钮
|
||||||
|
button_frame = ttk.Frame(frame)
|
||||||
|
button_frame.pack(fill="x", pady=(20, 0))
|
||||||
|
|
||||||
|
def apply_viewer_settings():
|
||||||
|
self.viewer_config["theme"] = theme_var.get()
|
||||||
|
self.viewer_config["font_size"] = font_size_var.get()
|
||||||
|
self.viewer_config["max_lines"] = max_lines_var.get()
|
||||||
|
self.viewer_config["auto_scroll"] = auto_scroll_var.get()
|
||||||
|
|
||||||
|
# 应用主题
|
||||||
|
self.apply_theme()
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
self.save_viewer_config()
|
||||||
|
|
||||||
|
viewer_window.destroy()
|
||||||
|
|
||||||
|
ttk.Button(button_frame, text="应用", command=apply_viewer_settings).pack(side="right", padx=(5, 0))
|
||||||
|
ttk.Button(button_frame, text="取消", command=viewer_window.destroy).pack(side="right")
|
||||||
|
|
||||||
|
def apply_theme(self):
|
||||||
|
"""应用主题设置"""
|
||||||
|
theme = self.viewer_config.get("theme", "dark")
|
||||||
|
font_size = self.viewer_config.get("font_size", 10)
|
||||||
|
|
||||||
|
# 更新虚拟显示组件的主题
|
||||||
|
if theme == "dark":
|
||||||
|
bg_color = "#1e1e1e"
|
||||||
|
fg_color = "#ffffff"
|
||||||
|
select_bg = "#404040"
|
||||||
|
else:
|
||||||
|
bg_color = "#ffffff"
|
||||||
|
fg_color = "#000000"
|
||||||
|
select_bg = "#c0c0c0"
|
||||||
|
|
||||||
|
self.log_display.text_widget.config(
|
||||||
|
background=bg_color, foreground=fg_color, selectbackground=select_bg, font=("Consolas", font_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重新配置标签样式
|
||||||
|
self.log_display.configure_text_tags()
|
||||||
|
|
||||||
|
def reload_config(self):
|
||||||
|
"""重新加载配置"""
|
||||||
|
self.load_config()
|
||||||
|
self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors)
|
||||||
|
self.log_display.formatter = self.formatter
|
||||||
|
self.log_display.configure_text_tags()
|
||||||
|
self.apply_theme()
|
||||||
|
self.filter_logs()
|
||||||
|
|
||||||
|
def clear_log_display(self):
|
||||||
|
"""清空日志显示"""
|
||||||
|
self.log_display.text_widget.delete(1.0, tk.END)
|
||||||
|
|
||||||
|
def export_logs(self):
|
||||||
|
"""导出当前显示的日志"""
|
||||||
|
filename = filedialog.asksaveasfilename(
|
||||||
|
defaultextension=".txt", filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")]
|
||||||
|
)
|
||||||
|
if filename:
|
||||||
|
try:
|
||||||
|
# 获取当前显示的所有日志条目
|
||||||
|
if self.log_index:
|
||||||
|
filtered_count = self.log_index.get_filtered_count()
|
||||||
|
log_lines = []
|
||||||
|
for i in range(filtered_count):
|
||||||
|
log_entry = self.log_index.get_entry_at_filtered_position(i)
|
||||||
|
if log_entry:
|
||||||
|
parts, tags = self.formatter.format_log_entry(log_entry)
|
||||||
|
line_text = " ".join(parts)
|
||||||
|
log_lines.append(line_text)
|
||||||
|
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(log_lines))
|
||||||
|
messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
|
||||||
|
else:
|
||||||
|
messagebox.showwarning("导出失败", "没有日志可导出")
|
||||||
|
except Exception as e:
|
||||||
|
messagebox.showerror("导出失败", f"导出日志时出错: {e}")
|
||||||
|
|
||||||
|
def load_module_mapping(self):
|
||||||
|
"""加载自定义模块映射"""
|
||||||
|
mapping_file = Path("config/module_mapping.json")
|
||||||
|
if mapping_file.exists():
|
||||||
|
try:
|
||||||
|
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||||
|
custom_mapping = json.load(f)
|
||||||
|
self.module_name_mapping.update(custom_mapping)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载模块映射失败: {e}")
|
||||||
|
|
||||||
|
def save_module_mapping(self):
|
||||||
|
"""保存自定义模块映射"""
|
||||||
|
mapping_file = Path("config/module_mapping.json")
|
||||||
|
mapping_file.parent.mkdir(exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(mapping_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.module_name_mapping, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存模块映射失败: {e}")
|
||||||
|
|
||||||
|
def show_color_settings(self):
|
||||||
|
"""显示颜色设置窗口"""
|
||||||
|
color_window = tk.Toplevel(self.root)
|
||||||
|
color_window.title("颜色设置")
|
||||||
|
color_window.geometry("300x400")
|
||||||
|
|
||||||
|
# 创建滚动框架
|
||||||
|
frame = ttk.Frame(color_window)
|
||||||
|
frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
|
||||||
|
|
||||||
|
# 创建滚动条
|
||||||
|
scrollbar = ttk.Scrollbar(frame)
|
||||||
|
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||||
|
|
||||||
|
# 创建颜色设置列表
|
||||||
|
canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set)
|
||||||
|
canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||||
|
scrollbar.config(command=canvas.yview)
|
||||||
|
|
||||||
|
# 创建内部框架
|
||||||
|
inner_frame = ttk.Frame(canvas)
|
||||||
|
canvas.create_window((0, 0), window=inner_frame, anchor="nw")
|
||||||
|
|
||||||
|
# 添加日志级别颜色设置
|
||||||
|
ttk.Label(inner_frame, text="日志级别颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5)
|
||||||
|
for level in ["info", "warning", "error"]:
|
||||||
|
frame = ttk.Frame(inner_frame)
|
||||||
|
frame.pack(fill=tk.X, padx=5, pady=2)
|
||||||
|
ttk.Label(frame, text=level).pack(side=tk.LEFT)
|
||||||
|
color_btn = ttk.Button(
|
||||||
|
frame, text="选择颜色", command=lambda level_name=level: self.choose_color(level_name)
|
||||||
|
)
|
||||||
|
color_btn.pack(side=tk.RIGHT)
|
||||||
|
# 显示当前颜色
|
||||||
|
color_label = ttk.Label(frame, text="■", foreground=self.formatter.level_colors[level])
|
||||||
|
color_label.pack(side=tk.RIGHT, padx=5)
|
||||||
|
|
||||||
|
# 添加模块颜色设置
|
||||||
|
ttk.Label(inner_frame, text="\n模块颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5)
|
||||||
|
for module in sorted(self.modules):
|
||||||
|
frame = ttk.Frame(inner_frame)
|
||||||
|
frame.pack(fill=tk.X, padx=5, pady=2)
|
||||||
|
ttk.Label(frame, text=module).pack(side=tk.LEFT)
|
||||||
|
color_btn = ttk.Button(frame, text="选择颜色", command=lambda m=module: self.choose_module_color(m))
|
||||||
|
color_btn.pack(side=tk.RIGHT)
|
||||||
|
# 显示当前颜色
|
||||||
|
color = self.formatter.module_colors.get(module, "black")
|
||||||
|
color_label = ttk.Label(frame, text="■", foreground=color)
|
||||||
|
color_label.pack(side=tk.RIGHT, padx=5)
|
||||||
|
|
||||||
|
# 更新画布滚动区域
|
||||||
|
inner_frame.update_idletasks()
|
||||||
|
canvas.config(scrollregion=canvas.bbox("all"))
|
||||||
|
|
||||||
|
# 添加确定按钮
|
||||||
|
ttk.Button(color_window, text="确定", command=color_window.destroy).pack(pady=5)
|
||||||
|
|
||||||
|
def choose_color(self, level):
|
||||||
|
"""选择日志级别颜色"""
|
||||||
|
color = colorchooser.askcolor(color=self.formatter.level_colors[level])[1]
|
||||||
|
if color:
|
||||||
|
self.formatter.level_colors[level] = color
|
||||||
|
self.custom_level_colors[level] = color # 保存到自定义颜色
|
||||||
|
self.log_display.formatter = self.formatter
|
||||||
|
self.log_display.configure_text_tags()
|
||||||
|
self.save_viewer_config() # 自动保存配置
|
||||||
|
self.filter_logs()
|
||||||
|
|
||||||
|
def choose_module_color(self, module):
|
||||||
|
"""选择模块颜色"""
|
||||||
|
color = colorchooser.askcolor(color=self.formatter.module_colors.get(module, "black"))[1]
|
||||||
|
if color:
|
||||||
|
self.formatter.module_colors[module] = color
|
||||||
|
self.custom_module_colors[module] = color # 保存到自定义颜色
|
||||||
|
self.log_display.formatter = self.formatter
|
||||||
|
self.log_display.configure_text_tags()
|
||||||
|
self.save_viewer_config() # 自动保存配置
|
||||||
|
self.filter_logs()
|
||||||
|
|
||||||
def create_control_panel(self):
|
def create_control_panel(self):
|
||||||
"""创建控制面板"""
|
"""创建控制面板"""
|
||||||
@@ -549,30 +974,43 @@ class LogViewer:
|
|||||||
side=tk.LEFT, padx=2
|
side=tk.LEFT, padx=2
|
||||||
)
|
)
|
||||||
|
|
||||||
# 过滤控制框架
|
# 模块选择框架
|
||||||
filter_frame = ttk.Frame(self.control_frame)
|
self.module_frame = ttk.LabelFrame(self.control_frame, text="模块")
|
||||||
filter_frame.pack(fill=tk.X, padx=5)
|
self.module_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5)
|
||||||
|
|
||||||
|
# 创建模块选择滚动区域
|
||||||
|
self.module_canvas = tk.Canvas(self.module_frame, height=80)
|
||||||
|
self.module_canvas.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
||||||
|
|
||||||
|
# 创建模块选择内部框架
|
||||||
|
self.module_inner_frame = ttk.Frame(self.module_canvas)
|
||||||
|
self.module_canvas.create_window((0, 0), window=self.module_inner_frame, anchor="nw")
|
||||||
|
|
||||||
|
# 创建右侧控制区域(级别和搜索)
|
||||||
|
self.right_control_frame = ttk.Frame(self.control_frame)
|
||||||
|
self.right_control_frame.pack(side=tk.RIGHT, padx=5)
|
||||||
|
|
||||||
|
# 映射编辑按钮
|
||||||
|
mapping_btn = ttk.Button(self.right_control_frame, text="模块映射", command=self.edit_module_mapping)
|
||||||
|
mapping_btn.pack(side=tk.TOP, fill=tk.X, pady=1)
|
||||||
|
|
||||||
# 日志级别选择
|
# 日志级别选择
|
||||||
ttk.Label(filter_frame, text="级别:").pack(side=tk.LEFT, padx=2)
|
level_frame = ttk.Frame(self.right_control_frame)
|
||||||
|
level_frame.pack(side=tk.TOP, fill=tk.X, pady=1)
|
||||||
|
ttk.Label(level_frame, text="级别:").pack(side=tk.LEFT, padx=2)
|
||||||
self.level_var = tk.StringVar(value="全部")
|
self.level_var = tk.StringVar(value="全部")
|
||||||
self.level_combo = ttk.Combobox(filter_frame, textvariable=self.level_var, width=8)
|
self.level_combo = ttk.Combobox(level_frame, textvariable=self.level_var, width=8)
|
||||||
self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"]
|
self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"]
|
||||||
self.level_combo.pack(side=tk.LEFT, padx=2)
|
self.level_combo.pack(side=tk.LEFT, padx=2)
|
||||||
|
|
||||||
# 搜索框
|
# 搜索框
|
||||||
ttk.Label(filter_frame, text="搜索:").pack(side=tk.LEFT, padx=(20, 2))
|
search_frame = ttk.Frame(self.right_control_frame)
|
||||||
|
search_frame.pack(side=tk.TOP, fill=tk.X, pady=1)
|
||||||
|
ttk.Label(search_frame, text="搜索:").pack(side=tk.LEFT, padx=2)
|
||||||
self.search_var = tk.StringVar()
|
self.search_var = tk.StringVar()
|
||||||
self.search_entry = ttk.Entry(filter_frame, textvariable=self.search_var, width=20)
|
self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=15)
|
||||||
self.search_entry.pack(side=tk.LEFT, padx=2)
|
self.search_entry.pack(side=tk.LEFT, padx=2)
|
||||||
|
|
||||||
# 模块选择
|
|
||||||
ttk.Label(filter_frame, text="模块:").pack(side=tk.LEFT, padx=(20, 2))
|
|
||||||
self.module_var = tk.StringVar(value="全部")
|
|
||||||
self.module_combo = ttk.Combobox(filter_frame, textvariable=self.module_var, width=15)
|
|
||||||
self.module_combo.pack(side=tk.LEFT, padx=2)
|
|
||||||
self.module_combo.bind("<<ComboboxSelected>>", self.on_module_selected)
|
|
||||||
|
|
||||||
def on_file_loaded(self, log_index, error):
|
def on_file_loaded(self, log_index, error):
|
||||||
"""文件加载完成回调"""
|
"""文件加载完成回调"""
|
||||||
self.progress_bar.pack_forget()
|
self.progress_bar.pack_forget()
|
||||||
@@ -590,6 +1028,7 @@ class LogViewer:
|
|||||||
self.status_var.set(f"已加载 {log_index.total_entries} 条日志")
|
self.status_var.set(f"已加载 {log_index.total_entries} 条日志")
|
||||||
|
|
||||||
# 更新模块列表
|
# 更新模块列表
|
||||||
|
self.modules = set(log_index.module_index.keys())
|
||||||
self.update_module_list()
|
self.update_module_list()
|
||||||
|
|
||||||
# 应用过滤并显示
|
# 应用过滤并显示
|
||||||
@@ -623,22 +1062,11 @@ class LogViewer:
|
|||||||
|
|
||||||
# 清空当前数据
|
# 清空当前数据
|
||||||
self.log_index = LogIndex()
|
self.log_index = LogIndex()
|
||||||
self.modules.clear()
|
|
||||||
self.selected_modules.clear()
|
self.selected_modules.clear()
|
||||||
self.module_var.set("全部")
|
|
||||||
|
|
||||||
# 开始异步加载
|
# 开始异步加载
|
||||||
self.async_loader.load_file_async(str(self.current_log_file), self.on_loading_progress)
|
self.async_loader.load_file_async(str(self.current_log_file), self.on_loading_progress)
|
||||||
|
|
||||||
def on_module_selected(self, event=None):
|
|
||||||
"""模块选择事件"""
|
|
||||||
module = self.module_var.get()
|
|
||||||
if module == "全部":
|
|
||||||
self.selected_modules = {"全部"}
|
|
||||||
else:
|
|
||||||
self.selected_modules = {module}
|
|
||||||
self.filter_logs()
|
|
||||||
|
|
||||||
def filter_logs(self, *args):
|
def filter_logs(self, *args):
|
||||||
"""过滤日志"""
|
"""过滤日志"""
|
||||||
if not self.log_index:
|
if not self.log_index:
|
||||||
@@ -743,7 +1171,7 @@ class LogViewer:
|
|||||||
def read_new_logs(self, from_position):
|
def read_new_logs(self, from_position):
|
||||||
"""读取新的日志条目并返回它们"""
|
"""读取新的日志条目并返回它们"""
|
||||||
new_entries = []
|
new_entries = []
|
||||||
new_modules_found = False
|
new_modules = set() # 收集新发现的模块
|
||||||
with open(self.current_log_file, "r", encoding="utf-8") as f:
|
with open(self.current_log_file, "r", encoding="utf-8") as f:
|
||||||
f.seek(from_position)
|
f.seek(from_position)
|
||||||
line_count = self.log_index.total_entries
|
line_count = self.log_index.total_entries
|
||||||
@@ -756,14 +1184,20 @@ class LogViewer:
|
|||||||
|
|
||||||
logger_name = log_entry.get("logger_name", "")
|
logger_name = log_entry.get("logger_name", "")
|
||||||
if logger_name and logger_name not in self.modules:
|
if logger_name and logger_name not in self.modules:
|
||||||
self.modules.add(logger_name)
|
new_modules.add(logger_name)
|
||||||
new_modules_found = True
|
|
||||||
|
|
||||||
line_count += 1
|
line_count += 1
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
if new_modules_found:
|
|
||||||
self.root.after(0, self.update_module_list)
|
# 如果发现了新模块,在主线程中更新模块集合
|
||||||
|
if new_modules:
|
||||||
|
def update_modules():
|
||||||
|
self.modules.update(new_modules)
|
||||||
|
self.update_module_list()
|
||||||
|
|
||||||
|
self.root.after(0, update_modules)
|
||||||
|
|
||||||
return new_entries
|
return new_entries
|
||||||
|
|
||||||
def append_new_logs(self, new_entries):
|
def append_new_logs(self, new_entries):
|
||||||
@@ -791,15 +1225,196 @@ class LogViewer:
|
|||||||
self.status_var.set(f"显示 {total_count} 条日志")
|
self.status_var.set(f"显示 {total_count} 条日志")
|
||||||
|
|
||||||
def update_module_list(self):
|
def update_module_list(self):
|
||||||
"""更新模块下拉列表"""
|
"""更新模块列表"""
|
||||||
current_selection = self.module_var.get()
|
# 清空现有选项
|
||||||
self.modules = set(self.log_index.module_index.keys())
|
for widget in self.module_inner_frame.winfo_children():
|
||||||
module_values = ["全部"] + sorted(list(self.modules))
|
widget.destroy()
|
||||||
self.module_combo["values"] = module_values
|
|
||||||
if current_selection in module_values:
|
# 计算总模块数(包括"全部")
|
||||||
self.module_var.set(current_selection)
|
total_modules = len(self.modules) + 1
|
||||||
|
max_cols = min(4, max(2, total_modules)) # 减少最大列数,避免超出边界
|
||||||
|
|
||||||
|
# 配置网格列权重,让每列平均分配空间
|
||||||
|
for i in range(max_cols):
|
||||||
|
self.module_inner_frame.grid_columnconfigure(i, weight=1, uniform="module_col")
|
||||||
|
|
||||||
|
# 创建一个多行布局
|
||||||
|
current_row = 0
|
||||||
|
current_col = 0
|
||||||
|
|
||||||
|
# 添加"全部"选项
|
||||||
|
all_frame = ttk.Frame(self.module_inner_frame)
|
||||||
|
all_frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew")
|
||||||
|
|
||||||
|
all_var = tk.BooleanVar(value="全部" in self.selected_modules)
|
||||||
|
all_check = ttk.Checkbutton(
|
||||||
|
all_frame, text="全部", variable=all_var, command=lambda: self.toggle_module("全部", all_var)
|
||||||
|
)
|
||||||
|
all_check.pack(side=tk.LEFT)
|
||||||
|
|
||||||
|
# 使用颜色标签替代按钮
|
||||||
|
all_color = self.formatter.module_colors.get("全部", "black")
|
||||||
|
all_color_label = ttk.Label(all_frame, text="■", foreground=all_color, width=2, cursor="hand2")
|
||||||
|
all_color_label.pack(side=tk.LEFT, padx=2)
|
||||||
|
all_color_label.bind("<Button-1>", lambda e: self.choose_module_color("全部"))
|
||||||
|
|
||||||
|
current_col += 1
|
||||||
|
|
||||||
|
# 添加其他模块选项
|
||||||
|
for module in sorted(self.modules):
|
||||||
|
if current_col >= max_cols:
|
||||||
|
current_row += 1
|
||||||
|
current_col = 0
|
||||||
|
|
||||||
|
frame = ttk.Frame(self.module_inner_frame)
|
||||||
|
frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew")
|
||||||
|
|
||||||
|
var = tk.BooleanVar(value=module in self.selected_modules)
|
||||||
|
|
||||||
|
# 使用中文映射名称显示
|
||||||
|
display_name = self.get_display_name(module)
|
||||||
|
if len(display_name) > 12:
|
||||||
|
display_name = display_name[:10] + "..."
|
||||||
|
|
||||||
|
check = ttk.Checkbutton(
|
||||||
|
frame, text=display_name, variable=var, command=lambda m=module, v=var: self.toggle_module(m, v)
|
||||||
|
)
|
||||||
|
check.pack(side=tk.LEFT)
|
||||||
|
|
||||||
|
# 添加工具提示显示完整名称和英文名
|
||||||
|
full_tooltip = f"{self.get_display_name(module)}"
|
||||||
|
if module != self.get_display_name(module):
|
||||||
|
full_tooltip += f"\n({module})"
|
||||||
|
self.create_tooltip(check, full_tooltip)
|
||||||
|
|
||||||
|
# 使用颜色标签替代按钮
|
||||||
|
color = self.formatter.module_colors.get(module, "black")
|
||||||
|
color_label = ttk.Label(frame, text="■", foreground=color, width=2, cursor="hand2")
|
||||||
|
color_label.pack(side=tk.LEFT, padx=2)
|
||||||
|
color_label.bind("<Button-1>", lambda e, m=module: self.choose_module_color(m))
|
||||||
|
|
||||||
|
current_col += 1
|
||||||
|
|
||||||
|
# 更新画布滚动区域
|
||||||
|
self.module_inner_frame.update_idletasks()
|
||||||
|
self.module_canvas.config(scrollregion=self.module_canvas.bbox("all"))
|
||||||
|
|
||||||
|
# 添加垂直滚动条
|
||||||
|
if not hasattr(self, "module_scrollbar"):
|
||||||
|
self.module_scrollbar = ttk.Scrollbar(
|
||||||
|
self.module_frame, orient=tk.VERTICAL, command=self.module_canvas.yview
|
||||||
|
)
|
||||||
|
self.module_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||||
|
self.module_canvas.config(yscrollcommand=self.module_scrollbar.set)
|
||||||
|
|
||||||
|
def create_tooltip(self, widget, text):
|
||||||
|
"""为控件创建工具提示"""
|
||||||
|
|
||||||
|
def on_enter(event):
|
||||||
|
tooltip = tk.Toplevel()
|
||||||
|
tooltip.wm_overrideredirect(True)
|
||||||
|
tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}")
|
||||||
|
label = ttk.Label(tooltip, text=text, background="lightyellow", relief="solid", borderwidth=1)
|
||||||
|
label.pack()
|
||||||
|
widget.tooltip = tooltip
|
||||||
|
|
||||||
|
def on_leave(event):
|
||||||
|
if hasattr(widget, "tooltip"):
|
||||||
|
widget.tooltip.destroy()
|
||||||
|
del widget.tooltip
|
||||||
|
|
||||||
|
widget.bind("<Enter>", on_enter)
|
||||||
|
widget.bind("<Leave>", on_leave)
|
||||||
|
|
||||||
|
def toggle_module(self, module, var):
|
||||||
|
"""切换模块选择状态"""
|
||||||
|
if module == "全部":
|
||||||
|
if var.get():
|
||||||
|
self.selected_modules = {"全部"}
|
||||||
|
else:
|
||||||
|
self.selected_modules.clear()
|
||||||
else:
|
else:
|
||||||
self.module_var.set("全部")
|
if var.get():
|
||||||
|
self.selected_modules.add(module)
|
||||||
|
if "全部" in self.selected_modules:
|
||||||
|
self.selected_modules.remove("全部")
|
||||||
|
else:
|
||||||
|
self.selected_modules.discard(module)
|
||||||
|
|
||||||
|
self.filter_logs()
|
||||||
|
|
||||||
|
def get_display_name(self, module_name):
|
||||||
|
"""获取模块的显示名称"""
|
||||||
|
return self.module_name_mapping.get(module_name, module_name)
|
||||||
|
|
||||||
|
def edit_module_mapping(self):
|
||||||
|
"""编辑模块映射"""
|
||||||
|
mapping_window = tk.Toplevel(self.root)
|
||||||
|
mapping_window.title("编辑模块映射")
|
||||||
|
mapping_window.geometry("500x600")
|
||||||
|
|
||||||
|
# 创建滚动框架
|
||||||
|
frame = ttk.Frame(mapping_window)
|
||||||
|
frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
|
||||||
|
|
||||||
|
# 创建滚动条
|
||||||
|
scrollbar = ttk.Scrollbar(frame)
|
||||||
|
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||||
|
|
||||||
|
# 创建映射编辑列表
|
||||||
|
canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set)
|
||||||
|
canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||||
|
scrollbar.config(command=canvas.yview)
|
||||||
|
|
||||||
|
# 创建内部框架
|
||||||
|
inner_frame = ttk.Frame(canvas)
|
||||||
|
canvas.create_window((0, 0), window=inner_frame, anchor="nw")
|
||||||
|
|
||||||
|
# 添加标题
|
||||||
|
ttk.Label(inner_frame, text="模块映射编辑", font=("", 12, "bold")).pack(anchor="w", padx=5, pady=5)
|
||||||
|
ttk.Label(inner_frame, text="英文名 -> 中文名", font=("", 10)).pack(anchor="w", padx=5, pady=2)
|
||||||
|
|
||||||
|
# 映射编辑字典
|
||||||
|
mapping_vars = {}
|
||||||
|
|
||||||
|
# 添加现有模块的映射编辑
|
||||||
|
all_modules = sorted(self.modules)
|
||||||
|
for module in all_modules:
|
||||||
|
frame_row = ttk.Frame(inner_frame)
|
||||||
|
frame_row.pack(fill=tk.X, padx=5, pady=2)
|
||||||
|
|
||||||
|
ttk.Label(frame_row, text=module, width=20).pack(side=tk.LEFT, padx=5)
|
||||||
|
ttk.Label(frame_row, text="->").pack(side=tk.LEFT, padx=5)
|
||||||
|
|
||||||
|
var = tk.StringVar(value=self.module_name_mapping.get(module, module))
|
||||||
|
mapping_vars[module] = var
|
||||||
|
entry = ttk.Entry(frame_row, textvariable=var, width=25)
|
||||||
|
entry.pack(side=tk.LEFT, padx=5)
|
||||||
|
|
||||||
|
# 更新画布滚动区域
|
||||||
|
inner_frame.update_idletasks()
|
||||||
|
canvas.config(scrollregion=canvas.bbox("all"))
|
||||||
|
|
||||||
|
def save_mappings():
|
||||||
|
# 更新映射
|
||||||
|
for module, var in mapping_vars.items():
|
||||||
|
new_name = var.get().strip()
|
||||||
|
if new_name and new_name != module:
|
||||||
|
self.module_name_mapping[module] = new_name
|
||||||
|
elif module in self.module_name_mapping and not new_name:
|
||||||
|
del self.module_name_mapping[module]
|
||||||
|
|
||||||
|
# 保存到文件
|
||||||
|
self.save_module_mapping()
|
||||||
|
# 更新模块列表显示
|
||||||
|
self.update_module_list()
|
||||||
|
mapping_window.destroy()
|
||||||
|
|
||||||
|
# 添加按钮
|
||||||
|
button_frame = ttk.Frame(mapping_window)
|
||||||
|
button_frame.pack(fill=tk.X, padx=5, pady=5)
|
||||||
|
ttk.Button(button_frame, text="保存", command=save_mappings).pack(side=tk.RIGHT, padx=5)
|
||||||
|
ttk.Button(button_frame, text="取消", command=mapping_window.destroy).pack(side=tk.RIGHT, padx=5)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,278 +0,0 @@
|
|||||||
import tkinter as tk
|
|
||||||
from tkinter import ttk
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import networkx as nx
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionViewer:
|
|
||||||
def __init__(self, root):
|
|
||||||
self.root = root
|
|
||||||
self.root.title("表达方式预览器")
|
|
||||||
self.root.geometry("1200x800")
|
|
||||||
|
|
||||||
# 创建主框架
|
|
||||||
self.main_frame = ttk.Frame(root)
|
|
||||||
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
||||||
|
|
||||||
# 创建左侧控制面板
|
|
||||||
self.control_frame = ttk.Frame(self.main_frame)
|
|
||||||
self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
|
|
||||||
|
|
||||||
# 创建搜索框
|
|
||||||
self.search_frame = ttk.Frame(self.control_frame)
|
|
||||||
self.search_frame.pack(fill=tk.X, pady=(0, 10))
|
|
||||||
|
|
||||||
self.search_var = tk.StringVar()
|
|
||||||
self.search_var.trace("w", self.filter_expressions)
|
|
||||||
self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var)
|
|
||||||
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
|
||||||
ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5))
|
|
||||||
|
|
||||||
# 创建文件选择下拉框
|
|
||||||
self.file_var = tk.StringVar()
|
|
||||||
self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var)
|
|
||||||
self.file_combo.pack(side=tk.LEFT, padx=5)
|
|
||||||
self.file_combo.bind("<<ComboboxSelected>>", self.load_file)
|
|
||||||
|
|
||||||
# 创建排序选项
|
|
||||||
self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项")
|
|
||||||
self.sort_frame.pack(fill=tk.X, pady=5)
|
|
||||||
|
|
||||||
self.sort_var = tk.StringVar(value="count")
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.sort_frame, text="按计数排序", variable=self.sort_var, value="count", command=self.apply_sort
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.sort_frame, text="按情境排序", variable=self.sort_var, value="situation", command=self.apply_sort
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.sort_frame, text="按风格排序", variable=self.sort_var, value="style", command=self.apply_sort
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
|
|
||||||
# 创建分群选项
|
|
||||||
self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项")
|
|
||||||
self.group_frame.pack(fill=tk.X, pady=5)
|
|
||||||
|
|
||||||
self.group_var = tk.StringVar(value="none")
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.group_frame, text="不分群", variable=self.group_var, value="none", command=self.apply_grouping
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.group_frame, text="按情境分群", variable=self.group_var, value="situation", command=self.apply_grouping
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
ttk.Radiobutton(
|
|
||||||
self.group_frame, text="按风格分群", variable=self.group_var, value="style", command=self.apply_grouping
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
|
|
||||||
# 创建相似度阈值滑块
|
|
||||||
self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置")
|
|
||||||
self.similarity_frame.pack(fill=tk.X, pady=5)
|
|
||||||
|
|
||||||
self.similarity_var = tk.DoubleVar(value=0.5)
|
|
||||||
self.similarity_scale = ttk.Scale(
|
|
||||||
self.similarity_frame,
|
|
||||||
from_=0.0,
|
|
||||||
to=1.0,
|
|
||||||
variable=self.similarity_var,
|
|
||||||
orient=tk.HORIZONTAL,
|
|
||||||
command=self.update_similarity,
|
|
||||||
)
|
|
||||||
self.similarity_scale.pack(fill=tk.X, padx=5, pady=5)
|
|
||||||
ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack()
|
|
||||||
|
|
||||||
# 创建显示选项
|
|
||||||
self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项")
|
|
||||||
self.view_frame.pack(fill=tk.X, pady=5)
|
|
||||||
|
|
||||||
self.show_graph_var = tk.BooleanVar(value=True)
|
|
||||||
ttk.Checkbutton(
|
|
||||||
self.view_frame, text="显示关系图", variable=self.show_graph_var, command=self.toggle_graph
|
|
||||||
).pack(anchor=tk.W)
|
|
||||||
|
|
||||||
# 创建右侧内容区域
|
|
||||||
self.content_frame = ttk.Frame(self.main_frame)
|
|
||||||
self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
|
||||||
|
|
||||||
# 创建文本显示区域
|
|
||||||
self.text_area = tk.Text(self.content_frame, wrap=tk.WORD)
|
|
||||||
self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
|
||||||
|
|
||||||
# 添加滚动条
|
|
||||||
scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview)
|
|
||||||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
|
||||||
self.text_area.config(yscrollcommand=scrollbar.set)
|
|
||||||
|
|
||||||
# 创建图形显示区域
|
|
||||||
self.graph_frame = ttk.Frame(self.content_frame)
|
|
||||||
self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
|
||||||
|
|
||||||
# 初始化数据
|
|
||||||
self.current_data = []
|
|
||||||
self.graph = nx.Graph()
|
|
||||||
self.canvas = None
|
|
||||||
|
|
||||||
# 加载文件列表
|
|
||||||
self.load_file_list()
|
|
||||||
|
|
||||||
def load_file_list(self):
|
|
||||||
expression_dir = Path("data/expression")
|
|
||||||
files = []
|
|
||||||
for root, _, filenames in os.walk(expression_dir):
|
|
||||||
for filename in filenames:
|
|
||||||
if filename.endswith(".json"):
|
|
||||||
rel_path = os.path.relpath(os.path.join(root, filename), expression_dir)
|
|
||||||
files.append(rel_path)
|
|
||||||
|
|
||||||
self.file_combo["values"] = files
|
|
||||||
if files:
|
|
||||||
self.file_combo.set(files[0])
|
|
||||||
self.load_file(None)
|
|
||||||
|
|
||||||
def load_file(self, event):
|
|
||||||
selected_file = self.file_var.get()
|
|
||||||
if not selected_file:
|
|
||||||
return
|
|
||||||
|
|
||||||
file_path = os.path.join("data/expression", selected_file)
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
self.current_data = json.load(f)
|
|
||||||
|
|
||||||
self.apply_sort()
|
|
||||||
self.update_similarity()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.text_area.delete(1.0, tk.END)
|
|
||||||
self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}")
|
|
||||||
|
|
||||||
def apply_sort(self):
|
|
||||||
if not self.current_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
sort_key = self.sort_var.get()
|
|
||||||
reverse = sort_key == "count"
|
|
||||||
|
|
||||||
self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse)
|
|
||||||
self.apply_grouping()
|
|
||||||
|
|
||||||
def apply_grouping(self):
|
|
||||||
if not self.current_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
group_key = self.group_var.get()
|
|
||||||
if group_key == "none":
|
|
||||||
self.display_data(self.current_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
grouped_data = defaultdict(list)
|
|
||||||
for item in self.current_data:
|
|
||||||
key = item.get(group_key, "未分类")
|
|
||||||
grouped_data[key].append(item)
|
|
||||||
|
|
||||||
self.text_area.delete(1.0, tk.END)
|
|
||||||
for group, items in grouped_data.items():
|
|
||||||
self.text_area.insert(tk.END, f"\n=== {group} ===\n\n")
|
|
||||||
for item in items:
|
|
||||||
self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
|
||||||
|
|
||||||
def display_data(self, data):
|
|
||||||
self.text_area.delete(1.0, tk.END)
|
|
||||||
for item in data:
|
|
||||||
self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
|
||||||
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
|
||||||
|
|
||||||
def update_similarity(self, *args):
|
|
||||||
if not self.current_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
threshold = self.similarity_var.get()
|
|
||||||
self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}")
|
|
||||||
|
|
||||||
# 计算相似度
|
|
||||||
texts = [f"{item['situation']} {item['style']}" for item in self.current_data]
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
|
||||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
|
||||||
|
|
||||||
# 创建图
|
|
||||||
self.graph.clear()
|
|
||||||
for i, item in enumerate(self.current_data):
|
|
||||||
self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}")
|
|
||||||
|
|
||||||
# 添加边
|
|
||||||
for i in range(len(self.current_data)):
|
|
||||||
for j in range(i + 1, len(self.current_data)):
|
|
||||||
if similarity_matrix[i, j] > threshold:
|
|
||||||
self.graph.add_edge(i, j, weight=similarity_matrix[i, j])
|
|
||||||
|
|
||||||
if self.show_graph_var.get():
|
|
||||||
self.draw_graph()
|
|
||||||
|
|
||||||
def draw_graph(self):
|
|
||||||
if self.canvas:
|
|
||||||
self.canvas.get_tk_widget().destroy()
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(8, 6))
|
|
||||||
pos = nx.spring_layout(self.graph)
|
|
||||||
|
|
||||||
# 绘制节点
|
|
||||||
nx.draw_networkx_nodes(self.graph, pos, node_color="lightblue", node_size=1000, alpha=0.6)
|
|
||||||
|
|
||||||
# 绘制边
|
|
||||||
nx.draw_networkx_edges(self.graph, pos, alpha=0.4)
|
|
||||||
|
|
||||||
# 添加标签
|
|
||||||
labels = nx.get_node_attributes(self.graph, "label")
|
|
||||||
nx.draw_networkx_labels(self.graph, pos, labels, font_size=8)
|
|
||||||
|
|
||||||
plt.title("表达方式关系图")
|
|
||||||
plt.axis("off")
|
|
||||||
|
|
||||||
self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame)
|
|
||||||
self.canvas.draw()
|
|
||||||
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
|
|
||||||
|
|
||||||
def toggle_graph(self):
|
|
||||||
if self.show_graph_var.get():
|
|
||||||
self.draw_graph()
|
|
||||||
else:
|
|
||||||
if self.canvas:
|
|
||||||
self.canvas.get_tk_widget().destroy()
|
|
||||||
self.canvas = None
|
|
||||||
|
|
||||||
def filter_expressions(self, *args):
|
|
||||||
search_text = self.search_var.get().lower()
|
|
||||||
if not search_text:
|
|
||||||
self.apply_sort()
|
|
||||||
return
|
|
||||||
|
|
||||||
filtered_data = []
|
|
||||||
for item in self.current_data:
|
|
||||||
situation = item.get("situation", "").lower()
|
|
||||||
style = item.get("style", "").lower()
|
|
||||||
if search_text in situation or search_text in style:
|
|
||||||
filtered_data.append(item)
|
|
||||||
|
|
||||||
self.display_data(filtered_data)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
root = tk.Tk()
|
|
||||||
# app = ExpressionViewer(root)
|
|
||||||
root.mainloop()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
HFC性能统计数据查看工具
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
|
||||||
"""格式化时间显示"""
|
|
||||||
if seconds < 1:
|
|
||||||
return f"{seconds * 1000:.1f}毫秒"
|
|
||||||
else:
|
|
||||||
return f"{seconds:.3f}秒"
|
|
||||||
|
|
||||||
|
|
||||||
def display_chat_stats(chat_id: str, stats: Dict[str, Any]):
|
|
||||||
"""显示单个聊天的统计数据"""
|
|
||||||
print(f"\n=== Chat ID: {chat_id} ===")
|
|
||||||
print(f"版本: {stats.get('version', 'unknown')}")
|
|
||||||
print(f"最后更新: {stats['last_updated']}")
|
|
||||||
|
|
||||||
overall = stats["overall"]
|
|
||||||
print("\n📊 总体统计:")
|
|
||||||
print(f" 总记录数: {overall['total_records']}")
|
|
||||||
print(f" 平均总时间: {format_time(overall['avg_total_time'])}")
|
|
||||||
|
|
||||||
print("\n⏱️ 各步骤平均时间:")
|
|
||||||
for step, avg_time in overall["avg_step_times"].items():
|
|
||||||
print(f" {step}: {format_time(avg_time)}")
|
|
||||||
|
|
||||||
print("\n🎯 按动作类型统计:")
|
|
||||||
by_action = stats["by_action"]
|
|
||||||
|
|
||||||
# 按比例排序
|
|
||||||
sorted_actions = sorted(by_action.items(), key=lambda x: x[1]["percentage"], reverse=True)
|
|
||||||
|
|
||||||
for action, action_stats in sorted_actions:
|
|
||||||
print(f" 📌 {action}:")
|
|
||||||
print(f" 次数: {action_stats['count']} ({action_stats['percentage']:.1f}%)")
|
|
||||||
print(f" 平均总时间: {format_time(action_stats['avg_total_time'])}")
|
|
||||||
|
|
||||||
if action_stats["avg_step_times"]:
|
|
||||||
print(" 步骤时间:")
|
|
||||||
for step, step_time in action_stats["avg_step_times"].items():
|
|
||||||
print(f" {step}: {format_time(step_time)}")
|
|
||||||
|
|
||||||
|
|
||||||
def display_comparison(stats_data: Dict[str, Dict[str, Any]]):
|
|
||||||
"""显示多个聊天的对比数据"""
|
|
||||||
if len(stats_data) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
print("\n=== 多聊天对比 ===")
|
|
||||||
|
|
||||||
# 创建对比表格
|
|
||||||
chat_ids = list(stats_data.keys())
|
|
||||||
|
|
||||||
print("\n📊 总体对比:")
|
|
||||||
print(f"{'Chat ID':<20} {'版本':<12} {'记录数':<8} {'平均时间':<12} {'最常见动作':<15}")
|
|
||||||
print("-" * 70)
|
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
|
||||||
stats = stats_data[chat_id]
|
|
||||||
overall = stats["overall"]
|
|
||||||
|
|
||||||
# 找到最常见的动作
|
|
||||||
most_common_action = max(stats["by_action"].items(), key=lambda x: x[1]["count"])
|
|
||||||
most_common_name = most_common_action[0]
|
|
||||||
most_common_pct = most_common_action[1]["percentage"]
|
|
||||||
|
|
||||||
version = stats.get("version", "unknown")
|
|
||||||
print(
|
|
||||||
f"{chat_id:<20} {version:<12} {overall['total_records']:<8} {format_time(overall['avg_total_time']):<12} {most_common_name}({most_common_pct:.0f}%)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def view_session_logs(chat_id: str = None, latest: bool = False):
|
|
||||||
"""查看会话日志文件"""
|
|
||||||
log_dir = Path("log/hfc_loop")
|
|
||||||
if not log_dir.exists():
|
|
||||||
print("❌ 日志目录不存在")
|
|
||||||
return
|
|
||||||
|
|
||||||
if chat_id:
|
|
||||||
pattern = f"{chat_id}_*.json"
|
|
||||||
else:
|
|
||||||
pattern = "*.json"
|
|
||||||
|
|
||||||
log_files = list(log_dir.glob(pattern))
|
|
||||||
|
|
||||||
if not log_files:
|
|
||||||
print(f"❌ 没有找到匹配的日志文件: {pattern}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if latest:
|
|
||||||
# 按文件修改时间排序,取最新的
|
|
||||||
log_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
|
|
||||||
log_files = log_files[:1]
|
|
||||||
|
|
||||||
for log_file in log_files:
|
|
||||||
print(f"\n=== 会话日志: {log_file.name} ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(log_file, "r", encoding="utf-8") as f:
|
|
||||||
records = json.load(f)
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
print(" 空文件")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f" 记录数: {len(records)}")
|
|
||||||
print(f" 时间范围: {records[0]['timestamp']} ~ {records[-1]['timestamp']}")
|
|
||||||
|
|
||||||
# 统计动作分布
|
|
||||||
action_counts = {}
|
|
||||||
total_time = 0
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
action = record["action_type"]
|
|
||||||
action_counts[action] = action_counts.get(action, 0) + 1
|
|
||||||
total_time += record["total_time"]
|
|
||||||
|
|
||||||
print(f" 总耗时: {format_time(total_time)}")
|
|
||||||
print(f" 平均耗时: {format_time(total_time / len(records))}")
|
|
||||||
print(f" 动作分布: {dict(action_counts)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ 读取文件失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="HFC性能统计数据查看工具")
|
|
||||||
parser.add_argument("--chat-id", help="指定要查看的Chat ID")
|
|
||||||
parser.add_argument("--logs", action="store_true", help="查看会话日志文件")
|
|
||||||
parser.add_argument("--latest", action="store_true", help="只显示最新的日志文件")
|
|
||||||
parser.add_argument("--compare", action="store_true", help="显示多聊天对比")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.logs:
|
|
||||||
view_session_logs(args.chat_id, args.latest)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 读取统计数据
|
|
||||||
stats_file = Path("data/hfc/time.json")
|
|
||||||
if not stats_file.exists():
|
|
||||||
print("❌ 统计数据文件不存在,请先运行一些HFC循环以生成数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(stats_file, "r", encoding="utf-8") as f:
|
|
||||||
stats_data = json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 读取统计数据失败: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not stats_data:
|
|
||||||
print("❌ 统计数据为空")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.chat_id:
|
|
||||||
if args.chat_id in stats_data:
|
|
||||||
display_chat_stats(args.chat_id, stats_data[args.chat_id])
|
|
||||||
else:
|
|
||||||
print(f"❌ 没有找到Chat ID '{args.chat_id}' 的数据")
|
|
||||||
print(f"可用的Chat ID: {list(stats_data.keys())}")
|
|
||||||
else:
|
|
||||||
# 显示所有聊天的统计数据
|
|
||||||
for chat_id, stats in stats_data.items():
|
|
||||||
display_chat_stats(chat_id, stats)
|
|
||||||
|
|
||||||
if args.compare:
|
|
||||||
display_comparison(stats_data)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -20,10 +20,9 @@ from src.person_info.person_info import get_person_info_manager
|
|||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||||
from src.chat.willing.willing_manager import get_willing_manager
|
from src.chat.willing.willing_manager import get_willing_manager
|
||||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from maim_message.message_base import GroupInfo,UserInfo
|
from maim_message.message_base import GroupInfo
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
ENABLE_THINKING = False
|
|
||||||
|
|
||||||
ERROR_LOOP_INFO = {
|
ERROR_LOOP_INFO = {
|
||||||
"loop_plan_info": {
|
"loop_plan_info": {
|
||||||
@@ -237,12 +236,12 @@ class HeartFChatting:
|
|||||||
if if_think:
|
if if_think:
|
||||||
factor = max(global_config.chat.focus_value, 0.1)
|
factor = max(global_config.chat.focus_value, 0.1)
|
||||||
self.energy_value *= 1.1 / factor
|
self.energy_value *= 1.1 / factor
|
||||||
logger.info(f"{self.log_prefix} 麦麦进行了思考,能量值按倍数增加,当前能量值:{self.energy_value}")
|
logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}")
|
||||||
else:
|
else:
|
||||||
self.energy_value += 0.1 / global_config.chat.focus_value
|
self.energy_value += 0.1 / global_config.chat.focus_value
|
||||||
logger.info(f"{self.log_prefix} 麦麦没有进行思考,能量值线性增加,当前能量值:{self.energy_value}")
|
logger.debug(f"{self.log_prefix} 没有进行思考,能量值线性增加,当前能量值:{self.energy_value:.1f}")
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value}")
|
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
@@ -257,31 +256,29 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||||
|
|
||||||
async def send_typing(self):
|
async def send_typing(self):
|
||||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform = "amaidesu_default",
|
platform="amaidesu_default",
|
||||||
user_info = None,
|
user_info=None,
|
||||||
group_info = group_info
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop_typing(self):
|
async def stop_typing(self):
|
||||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform = "amaidesu_default",
|
platform="amaidesu_default",
|
||||||
user_info = None,
|
user_info=None,
|
||||||
group_info = group_info
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||||
)
|
)
|
||||||
@@ -296,7 +293,8 @@ class HeartFChatting:
|
|||||||
|
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||||
|
|
||||||
await self.send_typing()
|
if ENABLE_S4U:
|
||||||
|
await self.send_typing()
|
||||||
|
|
||||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
loop_start_time = time.time()
|
loop_start_time = time.time()
|
||||||
@@ -366,13 +364,13 @@ class HeartFChatting:
|
|||||||
# 发送回复 (不再需要传入 chat)
|
# 发送回复 (不再需要传入 chat)
|
||||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
||||||
|
|
||||||
await self.stop_typing()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_THINKING:
|
|
||||||
|
if ENABLE_S4U:
|
||||||
|
await self.stop_typing()
|
||||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||||
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -504,10 +502,9 @@ class HeartFChatting:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
||||||
|
|
||||||
self.willing_manager.setup(message_data, self.chat_stream)
|
self.willing_manager.setup(message_data, self.chat_stream)
|
||||||
|
|
||||||
|
|
||||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||||
|
|
||||||
talk_frequency = -1.00
|
talk_frequency = -1.00
|
||||||
@@ -517,7 +514,7 @@ class HeartFChatting:
|
|||||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||||
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
||||||
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
||||||
|
|
||||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||||
reply_probability = talk_frequency * reply_probability
|
reply_probability = talk_frequency * reply_probability
|
||||||
|
|
||||||
@@ -527,9 +524,9 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# 打印消息信息
|
# 打印消息信息
|
||||||
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
||||||
|
|
||||||
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
||||||
|
|
||||||
if reply_probability > 0.05:
|
if reply_probability > 0.05:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{mes_name}]"
|
f"[{mes_name}]"
|
||||||
@@ -545,7 +542,6 @@ class HeartFChatting:
|
|||||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||||
@@ -570,7 +566,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data):
|
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||||
@@ -581,9 +577,14 @@ class HeartFChatting:
|
|||||||
|
|
||||||
need_reply = new_message_count >= random.randint(2, 4)
|
need_reply = new_message_count >= random.randint(2, 4)
|
||||||
|
|
||||||
logger.info(
|
if need_reply:
|
||||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
logger.info(
|
||||||
)
|
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||||
|
)
|
||||||
|
|
||||||
reply_text = ""
|
reply_text = ""
|
||||||
first_replied = False
|
first_replied = False
|
||||||
@@ -592,13 +593,27 @@ class HeartFChatting:
|
|||||||
if not first_replied:
|
if not first_replied:
|
||||||
if need_reply:
|
if need_reply:
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to=reply_to,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False)
|
await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=False,
|
||||||
|
)
|
||||||
first_replied = True
|
first_replied = True
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True)
|
await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=True,
|
||||||
|
)
|
||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
return reply_text
|
return reply_text
|
||||||
|
|||||||
@@ -836,7 +836,7 @@ class EmojiManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||||
"""获取表情包描述和情感列表
|
"""获取表情包描述和情感列表,优化复用已有描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_base64: 图片的base64编码
|
image_base64: 图片的base64编码
|
||||||
@@ -850,18 +850,35 @@ class EmojiManager:
|
|||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||||
if image_format == "gif" or image_format == "GIF":
|
existing_description = None
|
||||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
try:
|
||||||
if not image_base64:
|
from src.common.database.database_model import Images
|
||||||
raise RuntimeError("GIF表情包转换失败")
|
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
if existing_image and existing_image.description:
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
existing_description = existing_image.description
|
||||||
|
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"查询已有描述时出错: {e}")
|
||||||
|
|
||||||
|
# 第一步:VLM视觉分析(如果没有已有描述才调用)
|
||||||
|
if existing_description:
|
||||||
|
description = existing_description
|
||||||
|
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
||||||
else:
|
else:
|
||||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
logger.info("[VLM分析] 生成新的详细描述")
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
if image_format == "gif" or image_format == "GIF":
|
||||||
|
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||||
|
if not image_base64:
|
||||||
|
raise RuntimeError("GIF表情包转换失败")
|
||||||
|
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
|
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||||
|
else:
|
||||||
|
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
|
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
# 审核表情包
|
# 审核表情包
|
||||||
if global_config.emoji.content_filtration:
|
if global_config.emoji.content_filtration:
|
||||||
@@ -877,7 +894,7 @@ class EmojiManager:
|
|||||||
if content == "否":
|
if content == "否":
|
||||||
return "", []
|
return "", []
|
||||||
|
|
||||||
# 分析情感含义
|
# 第二步:LLM情感分析 - 基于详细描述生成情感标签列表
|
||||||
emotion_prompt = f"""
|
emotion_prompt = f"""
|
||||||
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||||
这是一个基于这个表情包的描述:'{description}'
|
这是一个基于这个表情包的描述:'{description}'
|
||||||
@@ -889,12 +906,14 @@ class EmojiManager:
|
|||||||
# 处理情感列表
|
# 处理情感列表
|
||||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||||
|
|
||||||
# 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个
|
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||||
if len(emotions) > 5:
|
if len(emotions) > 5:
|
||||||
emotions = random.sample(emotions, 3)
|
emotions = random.sample(emotions, 3)
|
||||||
elif len(emotions) > 2:
|
elif len(emotions) > 2:
|
||||||
emotions = random.sample(emotions, 2)
|
emotions = random.sample(emotions, 2)
|
||||||
|
|
||||||
|
logger.info(f"[注册分析] 详细描述: {description[:50]}... -> 情感标签: {emotions}")
|
||||||
|
|
||||||
return f"[表情包:{description}]", emotions
|
return f"[表情包:{description}]", emotions
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import time
|
|||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
|
|
||||||
@@ -21,6 +22,16 @@ DECAY_MIN = 0.01 # 最小衰减值
|
|||||||
logger = get_logger("expressor")
|
logger = get_logger("expressor")
|
||||||
|
|
||||||
|
|
||||||
|
def format_create_date(timestamp: float) -> str:
|
||||||
|
"""
|
||||||
|
将时间戳格式化为可读的日期字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return "未知时间"
|
||||||
|
|
||||||
|
|
||||||
def init_prompt() -> None:
|
def init_prompt() -> None:
|
||||||
learn_style_prompt = """
|
learn_style_prompt = """
|
||||||
{chat_str}
|
{chat_str}
|
||||||
@@ -76,35 +87,90 @@ class ExpressionLearner:
|
|||||||
request_type="expressor.learner",
|
request_type="expressor.learner",
|
||||||
)
|
)
|
||||||
self.llm_model = None
|
self.llm_model = None
|
||||||
|
self._ensure_expression_directories()
|
||||||
self._auto_migrate_json_to_db()
|
self._auto_migrate_json_to_db()
|
||||||
|
self._migrate_old_data_create_date()
|
||||||
|
|
||||||
|
def _ensure_expression_directories(self):
|
||||||
|
"""
|
||||||
|
确保表达方式相关的目录结构存在
|
||||||
|
"""
|
||||||
|
base_dir = os.path.join("data", "expression")
|
||||||
|
directories_to_create = [
|
||||||
|
base_dir,
|
||||||
|
os.path.join(base_dir, "learnt_style"),
|
||||||
|
os.path.join(base_dir, "learnt_grammar"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for directory in directories_to_create:
|
||||||
|
try:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
logger.debug(f"确保目录存在: {directory}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建目录失败 {directory}: {e}")
|
||||||
|
|
||||||
def _auto_migrate_json_to_db(self):
|
def _auto_migrate_json_to_db(self):
|
||||||
"""
|
"""
|
||||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||||
"""
|
"""
|
||||||
done_flag = os.path.join("data", "expression", "done.done")
|
base_dir = os.path.join("data", "expression")
|
||||||
|
done_flag = os.path.join(base_dir, "done.done")
|
||||||
|
|
||||||
|
# 确保基础目录存在
|
||||||
|
try:
|
||||||
|
os.makedirs(base_dir, exist_ok=True)
|
||||||
|
logger.debug(f"确保目录存在: {base_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建表达方式目录失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
if os.path.exists(done_flag):
|
if os.path.exists(done_flag):
|
||||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||||
return
|
return
|
||||||
base_dir = os.path.join("data", "expression")
|
|
||||||
|
logger.info("开始迁移表达方式JSON到数据库...")
|
||||||
|
migrated_count = 0
|
||||||
|
|
||||||
for type in ["learnt_style", "learnt_grammar"]:
|
for type in ["learnt_style", "learnt_grammar"]:
|
||||||
type_str = "style" if type == "learnt_style" else "grammar"
|
type_str = "style" if type == "learnt_style" else "grammar"
|
||||||
type_dir = os.path.join(base_dir, type)
|
type_dir = os.path.join(base_dir, type)
|
||||||
if not os.path.exists(type_dir):
|
if not os.path.exists(type_dir):
|
||||||
|
logger.debug(f"目录不存在,跳过: {type_dir}")
|
||||||
continue
|
continue
|
||||||
for chat_id in os.listdir(type_dir):
|
|
||||||
|
try:
|
||||||
|
chat_ids = os.listdir(type_dir)
|
||||||
|
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取目录失败 {type_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for chat_id in chat_ids:
|
||||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||||
if not os.path.exists(expr_file):
|
if not os.path.exists(expr_file):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
with open(expr_file, "r", encoding="utf-8") as f:
|
with open(expr_file, "r", encoding="utf-8") as f:
|
||||||
expressions = json.load(f)
|
expressions = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(expressions, list):
|
||||||
|
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||||
|
continue
|
||||||
|
|
||||||
for expr in expressions:
|
for expr in expressions:
|
||||||
|
if not isinstance(expr, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
situation = expr.get("situation")
|
situation = expr.get("situation")
|
||||||
style_val = expr.get("style")
|
style_val = expr.get("style")
|
||||||
count = expr.get("count", 1)
|
count = expr.get("count", 1)
|
||||||
last_active_time = expr.get("last_active_time", time.time())
|
last_active_time = expr.get("last_active_time", time.time())
|
||||||
|
|
||||||
|
if not situation or not style_val:
|
||||||
|
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 查重:同chat_id+type+situation+style
|
# 查重:同chat_id+type+situation+style
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
|
|
||||||
@@ -127,18 +193,54 @@ class ExpressionLearner:
|
|||||||
last_active_time=last_active_time,
|
last_active_time=last_active_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
type=type_str,
|
type=type_str,
|
||||||
|
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||||
)
|
)
|
||||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
migrated_count += 1
|
||||||
|
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||||
|
|
||||||
# 标记迁移完成
|
# 标记迁移完成
|
||||||
try:
|
try:
|
||||||
|
# 确保done.done文件的父目录存在
|
||||||
|
done_parent_dir = os.path.dirname(done_flag)
|
||||||
|
if not os.path.exists(done_parent_dir):
|
||||||
|
os.makedirs(done_parent_dir, exist_ok=True)
|
||||||
|
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||||
|
|
||||||
with open(done_flag, "w", encoding="utf-8") as f:
|
with open(done_flag, "w", encoding="utf-8") as f:
|
||||||
f.write("done\n")
|
f.write("done\n")
|
||||||
logger.info("表达方式JSON迁移已完成,已写入done.done标记文件")
|
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"写入done.done标记文件失败: {e}")
|
logger.error(f"写入done.done标记文件失败: {e}")
|
||||||
|
|
||||||
|
def _migrate_old_data_create_date(self):
|
||||||
|
"""
|
||||||
|
为没有create_date的老数据设置创建日期
|
||||||
|
使用last_active_time作为create_date的默认值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找所有create_date为空的表达方式
|
||||||
|
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
for expr in old_expressions:
|
||||||
|
# 使用last_active_time作为create_date
|
||||||
|
expr.create_date = expr.last_active_time
|
||||||
|
expr.save()
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
if updated_count > 0:
|
||||||
|
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||||
|
|
||||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||||
"""
|
"""
|
||||||
获取指定chat_id的style和grammar表达方式
|
获取指定chat_id的style和grammar表达方式
|
||||||
@@ -150,6 +252,8 @@ class ExpressionLearner:
|
|||||||
# 直接从数据库查询
|
# 直接从数据库查询
|
||||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||||
for expr in style_query:
|
for expr in style_query:
|
||||||
|
# 确保create_date存在,如果不存在则使用last_active_time
|
||||||
|
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||||
learnt_style_expressions.append(
|
learnt_style_expressions.append(
|
||||||
{
|
{
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
@@ -158,10 +262,13 @@ class ExpressionLearner:
|
|||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": chat_id,
|
"source_id": chat_id,
|
||||||
"type": "style",
|
"type": "style",
|
||||||
|
"create_date": create_date,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||||
for expr in grammar_query:
|
for expr in grammar_query:
|
||||||
|
# 确保create_date存在,如果不存在则使用last_active_time
|
||||||
|
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||||
learnt_grammar_expressions.append(
|
learnt_grammar_expressions.append(
|
||||||
{
|
{
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
@@ -170,10 +277,40 @@ class ExpressionLearner:
|
|||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": chat_id,
|
"source_id": chat_id,
|
||||||
"type": "grammar",
|
"type": "grammar",
|
||||||
|
"create_date": create_date,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return learnt_style_expressions, learnt_grammar_expressions
|
return learnt_style_expressions, learnt_grammar_expressions
|
||||||
|
|
||||||
|
def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定chat_id的表达方式创建信息,按创建日期排序
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
expressions = (Expression.select()
|
||||||
|
.where(Expression.chat_id == chat_id)
|
||||||
|
.order_by(Expression.create_date.desc())
|
||||||
|
.limit(limit))
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for expr in expressions:
|
||||||
|
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||||
|
result.append({
|
||||||
|
"situation": expr.situation,
|
||||||
|
"style": expr.style,
|
||||||
|
"type": expr.type,
|
||||||
|
"count": expr.count,
|
||||||
|
"create_date": create_date,
|
||||||
|
"create_date_formatted": format_create_date(create_date),
|
||||||
|
"last_active_time": expr.last_active_time,
|
||||||
|
"last_active_formatted": format_create_date(expr.last_active_time),
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表达方式创建信息失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
def is_similar(self, s1: str, s2: str) -> bool:
|
def is_similar(self, s1: str, s2: str) -> bool:
|
||||||
"""
|
"""
|
||||||
判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串)
|
判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串)
|
||||||
@@ -197,9 +334,17 @@ class ExpressionLearner:
|
|||||||
for type in ["style", "grammar"]:
|
for type in ["style", "grammar"]:
|
||||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
||||||
if not os.path.exists(base_dir):
|
if not os.path.exists(base_dir):
|
||||||
|
logger.debug(f"目录不存在,跳过衰减: {base_dir}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for chat_id in os.listdir(base_dir):
|
try:
|
||||||
|
chat_ids = os.listdir(base_dir)
|
||||||
|
logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取目录失败 {base_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for chat_id in chat_ids:
|
||||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
continue
|
continue
|
||||||
@@ -208,14 +353,24 @@ class ExpressionLearner:
|
|||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
expressions = json.load(f)
|
expressions = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(expressions, list):
|
||||||
|
logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 应用全局衰减
|
# 应用全局衰减
|
||||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
||||||
|
|
||||||
# 保存衰减后的结果
|
# 保存衰减后的结果
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}")
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"权限不足,无法更新 {file_path}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||||
@@ -350,6 +505,7 @@ class ExpressionLearner:
|
|||||||
last_active_time=current_time,
|
last_active_time=current_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
type=type,
|
type=type,
|
||||||
|
create_date=current_time, # 手动设置创建日期
|
||||||
)
|
)
|
||||||
# 限制最大数量
|
# 限制最大数量
|
||||||
exprs = list(
|
exprs = list(
|
||||||
|
|||||||
@@ -132,7 +132,8 @@ class ExpressionSelector:
|
|||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": cid,
|
"source_id": cid,
|
||||||
"type": "style"
|
"type": "style",
|
||||||
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
} for expr in style_query
|
} for expr in style_query
|
||||||
])
|
])
|
||||||
grammar_exprs.extend([
|
grammar_exprs.extend([
|
||||||
@@ -142,7 +143,8 @@ class ExpressionSelector:
|
|||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": cid,
|
"source_id": cid,
|
||||||
"type": "grammar"
|
"type": "grammar",
|
||||||
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
} for expr in grammar_query
|
} for expr in grammar_query
|
||||||
])
|
])
|
||||||
style_num = int(total_num * style_percentage)
|
style_num = int(total_num * style_percentage)
|
||||||
|
|||||||
@@ -111,9 +111,9 @@ class HeartFCMessageReceiver:
|
|||||||
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
||||||
|
|
||||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||||
|
if global_config.mood.enable_mood:
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||||
|
|
||||||
# 3. 日志记录
|
# 3. 日志记录
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
|
|||||||
@@ -13,10 +13,9 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
|||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
|
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
@@ -92,8 +91,19 @@ class ChatBot:
|
|||||||
# 使用新的组件注册中心查找命令
|
# 使用新的组件注册中心查找命令
|
||||||
command_result = component_registry.find_command_by_text(text)
|
command_result = component_registry.find_command_by_text(text)
|
||||||
if command_result:
|
if command_result:
|
||||||
|
command_class, matched_groups, command_info = command_result
|
||||||
|
plugin_name = command_info.plugin_name
|
||||||
|
command_name = command_info.name
|
||||||
|
if (
|
||||||
|
message.chat_stream
|
||||||
|
and message.chat_stream.stream_id
|
||||||
|
and command_name
|
||||||
|
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||||
|
):
|
||||||
|
logger.info("用户禁用的命令,跳过处理")
|
||||||
|
return False, None, True
|
||||||
|
|
||||||
message.is_command = True
|
message.is_command = True
|
||||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
|
||||||
|
|
||||||
# 获取插件配置
|
# 获取插件配置
|
||||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||||
@@ -104,7 +114,7 @@ class ChatBot:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行命令
|
# 执行命令
|
||||||
success, response = await command_instance.execute()
|
success, response, intercept_message = await command_instance.execute()
|
||||||
|
|
||||||
# 记录命令执行结果
|
# 记录命令执行结果
|
||||||
if success:
|
if success:
|
||||||
@@ -117,8 +127,6 @@ class ChatBot:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -127,7 +135,7 @@ class ChatBot:
|
|||||||
logger.error(f"发送错误消息失败: {send_error}")
|
logger.error(f"发送错误消息失败: {send_error}")
|
||||||
|
|
||||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||||
return True, str(e), not intercept_message
|
return True, str(e), False # 出错时继续处理消息
|
||||||
|
|
||||||
# 没有找到命令,继续处理消息
|
# 没有找到命令,继续处理消息
|
||||||
return False, None, True
|
return False, None, True
|
||||||
@@ -135,13 +143,12 @@ class ChatBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理命令时出错: {e}")
|
logger.error(f"处理命令时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
async def hanle_notice_message(self, message: MessageRecv):
|
async def hanle_notice_message(self, message: MessageRecv):
|
||||||
if message.message_info.message_id == "notice":
|
if message.message_info.message_id == "notice":
|
||||||
logger.info("收到notice消息,暂时不支持处理")
|
logger.info("收到notice消息,暂时不支持处理")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||||
message = MessageRecvS4U(message_data)
|
message = MessageRecvS4U(message_data)
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
@@ -163,7 +170,6 @@ class ChatBot:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||||
"""处理转化后的统一格式消息
|
"""处理转化后的统一格式消息
|
||||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||||
@@ -179,8 +185,6 @@ class ChatBot:
|
|||||||
- 性能计时
|
- 性能计时
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
# 确保所有任务已启动
|
# 确保所有任务已启动
|
||||||
await self._ensure_started()
|
await self._ensure_started()
|
||||||
|
|
||||||
@@ -201,11 +205,10 @@ class ChatBot:
|
|||||||
# print(message_data)
|
# print(message_data)
|
||||||
# logger.debug(str(message_data))
|
# logger.debug(str(message_data))
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
|
|
||||||
if await self.hanle_notice_message(message):
|
if await self.hanle_notice_message(message):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
if message.message_info.additional_config:
|
if message.message_info.additional_config:
|
||||||
@@ -214,9 +217,6 @@ class ChatBot:
|
|||||||
await MessageStorage.update_message(message)
|
await MessageStorage.update_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
|
||||||
return
|
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
@@ -229,11 +229,10 @@ class ChatBot:
|
|||||||
|
|
||||||
# 处理消息内容,生成纯文本
|
# 处理消息内容,生成纯文本
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# if await self.check_ban_content(message):
|
# if await self.check_ban_content(message):
|
||||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
@@ -252,6 +251,9 @@ class ChatBot:
|
|||||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||||
|
return
|
||||||
|
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||||
|
|||||||
@@ -163,20 +163,25 @@ class ChatManager:
|
|||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform, # type: ignore
|
message.message_info.platform, # type: ignore
|
||||||
message.message_info.user_info, # type: ignore
|
message.message_info.user_info,
|
||||||
message.message_info.group_info,
|
message.message_info.group_info,
|
||||||
)
|
)
|
||||||
self.last_messages[stream_id] = message
|
self.last_messages[stream_id] = message
|
||||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(
|
||||||
|
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||||
|
) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
|
if not user_info and not group_info:
|
||||||
|
raise ValueError("用户信息或群组信息必须提供")
|
||||||
|
|
||||||
if group_info:
|
if group_info:
|
||||||
# 组合关键信息
|
# 组合关键信息
|
||||||
components = [platform, str(group_info.group_id)]
|
components = [platform, str(group_info.group_id)]
|
||||||
else:
|
else:
|
||||||
components = [platform, str(user_info.user_id), "private"]
|
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||||
|
|
||||||
# 使用MD5生成唯一ID
|
# 使用MD5生成唯一ID
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Type
|
from typing import Dict, Optional, Type
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -22,53 +22,14 @@ class ActionManager:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化动作管理器"""
|
"""初始化动作管理器"""
|
||||||
# 所有注册的动作集合
|
|
||||||
self._registered_actions: Dict[str, ActionInfo] = {}
|
|
||||||
# 当前正在使用的动作集合,默认加载默认动作
|
# 当前正在使用的动作集合,默认加载默认动作
|
||||||
self._using_actions: Dict[str, ActionInfo] = {}
|
self._using_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
# 加载插件动作
|
|
||||||
self._load_plugin_actions()
|
|
||||||
|
|
||||||
# 初始化时将默认动作加载到使用中的动作
|
# 初始化时将默认动作加载到使用中的动作
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
|
|
||||||
def _load_plugin_actions(self) -> None:
|
# === 执行Action方法 ===
|
||||||
"""
|
|
||||||
加载所有插件系统中的动作
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 从新插件系统获取Action组件
|
|
||||||
self._load_plugin_system_actions()
|
|
||||||
logger.debug("从插件系统加载Action组件成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载插件动作失败: {e}")
|
|
||||||
|
|
||||||
def _load_plugin_system_actions(self) -> None:
|
|
||||||
"""从插件系统的component_registry加载Action组件"""
|
|
||||||
try:
|
|
||||||
# 获取所有Action组件
|
|
||||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
|
||||||
|
|
||||||
for action_name, action_info in action_components.items():
|
|
||||||
if action_name in self._registered_actions:
|
|
||||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._registered_actions[action_name] = action_info
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"加载了 {len(action_components)} 个Action动作")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def create_action(
|
def create_action(
|
||||||
self,
|
self,
|
||||||
@@ -139,36 +100,11 @@ class ActionManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_registered_actions(self) -> Dict[str, ActionInfo]:
|
|
||||||
"""获取所有已注册的动作集"""
|
|
||||||
return self._registered_actions.copy()
|
|
||||||
|
|
||||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||||
"""获取当前正在使用的动作集合"""
|
"""获取当前正在使用的动作集合"""
|
||||||
return self._using_actions.copy()
|
return self._using_actions.copy()
|
||||||
|
|
||||||
def add_action_to_using(self, action_name: str) -> bool:
|
# === Modify相关方法 ===
|
||||||
"""
|
|
||||||
添加已注册的动作到当前使用的动作集
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 添加是否成功
|
|
||||||
"""
|
|
||||||
if action_name not in self._registered_actions:
|
|
||||||
logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if action_name in self._using_actions:
|
|
||||||
logger.info(f"动作 {action_name} 已经在使用中")
|
|
||||||
return True
|
|
||||||
|
|
||||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
|
||||||
logger.info(f"添加动作 {action_name} 到使用集")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_action_from_using(self, action_name: str) -> bool:
|
def remove_action_from_using(self, action_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
从当前使用的动作集中移除指定动作
|
从当前使用的动作集中移除指定动作
|
||||||
@@ -187,79 +123,8 @@ class ActionManager:
|
|||||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
|
||||||
# """
|
|
||||||
# 添加新的动作到注册集
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# action_name: 动作名称
|
|
||||||
# description: 动作描述
|
|
||||||
# parameters: 动作参数定义,默认为空字典
|
|
||||||
# require: 动作依赖项,默认为空列表
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# bool: 添加是否成功
|
|
||||||
# """
|
|
||||||
# if action_name in self._registered_actions:
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# if parameters is None:
|
|
||||||
# parameters = {}
|
|
||||||
# if require is None:
|
|
||||||
# require = []
|
|
||||||
|
|
||||||
# action_info = {"description": description, "parameters": parameters, "require": require}
|
|
||||||
|
|
||||||
# self._registered_actions[action_name] = action_info
|
|
||||||
# return True
|
|
||||||
|
|
||||||
def remove_action(self, action_name: str) -> bool:
|
|
||||||
"""从注册集移除指定动作"""
|
|
||||||
if action_name not in self._registered_actions:
|
|
||||||
return False
|
|
||||||
del self._registered_actions[action_name]
|
|
||||||
# 如果在使用集中也存在,一并移除
|
|
||||||
if action_name in self._using_actions:
|
|
||||||
del self._using_actions[action_name]
|
|
||||||
return True
|
|
||||||
|
|
||||||
def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None:
|
|
||||||
"""临时移除使用集中的指定动作"""
|
|
||||||
for name in actions_to_remove:
|
|
||||||
self._using_actions.pop(name, None)
|
|
||||||
|
|
||||||
def restore_actions(self) -> None:
|
def restore_actions(self) -> None:
|
||||||
"""恢复到默认动作集"""
|
"""恢复到默认动作集"""
|
||||||
actions_to_restore = list(self._using_actions.keys())
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
|
|
||||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
根据需要添加系统动作到使用集
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功添加
|
|
||||||
"""
|
|
||||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
|
||||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
|
||||||
logger.info(f"临时添加系统动作到使用集: {action_name}")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
|
|
||||||
"""
|
|
||||||
获取指定动作的处理器类
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Any, Dict, TYPE_CHECKING
|
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo
|
|||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -47,7 +48,6 @@ class ActionModifier:
|
|||||||
|
|
||||||
async def modify_actions(
|
async def modify_actions(
|
||||||
self,
|
self,
|
||||||
history_loop=None,
|
|
||||||
message_content: str = "",
|
message_content: str = "",
|
||||||
): # sourcery skip: use-named-expression
|
): # sourcery skip: use-named-expression
|
||||||
"""
|
"""
|
||||||
@@ -61,8 +61,9 @@ class ActionModifier:
|
|||||||
"""
|
"""
|
||||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||||
|
|
||||||
removals_s1 = []
|
removals_s1: List[Tuple[str, str]] = []
|
||||||
removals_s2 = []
|
removals_s2: List[Tuple[str, str]] = []
|
||||||
|
removals_s3: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
self.action_manager.restore_actions()
|
self.action_manager.restore_actions()
|
||||||
all_actions = self.action_manager.get_using_actions()
|
all_actions = self.action_manager.get_using_actions()
|
||||||
@@ -84,25 +85,28 @@ class ActionModifier:
|
|||||||
if message_content:
|
if message_content:
|
||||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||||
|
|
||||||
# === 第一阶段:传统观察处理 ===
|
# === 第一阶段:去除用户自行禁用的 ===
|
||||||
# if history_loop:
|
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||||
# removals_from_loop = await self.analyze_loop_actions(history_loop)
|
if disabled_actions:
|
||||||
# if removals_from_loop:
|
for disabled_action_name in disabled_actions:
|
||||||
# removals_s1.extend(removals_from_loop)
|
if disabled_action_name in all_actions:
|
||||||
|
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||||
|
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||||
|
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||||
|
|
||||||
# 检查动作的关联类型
|
# === 第二阶段:检查动作的关联类型 ===
|
||||||
chat_context = self.chat_stream.context
|
chat_context = self.chat_stream.context
|
||||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||||
|
|
||||||
if type_mismatched_actions:
|
if type_mismatched_actions:
|
||||||
removals_s1.extend(type_mismatched_actions)
|
removals_s2.extend(type_mismatched_actions)
|
||||||
|
|
||||||
# 应用第一阶段的移除
|
# 应用第二阶段的移除
|
||||||
for action_name, reason in removals_s1:
|
for action_name, reason in removals_s2:
|
||||||
self.action_manager.remove_action_from_using(action_name)
|
self.action_manager.remove_action_from_using(action_name)
|
||||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
# === 第二阶段:激活类型判定 ===
|
# === 第三阶段:激活类型判定 ===
|
||||||
if chat_content is not None:
|
if chat_content is not None:
|
||||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||||
|
|
||||||
@@ -110,18 +114,18 @@ class ActionModifier:
|
|||||||
current_using_actions = self.action_manager.get_using_actions()
|
current_using_actions = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
# 获取因激活类型判定而需要移除的动作
|
# 获取因激活类型判定而需要移除的动作
|
||||||
removals_s2 = await self._get_deactivated_actions_by_type(
|
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||||
current_using_actions,
|
current_using_actions,
|
||||||
chat_content,
|
chat_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用第二阶段的移除
|
# 应用第三阶段的移除
|
||||||
for action_name, reason in removals_s2:
|
for action_name, reason in removals_s3:
|
||||||
self.action_manager.remove_action_from_using(action_name)
|
self.action_manager.remove_action_from_using(action_name)
|
||||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
# === 统一日志记录 ===
|
# === 统一日志记录 ===
|
||||||
all_removals = removals_s1 + removals_s2
|
all_removals = removals_s1 + removals_s2 + removals_s3
|
||||||
removals_summary: str = ""
|
removals_summary: str = ""
|
||||||
if all_removals:
|
if all_removals:
|
||||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||||
@@ -131,7 +135,7 @@ class ActionModifier:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||||
type_mismatched_actions = []
|
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||||
for action_name, action_info in all_actions.items():
|
for action_name, action_info in all_actions.items():
|
||||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||||
associated_types_str = ", ".join(action_info.associated_types)
|
associated_types_str = ", ".join(action_info.associated_types)
|
||||||
@@ -318,7 +322,7 @@ class ActionModifier:
|
|||||||
action_name: str,
|
action_name: str,
|
||||||
action_info: ActionInfo,
|
action_info: ActionInfo,
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> bool:
|
) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
|
||||||
"""
|
"""
|
||||||
使用LLM判定是否应该激活某个action
|
使用LLM判定是否应该激活某个action
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
logger = get_logger("planner")
|
logger = get_logger("planner")
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
async def plan(
|
async def plan(
|
||||||
self, mode: ChatMode = ChatMode.FOCUS
|
self, mode: ChatMode = ChatMode.FOCUS
|
||||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
@@ -113,16 +113,17 @@ class ActionPlanner:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
|
|
||||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||||
|
|
||||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
# 获取完整的动作信息
|
# 获取完整的动作信息
|
||||||
all_registered_actions = self.action_manager.get_registered_actions()
|
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||||
|
ComponentType.ACTION
|
||||||
for action_name in current_available_actions_dict.keys():
|
)
|
||||||
|
current_available_actions = {}
|
||||||
|
for action_name in current_available_actions_dict:
|
||||||
if action_name in all_registered_actions:
|
if action_name in all_registered_actions:
|
||||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||||
else:
|
else:
|
||||||
@@ -234,10 +235,13 @@ class ActionPlanner:
|
|||||||
"is_parallel": is_parallel,
|
"is_parallel": is_parallel,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return (
|
||||||
"action_result": action_result,
|
{
|
||||||
"action_prompt": prompt,
|
"action_result": action_result,
|
||||||
}, target_message
|
"action_prompt": prompt,
|
||||||
|
},
|
||||||
|
target_message,
|
||||||
|
)
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
@@ -275,23 +279,29 @@ class ActionPlanner:
|
|||||||
self.last_obs_time_mark = time.time()
|
self.last_obs_time_mark = time.time()
|
||||||
|
|
||||||
if mode == ChatMode.FOCUS:
|
if mode == ChatMode.FOCUS:
|
||||||
|
mentioned_bonus = ""
|
||||||
|
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你"
|
||||||
|
if global_config.chat.at_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||||
|
|
||||||
|
|
||||||
by_what = "聊天内容"
|
by_what = "聊天内容"
|
||||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
||||||
no_action_block = """重要说明1:
|
no_action_block = f"""重要说明1:
|
||||||
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
||||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||||
|
|
||||||
动作:reply
|
动作:reply
|
||||||
动作描述:参与聊天回复,发送文本进行表达
|
动作描述:参与聊天回复,发送文本进行表达
|
||||||
- 你想要闲聊或者随便附和
|
- 你想要闲聊或者随便附和{mentioned_bonus}
|
||||||
- 有人提到你
|
|
||||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
||||||
{
|
{{
|
||||||
"action": "reply",
|
"action": "reply",
|
||||||
"target_message_id":"触发action的消息id",
|
"target_message_id":"触发action的消息id",
|
||||||
"reason":"回复的原因"
|
"reason":"回复的原因"
|
||||||
}
|
}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import re
|
|||||||
|
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.individuality.individuality import get_individuality
|
from src.individuality.individuality import get_individuality
|
||||||
@@ -30,9 +30,6 @@ from src.plugin_system.base.component_types import ActionInfo
|
|||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
ENABLE_S2S_MODE = True
|
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||||
@@ -60,7 +57,6 @@ def init_prompt():
|
|||||||
现在请你读读之前的聊天记录,并给出回复
|
现在请你读读之前的聊天记录,并给出回复
|
||||||
{config_expression_style}。注意不要复读你说过的话
|
{config_expression_style}。注意不要复读你说过的话
|
||||||
{keywords_reaction_prompt}
|
{keywords_reaction_prompt}
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
||||||
"default_generator_prompt",
|
"default_generator_prompt",
|
||||||
@@ -78,6 +74,7 @@ def init_prompt():
|
|||||||
|
|
||||||
你正在{chat_target_2},{reply_target_block}
|
你正在{chat_target_2},{reply_target_block}
|
||||||
对这句话,你想表达,原句:{raw_reply},原因是:{reason}。你现在要思考怎么组织回复
|
对这句话,你想表达,原句:{raw_reply},原因是:{reason}。你现在要思考怎么组织回复
|
||||||
|
你现在的心情是:{mood_state}
|
||||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||||
{keywords_reaction_prompt}
|
{keywords_reaction_prompt}
|
||||||
@@ -98,29 +95,29 @@ def init_prompt():
|
|||||||
{relation_info_block}
|
{relation_info_block}
|
||||||
{extra_info_block}
|
{extra_info_block}
|
||||||
|
|
||||||
你是一个AI虚拟主播,正在直播QQ聊天,同时也在直播间回复弹幕,不过回复的时候不用过多提及这点
|
|
||||||
|
|
||||||
{identity}
|
{identity}
|
||||||
|
|
||||||
{action_descriptions}
|
{action_descriptions}
|
||||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。你现在的心情是:{mood_state}
|
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||||
|
|
||||||
{background_dialogue_prompt}
|
{background_dialogue_prompt}
|
||||||
--------------------------------
|
--------------------------------
|
||||||
{time_block}
|
{time_block}
|
||||||
这是你和{sender_name}的对话,你们正在交流中:
|
这是你和{sender_name}的对话,你们正在交流中:
|
||||||
|
|
||||||
{core_dialogue_prompt}
|
{core_dialogue_prompt}
|
||||||
|
|
||||||
{reply_target_block}
|
{reply_target_block}
|
||||||
对方最新发送的内容:{message_txt}
|
对方最新发送的内容:{message_txt}
|
||||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
你现在的心情是:{mood_state}
|
||||||
{config_expression_style}。注意不要复读你说过的话
|
{config_expression_style}
|
||||||
|
注意不要复读你说过的话
|
||||||
{keywords_reaction_prompt}
|
{keywords_reaction_prompt}
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出一条回复内容就好
|
||||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
现在,你说:
|
||||||
你的发言:
|
|
||||||
""",
|
""",
|
||||||
"s4u_style_prompt",
|
"s4u_style_prompt",
|
||||||
)
|
)
|
||||||
@@ -133,7 +130,6 @@ class DefaultReplyer:
|
|||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
request_type: str = "focus.replyer",
|
request_type: str = "focus.replyer",
|
||||||
):
|
):
|
||||||
self.log_prefix = "replyer"
|
|
||||||
self.request_type = request_type
|
self.request_type = request_type
|
||||||
|
|
||||||
if model_configs:
|
if model_configs:
|
||||||
@@ -197,7 +193,7 @@ class DefaultReplyer:
|
|||||||
}
|
}
|
||||||
for key, value in reply_data.items():
|
for key, value in reply_data.items():
|
||||||
if not value:
|
if not value:
|
||||||
logger.debug(f"{self.log_prefix} 回复数据跳过{key},生成回复时将忽略。")
|
logger.debug(f"回复数据跳过{key},生成回复时将忽略。")
|
||||||
|
|
||||||
# 3. 构建 Prompt
|
# 3. 构建 Prompt
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
@@ -218,7 +214,7 @@ class DefaultReplyer:
|
|||||||
# 加权随机选择一个模型配置
|
# 加权随机选择一个模型配置
|
||||||
selected_model_config = self._select_weighted_model_config()
|
selected_model_config = self._select_weighted_model_config()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 使用模型配置: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
express_model = LLMRequest(
|
express_model = LLMRequest(
|
||||||
@@ -227,9 +223,9 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix}\n{prompt}\n")
|
logger.info(f"\n{prompt}\n")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix}\n{prompt}\n")
|
logger.debug(f"\n{prompt}\n")
|
||||||
|
|
||||||
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
|
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
|
||||||
|
|
||||||
@@ -237,13 +233,13 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
logger.error(f"LLM 生成失败: {llm_e}")
|
||||||
return False, None, prompt # LLM 调用失败则无法生成回复
|
return False, None, prompt # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, content, prompt
|
return True, content, prompt
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
logger.error(f"回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None, prompt
|
return False, None, prompt
|
||||||
|
|
||||||
@@ -274,7 +270,7 @@ class DefaultReplyer:
|
|||||||
reasoning_content = None
|
reasoning_content = None
|
||||||
model_name = "unknown_model"
|
model_name = "unknown_model"
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.error(f"{self.log_prefix}Prompt 构建失败,无法生成回复。")
|
logger.error("Prompt 构建失败,无法生成回复。")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -282,7 +278,7 @@ class DefaultReplyer:
|
|||||||
# 加权随机选择一个模型配置
|
# 加权随机选择一个模型配置
|
||||||
selected_model_config = self._select_weighted_model_config()
|
selected_model_config = self._select_weighted_model_config()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
express_model = LLMRequest(
|
express_model = LLMRequest(
|
||||||
@@ -296,13 +292,13 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
logger.error(f"LLM 生成失败: {llm_e}")
|
||||||
return False, None # LLM 调用失败则无法生成回复
|
return False, None # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, content
|
return True, content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
logger.error(f"回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
@@ -322,7 +318,7 @@ class DefaultReplyer:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
||||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||||
@@ -341,7 +337,7 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if selected_expressions:
|
if selected_expressions:
|
||||||
logger.debug(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
|
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||||
for expr in selected_expressions:
|
for expr in selected_expressions:
|
||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||||
expr_type = expr.get("type", "style")
|
expr_type = expr.get("type", "style")
|
||||||
@@ -350,7 +346,7 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
|
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||||
|
|
||||||
style_habits_str = "\n".join(style_habits)
|
style_habits_str = "\n".join(style_habits)
|
||||||
@@ -358,10 +354,19 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
# 动态构建expression habits块
|
# 动态构建expression habits块
|
||||||
expression_habits_block = ""
|
expression_habits_block = ""
|
||||||
|
expression_habits_title = ""
|
||||||
if style_habits_str.strip():
|
if style_habits_str.strip():
|
||||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
expression_habits_title = "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||||
|
expression_habits_block += f"{style_habits_str}\n"
|
||||||
if grammar_habits_str.strip():
|
if grammar_habits_str.strip():
|
||||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
expression_habits_title = "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
|
||||||
|
expression_habits_block += f"{grammar_habits_str}\n"
|
||||||
|
|
||||||
|
if style_habits_str.strip() and grammar_habits_str.strip():
|
||||||
|
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
|
||||||
|
|
||||||
|
expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}"
|
||||||
|
|
||||||
|
|
||||||
return expression_habits_block
|
return expression_habits_block
|
||||||
|
|
||||||
@@ -432,19 +437,23 @@ class DefaultReplyer:
|
|||||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||||
|
|
||||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||||
logger.info(f"{self.log_prefix} 获取到 {len(tool_results)} 个工具结果")
|
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||||
|
|
||||||
return tool_info_str
|
return tool_info_str
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 未获取到任何工具结果")
|
logger.debug("未获取到任何工具结果")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 工具信息获取失败: {e}")
|
logger.error(f"工具信息获取失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||||
sender = ""
|
sender = ""
|
||||||
target = ""
|
target = ""
|
||||||
|
# 添加None检查,防止NoneType错误
|
||||||
|
if target_message is None:
|
||||||
|
return sender, target
|
||||||
if ":" in target_message or ":" in target_message:
|
if ":" in target_message or ":" in target_message:
|
||||||
# 使用正则表达式匹配中文或英文冒号
|
# 使用正则表达式匹配中文或英文冒号
|
||||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||||
@@ -457,6 +466,10 @@ class DefaultReplyer:
|
|||||||
# 关键词检测与反应
|
# 关键词检测与反应
|
||||||
keywords_reaction_prompt = ""
|
keywords_reaction_prompt = ""
|
||||||
try:
|
try:
|
||||||
|
# 添加None检查,防止NoneType错误
|
||||||
|
if target is None:
|
||||||
|
return keywords_reaction_prompt
|
||||||
|
|
||||||
# 处理关键词规则
|
# 处理关键词规则
|
||||||
for rule in global_config.keyword_reaction.keyword_rules:
|
for rule in global_config.keyword_reaction.keyword_rules:
|
||||||
if any(keyword in target for keyword in rule.keywords):
|
if any(keyword in target for keyword in rule.keywords):
|
||||||
@@ -510,19 +523,21 @@ class DefaultReplyer:
|
|||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
reply_to = msg_dict.get("reply_to", "")
|
||||||
|
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||||
|
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||||
# bot 和目标用户的对话
|
# bot 和目标用户的对话
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
else:
|
else:
|
||||||
# 其他用户的对话
|
# 其他用户的对话
|
||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"记录: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
# 构建背景对话 prompt
|
# 构建背景对话 prompt
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.5) :]
|
||||||
background_dialogue_prompt_str = build_readable_messages(
|
background_dialogue_prompt_str = build_readable_messages(
|
||||||
latest_25_msgs,
|
latest_25_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
@@ -549,6 +564,34 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return core_dialogue_prompt, background_dialogue_prompt
|
return core_dialogue_prompt, background_dialogue_prompt
|
||||||
|
|
||||||
|
def build_mai_think_context(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
memory_block: str,
|
||||||
|
relation_info: str,
|
||||||
|
time_block: str,
|
||||||
|
chat_target_1: str,
|
||||||
|
chat_target_2: str,
|
||||||
|
mood_prompt: str,
|
||||||
|
identity_block: str,
|
||||||
|
sender: str,
|
||||||
|
target: str,
|
||||||
|
chat_info: str,
|
||||||
|
):
|
||||||
|
"""构建 mai_think 上下文信息"""
|
||||||
|
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||||
|
mai_think.memory_block = memory_block
|
||||||
|
mai_think.relation_info_block = relation_info
|
||||||
|
mai_think.time_block = time_block
|
||||||
|
mai_think.chat_target = chat_target_1
|
||||||
|
mai_think.chat_target_2 = chat_target_2
|
||||||
|
mai_think.chat_info = chat_info
|
||||||
|
mai_think.mood_state = mood_prompt
|
||||||
|
mai_think.identity = identity_block
|
||||||
|
mai_think.sender = sender
|
||||||
|
mai_think.target = target
|
||||||
|
return mai_think
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
self,
|
self,
|
||||||
reply_data: Dict[str, Any],
|
reply_data: Dict[str, Any],
|
||||||
@@ -578,9 +621,12 @@ class DefaultReplyer:
|
|||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
reply_to = reply_data.get("reply_to", "none")
|
reply_to = reply_data.get("reply_to", "none")
|
||||||
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
||||||
|
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
if global_config.mood.enable_mood:
|
||||||
mood_prompt = chat_mood.mood_state
|
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||||
|
mood_prompt = chat_mood.mood_state
|
||||||
|
else:
|
||||||
|
mood_prompt = ""
|
||||||
|
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
|
|
||||||
@@ -628,44 +674,51 @@ class DefaultReplyer:
|
|||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 并行执行四个构建任务
|
# 并行执行五个构建任务
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target), "build_expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_relation_info(reply_data), "build_relation_info"
|
self.build_relation_info(reply_data), "relation_info"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
|
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block"),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info"
|
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "tool_info"
|
||||||
|
),
|
||||||
|
self._time_and_run_task(
|
||||||
|
get_prompt_info(target, threshold=0.38), "prompt_info"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 任务名称中英文映射
|
||||||
|
task_name_mapping = {
|
||||||
|
"expression_habits": "选取表达方式",
|
||||||
|
"relation_info": "感受关系",
|
||||||
|
"memory_block": "回忆",
|
||||||
|
"tool_info": "使用工具",
|
||||||
|
"prompt_info": "获取知识"
|
||||||
|
}
|
||||||
|
|
||||||
# 处理结果
|
# 处理结果
|
||||||
timing_logs = []
|
timing_logs = []
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
for name, result, duration in task_results:
|
for name, result, duration in task_results:
|
||||||
results_dict[name] = result
|
results_dict[name] = result
|
||||||
timing_logs.append(f"{name}: {duration:.4f}s")
|
chinese_name = task_name_mapping.get(name, name)
|
||||||
|
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||||
if duration > 8:
|
if duration > 8:
|
||||||
logger.warning(f"回复生成前信息获取耗时过长: {name} 耗时: {duration:.4f}s,请使用更快的模型")
|
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||||
logger.info(f"回复生成前信息获取耗时: {'; '.join(timing_logs)}")
|
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
|
||||||
|
|
||||||
expression_habits_block = results_dict["build_expression_habits"]
|
expression_habits_block = results_dict["expression_habits"]
|
||||||
relation_info = results_dict["build_relation_info"]
|
relation_info = results_dict["relation_info"]
|
||||||
memory_block = results_dict["build_memory_block"]
|
memory_block = results_dict["memory_block"]
|
||||||
tool_info = results_dict["build_tool_info"]
|
tool_info = results_dict["tool_info"]
|
||||||
|
prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||||
|
|
||||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||||
|
|
||||||
if tool_info:
|
|
||||||
tool_info_block = (
|
|
||||||
f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{tool_info}\n以上是一些额外的信息。"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_info_block = ""
|
|
||||||
|
|
||||||
if extra_info_block:
|
if extra_info_block:
|
||||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||||
else:
|
else:
|
||||||
@@ -699,10 +752,6 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
prompt_info = await get_prompt_info(target, threshold=0.38)
|
|
||||||
if prompt_info:
|
|
||||||
prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
|
|
||||||
|
|
||||||
template_name = "default_generator_prompt"
|
template_name = "default_generator_prompt"
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||||
@@ -742,24 +791,24 @@ class DefaultReplyer:
|
|||||||
message_list_before_now_long, target_user_id
|
message_list_before_now_long, target_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
self.build_mai_think_context(
|
||||||
mai_think.memory_block = memory_block
|
chat_id=chat_id,
|
||||||
mai_think.relation_info_block = relation_info
|
memory_block=memory_block,
|
||||||
mai_think.time_block = time_block
|
relation_info=relation_info,
|
||||||
mai_think.chat_target = chat_target_1
|
time_block=time_block,
|
||||||
mai_think.chat_target_2 = chat_target_2
|
chat_target_1=chat_target_1,
|
||||||
# mai_think.chat_info = chat_talking_prompt
|
chat_target_2=chat_target_2,
|
||||||
mai_think.mood_state = mood_prompt
|
mood_prompt=mood_prompt,
|
||||||
mai_think.identity = identity_block
|
identity_block=identity_block,
|
||||||
mai_think.sender = sender
|
sender=sender,
|
||||||
mai_think.target = target
|
target=target,
|
||||||
|
chat_info=f"""
|
||||||
mai_think.chat_info = f"""
|
|
||||||
{background_dialogue_prompt}
|
{background_dialogue_prompt}
|
||||||
--------------------------------
|
--------------------------------
|
||||||
{time_block}
|
{time_block}
|
||||||
这是你和{sender}的对话,你们正在交流中:
|
这是你和{sender}的对话,你们正在交流中:
|
||||||
{core_dialogue_prompt}"""
|
{core_dialogue_prompt}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 使用 s4u 风格的模板
|
# 使用 s4u 风格的模板
|
||||||
@@ -768,7 +817,7 @@ class DefaultReplyer:
|
|||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
tool_info_block=tool_info_block,
|
tool_info_block=tool_info,
|
||||||
knowledge_prompt=prompt_info,
|
knowledge_prompt=prompt_info,
|
||||||
memory_block=memory_block,
|
memory_block=memory_block,
|
||||||
relation_info_block=relation_info,
|
relation_info_block=relation_info,
|
||||||
@@ -787,17 +836,19 @@ class DefaultReplyer:
|
|||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
self.build_mai_think_context(
|
||||||
mai_think.memory_block = memory_block
|
chat_id=chat_id,
|
||||||
mai_think.relation_info_block = relation_info
|
memory_block=memory_block,
|
||||||
mai_think.time_block = time_block
|
relation_info=relation_info,
|
||||||
mai_think.chat_target = chat_target_1
|
time_block=time_block,
|
||||||
mai_think.chat_target_2 = chat_target_2
|
chat_target_1=chat_target_1,
|
||||||
mai_think.chat_info = chat_talking_prompt
|
chat_target_2=chat_target_2,
|
||||||
mai_think.mood_state = mood_prompt
|
mood_prompt=mood_prompt,
|
||||||
mai_think.identity = identity_block
|
identity_block=identity_block,
|
||||||
mai_think.sender = sender
|
sender=sender,
|
||||||
mai_think.target = target
|
target=target,
|
||||||
|
chat_info=chat_talking_prompt
|
||||||
|
)
|
||||||
|
|
||||||
# 使用原有的模式
|
# 使用原有的模式
|
||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
@@ -806,7 +857,7 @@ class DefaultReplyer:
|
|||||||
chat_target=chat_target_1,
|
chat_target=chat_target_1,
|
||||||
chat_info=chat_talking_prompt,
|
chat_info=chat_talking_prompt,
|
||||||
memory_block=memory_block,
|
memory_block=memory_block,
|
||||||
tool_info_block=tool_info_block,
|
tool_info_block=tool_info,
|
||||||
knowledge_prompt=prompt_info,
|
knowledge_prompt=prompt_info,
|
||||||
extra_info_block=extra_info_block,
|
extra_info_block=extra_info_block,
|
||||||
relation_info_block=relation_info,
|
relation_info_block=relation_info,
|
||||||
@@ -836,6 +887,13 @@ class DefaultReplyer:
|
|||||||
reason = reply_data.get("reason", "")
|
reason = reply_data.get("reason", "")
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
|
|
||||||
|
# 添加情绪状态获取
|
||||||
|
if global_config.mood.enable_mood:
|
||||||
|
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||||
|
mood_prompt = chat_mood.mood_state
|
||||||
|
else:
|
||||||
|
mood_prompt = ""
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -916,6 +974,7 @@ class DefaultReplyer:
|
|||||||
reply_target_block=reply_target_block,
|
reply_target_block=reply_target_block,
|
||||||
raw_reply=raw_reply,
|
raw_reply=raw_reply,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
|
mood_state=mood_prompt, # 添加情绪状态参数
|
||||||
config_expression_style=global_config.expression.expression_style,
|
config_expression_style=global_config.expression.expression_style,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
@@ -1012,7 +1071,10 @@ async def get_prompt_info(message: str, threshold: float):
|
|||||||
related_info += found_knowledge_from_lpmm
|
related_info += found_knowledge_from_lpmm
|
||||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||||
return related_info
|
|
||||||
|
# 格式化知识信息
|
||||||
|
formatted_prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=related_info)
|
||||||
|
return formatted_prompt_info
|
||||||
else:
|
else:
|
||||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -78,7 +78,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
# print(f"is_mentioned: {is_mentioned}")
|
# print(f"is_mentioned: {is_mentioned}")
|
||||||
# print(f"is_at: {is_at}")
|
# print(f"is_at: {is_at}")
|
||||||
|
|
||||||
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
|
if is_at and global_config.chat.at_bot_inevitable_reply:
|
||||||
reply_probability = 1.0
|
reply_probability = 1.0
|
||||||
logger.debug("被@,回复概率设置为100%")
|
logger.debug("被@,回复概率设置为100%")
|
||||||
else:
|
else:
|
||||||
@@ -103,7 +103,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
for nickname in nicknames:
|
for nickname in nicknames:
|
||||||
if nickname in message_content:
|
if nickname in message_content:
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
|
if is_mentioned and global_config.chat.mentioned_bot_inevitable_reply:
|
||||||
reply_probability = 1.0
|
reply_probability = 1.0
|
||||||
logger.debug("被提及,回复概率设置为100%")
|
logger.debug("被提及,回复概率设置为100%")
|
||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
@@ -619,9 +619,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
chat_target_info = None
|
chat_target_info = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
if chat_stream := get_chat_manager().get_stream(chat_id):
|
||||||
|
|
||||||
if chat_stream:
|
|
||||||
if chat_stream.group_info:
|
if chat_stream.group_info:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
chat_target_info = None # Explicitly None for group chat
|
chat_target_info = None # Explicitly None for group chat
|
||||||
@@ -660,8 +658,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
chat_target_info = target_info
|
chat_target_info = target_info
|
||||||
else:
|
else:
|
||||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||||
# Keep defaults: is_group_chat=False, chat_target_info=None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||||
# Keep defaults on error
|
# Keep defaults on error
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class ImageManager:
|
|||||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||||
|
|
||||||
async def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""获取表情包描述,带查重和保存功能"""
|
"""获取表情包描述,使用二步走识别并带缓存优化"""
|
||||||
try:
|
try:
|
||||||
# 计算图片哈希
|
# 计算图片哈希
|
||||||
# 确保base64字符串只包含ASCII字符
|
# 确保base64字符串只包含ASCII字符
|
||||||
@@ -107,33 +107,66 @@ class ImageManager:
|
|||||||
# 查询缓存的描述
|
# 查询缓存的描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
return f"[表情包,含义看起来是:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# === 二步走识别流程 ===
|
||||||
|
|
||||||
|
# 第一步:VLM视觉分析 - 生成详细描述
|
||||||
if image_format in ["gif", "GIF"]:
|
if image_format in ["gif", "GIF"]:
|
||||||
image_base64_processed = self.transform_gif(image_base64)
|
image_base64_processed = self.transform_gif(image_base64)
|
||||||
if image_base64_processed is None:
|
if image_base64_processed is None:
|
||||||
logger.warning("GIF转换失败,无法获取描述")
|
logger.warning("GIF转换失败,无法获取描述")
|
||||||
return "[表情包(GIF处理失败)]"
|
return "[表情包(GIF处理失败)]"
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
||||||
else:
|
else:
|
||||||
prompt = "图片是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
||||||
|
|
||||||
if description is None:
|
if detailed_description is None:
|
||||||
logger.warning("AI未能生成表情包描述")
|
logger.warning("VLM未能生成表情包详细描述")
|
||||||
return "[表情包(描述生成失败)]"
|
return "[表情包(VLM描述生成失败)]"
|
||||||
|
|
||||||
|
# 第二步:LLM情感分析 - 基于详细描述生成简短的情感标签
|
||||||
|
emotion_prompt = f"""
|
||||||
|
请你基于这个表情包的详细描述,提取出最核心的情感含义,用1-2个词概括。
|
||||||
|
详细描述:'{detailed_description}'
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 只输出1-2个最核心的情感词汇
|
||||||
|
2. 从互联网梗、meme的角度理解
|
||||||
|
3. 输出简短精准,不要解释
|
||||||
|
4. 如果有多个词用逗号分隔
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 使用较低温度确保输出稳定
|
||||||
|
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
|
||||||
|
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
|
||||||
|
|
||||||
|
if emotion_result is None:
|
||||||
|
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||||
|
# 降级处理:从详细描述中提取关键词
|
||||||
|
import jieba
|
||||||
|
words = list(jieba.cut(detailed_description))
|
||||||
|
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
||||||
|
|
||||||
|
# 处理情感结果,取前1-2个最重要的标签
|
||||||
|
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||||
|
final_emotion = emotions[0] if emotions else "表情"
|
||||||
|
|
||||||
|
# 如果有第二个情感且不重复,也包含进来
|
||||||
|
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||||
|
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||||
|
|
||||||
|
logger.info(f"[二步走识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||||
|
|
||||||
# 再次检查缓存,防止并发写入时重复生成
|
# 再次检查缓存,防止并发写入时重复生成
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||||
return f"[表情包,含义看起来是:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
# 根据配置决定是否保存图片
|
# 保存表情包文件和元数据(用于可能的后续分析)
|
||||||
# if global_config.emoji.save_emoji:
|
|
||||||
# 生成文件名和路径
|
|
||||||
logger.debug(f"保存表情包: {image_hash}")
|
logger.debug(f"保存表情包: {image_hash}")
|
||||||
current_timestamp = time.time()
|
current_timestamp = time.time()
|
||||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||||
@@ -146,11 +179,11 @@ class ImageManager:
|
|||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
|
|
||||||
# 保存到数据库 (Images表)
|
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||||
try:
|
try:
|
||||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||||
img_obj.path = file_path
|
img_obj.path = file_path
|
||||||
img_obj.description = description
|
img_obj.description = detailed_description # 保存详细描述
|
||||||
img_obj.timestamp = current_timestamp
|
img_obj.timestamp = current_timestamp
|
||||||
img_obj.save()
|
img_obj.save()
|
||||||
except Images.DoesNotExist: # type: ignore
|
except Images.DoesNotExist: # type: ignore
|
||||||
@@ -158,17 +191,17 @@ class ImageManager:
|
|||||||
emoji_hash=image_hash,
|
emoji_hash=image_hash,
|
||||||
path=file_path,
|
path=file_path,
|
||||||
type="emoji",
|
type="emoji",
|
||||||
description=description,
|
description=detailed_description, # 保存详细描述
|
||||||
timestamp=current_timestamp,
|
timestamp=current_timestamp,
|
||||||
)
|
)
|
||||||
# logger.debug(f"保存表情包元数据: {file_path}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||||
|
|
||||||
# 保存描述到数据库 (ImageDescriptions表)
|
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||||
self._save_description_to_db(image_hash, description, "emoji")
|
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||||
|
|
||||||
return f"[表情包:{description}]"
|
return f"[表情包:{final_emotion}]"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||||
return "[表情包]"
|
return "[表情包]"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ logger = get_logger("chat_voice")
|
|||||||
|
|
||||||
async def get_voice_text(voice_base64: str) -> str:
|
async def get_voice_text(voice_base64: str) -> str:
|
||||||
"""获取音频文件描述"""
|
"""获取音频文件描述"""
|
||||||
if not global_config.chat.enable_asr:
|
if not global_config.voice.enable_asr:
|
||||||
logger.warning("语音识别未启用,无法处理语音消息")
|
logger.warning("语音识别未启用,无法处理语音消息")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
if interested_rate > 0.2:
|
if interested_rate > 0.2:
|
||||||
current_willing += interested_rate - 0.2
|
current_willing += interested_rate - 0.2
|
||||||
|
|
||||||
if willing_info.is_mentioned_bot and global_config.normal_chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
||||||
current_willing += 1 if current_willing < 1.0 else 0.05
|
current_willing += 1 if current_willing < 1.0 else 0.05
|
||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
||||||
|
|||||||
@@ -306,6 +306,7 @@ class Expression(BaseModel):
|
|||||||
last_active_time = FloatField()
|
last_active_time = FloatField()
|
||||||
chat_id = TextField(index=True)
|
chat_id = TextField(index=True)
|
||||||
type = TextField()
|
type = TextField()
|
||||||
|
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "expression"
|
table_name = "expression"
|
||||||
@@ -449,9 +450,12 @@ def initialize_database():
|
|||||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||||
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
||||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||||
# 正确处理不同类型的默认值
|
# 正确处理不同类型的默认值,跳过lambda函数
|
||||||
default_value = field_obj.default
|
default_value = field_obj.default
|
||||||
if isinstance(default_value, str):
|
if callable(default_value):
|
||||||
|
# 跳过lambda函数或其他可调用对象,这些无法在SQL中表示
|
||||||
|
pass
|
||||||
|
elif isinstance(default_value, str):
|
||||||
alter_sql += f" DEFAULT '{default_value}'"
|
alter_sql += f" DEFAULT '{default_value}'"
|
||||||
elif isinstance(default_value, bool):
|
elif isinstance(default_value, bool):
|
||||||
alter_sql += f" DEFAULT {int(default_value)}"
|
alter_sql += f" DEFAULT {int(default_value)}"
|
||||||
|
|||||||
@@ -321,7 +321,7 @@ MODULE_COLORS = {
|
|||||||
# 核心模块
|
# 核心模块
|
||||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||||
"api": "\033[92m", # 亮绿色
|
"api": "\033[92m", # 亮绿色
|
||||||
"emoji": "\033[33m", # 亮绿色
|
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色但与replyer和action_manager不同
|
||||||
"chat": "\033[92m", # 亮蓝色
|
"chat": "\033[92m", # 亮蓝色
|
||||||
"config": "\033[93m", # 亮黄色
|
"config": "\033[93m", # 亮黄色
|
||||||
"common": "\033[95m", # 亮紫色
|
"common": "\033[95m", # 亮紫色
|
||||||
@@ -329,35 +329,33 @@ MODULE_COLORS = {
|
|||||||
"lpmm": "\033[96m",
|
"lpmm": "\033[96m",
|
||||||
"plugin_system": "\033[91m", # 亮红色
|
"plugin_system": "\033[91m", # 亮红色
|
||||||
"person_info": "\033[32m", # 绿色
|
"person_info": "\033[32m", # 绿色
|
||||||
"individuality": "\033[34m", # 蓝色
|
"individuality": "\033[94m", # 显眼的亮蓝色
|
||||||
"manager": "\033[35m", # 紫色
|
"manager": "\033[35m", # 紫色
|
||||||
"llm_models": "\033[36m", # 青色
|
"llm_models": "\033[36m", # 青色
|
||||||
"plugins": "\033[31m", # 红色
|
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||||
"plugin_api": "\033[33m", # 黄色
|
|
||||||
"remote": "\033[38;5;93m", # 紫蓝色
|
|
||||||
"planner": "\033[36m",
|
"planner": "\033[36m",
|
||||||
"memory": "\033[34m",
|
"memory": "\033[34m",
|
||||||
"hfc": "\033[96m",
|
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||||
"action_manager": "\033[38;5;166m",
|
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||||
# 关系系统
|
# 关系系统
|
||||||
"relation": "\033[38;5;201m", # 深粉色
|
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||||
# 聊天相关模块
|
# 聊天相关模块
|
||||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||||
"normal_chat_response": "\033[38;5;123m", # 青绿色
|
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||||
"heartflow": "\033[38;5;213m", # 粉色
|
|
||||||
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
||||||
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
||||||
"background_tasks": "\033[38;5;240m", # 灰色
|
"background_tasks": "\033[38;5;240m", # 灰色
|
||||||
"chat_message": "\033[38;5;45m", # 青色
|
"chat_message": "\033[38;5;45m", # 青色
|
||||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||||
"sender": "\033[38;5;39m", # 蓝色
|
"sender": "\033[38;5;67m", # 稍微暗一些的蓝色,不显眼
|
||||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||||
|
"expressor": "\033[38;5;166m", # 橙色
|
||||||
# 专注聊天模块
|
# 专注聊天模块
|
||||||
"replyer": "\033[38;5;166m", # 橙色
|
"replyer": "\033[38;5;166m", # 橙色
|
||||||
"base_processor": "\033[38;5;190m", # 绿黄色
|
|
||||||
"working_memory": "\033[38;5;22m", # 深绿色
|
|
||||||
"memory_activator": "\033[34m", # 绿色
|
"memory_activator": "\033[34m", # 绿色
|
||||||
# 插件系统
|
# 插件系统
|
||||||
|
"plugins": "\033[31m", # 红色
|
||||||
|
"plugin_api": "\033[33m", # 黄色
|
||||||
"plugin_manager": "\033[38;5;208m", # 红色
|
"plugin_manager": "\033[38;5;208m", # 红色
|
||||||
"base_plugin": "\033[38;5;202m", # 橙红色
|
"base_plugin": "\033[38;5;202m", # 橙红色
|
||||||
"send_api": "\033[38;5;208m", # 橙色
|
"send_api": "\033[38;5;208m", # 橙色
|
||||||
@@ -378,9 +376,9 @@ MODULE_COLORS = {
|
|||||||
"local_storage": "\033[38;5;141m", # 紫色
|
"local_storage": "\033[38;5;141m", # 紫色
|
||||||
"willing": "\033[38;5;147m", # 浅紫色
|
"willing": "\033[38;5;147m", # 浅紫色
|
||||||
# 工具模块
|
# 工具模块
|
||||||
"tool_use": "\033[38;5;64m", # 深绿色
|
"tool_use": "\033[38;5;172m", # 橙褐色
|
||||||
"tool_executor": "\033[38;5;64m", # 深绿色
|
"tool_executor": "\033[38;5;172m", # 橙褐色
|
||||||
"base_tool": "\033[38;5;70m", # 绿色
|
"base_tool": "\033[38;5;178m", # 金黄色
|
||||||
# 工具和实用模块
|
# 工具和实用模块
|
||||||
"prompt_build": "\033[38;5;105m", # 紫色
|
"prompt_build": "\033[38;5;105m", # 紫色
|
||||||
"chat_utils": "\033[38;5;111m", # 蓝色
|
"chat_utils": "\033[38;5;111m", # 蓝色
|
||||||
@@ -388,14 +386,16 @@ MODULE_COLORS = {
|
|||||||
"maibot_statistic": "\033[38;5;129m", # 紫色
|
"maibot_statistic": "\033[38;5;129m", # 紫色
|
||||||
# 特殊功能插件
|
# 特殊功能插件
|
||||||
"mute_plugin": "\033[38;5;240m", # 灰色
|
"mute_plugin": "\033[38;5;240m", # 灰色
|
||||||
"example_comprehensive": "\033[38;5;246m", # 浅灰色
|
|
||||||
"core_actions": "\033[38;5;117m", # 深红色
|
"core_actions": "\033[38;5;117m", # 深红色
|
||||||
"tts_action": "\033[38;5;58m", # 深黄色
|
"tts_action": "\033[38;5;58m", # 深黄色
|
||||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||||
"vtb_action": "\033[38;5;70m", # 绿色
|
# Action组件
|
||||||
|
"no_reply_action": "\033[38;5;196m", # 亮红色,更显眼
|
||||||
|
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||||
|
"base_action": "\033[38;5;250m", # 浅灰色
|
||||||
# 数据库和消息
|
# 数据库和消息
|
||||||
"database_model": "\033[38;5;94m", # 橙褐色
|
"database_model": "\033[38;5;94m", # 橙褐色
|
||||||
"maim_message": "\033[38;5;100m", # 绿褐色
|
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||||
# 日志系统
|
# 日志系统
|
||||||
"logger": "\033[38;5;8m", # 深灰色
|
"logger": "\033[38;5;8m", # 深灰色
|
||||||
"confirm": "\033[1;93m", # 黄色+粗体
|
"confirm": "\033[1;93m", # 黄色+粗体
|
||||||
@@ -409,6 +409,34 @@ MODULE_COLORS = {
|
|||||||
"S4U_chat": "\033[92m", # 深灰色
|
"S4U_chat": "\033[92m", # 深灰色
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
|
||||||
|
MODULE_ALIASES = {
|
||||||
|
# 示例映射
|
||||||
|
"individuality": "人格特质",
|
||||||
|
"emoji": "表情包",
|
||||||
|
"no_reply_action": "摸鱼",
|
||||||
|
"reply_action": "回复",
|
||||||
|
"action_manager": "动作",
|
||||||
|
"memory_activator": "记忆",
|
||||||
|
"tool_use": "工具",
|
||||||
|
"expressor": "表达方式",
|
||||||
|
"database_model": "数据库",
|
||||||
|
"mood": "情绪",
|
||||||
|
"memory": "记忆",
|
||||||
|
"tool_executor": "工具",
|
||||||
|
"hfc": "聊天节奏",
|
||||||
|
"chat": "所见",
|
||||||
|
"plugin_manager": "插件",
|
||||||
|
"relationship_builder": "关系",
|
||||||
|
"llm_models": "模型",
|
||||||
|
"person_info": "人物",
|
||||||
|
"chat_stream": "聊天流",
|
||||||
|
"planner": "规划器",
|
||||||
|
"replyer": "言语",
|
||||||
|
"config": "配置",
|
||||||
|
"main": "主程序",
|
||||||
|
}
|
||||||
|
|
||||||
RESET_COLOR = "\033[0m"
|
RESET_COLOR = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
@@ -497,15 +525,18 @@ class ModuleColoredConsoleRenderer:
|
|||||||
if self._colors and self._enable_module_colors and logger_name:
|
if self._colors and self._enable_module_colors and logger_name:
|
||||||
module_color = MODULE_COLORS.get(logger_name, "")
|
module_color = MODULE_COLORS.get(logger_name, "")
|
||||||
|
|
||||||
# 模块名称(带颜色)
|
# 模块名称(带颜色和别名支持)
|
||||||
if logger_name:
|
if logger_name:
|
||||||
|
# 获取别名,如果没有别名则使用原名称
|
||||||
|
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
||||||
|
|
||||||
if self._colors and self._enable_module_colors:
|
if self._colors and self._enable_module_colors:
|
||||||
if module_color:
|
if module_color:
|
||||||
module_part = f"{module_color}[{logger_name}]{RESET_COLOR}"
|
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
||||||
else:
|
else:
|
||||||
module_part = f"[{logger_name}]"
|
module_part = f"[{display_name}]"
|
||||||
else:
|
else:
|
||||||
module_part = f"[{logger_name}]"
|
module_part = f"[{display_name}]"
|
||||||
parts.append(module_part)
|
parts.append(module_part)
|
||||||
|
|
||||||
# 消息内容(确保转换为字符串)
|
# 消息内容(确保转换为字符串)
|
||||||
@@ -715,19 +746,7 @@ def configure_logging(
|
|||||||
root_logger.setLevel(getattr(logging, level.upper()))
|
root_logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
|
||||||
|
|
||||||
def set_module_color(module_name: str, color_code: str):
|
|
||||||
"""为指定模块设置颜色
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module_name: 模块名称
|
|
||||||
color_code: ANSI颜色代码,例如 '\033[92m' 表示亮绿色
|
|
||||||
"""
|
|
||||||
MODULE_COLORS[module_name] = color_code
|
|
||||||
|
|
||||||
|
|
||||||
def get_module_colors():
|
|
||||||
"""获取当前模块颜色配置"""
|
|
||||||
return MODULE_COLORS.copy()
|
|
||||||
|
|
||||||
|
|
||||||
def reload_log_config():
|
def reload_log_config():
|
||||||
@@ -918,9 +937,20 @@ def show_module_colors():
|
|||||||
for module_name, _color_code in MODULE_COLORS.items():
|
for module_name, _color_code in MODULE_COLORS.items():
|
||||||
# 临时创建一个该模块的logger来展示颜色
|
# 临时创建一个该模块的logger来展示颜色
|
||||||
demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name)
|
demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name)
|
||||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
|
alias = MODULE_ALIASES.get(module_name, module_name)
|
||||||
|
if alias != module_name:
|
||||||
|
demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})")
|
||||||
|
else:
|
||||||
|
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
|
||||||
|
|
||||||
print("=== 颜色展示结束 ===\n")
|
print("=== 颜色展示结束 ===\n")
|
||||||
|
|
||||||
|
# 显示别名映射表
|
||||||
|
if MODULE_ALIASES:
|
||||||
|
print("=== 当前别名映射 ===")
|
||||||
|
for module_name, alias in MODULE_ALIASES.items():
|
||||||
|
print(f" {module_name} -> {alias}")
|
||||||
|
print("=== 别名映射结束 ===\n")
|
||||||
|
|
||||||
|
|
||||||
def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from src.config.official_configs import (
|
|||||||
LPMMKnowledgeConfig,
|
LPMMKnowledgeConfig,
|
||||||
RelationshipConfig,
|
RelationshipConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
VoiceConfig,
|
||||||
DebugConfig,
|
DebugConfig,
|
||||||
CustomPromptConfig,
|
CustomPromptConfig,
|
||||||
)
|
)
|
||||||
@@ -64,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||||||
|
|
||||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||||
MMC_VERSION = "0.9.0-snapshot.2"
|
MMC_VERSION = "0.9.1"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -616,7 +617,7 @@ class Config(ConfigBase):
|
|||||||
tool: ToolConfig
|
tool: ToolConfig
|
||||||
debug: DebugConfig
|
debug: DebugConfig
|
||||||
custom_prompt: CustomPromptConfig
|
custom_prompt: CustomPromptConfig
|
||||||
|
voice: VoiceConfig
|
||||||
|
|
||||||
def load_config(config_path: str) -> Config:
|
def load_config(config_path: str) -> Config:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ from packaging.version import Version
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BotConfig(ConfigBase):
|
class BotConfig(ConfigBase):
|
||||||
"""QQ机器人配置类"""
|
"""QQ机器人配置类"""
|
||||||
|
|
||||||
|
platform: str
|
||||||
|
"""平台"""
|
||||||
|
|
||||||
qq_account: str
|
qq_account: str
|
||||||
"""QQ账号"""
|
"""QQ账号"""
|
||||||
@@ -82,6 +85,12 @@ class ChatConfig(ConfigBase):
|
|||||||
use_s4u_prompt_mode: bool = False
|
use_s4u_prompt_mode: bool = False
|
||||||
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
|
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
|
||||||
|
|
||||||
|
mentioned_bot_inevitable_reply: bool = False
|
||||||
|
"""提及 bot 必然回复"""
|
||||||
|
|
||||||
|
at_bot_inevitable_reply: bool = False
|
||||||
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
# 修改:基于时段的回复频率配置,改为数组格式
|
# 修改:基于时段的回复频率配置,改为数组格式
|
||||||
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
|
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
|
||||||
"""
|
"""
|
||||||
@@ -107,9 +116,6 @@ class ChatConfig(ConfigBase):
|
|||||||
focus_value: float = 1.0
|
focus_value: float = 1.0
|
||||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||||
|
|
||||||
enable_asr: bool = False
|
|
||||||
"""是否启用语音识别"""
|
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
@@ -271,11 +277,7 @@ class NormalChatConfig(ConfigBase):
|
|||||||
response_interested_rate_amplifier: float = 1.0
|
response_interested_rate_amplifier: float = 1.0
|
||||||
"""回复兴趣度放大系数"""
|
"""回复兴趣度放大系数"""
|
||||||
|
|
||||||
mentioned_bot_inevitable_reply: bool = False
|
|
||||||
"""提及 bot 必然回复"""
|
|
||||||
|
|
||||||
at_bot_inevitable_reply: bool = False
|
|
||||||
"""@bot 必然回复"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -310,6 +312,13 @@ class ToolConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_in_focus_chat: bool = True
|
enable_in_focus_chat: bool = True
|
||||||
"""是否在专注聊天中启用工具"""
|
"""是否在专注聊天中启用工具"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VoiceConfig(ConfigBase):
|
||||||
|
"""语音识别配置类"""
|
||||||
|
|
||||||
|
enable_asr: bool = False
|
||||||
|
"""是否启用语音识别"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -400,15 +409,9 @@ class MoodConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_mood: bool = False
|
enable_mood: bool = False
|
||||||
"""是否启用情绪系统"""
|
"""是否启用情绪系统"""
|
||||||
|
|
||||||
mood_update_interval: int = 1
|
mood_update_threshold: float = 1.0
|
||||||
"""情绪更新间隔(秒)"""
|
"""情绪更新阈值,越高,更新越慢"""
|
||||||
|
|
||||||
mood_decay_rate: float = 0.95
|
|
||||||
"""情绪衰减率"""
|
|
||||||
|
|
||||||
mood_intensity_factor: float = 0.7
|
|
||||||
"""情绪强度因子"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import ast
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import time
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -9,8 +9,6 @@ from src.llm_models.utils_model import LLMRequest
|
|||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from .personality import Personality
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("individuality")
|
logger = get_logger("individuality")
|
||||||
@@ -20,12 +18,10 @@ class Individuality:
|
|||||||
"""个体特征管理类"""
|
"""个体特征管理类"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 正常初始化实例属性
|
|
||||||
self.personality: Personality = None # type: ignore
|
|
||||||
|
|
||||||
self.name = ""
|
self.name = ""
|
||||||
self.bot_person_id = ""
|
self.bot_person_id = ""
|
||||||
self.meta_info_file_path = "data/personality/meta.json"
|
self.meta_info_file_path = "data/personality/meta.json"
|
||||||
|
self.personality_data_file_path = "data/personality/personality_data.json"
|
||||||
|
|
||||||
self.model = LLMRequest(
|
self.model = LLMRequest(
|
||||||
model=global_config.model.utils,
|
model=global_config.model.utils,
|
||||||
@@ -33,20 +29,13 @@ class Individuality:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化个体特征
|
"""初始化个体特征"""
|
||||||
|
|
||||||
Args:
|
|
||||||
bot_nickname: 机器人昵称
|
|
||||||
personality_core: 人格核心特点
|
|
||||||
personality_side: 人格侧面描述
|
|
||||||
identity: 身份细节描述
|
|
||||||
"""
|
|
||||||
bot_nickname = global_config.bot.nickname
|
bot_nickname = global_config.bot.nickname
|
||||||
personality_core = global_config.personality.personality_core
|
personality_core = global_config.personality.personality_core
|
||||||
personality_side = global_config.personality.personality_side
|
personality_side = global_config.personality.personality_side
|
||||||
identity = global_config.personality.identity
|
identity = global_config.personality.identity
|
||||||
|
|
||||||
logger.info("正在初始化个体特征")
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
self.name = bot_nickname
|
self.name = bot_nickname
|
||||||
@@ -56,129 +45,61 @@ class Individuality:
|
|||||||
bot_nickname, personality_core, personality_side, identity
|
bot_nickname, personality_core, personality_side, identity
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化人格(现在包含身份)
|
logger.info("正在构建人设信息")
|
||||||
self.personality = Personality.initialize(
|
|
||||||
bot_nickname=bot_nickname,
|
|
||||||
personality_core=personality_core,
|
|
||||||
personality_side=personality_side,
|
|
||||||
identity=identity,
|
|
||||||
compress_personality=global_config.personality.compress_personality,
|
|
||||||
compress_identity=global_config.personality.compress_identity,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("正在将所有人设写入impression")
|
# 如果配置有变化,重新生成压缩版本
|
||||||
# 将所有人设写入impression
|
if personality_changed or identity_changed:
|
||||||
impression_parts = []
|
logger.info("检测到配置变化,重新生成压缩版本")
|
||||||
if personality_core:
|
personality_result = await self._create_personality(personality_core, personality_side)
|
||||||
impression_parts.append(f"核心人格: {personality_core}")
|
identity_result = await self._create_identity(identity)
|
||||||
if personality_side:
|
else:
|
||||||
impression_parts.append(f"人格侧面: {personality_side}")
|
logger.info("配置未变化,使用缓存版本")
|
||||||
if identity:
|
# 从文件中获取已有的结果
|
||||||
impression_parts.append(f"身份: {identity}")
|
personality_result, identity_result = self._get_personality_from_file()
|
||||||
logger.info(f"impression_parts: {impression_parts}")
|
if not personality_result or not identity_result:
|
||||||
|
logger.info("未找到有效缓存,重新生成")
|
||||||
|
personality_result = await self._create_personality(personality_core, personality_side)
|
||||||
|
identity_result = await self._create_identity(identity)
|
||||||
|
|
||||||
impression_text = "。".join(impression_parts)
|
# 保存到文件
|
||||||
if impression_text:
|
if personality_result and identity_result:
|
||||||
impression_text += "。"
|
self._save_personality_to_file(personality_result, identity_result)
|
||||||
|
logger.info("已将人设构建并保存到文件")
|
||||||
|
else:
|
||||||
|
logger.error("人设构建失败")
|
||||||
|
|
||||||
if impression_text:
|
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||||
|
if personality_changed or identity_changed:
|
||||||
|
logger.info("将清空数据库中原有的关键词缓存")
|
||||||
update_data = {
|
update_data = {
|
||||||
"platform": "system",
|
"platform": "system",
|
||||||
"user_id": "bot_id",
|
"user_id": "bot_id",
|
||||||
"person_name": self.name,
|
"person_name": self.name,
|
||||||
"nickname": self.name,
|
"nickname": self.name,
|
||||||
}
|
}
|
||||||
|
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||||
await person_info_manager.update_one_field(
|
|
||||||
self.bot_person_id, "impression", impression_text, data=update_data
|
|
||||||
)
|
|
||||||
logger.debug("已将完整人设更新到bot的impression中")
|
|
||||||
|
|
||||||
# 根据变化情况决定是否重新创建
|
|
||||||
personality_result = None
|
|
||||||
identity_result = None
|
|
||||||
|
|
||||||
if personality_changed:
|
|
||||||
logger.info("检测到人格配置变化,重新生成压缩版本")
|
|
||||||
personality_result = await self._create_personality(personality_core, personality_side)
|
|
||||||
else:
|
|
||||||
logger.info("人格配置未变化,使用缓存版本")
|
|
||||||
# 从缓存中获取已有的personality结果
|
|
||||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
|
||||||
if existing_short_impression:
|
|
||||||
try:
|
|
||||||
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
|
||||||
if isinstance(existing_data, list) and len(existing_data) >= 1:
|
|
||||||
personality_result = existing_data[0]
|
|
||||||
except (json.JSONDecodeError, TypeError, IndexError):
|
|
||||||
logger.warning("无法解析现有的short_impression,将重新生成人格部分")
|
|
||||||
personality_result = await self._create_personality(personality_core, personality_side)
|
|
||||||
else:
|
|
||||||
logger.info("未找到现有的人格缓存,重新生成")
|
|
||||||
personality_result = await self._create_personality(personality_core, personality_side)
|
|
||||||
|
|
||||||
if identity_changed:
|
|
||||||
logger.info("检测到身份配置变化,重新生成压缩版本")
|
|
||||||
identity_result = await self._create_identity(identity)
|
|
||||||
else:
|
|
||||||
logger.info("身份配置未变化,使用缓存版本")
|
|
||||||
# 从缓存中获取已有的identity结果
|
|
||||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
|
||||||
if existing_short_impression:
|
|
||||||
try:
|
|
||||||
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
|
||||||
if isinstance(existing_data, list) and len(existing_data) >= 2:
|
|
||||||
identity_result = existing_data[1]
|
|
||||||
except (json.JSONDecodeError, TypeError, IndexError):
|
|
||||||
logger.warning("无法解析现有的short_impression,将重新生成身份部分")
|
|
||||||
identity_result = await self._create_identity(identity)
|
|
||||||
else:
|
|
||||||
logger.info("未找到现有的身份缓存,重新生成")
|
|
||||||
identity_result = await self._create_identity(identity)
|
|
||||||
|
|
||||||
result = [personality_result, identity_result]
|
|
||||||
|
|
||||||
# 更新short_impression字段
|
|
||||||
if personality_result and identity_result:
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
await person_info_manager.update_one_field(self.bot_person_id, "short_impression", result)
|
|
||||||
logger.info("已将人设构建")
|
|
||||||
else:
|
|
||||||
logger.error("人设构建失败")
|
|
||||||
|
|
||||||
async def get_personality_block(self) -> str:
|
async def get_personality_block(self) -> str:
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
|
||||||
|
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
if global_config.bot.alias_names:
|
if global_config.bot.alias_names:
|
||||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||||
else:
|
else:
|
||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
|
||||||
# 解析字符串形式的Python列表
|
# 从文件获取 short_impression
|
||||||
try:
|
personality, identity = self._get_personality_from_file()
|
||||||
if isinstance(short_impression, str) and short_impression.strip():
|
|
||||||
short_impression = ast.literal_eval(short_impression)
|
|
||||||
elif not short_impression:
|
|
||||||
logger.warning("short_impression为空,使用默认值")
|
|
||||||
short_impression = ["友好活泼", "人类"]
|
|
||||||
except (ValueError, SyntaxError) as e:
|
|
||||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
|
||||||
short_impression = ["友好活泼", "人类"]
|
|
||||||
# 确保short_impression是列表格式且有足够的元素
|
# 确保short_impression是列表格式且有足够的元素
|
||||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
if not personality or not identity:
|
||||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
|
||||||
short_impression = ["友好活泼", "人类"]
|
personality = "友好活泼"
|
||||||
personality = short_impression[0]
|
identity = "人类"
|
||||||
identity = short_impression[1]
|
|
||||||
prompt_personality = f"{personality},{identity}"
|
prompt_personality = f"{personality}\n{identity}"
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||||
|
|
||||||
return identity_block
|
|
||||||
|
|
||||||
def _get_config_hash(
|
def _get_config_hash(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""获取personality和identity配置的哈希值
|
"""获取personality和identity配置的哈希值
|
||||||
|
|
||||||
@@ -190,15 +111,15 @@ class Individuality:
|
|||||||
"nickname": bot_nickname,
|
"nickname": bot_nickname,
|
||||||
"personality_core": personality_core,
|
"personality_core": personality_core,
|
||||||
"personality_side": personality_side,
|
"personality_side": personality_side,
|
||||||
"compress_personality": self.personality.compress_personality if self.personality else True,
|
"compress_personality": global_config.personality.compress_personality,
|
||||||
}
|
}
|
||||||
personality_str = json.dumps(personality_config, sort_keys=True)
|
personality_str = json.dumps(personality_config, sort_keys=True)
|
||||||
personality_hash = hashlib.md5(personality_str.encode("utf-8")).hexdigest()
|
personality_hash = hashlib.md5(personality_str.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
# 身份配置哈希
|
# 身份配置哈希
|
||||||
identity_config = {
|
identity_config = {
|
||||||
"identity": sorted(identity),
|
"identity": identity,
|
||||||
"compress_identity": self.personality.compress_identity if self.personality else True,
|
"compress_identity": global_config.personality.compress_identity,
|
||||||
}
|
}
|
||||||
identity_str = json.dumps(identity_config, sort_keys=True)
|
identity_str = json.dumps(identity_config, sort_keys=True)
|
||||||
identity_hash = hashlib.md5(identity_str.encode("utf-8")).hexdigest()
|
identity_hash = hashlib.md5(identity_str.encode("utf-8")).hexdigest()
|
||||||
@@ -206,7 +127,7 @@ class Individuality:
|
|||||||
return personality_hash, identity_hash
|
return personality_hash, identity_hash
|
||||||
|
|
||||||
async def _check_config_and_clear_if_changed(
|
async def _check_config_and_clear_if_changed(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||||
) -> tuple[bool, bool]:
|
) -> tuple[bool, bool]:
|
||||||
"""检查配置是否发生变化,如果变化则清空相应缓存
|
"""检查配置是否发生变化,如果变化则清空相应缓存
|
||||||
|
|
||||||
@@ -271,6 +192,53 @@ class Individuality:
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.error(f"保存meta_info文件失败: {e}")
|
logger.error(f"保存meta_info文件失败: {e}")
|
||||||
|
|
||||||
|
def _load_personality_data(self) -> dict:
|
||||||
|
"""从JSON文件中加载personality数据"""
|
||||||
|
if os.path.exists(self.personality_data_file_path):
|
||||||
|
try:
|
||||||
|
with open(self.personality_data_file_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.error(f"读取personality_data文件失败: {e}, 将创建新文件。")
|
||||||
|
return {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save_personality_data(self, personality_data: dict):
|
||||||
|
"""将personality数据保存到JSON文件"""
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(self.personality_data_file_path), exist_ok=True)
|
||||||
|
with open(self.personality_data_file_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(personality_data, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"已保存personality数据到文件: {self.personality_data_file_path}")
|
||||||
|
except IOError as e:
|
||||||
|
logger.error(f"保存personality_data文件失败: {e}")
|
||||||
|
|
||||||
|
def _get_personality_from_file(self) -> tuple[str, str]:
|
||||||
|
"""从文件获取personality数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (personality, identity)
|
||||||
|
"""
|
||||||
|
personality_data = self._load_personality_data()
|
||||||
|
personality = personality_data.get("personality", "友好活泼")
|
||||||
|
identity = personality_data.get("identity", "人类")
|
||||||
|
return personality, identity
|
||||||
|
|
||||||
|
def _save_personality_to_file(self, personality: str, identity: str):
|
||||||
|
"""保存personality数据到文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
personality: 压缩后的人格描述
|
||||||
|
identity: 压缩后的身份描述
|
||||||
|
"""
|
||||||
|
personality_data = {
|
||||||
|
"personality": personality,
|
||||||
|
"identity": identity,
|
||||||
|
"bot_nickname": self.name,
|
||||||
|
"last_updated": int(time.time())
|
||||||
|
}
|
||||||
|
self._save_personality_data(personality_data)
|
||||||
|
|
||||||
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
||||||
# sourcery skip: merge-list-append, move-assign
|
# sourcery skip: merge-list-append, move-assign
|
||||||
"""使用LLM创建压缩版本的impression
|
"""使用LLM创建压缩版本的impression
|
||||||
@@ -290,7 +258,7 @@ class Individuality:
|
|||||||
personality_parts.append(f"{personality_core}")
|
personality_parts.append(f"{personality_core}")
|
||||||
|
|
||||||
# 准备需要压缩的内容
|
# 准备需要压缩的内容
|
||||||
if self.personality.compress_personality:
|
if global_config.personality.compress_personality:
|
||||||
personality_to_compress = f"人格特质: {personality_side}"
|
personality_to_compress = f"人格特质: {personality_side}"
|
||||||
|
|
||||||
prompt = f"""请将以下人格信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
prompt = f"""请将以下人格信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
||||||
@@ -321,11 +289,11 @@ class Individuality:
|
|||||||
|
|
||||||
return personality_result
|
return personality_result
|
||||||
|
|
||||||
async def _create_identity(self, identity: list) -> str:
|
async def _create_identity(self, identity: str) -> str:
|
||||||
"""使用LLM创建压缩版本的impression"""
|
"""使用LLM创建压缩版本的impression"""
|
||||||
logger.info("正在构建身份.........")
|
logger.info("正在构建身份.........")
|
||||||
|
|
||||||
if self.personality.compress_identity:
|
if global_config.personality.compress_identity:
|
||||||
identity_to_compress = f"身份背景: {identity}"
|
identity_to_compress = f"身份背景: {identity}"
|
||||||
|
|
||||||
prompt = f"""请将以下身份信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
prompt = f"""请将以下身份信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Personality:
|
|
||||||
"""人格特质类"""
|
|
||||||
|
|
||||||
bot_nickname: str # 机器人昵称
|
|
||||||
personality_core: str # 人格核心特点
|
|
||||||
personality_side: str # 人格侧面描述
|
|
||||||
identity: List[str] # 身份细节描述
|
|
||||||
compress_personality: bool # 是否压缩人格
|
|
||||||
compress_identity: bool # 是否压缩身份
|
|
||||||
|
|
||||||
_instance = None
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self, personality_core: str = "", personality_side: str = "", identity: List[str] = None):
|
|
||||||
self.personality_core = personality_core
|
|
||||||
self.personality_side = personality_side
|
|
||||||
self.identity = identity
|
|
||||||
self.compress_personality = True
|
|
||||||
self.compress_identity = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> "Personality":
|
|
||||||
"""获取Personality单例实例
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Personality: 单例实例
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize(
|
|
||||||
cls,
|
|
||||||
bot_nickname: str,
|
|
||||||
personality_core: str,
|
|
||||||
personality_side: str,
|
|
||||||
identity: List[str] = None,
|
|
||||||
compress_personality: bool = True,
|
|
||||||
compress_identity: bool = True,
|
|
||||||
) -> "Personality":
|
|
||||||
"""初始化人格特质
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bot_nickname: 机器人昵称
|
|
||||||
personality_core: 人格核心特点
|
|
||||||
personality_side: 人格侧面描述
|
|
||||||
identity: 身份细节描述
|
|
||||||
compress_personality: 是否压缩人格
|
|
||||||
compress_identity: 是否压缩身份
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Personality: 初始化后的人格特质实例
|
|
||||||
"""
|
|
||||||
instance = cls.get_instance()
|
|
||||||
instance.bot_nickname = bot_nickname
|
|
||||||
instance.personality_core = personality_core
|
|
||||||
instance.personality_side = personality_side
|
|
||||||
instance.identity = identity
|
|
||||||
instance.compress_personality = compress_personality
|
|
||||||
instance.compress_identity = compress_identity
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""将人格特质转换为字典格式"""
|
|
||||||
return {
|
|
||||||
"bot_nickname": self.bot_nickname,
|
|
||||||
"personality_core": self.personality_core,
|
|
||||||
"personality_side": self.personality_side,
|
|
||||||
"identity": self.identity,
|
|
||||||
"compress_personality": self.compress_personality,
|
|
||||||
"compress_identity": self.compress_identity,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "Personality":
|
|
||||||
"""从字典创建人格特质实例"""
|
|
||||||
instance = cls.get_instance()
|
|
||||||
for key, value in data.items():
|
|
||||||
setattr(instance, key, value)
|
|
||||||
return instance
|
|
||||||
@@ -10,6 +10,7 @@ import base64
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import copy # 添加copy模块用于深拷贝
|
||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -69,23 +70,28 @@ error_code_mapping = {
|
|||||||
|
|
||||||
|
|
||||||
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||||
|
"""安全地记录请求体,用于调试日志,不会修改原始payload对象"""
|
||||||
|
# 创建payload的深拷贝,避免修改原始对象
|
||||||
|
safe_payload = copy.deepcopy(payload)
|
||||||
|
|
||||||
image_base64: str = request_content.get("image_base64")
|
image_base64: str = request_content.get("image_base64")
|
||||||
image_format: str = request_content.get("image_format")
|
image_format: str = request_content.get("image_format")
|
||||||
if (
|
if (
|
||||||
image_base64
|
image_base64
|
||||||
and payload
|
and safe_payload
|
||||||
and isinstance(payload, dict)
|
and isinstance(safe_payload, dict)
|
||||||
and "messages" in payload
|
and "messages" in safe_payload
|
||||||
and len(payload["messages"]) > 0
|
and len(safe_payload["messages"]) > 0
|
||||||
):
|
):
|
||||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
|
||||||
content = payload["messages"][0]["content"]
|
content = safe_payload["messages"][0]["content"]
|
||||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
# 只修改拷贝的对象,用于安全的日志记录
|
||||||
|
safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||||
)
|
)
|
||||||
return payload
|
return safe_payload
|
||||||
|
|
||||||
|
|
||||||
class LLMRequest:
|
class LLMRequest:
|
||||||
@@ -109,10 +115,15 @@ class LLMRequest:
|
|||||||
|
|
||||||
def __init__(self, model: dict, **kwargs):
|
def __init__(self, model: dict, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
|
logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}")
|
||||||
|
logger.debug(f"🔍 [模型初始化] 模型配置: {model}")
|
||||||
|
logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# print(f"model['provider']: {model['provider']}")
|
# print(f"model['provider']: {model['provider']}")
|
||||||
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
||||||
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
||||||
|
logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logger.error(f"原始 model dict 信息:{model}")
|
logger.error(f"原始 model dict 信息:{model}")
|
||||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
@@ -124,6 +135,10 @@ class LLMRequest:
|
|||||||
self.model_name: str = model["name"]
|
self.model_name: str = model["name"]
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
|
# 记录配置文件中声明了哪些参数(不管值是什么)
|
||||||
|
self.has_enable_thinking = "enable_thinking" in model
|
||||||
|
self.has_thinking_budget = "thinking_budget" in model
|
||||||
|
|
||||||
self.enable_thinking = model.get("enable_thinking", False)
|
self.enable_thinking = model.get("enable_thinking", False)
|
||||||
self.temp = model.get("temp", 0.7)
|
self.temp = model.get("temp", 0.7)
|
||||||
self.thinking_budget = model.get("thinking_budget", 4096)
|
self.thinking_budget = model.get("thinking_budget", 4096)
|
||||||
@@ -132,12 +147,24 @@ class LLMRequest:
|
|||||||
self.pri_out = model.get("pri_out", 0)
|
self.pri_out = model.get("pri_out", 0)
|
||||||
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
|
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
|
||||||
# print(f"max_tokens: {self.max_tokens}")
|
# print(f"max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
|
logger.debug(f"🔍 [模型初始化] 模型参数设置完成:")
|
||||||
|
logger.debug(f" - model_name: {self.model_name}")
|
||||||
|
logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}")
|
||||||
|
logger.debug(f" - enable_thinking: {self.enable_thinking}")
|
||||||
|
logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}")
|
||||||
|
logger.debug(f" - thinking_budget: {self.thinking_budget}")
|
||||||
|
logger.debug(f" - temp: {self.temp}")
|
||||||
|
logger.debug(f" - stream: {self.stream}")
|
||||||
|
logger.debug(f" - max_tokens: {self.max_tokens}")
|
||||||
|
logger.debug(f" - base_url: {self.base_url}")
|
||||||
|
|
||||||
# 获取数据库实例
|
# 获取数据库实例
|
||||||
self._init_database()
|
self._init_database()
|
||||||
|
|
||||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||||
self.request_type = kwargs.pop("request_type", "default")
|
self.request_type = kwargs.pop("request_type", "default")
|
||||||
|
logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_database():
|
def _init_database():
|
||||||
@@ -262,11 +289,12 @@ class LLMRequest:
|
|||||||
if self.temp != 0.7:
|
if self.temp != 0.7:
|
||||||
payload["temperature"] = self.temp
|
payload["temperature"] = self.temp
|
||||||
|
|
||||||
# 添加enable_thinking参数(如果不是默认值False)
|
# 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
|
||||||
if not self.enable_thinking:
|
if self.has_enable_thinking:
|
||||||
payload["enable_thinking"] = False
|
payload["enable_thinking"] = self.enable_thinking
|
||||||
|
|
||||||
if self.thinking_budget != 4096:
|
# 添加thinking_budget参数(只有配置文件中声明了才添加)
|
||||||
|
if self.has_thinking_budget:
|
||||||
payload["thinking_budget"] = self.thinking_budget
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
|
|
||||||
if self.max_tokens:
|
if self.max_tokens:
|
||||||
@@ -334,6 +362,19 @@ class LLMRequest:
|
|||||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||||
if request_content["stream_mode"]:
|
if request_content["stream_mode"]:
|
||||||
headers["Accept"] = "text/event-stream"
|
headers["Accept"] = "text/event-stream"
|
||||||
|
|
||||||
|
# 添加请求发送前的调试信息
|
||||||
|
logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求")
|
||||||
|
logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}")
|
||||||
|
logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}")
|
||||||
|
|
||||||
|
if not file_bytes:
|
||||||
|
# 安全地记录请求体(隐藏敏感信息)
|
||||||
|
safe_payload = await _safely_record(request_content, request_content["payload"])
|
||||||
|
logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||||
post_kwargs = {"headers": headers}
|
post_kwargs = {"headers": headers}
|
||||||
# form-data数据上传方式不同
|
# form-data数据上传方式不同
|
||||||
@@ -491,7 +532,36 @@ class LLMRequest:
|
|||||||
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||||
raise RuntimeError("请求限制(429)")
|
raise RuntimeError("请求限制(429)")
|
||||||
elif response.status in policy["abort_codes"]:
|
elif response.status in policy["abort_codes"]:
|
||||||
if response.status != 403:
|
# 特别处理400错误,添加详细调试信息
|
||||||
|
if response.status == 400:
|
||||||
|
logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断")
|
||||||
|
logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}")
|
||||||
|
logger.error(f"🔍 [调试信息] API地址: {self.base_url}")
|
||||||
|
logger.error(f"🔍 [调试信息] 模型配置参数:")
|
||||||
|
logger.error(f" - enable_thinking: {self.enable_thinking}")
|
||||||
|
logger.error(f" - temp: {self.temp}")
|
||||||
|
logger.error(f" - thinking_budget: {self.thinking_budget}")
|
||||||
|
logger.error(f" - stream: {self.stream}")
|
||||||
|
logger.error(f" - max_tokens: {self.max_tokens}")
|
||||||
|
logger.error(f" - pri_in: {self.pri_in}")
|
||||||
|
logger.error(f" - pri_out: {self.pri_out}")
|
||||||
|
logger.error(f"🔍 [调试信息] 原始params: {self.params}")
|
||||||
|
|
||||||
|
# 尝试获取服务器返回的详细错误信息
|
||||||
|
try:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
error_json = json.loads(error_text)
|
||||||
|
logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"🔍 [调试信息] 错误响应不是有效的JSON格式")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}")
|
||||||
|
|
||||||
|
raise RequestAbortException("参数错误,请检查调试信息", response)
|
||||||
|
elif response.status != 403:
|
||||||
raise RequestAbortException("请求出现错误,中断处理", response)
|
raise RequestAbortException("请求出现错误,中断处理", response)
|
||||||
else:
|
else:
|
||||||
raise PermissionDeniedException("模型禁止访问")
|
raise PermissionDeniedException("模型禁止访问")
|
||||||
@@ -510,6 +580,19 @@ class LLMRequest:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 如果是400错误,额外输出请求体信息用于调试
|
||||||
|
if response.status == 400:
|
||||||
|
logger.error(f"🔍 [异常调试] 400错误 - 请求体调试信息:")
|
||||||
|
try:
|
||||||
|
safe_payload = await _safely_record(request_content, payload)
|
||||||
|
logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
|
||||||
|
except Exception as debug_error:
|
||||||
|
logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}")
|
||||||
|
logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}")
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}")
|
||||||
|
|
||||||
# print(request_content)
|
# print(request_content)
|
||||||
# print(response)
|
# print(response)
|
||||||
# 尝试获取并记录服务器返回的详细错误信息
|
# 尝试获取并记录服务器返回的详细错误信息
|
||||||
@@ -654,14 +737,27 @@ class LLMRequest:
|
|||||||
"""
|
"""
|
||||||
# 复制一份参数,避免直接修改原始数据
|
# 复制一份参数,避免直接修改原始数据
|
||||||
new_params = dict(params)
|
new_params = dict(params)
|
||||||
|
|
||||||
|
logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换")
|
||||||
|
logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}")
|
||||||
|
logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}")
|
||||||
|
|
||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||||
|
logger.debug(f"🔍 [参数转换] 检测到CoT模型,开始参数转换")
|
||||||
# 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度
|
# 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度
|
||||||
if "temperature" in new_params and new_params["temperature"] == 0.7:
|
if "temperature" in new_params and new_params["temperature"] == 0.7:
|
||||||
new_params.pop("temperature")
|
removed_temp = new_params.pop("temperature")
|
||||||
|
logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}")
|
||||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||||
if "max_tokens" in new_params:
|
if "max_tokens" in new_params:
|
||||||
|
old_value = new_params["max_tokens"]
|
||||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||||
|
logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})")
|
||||||
|
else:
|
||||||
|
logger.debug(f"🔍 [参数转换] 非CoT模型,无需参数转换")
|
||||||
|
|
||||||
|
logger.debug(f"🔍 [参数转换] 转换前参数: {params}")
|
||||||
|
logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}")
|
||||||
return new_params
|
return new_params
|
||||||
|
|
||||||
async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
|
async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
|
||||||
@@ -693,7 +789,12 @@ class LLMRequest:
|
|||||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||||
"""构建请求体"""
|
"""构建请求体"""
|
||||||
# 复制一份参数,避免直接修改 self.params
|
# 复制一份参数,避免直接修改 self.params
|
||||||
|
logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体")
|
||||||
|
logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}")
|
||||||
|
|
||||||
params_copy = await self._transform_parameters(self.params)
|
params_copy = await self._transform_parameters(self.params)
|
||||||
|
logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}")
|
||||||
|
|
||||||
if image_base64:
|
if image_base64:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -715,26 +816,37 @@ class LLMRequest:
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
**params_copy,
|
**params_copy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}")
|
||||||
|
|
||||||
# 添加temp参数(如果不是默认值0.7)
|
# 添加temp参数(如果不是默认值0.7)
|
||||||
if self.temp != 0.7:
|
if self.temp != 0.7:
|
||||||
payload["temperature"] = self.temp
|
payload["temperature"] = self.temp
|
||||||
|
logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}")
|
||||||
|
|
||||||
# 添加enable_thinking参数(如果不是默认值False)
|
# 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
|
||||||
if not self.enable_thinking:
|
if self.has_enable_thinking:
|
||||||
payload["enable_thinking"] = False
|
payload["enable_thinking"] = self.enable_thinking
|
||||||
|
logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}")
|
||||||
|
|
||||||
if self.thinking_budget != 4096:
|
# 添加thinking_budget参数(只有配置文件中声明了才添加)
|
||||||
|
if self.has_thinking_budget:
|
||||||
payload["thinking_budget"] = self.thinking_budget
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
|
logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}")
|
||||||
|
|
||||||
if self.max_tokens:
|
if self.max_tokens:
|
||||||
payload["max_tokens"] = self.max_tokens
|
payload["max_tokens"] = self.max_tokens
|
||||||
|
logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}")
|
||||||
|
|
||||||
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||||
# payload["max_tokens"] = global_config.model.model_max_output_length
|
# payload["max_tokens"] = global_config.model.model_max_output_length
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
|
old_value = payload["max_tokens"]
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})")
|
||||||
|
|
||||||
|
logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}")
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _default_response_handler(
|
def _default_response_handler(
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class MainSystem:
|
|||||||
|
|
||||||
# 初始化个体特征
|
# 初始化个体特征
|
||||||
await self.individuality.initialize()
|
await self.individuality.initialize()
|
||||||
logger.info("个体特征初始化成功")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init_time = int(1000 * (time.time() - init_start_time))
|
init_time = int(1000 * (time.time() - init_start_time))
|
||||||
|
|||||||
1
src/mais4u/constant_s4u.py
Normal file
1
src/mais4u/constant_s4u.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ENABLE_S4U = False
|
||||||
@@ -3,7 +3,7 @@ import time
|
|||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from src.mais4u.s4u_config import s4u_config
|
|||||||
from src.person_info.person_info import PersonInfoManager
|
from src.person_info.person_info import PersonInfoManager
|
||||||
from .super_chat_manager import get_super_chat_manager
|
from .super_chat_manager import get_super_chat_manager
|
||||||
from .yes_or_no import yes_or_no_head
|
from .yes_or_no import yes_or_no_head
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
|
|
||||||
logger = get_logger("S4U_chat")
|
logger = get_logger("S4U_chat")
|
||||||
|
|
||||||
@@ -165,7 +166,10 @@ class S4UChatManager:
|
|||||||
return self.s4u_chats[chat_stream.stream_id]
|
return self.s4u_chats[chat_stream.stream_id]
|
||||||
|
|
||||||
|
|
||||||
s4u_chat_manager = S4UChatManager()
|
if not ENABLE_S4U:
|
||||||
|
s4u_chat_manager = None
|
||||||
|
else:
|
||||||
|
s4u_chat_manager = S4UChatManager()
|
||||||
|
|
||||||
|
|
||||||
def get_s4u_chat_manager() -> S4UChatManager:
|
def get_s4u_chat_manager() -> S4UChatManager:
|
||||||
@@ -486,7 +490,7 @@ class S4UChat:
|
|||||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||||
|
|
||||||
if s4u_config.enable_streaming_output:
|
if s4u_config.enable_streaming_output:
|
||||||
logger.info(f"[S4U] 开始流式输出")
|
logger.info("[S4U] 开始流式输出")
|
||||||
# 流式输出,边生成边发送
|
# 流式输出,边生成边发送
|
||||||
gen = self.gpt.generate_response(message, "")
|
gen = self.gpt.generate_response(message, "")
|
||||||
async for chunk in gen:
|
async for chunk in gen:
|
||||||
@@ -494,7 +498,7 @@ class S4UChat:
|
|||||||
await sender_container.add_message(chunk)
|
await sender_container.add_message(chunk)
|
||||||
total_chars_sent += len(chunk)
|
total_chars_sent += len(chunk)
|
||||||
else:
|
else:
|
||||||
logger.info(f"[S4U] 开始一次性输出")
|
logger.info("[S4U] 开始一次性输出")
|
||||||
# 一次性输出,先收集所有chunk
|
# 一次性输出,先收集所有chunk
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
gen = self.gpt.generate_response(message, "")
|
gen = self.gpt.generate_response(message, "")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from src.config.config import global_config
|
|||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
|
|
||||||
"""
|
"""
|
||||||
情绪管理系统使用说明:
|
情绪管理系统使用说明:
|
||||||
@@ -446,9 +447,10 @@ class MoodManager:
|
|||||||
# 发送初始情绪状态到ws端
|
# 发送初始情绪状态到ws端
|
||||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||||
|
|
||||||
|
if ENABLE_S4U:
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
mood_manager = MoodManager()
|
||||||
mood_manager = MoodManager()
|
else:
|
||||||
|
mood_manager = None
|
||||||
|
|
||||||
"""全局情绪管理器"""
|
"""全局情绪管理器"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||||
from maim_message.message_base import GroupInfo,UserInfo
|
from maim_message.message_base import GroupInfo
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ from datetime import datetime
|
|||||||
import asyncio
|
import asyncio
|
||||||
from src.mais4u.s4u_config import s4u_config
|
from src.mais4u.s4u_config import s4u_config
|
||||||
from src.chat.message_receive.message import MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from .s4u_mood_manager import mood_manager
|
from .s4u_mood_manager import mood_manager
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
|
||||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||||
logger = get_logger("prompt")
|
logger = get_logger("prompt")
|
||||||
|
|
||||||
@@ -149,9 +149,17 @@ class PromptBuilder:
|
|||||||
|
|
||||||
relation_prompt = ""
|
relation_prompt = ""
|
||||||
if global_config.relationship.enable_relationship and who_chat_in_group:
|
if global_config.relationship.enable_relationship and who_chat_in_group:
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
|
||||||
|
|
||||||
|
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||||
|
person_ids = []
|
||||||
|
for person in who_chat_in_group:
|
||||||
|
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||||
|
person_ids.append(person_id)
|
||||||
|
|
||||||
|
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
|
||||||
relation_info_list = await asyncio.gather(
|
relation_info_list = await asyncio.gather(
|
||||||
*[relationship_manager.build_relationship_info(person) for person in who_chat_in_group]
|
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||||
)
|
)
|
||||||
relation_info = "".join(relation_info_list)
|
relation_info = "".join(relation_info_list)
|
||||||
if relation_info:
|
if relation_info:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from src.config.config import global_config
|
|||||||
from src.chat.message_receive.message import MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -49,19 +48,19 @@ class S4UStreamGenerator:
|
|||||||
self.chat_stream =None
|
self.chat_stream =None
|
||||||
|
|
||||||
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
|
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
|
||||||
person_id = PersonInfoManager.get_person_id(
|
# person_id = PersonInfoManager.get_person_id(
|
||||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||||
)
|
# )
|
||||||
person_info_manager = get_person_info_manager()
|
# person_info_manager = get_person_info_manager()
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
|
|
||||||
if message.chat_stream.user_info.user_nickname:
|
# if message.chat_stream.user_info.user_nickname:
|
||||||
if person_name:
|
# if person_name:
|
||||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||||
else:
|
# else:
|
||||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||||
else:
|
# else:
|
||||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
if previous_reply_context:
|
if previous_reply_context:
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.message import MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
|
# 全局SuperChat管理器实例
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
|
|
||||||
logger = get_logger("super_chat_manager")
|
logger = get_logger("super_chat_manager")
|
||||||
|
|
||||||
@@ -296,10 +298,14 @@ class SuperChatManager:
|
|||||||
logger.info("SuperChat管理器已关闭")
|
logger.info("SuperChat管理器已关闭")
|
||||||
|
|
||||||
|
|
||||||
# 全局SuperChat管理器实例
|
|
||||||
super_chat_manager = SuperChatManager()
|
|
||||||
|
|
||||||
|
|
||||||
|
if ENABLE_S4U:
|
||||||
|
super_chat_manager = SuperChatManager()
|
||||||
|
else:
|
||||||
|
super_chat_manager = None
|
||||||
|
|
||||||
def get_super_chat_manager() -> SuperChatManager:
|
def get_super_chat_manager() -> SuperChatManager:
|
||||||
"""获取全局SuperChat管理器实例"""
|
"""获取全局SuperChat管理器实例"""
|
||||||
return super_chat_manager
|
|
||||||
|
return super_chat_manager
|
||||||
@@ -1,16 +1,6 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
from json_repair import repair_json
|
|
||||||
from src.mais4u.s4u_config import s4u_config
|
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from tomlkit import TOMLDocument
|
|||||||
from tomlkit.items import Table
|
from tomlkit.items import Table
|
||||||
from dataclasses import dataclass, fields, MISSING, field
|
from dataclasses import dataclass, fields, MISSING, field
|
||||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("s4u_config")
|
logger = get_logger("s4u_config")
|
||||||
@@ -353,12 +353,16 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
# 初始化S4U配置
|
if not ENABLE_S4U:
|
||||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
s4u_config = None
|
||||||
update_s4u_config()
|
s4u_config_main = None
|
||||||
|
else:
|
||||||
|
# 初始化S4U配置
|
||||||
|
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||||
|
update_s4u_config()
|
||||||
|
|
||||||
logger.info("正在加载S4U配置文件...")
|
logger.info("正在加载S4U配置文件...")
|
||||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||||
logger.info("S4U配置文件加载完成!")
|
logger.info("S4U配置文件加载完成!")
|
||||||
|
|
||||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||||
@@ -83,12 +83,12 @@ class ChatMood:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
||||||
)
|
)
|
||||||
update_probability = min(1.0, base_probability * time_multiplier * interest_multiplier)
|
update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier)
|
||||||
|
|
||||||
if random.random() > update_probability:
|
if random.random() > update_probability:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}")
|
logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}")
|
||||||
|
|
||||||
message_time: float = message.message_info.time # type: ignore
|
message_time: float = message.message_info.time # type: ignore
|
||||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
@@ -201,7 +201,7 @@ class MoodRegressionTask(AsyncTask):
|
|||||||
if mood.regression_count >= 3:
|
if mood.regression_count >= 3:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"chat {mood.chat_id} 开始情绪回归, 这是第 {mood.regression_count + 1} 次")
|
logger.info(f"{mood.log_prefix} 开始情绪回归, 这是第 {mood.regression_count + 1} 次")
|
||||||
await mood.regress_mood()
|
await mood.regress_mood()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,8 +41,6 @@ person_info_default = {
|
|||||||
"know_times": 0,
|
"know_times": 0,
|
||||||
"know_since": None,
|
"know_since": None,
|
||||||
"last_know": None,
|
"last_know": None,
|
||||||
# "user_cardname": None, # This field is not in Peewee model PersonInfo
|
|
||||||
# "user_avatar": None, # This field is not in Peewee model PersonInfo
|
|
||||||
"impression": None, # Corrected from person_impression
|
"impression": None, # Corrected from person_impression
|
||||||
"short_impression": None,
|
"short_impression": None,
|
||||||
"info_list": None,
|
"info_list": None,
|
||||||
|
|||||||
@@ -112,15 +112,6 @@ class RelationshipFetcher:
|
|||||||
|
|
||||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||||
|
|
||||||
if isinstance(current_points, str):
|
|
||||||
try:
|
|
||||||
current_points = json.loads(current_points)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"解析points JSON失败: {current_points}")
|
|
||||||
current_points = []
|
|
||||||
elif not isinstance(current_points, list):
|
|
||||||
current_points = []
|
|
||||||
|
|
||||||
# 按时间排序forgotten_points
|
# 按时间排序forgotten_points
|
||||||
current_points.sort(key=lambda x: x[2])
|
current_points.sort(key=lambda x: x[2])
|
||||||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||||
@@ -370,60 +361,6 @@ class RelationshipFetcher:
|
|||||||
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
|
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
def _organize_known_info(self) -> str:
|
|
||||||
"""组织已知的用户信息为字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化的用户信息字符串
|
|
||||||
"""
|
|
||||||
persons_infos_str = ""
|
|
||||||
|
|
||||||
if self.info_fetched_cache:
|
|
||||||
persons_with_known_info = [] # 有已知信息的人员
|
|
||||||
persons_with_unknown_info = [] # 有未知信息的人员
|
|
||||||
|
|
||||||
for person_id in self.info_fetched_cache:
|
|
||||||
person_known_infos = []
|
|
||||||
person_unknown_infos = []
|
|
||||||
person_name = ""
|
|
||||||
|
|
||||||
for info_type in self.info_fetched_cache[person_id]:
|
|
||||||
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
|
|
||||||
if not self.info_fetched_cache[person_id][info_type]["unknown"]:
|
|
||||||
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
|
||||||
person_known_infos.append(f"[{info_type}]:{info_content}")
|
|
||||||
else:
|
|
||||||
person_unknown_infos.append(info_type)
|
|
||||||
|
|
||||||
# 如果有已知信息,添加到已知信息列表
|
|
||||||
if person_known_infos:
|
|
||||||
known_info_str = ";".join(person_known_infos) + ";"
|
|
||||||
persons_with_known_info.append((person_name, known_info_str))
|
|
||||||
|
|
||||||
# 如果有未知信息,添加到未知信息列表
|
|
||||||
if person_unknown_infos:
|
|
||||||
persons_with_unknown_info.append((person_name, person_unknown_infos))
|
|
||||||
|
|
||||||
# 先输出有已知信息的人员
|
|
||||||
for person_name, known_info_str in persons_with_known_info:
|
|
||||||
persons_infos_str += f"你对 {person_name} 的了解:{known_info_str}\n"
|
|
||||||
|
|
||||||
# 统一处理未知信息,避免重复的警告文本
|
|
||||||
if persons_with_unknown_info:
|
|
||||||
unknown_persons_details = []
|
|
||||||
for person_name, unknown_types in persons_with_unknown_info:
|
|
||||||
unknown_types_str = "、".join(unknown_types)
|
|
||||||
unknown_persons_details.append(f"{person_name}的[{unknown_types_str}]")
|
|
||||||
|
|
||||||
if len(unknown_persons_details) == 1:
|
|
||||||
persons_infos_str += (
|
|
||||||
f"你不了解{unknown_persons_details[0]}信息,不要胡乱回答,可以直接说不知道或忘记了;\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
unknown_all_str = "、".join(unknown_persons_details)
|
|
||||||
persons_infos_str += f"你不了解{unknown_all_str}等信息,不要胡乱回答,可以直接说不知道或忘记了;\n"
|
|
||||||
|
|
||||||
return persons_infos_str
|
|
||||||
|
|
||||||
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
||||||
# sourcery skip: use-next
|
# sourcery skip: use-next
|
||||||
|
|||||||
@@ -55,60 +55,6 @@ class RelationshipManager:
|
|||||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||||
# )
|
# )
|
||||||
|
|
||||||
async def build_relationship_info(self, person, is_id: bool = False) -> str:
|
|
||||||
if is_id:
|
|
||||||
person_id = person
|
|
||||||
else:
|
|
||||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
|
||||||
if not person_name or person_name == "none":
|
|
||||||
return ""
|
|
||||||
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
|
||||||
|
|
||||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
|
||||||
# print(f"current_points: {current_points}")
|
|
||||||
if isinstance(current_points, str):
|
|
||||||
try:
|
|
||||||
current_points = json.loads(current_points)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"解析points JSON失败: {current_points}")
|
|
||||||
current_points = []
|
|
||||||
elif not isinstance(current_points, list):
|
|
||||||
current_points = []
|
|
||||||
|
|
||||||
# 按时间排序forgotten_points
|
|
||||||
current_points.sort(key=lambda x: x[2])
|
|
||||||
# 按权重加权随机抽取3个points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
|
||||||
if len(current_points) > 3:
|
|
||||||
# point[1] 取值范围1-10,直接作为权重
|
|
||||||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
|
||||||
points = random.choices(current_points, weights=weights, k=3)
|
|
||||||
else:
|
|
||||||
points = current_points
|
|
||||||
|
|
||||||
# 构建points文本
|
|
||||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
|
||||||
|
|
||||||
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
|
||||||
platform = await person_info_manager.get_value(person_id, "platform")
|
|
||||||
|
|
||||||
if person_name == nickname_str and not short_impression:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if person_name == nickname_str:
|
|
||||||
relation_prompt = f"'{person_name}' :"
|
|
||||||
else:
|
|
||||||
relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。"
|
|
||||||
|
|
||||||
if short_impression:
|
|
||||||
relation_prompt += f"你对ta的印象是:{short_impression}。\n"
|
|
||||||
|
|
||||||
if points_text:
|
|
||||||
relation_prompt += f"你记得ta最近做的事:{points_text}"
|
|
||||||
|
|
||||||
return relation_prompt
|
|
||||||
|
|
||||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
||||||
"""更新用户印象
|
"""更新用户印象
|
||||||
|
|
||||||
|
|||||||
@@ -23,12 +23,6 @@ from .base import (
|
|||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
)
|
)
|
||||||
from .core import (
|
|
||||||
plugin_manager,
|
|
||||||
component_registry,
|
|
||||||
dependency_manager,
|
|
||||||
events_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 导入工具模块
|
# 导入工具模块
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -38,12 +32,42 @@ from .utils import (
|
|||||||
# generate_plugin_manifest,
|
# generate_plugin_manifest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .apis import register_plugin, get_logger
|
from .apis import (
|
||||||
|
chat_api,
|
||||||
|
component_manage_api,
|
||||||
|
config_api,
|
||||||
|
database_api,
|
||||||
|
emoji_api,
|
||||||
|
generator_api,
|
||||||
|
llm_api,
|
||||||
|
message_api,
|
||||||
|
person_api,
|
||||||
|
plugin_manage_api,
|
||||||
|
send_api,
|
||||||
|
utils_api,
|
||||||
|
register_plugin,
|
||||||
|
get_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "1.0.0"
|
__version__ = "1.0.0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# API 模块
|
||||||
|
"chat_api",
|
||||||
|
"component_manage_api",
|
||||||
|
"config_api",
|
||||||
|
"database_api",
|
||||||
|
"emoji_api",
|
||||||
|
"generator_api",
|
||||||
|
"llm_api",
|
||||||
|
"message_api",
|
||||||
|
"person_api",
|
||||||
|
"plugin_manage_api",
|
||||||
|
"send_api",
|
||||||
|
"utils_api",
|
||||||
|
"register_plugin",
|
||||||
|
"get_logger",
|
||||||
# 基础类
|
# 基础类
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
"BaseAction",
|
"BaseAction",
|
||||||
@@ -62,11 +86,6 @@ __all__ = [
|
|||||||
"EventType",
|
"EventType",
|
||||||
# 消息
|
# 消息
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
# 管理器
|
|
||||||
"plugin_manager",
|
|
||||||
"component_registry",
|
|
||||||
"dependency_manager",
|
|
||||||
"events_manager",
|
|
||||||
# 装饰器
|
# 装饰器
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
# 导入所有API模块
|
# 导入所有API模块
|
||||||
from src.plugin_system.apis import (
|
from src.plugin_system.apis import (
|
||||||
chat_api,
|
chat_api,
|
||||||
|
component_manage_api,
|
||||||
config_api,
|
config_api,
|
||||||
database_api,
|
database_api,
|
||||||
emoji_api,
|
emoji_api,
|
||||||
@@ -14,15 +15,17 @@ from src.plugin_system.apis import (
|
|||||||
llm_api,
|
llm_api,
|
||||||
message_api,
|
message_api,
|
||||||
person_api,
|
person_api,
|
||||||
|
plugin_manage_api,
|
||||||
send_api,
|
send_api,
|
||||||
utils_api,
|
utils_api,
|
||||||
plugin_register_api,
|
|
||||||
)
|
)
|
||||||
from .logging_api import get_logger
|
from .logging_api import get_logger
|
||||||
from .plugin_register_api import register_plugin
|
from .plugin_register_api import register_plugin
|
||||||
|
|
||||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"chat_api",
|
"chat_api",
|
||||||
|
"component_manage_api",
|
||||||
"config_api",
|
"config_api",
|
||||||
"database_api",
|
"database_api",
|
||||||
"emoji_api",
|
"emoji_api",
|
||||||
@@ -30,9 +33,9 @@ __all__ = [
|
|||||||
"llm_api",
|
"llm_api",
|
||||||
"message_api",
|
"message_api",
|
||||||
"person_api",
|
"person_api",
|
||||||
|
"plugin_manage_api",
|
||||||
"send_api",
|
"send_api",
|
||||||
"utils_api",
|
"utils_api",
|
||||||
"plugin_register_api",
|
|
||||||
"get_logger",
|
"get_logger",
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
]
|
]
|
||||||
|
|||||||
245
src/plugin_system/apis/component_manage_api.py
Normal file
245
src/plugin_system/apis/component_manage_api.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
from typing import Optional, Union, Dict
|
||||||
|
from src.plugin_system.base.component_types import (
|
||||||
|
CommandInfo,
|
||||||
|
ActionInfo,
|
||||||
|
EventHandlerInfo,
|
||||||
|
PluginInfo,
|
||||||
|
ComponentType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# === 插件信息查询 ===
|
||||||
|
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||||
|
"""
|
||||||
|
获取所有插件的信息。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_all_plugins()
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||||
|
"""
|
||||||
|
获取指定插件的信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 插件名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_plugin_info(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
|
# === 组件查询方法 ===
|
||||||
|
def get_component_info(
|
||||||
|
component_name: str, component_type: ComponentType
|
||||||
|
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
"""
|
||||||
|
获取指定组件的信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 组件名称。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
Returns:
|
||||||
|
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_component_info(component_name, component_type) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_components_info_by_type(
|
||||||
|
component_type: ComponentType,
|
||||||
|
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
"""
|
||||||
|
获取指定类型的所有组件信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_components_by_type(component_type) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_enabled_components_info_by_type(
|
||||||
|
component_type: ComponentType,
|
||||||
|
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
"""
|
||||||
|
获取指定类型的所有启用的组件信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# === Action 查询方法 ===
|
||||||
|
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||||
|
"""
|
||||||
|
获取指定 Action 的注册信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name (str): Action 名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_registered_action_info(action_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||||
|
"""
|
||||||
|
获取指定 Command 的注册信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_name (str): Command 名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_registered_command_info(command_name)
|
||||||
|
|
||||||
|
|
||||||
|
# === EventHandler 特定查询方法 ===
|
||||||
|
def get_registered_event_handler_info(
|
||||||
|
event_handler_name: str,
|
||||||
|
) -> Optional[EventHandlerInfo]:
|
||||||
|
"""
|
||||||
|
获取指定 EventHandler 的注册信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_handler_name (str): EventHandler 名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.get_registered_event_handler_info(event_handler_name)
|
||||||
|
|
||||||
|
|
||||||
|
# === 组件管理方法 ===
|
||||||
|
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""
|
||||||
|
全局启用指定组件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 组件名称。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 启用成功返回 True,否则返回 False。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return component_registry.enable_component(component_name, component_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""
|
||||||
|
全局禁用指定组件。
|
||||||
|
|
||||||
|
**此函数是异步的,确保在异步环境中调用。**
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 组件名称。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 禁用成功返回 True,否则返回 False。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
return await component_registry.disable_component(component_name, component_type)
|
||||||
|
|
||||||
|
|
||||||
|
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
局部启用指定组件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 组件名称。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
stream_id (str): 消息流 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 启用成功返回 True,否则返回 False。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"未知 component type: {component_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
局部禁用指定组件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 组件名称。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
stream_id (str): 消息流 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 禁用成功返回 True,否则返回 False。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"未知 component type: {component_type}")
|
||||||
|
|
||||||
|
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取指定消息流中禁用的组件列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id (str): 消息流 ID。
|
||||||
|
component_type (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: 禁用的组件名称列表。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"未知 component type: {component_type}")
|
||||||
95
src/plugin_system/apis/plugin_manage_api.py
Normal file
95
src/plugin_system/apis/plugin_manage_api.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from typing import Tuple, List
|
||||||
|
def list_loaded_plugins() -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有当前加载的插件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 当前加载的插件名称列表。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return plugin_manager.list_loaded_plugins()
|
||||||
|
|
||||||
|
|
||||||
|
def list_registered_plugins() -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有已注册的插件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 已注册的插件名称列表。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return plugin_manager.list_registered_plugins()
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_plugin(plugin_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
卸载指定的插件。
|
||||||
|
|
||||||
|
**此函数是异步的,确保在异步环境中调用。**
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 要卸载的插件名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 卸载是否成功。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return await plugin_manager.remove_registered_plugin(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def reload_plugin(plugin_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
重新加载指定的插件。
|
||||||
|
|
||||||
|
**此函数是异步的,确保在异步环境中调用。**
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 要重新加载的插件名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 重新加载是否成功。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
加载指定的插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 要加载的插件名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, int]: 加载是否成功,成功或失败个数。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||||
|
|
||||||
|
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||||
|
"""
|
||||||
|
添加插件目录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_directory (str): 要添加的插件目录路径。
|
||||||
|
Returns:
|
||||||
|
bool: 添加是否成功。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||||
|
|
||||||
|
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
重新扫描插件目录,加载新插件。
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
return plugin_manager.rescan_plugin_directory()
|
||||||
@@ -28,7 +28,6 @@ def register_plugin(cls):
|
|||||||
if "." in plugin_name:
|
if "." in plugin_name:
|
||||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
plugin_manager.plugin_classes[plugin_name] = cls
|
|
||||||
splitted_name = cls.__module__.split(".")
|
splitted_name = cls.__module__.split(".")
|
||||||
root_path = Path(__file__)
|
root_path = Path(__file__)
|
||||||
|
|
||||||
@@ -40,6 +39,7 @@ def register_plugin(cls):
|
|||||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
plugin_manager.plugin_classes[plugin_name] = cls
|
||||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||||
|
|
||||||
|
|||||||
@@ -49,12 +49,10 @@ class BaseAction(ABC):
|
|||||||
reasoning: 执行该动作的理由
|
reasoning: 执行该动作的理由
|
||||||
cycle_timers: 计时器字典
|
cycle_timers: 计时器字典
|
||||||
thinking_id: 思考ID
|
thinking_id: 思考ID
|
||||||
expressor: 表达器对象
|
|
||||||
replyer: 回复器对象
|
|
||||||
chat_stream: 聊天流对象
|
chat_stream: 聊天流对象
|
||||||
log_prefix: 日志前缀
|
log_prefix: 日志前缀
|
||||||
shutting_down: 是否正在关闭
|
|
||||||
plugin_config: 插件配置字典
|
plugin_config: 插件配置字典
|
||||||
|
action_message: 消息数据
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
"""
|
"""
|
||||||
if plugin_config is None:
|
if plugin_config is None:
|
||||||
@@ -65,21 +63,30 @@ class BaseAction(ABC):
|
|||||||
self.thinking_id = thinking_id
|
self.thinking_id = thinking_id
|
||||||
self.log_prefix = log_prefix
|
self.log_prefix = log_prefix
|
||||||
|
|
||||||
# 保存插件配置
|
|
||||||
self.plugin_config = plugin_config or {}
|
self.plugin_config = plugin_config or {}
|
||||||
|
"""对应的插件配置"""
|
||||||
|
|
||||||
# 设置动作基本信息实例属性
|
# 设置动作基本信息实例属性
|
||||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||||
|
"""Action的名字"""
|
||||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||||
|
"""Action的描述"""
|
||||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||||
|
|
||||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||||
|
"""FOCUS模式下的激活类型"""
|
||||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||||
|
"""NORMAL模式下的激活类型"""
|
||||||
|
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
|
||||||
|
"""激活类型"""
|
||||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||||
|
"""当激活类型为RANDOM时的概率"""
|
||||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||||
|
"""协助LLM进行判断的Prompt"""
|
||||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||||
|
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||||
@@ -136,7 +143,7 @@ class BaseAction(ABC):
|
|||||||
self.target_id = self.user_id
|
self.target_id = self.user_id
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -405,23 +412,11 @@ class BaseAction(ABC):
|
|||||||
"""
|
"""
|
||||||
return await self.execute()
|
return await self.execute()
|
||||||
|
|
||||||
# def get_action_context(self, key: str, default=None):
|
|
||||||
# """获取action上下文信息
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# key: 上下文键名
|
|
||||||
# default: 默认值
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# Any: 上下文值或默认值
|
|
||||||
# """
|
|
||||||
# return self.api.get_action_context(key, default)
|
|
||||||
|
|
||||||
def get_config(self, key: str, default=None):
|
def get_config(self, key: str, default=None):
|
||||||
"""获取插件配置值,支持嵌套键访问
|
"""获取插件配置值,使用嵌套键访问
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||||
default: 默认值
|
default: 默认值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -17,17 +17,18 @@ class BaseCommand(ABC):
|
|||||||
- command_pattern: 命令匹配的正则表达式
|
- command_pattern: 命令匹配的正则表达式
|
||||||
- command_help: 命令帮助信息
|
- command_help: 命令帮助信息
|
||||||
- command_examples: 命令使用示例列表
|
- command_examples: 命令使用示例列表
|
||||||
- intercept_message: 是否拦截消息处理(默认True拦截,False继续传递)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
command_name: str = ""
|
command_name: str = ""
|
||||||
|
"""Command组件的名称"""
|
||||||
command_description: str = ""
|
command_description: str = ""
|
||||||
|
"""Command组件的描述"""
|
||||||
# 默认命令设置(子类可以覆盖)
|
# 默认命令设置
|
||||||
command_pattern: str = ""
|
command_pattern: str = r""
|
||||||
|
"""命令匹配的正则表达式"""
|
||||||
command_help: str = ""
|
command_help: str = ""
|
||||||
|
"""命令帮助信息"""
|
||||||
command_examples: List[str] = []
|
command_examples: List[str] = []
|
||||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||||
"""初始化Command组件
|
"""初始化Command组件
|
||||||
@@ -53,11 +54,11 @@ class BaseCommand(ABC):
|
|||||||
self.matched_groups = groups
|
self.matched_groups = groups
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||||
"""执行Command的抽象方法,子类必须实现
|
"""执行Command的抽象方法,子类必须实现
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息)
|
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -229,5 +230,4 @@ class BaseCommand(ABC):
|
|||||||
command_pattern=cls.command_pattern,
|
command_pattern=cls.command_pattern,
|
||||||
command_help=cls.command_help,
|
command_help=cls.command_help,
|
||||||
command_examples=cls.command_examples.copy() if cls.command_examples else [],
|
command_examples=cls.command_examples.copy() if cls.command_examples else [],
|
||||||
intercept_message=cls.intercept_message,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
|
|||||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
event_type: EventType = EventType.UNKNOWN
|
||||||
handler_name: str = "" # 处理器名称
|
"""事件类型,默认为未知"""
|
||||||
|
handler_name: str = ""
|
||||||
|
"""处理器名称"""
|
||||||
handler_description: str = ""
|
handler_description: str = ""
|
||||||
weight: int = 0 # 权重,数值越大优先级越高
|
"""处理器描述"""
|
||||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
weight: int = 0
|
||||||
|
"""处理器权重,越大权重越高"""
|
||||||
|
intercept_message: bool = False
|
||||||
|
"""是否拦截消息,默认为否"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_prefix = "[EventHandler]"
|
self.log_prefix = "[EventHandler]"
|
||||||
self.plugin_name = "" # 对应插件名
|
self.plugin_name = ""
|
||||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
"""对应插件名"""
|
||||||
|
self.plugin_config: Optional[Dict] = None
|
||||||
|
"""插件配置字典"""
|
||||||
if self.event_type == EventType.UNKNOWN:
|
if self.event_type == EventType.UNKNOWN:
|
||||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import List, Type, Tuple, Union
|
|||||||
from .plugin_base import PluginBase
|
from .plugin_base import PluginBase
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo
|
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo
|
||||||
from .base_action import BaseAction
|
from .base_action import BaseAction
|
||||||
from .base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from .base_events_handler import BaseEventHandler
|
from .base_events_handler import BaseEventHandler
|
||||||
|
|||||||
@@ -142,7 +142,6 @@ class CommandInfo(ComponentInfo):
|
|||||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||||
command_help: str = "" # 命令帮助信息
|
command_help: str = "" # 命令帮助信息
|
||||||
command_examples: List[str] = field(default_factory=list) # 命令使用示例
|
command_examples: List[str] = field(default_factory=list) # 命令使用示例
|
||||||
intercept_message: bool = True # 是否拦截消息处理(默认拦截)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
|||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plugin_manager",
|
"plugin_manager",
|
||||||
"component_registry",
|
"component_registry",
|
||||||
"dependency_manager",
|
"dependency_manager",
|
||||||
"events_manager",
|
"events_manager",
|
||||||
|
"global_announcement_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -25,27 +25,35 @@ class ComponentRegistry:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 组件注册表
|
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||||
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
|
self._components: Dict[str, ComponentInfo] = {}
|
||||||
# 类型 -> 命名空间式名称 -> 组件信息
|
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||||
# 命名空间式组件名 -> 组件类
|
"""类型 -> 组件原名称 -> 组件信息"""
|
||||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
||||||
|
"""命名空间式组件名 -> 组件类"""
|
||||||
|
|
||||||
# 插件注册表
|
# 插件注册表
|
||||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
self._plugins: Dict[str, PluginInfo] = {}
|
||||||
|
"""插件名 -> 插件信息"""
|
||||||
|
|
||||||
# Action特定注册表
|
# Action特定注册表
|
||||||
self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
|
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
||||||
self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态
|
"""Action注册表 action名 -> action类"""
|
||||||
|
self._default_actions: Dict[str, ActionInfo] = {}
|
||||||
|
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||||
|
|
||||||
# Command特定注册表
|
# Command特定注册表
|
||||||
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
|
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
||||||
self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command名
|
"""Command类注册表 command名 -> command类"""
|
||||||
|
self._command_patterns: Dict[Pattern, str] = {}
|
||||||
|
"""编译后的正则 -> command名"""
|
||||||
|
|
||||||
# EventHandler特定注册表
|
# EventHandler特定注册表
|
||||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类
|
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器
|
"""event_handler名 -> event_handler类"""
|
||||||
|
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||||
|
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||||
|
|
||||||
logger.info("组件注册中心初始化完成")
|
logger.info("组件注册中心初始化完成")
|
||||||
|
|
||||||
@@ -110,11 +118,17 @@ class ComponentRegistry:
|
|||||||
# 根据组件类型进行特定注册(使用原始名称)
|
# 根据组件类型进行特定注册(使用原始名称)
|
||||||
match component_type:
|
match component_type:
|
||||||
case ComponentType.ACTION:
|
case ComponentType.ACTION:
|
||||||
ret = self._register_action_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, ActionInfo)
|
||||||
|
assert issubclass(component_class, BaseAction)
|
||||||
|
ret = self._register_action_component(component_info, component_class)
|
||||||
case ComponentType.COMMAND:
|
case ComponentType.COMMAND:
|
||||||
ret = self._register_command_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, CommandInfo)
|
||||||
|
assert issubclass(component_class, BaseCommand)
|
||||||
|
ret = self._register_command_component(component_info, component_class)
|
||||||
case ComponentType.EVENT_HANDLER:
|
case ComponentType.EVENT_HANDLER:
|
||||||
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, EventHandlerInfo)
|
||||||
|
assert issubclass(component_class, BaseEventHandler)
|
||||||
|
ret = self._register_event_handler_component(component_info, component_class)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"未知组件类型: {component_type}")
|
logger.warning(f"未知组件类型: {component_type}")
|
||||||
|
|
||||||
@@ -160,7 +174,9 @@ class ComponentRegistry:
|
|||||||
if pattern not in self._command_patterns:
|
if pattern not in self._command_patterns:
|
||||||
self._command_patterns[pattern] = command_name
|
self._command_patterns[pattern] = command_name
|
||||||
else:
|
else:
|
||||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
logger.warning(
|
||||||
|
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -176,6 +192,10 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
self._event_handler_registry[handler_name] = handler_class
|
self._event_handler_registry[handler_name] = handler_class
|
||||||
|
|
||||||
|
if not handler_info.enabled:
|
||||||
|
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||||
|
return True # 未启用,但是也是注册成功
|
||||||
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||||
@@ -185,6 +205,124 @@ class ComponentRegistry:
|
|||||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# === 组件移除相关 ===
|
||||||
|
|
||||||
|
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
if not target_component_class:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
self._action_registry.pop(component_name)
|
||||||
|
self._default_actions.pop(component_name)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
self._command_registry.pop(component_name)
|
||||||
|
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
self._command_patterns.pop(key)
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
self._event_handler_registry.pop(component_name)
|
||||||
|
self._enabled_event_handlers.pop(component_name)
|
||||||
|
await events_manager.unregister_event_subscriber(component_name)
|
||||||
|
namespaced_name = f"{component_type}.{component_name}"
|
||||||
|
self._components.pop(namespaced_name)
|
||||||
|
self._components_by_type[component_type].pop(component_name)
|
||||||
|
self._components_classes.pop(namespaced_name)
|
||||||
|
logger.info(f"组件 {component_name} 已移除")
|
||||||
|
return True
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"移除组件时未找到组件: {component_name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||||
|
"""移除插件注册信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功移除
|
||||||
|
"""
|
||||||
|
if plugin_name not in self._plugins:
|
||||||
|
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
||||||
|
return False
|
||||||
|
del self._plugins[plugin_name]
|
||||||
|
logger.info(f"插件 {plugin_name} 已移除")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# === 组件全局启用/禁用方法 ===
|
||||||
|
|
||||||
|
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""全局的启用某个组件
|
||||||
|
Parameters:
|
||||||
|
component_name: 组件名称
|
||||||
|
component_type: 组件类型
|
||||||
|
Returns:
|
||||||
|
bool: 启用成功返回True,失败返回False
|
||||||
|
"""
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
target_component_info = self.get_component_info(component_name, component_type)
|
||||||
|
if not target_component_class or not target_component_info:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||||
|
return False
|
||||||
|
target_component_info.enabled = True
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
assert isinstance(target_component_info, ActionInfo)
|
||||||
|
self._default_actions[component_name] = target_component_info
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
assert isinstance(target_component_info, CommandInfo)
|
||||||
|
pattern = target_component_info.command_pattern
|
||||||
|
self._command_patterns[re.compile(pattern)] = component_name
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
assert isinstance(target_component_info, EventHandlerInfo)
|
||||||
|
assert issubclass(target_component_class, BaseEventHandler)
|
||||||
|
self._enabled_event_handlers[component_name] = target_component_class
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||||
|
namespaced_name = f"{component_type}.{component_name}"
|
||||||
|
self._components[namespaced_name].enabled = True
|
||||||
|
self._components_by_type[component_type][component_name].enabled = True
|
||||||
|
logger.info(f"组件 {component_name} 已启用")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""全局的禁用某个组件
|
||||||
|
Parameters:
|
||||||
|
component_name: 组件名称
|
||||||
|
component_type: 组件类型
|
||||||
|
Returns:
|
||||||
|
bool: 禁用成功返回True,失败返回False
|
||||||
|
"""
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
target_component_info = self.get_component_info(component_name, component_type)
|
||||||
|
if not target_component_class or not target_component_info:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||||
|
return False
|
||||||
|
target_component_info.enabled = False
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
self._default_actions.pop(component_name, None)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
self._enabled_event_handlers.pop(component_name, None)
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
await events_manager.unregister_event_subscriber(component_name)
|
||||||
|
self._components[component_name].enabled = False
|
||||||
|
self._components_by_type[component_type][component_name].enabled = False
|
||||||
|
logger.info(f"组件 {component_name} 已禁用")
|
||||||
|
return True
|
||||||
|
|
||||||
# === 组件查询方法 ===
|
# === 组件查询方法 ===
|
||||||
def get_component_info(
|
def get_component_info(
|
||||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||||
@@ -287,7 +425,7 @@ class ComponentRegistry:
|
|||||||
# === Action特定查询方法 ===
|
# === Action特定查询方法 ===
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||||
"""获取Action注册表(用于兼容现有系统)"""
|
"""获取Action注册表"""
|
||||||
return self._action_registry.copy()
|
return self._action_registry.copy()
|
||||||
|
|
||||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||||
@@ -314,7 +452,7 @@ class ComponentRegistry:
|
|||||||
"""获取Command模式注册表"""
|
"""获取Command模式注册表"""
|
||||||
return self._command_patterns.copy()
|
return self._command_patterns.copy()
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||||
# sourcery skip: use-named-expression, use-next
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
@@ -335,11 +473,10 @@ class ComponentRegistry:
|
|||||||
return (
|
return (
|
||||||
self._command_registry[command_name],
|
self._command_registry[command_name],
|
||||||
candidates[0].match(text).groupdict(), # type: ignore
|
candidates[0].match(text).groupdict(), # type: ignore
|
||||||
command_info.intercept_message,
|
command_info,
|
||||||
command_info.plugin_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# === 事件处理器特定查询方法 ===
|
# === EventHandler 特定查询方法 ===
|
||||||
|
|
||||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||||
"""获取事件处理器注册表"""
|
"""获取事件处理器注册表"""
|
||||||
@@ -364,9 +501,9 @@ class ComponentRegistry:
|
|||||||
"""获取所有插件"""
|
"""获取所有插件"""
|
||||||
return self._plugins.copy()
|
return self._plugins.copy()
|
||||||
|
|
||||||
def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||||
"""获取所有启用的插件"""
|
# """获取所有启用的插件"""
|
||||||
return {name: info for name, info in self._plugins.items() if info.enabled}
|
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||||
|
|
||||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||||
"""获取插件的所有组件"""
|
"""获取插件的所有组件"""
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
|
from .global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
logger = get_logger("events_manager")
|
logger = get_logger("events_manager")
|
||||||
|
|
||||||
@@ -28,18 +29,16 @@ class EventsManager:
|
|||||||
bool: 是否注册成功
|
bool: 是否注册成功
|
||||||
"""
|
"""
|
||||||
handler_name = handler_info.name
|
handler_name = handler_info.name
|
||||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
|
||||||
|
|
||||||
namespace_name = f"{plugin_name}.{handler_name}"
|
if handler_name in self._handler_mapping:
|
||||||
if namespace_name in self._handler_mapping:
|
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not issubclass(handler_class, BaseEventHandler):
|
if not issubclass(handler_class, BaseEventHandler):
|
||||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._handler_mapping[namespace_name] = handler_class
|
self._handler_mapping[handler_name] = handler_class
|
||||||
return self._insert_event_handler(handler_class, handler_info)
|
return self._insert_event_handler(handler_class, handler_info)
|
||||||
|
|
||||||
async def handle_mai_events(
|
async def handle_mai_events(
|
||||||
@@ -55,6 +54,10 @@ class EventsManager:
|
|||||||
continue_flag = True
|
continue_flag = True
|
||||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||||
for handler in self._events_subscribers.get(event_type, []):
|
for handler in self._events_subscribers.get(event_type, []):
|
||||||
|
if message.chat_stream and message.chat_stream.stream_id:
|
||||||
|
stream_id = message.chat_stream.stream_id
|
||||||
|
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||||
|
continue
|
||||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||||
if handler.intercept_message:
|
if handler.intercept_message:
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +74,9 @@ class EventsManager:
|
|||||||
try:
|
try:
|
||||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||||
handler_task.add_done_callback(self._task_done_callback)
|
handler_task.add_done_callback(self._task_done_callback)
|
||||||
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
|
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
|
||||||
|
if handler.handler_name not in self._handler_tasks:
|
||||||
|
self._handler_tasks[handler.handler_name] = []
|
||||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||||
@@ -91,7 +96,7 @@ class EventsManager:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||||
"""从事件类型列表中移除事件处理器"""
|
"""从事件类型列表中移除事件处理器"""
|
||||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
@@ -190,5 +195,20 @@ class EventsManager:
|
|||||||
finally:
|
finally:
|
||||||
del self._handler_tasks[handler_name]
|
del self._handler_tasks[handler_name]
|
||||||
|
|
||||||
|
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||||
|
"""取消注册事件处理器"""
|
||||||
|
if handler_name not in self._handler_mapping:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self.cancel_handler_tasks(handler_name)
|
||||||
|
|
||||||
|
handler_class = self._handler_mapping.pop(handler_name)
|
||||||
|
if not self._remove_event_handler_instance(handler_class):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
events_manager = EventsManager()
|
events_manager = EventsManager()
|
||||||
|
|||||||
93
src/plugin_system/core/global_announcement_manager.py
Normal file
93
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("global_announcement_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalAnnouncementManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# 用户禁用的动作,chat_id -> [action_name]
|
||||||
|
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||||
|
# 用户禁用的命令,chat_id -> [command_name]
|
||||||
|
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||||
|
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||||
|
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个动作"""
|
||||||
|
if chat_id not in self._user_disabled_actions:
|
||||||
|
self._user_disabled_actions[chat_id] = []
|
||||||
|
if action_name in self._user_disabled_actions[chat_id]:
|
||||||
|
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_actions[chat_id].append(action_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个动作"""
|
||||||
|
if chat_id in self._user_disabled_actions:
|
||||||
|
try:
|
||||||
|
self._user_disabled_actions[chat_id].remove(action_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"动作 {action_name} 不在禁用列表中")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个命令"""
|
||||||
|
if chat_id not in self._user_disabled_commands:
|
||||||
|
self._user_disabled_commands[chat_id] = []
|
||||||
|
if command_name in self._user_disabled_commands[chat_id]:
|
||||||
|
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_commands[chat_id].append(command_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个命令"""
|
||||||
|
if chat_id in self._user_disabled_commands:
|
||||||
|
try:
|
||||||
|
self._user_disabled_commands[chat_id].remove(command_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"命令 {command_name} 不在禁用列表中")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个事件处理器"""
|
||||||
|
if chat_id not in self._user_disabled_event_handlers:
|
||||||
|
self._user_disabled_event_handlers[chat_id] = []
|
||||||
|
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个事件处理器"""
|
||||||
|
if chat_id in self._user_disabled_event_handlers:
|
||||||
|
try:
|
||||||
|
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有动作"""
|
||||||
|
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有命令"""
|
||||||
|
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有事件处理器"""
|
||||||
|
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
|
||||||
|
global_announcement_manager = GlobalAnnouncementManager()
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||||
@@ -8,11 +7,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
|
||||||
from src.plugin_system.base.plugin_base import PluginBase
|
from src.plugin_system.base.plugin_base import PluginBase
|
||||||
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
|
from src.plugin_system.base.component_types import ComponentType, PythonDependency
|
||||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||||
|
from .component_registry import component_registry
|
||||||
|
from .dependency_manager import dependency_manager
|
||||||
|
|
||||||
logger = get_logger("plugin_manager")
|
logger = get_logger("plugin_manager")
|
||||||
|
|
||||||
@@ -36,19 +35,7 @@ class PluginManager:
|
|||||||
self._ensure_plugin_directories()
|
self._ensure_plugin_directories()
|
||||||
logger.info("插件管理器初始化完成")
|
logger.info("插件管理器初始化完成")
|
||||||
|
|
||||||
def _ensure_plugin_directories(self) -> None:
|
# === 插件目录管理 ===
|
||||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
|
||||||
default_directories = ["src/plugins/built_in", "plugins"]
|
|
||||||
|
|
||||||
for directory in default_directories:
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
logger.info(f"创建插件根目录: {directory}")
|
|
||||||
if directory not in self.plugin_directories:
|
|
||||||
self.plugin_directories.append(directory)
|
|
||||||
logger.debug(f"已添加插件根目录: {directory}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"根目录不可重复加载: {directory}")
|
|
||||||
|
|
||||||
def add_plugin_directory(self, directory: str) -> bool:
|
def add_plugin_directory(self, directory: str) -> bool:
|
||||||
"""添加插件目录"""
|
"""添加插件目录"""
|
||||||
@@ -63,6 +50,8 @@ class PluginManager:
|
|||||||
logger.warning(f"插件目录不存在: {directory}")
|
logger.warning(f"插件目录不存在: {directory}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# === 插件加载管理 ===
|
||||||
|
|
||||||
def load_all_plugins(self) -> Tuple[int, int]:
|
def load_all_plugins(self) -> Tuple[int, int]:
|
||||||
"""加载所有插件
|
"""加载所有插件
|
||||||
|
|
||||||
@@ -162,62 +151,50 @@ class PluginManager:
|
|||||||
logger.debug("详细错误信息: ", exc_info=True)
|
logger.debug("详细错误信息: ", exc_info=True)
|
||||||
return False, 1
|
return False, 1
|
||||||
|
|
||||||
def unload_registered_plugin_module(self, plugin_name: str) -> None:
|
async def remove_registered_plugin(self, plugin_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
卸载插件模块
|
禁用插件模块
|
||||||
"""
|
"""
|
||||||
pass
|
if not plugin_name:
|
||||||
|
raise ValueError("插件名称不能为空")
|
||||||
|
if plugin_name not in self.loaded_plugins:
|
||||||
|
logger.warning(f"插件 {plugin_name} 未加载")
|
||||||
|
return False
|
||||||
|
plugin_instance = self.loaded_plugins[plugin_name]
|
||||||
|
plugin_info = plugin_instance.plugin_info
|
||||||
|
success = True
|
||||||
|
for component in plugin_info.components:
|
||||||
|
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
||||||
|
success &= component_registry.remove_plugin_registry(plugin_name)
|
||||||
|
del self.loaded_plugins[plugin_name]
|
||||||
|
return success
|
||||||
|
|
||||||
def reload_registered_plugin_module(self, plugin_name: str) -> None:
|
async def reload_registered_plugin(self, plugin_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
重载插件模块
|
重载插件模块
|
||||||
"""
|
"""
|
||||||
self.unload_registered_plugin_module(plugin_name)
|
if not await self.remove_registered_plugin(plugin_name):
|
||||||
self.load_registered_plugin_classes(plugin_name)
|
return False
|
||||||
|
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
||||||
|
return False
|
||||||
|
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||||
|
return True
|
||||||
|
|
||||||
def rescan_plugin_directory(self) -> None:
|
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
重新扫描插件根目录
|
重新扫描插件根目录
|
||||||
"""
|
"""
|
||||||
# --------------------------------------- NEED REFACTORING ---------------------------------------
|
total_success = 0
|
||||||
|
total_fail = 0
|
||||||
for directory in self.plugin_directories:
|
for directory in self.plugin_directories:
|
||||||
if os.path.exists(directory):
|
if os.path.exists(directory):
|
||||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||||
self._load_plugin_modules_from_directory(directory)
|
success, fail = self._load_plugin_modules_from_directory(directory)
|
||||||
|
total_success += success
|
||||||
|
total_fail += fail
|
||||||
else:
|
else:
|
||||||
logger.warning(f"插件根目录不存在: {directory}")
|
logger.warning(f"插件根目录不存在: {directory}")
|
||||||
|
return total_success, total_fail
|
||||||
def get_loaded_plugins(self) -> List[PluginInfo]:
|
|
||||||
"""获取所有已加载的插件信息"""
|
|
||||||
return list(component_registry.get_all_plugins().values())
|
|
||||||
|
|
||||||
def get_enabled_plugins(self) -> List[PluginInfo]:
|
|
||||||
"""获取所有启用的插件信息"""
|
|
||||||
return list(component_registry.get_enabled_plugins().values())
|
|
||||||
|
|
||||||
# def enable_plugin(self, plugin_name: str) -> bool:
|
|
||||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
|
||||||
# """启用插件"""
|
|
||||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
# plugin_info.enabled = True
|
|
||||||
# # 启用插件的所有组件
|
|
||||||
# for component in plugin_info.components:
|
|
||||||
# component_registry.enable_component(component.name)
|
|
||||||
# logger.debug(f"已启用插件: {plugin_name}")
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# def disable_plugin(self, plugin_name: str) -> bool:
|
|
||||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
|
||||||
# """禁用插件"""
|
|
||||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
# plugin_info.enabled = False
|
|
||||||
# # 禁用插件的所有组件
|
|
||||||
# for component in plugin_info.components:
|
|
||||||
# component_registry.disable_component(component.name)
|
|
||||||
# logger.debug(f"已禁用插件: {plugin_name}")
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||||
"""获取插件实例
|
"""获取插件实例
|
||||||
@@ -230,25 +207,6 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
return self.loaded_plugins.get(plugin_name)
|
return self.loaded_plugins.get(plugin_name)
|
||||||
|
|
||||||
def get_plugin_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取插件统计信息"""
|
|
||||||
all_plugins = component_registry.get_all_plugins()
|
|
||||||
enabled_plugins = component_registry.get_enabled_plugins()
|
|
||||||
|
|
||||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
|
||||||
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_plugins": len(all_plugins),
|
|
||||||
"enabled_plugins": len(enabled_plugins),
|
|
||||||
"failed_plugins": len(self.failed_plugins),
|
|
||||||
"total_components": len(action_components) + len(command_components),
|
|
||||||
"action_components": len(action_components),
|
|
||||||
"command_components": len(command_components),
|
|
||||||
"loaded_plugin_files": len(self.loaded_plugins),
|
|
||||||
"failed_plugin_details": self.failed_plugins.copy(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
||||||
"""检查所有插件的Python依赖包
|
"""检查所有插件的Python依赖包
|
||||||
|
|
||||||
@@ -347,6 +305,43 @@ class PluginManager:
|
|||||||
|
|
||||||
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
||||||
|
|
||||||
|
# === 查询方法 ===
|
||||||
|
def list_loaded_plugins(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有当前加载的插件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 当前加载的插件名称列表。
|
||||||
|
"""
|
||||||
|
return list(self.loaded_plugins.keys())
|
||||||
|
|
||||||
|
def list_registered_plugins(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有已注册的插件类。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 已注册的插件类名称列表。
|
||||||
|
"""
|
||||||
|
return list(self.plugin_classes.keys())
|
||||||
|
|
||||||
|
# === 私有方法 ===
|
||||||
|
# == 目录管理 ==
|
||||||
|
def _ensure_plugin_directories(self) -> None:
|
||||||
|
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||||
|
default_directories = ["src/plugins/built_in", "plugins"]
|
||||||
|
|
||||||
|
for directory in default_directories:
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
logger.info(f"创建插件根目录: {directory}")
|
||||||
|
if directory not in self.plugin_directories:
|
||||||
|
self.plugin_directories.append(directory)
|
||||||
|
logger.debug(f"已添加插件根目录: {directory}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"根目录不可重复加载: {directory}")
|
||||||
|
|
||||||
|
# == 插件加载 ==
|
||||||
|
|
||||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||||
"""从指定目录加载插件模块"""
|
"""从指定目录加载插件模块"""
|
||||||
loaded_count = 0
|
loaded_count = 0
|
||||||
@@ -372,18 +367,6 @@ class PluginManager:
|
|||||||
|
|
||||||
return loaded_count, failed_count
|
return loaded_count, failed_count
|
||||||
|
|
||||||
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
|
|
||||||
"""查找插件类对应的目录路径"""
|
|
||||||
try:
|
|
||||||
# module = getmodule(plugin_class)
|
|
||||||
# if module and hasattr(module, "__file__") and module.__file__:
|
|
||||||
# return os.path.dirname(module.__file__)
|
|
||||||
file_path = inspect.getfile(plugin_class)
|
|
||||||
return os.path.dirname(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||||
# sourcery skip: extract-method
|
# sourcery skip: extract-method
|
||||||
"""加载单个插件模块文件
|
"""加载单个插件模块文件
|
||||||
@@ -416,6 +399,8 @@ class PluginManager:
|
|||||||
self.failed_plugins[module_name] = error_msg
|
self.failed_plugins[module_name] = error_msg
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# == 兼容性检查 ==
|
||||||
|
|
||||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||||
"""检查插件版本兼容性
|
"""检查插件版本兼容性
|
||||||
|
|
||||||
@@ -451,6 +436,8 @@ class PluginManager:
|
|||||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||||
|
|
||||||
|
# == 显示统计与插件信息 ==
|
||||||
|
|
||||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||||
# sourcery skip: low-code-quality
|
# sourcery skip: low-code-quality
|
||||||
# 获取组件统计信息
|
# 获取组件统计信息
|
||||||
@@ -493,9 +480,15 @@ class PluginManager:
|
|||||||
|
|
||||||
# 组件列表
|
# 组件列表
|
||||||
if plugin_info.components:
|
if plugin_info.components:
|
||||||
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
|
action_components = [
|
||||||
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
|
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||||
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
|
]
|
||||||
|
command_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||||
|
]
|
||||||
|
event_handler_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||||
|
]
|
||||||
|
|
||||||
if action_components:
|
if action_components:
|
||||||
action_names = [c.name for c in action_components]
|
action_names = [c.name for c in action_components]
|
||||||
@@ -504,7 +497,7 @@ class PluginManager:
|
|||||||
if command_components:
|
if command_components:
|
||||||
command_names = [c.name for c in command_components]
|
command_names = [c.name for c in command_components]
|
||||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||||
|
|
||||||
if event_handler_components:
|
if event_handler_components:
|
||||||
event_handler_names = [c.name for c in event_handler_components]
|
event_handler_names = [c.name for c in event_handler_components]
|
||||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from src.common.logger import get_logger
|
|||||||
# 导入API模块 - 标准Python包方式
|
# 导入API模块 - 标准Python包方式
|
||||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("emoji")
|
logger = get_logger("emoji")
|
||||||
@@ -102,7 +103,11 @@ class EmojiAction(BaseAction):
|
|||||||
这里是可用的情感标签:{available_emotions}
|
这里是可用的情感标签:{available_emotions}
|
||||||
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
||||||
"""
|
"""
|
||||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
|
||||||
|
if global_config.debug.show_prompt:
|
||||||
|
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
|
|
||||||
# 5. 调用LLM
|
# 5. 调用LLM
|
||||||
models = llm_api.get_available_models()
|
models = llm_api.get_available_models()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.plugin_system.apis import message_api
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("core_actions")
|
logger = get_logger("no_reply_action")
|
||||||
|
|
||||||
|
|
||||||
class NoReplyAction(BaseAction):
|
class NoReplyAction(BaseAction):
|
||||||
|
|||||||
@@ -5,15 +5,10 @@
|
|||||||
这是系统的内置插件,提供基础的聊天交互功能
|
这是系统的内置插件,提供基础的聊天交互功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from typing import List, Tuple, Type
|
from typing import List, Tuple, Type
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
# 导入新插件系统
|
# 导入新插件系统
|
||||||
from src.plugin_system import BasePlugin, register_plugin, BaseAction, ComponentInfo, ActionActivationType, ChatMode
|
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ActionActivationType
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -21,139 +16,12 @@ from src.config.config import global_config
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# 导入API模块 - 标准Python包方式
|
# 导入API模块 - 标准Python包方式
|
||||||
from src.plugin_system.apis import generator_api, message_api
|
|
||||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||||
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.plugins.built_in.core_actions.reply import ReplyAction
|
||||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
|
||||||
|
|
||||||
logger = get_logger("core_actions")
|
logger = get_logger("core_actions")
|
||||||
|
|
||||||
# 常量定义
|
|
||||||
WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒
|
|
||||||
|
|
||||||
ENABLE_THINKING = False
|
|
||||||
|
|
||||||
class ReplyAction(BaseAction):
|
|
||||||
"""回复动作 - 参与聊天回复"""
|
|
||||||
|
|
||||||
# 激活设置
|
|
||||||
focus_activation_type = ActionActivationType.NEVER
|
|
||||||
normal_activation_type = ActionActivationType.NEVER
|
|
||||||
mode_enable = ChatMode.FOCUS
|
|
||||||
parallel_action = False
|
|
||||||
|
|
||||||
# 动作基本信息
|
|
||||||
action_name = "reply"
|
|
||||||
action_description = "参与聊天回复,发送文本进行表达"
|
|
||||||
|
|
||||||
# 动作参数定义
|
|
||||||
action_parameters = {}
|
|
||||||
|
|
||||||
# 动作使用场景
|
|
||||||
action_require = ["你想要闲聊或者随便附和", "有人提到你", "如果你刚刚进行了回复,不要对同一个话题重复回应"]
|
|
||||||
|
|
||||||
# 关联类型
|
|
||||||
associated_types = ["text"]
|
|
||||||
|
|
||||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
|
||||||
sender = ""
|
|
||||||
target = ""
|
|
||||||
if ":" in target_message or ":" in target_message:
|
|
||||||
# 使用正则表达式匹配中文或英文冒号
|
|
||||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
sender = parts[0].strip()
|
|
||||||
target = parts[1].strip()
|
|
||||||
return sender, target
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""执行回复动作"""
|
|
||||||
logger.info(f"{self.log_prefix} 决定进行回复")
|
|
||||||
start_time = self.action_data.get("loop_start_time", time.time())
|
|
||||||
|
|
||||||
user_id = self.user_id
|
|
||||||
platform = self.platform
|
|
||||||
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
|
|
||||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
|
||||||
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
|
|
||||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
|
|
||||||
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}"
|
|
||||||
logger.info(f"{self.log_prefix} 回复目标: {reply_to}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
|
||||||
reply_text = prepared_reply
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
success, reply_set, _ = await asyncio.wait_for(
|
|
||||||
generator_api.generate_reply(
|
|
||||||
extra_info="",
|
|
||||||
reply_to=reply_to,
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
request_type="chat.replyer.focus",
|
|
||||||
enable_tool=global_config.tool.enable_in_focus_chat,
|
|
||||||
),
|
|
||||||
timeout=global_config.chat.thinking_timeout,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
|
|
||||||
return False, "timeout"
|
|
||||||
|
|
||||||
# 检查从start_time以来的新消息数量
|
|
||||||
# 获取动作触发时间或使用默认值
|
|
||||||
current_time = time.time()
|
|
||||||
new_message_count = message_api.count_new_messages(
|
|
||||||
chat_id=self.chat_id, start_time=start_time, end_time=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据新消息数量决定是否使用reply_to
|
|
||||||
need_reply = new_message_count >= random.randint(2, 4)
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
|
||||||
)
|
|
||||||
# 构建回复文本
|
|
||||||
reply_text = ""
|
|
||||||
first_replied = False
|
|
||||||
reply_to_platform_id = f"{platform}:{user_id}"
|
|
||||||
for reply_seg in reply_set:
|
|
||||||
data = reply_seg[1]
|
|
||||||
if not first_replied:
|
|
||||||
if need_reply:
|
|
||||||
await self.send_text(
|
|
||||||
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
|
||||||
first_replied = True
|
|
||||||
else:
|
|
||||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
|
|
||||||
reply_text += data
|
|
||||||
|
|
||||||
# 存储动作记录
|
|
||||||
reply_text = f"你对{person_name}进行了回复:{reply_text}"
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_THINKING:
|
|
||||||
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
|
|
||||||
|
|
||||||
|
|
||||||
await self.store_action_info(
|
|
||||||
action_build_into_prompt=False,
|
|
||||||
action_prompt_display=reply_text,
|
|
||||||
action_done=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 重置NoReplyAction的连续计数器
|
|
||||||
NoReplyAction.reset_consecutive_count()
|
|
||||||
|
|
||||||
return success, reply_text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False, f"回复失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class CoreActionsPlugin(BasePlugin):
|
class CoreActionsPlugin(BasePlugin):
|
||||||
@@ -168,11 +36,11 @@ class CoreActionsPlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "core_actions" # 内部标识符
|
plugin_name: str = "core_actions" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin: bool = True
|
||||||
dependencies = [] # 插件依赖列表
|
dependencies: list[str] = [] # 插件依赖列表
|
||||||
python_dependencies = [] # Python包依赖列表
|
python_dependencies: list[str] = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name: str = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {
|
config_section_descriptions = {
|
||||||
@@ -181,7 +49,7 @@ class CoreActionsPlugin(BasePlugin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema = {
|
config_schema: dict = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||||
"config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"),
|
"config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"),
|
||||||
|
|||||||
149
src/plugins/built_in/core_actions/reply.py
Normal file
149
src/plugins/built_in/core_actions/reply.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# 导入新插件系统
|
||||||
|
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||||
|
from src.config.config import global_config
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Tuple
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# 导入依赖的系统组件
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
# 导入API模块 - 标准Python包方式
|
||||||
|
from src.plugin_system.apis import generator_api, message_api
|
||||||
|
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
|
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||||
|
|
||||||
|
logger = get_logger("reply_action")
|
||||||
|
|
||||||
|
|
||||||
|
class ReplyAction(BaseAction):
|
||||||
|
"""回复动作 - 参与聊天回复"""
|
||||||
|
|
||||||
|
# 激活设置
|
||||||
|
focus_activation_type = ActionActivationType.NEVER
|
||||||
|
normal_activation_type = ActionActivationType.NEVER
|
||||||
|
mode_enable = ChatMode.FOCUS
|
||||||
|
parallel_action = False
|
||||||
|
|
||||||
|
# 动作基本信息
|
||||||
|
action_name = "reply"
|
||||||
|
action_description = ""
|
||||||
|
|
||||||
|
# 动作参数定义
|
||||||
|
action_parameters = {}
|
||||||
|
|
||||||
|
# 动作使用场景
|
||||||
|
action_require = [""]
|
||||||
|
|
||||||
|
# 关联类型
|
||||||
|
associated_types = ["text"]
|
||||||
|
|
||||||
|
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||||
|
sender = ""
|
||||||
|
target = ""
|
||||||
|
# 添加None检查,防止NoneType错误
|
||||||
|
if target_message is None:
|
||||||
|
return sender, target
|
||||||
|
if ":" in target_message or ":" in target_message:
|
||||||
|
# 使用正则表达式匹配中文或英文冒号
|
||||||
|
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
sender = parts[0].strip()
|
||||||
|
target = parts[1].strip()
|
||||||
|
return sender, target
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
"""执行回复动作"""
|
||||||
|
logger.debug(f"{self.log_prefix} 决定进行回复")
|
||||||
|
start_time = self.action_data.get("loop_start_time", time.time())
|
||||||
|
|
||||||
|
user_id = self.user_id
|
||||||
|
platform = self.platform
|
||||||
|
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
|
||||||
|
person_id = get_person_info_manager().get_person_id(platform, user_id) # type: ignore
|
||||||
|
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
|
||||||
|
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
|
||||||
|
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" # type: ignore
|
||||||
|
logger.info(f"{self.log_prefix} 决定进行回复,目标: {reply_to}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
||||||
|
reply_text = prepared_reply
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
success, reply_set, _ = await asyncio.wait_for(
|
||||||
|
generator_api.generate_reply(
|
||||||
|
extra_info="",
|
||||||
|
reply_to=reply_to,
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
request_type="chat.replyer.focus",
|
||||||
|
enable_tool=global_config.tool.enable_in_focus_chat,
|
||||||
|
),
|
||||||
|
timeout=global_config.chat.thinking_timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
|
||||||
|
return False, "timeout"
|
||||||
|
|
||||||
|
# 检查从start_time以来的新消息数量
|
||||||
|
# 获取动作触发时间或使用默认值
|
||||||
|
current_time = time.time()
|
||||||
|
new_message_count = message_api.count_new_messages(
|
||||||
|
chat_id=self.chat_id, start_time=start_time, end_time=current_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据新消息数量决定是否使用reply_to
|
||||||
|
need_reply = new_message_count >= random.randint(2, 4)
|
||||||
|
if need_reply:
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建回复文本
|
||||||
|
reply_text = ""
|
||||||
|
first_replied = False
|
||||||
|
reply_to_platform_id = f"{platform}:{user_id}"
|
||||||
|
for reply_seg in reply_set:
|
||||||
|
data = reply_seg[1]
|
||||||
|
if not first_replied:
|
||||||
|
if need_reply:
|
||||||
|
await self.send_text(
|
||||||
|
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||||
|
first_replied = True
|
||||||
|
else:
|
||||||
|
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||||
|
reply_text += data
|
||||||
|
|
||||||
|
# 存储动作记录
|
||||||
|
reply_text = f"你对{person_name}进行了回复:{reply_text}"
|
||||||
|
|
||||||
|
if ENABLE_S4U:
|
||||||
|
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
|
||||||
|
|
||||||
|
await self.store_action_info(
|
||||||
|
action_build_into_prompt=False,
|
||||||
|
action_prompt_display=reply_text,
|
||||||
|
action_done=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重置NoReplyAction的连续计数器
|
||||||
|
NoReplyAction.reset_consecutive_count()
|
||||||
|
|
||||||
|
return success, reply_text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False, f"回复失败: {str(e)}"
|
||||||
39
src/plugins/built_in/plugin_management/_manifest.json
Normal file
39
src/plugins/built_in/plugin_management/_manifest.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"manifest_version": 1,
|
||||||
|
"name": "插件和组件管理 (Plugin and Component Management)",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
||||||
|
"author": {
|
||||||
|
"name": "MaiBot团队",
|
||||||
|
"url": "https://github.com/MaiM-with-u"
|
||||||
|
},
|
||||||
|
"license": "GPL-v3.0-or-later",
|
||||||
|
"host_application": {
|
||||||
|
"min_version": "0.9.0"
|
||||||
|
},
|
||||||
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
|
"keywords": [
|
||||||
|
"plugins",
|
||||||
|
"components",
|
||||||
|
"management",
|
||||||
|
"built-in"
|
||||||
|
],
|
||||||
|
"categories": [
|
||||||
|
"Core System",
|
||||||
|
"Plugin Management"
|
||||||
|
],
|
||||||
|
"default_locale": "zh-CN",
|
||||||
|
"locales_path": "_locales",
|
||||||
|
"plugin_info": {
|
||||||
|
"is_built_in": true,
|
||||||
|
"plugin_type": "plugin_management",
|
||||||
|
"components": [
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"name": "plugin_management",
|
||||||
|
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
440
src/plugins/built_in/plugin_management/plugin.py
Normal file
440
src/plugins/built_in/plugin_management/plugin.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import List, Tuple, Type
|
||||||
|
from src.plugin_system import (
|
||||||
|
BasePlugin,
|
||||||
|
BaseCommand,
|
||||||
|
CommandInfo,
|
||||||
|
ConfigField,
|
||||||
|
register_plugin,
|
||||||
|
plugin_manage_api,
|
||||||
|
component_manage_api,
|
||||||
|
ComponentInfo,
|
||||||
|
ComponentType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ManagementCommand(BaseCommand):
|
||||||
|
command_name: str = "management"
|
||||||
|
description: str = "管理命令"
|
||||||
|
command_pattern: str = r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)"
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str, bool]:
|
||||||
|
# sourcery skip: merge-duplicate-blocks
|
||||||
|
if (
|
||||||
|
not self.message
|
||||||
|
or not self.message.message_info
|
||||||
|
or not self.message.message_info.user_info
|
||||||
|
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
|
||||||
|
):
|
||||||
|
await self.send_text("你没有权限使用插件管理命令")
|
||||||
|
return False, "没有权限", True
|
||||||
|
command_list = self.matched_groups["manage_command"].strip().split(" ")
|
||||||
|
if len(command_list) == 1:
|
||||||
|
await self.show_help("all")
|
||||||
|
return True, "帮助已发送", True
|
||||||
|
if len(command_list) == 2:
|
||||||
|
match command_list[1]:
|
||||||
|
case "plugin":
|
||||||
|
await self.show_help("plugin")
|
||||||
|
case "component":
|
||||||
|
await self.show_help("component")
|
||||||
|
case "help":
|
||||||
|
await self.show_help("all")
|
||||||
|
case _:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if len(command_list) == 3:
|
||||||
|
if command_list[1] == "plugin":
|
||||||
|
match command_list[2]:
|
||||||
|
case "help":
|
||||||
|
await self.show_help("plugin")
|
||||||
|
case "list":
|
||||||
|
await self._list_registered_plugins()
|
||||||
|
case "list_enabled":
|
||||||
|
await self._list_loaded_plugins()
|
||||||
|
case "rescan":
|
||||||
|
await self._rescan_plugin_dirs()
|
||||||
|
case _:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
elif command_list[1] == "component":
|
||||||
|
if command_list[2] == "list":
|
||||||
|
await self._list_all_registered_components()
|
||||||
|
elif command_list[2] == "help":
|
||||||
|
await self.show_help("component")
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if len(command_list) == 4:
|
||||||
|
if command_list[1] == "plugin":
|
||||||
|
match command_list[2]:
|
||||||
|
case "load":
|
||||||
|
await self._load_plugin(command_list[3])
|
||||||
|
case "unload":
|
||||||
|
await self._unload_plugin(command_list[3])
|
||||||
|
case "reload":
|
||||||
|
await self._reload_plugin(command_list[3])
|
||||||
|
case "add_dir":
|
||||||
|
await self._add_dir(command_list[3])
|
||||||
|
case _:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
elif command_list[1] == "component":
|
||||||
|
if command_list[2] != "list":
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if command_list[3] == "enabled":
|
||||||
|
await self._list_enabled_components()
|
||||||
|
elif command_list[3] == "disabled":
|
||||||
|
await self._list_disabled_components()
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if len(command_list) == 5:
|
||||||
|
if command_list[1] != "component":
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if command_list[2] != "list":
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if command_list[3] == "enabled":
|
||||||
|
await self._list_enabled_components(target_type=command_list[4])
|
||||||
|
elif command_list[3] == "disabled":
|
||||||
|
await self._list_disabled_components(target_type=command_list[4])
|
||||||
|
elif command_list[3] == "type":
|
||||||
|
await self._list_registered_components_by_type(command_list[4])
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if len(command_list) == 6:
|
||||||
|
if command_list[1] != "component":
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
if command_list[2] == "enable":
|
||||||
|
if command_list[3] == "global":
|
||||||
|
await self._globally_enable_component(command_list[4], command_list[5])
|
||||||
|
elif command_list[3] == "local":
|
||||||
|
await self._locally_enable_component(command_list[4], command_list[5])
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
elif command_list[2] == "disable":
|
||||||
|
if command_list[3] == "global":
|
||||||
|
await self._globally_disable_component(command_list[4], command_list[5])
|
||||||
|
elif command_list[3] == "local":
|
||||||
|
await self._locally_disable_component(command_list[4], command_list[5])
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
else:
|
||||||
|
await self.send_text("插件管理命令不合法")
|
||||||
|
return False, "命令不合法", True
|
||||||
|
|
||||||
|
return True, "命令执行完成", True
|
||||||
|
|
||||||
|
async def show_help(self, target: str):
|
||||||
|
help_msg = ""
|
||||||
|
match target:
|
||||||
|
case "all":
|
||||||
|
help_msg = (
|
||||||
|
"管理命令帮助\n"
|
||||||
|
"/pm help 管理命令提示\n"
|
||||||
|
"/pm plugin 插件管理命令\n"
|
||||||
|
"/pm component 组件管理命令\n"
|
||||||
|
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
|
||||||
|
)
|
||||||
|
case "plugin":
|
||||||
|
help_msg = (
|
||||||
|
"插件管理命令帮助\n"
|
||||||
|
"/pm plugin help 插件管理命令提示\n"
|
||||||
|
"/pm plugin list 列出所有注册的插件\n"
|
||||||
|
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
|
||||||
|
"/pm plugin rescan 重新扫描所有目录\n"
|
||||||
|
"/pm plugin load <plugin_name> 加载指定插件\n"
|
||||||
|
"/pm plugin unload <plugin_name> 卸载指定插件\n"
|
||||||
|
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
|
||||||
|
"/pm plugin add_dir <directory_path> 添加插件目录\n"
|
||||||
|
)
|
||||||
|
case "component":
|
||||||
|
help_msg = (
|
||||||
|
"组件管理命令帮助\n"
|
||||||
|
"/pm component help 组件管理命令提示\n"
|
||||||
|
"/pm component list 列出所有注册的组件\n"
|
||||||
|
"/pm component list enabled <可选: type> 列出所有启用的组件\n"
|
||||||
|
"/pm component list disabled <可选: type> 列出所有禁用的组件\n"
|
||||||
|
" - <type> 可选项: local,代表当前聊天中的;global,代表全局的\n"
|
||||||
|
" - <type> 不填时为 global\n"
|
||||||
|
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n"
|
||||||
|
"/pm component enable global <component_name> <component_type> 全局启用组件\n"
|
||||||
|
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
|
||||||
|
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
|
||||||
|
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
|
||||||
|
" - <component_type> 可选项: action, command, event_handler\n"
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
return
|
||||||
|
await self.send_text(help_msg)
|
||||||
|
|
||||||
|
async def _list_loaded_plugins(self):
|
||||||
|
plugins = plugin_manage_api.list_loaded_plugins()
|
||||||
|
await self.send_text(f"已加载的插件: {', '.join(plugins)}")
|
||||||
|
|
||||||
|
async def _list_registered_plugins(self):
|
||||||
|
plugins = plugin_manage_api.list_registered_plugins()
|
||||||
|
await self.send_text(f"已注册的插件: {', '.join(plugins)}")
|
||||||
|
|
||||||
|
async def _rescan_plugin_dirs(self):
|
||||||
|
plugin_manage_api.rescan_plugin_directory()
|
||||||
|
await self.send_text("插件目录重新扫描执行中")
|
||||||
|
|
||||||
|
async def _load_plugin(self, plugin_name: str):
|
||||||
|
success, count = plugin_manage_api.load_plugin(plugin_name)
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"插件加载成功: {plugin_name}")
|
||||||
|
else:
|
||||||
|
if count == 0:
|
||||||
|
await self.send_text(f"插件{plugin_name}为禁用状态")
|
||||||
|
await self.send_text(f"插件加载失败: {plugin_name}")
|
||||||
|
|
||||||
|
async def _unload_plugin(self, plugin_name: str):
|
||||||
|
success = await plugin_manage_api.remove_plugin(plugin_name)
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"插件卸载成功: {plugin_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"插件卸载失败: {plugin_name}")
|
||||||
|
|
||||||
|
async def _reload_plugin(self, plugin_name: str):
|
||||||
|
success = await plugin_manage_api.reload_plugin(plugin_name)
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"插件重新加载成功: {plugin_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"插件重新加载失败: {plugin_name}")
|
||||||
|
|
||||||
|
async def _add_dir(self, dir_path: str):
|
||||||
|
await self.send_text(f"正在添加插件目录: {dir_path}")
|
||||||
|
success = plugin_manage_api.add_plugin_directory(dir_path)
|
||||||
|
await asyncio.sleep(0.5) # 防止乱序发送
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"插件目录添加成功: {dir_path}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"插件目录添加失败: {dir_path}")
|
||||||
|
|
||||||
|
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
|
||||||
|
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||||
|
if not all_plugin_info:
|
||||||
|
return []
|
||||||
|
|
||||||
|
components_info: List[ComponentInfo] = []
|
||||||
|
for plugin_info in all_plugin_info.values():
|
||||||
|
components_info.extend(plugin_info.components)
|
||||||
|
return components_info
|
||||||
|
|
||||||
|
def _fetch_locally_disabled_components(self) -> List[str]:
|
||||||
|
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
|
||||||
|
self.message.chat_stream.stream_id, ComponentType.ACTION
|
||||||
|
)
|
||||||
|
locally_disabled_components_commands = component_manage_api.get_locally_disabled_components(
|
||||||
|
self.message.chat_stream.stream_id, ComponentType.COMMAND
|
||||||
|
)
|
||||||
|
locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components(
|
||||||
|
self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
locally_disabled_components_actions
|
||||||
|
+ locally_disabled_components_commands
|
||||||
|
+ locally_disabled_components_event_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _list_all_registered_components(self):
|
||||||
|
components_info = self._fetch_all_registered_components()
|
||||||
|
if not components_info:
|
||||||
|
await self.send_text("没有注册的组件")
|
||||||
|
return
|
||||||
|
|
||||||
|
all_components_str = ", ".join(
|
||||||
|
f"{component.name} ({component.component_type})" for component in components_info
|
||||||
|
)
|
||||||
|
await self.send_text(f"已注册的组件: {all_components_str}")
|
||||||
|
|
||||||
|
async def _list_enabled_components(self, target_type: str = "global"):
|
||||||
|
components_info = self._fetch_all_registered_components()
|
||||||
|
if not components_info:
|
||||||
|
await self.send_text("没有注册的组件")
|
||||||
|
return
|
||||||
|
|
||||||
|
if target_type == "global":
|
||||||
|
enabled_components = [component for component in components_info if component.enabled]
|
||||||
|
if not enabled_components:
|
||||||
|
await self.send_text("没有满足条件的已启用全局组件")
|
||||||
|
return
|
||||||
|
enabled_components_str = ", ".join(
|
||||||
|
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||||
|
)
|
||||||
|
await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}")
|
||||||
|
elif target_type == "local":
|
||||||
|
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||||
|
enabled_components = [
|
||||||
|
component
|
||||||
|
for component in components_info
|
||||||
|
if (component.name not in locally_disabled_components and component.enabled)
|
||||||
|
]
|
||||||
|
if not enabled_components:
|
||||||
|
await self.send_text("本聊天没有满足条件的已启用组件")
|
||||||
|
return
|
||||||
|
enabled_components_str = ", ".join(
|
||||||
|
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||||
|
)
|
||||||
|
await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}")
|
||||||
|
|
||||||
|
async def _list_disabled_components(self, target_type: str = "global"):
|
||||||
|
components_info = self._fetch_all_registered_components()
|
||||||
|
if not components_info:
|
||||||
|
await self.send_text("没有注册的组件")
|
||||||
|
return
|
||||||
|
|
||||||
|
if target_type == "global":
|
||||||
|
disabled_components = [component for component in components_info if not component.enabled]
|
||||||
|
if not disabled_components:
|
||||||
|
await self.send_text("没有满足条件的已禁用全局组件")
|
||||||
|
return
|
||||||
|
disabled_components_str = ", ".join(
|
||||||
|
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||||
|
)
|
||||||
|
await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}")
|
||||||
|
elif target_type == "local":
|
||||||
|
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||||
|
disabled_components = [
|
||||||
|
component
|
||||||
|
for component in components_info
|
||||||
|
if (component.name in locally_disabled_components or not component.enabled)
|
||||||
|
]
|
||||||
|
if not disabled_components:
|
||||||
|
await self.send_text("本聊天没有满足条件的已禁用组件")
|
||||||
|
return
|
||||||
|
disabled_components_str = ", ".join(
|
||||||
|
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||||
|
)
|
||||||
|
await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
|
||||||
|
|
||||||
|
async def _list_registered_components_by_type(self, target_type: str):
|
||||||
|
match target_type:
|
||||||
|
case "action":
|
||||||
|
component_type = ComponentType.ACTION
|
||||||
|
case "command":
|
||||||
|
component_type = ComponentType.COMMAND
|
||||||
|
case "event_handler":
|
||||||
|
component_type = ComponentType.EVENT_HANDLER
|
||||||
|
case _:
|
||||||
|
await self.send_text(f"未知组件类型: {target_type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
components_info = component_manage_api.get_components_info_by_type(component_type)
|
||||||
|
if not components_info:
|
||||||
|
await self.send_text(f"没有注册的 {target_type} 组件")
|
||||||
|
return
|
||||||
|
|
||||||
|
components_str = ", ".join(
|
||||||
|
f"{name} ({component.component_type})" for name, component in components_info.items()
|
||||||
|
)
|
||||||
|
await self.send_text(f"注册的 {target_type} 组件: {components_str}")
|
||||||
|
|
||||||
|
async def _globally_enable_component(self, component_name: str, component_type: str):
|
||||||
|
match component_type:
|
||||||
|
case "action":
|
||||||
|
target_component_type = ComponentType.ACTION
|
||||||
|
case "command":
|
||||||
|
target_component_type = ComponentType.COMMAND
|
||||||
|
case "event_handler":
|
||||||
|
target_component_type = ComponentType.EVENT_HANDLER
|
||||||
|
case _:
|
||||||
|
await self.send_text(f"未知组件类型: {component_type}")
|
||||||
|
return
|
||||||
|
if component_manage_api.globally_enable_component(component_name, target_component_type):
|
||||||
|
await self.send_text(f"全局启用组件成功: {component_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"全局启用组件失败: {component_name}")
|
||||||
|
|
||||||
|
async def _globally_disable_component(self, component_name: str, component_type: str):
|
||||||
|
match component_type:
|
||||||
|
case "action":
|
||||||
|
target_component_type = ComponentType.ACTION
|
||||||
|
case "command":
|
||||||
|
target_component_type = ComponentType.COMMAND
|
||||||
|
case "event_handler":
|
||||||
|
target_component_type = ComponentType.EVENT_HANDLER
|
||||||
|
case _:
|
||||||
|
await self.send_text(f"未知组件类型: {component_type}")
|
||||||
|
return
|
||||||
|
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"全局禁用组件成功: {component_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"全局禁用组件失败: {component_name}")
|
||||||
|
|
||||||
|
async def _locally_enable_component(self, component_name: str, component_type: str):
|
||||||
|
match component_type:
|
||||||
|
case "action":
|
||||||
|
target_component_type = ComponentType.ACTION
|
||||||
|
case "command":
|
||||||
|
target_component_type = ComponentType.COMMAND
|
||||||
|
case "event_handler":
|
||||||
|
target_component_type = ComponentType.EVENT_HANDLER
|
||||||
|
case _:
|
||||||
|
await self.send_text(f"未知组件类型: {component_type}")
|
||||||
|
return
|
||||||
|
if component_manage_api.locally_enable_component(
|
||||||
|
component_name,
|
||||||
|
target_component_type,
|
||||||
|
self.message.chat_stream.stream_id,
|
||||||
|
):
|
||||||
|
await self.send_text(f"本地启用组件成功: {component_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"本地启用组件失败: {component_name}")
|
||||||
|
|
||||||
|
async def _locally_disable_component(self, component_name: str, component_type: str):
|
||||||
|
match component_type:
|
||||||
|
case "action":
|
||||||
|
target_component_type = ComponentType.ACTION
|
||||||
|
case "command":
|
||||||
|
target_component_type = ComponentType.COMMAND
|
||||||
|
case "event_handler":
|
||||||
|
target_component_type = ComponentType.EVENT_HANDLER
|
||||||
|
case _:
|
||||||
|
await self.send_text(f"未知组件类型: {component_type}")
|
||||||
|
return
|
||||||
|
if component_manage_api.locally_disable_component(
|
||||||
|
component_name,
|
||||||
|
target_component_type,
|
||||||
|
self.message.chat_stream.stream_id,
|
||||||
|
):
|
||||||
|
await self.send_text(f"本地禁用组件成功: {component_name}")
|
||||||
|
else:
|
||||||
|
await self.send_text(f"本地禁用组件失败: {component_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@register_plugin
|
||||||
|
class PluginManagementPlugin(BasePlugin):
|
||||||
|
plugin_name: str = "plugin_management_plugin"
|
||||||
|
enable_plugin: bool = True
|
||||||
|
dependencies: list[str] = []
|
||||||
|
python_dependencies: list[str] = []
|
||||||
|
config_file_name: str = "config.toml"
|
||||||
|
config_schema: dict = {
|
||||||
|
"plugin": {
|
||||||
|
"enable": ConfigField(bool, default=True, description="是否启用插件"),
|
||||||
|
"permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
|
||||||
|
components = []
|
||||||
|
if self.get_config("plugin.enable", True):
|
||||||
|
components.append((ManagementCommand.get_command_info(), ManagementCommand))
|
||||||
|
return components
|
||||||
@@ -92,7 +92,7 @@ class TTSAction(BaseAction):
|
|||||||
|
|
||||||
# 确保句子结尾有合适的标点
|
# 确保句子结尾有合适的标点
|
||||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||||
processed_text = processed_text + "。"
|
processed_text = f"{processed_text}。"
|
||||||
|
|
||||||
return processed_text
|
return processed_text
|
||||||
|
|
||||||
@@ -107,11 +107,11 @@ class TTSPlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "tts_plugin" # 内部标识符
|
plugin_name: str = "tts_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin: bool = True
|
||||||
dependencies = [] # 插件依赖列表
|
dependencies: list[str] = [] # 插件依赖列表
|
||||||
python_dependencies = [] # Python包依赖列表
|
python_dependencies: list[str] = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name: str = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {
|
config_section_descriptions = {
|
||||||
@@ -121,7 +121,7 @@ class TTSPlugin(BasePlugin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema = {
|
config_schema: dict = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "4.4.4"
|
version = "4.4.8"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||||
@@ -13,6 +13,7 @@ version = "4.4.4"
|
|||||||
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
|
|
||||||
[bot]
|
[bot]
|
||||||
|
platform = "qq"
|
||||||
qq_account = 1145141919810 # 麦麦的QQ账号
|
qq_account = 1145141919810 # 麦麦的QQ账号
|
||||||
nickname = "麦麦" # 麦麦的昵称
|
nickname = "麦麦" # 麦麦的昵称
|
||||||
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||||
@@ -33,7 +34,7 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
|
|||||||
# 表达方式
|
# 表达方式
|
||||||
enable_expression = true # 是否启用表达方式
|
enable_expression = true # 是否启用表达方式
|
||||||
# 描述麦麦说话的表达风格,表达习惯,例如:(请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。)
|
# 描述麦麦说话的表达风格,表达习惯,例如:(请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。)
|
||||||
expression_style = "请回复的平淡些,简短一些,说中文,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,不要刻意突出自身学科背景。"
|
expression_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
|
||||||
enable_expression_learning = false # 是否启用表达学习,麦麦会学习不同群里人类说话风格(群之间不互通)
|
enable_expression_learning = false # 是否启用表达学习,麦麦会学习不同群里人类说话风格(群之间不互通)
|
||||||
learning_interval = 350 # 学习间隔 单位秒
|
learning_interval = 350 # 学习间隔 单位秒
|
||||||
|
|
||||||
@@ -58,6 +59,9 @@ max_context_size = 25 # 上下文长度
|
|||||||
thinking_timeout = 20 # 麦麦一次回复最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)
|
thinking_timeout = 20 # 麦麦一次回复最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)
|
||||||
replyer_random_probability = 0.5 # 首要replyer模型被选择的概率
|
replyer_random_probability = 0.5 # 首要replyer模型被选择的概率
|
||||||
|
|
||||||
|
mentioned_bot_inevitable_reply = true # 提及 bot 大概率回复
|
||||||
|
at_bot_inevitable_reply = true # @bot 或 提及bot 大概率回复
|
||||||
|
|
||||||
use_s4u_prompt_mode = true # 是否使用 s4u 对话构建模式,该模式会更好的把握当前对话对象的对话内容,但是对群聊整理理解能力较差(测试功能!!可能有未知问题!!)
|
use_s4u_prompt_mode = true # 是否使用 s4u 对话构建模式,该模式会更好的把握当前对话对象的对话内容,但是对群聊整理理解能力较差(测试功能!!可能有未知问题!!)
|
||||||
|
|
||||||
|
|
||||||
@@ -87,8 +91,6 @@ talk_frequency_adjust = [
|
|||||||
# - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3
|
# - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3
|
||||||
# - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配
|
# - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配
|
||||||
|
|
||||||
enable_asr = false # 是否启用语音识别,启用后麦麦可以通过语音输入进行对话,启用该功能需要配置语音识别模型[model.voice]
|
|
||||||
|
|
||||||
[message_receive]
|
[message_receive]
|
||||||
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
|
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
|
||||||
ban_words = [
|
ban_words = [
|
||||||
@@ -102,11 +104,8 @@ ban_msgs_regex = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[normal_chat] #普通聊天
|
[normal_chat] #普通聊天
|
||||||
#一般回复参数
|
|
||||||
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
|
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
|
||||||
response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数
|
response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数
|
||||||
mentioned_bot_inevitable_reply = true # 提及 bot 必然回复
|
|
||||||
at_bot_inevitable_reply = true # @bot 必然回复(包含提及)
|
|
||||||
|
|
||||||
[tool]
|
[tool]
|
||||||
enable_in_normal_chat = false # 是否在普通聊天中启用工具
|
enable_in_normal_chat = false # 是否在普通聊天中启用工具
|
||||||
@@ -144,14 +143,15 @@ enable_instant_memory = false # 是否启用即时记忆,测试功能,可能
|
|||||||
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
|
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
|
||||||
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
||||||
|
|
||||||
|
[voice]
|
||||||
|
enable_asr = false # 是否启用语音识别,启用后麦麦可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]s
|
||||||
|
|
||||||
[mood]
|
[mood]
|
||||||
enable_mood = true # 是否启用情绪系统
|
enable_mood = true # 是否启用情绪系统
|
||||||
mood_update_interval = 1.0 # 情绪更新间隔 单位秒
|
mood_update_threshold = 1 # 情绪更新阈值,越高,更新越慢
|
||||||
mood_decay_rate = 0.95 # 情绪衰减率
|
|
||||||
mood_intensity_factor = 1.0 # 情绪强度因子
|
|
||||||
|
|
||||||
[lpmm_knowledge] # lpmm知识库配置
|
[lpmm_knowledge] # lpmm知识库配置
|
||||||
enable = true # 是否启用lpmm知识库
|
enable = false # 是否启用lpmm知识库
|
||||||
rag_synonym_search_top_k = 10 # 同义词搜索TopK
|
rag_synonym_search_top_k = 10 # 同义词搜索TopK
|
||||||
rag_synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
|
rag_synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
|
||||||
info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5
|
info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5
|
||||||
@@ -229,7 +229,7 @@ show_prompt = false # 是否显示prompt
|
|||||||
|
|
||||||
|
|
||||||
[model]
|
[model]
|
||||||
model_max_output_length = 1000 # 模型单次返回的最大token数
|
model_max_output_length = 1024 # 模型单次返回的最大token数
|
||||||
|
|
||||||
#------------必填:组件模型------------
|
#------------必填:组件模型------------
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user