diff --git a/README.md b/README.md index f450fc0a4..3a9e14f80 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,11 @@ **🍔MaiCore 是一个基于大语言模型的可交互智能体** -- 💭 **智能对话系统**:基于 LLM 的自然语言交互。 +- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。 +- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。 - 🤔 **实时思维系统**:模拟人类思考过程。 -- 💝 **情感表达系统**:丰富的表情包和情绪表达。 +- 🧠 **表达学习功能**:学习群友的说话风格和表达方式 +- 💝 **情感表达系统**:情绪系统和表情包系统。 - 🧠 **持久记忆系统**:基于图的长期记忆存储。 - 🔄 **动态人格系统**:自适应的性格特征和表达方式。 @@ -44,11 +46,10 @@ ## 🔥 更新和安装 - -**最新版本: v0.8.1** ([更新日志](changelogs/changelog.md)) +**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md)) 可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 -可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/tag/v0.1.0)下载最新启动器 +可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器 **GitHub 分支说明:** - `main`: 稳定发布版本(推荐) - `dev`: 开发测试版本(不稳定) @@ -68,11 +69,17 @@ ## 💬 讨论 -- [四群](https://qm.qq.com/q/wGePTl1UyY) | - [一群](https://qm.qq.com/q/VQ3XZrWgMs) | +**技术交流群:** +- [一群](https://qm.qq.com/q/VQ3XZrWgMs) | [二群](https://qm.qq.com/q/RzmCiRtHEW) | - [五群](https://qm.qq.com/q/JxvHZnxyec) | - [三群](https://qm.qq.com/q/wlH5eT8OmQ) + [三群](https://qm.qq.com/q/wlH5eT8OmQ) | + [四群](https://qm.qq.com/q/wGePTl1UyY) + +**聊天吹水群:** +- [五群](https://qm.qq.com/q/JxvHZnxyec) + +**插件开发测试版群:** +- [插件开发群](https://qm.qq.com/q/1036092828) ## 📚 文档 diff --git a/changelogs/changelog.md b/changelogs/changelog.md index 4d9760629..c56426a72 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -1,29 +1,80 @@ # Changelog -## [0.8.2] - 2025-7-5 +## [0.9.1] - 2025-7-25 -功能更新: +- 修复表达方式迁移空目录问题 +- 修复reply_to空字段问题 +- 将metioned bot 和 at应用到focus prompt中 -- 新的情绪系统,麦麦现在拥有持续的情绪 -- -优化和修复: -- -- 优化no_reply逻辑 -- 优化Log显示 -- 优化关系配置 -- 简化配置文件 -- 修复在auto模式下,私聊会转为normal的bug -- 修复一般过滤次序问题 -- 优化normal_chat代码,采用和focus一致的关系构建 -- 优化计时信息和Log -- 添加回复超时检查 -- normal的插件允许llm激活 -- 合并action激活器 -- emoji统一可选随机激活或llm激活 -- 移除observation和processor,简化focus的代码逻辑 +## [0.9.0] - 2025-7-25 + +### 摘要 +MaiBot 0.9.0 重磅升级!本版本带来两大核心突破:**全面重构的插件系统**提供更强大的扩展能力和管理功能;**normal和focus模式统一化处理**大幅简化架构并提升性能。同时新增s4u prompt模式优化、语音消息支持、全新情绪系统和mais4u直播互动功能,为MaiBot带来更自然、更智能的交互体验! + +### 🌟 主要功能概览 + +#### 🔌 插件系统全面重构 - 重点升级 +- **完整管理API**: 全新的插件管理API,支持插件的启用、禁用、重载和卸载操作 +- **权限控制系统**: 为插件管理增加完善的权限控制,确保系统安全性 +- **智能依赖管理**: 优化插件依赖管理和自动注册机制,减少配置复杂度 + +#### ⚡ Normal和Focus模式统一化处理 - 重点升级 +- **架构统一**: 彻底统一normal和focus聊天模式,消除模式间的差异和复杂性 +- **智能模式切换**: 优化频率控制和模式切换逻辑,normal可以无缝切换到focus +- **统一LLM激活**: normal模式现在支持LLM激活插件,与focus模式功能对等 +- **一致的关系构建**: normal采用与focus一致的关系构建机制,提升交互质量 +- **统一退出机制**: 为focus提供更合理的退出方法,简化状态管理 + +#### 🎯 s4u prompt模式 +- **s4u prompt模式**: 新增专门的s4u prompt构建方式,提供更好的交互效果 +- **配置化启用**: 可在配置文件中选择启用s4u prompt模式,灵活控制 +- **兼容性保持**: 与现有系统完全兼容,可随时切换启用或禁用 + +#### 🎤 语音消息支持 +- **Voice消息处理**: 新增对voice类型消息的支持,麦麦现在可以识别和处理语音消息(需要模型配置) + +#### 全新情绪系统 +- **持续情绪**: 麦麦现在拥有持续的情绪状态,情绪会影响回复风格和行为 + + +### 💻 更新预览 + +#### 关系系统优化 +- **prompt优化**: 优化关系prompt和person_info信息展示 +- **构建间隔**: 让关系构建间隔可配置,提升灵活性 +- **关系配置**: 优化关系配置,采用和focus一致的关系构建 + +#### 表情包系统升级 +- **识别增强**: 加强emoji的识别能力,优化emoji显示 +- **匹配精准**: 更精准的表情包匹配算法 + +#### 完善mais4u系统(需要amaidesu支持) +- **直播互动**: 新增mais4u直播功能,支持实时互动和思考状态展示 +- **动作控制**: 支持眨眼、微动作、注视等多种动作适配 + +#### 日志系统优化 +- **显示优化**: 优化Logger前缀映射、颜色格式和计时信息显示 +- **级别优化**: 优化日志级别和信息过滤,提升调试体验 +- **日志查看器**: 升级logger_viewer,移除无用脚本 + +#### 配置系统改进 +- **配置简化**: 简化配置文件,让配置更加精简易懂 +- **prompt显示**: 可选打开prompt显示功能 +- **配置更新**: 更好的配置文件更新机制和更新内容显示 + +#### 问题修复与优化 + +- 修复normal planner没有超时退出问题,添加回复超时检查 +- 重构no_reply逻辑,不再使用小模型,采用激活度决定 - 修复图片与文字混合兴趣值为0的情况 +- 适配无兴趣度消息处理 +- 优化Docker镜像构建流程,合并AMD64和ARM64构建步骤 +- 移除vtb插件和take_picture_plugin,功能已由其他系统接管,移除pfc遗留代码和其他过时功能 +- 移除observation和processor等冗余组件,大幅简化focus代码逻辑 +- 修复了LPMM的学习问题 + ## [0.8.1] - 2025-7-5 diff --git a/changes.md b/changes.md index 407537d28..7d4f2ae8f 100644 --- a/changes.md +++ b/changes.md @@ -20,6 +20,9 @@ - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。 +6. 增加了插件和组件管理的API。 +7. `BaseCommand`的`execute`方法现在返回一个三元组,包含是否执行成功、可选的回复消息和是否拦截消息。 + - 这意味着你终于可以动态控制是否继续后续消息的处理了。 # 插件系统修改 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** @@ -45,6 +48,17 @@ 10. 修正了`main.py`中的错误输出。 11. 修正了`command`所编译的`Pattern`注册时的错误输出。 12. `events_manager`有了task相关逻辑了。 +13. 现在有了插件卸载和重载功能了,也就是热插拔。 +14. 实现了组件的全局启用和禁用功能。 + - 通过`enable_component`和`disable_component`方法来启用或禁用组件。 + - 不过这个操作不会保存到配置文件~ +15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。 + - 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作 + - 同样不保存到配置文件~ + +# 官方插件修改 +1. `HelloWorld`插件现在有一个样例的`EventHandler`。 +2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。 ### TODO 把这个看起来就很别扭的config获取方式改一下 @@ -64,4 +78,7 @@ else: plugin_path = Path(plugin_file) module_name = ".".join(plugin_path.parent.parts) ``` -这两个区别很大的。 \ No newline at end of file +这两个区别很大的。 + +### 执笔BGM +塞壬唱片! \ No newline at end of file diff --git a/docs/plugins/action-components.md b/docs/plugins/action-components.md index d68d87076..3953c79c2 100644 --- a/docs/plugins/action-components.md +++ b/docs/plugins/action-components.md @@ -4,42 +4,183 @@ Action是给麦麦在回复之外提供额外功能的智能组件,**由麦麦的决策系统自主选择是否使用**,具有随机性和拟人化的调用特点。Action不是直接响应用户命令,而是让麦麦根据聊天情境智能地选择合适的动作,使其行为更加自然和真实。 -### 🎯 Action的特点 +### Action的特点 - 🧠 **智能激活**:麦麦根据多种条件智能判断是否使用 -- 🎲 **随机性**:增加行为的不可预测性,更接近真人交流 +- 🎲 **可随机性**:可以使用随机数激活,增加行为的不可预测性,更接近真人交流 - 🤖 **拟人化**:让麦麦的回应更自然、更有个性 - 🔄 **情境感知**:基于聊天上下文做出合适的反应 -## 🎯 两层决策机制 +--- + +## 🎯 Action组件的基本结构 +首先,所有的Action都应该继承`BaseAction`类。 + +其次,每个Action组件都应该实现以下基本信息: +```python +class ExampleAction(BaseAction): + action_name = "example_action" # 动作的唯一标识符 + action_description = "这是一个示例动作" # 动作描述 + activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例 + mode_enable = ChatMode.ALL # 这里以 ALL 为例 + associated_types = ["text", "emoji", ...] # 关联类型 + parallel_action = False # 是否允许与其他Action并行执行 + action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...} + # Action使用场景描述 - 帮助LLM判断何时"选择"使用 + action_require = ["使用场景描述1", "使用场景描述2", ...] + + async def execute(self) -> Tuple[bool, str]: + """ + 执行Action的主要逻辑 + + Returns: + Tuple[bool, str]: (是否成功, 执行结果描述) + """ + # ---- 执行动作的逻辑 ---- + return True, "执行成功" +``` +#### associated_types: 该Action会发送的消息类型,例如文本、表情等。 + +这部分由Adapter传递给处理器。 + +以 MaiBot-Napcat-Adapter 为例,可选项目如下: +| 类型 | 说明 | 格式 | +| --- | --- | --- | +| text | 文本消息 | str | +| emoji | 表情消息 | str: 表情包的无头base64| +| image | 图片消息 | str: 图片的无头base64 | +| reply | 回复消息 | str: 回复的消息ID | +| voice | 语音消息 | str: wav格式语音的无头base64 | +| command | 命令消息 | 参见Adapter文档 | +| voiceurl | 语音URL消息 | str: wav格式语音的URL | +| music | 音乐消息 | str: 这首歌在网易云音乐的音乐id | +| videourl | 视频URL消息 | str: 视频的URL | +| file | 文件消息 | str: 文件的路径 | + +**请知悉,对于不同的处理器,其支持的消息类型可能会有所不同。在开发时请注意。** + +#### action_parameters: 该Action的参数说明。 +这是一个字典,键为参数名,值为参数说明。这个字段可以帮助LLM理解如何使用这个Action,并由LLM返回对应的参数,最后传递到 Action 的 action_data 属性中。其格式与你定义的格式完全相同 **(除非LLM哈气了,返回了错误的内容)**。 + +--- + +## 🎯 Action 调用的决策机制 Action采用**两层决策机制**来优化性能和决策质量: -### 第一层:激活控制(Activation Control) +> 设计目的:在加载许多插件的时候降低LLM决策压力,避免让麦麦在过多的选项中纠结。 -**激活决定麦麦是否"知道"这个Action的存在**,即这个Action是否进入决策候选池。**不被激活的Action麦麦永远不会选择**。 +**第一层:激活控制(Activation Control)** -> 🎯 **设计目的**:在加载许多插件的时候降低LLM决策压力,避免让麦麦在过多的选项中纠结。 +激活决定麦麦是否 **“知道”** 这个Action的存在,即这个Action是否进入决策候选池。不被激活的Action麦麦永远不会选择。 -#### 激活类型说明 +**第二层:使用决策(Usage Decision)** -| 激活类型 | 说明 | 使用场景 | -| ------------- | ------------------------------------------- | ------------------------ | -| `NEVER` | 从不激活,Action对麦麦不可见 | 临时禁用某个Action | -| `ALWAYS` | 永远激活,Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 | -| `LLM_JUDGE` | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 | -| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 | -| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 | +在Action被激活后,使用条件决定麦麦什么时候会 **“选择”** 使用这个Action。 -#### 聊天模式控制 +### 决策参数详解 🔧 -| 模式 | 说明 | -| ------------------- | ------------------------ | -| `ChatMode.FOCUS` | 仅在专注聊天模式下可激活 | -| `ChatMode.NORMAL` | 仅在普通聊天模式下可激活 | -| `ChatMode.ALL` | 所有模式下都可激活 | +#### 第一层:ActivationType 激活类型说明 -### 第二层:使用决策(Usage Decision) +| 激活类型 | 说明 | 使用场景 | +| ----------- | ---------------------------------------- | ---------------------- | +| [`NEVER`](#never-激活) | 从不激活,Action对麦麦不可见 | 临时禁用某个Action | +| [`ALWAYS`](#always-激活) | 永远激活,Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 | +| [`LLM_JUDGE`](#llm_judge-激活) | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 | +| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 | +| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 | + +#### `NEVER` 激活 + +`ActionActivationType.NEVER` 会使得 Action 永远不会被激活 + +```python +class DisabledAction(BaseAction): + activation_type = ActionActivationType.NEVER # 永远不激活 + + async def execute(self) -> Tuple[bool, str]: + # 这个Action永远不会被执行 + return False, "这个Action被禁用" +``` + +#### `ALWAYS` 激活 + +`ActionActivationType.ALWAYS` 会使得 Action 永远会被激活,即一直在 Action 候选池中 + +这种激活方式常用于核心功能,如回复或不回复。 + +```python +class AlwaysActivatedAction(BaseAction): + activation_type = ActionActivationType.ALWAYS # 永远激活 + + async def execute(self) -> Tuple[bool, str]: + # 执行核心功能 + return True, "执行了核心功能" +``` + +#### `LLM_JUDGE` 激活 + +`ActionActivationType.LLM_JUDGE`会使得这个 Action 根据 LLM 的判断来决定是否加入候选池。 + +而 LLM 的判断是基于代码中预设的`llm_judge_prompt`和自动提供的聊天上下文进行的。 + +因此使用此种方法需要实现`llm_judge_prompt`属性。 + +```python +class LLMJudgedAction(BaseAction): + activation_type = ActionActivationType.LLM_JUDGE # 通过LLM判断激活 + # LLM判断提示词 + llm_judge_prompt = ( + "判定是否需要使用这个动作的条件:\n" + "1. 用户希望调用XXX这个动作\n" + "...\n" + "请回答\"是\"或\"否\"。\n" + ) + + async def execute(self) -> Tuple[bool, str]: + # 根据LLM判断是否执行 + return True, "执行了LLM判断功能" +``` + +#### `RANDOM` 激活 + +`ActionActivationType.RANDOM`会使得这个 Action 根据随机概率决定是否加入候选池。 + +概率则由代码中的`random_activation_probability`控制。在内部实现中我们使用了`random.random()`来生成一个0到1之间的随机数,并与这个概率进行比较。 + +因此使用这个方法需要实现`random_activation_probability`属性。 + +```python +class SurpriseAction(BaseAction): + activation_type = ActionActivationType.RANDOM # 基于随机概率激活 + # 随机激活概率 + random_activation_probability = 0.1 # 10%概率激活 + + async def execute(self) -> Tuple[bool, str]: + # 执行惊喜动作 + return True, "发送了惊喜内容" +``` + +#### `KEYWORD` 激活 + +`ActionActivationType.KEYWORD`会使得这个 Action 在检测到特定关键词时激活。 + +关键词由代码中的`activation_keywords`定义,而`keyword_case_sensitive`则控制关键词匹配时是否区分大小写。在内部实现中,我们使用了`in`操作符来检查消息内容是否包含这些关键词。 + +因此,使用此种方法需要实现`activation_keywords`和`keyword_case_sensitive`属性。 + +```python +class GreetingAction(BaseAction): + activation_type = ActionActivationType.KEYWORD # 关键词激活 + activation_keywords = ["你好", "hello", "hi", "嗨"] # 关键词配置 + keyword_case_sensitive = False # 不区分大小写 + + async def execute(self) -> Tuple[bool, str]: + # 执行问候逻辑 + return True, "发送了问候" +``` + +#### 第二层:使用决策 **在Action被激活后,使用条件决定麦麦什么时候会"选择"使用这个Action**。 @@ -49,17 +190,16 @@ Action采用**两层决策机制**来优化性能和决策质量: - `action_parameters`:所需参数,影响Action的可执行性 - 当前聊天上下文和麦麦的决策逻辑 -### 🎬 决策流程示例 +--- -假设有一个"发送表情"Action: +### 决策流程示例 ```python class EmojiAction(BaseAction): # 第一层:激活控制 - focus_activation_type = ActionActivationType.RANDOM # 专注模式下随机激活 - normal_activation_type = ActionActivationType.KEYWORD # 普通模式下关键词激活 - activation_keywords = ["表情", "emoji", "😊"] - + activation_type = ActionActivationType.RANDOM # 随机激活 + random_activation_probability = 0.1 # 10%概率激活 + # 第二层:使用决策 action_require = [ "表达情绪时可以选择使用", @@ -72,311 +212,85 @@ class EmojiAction(BaseAction): 1. **第一层激活判断**: - - 普通模式:只有当用户消息包含"表情"、"emoji"或"😊"时,麦麦才"知道"可以使用这个Action - - 专注模式:随机激活,有概率让麦麦"看到"这个Action + - 使用随机数进行决策,当`random.random() < self.random_activation_probability`时,麦麦才"知道"可以使用这个Action 2. **第二层使用决策**: - - 即使Action被激活,麦麦还会根据 `action_require`中的条件判断是否真正选择使用 + - 即使Action被激活,麦麦还会根据 `action_require` 中的条件判断是否真正选择使用 - 例如:如果刚刚已经发过表情,根据"不要连续发送多个表情"的要求,麦麦可能不会选择这个Action -## 📋 Action必须项清单 - -每个Action类都**必须**包含以下属性: - -### 1. 激活控制必须项 +--- +## Action 内置属性说明 ```python -# 专注模式下的激活类型 -focus_activation_type = ActionActivationType.LLM_JUDGE - -# 普通模式下的激活类型 -normal_activation_type = ActionActivationType.KEYWORD - -# 启用的聊天模式 -mode_enable = ChatMode.ALL - -# 是否允许与其他Action并行执行 -parallel_action = False -``` - -### 2. 基本信息必须项 - -```python -# Action的唯一标识名称 -action_name = "my_action" - -# Action的功能描述 -action_description = "描述这个Action的具体功能和用途" -``` - -### 3. 功能定义必须项 - -```python -# Action参数定义 - 告诉LLM执行时需要什么参数 -action_parameters = { - "param1": "参数1的说明", - "param2": "参数2的说明" -} - -# Action使用场景描述 - 帮助LLM判断何时"选择"使用 -action_require = [ - "使用场景描述1", - "使用场景描述2" -] - -# 关联的消息类型 - 说明Action能处理什么类型的内容 -associated_types = ["text", "emoji", "image"] -``` - -### 4. 执行方法必须项 - -```python -async def execute(self) -> Tuple[bool, str]: - """ - 执行Action的主要逻辑 - - Returns: - Tuple[bool, str]: (是否成功, 执行结果描述) - """ - # 执行动作的代码 - success = True - message = "动作执行成功" - - return success, message -``` - -## 🔧 激活类型详解 - -### KEYWORD激活 - -当检测到特定关键词时激活Action: - -```python -class GreetingAction(BaseAction): - focus_activation_type = ActionActivationType.KEYWORD - normal_activation_type = ActionActivationType.KEYWORD - - # 关键词配置 - activation_keywords = ["你好", "hello", "hi", "嗨"] - keyword_case_sensitive = False # 不区分大小写 - - async def execute(self) -> Tuple[bool, str]: - # 执行问候逻辑 - return True, "发送了问候" -``` - -### LLM_JUDGE激活 - -通过LLM智能判断是否激活: - -```python -class HelpAction(BaseAction): - focus_activation_type = ActionActivationType.LLM_JUDGE - normal_activation_type = ActionActivationType.LLM_JUDGE - - # LLM判断提示词 - llm_judge_prompt = """ - 判定是否需要使用帮助动作的条件: - 1. 用户表达了困惑或需要帮助 - 2. 用户提出了问题但没有得到满意答案 - 3. 对话中出现了技术术语或复杂概念 - - 请回答"是"或"否"。 - """ - - async def execute(self) -> Tuple[bool, str]: - # 执行帮助逻辑 - return True, "提供了帮助" -``` - -### RANDOM激活 - -基于随机概率激活: - -```python -class SurpriseAction(BaseAction): - focus_activation_type = ActionActivationType.RANDOM - normal_activation_type = ActionActivationType.RANDOM - - # 随机激活概率 - random_activation_probability = 0.1 # 10%概率激活 - - async def execute(self) -> Tuple[bool, str]: - # 执行惊喜动作 - return True, "发送了惊喜内容" -``` - -### ALWAYS激活 - -永远激活,常用于核心功能: - -```python -class CoreAction(BaseAction): - focus_activation_type = ActionActivationType.ALWAYS - normal_activation_type = ActionActivationType.ALWAYS - - async def execute(self) -> Tuple[bool, str]: - # 执行核心功能 - return True, "执行了核心功能" -``` - -### NEVER激活 - -从不激活,用于临时禁用: - -```python -class DisabledAction(BaseAction): - focus_activation_type = ActionActivationType.NEVER - normal_activation_type = ActionActivationType.NEVER - - async def execute(self) -> Tuple[bool, str]: - # 这个方法不会被调用 - return False, "已禁用" -``` - -## 📚 BaseAction内置属性和方法 - -### 内置属性 - -```python -class MyAction(BaseAction): +class BaseAction: def __init__(self): # 消息相关属性 - self.message # 当前消息对象 - self.chat_stream # 聊天流对象 - self.user_id # 用户ID - self.user_nickname # 用户昵称 - self.platform # 平台类型 (qq, telegram等) - self.chat_id # 聊天ID - self.is_group # 是否群聊 - - # Action相关属性 - self.action_data # Action执行时的数据 - self.thinking_id # 思考ID - self.matched_groups # 匹配到的组(如果有正则匹配) -``` + self.log_prefix: str # 日志前缀 + self.group_id: str # 群组ID + self.group_name: str # 群组名称 + self.user_id: str # 用户ID + self.user_nickname: str # 用户昵称 + self.platform: str # 平台类型 (qq, telegram等) + self.chat_id: str # 聊天ID + self.chat_stream: ChatStream # 聊天流对象 + self.is_group: bool # 是否群聊 -### 内置方法 + # 消息体 + self.action_message: dict # 消息数据 + + # Action相关属性 + self.action_data: dict # Action执行时的数据 + self.thinking_id: str # 思考ID +``` +action_message为一个字典,包含的键值对如下(省略了不必要的键值对) ```python -class MyAction(BaseAction): +{ + "message_id": "1234567890", # 消息id,str + "time": 1627545600.0, # 时间戳,float + "chat_id": "abcdef123456", # 聊天ID,str + "reply_to": None, # 回复消息id,str或None + "interest_value": 0.85, # 兴趣值,float + "is_mentioned": True, # 是否被提及,bool + "chat_info_last_active_time": 1627548600.0, # 最后活跃时间,float + "processed_plain_text": None, # 处理后的文本,str或None + "additional_config": None, # Adapter传来的additional_config,dict或None + "is_emoji": False, # 是否为表情,bool + "is_picid": False, # 是否为图片ID,bool + "is_command": False # 是否为命令,bool +} +``` + +部分值的格式请自行查询数据库。 + +--- + +## Action 内置方法说明 +```python +class BaseAction: # 配置相关 def get_config(self, key: str, default=None): - """获取配置值""" - pass + """获取插件配置值,使用嵌套键访问""" - # 消息发送相关 - async def send_text(self, text: str): + async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: + """等待新消息或超时""" + + async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool: """发送文本消息""" - pass - - async def send_emoji(self, emoji_base64: str): + + async def send_emoji(self, emoji_base64: str) -> bool: """发送表情包""" - pass - - async def send_image(self, image_base64: str): + + async def send_image(self, image_base64: str) -> bool: """发送图片""" - pass - - # 动作记录相关 - async def store_action_info(self, **kwargs): - """记录动作信息""" - pass + + async def send_custom(self, message_type: str, content: str, typing: bool = False, reply_to: str = "") -> bool: + """发送自定义类型消息""" + + async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None: + """存储动作信息到数据库""" + + async def send_command(self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True) -> bool: + """发送命令消息""" ``` - -## 🎯 完整Action示例 - -```python -from src.plugin_system import BaseAction, ActionActivationType, ChatMode -from typing import Tuple - -class ExampleAction(BaseAction): - """示例Action - 展示完整的Action结构""" - - # === 激活控制 === - focus_activation_type = ActionActivationType.LLM_JUDGE - normal_activation_type = ActionActivationType.KEYWORD - mode_enable = ChatMode.ALL - parallel_action = False - - # 关键词激活配置 - activation_keywords = ["示例", "测试", "example"] - keyword_case_sensitive = False - - # LLM判断提示词 - llm_judge_prompt = "当用户需要示例或测试功能时激活" - - # 随机激活概率(如果使用RANDOM类型) - random_activation_probability = 0.2 - - # === 基本信息 === - action_name = "example_action" - action_description = "这是一个示例Action,用于演示Action的完整结构" - - # === 功能定义 === - action_parameters = { - "content": "要处理的内容", - "type": "处理类型", - "options": "可选配置" - } - - action_require = [ - "用户需要示例功能时使用", - "适合用于测试和演示", - "不要在正式对话中频繁使用" - ] - - associated_types = ["text", "emoji"] - - async def execute(self) -> Tuple[bool, str]: - """执行示例Action""" - try: - # 获取Action参数 - content = self.action_data.get("content", "默认内容") - action_type = self.action_data.get("type", "default") - - # 获取配置 - enable_feature = self.get_config("example.enable_advanced", False) - max_length = self.get_config("example.max_length", 100) - - # 执行具体逻辑 - if action_type == "greeting": - await self.send_text(f"你好!这是示例内容:{content}") - elif action_type == "info": - await self.send_text(f"信息:{content[:max_length]}") - else: - await self.send_text("执行了示例Action") - - # 记录动作信息 - await self.store_action_info( - action_build_into_prompt=True, - action_prompt_display=f"执行了示例动作:{action_type}", - action_done=True - ) - - return True, f"示例Action执行成功,类型:{action_type}" - - except Exception as e: - return False, f"执行失败:{str(e)}" -``` - -## 🎯 最佳实践 - -### 1. Action设计原则 - -- **单一职责**:每个Action只负责一个明确的功能 -- **智能激活**:合理选择激活类型,避免过度激活 -- **清晰描述**:提供准确的`action_require`帮助LLM决策 -- **错误处理**:妥善处理执行过程中的异常情况 - -### 2. 性能优化 - -- **激活控制**:使用合适的激活类型减少不必要的LLM调用 -- **并行执行**:谨慎设置`parallel_action`,避免冲突 -- **资源管理**:及时释放占用的资源 - -### 3. 调试技巧 - -- **日志记录**:在关键位置添加日志 -- **参数验证**:检查`action_data`的有效性 -- **配置测试**:测试不同配置下的行为 +具体参数与用法参见`BaseAction`基类的定义。 \ No newline at end of file diff --git a/docs/plugins/api/emoji-api.md b/docs/plugins/api/emoji-api.md index 3346db9f9..6dd071b9a 100644 --- a/docs/plugins/api/emoji-api.md +++ b/docs/plugins/api/emoji-api.md @@ -8,6 +8,25 @@ from src.plugin_system.apis import emoji_api ``` +## 🆕 **二步走识别优化** + +从最新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案: + +### **收到表情包时的识别流程** +1. **第一步**:VLM视觉分析 - 生成详细描述 +2. **第二步**:LLM情感分析 - 基于详细描述提取核心情感标签 +3. **缓存机制**:将情感标签缓存到数据库,详细描述保存到Images表 + +### **注册表情包时的优化** +- **智能复用**:优先从Images表获取已有的详细描述 +- **避免重复**:如果表情包之前被收到过,跳过VLM调用 +- **性能提升**:减少不必要的AI调用,降低延时和成本 + +### **缓存策略** +- **ImageDescriptions表**:缓存最终的情感标签(用于快速显示) +- **Images表**:保存详细描述(用于注册时复用) +- **双重检查**:防止并发情况下的重复生成 + ## 主要功能 ### 1. 表情包获取 diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 14a9d16c5..11ff22bd8 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -77,9 +77,8 @@ class TimeCommand(BaseCommand): command_pattern = r"^/time$" # 精确匹配 "/time" 命令 command_help = "查询当前时间" command_examples = ["/time"] - intercept_message = True # 拦截消息,不让其他组件处理 - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> Tuple[bool, str, bool]: """执行时间查询""" import datetime @@ -92,7 +91,7 @@ class TimeCommand(BaseCommand): message = f"⏰ 当前时间:{time_str}" await self.send_text(message) - return True, f"显示了当前时间: {time_str}" + return True, f"显示了当前时间: {time_str}", True class PrintMessage(BaseEventHandler): @@ -118,17 +117,17 @@ class HelloWorldPlugin(BasePlugin): """Hello World插件 - 你的第一个MaiCore插件""" # 插件基本信息 - plugin_name = "hello_world_plugin" # 内部标识符 - enable_plugin = True - dependencies = [] # 插件依赖列表 - python_dependencies = [] # Python包依赖列表 - config_file_name = "config.toml" # 配置文件名 + plugin_name: str = "hello_world_plugin" # 内部标识符 + enable_plugin: bool = True + dependencies: List[str] = [] # 插件依赖列表 + python_dependencies: List[str] = [] # Python包依赖列表 + config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"} # 配置Schema定义 - config_schema = { + config_schema: dict = { "plugin": { "name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"), diff --git a/plugins/take_picture_plugin/_manifest.json b/plugins/take_picture_plugin/_manifest.json deleted file mode 100644 index 0488d1de1..000000000 --- a/plugins/take_picture_plugin/_manifest.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "manifest_version": 1, - "name": "AI拍照插件 (Take Picture Plugin)", - "version": "1.0.0", - "description": "基于AI图像生成的拍照插件,可以生成逼真的自拍照片,支持照片存储和展示功能。", - "author": { - "name": "SengokuCola", - "url": "https://github.com/SengokuCola" - }, - "license": "GPL-v3.0-or-later", - - "host_application": { - "min_version": "0.9.0" - }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": ["camera", "photo", "selfie", "ai", "image", "generation"], - "categories": ["AI Tools", "Image Processing", "Entertainment"], - - "default_locale": "zh-CN", - "locales_path": "_locales", - - "plugin_info": { - "is_built_in": false, - "plugin_type": "image_generator", - "api_dependencies": ["volcengine"], - "components": [ - { - "type": "action", - "name": "take_picture", - "description": "生成一张用手机拍摄的照片,比如自拍或者近照", - "activation_modes": ["keyword"], - "keywords": ["拍张照", "自拍", "发张照片", "看看你", "你的照片"] - }, - { - "type": "command", - "name": "show_recent_pictures", - "description": "展示最近生成的5张照片", - "pattern": "/show_pics" - } - ], - "features": [ - "AI驱动的自拍照生成", - "个性化照片风格", - "照片历史记录", - "缓存机制优化", - "火山引擎API集成" - ] - } -} \ No newline at end of file diff --git a/plugins/take_picture_plugin/plugin(deprecated).py b/plugins/take_picture_plugin/plugin(deprecated).py deleted file mode 100644 index 24e86fece..000000000 --- a/plugins/take_picture_plugin/plugin(deprecated).py +++ /dev/null @@ -1,517 +0,0 @@ -""" -拍照插件 - -功能特性: -- Action: 生成一张自拍照,prompt由人设和模板生成 -- Command: 展示最近生成的照片 - -#此插件并不完善 -#此插件并不完善 - -#此插件并不完善 - -#此插件并不完善 - -#此插件并不完善 - -#此插件并不完善 - -#此插件并不完善 - - - -包含组件: -- 拍照Action - 生成自拍照 -- 展示照片Command - 展示最近生成的照片 -""" - -from typing import List, Tuple, Type, Optional -import random -import datetime -import json -import os -import asyncio -import urllib.request -import urllib.error -import base64 -import traceback - -from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_command import BaseCommand -from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode -from src.plugin_system.base.config_types import ConfigField -from src.plugin_system import register_plugin -from src.common.logger import get_logger - -logger = get_logger("take_picture_plugin") - -# 定义数据目录常量 -DATA_DIR = os.path.join("data", "take_picture_data") -# 确保数据目录存在 -os.makedirs(DATA_DIR, exist_ok=True) -# 创建全局锁 -file_lock = asyncio.Lock() - - -class TakePictureAction(BaseAction): - """生成一张自拍照""" - - focus_activation_type = ActionActivationType.KEYWORD - normal_activation_type = ActionActivationType.KEYWORD - mode_enable = ChatMode.ALL - parallel_action = False - - action_name = "take_picture" - action_description = "生成一张用手机拍摄,比如自拍或者近照" - activation_keywords = ["拍张照", "自拍", "发张照片", "看看你", "你的照片"] - keyword_case_sensitive = False - - action_parameters = {} - - action_require = ["当用户想看你的照片时使用", "当用户让你发自拍时使用当想随手拍眼前的场景时使用"] - - associated_types = ["text", "image"] - - # 内置的Prompt模板,如果配置文件中没有定义,将使用这些模板 - DEFAULT_PROMPT_TEMPLATES = [ - "极其频繁无奇的iPhone自拍照,没有明确的主体或构图感,就是随手一拍的快照照片略带运动模糊,阳光或室内打光不均匀导致的轻微曝光过度,整体呈现出一种刻意的平庸感,就像是从口袋里拿手机时不小心拍到的一张自拍。主角是{name},{personality}" - ] - - # 简单的请求缓存,避免短时间内重复请求 - _request_cache = {} - - async def execute(self) -> Tuple[bool, Optional[str]]: - logger.info(f"{self.log_prefix} 执行拍照动作") - - try: - # 配置验证 - http_base_url = self.api.get_config("api.base_url") - http_api_key = self.api.get_config("api.volcano_generate_api_key") - - if not (http_base_url and http_api_key): - error_msg = "抱歉,照片生成功能所需的API配置(如API地址或密钥)不完整,无法提供服务。" - await self.send_text(error_msg) - logger.error(f"{self.log_prefix} HTTP调用配置缺失: base_url 或 volcano_generate_api_key.") - return False, "API配置不完整" - - # API密钥验证 - if http_api_key == "YOUR_DOUBAO_API_KEY_HERE": - error_msg = "照片生成功能尚未配置,请设置正确的API密钥。" - await self.send_text(error_msg) - logger.error(f"{self.log_prefix} API密钥未配置") - return False, "API密钥未配置" - - # 获取全局配置信息 - bot_nickname = self.api.get_global_config("bot.nickname", "麦麦") - bot_personality = self.api.get_global_config("personality.personality_core", "") - - personality_side = self.api.get_global_config("personality.personality_side", []) - if personality_side: - bot_personality += random.choice(personality_side) - - # 准备模板变量 - template_vars = {"name": bot_nickname, "personality": bot_personality} - - logger.info(f"{self.log_prefix} 使用的全局配置: name={bot_nickname}, personality={bot_personality}") - - # 尝试从配置文件获取模板,如果没有则使用默认模板 - templates = self.api.get_config("picture.prompt_templates", self.DEFAULT_PROMPT_TEMPLATES) - if not templates: - logger.warning(f"{self.log_prefix} 未找到有效的提示词模板,使用默认模板") - templates = self.DEFAULT_PROMPT_TEMPLATES - - prompt_template = random.choice(templates) - - # 填充模板 - final_prompt = prompt_template.format(**template_vars) - - logger.info(f"{self.log_prefix} 生成的最终Prompt: {final_prompt}") - - # 从配置获取参数 - model = self.api.get_config("picture.default_model", "doubao-seedream-3-0-t2i-250415") - size = self.api.get_config("picture.default_size", "1024x1024") - watermark = self.api.get_config("picture.default_watermark", True) - guidance_scale = self.api.get_config("picture.default_guidance_scale", 2.5) - seed = self.api.get_config("picture.default_seed", 42) - - # 检查缓存 - enable_cache = self.api.get_config("storage.enable_cache", True) - if enable_cache: - cache_key = self._get_cache_key(final_prompt, model, size) - if cache_key in self._request_cache: - cached_result = self._request_cache[cache_key] - logger.info(f"{self.log_prefix} 使用缓存的图片结果") - await self.send_text("我之前拍过类似的照片,用之前的结果~") - - # 直接发送缓存的结果 - send_success = await self._send_image(cached_result) - if send_success: - await self.send_text("这是我的照片,好看吗?") - return True, "照片已发送(缓存)" - else: - # 缓存失败,清除这个缓存项并继续正常流程 - del self._request_cache[cache_key] - - await self.send_text("正在为你拍照,请稍候...") - - try: - seed = random.randint(1, 1000000) - success, result = await asyncio.to_thread( - self._make_http_image_request, - prompt=final_prompt, - model=model, - size=size, - seed=seed, - guidance_scale=guidance_scale, - watermark=watermark, - ) - except Exception as e: - logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True) - traceback.print_exc() - success = False - result = f"照片生成服务遇到意外问题: {str(e)[:100]}" - - if success: - image_url = result - logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.") - - try: - encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url) - except Exception as e: - logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True) - traceback.print_exc() - encode_success = False - encode_result = f"图片下载或编码时发生内部错误: {str(e)[:100]}" - - if encode_success: - base64_image_string = encode_result - # 更新缓存 - if enable_cache: - self._update_cache(final_prompt, model, size, base64_image_string) - - # 发送图片 - send_success = await self._send_image(base64_image_string) - if send_success: - # 存储到文件 - await self._store_picture_info(final_prompt, image_url) - logger.info(f"{self.log_prefix} 成功生成并存储照片: {image_url}") - await self.send_text("当当当当~这是我刚拍的照片,好看吗?") - return True, f"成功生成照片: {image_url}" - else: - await self.send_text("照片生成了,但发送失败了,可能是格式问题...") - return False, "照片发送失败" - else: - await self.send_text(f"照片下载失败: {encode_result}") - return False, encode_result - else: - await self.send_text(f"哎呀,拍照失败了: {result}") - return False, result - - except Exception as e: - logger.error(f"{self.log_prefix} 执行拍照动作失败: {e}", exc_info=True) - traceback.print_exc() - await self.send_text("呜呜,拍照的时候出了一点小问题...") - return False, str(e) - - async def _store_picture_info(self, prompt: str, image_url: str): - """将照片信息存入日志文件""" - log_file = self.api.get_config("storage.log_file", "picture_log.json") - log_path = os.path.join(DATA_DIR, log_file) - max_photos = self.api.get_config("storage.max_photos", 50) - - async with file_lock: - try: - if os.path.exists(log_path): - with open(log_path, "r", encoding="utf-8") as f: - log_data = json.load(f) - else: - log_data = [] - except (json.JSONDecodeError, FileNotFoundError): - log_data = [] - - # 添加新照片 - log_data.append( - {"prompt": prompt, "image_url": image_url, "timestamp": datetime.datetime.now().isoformat()} - ) - - # 如果超过最大数量,删除最旧的 - if len(log_data) > max_photos: - log_data = sorted(log_data, key=lambda x: x.get("timestamp", ""), reverse=True)[:max_photos] - - try: - with open(log_path, "w", encoding="utf-8") as f: - json.dump(log_data, f, ensure_ascii=False, indent=4) - except Exception as e: - logger.error(f"{self.log_prefix} 写入照片日志文件失败: {e}", exc_info=True) - - def _make_http_image_request( - self, prompt: str, model: str, size: str, seed: int, guidance_scale: float, watermark: bool - ) -> Tuple[bool, str]: - """发送HTTP请求到火山引擎豆包API生成图片""" - try: - base_url = self.api.get_config("api.base_url") - api_key = self.api.get_config("api.volcano_generate_api_key") - - # 构建请求URL和头部 - endpoint = f"{base_url.rstrip('/')}/images/generations" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - - # 构建请求体 - request_body = { - "model": model, - "prompt": prompt, - "response_format": "url", - "size": size, - "seed": seed, - "guidance_scale": guidance_scale, - "watermark": watermark, - "api-key": api_key, - } - - # 创建请求对象 - req = urllib.request.Request( - endpoint, - data=json.dumps(request_body).encode("utf-8"), - headers=headers, - method="POST", - ) - - # 发送请求并获取响应 - with urllib.request.urlopen(req, timeout=60) as response: - response_data = json.loads(response.read().decode("utf-8")) - - # 解析响应 - image_url = None - if ( - isinstance(response_data.get("data"), list) - and response_data["data"] - and isinstance(response_data["data"][0], dict) - ): - image_url = response_data["data"][0].get("url") - elif response_data.get("url"): - image_url = response_data.get("url") - - if image_url: - return True, image_url - else: - error_msg = response_data.get("error", {}).get("message", "未知错误") - logger.error(f"API返回错误: {error_msg}") - return False, f"API错误: {error_msg}" - - except urllib.error.HTTPError as e: - error_body = e.read().decode("utf-8") - logger.error(f"HTTP错误 {e.code}: {error_body}") - return False, f"HTTP错误 {e.code}: {error_body[:100]}..." - except Exception as e: - logger.error(f"请求异常: {e}", exc_info=True) - return False, f"请求异常: {str(e)}" - - def _download_and_encode_base64(self, image_url: str) -> Tuple[bool, str]: - """下载图片并转换为Base64编码""" - try: - with urllib.request.urlopen(image_url) as response: - image_data = response.read() - - base64_encoded = base64.b64encode(image_data).decode("utf-8") - return True, base64_encoded - except Exception as e: - logger.error(f"图片下载编码失败: {e}", exc_info=True) - return False, str(e) - - async def _send_image(self, base64_image: str) -> bool: - """发送图片""" - try: - # 使用聊天流信息确定发送目标 - chat_stream = self.api.get_service("chat_stream") - if not chat_stream: - logger.error(f"{self.log_prefix} 没有可用的聊天流发送图片") - return False - - if chat_stream.group_info: - # 群聊 - return await self.api.send_message_to_target( - message_type="image", - content=base64_image, - platform=chat_stream.platform, - target_id=str(chat_stream.group_info.group_id), - is_group=True, - display_message="发送生成的照片", - ) - else: - # 私聊 - return await self.api.send_message_to_target( - message_type="image", - content=base64_image, - platform=chat_stream.platform, - target_id=str(chat_stream.user_info.user_id), - is_group=False, - display_message="发送生成的照片", - ) - except Exception as e: - logger.error(f"{self.log_prefix} 发送图片时出错: {e}") - return False - - @classmethod - def _get_cache_key(cls, description: str, model: str, size: str) -> str: - """生成缓存键""" - return f"{description}|{model}|{size}" - - def _update_cache(self, description: str, model: str, size: str, base64_image: str): - """更新缓存""" - max_cache_size = self.api.get_config("storage.max_cache_size", 10) - cache_key = self._get_cache_key(description, model, size) - - # 添加到缓存 - self._request_cache[cache_key] = base64_image - - # 如果缓存超过最大大小,删除最旧的项 - if len(self._request_cache) > max_cache_size: - oldest_key = next(iter(self._request_cache)) - del self._request_cache[oldest_key] - - -class ShowRecentPicturesCommand(BaseCommand): - """展示最近生成的照片""" - - command_name = "show_recent_pictures" - command_description = "展示最近生成的5张照片" - command_pattern = r"^/show_pics$" - command_help = "用法: /show_pics" - command_examples = ["/show_pics"] - intercept_message = True - - async def execute(self) -> Tuple[bool, Optional[str]]: - logger.info(f"{self.log_prefix} 执行展示最近照片命令") - log_file = self.api.get_config("storage.log_file", "picture_log.json") - log_path = os.path.join(DATA_DIR, log_file) - - async with file_lock: - try: - if not os.path.exists(log_path): - await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!") - return True, "没有照片日志文件" - - with open(log_path, "r", encoding="utf-8") as f: - log_data = json.load(f) - - if not log_data: - await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!") - return True, "没有照片" - - # 获取最新的5张照片 - recent_pics = sorted(log_data, key=lambda x: x["timestamp"], reverse=True)[:5] - - # 先发送文本消息 - await self.send_text("这是我最近拍的几张照片~") - - # 逐个发送图片 - for pic in recent_pics: - # 尝试获取图片URL - image_url = pic.get("image_url") - if image_url: - try: - # 下载图片并转换为Base64 - with urllib.request.urlopen(image_url) as response: - image_data = response.read() - base64_encoded = base64.b64encode(image_data).decode("utf-8") - - # 发送图片 - await self.send_type( - message_type="image", content=base64_encoded, display_message="发送最近的照片" - ) - except Exception as e: - logger.error(f"{self.log_prefix} 下载或发送照片失败: {e}", exc_info=True) - - return True, "成功展示最近的照片" - - except json.JSONDecodeError: - await self.send_text("照片记录文件好像损坏了...") - return False, "JSON解码错误" - except Exception as e: - logger.error(f"{self.log_prefix} 展示照片失败: {e}", exc_info=True) - await self.send_text("哎呀,查找照片的时候出错了。") - return False, str(e) - - -@register_plugin -class TakePicturePlugin(BasePlugin): - """拍照插件""" - - plugin_name = "take_picture_plugin" # 内部标识符 - enable_plugin = False - dependencies = [] # 插件依赖列表 - python_dependencies = [] # Python包依赖列表 - config_file_name = "config.toml" - - # 配置节描述 - config_section_descriptions = { - "plugin": "插件基本信息配置", - "api": "API相关配置,包含火山引擎API的访问信息", - "components": "组件启用控制", - "picture": "拍照功能核心配置", - "storage": "照片存储相关配置", - } - - # 配置Schema定义 - config_schema = { - "plugin": { - "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), - }, - "api": { - "base_url": ConfigField( - type=str, - default="https://ark.cn-beijing.volces.com/api/v3", - description="API基础URL", - example="https://api.example.com/v1", - ), - "volcano_generate_api_key": ConfigField( - type=str, default="YOUR_DOUBAO_API_KEY_HERE", description="火山引擎豆包API密钥", required=True - ), - }, - "components": { - "enable_take_picture_action": ConfigField(type=bool, default=True, description="是否启用拍照Action"), - "enable_show_pics_command": ConfigField(type=bool, default=True, description="是否启用展示照片Command"), - }, - "picture": { - "default_model": ConfigField( - type=str, - default="doubao-seedream-3-0-t2i-250415", - description="默认使用的文生图模型", - choices=["doubao-seedream-3-0-t2i-250415", "doubao-seedream-2-0-t2i"], - ), - "default_size": ConfigField( - type=str, - default="1024x1024", - description="默认图片尺寸", - example="1024x1024", - choices=["1024x1024", "1024x1280", "1280x1024", "1024x1536", "1536x1024"], - ), - "default_watermark": ConfigField(type=bool, default=True, description="是否默认添加水印"), - "default_guidance_scale": ConfigField( - type=float, default=2.5, description="模型指导强度,影响图片与提示的关联性", example="2.0" - ), - "default_seed": ConfigField(type=int, default=42, description="随机种子,用于复现图片"), - "prompt_templates": ConfigField( - type=list, default=TakePictureAction.DEFAULT_PROMPT_TEMPLATES, description="用于生成自拍照的prompt模板" - ), - }, - "storage": { - "max_photos": ConfigField(type=int, default=50, description="最大保存的照片数量"), - "log_file": ConfigField(type=str, default="picture_log.json", description="照片日志文件名"), - "enable_cache": ConfigField(type=bool, default=True, description="是否启用请求缓存"), - "max_cache_size": ConfigField(type=int, default=10, description="最大缓存数量"), - }, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - """返回插件包含的组件列表""" - components = [] - if self.get_config("components.enable_take_picture_action", True): - components.append((TakePictureAction.get_action_info(), TakePictureAction)) - if self.get_config("components.enable_show_pics_command", True): - components.append((ShowRecentPicturesCommand.get_command_info(), ShowRecentPicturesCommand)) - return components diff --git a/scripts/analyze_expression_similarity.py b/scripts/analyze_expression_similarity.py deleted file mode 100644 index d84d21db1..000000000 --- a/scripts/analyze_expression_similarity.py +++ /dev/null @@ -1,192 +0,0 @@ -import os -import json -from typing import List, Dict, Tuple -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -import glob -import sqlite3 -import re -from datetime import datetime - - -def clean_group_name(name: str) -> str: - """清理群组名称,只保留中文和英文字符""" - cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) - if not cleaned: - cleaned = datetime.now().strftime("%Y%m%d") - return cleaned - - -def get_group_name(stream_id: str) -> str: - """从数据库中获取群组名称""" - conn = sqlite3.connect("data/maibot.db") - cursor = conn.cursor() - - cursor.execute( - """ - SELECT group_name, user_nickname, platform - FROM chat_streams - WHERE stream_id = ? - """, - (stream_id,), - ) - - result = cursor.fetchone() - conn.close() - - if result: - group_name, user_nickname, platform = result - if group_name: - return clean_group_name(group_name) - if user_nickname: - return clean_group_name(user_nickname) - if platform: - return clean_group_name(f"{platform}{stream_id[:8]}") - return stream_id - - -def format_timestamp(timestamp: float) -> str: - """将时间戳转换为可读的时间格式""" - if not timestamp: - return "未知" - try: - dt = datetime.fromtimestamp(timestamp) - return dt.strftime("%Y-%m-%d %H:%M:%S") - except Exception as e: - print(f"时间戳格式化错误: {e}") - return "未知" - - -def load_expressions(chat_id: str) -> List[Dict]: - """加载指定群聊的表达方式""" - style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") - - style_exprs = [] - - if os.path.exists(style_file): - with open(style_file, "r", encoding="utf-8") as f: - style_exprs = json.load(f) - - return style_exprs - - -def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]: - """找出每个表达方式最相似的top_k个表达方式""" - if not expressions: - return {} - - # 分别准备情景和表达方式的文本数据 - situations = [expr["situation"] for expr in expressions] - styles = [expr["style"] for expr in expressions] - - # 使用TF-IDF向量化 - vectorizer = TfidfVectorizer() - situation_matrix = vectorizer.fit_transform(situations) - style_matrix = vectorizer.fit_transform(styles) - - # 计算余弦相似度 - situation_similarity = cosine_similarity(situation_matrix) - style_similarity = cosine_similarity(style_matrix) - - # 对每个表达方式找出最相似的top_k个 - similar_expressions = {} - for i, _ in enumerate(expressions): - # 获取相似度分数 - situation_scores = situation_similarity[i] - style_scores = style_similarity[i] - - # 获取top_k的索引(排除自己) - situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1] - style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1] - - similar_situations = [] - similar_styles = [] - - # 处理相似情景 - for idx in situation_indices: - if situation_scores[idx] > 0: # 只保留有相似度的 - similar_situations.append( - ( - expressions[idx]["situation"], - expressions[idx]["style"], # 添加对应的原始表达 - situation_scores[idx], - ) - ) - - # 处理相似表达 - for idx in style_indices: - if style_scores[idx] > 0: # 只保留有相似度的 - similar_styles.append( - ( - expressions[idx]["style"], - expressions[idx]["situation"], # 添加对应的原始情景 - style_scores[idx], - ) - ) - - if similar_situations or similar_styles: - similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles} - - return similar_expressions - - -def main(): - # 获取所有群聊ID - style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) - chat_ids = [os.path.basename(d) for d in style_dirs] - - if not chat_ids: - print("没有找到任何群聊的表达方式数据") - return - - print("可用的群聊:") - for i, chat_id in enumerate(chat_ids, 1): - group_name = get_group_name(chat_id) - print(f"{i}. {group_name}") - - while True: - try: - choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) - if choice == 0: - break - if 1 <= choice <= len(chat_ids): - chat_id = chat_ids[choice - 1] - break - print("无效的选择,请重试") - except ValueError: - print("请输入有效的数字") - - if choice == 0: - return - - # 加载表达方式 - style_exprs = load_expressions(chat_id) - - group_name = get_group_name(chat_id) - print(f"\n分析群聊 {group_name} 的表达方式:") - - similar_styles = find_similar_expressions(style_exprs) - for i, expr in enumerate(style_exprs): - if i in similar_styles: - print("\n" + "-" * 20) - print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}") - - if similar_styles[i]["styles"]: - print("\n\033[33m相似表达:\033[0m") - for similar_style, original_situation, score in similar_styles[i]["styles"]: - print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m") - - if similar_styles[i]["situations"]: - print("\n\033[32m相似情景:\033[0m") - for similar_situation, original_style, score in similar_styles[i]["situations"]: - print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m") - - print( - f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}" - ) - print("-" * 20) - - -if __name__ == "__main__": - main() diff --git a/scripts/analyze_expressions.py b/scripts/analyze_expressions.py deleted file mode 100644 index ecbb3f381..000000000 --- a/scripts/analyze_expressions.py +++ /dev/null @@ -1,215 +0,0 @@ -import os -import json -import time -import re -from datetime import datetime -from typing import Dict, List, Any -import sqlite3 - - -def clean_group_name(name: str) -> str: - """清理群组名称,只保留中文和英文字符""" - # 提取中文和英文字符 - cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) - # 如果清理后为空,使用当前日期 - if not cleaned: - cleaned = datetime.now().strftime("%Y%m%d") - return cleaned - - -def get_group_name(stream_id: str) -> str: - """从数据库中获取群组名称""" - conn = sqlite3.connect("data/maibot.db") - cursor = conn.cursor() - - cursor.execute( - """ - SELECT group_name, user_nickname, platform - FROM chat_streams - WHERE stream_id = ? - """, - (stream_id,), - ) - - result = cursor.fetchone() - conn.close() - - if result: - group_name, user_nickname, platform = result - if group_name: - return clean_group_name(group_name) - if user_nickname: - return clean_group_name(user_nickname) - if platform: - return clean_group_name(f"{platform}{stream_id[:8]}") - return stream_id - - -def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: - """加载指定群组的表达方式""" - learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") - learnt_grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") - personality_file = os.path.join("data", "expression", "personality", "expressions.json") - - style_expressions = [] - grammar_expressions = [] - personality_expressions = [] - - if os.path.exists(learnt_style_file): - with open(learnt_style_file, "r", encoding="utf-8") as f: - style_expressions = json.load(f) - - if os.path.exists(learnt_grammar_file): - with open(learnt_grammar_file, "r", encoding="utf-8") as f: - grammar_expressions = json.load(f) - - if os.path.exists(personality_file): - with open(personality_file, "r", encoding="utf-8") as f: - personality_expressions = json.load(f) - - return style_expressions, grammar_expressions, personality_expressions - - -def format_time(timestamp: float) -> str: - """格式化时间戳为可读字符串""" - return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - - -def write_expressions(f, expressions: List[Dict[str, Any]], title: str): - """写入表达方式列表""" - if not expressions: - f.write(f"{title}:暂无数据\n") - f.write("-" * 40 + "\n") - return - - f.write(f"{title}:\n") - for expr in expressions: - count = expr.get("count", 0) - last_active = expr.get("last_active_time", time.time()) - f.write(f"场景: {expr['situation']}\n") - f.write(f"表达: {expr['style']}\n") - f.write(f"计数: {count:.4f}\n") - f.write(f"最后活跃: {format_time(last_active)}\n") - f.write("-" * 40 + "\n") - - -def write_group_report( - group_file: str, - group_name: str, - chat_id: str, - style_exprs: List[Dict[str, Any]], - grammar_exprs: List[Dict[str, Any]], -): - """写入群组详细报告""" - with open(group_file, "w", encoding="utf-8") as gf: - gf.write(f"群组: {group_name} (ID: {chat_id})\n") - gf.write("=" * 80 + "\n\n") - - # 写入语言风格 - gf.write("【语言风格】\n") - gf.write("=" * 40 + "\n") - write_expressions(gf, style_exprs, "语言风格") - gf.write("\n") - - # 写入句法特点 - gf.write("【句法特点】\n") - gf.write("=" * 40 + "\n") - write_expressions(gf, grammar_exprs, "句法特点") - - -def analyze_expressions(): - """分析所有群组的表达方式""" - # 获取所有群组ID - style_dir = os.path.join("data", "expression", "learnt_style") - chat_ids = [d for d in os.listdir(style_dir) if os.path.isdir(os.path.join(style_dir, d))] - - # 创建输出目录 - output_dir = "data/expression_analysis" - personality_dir = os.path.join(output_dir, "personality") - os.makedirs(output_dir, exist_ok=True) - os.makedirs(personality_dir, exist_ok=True) - - # 生成时间戳 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # 创建总报告 - summary_file = os.path.join(output_dir, f"summary_{timestamp}.txt") - with open(summary_file, "w", encoding="utf-8") as f: - f.write(f"表达方式分析报告 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write("=" * 80 + "\n\n") - - # 先处理人格表达 - personality_exprs = [] - personality_file = os.path.join("data", "expression", "personality", "expressions.json") - if os.path.exists(personality_file): - with open(personality_file, "r", encoding="utf-8") as pf: - personality_exprs = json.load(pf) - - # 保存人格表达总数 - total_personality = len(personality_exprs) - - # 排序并取前20条 - personality_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) - personality_exprs = personality_exprs[:20] - - # 写入人格表达报告 - personality_report = os.path.join(personality_dir, f"expressions_{timestamp}.txt") - with open(personality_report, "w", encoding="utf-8") as pf: - pf.write("【人格表达方式】\n") - pf.write("=" * 40 + "\n") - write_expressions(pf, personality_exprs, "人格表达") - - # 写入总报告摘要中的人格表达部分 - f.write("【人格表达方式】\n") - f.write("=" * 40 + "\n") - f.write(f"人格表达总数: {total_personality} (显示前20条)\n") - f.write(f"详细报告: {personality_report}\n") - f.write("-" * 40 + "\n\n") - - # 处理各个群组的表达方式 - f.write("【群组表达方式】\n") - f.write("=" * 40 + "\n\n") - - for chat_id in chat_ids: - style_exprs, grammar_exprs, _ = load_expressions(chat_id) - - # 保存总数 - total_style = len(style_exprs) - total_grammar = len(grammar_exprs) - - # 分别排序 - style_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) - grammar_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) - - # 只取前20条 - style_exprs = style_exprs[:20] - grammar_exprs = grammar_exprs[:20] - - # 获取群组名称 - group_name = get_group_name(chat_id) - - # 创建群组子目录(使用清理后的名称) - safe_group_name = clean_group_name(group_name) - group_dir = os.path.join(output_dir, f"{safe_group_name}_{chat_id}") - os.makedirs(group_dir, exist_ok=True) - - # 写入群组详细报告 - group_file = os.path.join(group_dir, f"expressions_{timestamp}.txt") - write_group_report(group_file, group_name, chat_id, style_exprs, grammar_exprs) - - # 写入总报告摘要 - f.write(f"群组: {group_name} (ID: {chat_id})\n") - f.write("-" * 40 + "\n") - f.write(f"语言风格总数: {total_style} (显示前20条)\n") - f.write(f"句法特点总数: {total_grammar} (显示前20条)\n") - f.write(f"详细报告: {group_file}\n") - f.write("-" * 40 + "\n\n") - - print("分析报告已生成:") - print(f"总报告: {summary_file}") - print(f"人格表达报告: {personality_report}") - print(f"各群组详细报告位于: {output_dir}") - - -if __name__ == "__main__": - analyze_expressions() diff --git a/scripts/analyze_group_similarity.py b/scripts/analyze_group_similarity.py deleted file mode 100644 index f1d53ee20..000000000 --- a/scripts/analyze_group_similarity.py +++ /dev/null @@ -1,196 +0,0 @@ -import json -from pathlib import Path -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -import matplotlib.pyplot as plt -import seaborn as sns -import sqlite3 - -# 设置中文字体 -plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 使用微软雅黑 -plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号 -plt.rcParams["font.family"] = "sans-serif" - -# 获取脚本所在目录 -SCRIPT_DIR = Path(__file__).parent - - -def get_group_name(stream_id): - """从数据库中获取群组名称""" - conn = sqlite3.connect("data/maibot.db") - cursor = conn.cursor() - - cursor.execute( - """ - SELECT group_name, user_nickname, platform - FROM chat_streams - WHERE stream_id = ? - """, - (stream_id,), - ) - - result = cursor.fetchone() - conn.close() - - if result: - group_name, user_nickname, platform = result - if group_name: - return group_name - if user_nickname: - return user_nickname - if platform: - return f"{platform}-{stream_id[:8]}" - return stream_id - - -def load_group_data(group_dir): - """加载单个群组的数据""" - json_path = Path(group_dir) / "expressions.json" - if not json_path.exists(): - return [], [], [], 0 - - with open(json_path, "r", encoding="utf-8") as f: - data = json.load(f) - - situations = [] - styles = [] - combined = [] - total_count = sum(item["count"] for item in data) - - for item in data: - count = item["count"] - situations.extend([item["situation"]] * int(count)) - styles.extend([item["style"]] * int(count)) - combined.extend([f"{item['situation']} {item['style']}"] * int(count)) - - return situations, styles, combined, total_count - - -def analyze_group_similarity(): - # 获取所有群组目录 - base_dir = Path("data/expression/learnt_style") - group_dirs = [d for d in base_dir.iterdir() if d.is_dir()] - - # 加载所有群组的数据并过滤 - valid_groups = [] - valid_names = [] - valid_situations = [] - valid_styles = [] - valid_combined = [] - - for d in group_dirs: - situations, styles, combined, total_count = load_group_data(d) - if total_count >= 50: # 只保留数据量大于等于50的群组 - valid_groups.append(d) - valid_names.append(get_group_name(d.name)) - valid_situations.append(" ".join(situations)) - valid_styles.append(" ".join(styles)) - valid_combined.append(" ".join(combined)) - - if not valid_groups: - print("没有找到数据量大于等于50的群组") - return - - # 创建TF-IDF向量化器 - vectorizer = TfidfVectorizer() - - # 计算三种相似度矩阵 - situation_matrix = cosine_similarity(vectorizer.fit_transform(valid_situations)) - style_matrix = cosine_similarity(vectorizer.fit_transform(valid_styles)) - combined_matrix = cosine_similarity(vectorizer.fit_transform(valid_combined)) - - # 对相似度矩阵进行对数变换 - log_situation_matrix = np.log10(situation_matrix * 100 + 1) * 10 / np.log10(4) - log_style_matrix = np.log10(style_matrix * 100 + 1) * 10 / np.log10(4) - log_combined_matrix = np.log10(combined_matrix * 100 + 1) * 10 / np.log10(4) - - # 创建一个大图,包含三个子图 - plt.figure(figsize=(45, 12)) - - # 场景相似度热力图 - plt.subplot(1, 3, 1) - sns.heatmap( - log_situation_matrix, - xticklabels=valid_names, - yticklabels=valid_names, - cmap="YlOrRd", - annot=True, - fmt=".1f", - vmin=0, - vmax=30, - ) - plt.title("群组场景相似度热力图 (对数百分比)") - plt.xticks(rotation=45, ha="right") - - # 表达方式相似度热力图 - plt.subplot(1, 3, 2) - sns.heatmap( - log_style_matrix, - xticklabels=valid_names, - yticklabels=valid_names, - cmap="YlOrRd", - annot=True, - fmt=".1f", - vmin=0, - vmax=30, - ) - plt.title("群组表达方式相似度热力图 (对数百分比)") - plt.xticks(rotation=45, ha="right") - - # 组合相似度热力图 - plt.subplot(1, 3, 3) - sns.heatmap( - log_combined_matrix, - xticklabels=valid_names, - yticklabels=valid_names, - cmap="YlOrRd", - annot=True, - fmt=".1f", - vmin=0, - vmax=30, - ) - plt.title("群组场景+表达方式相似度热力图 (对数百分比)") - plt.xticks(rotation=45, ha="right") - - plt.tight_layout() - plt.savefig(SCRIPT_DIR / "group_similarity_heatmaps.png", dpi=300, bbox_inches="tight") - plt.close() - - # 保存匹配详情到文本文件 - with open(SCRIPT_DIR / "group_similarity_details.txt", "w", encoding="utf-8") as f: - f.write("群组相似度详情\n") - f.write("=" * 50 + "\n\n") - - for i in range(len(valid_names)): - for j in range(i + 1, len(valid_names)): - if log_combined_matrix[i][j] > 50: - f.write(f"群组1: {valid_names[i]}\n") - f.write(f"群组2: {valid_names[j]}\n") - f.write(f"场景相似度: {situation_matrix[i][j]:.4f}\n") - f.write(f"表达方式相似度: {style_matrix[i][j]:.4f}\n") - f.write(f"组合相似度: {combined_matrix[i][j]:.4f}\n") - - # 获取两个群组的数据 - situations1, styles1, _ = load_group_data(valid_groups[i]) - situations2, styles2, _ = load_group_data(valid_groups[j]) - - # 找出共同的场景 - common_situations = set(situations1) & set(situations2) - if common_situations: - f.write("\n共同场景:\n") - for situation in common_situations: - f.write(f"- {situation}\n") - - # 找出共同的表达方式 - common_styles = set(styles1) & set(styles2) - if common_styles: - f.write("\n共同表达方式:\n") - for style in common_styles: - f.write(f"- {style}\n") - - f.write("\n" + "-" * 50 + "\n\n") - - -if __name__ == "__main__": - analyze_group_similarity() diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py new file mode 100644 index 000000000..4e761d8d1 --- /dev/null +++ b/scripts/expression_stats.py @@ -0,0 +1,208 @@ +import time +import sys +import os +from typing import Dict, List + +# Add project root to Python path +from src.common.database.database_model import Expression, ChatStreams +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + + + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + # 直接从数据库查询ChatStreams表 + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + # 如果有群组信息,显示群组名称 + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + # 如果是私聊,显示用户昵称 + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def calculate_time_distribution(expressions) -> Dict[str, int]: + """Calculate distribution of last active time in days""" + now = time.time() + distribution = { + '0-1天': 0, + '1-3天': 0, + '3-7天': 0, + '7-14天': 0, + '14-30天': 0, + '30-60天': 0, + '60-90天': 0, + '90+天': 0 + } + for expr in expressions: + diff_days = (now - expr.last_active_time) / (24*3600) + if diff_days < 1: + distribution['0-1天'] += 1 + elif diff_days < 3: + distribution['1-3天'] += 1 + elif diff_days < 7: + distribution['3-7天'] += 1 + elif diff_days < 14: + distribution['7-14天'] += 1 + elif diff_days < 30: + distribution['14-30天'] += 1 + elif diff_days < 60: + distribution['30-60天'] += 1 + elif diff_days < 90: + distribution['60-90天'] += 1 + else: + distribution['90+天'] += 1 + return distribution + + +def calculate_count_distribution(expressions) -> Dict[str, int]: + """Calculate distribution of count values""" + distribution = { + '0-1': 0, + '1-2': 0, + '2-3': 0, + '3-4': 0, + '4-5': 0, + '5-10': 0, + '10+': 0 + } + for expr in expressions: + cnt = expr.count + if cnt < 1: + distribution['0-1'] += 1 + elif cnt < 2: + distribution['1-2'] += 1 + elif cnt < 3: + distribution['2-3'] += 1 + elif cnt < 4: + distribution['3-4'] += 1 + elif cnt < 5: + distribution['4-5'] += 1 + elif cnt < 10: + distribution['5-10'] += 1 + else: + distribution['10+'] += 1 + return distribution + + +def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: + """Get top N most used expressions for a specific chat_id""" + return (Expression.select() + .where(Expression.chat_id == chat_id) + .order_by(Expression.count.desc()) + .limit(top_n)) + + +def show_overall_statistics(expressions, total: int) -> None: + """Show overall statistics""" + time_dist = calculate_time_distribution(expressions) + count_dist = calculate_count_distribution(expressions) + + print("\n=== 总体统计 ===") + print(f"总表达式数量: {total}") + + print("\n上次激活时间分布:") + for period, count in time_dist.items(): + print(f"{period}: {count} ({count/total*100:.2f}%)") + + print("\ncount分布:") + for range_, count in count_dist.items(): + print(f"{range_}: {count} ({count/total*100:.2f}%)") + + +def show_chat_statistics(chat_id: str, chat_name: str) -> None: + """Show statistics for a specific chat""" + chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id)) + chat_total = len(chat_exprs) + + print(f"\n=== {chat_name} ===") + print(f"表达式数量: {chat_total}") + + if chat_total == 0: + print("该聊天没有表达式数据") + return + + # Time distribution for this chat + time_dist = calculate_time_distribution(chat_exprs) + print("\n上次激活时间分布:") + for period, count in time_dist.items(): + if count > 0: + print(f"{period}: {count} ({count/chat_total*100:.2f}%)") + + # Count distribution for this chat + count_dist = calculate_count_distribution(chat_exprs) + print("\ncount分布:") + for range_, count in count_dist.items(): + if count > 0: + print(f"{range_}: {count} ({count/chat_total*100:.2f}%)") + + # Top expressions + print("\nTop 10使用最多的表达式:") + top_exprs = get_top_expressions_by_chat(chat_id, 10) + for i, expr in enumerate(top_exprs, 1): + print(f"{i}. [{expr.type}] Count: {expr.count}") + print(f" Situation: {expr.situation}") + print(f" Style: {expr.style}") + print() + + +def interactive_menu() -> None: + """Interactive menu for expression statistics""" + # Get all expressions + expressions = list(Expression.select()) + if not expressions: + print("数据库中没有找到表达式") + return + + total = len(expressions) + + # Get unique chat_ids and their names + chat_ids = list(set(expr.chat_id for expr in expressions)) + chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids] + chat_info.sort(key=lambda x: x[1]) # Sort by chat name + + while True: + print("\n" + "="*50) + print("表达式统计分析") + print("="*50) + print("0. 显示总体统计") + + for i, (chat_id, chat_name) in enumerate(chat_info, 1): + chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id) + print(f"{i}. {chat_name} ({chat_count}个表达式)") + + print("q. 退出") + + choice = input("\n请选择要查看的统计 (输入序号): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + try: + choice_num = int(choice) + if choice_num == 0: + show_overall_statistics(expressions, total) + elif 1 <= choice_num <= len(chat_info): + chat_id, chat_name = chat_info[choice_num - 1] + show_chat_statistics(chat_id, chat_name) + else: + print("无效的选择,请重新输入") + except ValueError: + print("请输入有效的数字") + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/scripts/find_similar_expression.py b/scripts/find_similar_expression.py deleted file mode 100644 index 23f9e63d9..000000000 --- a/scripts/find_similar_expression.py +++ /dev/null @@ -1,252 +0,0 @@ -import os -import sys - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import json -from typing import List, Dict, Tuple -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -import glob -import sqlite3 -import re -from datetime import datetime -import random -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config - - -def clean_group_name(name: str) -> str: - """清理群组名称,只保留中文和英文字符""" - cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) - if not cleaned: - cleaned = datetime.now().strftime("%Y%m%d") - return cleaned - - -def get_group_name(stream_id: str) -> str: - """从数据库中获取群组名称""" - conn = sqlite3.connect("data/maibot.db") - cursor = conn.cursor() - - cursor.execute( - """ - SELECT group_name, user_nickname, platform - FROM chat_streams - WHERE stream_id = ? - """, - (stream_id,), - ) - - result = cursor.fetchone() - conn.close() - - if result: - group_name, user_nickname, platform = result - if group_name: - return clean_group_name(group_name) - if user_nickname: - return clean_group_name(user_nickname) - if platform: - return clean_group_name(f"{platform}{stream_id[:8]}") - return stream_id - - -def load_expressions(chat_id: str) -> List[Dict]: - """加载指定群聊的表达方式""" - style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") - - style_exprs = [] - - if os.path.exists(style_file): - with open(style_file, "r", encoding="utf-8") as f: - style_exprs = json.load(f) - - # 如果表达方式超过10个,随机选择10个 - if len(style_exprs) > 50: - style_exprs = random.sample(style_exprs, 50) - print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配") - - return style_exprs - - -def find_similar_expressions_tfidf( - input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10 -) -> List[Tuple[str, str, float]]: - """使用TF-IDF方法找出与输入文本最相似的top_k个表达方式""" - if not expressions: - return [] - - # 准备文本数据 - if mode == "style": - texts = [expr["style"] for expr in expressions] - elif mode == "situation": - texts = [expr["situation"] for expr in expressions] - else: # both - texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] - - texts.append(input_text) # 添加输入文本 - - # 使用TF-IDF向量化 - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform(texts) - - # 计算余弦相似度 - similarity_matrix = cosine_similarity(tfidf_matrix) - - # 获取输入文本的相似度分数(最后一行) - scores = similarity_matrix[-1][:-1] # 排除与自身的相似度 - - # 获取top_k的索引 - top_indices = np.argsort(scores)[::-1][:top_k] - - # 获取相似表达 - similar_exprs = [] - for idx in top_indices: - if scores[idx] > 0: # 只保留有相似度的 - similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx])) - - return similar_exprs - - -async def find_similar_expressions_embedding( - input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5 -) -> List[Tuple[str, str, float]]: - """使用嵌入模型找出与输入文本最相似的top_k个表达方式""" - if not expressions: - return [] - - # 准备文本数据 - if mode == "style": - texts = [expr["style"] for expr in expressions] - elif mode == "situation": - texts = [expr["situation"] for expr in expressions] - else: # both - texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] - - # 获取嵌入向量 - llm_request = LLMRequest(global_config.model.embedding) - text_embeddings = [] - for text in texts: - embedding = await llm_request.get_embedding(text) - if embedding: - text_embeddings.append(embedding) - - input_embedding = await llm_request.get_embedding(input_text) - if not input_embedding or not text_embeddings: - return [] - - # 计算余弦相似度 - text_embeddings = np.array(text_embeddings) - similarities = np.dot(text_embeddings, input_embedding) / ( - np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding) - ) - - # 获取top_k的索引 - top_indices = np.argsort(similarities)[::-1][:top_k] - - # 获取相似表达 - similar_exprs = [] - for idx in top_indices: - if similarities[idx] > 0: # 只保留有相似度的 - similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx])) - - return similar_exprs - - -async def main(): - # 获取所有群聊ID - style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) - chat_ids = [os.path.basename(d) for d in style_dirs] - - if not chat_ids: - print("没有找到任何群聊的表达方式数据") - return - - print("可用的群聊:") - for i, chat_id in enumerate(chat_ids, 1): - group_name = get_group_name(chat_id) - print(f"{i}. {group_name}") - - while True: - try: - choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) - if choice == 0: - break - if 1 <= choice <= len(chat_ids): - chat_id = chat_ids[choice - 1] - break - print("无效的选择,请重试") - except ValueError: - print("请输入有效的数字") - - if choice == 0: - return - - # 加载表达方式 - style_exprs = load_expressions(chat_id) - - group_name = get_group_name(chat_id) - print(f"\n已选择群聊:{group_name}") - - # 选择匹配模式 - print("\n请选择匹配模式:") - print("1. 匹配表达方式") - print("2. 匹配情景") - print("3. 两者都考虑") - - while True: - try: - mode_choice = int(input("\n请选择匹配模式 (1-3): ")) - if 1 <= mode_choice <= 3: - break - print("无效的选择,请重试") - except ValueError: - print("请输入有效的数字") - - mode_map = {1: "style", 2: "situation", 3: "both"} - mode = mode_map[mode_choice] - - # 选择匹配方法 - print("\n请选择匹配方法:") - print("1. TF-IDF方法") - print("2. 嵌入模型方法") - - while True: - try: - method_choice = int(input("\n请选择匹配方法 (1-2): ")) - if 1 <= method_choice <= 2: - break - print("无效的选择,请重试") - except ValueError: - print("请输入有效的数字") - - while True: - input_text = input("\n请输入要匹配的文本(输入q退出): ") - if input_text.lower() == "q": - break - - if not input_text.strip(): - continue - - if method_choice == 1: - similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode) - else: - similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode) - - if similar_exprs: - print("\n找到以下相似表达:") - for style, situation, score in similar_exprs: - print(f"\n\033[33m表达方式:{style}\033[0m") - print(f"\033[32m对应情景:{situation}\033[0m") - print(f"相似度:{score:.3f}") - print("-" * 20) - else: - print("\n没有找到相似的表达方式") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py new file mode 100644 index 000000000..19007f68a --- /dev/null +++ b/scripts/interest_value_analysis.py @@ -0,0 +1,287 @@ +import time +import sys +import os +from typing import Dict, List, Tuple, Optional +from datetime import datetime +from src.common.database.database_model import Messages, ChatStreams +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + + + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp to readable date string""" + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def calculate_interest_value_distribution(messages) -> Dict[str, int]: + """Calculate distribution of interest_value""" + distribution = { + '0.000-0.010': 0, + '0.010-0.050': 0, + '0.050-0.100': 0, + '0.100-0.500': 0, + '0.500-1.000': 0, + '1.000-2.000': 0, + '2.000-5.000': 0, + '5.000-10.000': 0, + '10.000+': 0 + } + + for msg in messages: + if msg.interest_value is None or msg.interest_value == 0.0: + continue + + value = float(msg.interest_value) + if value < 0.010: + distribution['0.000-0.010'] += 1 + elif value < 0.050: + distribution['0.010-0.050'] += 1 + elif value < 0.100: + distribution['0.050-0.100'] += 1 + elif value < 0.500: + distribution['0.100-0.500'] += 1 + elif value < 1.000: + distribution['0.500-1.000'] += 1 + elif value < 2.000: + distribution['1.000-2.000'] += 1 + elif value < 5.000: + distribution['2.000-5.000'] += 1 + elif value < 10.000: + distribution['5.000-10.000'] += 1 + else: + distribution['10.000+'] += 1 + + return distribution + + +def get_interest_value_stats(messages) -> Dict[str, float]: + """Calculate basic statistics for interest_value""" + values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0] + + if not values: + return { + 'count': 0, + 'min': 0, + 'max': 0, + 'avg': 0, + 'median': 0 + } + + values.sort() + count = len(values) + + return { + 'count': count, + 'min': min(values), + 'max': max(values), + 'avg': sum(values) / count, + 'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2 + } + + +def get_available_chats() -> List[Tuple[str, str, int]]: + """Get all available chats with message counts""" + try: + # 获取所有有消息的chat_id + chat_counts = {} + for msg in Messages.select(Messages.chat_id).distinct(): + chat_id = msg.chat_id + count = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.interest_value.is_null(False)) & + (Messages.interest_value != 0.0) + ).count() + if count > 0: + chat_counts[chat_id] = count + + # 获取聊天名称 + result = [] + for chat_id, count in chat_counts.items(): + chat_name = get_chat_name(chat_id) + result.append((chat_id, chat_name, count)) + + # 按消息数量排序 + result.sort(key=lambda x: x[2], reverse=True) + return result + except Exception as e: + print(f"获取聊天列表失败: {e}") + return [] + + +def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: + """Get time range input from user""" + print("\n时间范围选择:") + print("1. 最近1天") + print("2. 最近3天") + print("3. 最近7天") + print("4. 最近30天") + print("5. 自定义时间范围") + print("6. 不限制时间") + + choice = input("请选择时间范围 (1-6): ").strip() + + now = time.time() + + if choice == "1": + return now - 24*3600, now + elif choice == "2": + return now - 3*24*3600, now + elif choice == "3": + return now - 7*24*3600, now + elif choice == "4": + return now - 30*24*3600, now + elif choice == "5": + print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") + start_str = input().strip() + print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") + end_str = input().strip() + + try: + start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() + end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() + return start_time, end_time + except ValueError: + print("时间格式错误,将不限制时间范围") + return None, None + else: + return None, None + + +def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: + """Analyze interest values with optional filters""" + + # 构建查询条件 + query = Messages.select().where( + (Messages.interest_value.is_null(False)) & + (Messages.interest_value != 0.0) + ) + + if chat_id: + query = query.where(Messages.chat_id == chat_id) + + if start_time: + query = query.where(Messages.time >= start_time) + + if end_time: + query = query.where(Messages.time <= end_time) + + messages = list(query) + + if not messages: + print("没有找到符合条件的消息") + return + + # 计算统计信息 + distribution = calculate_interest_value_distribution(messages) + stats = get_interest_value_stats(messages) + + # 显示结果 + print("\n=== Interest Value 分析结果 ===") + if chat_id: + print(f"聊天: {get_chat_name(chat_id)}") + else: + print("聊天: 全部聊天") + + if start_time and end_time: + print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") + elif start_time: + print(f"时间范围: {format_timestamp(start_time)} 之后") + elif end_time: + print(f"时间范围: {format_timestamp(end_time)} 之前") + else: + print("时间范围: 不限制") + + print("\n基本统计:") + print(f"有效消息数量: {stats['count']} (排除null和0值)") + print(f"最小值: {stats['min']:.3f}") + print(f"最大值: {stats['max']:.3f}") + print(f"平均值: {stats['avg']:.3f}") + print(f"中位数: {stats['median']:.3f}") + + print("\nInterest Value 分布:") + total = stats['count'] + for range_name, count in distribution.items(): + if count > 0: + percentage = count / total * 100 + print(f"{range_name}: {count} ({percentage:.2f}%)") + + +def interactive_menu() -> None: + """Interactive menu for interest value analysis""" + + while True: + print("\n" + "="*50) + print("Interest Value 分析工具") + print("="*50) + print("1. 分析全部聊天") + print("2. 选择特定聊天分析") + print("q. 退出") + + choice = input("\n请选择分析模式 (1-2, q): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + chat_id = None + + if choice == "2": + # 显示可用的聊天列表 + chats = get_available_chats() + if not chats: + print("没有找到有interest_value数据的聊天") + continue + + print(f"\n可用的聊天 (共{len(chats)}个):") + for i, (_cid, name, count) in enumerate(chats, 1): + print(f"{i}. {name} ({count}条有效消息)") + + try: + chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) + if 1 <= chat_choice <= len(chats): + chat_id = chats[chat_choice - 1][0] + else: + print("无效选择") + continue + except ValueError: + print("请输入有效数字") + continue + + elif choice != "1": + print("无效选择") + continue + + # 获取时间范围 + start_time, end_time = get_time_range_input() + + # 执行分析 + analyze_interest_values(chat_id, start_time, end_time) + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/scripts/log_viewer.py b/scripts/log_viewer.py deleted file mode 100644 index 248919fa8..000000000 --- a/scripts/log_viewer.py +++ /dev/null @@ -1,1185 +0,0 @@ -import tkinter as tk -from tkinter import ttk, colorchooser, messagebox, filedialog -import json -from pathlib import Path -import threading -import queue -import time -import toml -from datetime import datetime - - -class LogFormatter: - """日志格式化器,同步logger.py的格式""" - - def __init__(self, config, custom_module_colors=None, custom_level_colors=None): - self.config = config - - # 日志级别颜色 - self.level_colors = { - "debug": "#FFA500", # 橙色 - "info": "#0000FF", # 蓝色 - "success": "#008000", # 绿色 - "warning": "#FFFF00", # 黄色 - "error": "#FF0000", # 红色 - "critical": "#800080", # 紫色 - } - - # 模块颜色映射 - 同步logger.py中的MODULE_COLORS - self.module_colors = { - "api": "#00FF00", # 亮绿色 - "emoji": "#00FF00", # 亮绿色 - "chat": "#0080FF", # 亮蓝色 - "config": "#FFFF00", # 亮黄色 - "common": "#FF00FF", # 亮紫色 - "tools": "#00FFFF", # 亮青色 - "lpmm": "#00FFFF", # 亮青色 - "plugin_system": "#FF0080", # 亮红色 - "experimental": "#FFFFFF", # 亮白色 - "person_info": "#008000", # 绿色 - "individuality": "#000080", # 蓝色 - "manager": "#800080", # 紫色 - "llm_models": "#008080", # 青色 - "plugins": "#800000", # 红色 - "plugin_api": "#808000", # 黄色 - "remote": "#8000FF", # 紫蓝色 - } - - # 应用自定义颜色 - if custom_module_colors: - self.module_colors.update(custom_module_colors) - if custom_level_colors: - self.level_colors.update(custom_level_colors) - - # 根据配置决定颜色启用状态 - color_text = self.config.get("color_text", "full") - if color_text == "none": - self.enable_colors = False - self.enable_module_colors = False - self.enable_level_colors = False - elif color_text == "title": - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = False - elif color_text == "full": - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = True - else: - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = False - - def format_log_entry(self, log_entry): - """格式化日志条目,返回格式化后的文本和样式标签""" - # 获取基本信息 - timestamp = log_entry.get("timestamp", "") - level = log_entry.get("level", "info") - logger_name = log_entry.get("logger_name", "") - event = log_entry.get("event", "") - - # 格式化时间戳 - formatted_timestamp = self.format_timestamp(timestamp) - - # 构建输出部分 - parts = [] - tags = [] - - # 日志级别样式配置 - log_level_style = self.config.get("log_level_style", "lite") - - # 时间戳 - if formatted_timestamp: - if log_level_style == "lite" and self.enable_level_colors: - # lite模式下时间戳按级别着色 - parts.append(formatted_timestamp) - tags.append(f"level_{level}") - else: - parts.append(formatted_timestamp) - tags.append("timestamp") - - # 日志级别显示 - if log_level_style == "full": - # 显示完整级别名 - level_text = f"[{level.upper():>8}]" - parts.append(level_text) - if self.enable_level_colors: - tags.append(f"level_{level}") - else: - tags.append("level") - elif log_level_style == "compact": - # 只显示首字母 - level_text = f"[{level.upper()[0]:>8}]" - parts.append(level_text) - if self.enable_level_colors: - tags.append(f"level_{level}") - else: - tags.append("level") - # lite模式不显示级别 - - # 模块名称 - if logger_name: - module_text = f"[{logger_name}]" - parts.append(module_text) - if self.enable_module_colors: - tags.append(f"module_{logger_name}") - else: - tags.append("module") - - # 消息内容 - if isinstance(event, str): - parts.append(event) - elif isinstance(event, dict): - try: - parts.append(json.dumps(event, ensure_ascii=False, indent=None)) - except (TypeError, ValueError): - parts.append(str(event)) - else: - parts.append(str(event)) - tags.append("message") - - # 处理其他字段 - extras = [] - for key, value in log_entry.items(): - if key not in ("timestamp", "level", "logger_name", "event"): - if isinstance(value, (dict, list)): - try: - value_str = json.dumps(value, ensure_ascii=False, indent=None) - except (TypeError, ValueError): - value_str = str(value) - else: - value_str = str(value) - extras.append(f"{key}={value_str}") - - if extras: - parts.append(" ".join(extras)) - tags.append("extras") - - return parts, tags - - def format_timestamp(self, timestamp): - """格式化时间戳""" - if not timestamp: - return "" - - try: - # 尝试解析ISO格式时间戳 - if "T" in timestamp: - dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) - else: - # 假设已经是格式化的字符串 - return timestamp - - # 根据配置格式化 - date_style = self.config.get("date_style", "m-d H:i:s") - format_map = { - "Y": "%Y", # 4位年份 - "m": "%m", # 月份(01-12) - "d": "%d", # 日期(01-31) - "H": "%H", # 小时(00-23) - "i": "%M", # 分钟(00-59) - "s": "%S", # 秒数(00-59) - } - - python_format = date_style - for php_char, python_char in format_map.items(): - python_format = python_format.replace(php_char, python_char) - - return dt.strftime(python_format) - except Exception: - return timestamp - - -class LogViewer: - def __init__(self, root): - self.root = root - self.root.title("MaiBot日志查看器") - self.root.geometry("1200x800") - - # 加载配置 - self.load_config() - - # 初始化日志格式化器 - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - - # 初始化日志文件路径 - self.current_log_file = Path("logs/app.log.jsonl") - - # 创建主框架 - self.main_frame = ttk.Frame(root) - self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建菜单栏 - self.create_menu() - - # 创建控制面板 - self.control_frame = ttk.Frame(self.main_frame) - self.control_frame.pack(fill=tk.X, pady=(0, 5)) - - # 文件选择框架 - self.file_frame = ttk.LabelFrame(self.control_frame, text="日志文件") - self.file_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=(0, 5)) - - # 当前文件显示 - self.current_file_var = tk.StringVar(value=str(self.current_log_file)) - self.file_label = ttk.Label(self.file_frame, textvariable=self.current_file_var, foreground="blue") - self.file_label.pack(side=tk.LEFT, padx=5, pady=2) - - # 选择文件按钮 - select_file_btn = ttk.Button(self.file_frame, text="选择文件", command=self.select_log_file) - select_file_btn.pack(side=tk.RIGHT, padx=5, pady=2) - - # 刷新按钮 - refresh_btn = ttk.Button(self.file_frame, text="刷新", command=self.refresh_log_file) - refresh_btn.pack(side=tk.RIGHT, padx=2, pady=2) - - # 模块选择框架 - self.module_frame = ttk.LabelFrame(self.control_frame, text="模块") - self.module_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5) - - # 创建模块选择滚动区域 - self.module_canvas = tk.Canvas(self.module_frame, height=80) - self.module_canvas.pack(side=tk.LEFT, fill=tk.X, expand=True) - - # 创建模块选择内部框架 - self.module_inner_frame = ttk.Frame(self.module_canvas) - self.module_canvas.create_window((0, 0), window=self.module_inner_frame, anchor="nw") - - # 创建右侧控制区域(级别和搜索) - self.right_control_frame = ttk.Frame(self.control_frame) - self.right_control_frame.pack(side=tk.RIGHT, padx=5) - - # 映射编辑按钮 - mapping_btn = ttk.Button(self.right_control_frame, text="模块映射", command=self.edit_module_mapping) - mapping_btn.pack(side=tk.TOP, fill=tk.X, pady=1) - - # 日志级别选择 - level_frame = ttk.Frame(self.right_control_frame) - level_frame.pack(side=tk.TOP, fill=tk.X, pady=1) - ttk.Label(level_frame, text="级别:").pack(side=tk.LEFT, padx=2) - self.level_var = tk.StringVar(value="全部") - self.level_combo = ttk.Combobox(level_frame, textvariable=self.level_var, width=8) - self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"] - self.level_combo.pack(side=tk.LEFT, padx=2) - - # 搜索框 - search_frame = ttk.Frame(self.right_control_frame) - search_frame.pack(side=tk.TOP, fill=tk.X, pady=1) - ttk.Label(search_frame, text="搜索:").pack(side=tk.LEFT, padx=2) - self.search_var = tk.StringVar() - self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=15) - self.search_entry.pack(side=tk.LEFT, padx=2) - - # 创建日志显示区域 - self.log_frame = ttk.Frame(self.main_frame) - self.log_frame.pack(fill=tk.BOTH, expand=True) - - # 创建文本框和滚动条 - self.scrollbar = ttk.Scrollbar(self.log_frame) - self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - self.log_text = tk.Text( - self.log_frame, - wrap=tk.WORD, - yscrollcommand=self.scrollbar.set, - background="#1e1e1e", - foreground="#ffffff", - insertbackground="#ffffff", - selectbackground="#404040", - ) - self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - self.scrollbar.config(command=self.log_text.yview) - - # 配置文本标签样式 - self.configure_text_tags() - - # 模块名映射 - self.module_name_mapping = { - "api": "API接口", - "async_task_manager": "异步任务管理器", - "background_tasks": "后台任务", - "base_tool": "基础工具", - "chat_stream": "聊天流", - "component_registry": "组件注册器", - "config": "配置", - "database_model": "数据库模型", - "emoji": "表情", - "heartflow": "心流", - "local_storage": "本地存储", - "lpmm": "LPMM", - "maibot_statistic": "MaiBot统计", - "main_message": "主消息", - "main": "主程序", - "memory": "内存", - "mood": "情绪", - "plugin_manager": "插件管理器", - "remote": "远程", - "willing": "意愿", - } - - # 加载自定义映射 - self.load_module_mapping() - - # 创建日志队列和缓存 - self.log_queue = queue.Queue() - self.log_cache = [] - - # 选中的模块集合 - self.selected_modules = set() - - # 初始化模块列表 - self.modules = set() - self.update_module_list() - - # 绑定事件 - self.level_combo.bind("<>", self.filter_logs) - self.search_var.trace("w", self.filter_logs) - - # 启动日志监控线程 - self.running = True - self.monitor_thread = threading.Thread(target=self.monitor_log_file) - self.monitor_thread.daemon = True - self.monitor_thread.start() - - # 启动日志更新线程 - self.update_thread = threading.Thread(target=self.update_logs) - self.update_thread.daemon = True - self.update_thread.start() - - # 绑定快捷键 - self.root.bind("", lambda e: self.select_log_file()) - self.root.bind("", lambda e: self.refresh_log_file()) - self.root.bind("", lambda e: self.export_logs()) - - # 更新窗口标题 - self.update_window_title() - - def load_config(self): - """加载配置文件""" - # 默认配置 - self.default_config = { - "log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full", "log_level": "INFO"}, - "viewer": { - "theme": "dark", - "font_size": 10, - "max_lines": 1000, - "auto_scroll": True, - "show_milliseconds": False, - "window": {"width": 1200, "height": 800, "remember_position": True}, - }, - } - - # 从bot_config.toml加载日志配置 - config_path = Path("config/bot_config.toml") - self.log_config = self.default_config["log"].copy() - self.viewer_config = self.default_config["viewer"].copy() - - try: - if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: - bot_config = toml.load(f) - if "log" in bot_config: - self.log_config.update(bot_config["log"]) - except Exception as e: - print(f"加载bot配置失败: {e}") - - # 从viewer配置文件加载查看器配置 - viewer_config_path = Path("config/log_viewer_config.toml") - self.custom_module_colors = {} - self.custom_level_colors = {} - - try: - if viewer_config_path.exists(): - with open(viewer_config_path, "r", encoding="utf-8") as f: - viewer_config = toml.load(f) - if "viewer" in viewer_config: - self.viewer_config.update(viewer_config["viewer"]) - - # 加载自定义模块颜色 - if "module_colors" in viewer_config["viewer"]: - self.custom_module_colors = viewer_config["viewer"]["module_colors"] - - # 加载自定义级别颜色 - if "level_colors" in viewer_config["viewer"]: - self.custom_level_colors = viewer_config["viewer"]["level_colors"] - - if "log" in viewer_config: - self.log_config.update(viewer_config["log"]) - except Exception as e: - print(f"加载查看器配置失败: {e}") - - # 应用窗口配置 - window_config = self.viewer_config.get("window", {}) - window_width = window_config.get("width", 1200) - window_height = window_config.get("height", 800) - self.root.geometry(f"{window_width}x{window_height}") - - def save_viewer_config(self): - """保存查看器配置""" - # 准备完整的配置数据 - viewer_config_copy = self.viewer_config.copy() - - # 保存自定义颜色(只保存与默认值不同的颜色) - if self.custom_module_colors: - viewer_config_copy["module_colors"] = self.custom_module_colors - if self.custom_level_colors: - viewer_config_copy["level_colors"] = self.custom_level_colors - - config_data = {"log": self.log_config, "viewer": viewer_config_copy} - - config_path = Path("config/log_viewer_config.toml") - config_path.parent.mkdir(exist_ok=True) - - try: - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - except Exception as e: - print(f"保存查看器配置失败: {e}") - - def create_menu(self): - """创建菜单栏""" - menubar = tk.Menu(self.root) - self.root.config(menu=menubar) - - # 配置菜单 - config_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="配置", menu=config_menu) - config_menu.add_command(label="日志格式设置", command=self.show_format_settings) - config_menu.add_command(label="颜色设置", command=self.show_color_settings) - config_menu.add_command(label="查看器设置", command=self.show_viewer_settings) - config_menu.add_separator() - config_menu.add_command(label="重新加载配置", command=self.reload_config) - - # 文件菜单 - file_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="文件", menu=file_menu) - file_menu.add_command(label="选择日志文件", command=self.select_log_file, accelerator="Ctrl+O") - file_menu.add_command(label="刷新当前文件", command=self.refresh_log_file, accelerator="F5") - file_menu.add_separator() - file_menu.add_command(label="导出当前日志", command=self.export_logs, accelerator="Ctrl+S") - - # 工具菜单 - tools_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="工具", menu=tools_menu) - tools_menu.add_command(label="清空日志显示", command=self.clear_log_display) - - def show_format_settings(self): - """显示格式设置窗口""" - format_window = tk.Toplevel(self.root) - format_window.title("日志格式设置") - format_window.geometry("400x300") - - frame = ttk.Frame(format_window) - frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - - # 日期格式 - ttk.Label(frame, text="日期格式:").pack(anchor="w", pady=2) - date_style_var = tk.StringVar(value=self.log_config.get("date_style", "m-d H:i:s")) - date_entry = ttk.Entry(frame, textvariable=date_style_var, width=30) - date_entry.pack(anchor="w", pady=2) - ttk.Label(frame, text="格式说明: Y=年份, m=月份, d=日期, H=小时, i=分钟, s=秒", font=("", 8)).pack( - anchor="w", pady=2 - ) - - # 日志级别样式 - ttk.Label(frame, text="日志级别样式:").pack(anchor="w", pady=(10, 2)) - level_style_var = tk.StringVar(value=self.log_config.get("log_level_style", "lite")) - level_frame = ttk.Frame(frame) - level_frame.pack(anchor="w", pady=2) - - ttk.Radiobutton(level_frame, text="简洁(lite)", variable=level_style_var, value="lite").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(level_frame, text="紧凑(compact)", variable=level_style_var, value="compact").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(level_frame, text="完整(full)", variable=level_style_var, value="full").pack( - side="left", padx=(0, 10) - ) - - # 颜色文本设置 - ttk.Label(frame, text="文本颜色设置:").pack(anchor="w", pady=(10, 2)) - color_text_var = tk.StringVar(value=self.log_config.get("color_text", "full")) - color_frame = ttk.Frame(frame) - color_frame.pack(anchor="w", pady=2) - - ttk.Radiobutton(color_frame, text="无颜色(none)", variable=color_text_var, value="none").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(color_frame, text="仅标题(title)", variable=color_text_var, value="title").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(color_frame, text="全部(full)", variable=color_text_var, value="full").pack( - side="left", padx=(0, 10) - ) - - # 按钮 - button_frame = ttk.Frame(frame) - button_frame.pack(fill="x", pady=(20, 0)) - - def apply_format(): - self.log_config["date_style"] = date_style_var.get() - self.log_config["log_level_style"] = level_style_var.get() - self.log_config["color_text"] = color_text_var.get() - - # 重新初始化格式化器 - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - self.configure_text_tags() - - # 保存配置 - self.save_viewer_config() - - # 重新过滤日志以应用新格式 - self.filter_logs() - - format_window.destroy() - - ttk.Button(button_frame, text="应用", command=apply_format).pack(side="right", padx=(5, 0)) - ttk.Button(button_frame, text="取消", command=format_window.destroy).pack(side="right") - - def show_viewer_settings(self): - """显示查看器设置窗口""" - viewer_window = tk.Toplevel(self.root) - viewer_window.title("查看器设置") - viewer_window.geometry("350x250") - - frame = ttk.Frame(viewer_window) - frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - - # 主题设置 - ttk.Label(frame, text="主题:").pack(anchor="w", pady=2) - theme_var = tk.StringVar(value=self.viewer_config.get("theme", "dark")) - theme_frame = ttk.Frame(frame) - theme_frame.pack(anchor="w", pady=2) - ttk.Radiobutton(theme_frame, text="深色", variable=theme_var, value="dark").pack(side="left", padx=(0, 10)) - ttk.Radiobutton(theme_frame, text="浅色", variable=theme_var, value="light").pack(side="left") - - # 字体大小 - ttk.Label(frame, text="字体大小:").pack(anchor="w", pady=(10, 2)) - font_size_var = tk.IntVar(value=self.viewer_config.get("font_size", 10)) - font_size_spin = ttk.Spinbox(frame, from_=8, to=20, textvariable=font_size_var, width=10) - font_size_spin.pack(anchor="w", pady=2) - - # 最大行数 - ttk.Label(frame, text="最大显示行数:").pack(anchor="w", pady=(10, 2)) - max_lines_var = tk.IntVar(value=self.viewer_config.get("max_lines", 1000)) - max_lines_spin = ttk.Spinbox(frame, from_=100, to=10000, increment=100, textvariable=max_lines_var, width=10) - max_lines_spin.pack(anchor="w", pady=2) - - # 自动滚动 - auto_scroll_var = tk.BooleanVar(value=self.viewer_config.get("auto_scroll", True)) - ttk.Checkbutton(frame, text="自动滚动到底部", variable=auto_scroll_var).pack(anchor="w", pady=(10, 2)) - - # 按钮 - button_frame = ttk.Frame(frame) - button_frame.pack(fill="x", pady=(20, 0)) - - def apply_viewer_settings(): - self.viewer_config["theme"] = theme_var.get() - self.viewer_config["font_size"] = font_size_var.get() - self.viewer_config["max_lines"] = max_lines_var.get() - self.viewer_config["auto_scroll"] = auto_scroll_var.get() - - # 应用主题 - self.apply_theme() - - # 保存配置 - self.save_viewer_config() - - viewer_window.destroy() - - ttk.Button(button_frame, text="应用", command=apply_viewer_settings).pack(side="right", padx=(5, 0)) - ttk.Button(button_frame, text="取消", command=viewer_window.destroy).pack(side="right") - - def apply_theme(self): - """应用主题设置""" - theme = self.viewer_config.get("theme", "dark") - font_size = self.viewer_config.get("font_size", 10) - - if theme == "dark": - bg_color = "#1e1e1e" - fg_color = "#ffffff" - select_bg = "#404040" - else: - bg_color = "#ffffff" - fg_color = "#000000" - select_bg = "#c0c0c0" - - self.log_text.config( - background=bg_color, foreground=fg_color, selectbackground=select_bg, font=("Consolas", font_size) - ) - - # 重新配置标签样式 - self.configure_text_tags() - - def configure_text_tags(self): - """配置文本标签样式""" - # 清除现有标签 - for tag in self.log_text.tag_names(): - if tag != "sel": - self.log_text.tag_delete(tag) - - # 基础标签 - self.log_text.tag_configure("timestamp", foreground="#808080") - self.log_text.tag_configure("level", foreground="#808080") - self.log_text.tag_configure("module", foreground="#808080") - self.log_text.tag_configure("message", foreground=self.log_text.cget("foreground")) - self.log_text.tag_configure("extras", foreground="#808080") - - # 日志级别颜色标签 - for level, color in self.formatter.level_colors.items(): - self.log_text.tag_configure(f"level_{level}", foreground=color) - - # 模块颜色标签 - for module, color in self.formatter.module_colors.items(): - self.log_text.tag_configure(f"module_{module}", foreground=color) - - def reload_config(self): - """重新加载配置""" - self.load_config() - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - self.configure_text_tags() - self.apply_theme() - self.filter_logs() - - def clear_log_display(self): - """清空日志显示""" - self.log_text.delete(1.0, tk.END) - - def export_logs(self): - """导出当前显示的日志""" - filename = filedialog.asksaveasfilename( - defaultextension=".txt", filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")] - ) - if filename: - try: - with open(filename, "w", encoding="utf-8") as f: - f.write(self.log_text.get(1.0, tk.END)) - messagebox.showinfo("导出成功", f"日志已导出到: {filename}") - except Exception as e: - messagebox.showerror("导出失败", f"导出日志时出错: {e}") - - def load_module_mapping(self): - """加载自定义模块映射""" - mapping_file = Path("config/module_mapping.json") - if mapping_file.exists(): - try: - with open(mapping_file, "r", encoding="utf-8") as f: - custom_mapping = json.load(f) - self.module_name_mapping.update(custom_mapping) - except Exception as e: - print(f"加载模块映射失败: {e}") - - def save_module_mapping(self): - """保存自定义模块映射""" - mapping_file = Path("config/module_mapping.json") - mapping_file.parent.mkdir(exist_ok=True) - try: - with open(mapping_file, "w", encoding="utf-8") as f: - json.dump(self.module_name_mapping, f, ensure_ascii=False, indent=2) - except Exception as e: - print(f"保存模块映射失败: {e}") - - def show_color_settings(self): - """显示颜色设置窗口""" - color_window = tk.Toplevel(self.root) - color_window.title("颜色设置") - color_window.geometry("300x400") - - # 创建滚动框架 - frame = ttk.Frame(color_window) - frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建滚动条 - scrollbar = ttk.Scrollbar(frame) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - # 创建颜色设置列表 - canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) - canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - scrollbar.config(command=canvas.yview) - - # 创建内部框架 - inner_frame = ttk.Frame(canvas) - canvas.create_window((0, 0), window=inner_frame, anchor="nw") - - # 添加日志级别颜色设置 - ttk.Label(inner_frame, text="日志级别颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) - for level in ["info", "warning", "error"]: - frame = ttk.Frame(inner_frame) - frame.pack(fill=tk.X, padx=5, pady=2) - ttk.Label(frame, text=level).pack(side=tk.LEFT) - color_btn = ttk.Button( - frame, text="选择颜色", command=lambda level_name=level: self.choose_color(level_name) - ) - color_btn.pack(side=tk.RIGHT) - # 显示当前颜色 - color_label = ttk.Label(frame, text="■", foreground=self.formatter.level_colors[level]) - color_label.pack(side=tk.RIGHT, padx=5) - - # 添加模块颜色设置 - ttk.Label(inner_frame, text="\n模块颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) - for module in sorted(self.modules): - frame = ttk.Frame(inner_frame) - frame.pack(fill=tk.X, padx=5, pady=2) - ttk.Label(frame, text=module).pack(side=tk.LEFT) - color_btn = ttk.Button(frame, text="选择颜色", command=lambda m=module: self.choose_module_color(m)) - color_btn.pack(side=tk.RIGHT) - # 显示当前颜色 - color = self.formatter.module_colors.get(module, "black") - color_label = ttk.Label(frame, text="■", foreground=color) - color_label.pack(side=tk.RIGHT, padx=5) - - # 更新画布滚动区域 - inner_frame.update_idletasks() - canvas.config(scrollregion=canvas.bbox("all")) - - # 添加确定按钮 - ttk.Button(color_window, text="确定", command=color_window.destroy).pack(pady=5) - - def choose_color(self, level): - """选择日志级别颜色""" - color = colorchooser.askcolor(color=self.formatter.level_colors[level])[1] - if color: - self.formatter.level_colors[level] = color - self.custom_level_colors[level] = color # 保存到自定义颜色 - self.configure_text_tags() - self.save_viewer_config() # 自动保存配置 - self.filter_logs() - - def choose_module_color(self, module): - """选择模块颜色""" - color = colorchooser.askcolor(color=self.formatter.module_colors.get(module, "black"))[1] - if color: - self.formatter.module_colors[module] = color - self.custom_module_colors[module] = color # 保存到自定义颜色 - self.configure_text_tags() - self.save_viewer_config() # 自动保存配置 - self.filter_logs() - - def update_module_list(self): - """更新模块列表""" - if self.current_log_file.exists(): - with open(self.current_log_file, "r", encoding="utf-8") as f: - for line in f: - try: - log_entry = json.loads(line) - if "logger_name" in log_entry: - self.modules.add(log_entry["logger_name"]) - except json.JSONDecodeError: - continue - - # 清空现有选项 - for widget in self.module_inner_frame.winfo_children(): - widget.destroy() - - # 计算总模块数(包括"全部") - total_modules = len(self.modules) + 1 - max_cols = min(4, max(2, total_modules)) # 减少最大列数,避免超出边界 - - # 配置网格列权重,让每列平均分配空间 - for i in range(max_cols): - self.module_inner_frame.grid_columnconfigure(i, weight=1, uniform="module_col") - - # 创建一个多行布局 - current_row = 0 - current_col = 0 - - # 添加"全部"选项 - all_frame = ttk.Frame(self.module_inner_frame) - all_frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") - - all_var = tk.BooleanVar(value="全部" in self.selected_modules) - all_check = ttk.Checkbutton( - all_frame, text="全部", variable=all_var, command=lambda: self.toggle_module("全部", all_var) - ) - all_check.pack(side=tk.LEFT) - - # 使用颜色标签替代按钮 - all_color = self.formatter.module_colors.get("全部", "black") - all_color_label = ttk.Label(all_frame, text="■", foreground=all_color, width=2, cursor="hand2") - all_color_label.pack(side=tk.LEFT, padx=2) - all_color_label.bind("", lambda e: self.choose_module_color("全部")) - - current_col += 1 - - # 添加其他模块选项 - for module in sorted(self.modules): - if current_col >= max_cols: - current_row += 1 - current_col = 0 - - frame = ttk.Frame(self.module_inner_frame) - frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") - - var = tk.BooleanVar(value=module in self.selected_modules) - - # 使用中文映射名称显示 - display_name = self.get_display_name(module) - if len(display_name) > 12: - display_name = display_name[:10] + "..." - - check = ttk.Checkbutton( - frame, text=display_name, variable=var, command=lambda m=module, v=var: self.toggle_module(m, v) - ) - check.pack(side=tk.LEFT) - - # 添加工具提示显示完整名称和英文名 - full_tooltip = f"{self.get_display_name(module)}" - if module != self.get_display_name(module): - full_tooltip += f"\n({module})" - self.create_tooltip(check, full_tooltip) - - # 使用颜色标签替代按钮 - color = self.formatter.module_colors.get(module, "black") - color_label = ttk.Label(frame, text="■", foreground=color, width=2, cursor="hand2") - color_label.pack(side=tk.LEFT, padx=2) - color_label.bind("", lambda e, m=module: self.choose_module_color(m)) - - current_col += 1 - - # 更新画布滚动区域 - self.module_inner_frame.update_idletasks() - self.module_canvas.config(scrollregion=self.module_canvas.bbox("all")) - - # 添加垂直滚动条 - if not hasattr(self, "module_scrollbar"): - self.module_scrollbar = ttk.Scrollbar( - self.module_frame, orient=tk.VERTICAL, command=self.module_canvas.yview - ) - self.module_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - self.module_canvas.config(yscrollcommand=self.module_scrollbar.set) - - def create_tooltip(self, widget, text): - """为控件创建工具提示""" - - def on_enter(event): - tooltip = tk.Toplevel() - tooltip.wm_overrideredirect(True) - tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}") - label = ttk.Label(tooltip, text=text, background="lightyellow", relief="solid", borderwidth=1) - label.pack() - widget.tooltip = tooltip - - def on_leave(event): - if hasattr(widget, "tooltip"): - widget.tooltip.destroy() - del widget.tooltip - - widget.bind("", on_enter) - widget.bind("", on_leave) - - def toggle_module(self, module, var): - """切换模块选择状态""" - if module == "全部": - if var.get(): - self.selected_modules = {"全部"} - else: - self.selected_modules.clear() - else: - if var.get(): - self.selected_modules.add(module) - if "全部" in self.selected_modules: - self.selected_modules.remove("全部") - else: - self.selected_modules.discard(module) - - self.filter_logs() - - def monitor_log_file(self): - """监控日志文件变化""" - last_position = 0 - current_monitored_file = None - - while self.running: - # 检查是否需要切换监控的文件 - if current_monitored_file != self.current_log_file: - current_monitored_file = self.current_log_file - last_position = 0 # 重置位置 - - if current_monitored_file.exists(): - try: - # 使用共享读取模式,避免文件锁定 - with open(current_monitored_file, "r", encoding="utf-8", buffering=1) as f: - f.seek(last_position) - new_lines = f.readlines() - last_position = f.tell() - - for line in new_lines: - try: - log_entry = json.loads(line) - self.log_queue.put(log_entry) - self.log_cache.append(log_entry) - - # 检查是否有新模块 - if "logger_name" in log_entry: - logger_name = log_entry["logger_name"] - if logger_name not in self.modules: - self.modules.add(logger_name) - # 在主线程中更新模块列表UI - self.root.after(0, self.update_module_list) - - except json.JSONDecodeError: - continue - except (FileNotFoundError, PermissionError) as e: - # 文件被占用或不存在时,等待更长时间 - print(f"日志文件访问受限: {e}") - time.sleep(1) - continue - except Exception as e: - print(f"读取日志文件时出错: {e}") - - time.sleep(0.1) - - def update_logs(self): - """更新日志显示""" - while self.running: - try: - log_entry = self.log_queue.get(timeout=0.1) - self.process_log_entry(log_entry) - except queue.Empty: - continue - - def process_log_entry(self, log_entry): - """处理日志条目""" - # 检查过滤条件 - if not self.should_show_log(log_entry): - return - - # 使用格式化器格式化日志 - parts, tags = self.formatter.format_log_entry(log_entry) - - # 在主线程中更新UI - self.root.after(0, lambda: self.add_formatted_log_line(parts, tags, log_entry)) - - def add_formatted_log_line(self, parts, tags, log_entry): - """添加格式化的日志行到文本框""" - # 控制最大行数 - max_lines = self.viewer_config.get("max_lines", 1000) - current_lines = int(self.log_text.index("end-1c").split(".")[0]) - - if current_lines > max_lines: - # 删除前面的行 - lines_to_delete = current_lines - max_lines + 100 # 一次删除多一些,减少频繁操作 - self.log_text.delete(1.0, f"{lines_to_delete}.0") - - # 插入格式化的文本 - for i, part in enumerate(parts): - if i < len(tags): - tag = tags[i] - # 根据内容类型选择合适的标签 - if tag.startswith("level_"): - if self.formatter.enable_level_colors: - self.log_text.insert(tk.END, part, tag) - else: - self.log_text.insert(tk.END, part, "level") - elif tag.startswith("module_"): - if self.formatter.enable_module_colors: - self.log_text.insert(tk.END, part, tag) - else: - self.log_text.insert(tk.END, part, "module") - else: - self.log_text.insert(tk.END, part, tag) - else: - self.log_text.insert(tk.END, part) - - # 在部分之间添加空格(除了最后一个) - if i < len(parts) - 1: - self.log_text.insert(tk.END, " ") - - self.log_text.insert(tk.END, "\n") - - # 自动滚动 - if self.viewer_config.get("auto_scroll", True): - if self.log_text.yview()[1] >= 0.99: - self.log_text.see(tk.END) - - def should_show_log(self, log_entry): - """检查日志是否应该显示""" - # 检查模块过滤 - if self.selected_modules: - if "全部" not in self.selected_modules: - if log_entry.get("logger_name") not in self.selected_modules: - return False - - # 检查级别过滤 - if self.level_var.get() != "全部": - if log_entry.get("level") != self.level_var.get(): - return False - - # 检查搜索过滤 - search_text = self.search_var.get().lower() - if search_text: - event = str(log_entry.get("event", "")).lower() - logger_name = str(log_entry.get("logger_name", "")).lower() - if search_text not in event and search_text not in logger_name: - return False - - return True - - def filter_logs(self, *args): - """过滤日志""" - # 保存当前滚动位置 - scroll_position = self.log_text.yview() - - # 清空显示 - self.log_text.delete(1.0, tk.END) - - # 重新显示所有符合条件的日志 - for log_entry in self.log_cache: - if self.should_show_log(log_entry): - parts, tags = self.formatter.format_log_entry(log_entry) - self.add_formatted_log_line(parts, tags, log_entry) - - # 恢复滚动位置(如果不是自动滚动模式) - if not self.viewer_config.get("auto_scroll", True): - self.log_text.yview_moveto(scroll_position[0]) - - def get_display_name(self, module_name): - """获取模块的显示名称""" - return self.module_name_mapping.get(module_name, module_name) - - def edit_module_mapping(self): - """编辑模块映射""" - mapping_window = tk.Toplevel(self.root) - mapping_window.title("编辑模块映射") - mapping_window.geometry("500x600") - - # 创建滚动框架 - frame = ttk.Frame(mapping_window) - frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建滚动条 - scrollbar = ttk.Scrollbar(frame) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - # 创建映射编辑列表 - canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) - canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - scrollbar.config(command=canvas.yview) - - # 创建内部框架 - inner_frame = ttk.Frame(canvas) - canvas.create_window((0, 0), window=inner_frame, anchor="nw") - - # 添加标题 - ttk.Label(inner_frame, text="模块映射编辑", font=("", 12, "bold")).pack(anchor="w", padx=5, pady=5) - ttk.Label(inner_frame, text="英文名 -> 中文名", font=("", 10)).pack(anchor="w", padx=5, pady=2) - - # 映射编辑字典 - mapping_vars = {} - - # 添加现有模块的映射编辑 - all_modules = sorted(self.modules) - for module in all_modules: - frame_row = ttk.Frame(inner_frame) - frame_row.pack(fill=tk.X, padx=5, pady=2) - - ttk.Label(frame_row, text=module, width=20).pack(side=tk.LEFT, padx=5) - ttk.Label(frame_row, text="->").pack(side=tk.LEFT, padx=5) - - var = tk.StringVar(value=self.module_name_mapping.get(module, module)) - mapping_vars[module] = var - entry = ttk.Entry(frame_row, textvariable=var, width=25) - entry.pack(side=tk.LEFT, padx=5) - - # 更新画布滚动区域 - inner_frame.update_idletasks() - canvas.config(scrollregion=canvas.bbox("all")) - - def save_mappings(): - # 更新映射 - for module, var in mapping_vars.items(): - new_name = var.get().strip() - if new_name and new_name != module: - self.module_name_mapping[module] = new_name - elif module in self.module_name_mapping and not new_name: - del self.module_name_mapping[module] - - # 保存到文件 - self.save_module_mapping() - # 更新模块列表显示 - self.update_module_list() - mapping_window.destroy() - - # 添加按钮 - button_frame = ttk.Frame(mapping_window) - button_frame.pack(fill=tk.X, padx=5, pady=5) - ttk.Button(button_frame, text="保存", command=save_mappings).pack(side=tk.RIGHT, padx=5) - ttk.Button(button_frame, text="取消", command=mapping_window.destroy).pack(side=tk.RIGHT, padx=5) - - def select_log_file(self): - """选择日志文件""" - filename = filedialog.askopenfilename( - title="选择日志文件", - filetypes=[("JSONL日志文件", "*.jsonl"), ("所有文件", "*.*")], - initialdir="logs" if Path("logs").exists() else ".", - ) - if filename: - new_file = Path(filename) - if new_file != self.current_log_file: - self.current_log_file = new_file - self.current_file_var.set(str(self.current_log_file)) - self.reload_log_file() - - def refresh_log_file(self): - """刷新日志文件""" - self.reload_log_file() - - def reload_log_file(self): - """重新加载日志文件""" - # 清空当前缓存和显示 - self.log_cache.clear() - self.modules.clear() - self.selected_modules.clear() - self.log_text.delete(1.0, tk.END) - - # 清空日志队列 - while not self.log_queue.empty(): - try: - self.log_queue.get_nowait() - except queue.Empty: - break - - # 重新读取整个文件 - if self.current_log_file.exists(): - try: - with open(self.current_log_file, "r", encoding="utf-8") as f: - for line in f: - try: - log_entry = json.loads(line) - self.log_cache.append(log_entry) - - # 收集模块信息 - if "logger_name" in log_entry: - self.modules.add(log_entry["logger_name"]) - - except json.JSONDecodeError: - continue - except Exception as e: - messagebox.showerror("错误", f"读取日志文件失败: {e}") - return - - # 更新模块列表UI - self.update_module_list() - - # 过滤并显示日志 - self.filter_logs() - - # 更新窗口标题 - self.update_window_title() - - def update_window_title(self): - """更新窗口标题""" - filename = self.current_log_file.name - self.root.title(f"MaiBot日志查看器 - {filename}") - - -def main(): - root = tk.Tk() - LogViewer(root) - root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py index 3a96e4aac..8f19fb6cf 100644 --- a/scripts/log_viewer_optimized.py +++ b/scripts/log_viewer_optimized.py @@ -1,5 +1,5 @@ import tkinter as tk -from tkinter import ttk, messagebox, filedialog +from tkinter import ttk, messagebox, filedialog, colorchooser import json from pathlib import Path import threading @@ -206,6 +206,23 @@ class LogFormatter: parts.append(str(event)) tags.append("message") + # 处理其他字段 + extras = [] + for key, value in log_entry.items(): + if key not in ("timestamp", "level", "logger_name", "event"): + if isinstance(value, (dict, list)): + try: + value_str = json.dumps(value, ensure_ascii=False, indent=None) + except (TypeError, ValueError): + value_str = str(value) + else: + value_str = str(value) + extras.append(f"{key}={value_str}") + + if extras: + parts.append(" ".join(extras)) + tags.append("extras") + return parts, tags def format_timestamp(self, timestamp): @@ -287,6 +304,7 @@ class VirtualLogDisplay: self.text_widget.tag_configure("level", foreground="#808080") self.text_widget.tag_configure("module", foreground="#808080") self.text_widget.tag_configure("message", foreground="#ffffff") + self.text_widget.tag_configure("extras", foreground="#808080") # 日志级别颜色标签 for level, color in self.formatter.level_colors.items(): @@ -449,7 +467,7 @@ class LogViewer: self.load_config() # 初始化日志格式化器 - self.formatter = LogFormatter(self.log_config, {}, {}) + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) # 初始化日志文件路径 self.current_log_file = Path("logs/app.log.jsonl") @@ -467,6 +485,9 @@ class LogViewer: self.main_frame = ttk.Frame(root) self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + # 创建菜单栏 + self.create_menu() + # 创建控制面板 self.create_control_panel() @@ -477,12 +498,30 @@ class LogViewer: # 模块名映射 self.module_name_mapping = { "api": "API接口", + "async_task_manager": "异步任务管理器", + "background_tasks": "后台任务", + "base_tool": "基础工具", + "chat_stream": "聊天流", + "component_registry": "组件注册器", "config": "配置", - "chat": "聊天", - "plugin": "插件", + "database_model": "数据库模型", + "emoji": "表情", + "heartflow": "心流", + "local_storage": "本地存储", + "lpmm": "LPMM", + "maibot_statistic": "MaiBot统计", + "main_message": "主消息", "main": "主程序", + "memory": "内存", + "mood": "情绪", + "plugin_manager": "插件管理器", + "remote": "远程", + "willing": "意愿", } + # 加载自定义映射 + self.load_module_mapping() + # 选中的模块集合 self.selected_modules = set() self.modules = set() @@ -491,19 +530,35 @@ class LogViewer: self.level_combo.bind("<>", self.filter_logs) self.search_var.trace("w", self.filter_logs) + # 绑定快捷键 + self.root.bind("", lambda e: self.select_log_file()) + self.root.bind("", lambda e: self.refresh_log_file()) + self.root.bind("", lambda e: self.export_logs()) + # 初始加载文件 if self.current_log_file.exists(): self.load_log_file_async() def load_config(self): """加载配置文件""" + # 默认配置 self.default_config = { - "log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full"}, + "log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full", "log_level": "INFO"}, + "viewer": { + "theme": "dark", + "font_size": 10, + "max_lines": 1000, + "auto_scroll": True, + "show_milliseconds": False, + "window": {"width": 1200, "height": 800, "remember_position": True}, + }, } - self.log_config = self.default_config["log"].copy() - + # 从bot_config.toml加载日志配置 config_path = Path("config/bot_config.toml") + self.log_config = self.default_config["log"].copy() + self.viewer_config = self.default_config["viewer"].copy() + try: if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: @@ -511,7 +566,377 @@ class LogViewer: if "log" in bot_config: self.log_config.update(bot_config["log"]) except Exception as e: - print(f"加载配置失败: {e}") + print(f"加载bot配置失败: {e}") + + # 从viewer配置文件加载查看器配置 + viewer_config_path = Path("config/log_viewer_config.toml") + self.custom_module_colors = {} + self.custom_level_colors = {} + + try: + if viewer_config_path.exists(): + with open(viewer_config_path, "r", encoding="utf-8") as f: + viewer_config = toml.load(f) + if "viewer" in viewer_config: + self.viewer_config.update(viewer_config["viewer"]) + + # 加载自定义模块颜色 + if "module_colors" in viewer_config["viewer"]: + self.custom_module_colors = viewer_config["viewer"]["module_colors"] + + # 加载自定义级别颜色 + if "level_colors" in viewer_config["viewer"]: + self.custom_level_colors = viewer_config["viewer"]["level_colors"] + + if "log" in viewer_config: + self.log_config.update(viewer_config["log"]) + except Exception as e: + print(f"加载查看器配置失败: {e}") + + # 应用窗口配置 + window_config = self.viewer_config.get("window", {}) + window_width = window_config.get("width", 1200) + window_height = window_config.get("height", 800) + self.root.geometry(f"{window_width}x{window_height}") + + def save_viewer_config(self): + """保存查看器配置""" + # 准备完整的配置数据 + viewer_config_copy = self.viewer_config.copy() + + # 保存自定义颜色(只保存与默认值不同的颜色) + if self.custom_module_colors: + viewer_config_copy["module_colors"] = self.custom_module_colors + if self.custom_level_colors: + viewer_config_copy["level_colors"] = self.custom_level_colors + + config_data = {"log": self.log_config, "viewer": viewer_config_copy} + + config_path = Path("config/log_viewer_config.toml") + config_path.parent.mkdir(exist_ok=True) + + try: + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + except Exception as e: + print(f"保存查看器配置失败: {e}") + + def create_menu(self): + """创建菜单栏""" + menubar = tk.Menu(self.root) + self.root.config(menu=menubar) + + # 配置菜单 + config_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="配置", menu=config_menu) + config_menu.add_command(label="日志格式设置", command=self.show_format_settings) + config_menu.add_command(label="颜色设置", command=self.show_color_settings) + config_menu.add_command(label="查看器设置", command=self.show_viewer_settings) + config_menu.add_separator() + config_menu.add_command(label="重新加载配置", command=self.reload_config) + + # 文件菜单 + file_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="文件", menu=file_menu) + file_menu.add_command(label="选择日志文件", command=self.select_log_file, accelerator="Ctrl+O") + file_menu.add_command(label="刷新当前文件", command=self.refresh_log_file, accelerator="F5") + file_menu.add_separator() + file_menu.add_command(label="导出当前日志", command=self.export_logs, accelerator="Ctrl+S") + + # 工具菜单 + tools_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="工具", menu=tools_menu) + tools_menu.add_command(label="清空日志显示", command=self.clear_log_display) + + def show_format_settings(self): + """显示格式设置窗口""" + format_window = tk.Toplevel(self.root) + format_window.title("日志格式设置") + format_window.geometry("400x300") + + frame = ttk.Frame(format_window) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # 日期格式 + ttk.Label(frame, text="日期格式:").pack(anchor="w", pady=2) + date_style_var = tk.StringVar(value=self.log_config.get("date_style", "m-d H:i:s")) + date_entry = ttk.Entry(frame, textvariable=date_style_var, width=30) + date_entry.pack(anchor="w", pady=2) + ttk.Label(frame, text="格式说明: Y=年份, m=月份, d=日期, H=小时, i=分钟, s=秒", font=("", 8)).pack( + anchor="w", pady=2 + ) + + # 日志级别样式 + ttk.Label(frame, text="日志级别样式:").pack(anchor="w", pady=(10, 2)) + level_style_var = tk.StringVar(value=self.log_config.get("log_level_style", "lite")) + level_frame = ttk.Frame(frame) + level_frame.pack(anchor="w", pady=2) + + ttk.Radiobutton(level_frame, text="简洁(lite)", variable=level_style_var, value="lite").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(level_frame, text="紧凑(compact)", variable=level_style_var, value="compact").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(level_frame, text="完整(full)", variable=level_style_var, value="full").pack( + side="left", padx=(0, 10) + ) + + # 颜色文本设置 + ttk.Label(frame, text="文本颜色设置:").pack(anchor="w", pady=(10, 2)) + color_text_var = tk.StringVar(value=self.log_config.get("color_text", "full")) + color_frame = ttk.Frame(frame) + color_frame.pack(anchor="w", pady=2) + + ttk.Radiobutton(color_frame, text="无颜色(none)", variable=color_text_var, value="none").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(color_frame, text="仅标题(title)", variable=color_text_var, value="title").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(color_frame, text="全部(full)", variable=color_text_var, value="full").pack( + side="left", padx=(0, 10) + ) + + # 按钮 + button_frame = ttk.Frame(frame) + button_frame.pack(fill="x", pady=(20, 0)) + + def apply_format(): + self.log_config["date_style"] = date_style_var.get() + self.log_config["log_level_style"] = level_style_var.get() + self.log_config["color_text"] = color_text_var.get() + + # 重新初始化格式化器 + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + + # 保存配置 + self.save_viewer_config() + + # 重新过滤日志以应用新格式 + self.filter_logs() + + format_window.destroy() + + ttk.Button(button_frame, text="应用", command=apply_format).pack(side="right", padx=(5, 0)) + ttk.Button(button_frame, text="取消", command=format_window.destroy).pack(side="right") + + def show_viewer_settings(self): + """显示查看器设置窗口""" + viewer_window = tk.Toplevel(self.root) + viewer_window.title("查看器设置") + viewer_window.geometry("350x250") + + frame = ttk.Frame(viewer_window) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # 主题设置 + ttk.Label(frame, text="主题:").pack(anchor="w", pady=2) + theme_var = tk.StringVar(value=self.viewer_config.get("theme", "dark")) + theme_frame = ttk.Frame(frame) + theme_frame.pack(anchor="w", pady=2) + ttk.Radiobutton(theme_frame, text="深色", variable=theme_var, value="dark").pack(side="left", padx=(0, 10)) + ttk.Radiobutton(theme_frame, text="浅色", variable=theme_var, value="light").pack(side="left") + + # 字体大小 + ttk.Label(frame, text="字体大小:").pack(anchor="w", pady=(10, 2)) + font_size_var = tk.IntVar(value=self.viewer_config.get("font_size", 10)) + font_size_spin = ttk.Spinbox(frame, from_=8, to=20, textvariable=font_size_var, width=10) + font_size_spin.pack(anchor="w", pady=2) + + # 最大行数 + ttk.Label(frame, text="最大显示行数:").pack(anchor="w", pady=(10, 2)) + max_lines_var = tk.IntVar(value=self.viewer_config.get("max_lines", 1000)) + max_lines_spin = ttk.Spinbox(frame, from_=100, to=10000, increment=100, textvariable=max_lines_var, width=10) + max_lines_spin.pack(anchor="w", pady=2) + + # 自动滚动 + auto_scroll_var = tk.BooleanVar(value=self.viewer_config.get("auto_scroll", True)) + ttk.Checkbutton(frame, text="自动滚动到底部", variable=auto_scroll_var).pack(anchor="w", pady=(10, 2)) + + # 按钮 + button_frame = ttk.Frame(frame) + button_frame.pack(fill="x", pady=(20, 0)) + + def apply_viewer_settings(): + self.viewer_config["theme"] = theme_var.get() + self.viewer_config["font_size"] = font_size_var.get() + self.viewer_config["max_lines"] = max_lines_var.get() + self.viewer_config["auto_scroll"] = auto_scroll_var.get() + + # 应用主题 + self.apply_theme() + + # 保存配置 + self.save_viewer_config() + + viewer_window.destroy() + + ttk.Button(button_frame, text="应用", command=apply_viewer_settings).pack(side="right", padx=(5, 0)) + ttk.Button(button_frame, text="取消", command=viewer_window.destroy).pack(side="right") + + def apply_theme(self): + """应用主题设置""" + theme = self.viewer_config.get("theme", "dark") + font_size = self.viewer_config.get("font_size", 10) + + # 更新虚拟显示组件的主题 + if theme == "dark": + bg_color = "#1e1e1e" + fg_color = "#ffffff" + select_bg = "#404040" + else: + bg_color = "#ffffff" + fg_color = "#000000" + select_bg = "#c0c0c0" + + self.log_display.text_widget.config( + background=bg_color, foreground=fg_color, selectbackground=select_bg, font=("Consolas", font_size) + ) + + # 重新配置标签样式 + self.log_display.configure_text_tags() + + def reload_config(self): + """重新加载配置""" + self.load_config() + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.apply_theme() + self.filter_logs() + + def clear_log_display(self): + """清空日志显示""" + self.log_display.text_widget.delete(1.0, tk.END) + + def export_logs(self): + """导出当前显示的日志""" + filename = filedialog.asksaveasfilename( + defaultextension=".txt", filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")] + ) + if filename: + try: + # 获取当前显示的所有日志条目 + if self.log_index: + filtered_count = self.log_index.get_filtered_count() + log_lines = [] + for i in range(filtered_count): + log_entry = self.log_index.get_entry_at_filtered_position(i) + if log_entry: + parts, tags = self.formatter.format_log_entry(log_entry) + line_text = " ".join(parts) + log_lines.append(line_text) + + with open(filename, "w", encoding="utf-8") as f: + f.write("\n".join(log_lines)) + messagebox.showinfo("导出成功", f"日志已导出到: {filename}") + else: + messagebox.showwarning("导出失败", "没有日志可导出") + except Exception as e: + messagebox.showerror("导出失败", f"导出日志时出错: {e}") + + def load_module_mapping(self): + """加载自定义模块映射""" + mapping_file = Path("config/module_mapping.json") + if mapping_file.exists(): + try: + with open(mapping_file, "r", encoding="utf-8") as f: + custom_mapping = json.load(f) + self.module_name_mapping.update(custom_mapping) + except Exception as e: + print(f"加载模块映射失败: {e}") + + def save_module_mapping(self): + """保存自定义模块映射""" + mapping_file = Path("config/module_mapping.json") + mapping_file.parent.mkdir(exist_ok=True) + try: + with open(mapping_file, "w", encoding="utf-8") as f: + json.dump(self.module_name_mapping, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存模块映射失败: {e}") + + def show_color_settings(self): + """显示颜色设置窗口""" + color_window = tk.Toplevel(self.root) + color_window.title("颜色设置") + color_window.geometry("300x400") + + # 创建滚动框架 + frame = ttk.Frame(color_window) + frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # 创建滚动条 + scrollbar = ttk.Scrollbar(frame) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 创建颜色设置列表 + canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) + canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.config(command=canvas.yview) + + # 创建内部框架 + inner_frame = ttk.Frame(canvas) + canvas.create_window((0, 0), window=inner_frame, anchor="nw") + + # 添加日志级别颜色设置 + ttk.Label(inner_frame, text="日志级别颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) + for level in ["info", "warning", "error"]: + frame = ttk.Frame(inner_frame) + frame.pack(fill=tk.X, padx=5, pady=2) + ttk.Label(frame, text=level).pack(side=tk.LEFT) + color_btn = ttk.Button( + frame, text="选择颜色", command=lambda level_name=level: self.choose_color(level_name) + ) + color_btn.pack(side=tk.RIGHT) + # 显示当前颜色 + color_label = ttk.Label(frame, text="■", foreground=self.formatter.level_colors[level]) + color_label.pack(side=tk.RIGHT, padx=5) + + # 添加模块颜色设置 + ttk.Label(inner_frame, text="\n模块颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) + for module in sorted(self.modules): + frame = ttk.Frame(inner_frame) + frame.pack(fill=tk.X, padx=5, pady=2) + ttk.Label(frame, text=module).pack(side=tk.LEFT) + color_btn = ttk.Button(frame, text="选择颜色", command=lambda m=module: self.choose_module_color(m)) + color_btn.pack(side=tk.RIGHT) + # 显示当前颜色 + color = self.formatter.module_colors.get(module, "black") + color_label = ttk.Label(frame, text="■", foreground=color) + color_label.pack(side=tk.RIGHT, padx=5) + + # 更新画布滚动区域 + inner_frame.update_idletasks() + canvas.config(scrollregion=canvas.bbox("all")) + + # 添加确定按钮 + ttk.Button(color_window, text="确定", command=color_window.destroy).pack(pady=5) + + def choose_color(self, level): + """选择日志级别颜色""" + color = colorchooser.askcolor(color=self.formatter.level_colors[level])[1] + if color: + self.formatter.level_colors[level] = color + self.custom_level_colors[level] = color # 保存到自定义颜色 + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.save_viewer_config() # 自动保存配置 + self.filter_logs() + + def choose_module_color(self, module): + """选择模块颜色""" + color = colorchooser.askcolor(color=self.formatter.module_colors.get(module, "black"))[1] + if color: + self.formatter.module_colors[module] = color + self.custom_module_colors[module] = color # 保存到自定义颜色 + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.save_viewer_config() # 自动保存配置 + self.filter_logs() def create_control_panel(self): """创建控制面板""" @@ -549,30 +974,43 @@ class LogViewer: side=tk.LEFT, padx=2 ) - # 过滤控制框架 - filter_frame = ttk.Frame(self.control_frame) - filter_frame.pack(fill=tk.X, padx=5) + # 模块选择框架 + self.module_frame = ttk.LabelFrame(self.control_frame, text="模块") + self.module_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5) + + # 创建模块选择滚动区域 + self.module_canvas = tk.Canvas(self.module_frame, height=80) + self.module_canvas.pack(side=tk.LEFT, fill=tk.X, expand=True) + + # 创建模块选择内部框架 + self.module_inner_frame = ttk.Frame(self.module_canvas) + self.module_canvas.create_window((0, 0), window=self.module_inner_frame, anchor="nw") + + # 创建右侧控制区域(级别和搜索) + self.right_control_frame = ttk.Frame(self.control_frame) + self.right_control_frame.pack(side=tk.RIGHT, padx=5) + + # 映射编辑按钮 + mapping_btn = ttk.Button(self.right_control_frame, text="模块映射", command=self.edit_module_mapping) + mapping_btn.pack(side=tk.TOP, fill=tk.X, pady=1) # 日志级别选择 - ttk.Label(filter_frame, text="级别:").pack(side=tk.LEFT, padx=2) + level_frame = ttk.Frame(self.right_control_frame) + level_frame.pack(side=tk.TOP, fill=tk.X, pady=1) + ttk.Label(level_frame, text="级别:").pack(side=tk.LEFT, padx=2) self.level_var = tk.StringVar(value="全部") - self.level_combo = ttk.Combobox(filter_frame, textvariable=self.level_var, width=8) + self.level_combo = ttk.Combobox(level_frame, textvariable=self.level_var, width=8) self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"] self.level_combo.pack(side=tk.LEFT, padx=2) # 搜索框 - ttk.Label(filter_frame, text="搜索:").pack(side=tk.LEFT, padx=(20, 2)) + search_frame = ttk.Frame(self.right_control_frame) + search_frame.pack(side=tk.TOP, fill=tk.X, pady=1) + ttk.Label(search_frame, text="搜索:").pack(side=tk.LEFT, padx=2) self.search_var = tk.StringVar() - self.search_entry = ttk.Entry(filter_frame, textvariable=self.search_var, width=20) + self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=15) self.search_entry.pack(side=tk.LEFT, padx=2) - # 模块选择 - ttk.Label(filter_frame, text="模块:").pack(side=tk.LEFT, padx=(20, 2)) - self.module_var = tk.StringVar(value="全部") - self.module_combo = ttk.Combobox(filter_frame, textvariable=self.module_var, width=15) - self.module_combo.pack(side=tk.LEFT, padx=2) - self.module_combo.bind("<>", self.on_module_selected) - def on_file_loaded(self, log_index, error): """文件加载完成回调""" self.progress_bar.pack_forget() @@ -590,6 +1028,7 @@ class LogViewer: self.status_var.set(f"已加载 {log_index.total_entries} 条日志") # 更新模块列表 + self.modules = set(log_index.module_index.keys()) self.update_module_list() # 应用过滤并显示 @@ -623,22 +1062,11 @@ class LogViewer: # 清空当前数据 self.log_index = LogIndex() - self.modules.clear() self.selected_modules.clear() - self.module_var.set("全部") # 开始异步加载 self.async_loader.load_file_async(str(self.current_log_file), self.on_loading_progress) - def on_module_selected(self, event=None): - """模块选择事件""" - module = self.module_var.get() - if module == "全部": - self.selected_modules = {"全部"} - else: - self.selected_modules = {module} - self.filter_logs() - def filter_logs(self, *args): """过滤日志""" if not self.log_index: @@ -743,7 +1171,7 @@ class LogViewer: def read_new_logs(self, from_position): """读取新的日志条目并返回它们""" new_entries = [] - new_modules_found = False + new_modules = set() # 收集新发现的模块 with open(self.current_log_file, "r", encoding="utf-8") as f: f.seek(from_position) line_count = self.log_index.total_entries @@ -756,14 +1184,20 @@ class LogViewer: logger_name = log_entry.get("logger_name", "") if logger_name and logger_name not in self.modules: - self.modules.add(logger_name) - new_modules_found = True + new_modules.add(logger_name) line_count += 1 except json.JSONDecodeError: continue - if new_modules_found: - self.root.after(0, self.update_module_list) + + # 如果发现了新模块,在主线程中更新模块集合 + if new_modules: + def update_modules(): + self.modules.update(new_modules) + self.update_module_list() + + self.root.after(0, update_modules) + return new_entries def append_new_logs(self, new_entries): @@ -791,15 +1225,196 @@ class LogViewer: self.status_var.set(f"显示 {total_count} 条日志") def update_module_list(self): - """更新模块下拉列表""" - current_selection = self.module_var.get() - self.modules = set(self.log_index.module_index.keys()) - module_values = ["全部"] + sorted(list(self.modules)) - self.module_combo["values"] = module_values - if current_selection in module_values: - self.module_var.set(current_selection) + """更新模块列表""" + # 清空现有选项 + for widget in self.module_inner_frame.winfo_children(): + widget.destroy() + + # 计算总模块数(包括"全部") + total_modules = len(self.modules) + 1 + max_cols = min(4, max(2, total_modules)) # 减少最大列数,避免超出边界 + + # 配置网格列权重,让每列平均分配空间 + for i in range(max_cols): + self.module_inner_frame.grid_columnconfigure(i, weight=1, uniform="module_col") + + # 创建一个多行布局 + current_row = 0 + current_col = 0 + + # 添加"全部"选项 + all_frame = ttk.Frame(self.module_inner_frame) + all_frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") + + all_var = tk.BooleanVar(value="全部" in self.selected_modules) + all_check = ttk.Checkbutton( + all_frame, text="全部", variable=all_var, command=lambda: self.toggle_module("全部", all_var) + ) + all_check.pack(side=tk.LEFT) + + # 使用颜色标签替代按钮 + all_color = self.formatter.module_colors.get("全部", "black") + all_color_label = ttk.Label(all_frame, text="■", foreground=all_color, width=2, cursor="hand2") + all_color_label.pack(side=tk.LEFT, padx=2) + all_color_label.bind("", lambda e: self.choose_module_color("全部")) + + current_col += 1 + + # 添加其他模块选项 + for module in sorted(self.modules): + if current_col >= max_cols: + current_row += 1 + current_col = 0 + + frame = ttk.Frame(self.module_inner_frame) + frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") + + var = tk.BooleanVar(value=module in self.selected_modules) + + # 使用中文映射名称显示 + display_name = self.get_display_name(module) + if len(display_name) > 12: + display_name = display_name[:10] + "..." + + check = ttk.Checkbutton( + frame, text=display_name, variable=var, command=lambda m=module, v=var: self.toggle_module(m, v) + ) + check.pack(side=tk.LEFT) + + # 添加工具提示显示完整名称和英文名 + full_tooltip = f"{self.get_display_name(module)}" + if module != self.get_display_name(module): + full_tooltip += f"\n({module})" + self.create_tooltip(check, full_tooltip) + + # 使用颜色标签替代按钮 + color = self.formatter.module_colors.get(module, "black") + color_label = ttk.Label(frame, text="■", foreground=color, width=2, cursor="hand2") + color_label.pack(side=tk.LEFT, padx=2) + color_label.bind("", lambda e, m=module: self.choose_module_color(m)) + + current_col += 1 + + # 更新画布滚动区域 + self.module_inner_frame.update_idletasks() + self.module_canvas.config(scrollregion=self.module_canvas.bbox("all")) + + # 添加垂直滚动条 + if not hasattr(self, "module_scrollbar"): + self.module_scrollbar = ttk.Scrollbar( + self.module_frame, orient=tk.VERTICAL, command=self.module_canvas.yview + ) + self.module_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.module_canvas.config(yscrollcommand=self.module_scrollbar.set) + + def create_tooltip(self, widget, text): + """为控件创建工具提示""" + + def on_enter(event): + tooltip = tk.Toplevel() + tooltip.wm_overrideredirect(True) + tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}") + label = ttk.Label(tooltip, text=text, background="lightyellow", relief="solid", borderwidth=1) + label.pack() + widget.tooltip = tooltip + + def on_leave(event): + if hasattr(widget, "tooltip"): + widget.tooltip.destroy() + del widget.tooltip + + widget.bind("", on_enter) + widget.bind("", on_leave) + + def toggle_module(self, module, var): + """切换模块选择状态""" + if module == "全部": + if var.get(): + self.selected_modules = {"全部"} + else: + self.selected_modules.clear() else: - self.module_var.set("全部") + if var.get(): + self.selected_modules.add(module) + if "全部" in self.selected_modules: + self.selected_modules.remove("全部") + else: + self.selected_modules.discard(module) + + self.filter_logs() + + def get_display_name(self, module_name): + """获取模块的显示名称""" + return self.module_name_mapping.get(module_name, module_name) + + def edit_module_mapping(self): + """编辑模块映射""" + mapping_window = tk.Toplevel(self.root) + mapping_window.title("编辑模块映射") + mapping_window.geometry("500x600") + + # 创建滚动框架 + frame = ttk.Frame(mapping_window) + frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # 创建滚动条 + scrollbar = ttk.Scrollbar(frame) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 创建映射编辑列表 + canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) + canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.config(command=canvas.yview) + + # 创建内部框架 + inner_frame = ttk.Frame(canvas) + canvas.create_window((0, 0), window=inner_frame, anchor="nw") + + # 添加标题 + ttk.Label(inner_frame, text="模块映射编辑", font=("", 12, "bold")).pack(anchor="w", padx=5, pady=5) + ttk.Label(inner_frame, text="英文名 -> 中文名", font=("", 10)).pack(anchor="w", padx=5, pady=2) + + # 映射编辑字典 + mapping_vars = {} + + # 添加现有模块的映射编辑 + all_modules = sorted(self.modules) + for module in all_modules: + frame_row = ttk.Frame(inner_frame) + frame_row.pack(fill=tk.X, padx=5, pady=2) + + ttk.Label(frame_row, text=module, width=20).pack(side=tk.LEFT, padx=5) + ttk.Label(frame_row, text="->").pack(side=tk.LEFT, padx=5) + + var = tk.StringVar(value=self.module_name_mapping.get(module, module)) + mapping_vars[module] = var + entry = ttk.Entry(frame_row, textvariable=var, width=25) + entry.pack(side=tk.LEFT, padx=5) + + # 更新画布滚动区域 + inner_frame.update_idletasks() + canvas.config(scrollregion=canvas.bbox("all")) + + def save_mappings(): + # 更新映射 + for module, var in mapping_vars.items(): + new_name = var.get().strip() + if new_name and new_name != module: + self.module_name_mapping[module] = new_name + elif module in self.module_name_mapping and not new_name: + del self.module_name_mapping[module] + + # 保存到文件 + self.save_module_mapping() + # 更新模块列表显示 + self.update_module_list() + mapping_window.destroy() + + # 添加按钮 + button_frame = ttk.Frame(mapping_window) + button_frame.pack(fill=tk.X, padx=5, pady=5) + ttk.Button(button_frame, text="保存", command=save_mappings).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="取消", command=mapping_window.destroy).pack(side=tk.RIGHT, padx=5) def main(): diff --git a/scripts/preview_expressions.py b/scripts/preview_expressions.py deleted file mode 100644 index 1e71120d8..000000000 --- a/scripts/preview_expressions.py +++ /dev/null @@ -1,278 +0,0 @@ -import tkinter as tk -from tkinter import ttk -import json -import os -from pathlib import Path -import networkx as nx -import matplotlib.pyplot as plt -from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -from collections import defaultdict - - -class ExpressionViewer: - def __init__(self, root): - self.root = root - self.root.title("表达方式预览器") - self.root.geometry("1200x800") - - # 创建主框架 - self.main_frame = ttk.Frame(root) - self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - - # 创建左侧控制面板 - self.control_frame = ttk.Frame(self.main_frame) - self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10)) - - # 创建搜索框 - self.search_frame = ttk.Frame(self.control_frame) - self.search_frame.pack(fill=tk.X, pady=(0, 10)) - - self.search_var = tk.StringVar() - self.search_var.trace("w", self.filter_expressions) - self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var) - self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) - ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5)) - - # 创建文件选择下拉框 - self.file_var = tk.StringVar() - self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var) - self.file_combo.pack(side=tk.LEFT, padx=5) - self.file_combo.bind("<>", self.load_file) - - # 创建排序选项 - self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项") - self.sort_frame.pack(fill=tk.X, pady=5) - - self.sort_var = tk.StringVar(value="count") - ttk.Radiobutton( - self.sort_frame, text="按计数排序", variable=self.sort_var, value="count", command=self.apply_sort - ).pack(anchor=tk.W) - ttk.Radiobutton( - self.sort_frame, text="按情境排序", variable=self.sort_var, value="situation", command=self.apply_sort - ).pack(anchor=tk.W) - ttk.Radiobutton( - self.sort_frame, text="按风格排序", variable=self.sort_var, value="style", command=self.apply_sort - ).pack(anchor=tk.W) - - # 创建分群选项 - self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项") - self.group_frame.pack(fill=tk.X, pady=5) - - self.group_var = tk.StringVar(value="none") - ttk.Radiobutton( - self.group_frame, text="不分群", variable=self.group_var, value="none", command=self.apply_grouping - ).pack(anchor=tk.W) - ttk.Radiobutton( - self.group_frame, text="按情境分群", variable=self.group_var, value="situation", command=self.apply_grouping - ).pack(anchor=tk.W) - ttk.Radiobutton( - self.group_frame, text="按风格分群", variable=self.group_var, value="style", command=self.apply_grouping - ).pack(anchor=tk.W) - - # 创建相似度阈值滑块 - self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置") - self.similarity_frame.pack(fill=tk.X, pady=5) - - self.similarity_var = tk.DoubleVar(value=0.5) - self.similarity_scale = ttk.Scale( - self.similarity_frame, - from_=0.0, - to=1.0, - variable=self.similarity_var, - orient=tk.HORIZONTAL, - command=self.update_similarity, - ) - self.similarity_scale.pack(fill=tk.X, padx=5, pady=5) - ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack() - - # 创建显示选项 - self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项") - self.view_frame.pack(fill=tk.X, pady=5) - - self.show_graph_var = tk.BooleanVar(value=True) - ttk.Checkbutton( - self.view_frame, text="显示关系图", variable=self.show_graph_var, command=self.toggle_graph - ).pack(anchor=tk.W) - - # 创建右侧内容区域 - self.content_frame = ttk.Frame(self.main_frame) - self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - - # 创建文本显示区域 - self.text_area = tk.Text(self.content_frame, wrap=tk.WORD) - self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - - # 添加滚动条 - scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - self.text_area.config(yscrollcommand=scrollbar.set) - - # 创建图形显示区域 - self.graph_frame = ttk.Frame(self.content_frame) - self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - - # 初始化数据 - self.current_data = [] - self.graph = nx.Graph() - self.canvas = None - - # 加载文件列表 - self.load_file_list() - - def load_file_list(self): - expression_dir = Path("data/expression") - files = [] - for root, _, filenames in os.walk(expression_dir): - for filename in filenames: - if filename.endswith(".json"): - rel_path = os.path.relpath(os.path.join(root, filename), expression_dir) - files.append(rel_path) - - self.file_combo["values"] = files - if files: - self.file_combo.set(files[0]) - self.load_file(None) - - def load_file(self, event): - selected_file = self.file_var.get() - if not selected_file: - return - - file_path = os.path.join("data/expression", selected_file) - try: - with open(file_path, "r", encoding="utf-8") as f: - self.current_data = json.load(f) - - self.apply_sort() - self.update_similarity() - - except Exception as e: - self.text_area.delete(1.0, tk.END) - self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}") - - def apply_sort(self): - if not self.current_data: - return - - sort_key = self.sort_var.get() - reverse = sort_key == "count" - - self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse) - self.apply_grouping() - - def apply_grouping(self): - if not self.current_data: - return - - group_key = self.group_var.get() - if group_key == "none": - self.display_data(self.current_data) - return - - grouped_data = defaultdict(list) - for item in self.current_data: - key = item.get(group_key, "未分类") - grouped_data[key].append(item) - - self.text_area.delete(1.0, tk.END) - for group, items in grouped_data.items(): - self.text_area.insert(tk.END, f"\n=== {group} ===\n\n") - for item in items: - self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n") - self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") - self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") - self.text_area.insert(tk.END, "-" * 50 + "\n") - - def display_data(self, data): - self.text_area.delete(1.0, tk.END) - for item in data: - self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n") - self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") - self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") - self.text_area.insert(tk.END, "-" * 50 + "\n") - - def update_similarity(self, *args): - if not self.current_data: - return - - threshold = self.similarity_var.get() - self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}") - - # 计算相似度 - texts = [f"{item['situation']} {item['style']}" for item in self.current_data] - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform(texts) - similarity_matrix = cosine_similarity(tfidf_matrix) - - # 创建图 - self.graph.clear() - for i, item in enumerate(self.current_data): - self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}") - - # 添加边 - for i in range(len(self.current_data)): - for j in range(i + 1, len(self.current_data)): - if similarity_matrix[i, j] > threshold: - self.graph.add_edge(i, j, weight=similarity_matrix[i, j]) - - if self.show_graph_var.get(): - self.draw_graph() - - def draw_graph(self): - if self.canvas: - self.canvas.get_tk_widget().destroy() - - fig = plt.figure(figsize=(8, 6)) - pos = nx.spring_layout(self.graph) - - # 绘制节点 - nx.draw_networkx_nodes(self.graph, pos, node_color="lightblue", node_size=1000, alpha=0.6) - - # 绘制边 - nx.draw_networkx_edges(self.graph, pos, alpha=0.4) - - # 添加标签 - labels = nx.get_node_attributes(self.graph, "label") - nx.draw_networkx_labels(self.graph, pos, labels, font_size=8) - - plt.title("表达方式关系图") - plt.axis("off") - - self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame) - self.canvas.draw() - self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) - - def toggle_graph(self): - if self.show_graph_var.get(): - self.draw_graph() - else: - if self.canvas: - self.canvas.get_tk_widget().destroy() - self.canvas = None - - def filter_expressions(self, *args): - search_text = self.search_var.get().lower() - if not search_text: - self.apply_sort() - return - - filtered_data = [] - for item in self.current_data: - situation = item.get("situation", "").lower() - style = item.get("style", "").lower() - if search_text in situation or search_text in style: - filtered_data.append(item) - - self.display_data(filtered_data) - - -def main(): - root = tk.Tk() - # app = ExpressionViewer(root) - root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/scripts/view_hfc_stats.py b/scripts/view_hfc_stats.py deleted file mode 100644 index 75e792e25..000000000 --- a/scripts/view_hfc_stats.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -""" -HFC性能统计数据查看工具 -""" - -import sys -import json -import argparse -from pathlib import Path -from typing import Dict, Any - -# 添加项目根目录到Python路径 -sys.path.insert(0, str(Path(__file__).parent.parent)) - - -def format_time(seconds: float) -> str: - """格式化时间显示""" - if seconds < 1: - return f"{seconds * 1000:.1f}毫秒" - else: - return f"{seconds:.3f}秒" - - -def display_chat_stats(chat_id: str, stats: Dict[str, Any]): - """显示单个聊天的统计数据""" - print(f"\n=== Chat ID: {chat_id} ===") - print(f"版本: {stats.get('version', 'unknown')}") - print(f"最后更新: {stats['last_updated']}") - - overall = stats["overall"] - print("\n📊 总体统计:") - print(f" 总记录数: {overall['total_records']}") - print(f" 平均总时间: {format_time(overall['avg_total_time'])}") - - print("\n⏱️ 各步骤平均时间:") - for step, avg_time in overall["avg_step_times"].items(): - print(f" {step}: {format_time(avg_time)}") - - print("\n🎯 按动作类型统计:") - by_action = stats["by_action"] - - # 按比例排序 - sorted_actions = sorted(by_action.items(), key=lambda x: x[1]["percentage"], reverse=True) - - for action, action_stats in sorted_actions: - print(f" 📌 {action}:") - print(f" 次数: {action_stats['count']} ({action_stats['percentage']:.1f}%)") - print(f" 平均总时间: {format_time(action_stats['avg_total_time'])}") - - if action_stats["avg_step_times"]: - print(" 步骤时间:") - for step, step_time in action_stats["avg_step_times"].items(): - print(f" {step}: {format_time(step_time)}") - - -def display_comparison(stats_data: Dict[str, Dict[str, Any]]): - """显示多个聊天的对比数据""" - if len(stats_data) < 2: - return - - print("\n=== 多聊天对比 ===") - - # 创建对比表格 - chat_ids = list(stats_data.keys()) - - print("\n📊 总体对比:") - print(f"{'Chat ID':<20} {'版本':<12} {'记录数':<8} {'平均时间':<12} {'最常见动作':<15}") - print("-" * 70) - - for chat_id in chat_ids: - stats = stats_data[chat_id] - overall = stats["overall"] - - # 找到最常见的动作 - most_common_action = max(stats["by_action"].items(), key=lambda x: x[1]["count"]) - most_common_name = most_common_action[0] - most_common_pct = most_common_action[1]["percentage"] - - version = stats.get("version", "unknown") - print( - f"{chat_id:<20} {version:<12} {overall['total_records']:<8} {format_time(overall['avg_total_time']):<12} {most_common_name}({most_common_pct:.0f}%)" - ) - - -def view_session_logs(chat_id: str = None, latest: bool = False): - """查看会话日志文件""" - log_dir = Path("log/hfc_loop") - if not log_dir.exists(): - print("❌ 日志目录不存在") - return - - if chat_id: - pattern = f"{chat_id}_*.json" - else: - pattern = "*.json" - - log_files = list(log_dir.glob(pattern)) - - if not log_files: - print(f"❌ 没有找到匹配的日志文件: {pattern}") - return - - if latest: - # 按文件修改时间排序,取最新的 - log_files.sort(key=lambda f: f.stat().st_mtime, reverse=True) - log_files = log_files[:1] - - for log_file in log_files: - print(f"\n=== 会话日志: {log_file.name} ===") - - try: - with open(log_file, "r", encoding="utf-8") as f: - records = json.load(f) - - if not records: - print(" 空文件") - continue - - print(f" 记录数: {len(records)}") - print(f" 时间范围: {records[0]['timestamp']} ~ {records[-1]['timestamp']}") - - # 统计动作分布 - action_counts = {} - total_time = 0 - - for record in records: - action = record["action_type"] - action_counts[action] = action_counts.get(action, 0) + 1 - total_time += record["total_time"] - - print(f" 总耗时: {format_time(total_time)}") - print(f" 平均耗时: {format_time(total_time / len(records))}") - print(f" 动作分布: {dict(action_counts)}") - - except Exception as e: - print(f" ❌ 读取文件失败: {e}") - - -def main(): - parser = argparse.ArgumentParser(description="HFC性能统计数据查看工具") - parser.add_argument("--chat-id", help="指定要查看的Chat ID") - parser.add_argument("--logs", action="store_true", help="查看会话日志文件") - parser.add_argument("--latest", action="store_true", help="只显示最新的日志文件") - parser.add_argument("--compare", action="store_true", help="显示多聊天对比") - - args = parser.parse_args() - - if args.logs: - view_session_logs(args.chat_id, args.latest) - return - - # 读取统计数据 - stats_file = Path("data/hfc/time.json") - if not stats_file.exists(): - print("❌ 统计数据文件不存在,请先运行一些HFC循环以生成数据") - return - - try: - with open(stats_file, "r", encoding="utf-8") as f: - stats_data = json.load(f) - except Exception as e: - print(f"❌ 读取统计数据失败: {e}") - return - - if not stats_data: - print("❌ 统计数据为空") - return - - if args.chat_id: - if args.chat_id in stats_data: - display_chat_stats(args.chat_id, stats_data[args.chat_id]) - else: - print(f"❌ 没有找到Chat ID '{args.chat_id}' 的数据") - print(f"可用的Chat ID: {list(stats_data.keys())}") - else: - # 显示所有聊天的统计数据 - for chat_id, stats in stats_data.items(): - display_chat_stats(chat_id, stats) - - if args.compare: - display_comparison(stats_data) - - -if __name__ == "__main__": - main() diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 5a82e8390..41101b2dd 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -20,10 +20,9 @@ from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo, ChatMode from src.plugin_system.apis import generator_api, send_api, message_api from src.chat.willing.willing_manager import get_willing_manager -from src.chat.mai_thinking.mai_think import mai_thinking_manager -from maim_message.message_base import GroupInfo,UserInfo - -ENABLE_THINKING = False +from src.mais4u.mai_think import mai_thinking_manager +from maim_message.message_base import GroupInfo +from src.mais4u.constant_s4u import ENABLE_S4U ERROR_LOOP_INFO = { "loop_plan_info": { @@ -237,12 +236,12 @@ class HeartFChatting: if if_think: factor = max(global_config.chat.focus_value, 0.1) self.energy_value *= 1.1 / factor - logger.info(f"{self.log_prefix} 麦麦进行了思考,能量值按倍数增加,当前能量值:{self.energy_value}") + logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}") else: self.energy_value += 0.1 / global_config.chat.focus_value - logger.info(f"{self.log_prefix} 麦麦没有进行思考,能量值线性增加,当前能量值:{self.energy_value}") + logger.debug(f"{self.log_prefix} 没有进行思考,能量值线性增加,当前能量值:{self.energy_value:.1f}") - logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value}") + logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}") return True await asyncio.sleep(1) @@ -257,31 +256,29 @@ class HeartFChatting: ) person_name = await person_info_manager.get_value(person_id, "person_name") return f"{person_name}:{message_data.get('processed_plain_text')}" - + async def send_typing(self): - group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心") - - chat = await get_chat_manager().get_or_create_stream( - platform = "amaidesu_default", - user_info = None, - group_info = group_info + group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") + + chat = await get_chat_manager().get_or_create_stream( + platform="amaidesu_default", + user_info=None, + group_info=group_info, ) - - + await send_api.custom_to_stream( message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False ) - + async def stop_typing(self): - group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心") - - chat = await get_chat_manager().get_or_create_stream( - platform = "amaidesu_default", - user_info = None, - group_info = group_info + group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") + + chat = await get_chat_manager().get_or_create_stream( + platform="amaidesu_default", + user_info=None, + group_info=group_info, ) - - + await send_api.custom_to_stream( message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False ) @@ -296,7 +293,8 @@ class HeartFChatting: logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]") - await self.send_typing() + if ENABLE_S4U: + await self.send_typing() async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): loop_start_time = time.time() @@ -366,13 +364,13 @@ class HeartFChatting: # 发送回复 (不再需要传入 chat) reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data) - await self.stop_typing() - if ENABLE_THINKING: + + if ENABLE_S4U: + await self.stop_typing() await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text) - return True @@ -504,10 +502,9 @@ class HeartFChatting: """ interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier - + self.willing_manager.setup(message_data, self.chat_stream) - - + reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", "")) talk_frequency = -1.00 @@ -517,7 +514,7 @@ class HeartFChatting: if additional_config and "maimcore_reply_probability_gain" in additional_config: reply_probability += additional_config["maimcore_reply_probability_gain"] reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间 - + talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id) reply_probability = talk_frequency * reply_probability @@ -527,9 +524,9 @@ class HeartFChatting: # 打印消息信息 mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊" - + # logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%") - + if reply_probability > 0.05: logger.info( f"[{mes_name}]" @@ -545,7 +542,6 @@ class HeartFChatting: # 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除) self.willing_manager.delete(message_data.get("message_id", "")) return False - async def _generate_response( self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str @@ -570,7 +566,7 @@ class HeartFChatting: logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}") return None - async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data): + async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data): current_time = time.time() new_message_count = message_api.count_new_messages( chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time @@ -581,9 +577,14 @@ class HeartFChatting: need_reply = new_message_count >= random.randint(2, 4) - logger.info( - f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复" - ) + if need_reply: + logger.info( + f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复" + ) + else: + logger.debug( + f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复" + ) reply_text = "" first_replied = False @@ -592,13 +593,27 @@ class HeartFChatting: if not first_replied: if need_reply: await send_api.text_to_stream( - text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False + text=data, + stream_id=self.chat_stream.stream_id, + reply_to=reply_to, + reply_to_platform_id=reply_to_platform_id, + typing=False, ) else: - await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False) + await send_api.text_to_stream( + text=data, + stream_id=self.chat_stream.stream_id, + reply_to_platform_id=reply_to_platform_id, + typing=False, + ) first_replied = True else: - await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True) + await send_api.text_to_stream( + text=data, + stream_id=self.chat_stream.stream_id, + reply_to_platform_id=reply_to_platform_id, + typing=True, + ) reply_text += data return reply_text diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index dd9f12c0d..b3c2493d3 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -836,7 +836,7 @@ class EmojiManager: return False async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: - """获取表情包描述和情感列表 + """获取表情包描述和情感列表,优化复用已有描述 Args: image_base64: 图片的base64编码 @@ -850,18 +850,35 @@ class EmojiManager: if isinstance(image_base64, str): image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) + image_hash = hashlib.md5(image_bytes).hexdigest() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore - # 调用AI获取描述 - if image_format == "gif" or image_format == "GIF": - image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore - if not image_base64: - raise RuntimeError("GIF表情包转换失败") - prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") + # 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成) + existing_description = None + try: + from src.common.database.database_model import Images + existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + if existing_image and existing_image.description: + existing_description = existing_image.description + logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") + except Exception as e: + logger.debug(f"查询已有描述时出错: {e}") + + # 第一步:VLM视觉分析(如果没有已有描述才调用) + if existing_description: + description = existing_description + logger.info("[优化] 复用已有的详细描述,跳过VLM调用") else: - prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + logger.info("[VLM分析] 生成新的详细描述") + if image_format == "gif" or image_format == "GIF": + image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore + if not image_base64: + raise RuntimeError("GIF表情包转换失败") + prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") + else: + prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) # 审核表情包 if global_config.emoji.content_filtration: @@ -877,7 +894,7 @@ class EmojiManager: if content == "否": return "", [] - # 分析情感含义 + # 第二步:LLM情感分析 - 基于详细描述生成情感标签列表 emotion_prompt = f""" 请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 这是一个基于这个表情包的描述:'{description}' @@ -889,12 +906,14 @@ class EmojiManager: # 处理情感列表 emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] - # 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个 + # 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个 if len(emotions) > 5: emotions = random.sample(emotions, 3) elif len(emotions) > 2: emotions = random.sample(emotions, 2) + logger.info(f"[注册分析] 详细描述: {description[:50]}... -> 情感标签: {emotions}") + return f"[表情包:{description}]", emotions except Exception as e: diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index e02ff7311..ac41b12a3 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -2,6 +2,7 @@ import time import random import json import os +from datetime import datetime from typing import List, Dict, Optional, Any, Tuple @@ -21,6 +22,16 @@ DECAY_MIN = 0.01 # 最小衰减值 logger = get_logger("expressor") +def format_create_date(timestamp: float) -> str: + """ + 将时间戳格式化为可读的日期字符串 + """ + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + def init_prompt() -> None: learn_style_prompt = """ {chat_str} @@ -76,35 +87,90 @@ class ExpressionLearner: request_type="expressor.learner", ) self.llm_model = None + self._ensure_expression_directories() self._auto_migrate_json_to_db() + self._migrate_old_data_create_date() + + def _ensure_expression_directories(self): + """ + 确保表达方式相关的目录结构存在 + """ + base_dir = os.path.join("data", "expression") + directories_to_create = [ + base_dir, + os.path.join(base_dir, "learnt_style"), + os.path.join(base_dir, "learnt_grammar"), + ] + + for directory in directories_to_create: + try: + os.makedirs(directory, exist_ok=True) + logger.debug(f"确保目录存在: {directory}") + except Exception as e: + logger.error(f"创建目录失败 {directory}: {e}") def _auto_migrate_json_to_db(self): """ 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。 """ - done_flag = os.path.join("data", "expression", "done.done") + base_dir = os.path.join("data", "expression") + done_flag = os.path.join(base_dir, "done.done") + + # 确保基础目录存在 + try: + os.makedirs(base_dir, exist_ok=True) + logger.debug(f"确保目录存在: {base_dir}") + except Exception as e: + logger.error(f"创建表达方式目录失败: {e}") + return + if os.path.exists(done_flag): logger.info("表达方式JSON已迁移,无需重复迁移。") return - base_dir = os.path.join("data", "expression") + + logger.info("开始迁移表达方式JSON到数据库...") + migrated_count = 0 + for type in ["learnt_style", "learnt_grammar"]: type_str = "style" if type == "learnt_style" else "grammar" type_dir = os.path.join(base_dir, type) if not os.path.exists(type_dir): + logger.debug(f"目录不存在,跳过: {type_dir}") continue - for chat_id in os.listdir(type_dir): + + try: + chat_ids = os.listdir(type_dir) + logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录") + except Exception as e: + logger.error(f"读取目录失败 {type_dir}: {e}") + continue + + for chat_id in chat_ids: expr_file = os.path.join(type_dir, chat_id, "expressions.json") if not os.path.exists(expr_file): continue try: with open(expr_file, "r", encoding="utf-8") as f: expressions = json.load(f) + + if not isinstance(expressions, list): + logger.warning(f"表达方式文件格式错误,跳过: {expr_file}") + continue + for expr in expressions: + if not isinstance(expr, dict): + continue + situation = expr.get("situation") style_val = expr.get("style") count = expr.get("count", 1) last_active_time = expr.get("last_active_time", time.time()) + + if not situation or not style_val: + logger.warning(f"表达方式缺少必要字段,跳过: {expr}") + continue + # 查重:同chat_id+type+situation+style from src.common.database.database_model import Expression @@ -127,18 +193,54 @@ class ExpressionLearner: last_active_time=last_active_time, chat_id=chat_id, type=type_str, + create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 ) - logger.info(f"已迁移 {expr_file} 到数据库") + migrated_count += 1 + logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败 {expr_file}: {e}") except Exception as e: logger.error(f"迁移表达方式 {expr_file} 失败: {e}") + # 标记迁移完成 try: + # 确保done.done文件的父目录存在 + done_parent_dir = os.path.dirname(done_flag) + if not os.path.exists(done_parent_dir): + os.makedirs(done_parent_dir, exist_ok=True) + logger.debug(f"为done.done创建父目录: {done_parent_dir}") + with open(done_flag, "w", encoding="utf-8") as f: f.write("done\n") - logger.info("表达方式JSON迁移已完成,已写入done.done标记文件") + logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件") + except PermissionError as e: + logger.error(f"权限不足,无法写入done.done标记文件: {e}") + except OSError as e: + logger.error(f"文件系统错误,无法写入done.done标记文件: {e}") except Exception as e: logger.error(f"写入done.done标记文件失败: {e}") + def _migrate_old_data_create_date(self): + """ + 为没有create_date的老数据设置创建日期 + 使用last_active_time作为create_date的默认值 + """ + try: + # 查找所有create_date为空的表达方式 + old_expressions = Expression.select().where(Expression.create_date.is_null()) + updated_count = 0 + + for expr in old_expressions: + # 使用last_active_time作为create_date + expr.create_date = expr.last_active_time + expr.save() + updated_count += 1 + + if updated_count > 0: + logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") + except Exception as e: + logger.error(f"迁移老数据创建日期失败: {e}") + def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: """ 获取指定chat_id的style和grammar表达方式 @@ -150,6 +252,8 @@ class ExpressionLearner: # 直接从数据库查询 style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style")) for expr in style_query: + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time learnt_style_expressions.append( { "situation": expr.situation, @@ -158,10 +262,13 @@ class ExpressionLearner: "last_active_time": expr.last_active_time, "source_id": chat_id, "type": "style", + "create_date": create_date, } ) grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar")) for expr in grammar_query: + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time learnt_grammar_expressions.append( { "situation": expr.situation, @@ -170,10 +277,40 @@ class ExpressionLearner: "last_active_time": expr.last_active_time, "source_id": chat_id, "type": "grammar", + "create_date": create_date, } ) return learnt_style_expressions, learnt_grammar_expressions + def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]: + """ + 获取指定chat_id的表达方式创建信息,按创建日期排序 + """ + try: + expressions = (Expression.select() + .where(Expression.chat_id == chat_id) + .order_by(Expression.create_date.desc()) + .limit(limit)) + + result = [] + for expr in expressions: + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + result.append({ + "situation": expr.situation, + "style": expr.style, + "type": expr.type, + "count": expr.count, + "create_date": create_date, + "create_date_formatted": format_create_date(create_date), + "last_active_time": expr.last_active_time, + "last_active_formatted": format_create_date(expr.last_active_time), + }) + + return result + except Exception as e: + logger.error(f"获取表达方式创建信息失败: {e}") + return [] + def is_similar(self, s1: str, s2: str) -> bool: """ 判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串) @@ -197,9 +334,17 @@ class ExpressionLearner: for type in ["style", "grammar"]: base_dir = os.path.join("data", "expression", f"learnt_{type}") if not os.path.exists(base_dir): + logger.debug(f"目录不存在,跳过衰减: {base_dir}") continue - for chat_id in os.listdir(base_dir): + try: + chat_ids = os.listdir(base_dir) + logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减") + except Exception as e: + logger.error(f"读取目录失败 {base_dir}: {e}") + continue + + for chat_id in chat_ids: file_path = os.path.join(base_dir, chat_id, "expressions.json") if not os.path.exists(file_path): continue @@ -208,14 +353,24 @@ class ExpressionLearner: with open(file_path, "r", encoding="utf-8") as f: expressions = json.load(f) + if not isinstance(expressions, list): + logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}") + continue + # 应用全局衰减 decayed_expressions = self.apply_decay_to_expressions(expressions, current_time) # 保存衰减后的结果 with open(file_path, "w", encoding="utf-8") as f: json.dump(decayed_expressions, f, ensure_ascii=False, indent=2) + + logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式") + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}") + except PermissionError as e: + logger.error(f"权限不足,无法更新 {file_path}: {e}") except Exception as e: - logger.error(f"全局衰减{type}表达方式失败: {e}") + logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}") continue learnt_style: Optional[List[Tuple[str, str, str]]] = [] @@ -350,6 +505,7 @@ class ExpressionLearner: last_active_time=current_time, chat_id=chat_id, type=type, + create_date=current_time, # 手动设置创建日期 ) # 限制最大数量 exprs = list( diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 4ebad5a0e..d83d3a472 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -132,7 +132,8 @@ class ExpressionSelector: "count": expr.count, "last_active_time": expr.last_active_time, "source_id": cid, - "type": "style" + "type": "style", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, } for expr in style_query ]) grammar_exprs.extend([ @@ -142,7 +143,8 @@ class ExpressionSelector: "count": expr.count, "last_active_time": expr.last_active_time, "source_id": cid, - "type": "grammar" + "type": "grammar", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, } for expr in grammar_query ]) style_num = int(total_num * style_percentage) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index a9d118286..3aa174bb5 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -111,9 +111,9 @@ class HeartFCMessageReceiver: subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) - - chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) - asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) + if global_config.mood.enable_mood: + chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) + asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) # 3. 日志记录 mes_name = chat.group_info.group_name if chat.group_info else "私聊" diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index d229fc94f..a4228b89a 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -13,10 +13,9 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.plugin_system.core import component_registry, events_manager # 导入新插件系统 +from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.llm_models.utils_model import LLMRequest # 定义日志配置 @@ -92,8 +91,19 @@ class ChatBot: # 使用新的组件注册中心查找命令 command_result = component_registry.find_command_by_text(text) if command_result: + command_class, matched_groups, command_info = command_result + plugin_name = command_info.plugin_name + command_name = command_info.name + if ( + message.chat_stream + and message.chat_stream.stream_id + and command_name + in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) + ): + logger.info("用户禁用的命令,跳过处理") + return False, None, True + message.is_command = True - command_class, matched_groups, intercept_message, plugin_name = command_result # 获取插件配置 plugin_config = component_registry.get_plugin_config(plugin_name) @@ -104,7 +114,7 @@ class ChatBot: try: # 执行命令 - success, response = await command_instance.execute() + success, response, intercept_message = await command_instance.execute() # 记录命令执行结果 if success: @@ -117,8 +127,6 @@ class ChatBot: except Exception as e: logger.error(f"执行命令时出错: {command_class.__name__} - {e}") - import traceback - logger.error(traceback.format_exc()) try: @@ -127,7 +135,7 @@ class ChatBot: logger.error(f"发送错误消息失败: {send_error}") # 命令出错时,根据命令的拦截设置决定是否继续处理消息 - return True, str(e), not intercept_message + return True, str(e), False # 出错时继续处理消息 # 没有找到命令,继续处理消息 return False, None, True @@ -135,13 +143,12 @@ class ChatBot: except Exception as e: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - + async def hanle_notice_message(self, message: MessageRecv): if message.message_info.message_id == "notice": logger.info("收到notice消息,暂时不支持处理") return True - - + async def do_s4u(self, message_data: Dict[str, Any]): message = MessageRecvS4U(message_data) group_info = message.message_info.group_info @@ -163,7 +170,6 @@ class ChatBot: return - async def message_process(self, message_data: Dict[str, Any]) -> None: """处理转化后的统一格式消息 这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中 @@ -179,8 +185,6 @@ class ChatBot: - 性能计时 """ try: - - # 确保所有任务已启动 await self._ensure_started() @@ -201,11 +205,10 @@ class ChatBot: # print(message_data) # logger.debug(str(message_data)) message = MessageRecv(message_data) - + if await self.hanle_notice_message(message): return - - + group_info = message.message_info.group_info user_info = message.message_info.user_info if message.message_info.additional_config: @@ -214,9 +217,6 @@ class ChatBot: await MessageStorage.update_message(message) return - if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message): - return - get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( @@ -229,11 +229,10 @@ class ChatBot: # 处理消息内容,生成纯文本 await message.process() - + # if await self.check_ban_content(message): # logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}") # return - # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore @@ -252,6 +251,9 @@ class ChatBot: logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") return + if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message): + return + # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index e4a61900e..2ee2be05a 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -163,20 +163,25 @@ class ChatManager: """注册消息到聊天流""" stream_id = self._generate_stream_id( message.message_info.platform, # type: ignore - message.message_info.user_info, # type: ignore + message.message_info.user_info, message.message_info.group_info, ) self.last_messages[stream_id] = message # logger.debug(f"注册消息到聊天流: {stream_id}") @staticmethod - def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: + def _generate_stream_id( + platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None + ) -> str: """生成聊天流唯一ID""" + if not user_info and not group_info: + raise ValueError("用户信息或群组信息必须提供") + if group_info: # 组合关键信息 components = [platform, str(group_info.group_id)] else: - components = [platform, str(user_info.user_id), "private"] + components = [platform, str(user_info.user_id), "private"] # type: ignore # 使用MD5生成唯一ID key = "_".join(components) diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index a2f4c37bd..37f939b92 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Type +from typing import Dict, Optional, Type from src.plugin_system.base.base_action import BaseAction from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger @@ -22,53 +22,14 @@ class ActionManager: def __init__(self): """初始化动作管理器""" - # 所有注册的动作集合 - self._registered_actions: Dict[str, ActionInfo] = {} + # 当前正在使用的动作集合,默认加载默认动作 self._using_actions: Dict[str, ActionInfo] = {} - # 加载插件动作 - self._load_plugin_actions() - # 初始化时将默认动作加载到使用中的动作 self._using_actions = component_registry.get_default_actions() - def _load_plugin_actions(self) -> None: - """ - 加载所有插件系统中的动作 - """ - try: - # 从新插件系统获取Action组件 - self._load_plugin_system_actions() - logger.debug("从插件系统加载Action组件成功") - - except Exception as e: - logger.error(f"加载插件动作失败: {e}") - - def _load_plugin_system_actions(self) -> None: - """从插件系统的component_registry加载Action组件""" - try: - # 获取所有Action组件 - action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore - - for action_name, action_info in action_components.items(): - if action_name in self._registered_actions: - logger.debug(f"Action组件 {action_name} 已存在,跳过") - continue - - self._registered_actions[action_name] = action_info - - logger.debug( - f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})" - ) - - logger.info(f"加载了 {len(action_components)} 个Action动作") - - except Exception as e: - logger.error(f"从插件系统加载Action组件失败: {e}") - import traceback - - logger.error(traceback.format_exc()) + # === 执行Action方法 === def create_action( self, @@ -139,36 +100,11 @@ class ActionManager: logger.error(traceback.format_exc()) return None - def get_registered_actions(self) -> Dict[str, ActionInfo]: - """获取所有已注册的动作集""" - return self._registered_actions.copy() - def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集合""" return self._using_actions.copy() - def add_action_to_using(self, action_name: str) -> bool: - """ - 添加已注册的动作到当前使用的动作集 - - Args: - action_name: 动作名称 - - Returns: - bool: 添加是否成功 - """ - if action_name not in self._registered_actions: - logger.warning(f"添加失败: 动作 {action_name} 未注册") - return False - - if action_name in self._using_actions: - logger.info(f"动作 {action_name} 已经在使用中") - return True - - self._using_actions[action_name] = self._registered_actions[action_name] - logger.info(f"添加动作 {action_name} 到使用集") - return True - + # === Modify相关方法 === def remove_action_from_using(self, action_name: str) -> bool: """ 从当前使用的动作集中移除指定动作 @@ -187,79 +123,8 @@ class ActionManager: logger.debug(f"已从使用集中移除动作 {action_name}") return True - # def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: - # """ - # 添加新的动作到注册集 - - # Args: - # action_name: 动作名称 - # description: 动作描述 - # parameters: 动作参数定义,默认为空字典 - # require: 动作依赖项,默认为空列表 - - # Returns: - # bool: 添加是否成功 - # """ - # if action_name in self._registered_actions: - # return False - - # if parameters is None: - # parameters = {} - # if require is None: - # require = [] - - # action_info = {"description": description, "parameters": parameters, "require": require} - - # self._registered_actions[action_name] = action_info - # return True - - def remove_action(self, action_name: str) -> bool: - """从注册集移除指定动作""" - if action_name not in self._registered_actions: - return False - del self._registered_actions[action_name] - # 如果在使用集中也存在,一并移除 - if action_name in self._using_actions: - del self._using_actions[action_name] - return True - - def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None: - """临时移除使用集中的指定动作""" - for name in actions_to_remove: - self._using_actions.pop(name, None) - def restore_actions(self) -> None: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) self._using_actions = component_registry.get_default_actions() logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") - - def add_system_action_if_needed(self, action_name: str) -> bool: - """ - 根据需要添加系统动作到使用集 - - Args: - action_name: 动作名称 - - Returns: - bool: 是否成功添加 - """ - if action_name in self._registered_actions and action_name not in self._using_actions: - self._using_actions[action_name] = self._registered_actions[action_name] - logger.info(f"临时添加系统动作到使用集: {action_name}") - return True - return False - - def get_action(self, action_name: str) -> Optional[Type[BaseAction]]: - """ - 获取指定动作的处理器类 - - Args: - action_name: 动作名称 - - Returns: - Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 93be49842..c7964edc9 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -2,7 +2,7 @@ import random import asyncio import hashlib import time -from typing import List, Any, Dict, TYPE_CHECKING +from typing import List, Any, Dict, TYPE_CHECKING, Tuple from src.common.logger import get_logger from src.config.config import global_config @@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.plugin_system.base.component_types import ActionInfo, ActionActivationType +from src.plugin_system.core.global_announcement_manager import global_announcement_manager if TYPE_CHECKING: from src.chat.message_receive.chat_stream import ChatStream @@ -47,7 +48,6 @@ class ActionModifier: async def modify_actions( self, - history_loop=None, message_content: str = "", ): # sourcery skip: use-named-expression """ @@ -61,8 +61,9 @@ class ActionModifier: """ logger.debug(f"{self.log_prefix}开始完整动作修改流程") - removals_s1 = [] - removals_s2 = [] + removals_s1: List[Tuple[str, str]] = [] + removals_s2: List[Tuple[str, str]] = [] + removals_s3: List[Tuple[str, str]] = [] self.action_manager.restore_actions() all_actions = self.action_manager.get_using_actions() @@ -84,25 +85,28 @@ class ActionModifier: if message_content: chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}" - # === 第一阶段:传统观察处理 === - # if history_loop: - # removals_from_loop = await self.analyze_loop_actions(history_loop) - # if removals_from_loop: - # removals_s1.extend(removals_from_loop) + # === 第一阶段:去除用户自行禁用的 === + disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id) + if disabled_actions: + for disabled_action_name in disabled_actions: + if disabled_action_name in all_actions: + removals_s1.append((disabled_action_name, "用户自行禁用")) + self.action_manager.remove_action_from_using(disabled_action_name) + logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") - # 检查动作的关联类型 + # === 第二阶段:检查动作的关联类型 === chat_context = self.chat_stream.context type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context) if type_mismatched_actions: - removals_s1.extend(type_mismatched_actions) + removals_s2.extend(type_mismatched_actions) - # 应用第一阶段的移除 - for action_name, reason in removals_s1: + # 应用第二阶段的移除 + for action_name, reason in removals_s2: self.action_manager.remove_action_from_using(action_name) - logger.debug(f"{self.log_prefix}阶段一移除动作: {action_name},原因: {reason}") + logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") - # === 第二阶段:激活类型判定 === + # === 第三阶段:激活类型判定 === if chat_content is not None: logger.debug(f"{self.log_prefix}开始激活类型判定阶段") @@ -110,18 +114,18 @@ class ActionModifier: current_using_actions = self.action_manager.get_using_actions() # 获取因激活类型判定而需要移除的动作 - removals_s2 = await self._get_deactivated_actions_by_type( + removals_s3 = await self._get_deactivated_actions_by_type( current_using_actions, chat_content, ) - # 应用第二阶段的移除 - for action_name, reason in removals_s2: + # 应用第三阶段的移除 + for action_name, reason in removals_s3: self.action_manager.remove_action_from_using(action_name) - logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") + logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") # === 统一日志记录 === - all_removals = removals_s1 + removals_s2 + all_removals = removals_s1 + removals_s2 + removals_s3 removals_summary: str = "" if all_removals: removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals]) @@ -131,7 +135,7 @@ class ActionModifier: ) def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): - type_mismatched_actions = [] + type_mismatched_actions: List[Tuple[str, str]] = [] for action_name, action_info in all_actions.items(): if action_info.associated_types and not chat_context.check_types(action_info.associated_types): associated_types_str = ", ".join(action_info.associated_types) @@ -318,7 +322,7 @@ class ActionModifier: action_name: str, action_info: ActionInfo, chat_content: str = "", - ) -> bool: + ) -> bool: # sourcery skip: move-assign-in-block, use-named-expression """ 使用LLM判定是否应该激活某个action diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 61fc2f4d6..e3d1edef9 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.planner_actions.action_manager import ActionManager from src.chat.message_receive.chat_stream import get_chat_manager -from src.plugin_system.base.component_types import ActionInfo, ChatMode - +from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType +from src.plugin_system.core.component_registry import component_registry logger = get_logger("planner") @@ -99,7 +99,7 @@ class ActionPlanner: async def plan( self, mode: ChatMode = ChatMode.FOCUS - ) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension + ) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ @@ -113,16 +113,17 @@ class ActionPlanner: try: is_group_chat = True - is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") current_available_actions_dict = self.action_manager.get_using_actions() # 获取完整的动作信息 - all_registered_actions = self.action_manager.get_registered_actions() - - for action_name in current_available_actions_dict.keys(): + all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + ComponentType.ACTION + ) + current_available_actions = {} + for action_name in current_available_actions_dict: if action_name in all_registered_actions: current_available_actions[action_name] = all_registered_actions[action_name] else: @@ -234,10 +235,13 @@ class ActionPlanner: "is_parallel": is_parallel, } - return { - "action_result": action_result, - "action_prompt": prompt, - }, target_message + return ( + { + "action_result": action_result, + "action_prompt": prompt, + }, + target_message, + ) async def build_planner_prompt( self, @@ -275,23 +279,29 @@ class ActionPlanner: self.last_obs_time_mark = time.time() if mode == ChatMode.FOCUS: + mentioned_bonus = "" + if global_config.chat.mentioned_bot_inevitable_reply: + mentioned_bonus = "\n- 有人提到你" + if global_config.chat.at_bot_inevitable_reply: + mentioned_bonus = "\n- 有人提到你,或者at你" + + by_what = "聊天内容" target_prompt = '\n "target_message_id":"触发action的消息id"' - no_action_block = """重要说明1: + no_action_block = f"""重要说明1: - 'no_reply' 表示只进行不进行回复,等待合适的回复时机 - 当你刚刚发送了消息,没有人回复时,选择no_reply - 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply 动作:reply 动作描述:参与聊天回复,发送文本进行表达 -- 你想要闲聊或者随便附和 -- 有人提到你 +- 你想要闲聊或者随便附和{mentioned_bonus} - 如果你刚刚进行了回复,不要对同一个话题重复回应 -{ +{{ "action": "reply", "target_message_id":"触发action的消息id", "reason":"回复的原因" -} +}} """ else: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index f41ca8ddc..9d75671c6 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -6,7 +6,7 @@ import re from typing import List, Optional, Dict, Any, Tuple from datetime import datetime -from src.chat.mai_thinking.mai_think import mai_thinking_manager +from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger from src.config.config import global_config from src.individuality.individuality import get_individuality @@ -30,9 +30,6 @@ from src.plugin_system.base.component_types import ActionInfo logger = get_logger("replyer") -ENABLE_S2S_MODE = True - - def init_prompt(): Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1") Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") @@ -60,7 +57,6 @@ def init_prompt(): 现在请你读读之前的聊天记录,并给出回复 {config_expression_style}。注意不要复读你说过的话 {keywords_reaction_prompt} -请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。 {moderation_prompt} 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""", "default_generator_prompt", @@ -78,6 +74,7 @@ def init_prompt(): 你正在{chat_target_2},{reply_target_block} 对这句话,你想表达,原句:{raw_reply},原因是:{reason}。你现在要思考怎么组织回复 +你现在的心情是:{mood_state} 你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯 {config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。 {keywords_reaction_prompt} @@ -98,29 +95,29 @@ def init_prompt(): {relation_info_block} {extra_info_block} -你是一个AI虚拟主播,正在直播QQ聊天,同时也在直播间回复弹幕,不过回复的时候不用过多提及这点 {identity} {action_descriptions} -你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。你现在的心情是:{mood_state} +你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。 {background_dialogue_prompt} -------------------------------- {time_block} 这是你和{sender_name}的对话,你们正在交流中: + {core_dialogue_prompt} {reply_target_block} 对方最新发送的内容:{message_txt} -回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。 -{config_expression_style}。注意不要复读你说过的话 +你现在的心情是:{mood_state} +{config_expression_style} +注意不要复读你说过的话 {keywords_reaction_prompt} 请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。 {moderation_prompt} -不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。 -你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。 -你的发言: +不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出一条回复内容就好 +现在,你说: """, "s4u_style_prompt", ) @@ -133,7 +130,6 @@ class DefaultReplyer: model_configs: Optional[List[Dict[str, Any]]] = None, request_type: str = "focus.replyer", ): - self.log_prefix = "replyer" self.request_type = request_type if model_configs: @@ -197,7 +193,7 @@ class DefaultReplyer: } for key, value in reply_data.items(): if not value: - logger.debug(f"{self.log_prefix} 回复数据跳过{key},生成回复时将忽略。") + logger.debug(f"回复数据跳过{key},生成回复时将忽略。") # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 @@ -218,7 +214,7 @@ class DefaultReplyer: # 加权随机选择一个模型配置 selected_model_config = self._select_weighted_model_config() logger.info( - f"{self.log_prefix} 使用模型配置: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})" + f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})" ) express_model = LLMRequest( @@ -227,9 +223,9 @@ class DefaultReplyer: ) if global_config.debug.show_prompt: - logger.info(f"{self.log_prefix}\n{prompt}\n") + logger.info(f"\n{prompt}\n") else: - logger.debug(f"{self.log_prefix}\n{prompt}\n") + logger.debug(f"\n{prompt}\n") content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt) @@ -237,13 +233,13 @@ class DefaultReplyer: except Exception as llm_e: # 精简报错信息 - logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") + logger.error(f"LLM 生成失败: {llm_e}") return False, None, prompt # LLM 调用失败则无法生成回复 return True, content, prompt except Exception as e: - logger.error(f"{self.log_prefix}回复生成意外失败: {e}") + logger.error(f"回复生成意外失败: {e}") traceback.print_exc() return False, None, prompt @@ -274,7 +270,7 @@ class DefaultReplyer: reasoning_content = None model_name = "unknown_model" if not prompt: - logger.error(f"{self.log_prefix}Prompt 构建失败,无法生成回复。") + logger.error("Prompt 构建失败,无法生成回复。") return False, None try: @@ -282,7 +278,7 @@ class DefaultReplyer: # 加权随机选择一个模型配置 selected_model_config = self._select_weighted_model_config() logger.info( - f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})" + f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})" ) express_model = LLMRequest( @@ -296,13 +292,13 @@ class DefaultReplyer: except Exception as llm_e: # 精简报错信息 - logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") + logger.error(f"LLM 生成失败: {llm_e}") return False, None # LLM 调用失败则无法生成回复 return True, content except Exception as e: - logger.error(f"{self.log_prefix}回复生成意外失败: {e}") + logger.error(f"回复生成意外失败: {e}") traceback.print_exc() return False, None @@ -322,7 +318,7 @@ class DefaultReplyer: person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id_by_person_name(sender) if not person_id: - logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取") + logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" return await relationship_fetcher.build_relation_info(person_id, points_num=5) @@ -341,7 +337,7 @@ class DefaultReplyer: ) if selected_expressions: - logger.debug(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式") + logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式") for expr in selected_expressions: if isinstance(expr, dict) and "situation" in expr and "style" in expr: expr_type = expr.get("type", "style") @@ -350,7 +346,7 @@ class DefaultReplyer: else: style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") else: - logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式") + logger.debug("没有从处理器获得表达方式,将使用空的表达方式") # 不再在replyer中进行随机选择,全部交给处理器处理 style_habits_str = "\n".join(style_habits) @@ -358,10 +354,19 @@ class DefaultReplyer: # 动态构建expression habits块 expression_habits_block = "" + expression_habits_title = "" if style_habits_str.strip(): - expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n" + expression_habits_title = "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:" + expression_habits_block += f"{style_habits_str}\n" if grammar_habits_str.strip(): - expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n" + expression_habits_title = "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:" + expression_habits_block += f"{grammar_habits_str}\n" + + if style_habits_str.strip() and grammar_habits_str.strip(): + expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:" + + expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}" + return expression_habits_block @@ -432,19 +437,23 @@ class DefaultReplyer: tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n" tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。" - logger.info(f"{self.log_prefix} 获取到 {len(tool_results)} 个工具结果") + logger.info(f"获取到 {len(tool_results)} 个工具结果") + return tool_info_str else: - logger.debug(f"{self.log_prefix} 未获取到任何工具结果") + logger.debug("未获取到任何工具结果") return "" except Exception as e: - logger.error(f"{self.log_prefix} 工具信息获取失败: {e}") + logger.error(f"工具信息获取失败: {e}") return "" def _parse_reply_target(self, target_message: str) -> tuple: sender = "" target = "" + # 添加None检查,防止NoneType错误 + if target_message is None: + return sender, target if ":" in target_message or ":" in target_message: # 使用正则表达式匹配中文或英文冒号 parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) @@ -457,6 +466,10 @@ class DefaultReplyer: # 关键词检测与反应 keywords_reaction_prompt = "" try: + # 添加None检查,防止NoneType错误 + if target is None: + return keywords_reaction_prompt + # 处理关键词规则 for rule in global_config.keyword_reaction.keyword_rules: if any(keyword in target for keyword in rule.keywords): @@ -510,19 +523,21 @@ class DefaultReplyer: for msg_dict in message_list_before_now: try: msg_user_id = str(msg_dict.get("user_id")) - if msg_user_id == bot_id or msg_user_id == target_user_id: + reply_to = msg_dict.get("reply_to", "") + _platform, reply_to_user_id = self._parse_reply_target(reply_to) + if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: # bot 和目标用户的对话 core_dialogue_list.append(msg_dict) else: # 其他用户的对话 background_dialogue_list.append(msg_dict) except Exception as e: - logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") + logger.error(f"![1753364551656](image/default_generator/1753364551656.png)记录: {msg_dict}, 错误: {e}") # 构建背景对话 prompt background_dialogue_prompt = "" if background_dialogue_list: - latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :] + latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.5) :] background_dialogue_prompt_str = build_readable_messages( latest_25_msgs, replace_bot_name=True, @@ -549,6 +564,34 @@ class DefaultReplyer: return core_dialogue_prompt, background_dialogue_prompt + def build_mai_think_context( + self, + chat_id: str, + memory_block: str, + relation_info: str, + time_block: str, + chat_target_1: str, + chat_target_2: str, + mood_prompt: str, + identity_block: str, + sender: str, + target: str, + chat_info: str, + ): + """构建 mai_think 上下文信息""" + mai_think = mai_thinking_manager.get_mai_think(chat_id) + mai_think.memory_block = memory_block + mai_think.relation_info_block = relation_info + mai_think.time_block = time_block + mai_think.chat_target = chat_target_1 + mai_think.chat_target_2 = chat_target_2 + mai_think.chat_info = chat_info + mai_think.mood_state = mood_prompt + mai_think.identity = identity_block + mai_think.sender = sender + mai_think.target = target + return mai_think + async def build_prompt_reply_context( self, reply_data: Dict[str, Any], @@ -578,9 +621,12 @@ class DefaultReplyer: is_group_chat = bool(chat_stream.group_info) reply_to = reply_data.get("reply_to", "none") extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "") - - chat_mood = mood_manager.get_mood_by_chat_id(chat_id) - mood_prompt = chat_mood.mood_state + + if global_config.mood.enable_mood: + chat_mood = mood_manager.get_mood_by_chat_id(chat_id) + mood_prompt = chat_mood.mood_state + else: + mood_prompt = "" sender, target = self._parse_reply_target(reply_to) @@ -628,44 +674,51 @@ class DefaultReplyer: show_actions=True, ) - # 并行执行四个构建任务 + # 并行执行五个构建任务 task_results = await asyncio.gather( self._time_and_run_task( - self.build_expression_habits(chat_talking_prompt_short, target), "build_expression_habits" + self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" ), self._time_and_run_task( - self.build_relation_info(reply_data), "build_relation_info" + self.build_relation_info(reply_data), "relation_info" ), - self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"), + self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block"), self._time_and_run_task( - self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info" + self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "tool_info" + ), + self._time_and_run_task( + get_prompt_info(target, threshold=0.38), "prompt_info" ), ) + # 任务名称中英文映射 + task_name_mapping = { + "expression_habits": "选取表达方式", + "relation_info": "感受关系", + "memory_block": "回忆", + "tool_info": "使用工具", + "prompt_info": "获取知识" + } + # 处理结果 timing_logs = [] results_dict = {} for name, result, duration in task_results: results_dict[name] = result - timing_logs.append(f"{name}: {duration:.4f}s") + chinese_name = task_name_mapping.get(name, name) + timing_logs.append(f"{chinese_name}: {duration:.1f}s") if duration > 8: - logger.warning(f"回复生成前信息获取耗时过长: {name} 耗时: {duration:.4f}s,请使用更快的模型") - logger.info(f"回复生成前信息获取耗时: {'; '.join(timing_logs)}") + logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型") + logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}") - expression_habits_block = results_dict["build_expression_habits"] - relation_info = results_dict["build_relation_info"] - memory_block = results_dict["build_memory_block"] - tool_info = results_dict["build_tool_info"] + expression_habits_block = results_dict["expression_habits"] + relation_info = results_dict["relation_info"] + memory_block = results_dict["memory_block"] + tool_info = results_dict["tool_info"] + prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果 keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) - if tool_info: - tool_info_block = ( - f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{tool_info}\n以上是一些额外的信息。" - ) - else: - tool_info_block = "" - if extra_info_block: extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策" else: @@ -699,10 +752,6 @@ class DefaultReplyer: else: reply_target_block = "" - prompt_info = await get_prompt_info(target, threshold=0.38) - if prompt_info: - prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info) - template_name = "default_generator_prompt" if is_group_chat: chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") @@ -742,24 +791,24 @@ class DefaultReplyer: message_list_before_now_long, target_user_id ) - mai_think = mai_thinking_manager.get_mai_think(chat_id) - mai_think.memory_block = memory_block - mai_think.relation_info_block = relation_info - mai_think.time_block = time_block - mai_think.chat_target = chat_target_1 - mai_think.chat_target_2 = chat_target_2 - # mai_think.chat_info = chat_talking_prompt - mai_think.mood_state = mood_prompt - mai_think.identity = identity_block - mai_think.sender = sender - mai_think.target = target - - mai_think.chat_info = f""" + self.build_mai_think_context( + chat_id=chat_id, + memory_block=memory_block, + relation_info=relation_info, + time_block=time_block, + chat_target_1=chat_target_1, + chat_target_2=chat_target_2, + mood_prompt=mood_prompt, + identity_block=identity_block, + sender=sender, + target=target, + chat_info=f""" {background_dialogue_prompt} -------------------------------- {time_block} 这是你和{sender}的对话,你们正在交流中: {core_dialogue_prompt}""" + ) # 使用 s4u 风格的模板 @@ -768,7 +817,7 @@ class DefaultReplyer: return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, - tool_info_block=tool_info_block, + tool_info_block=tool_info, knowledge_prompt=prompt_info, memory_block=memory_block, relation_info_block=relation_info, @@ -787,17 +836,19 @@ class DefaultReplyer: moderation_prompt=moderation_prompt_block, ) else: - mai_think = mai_thinking_manager.get_mai_think(chat_id) - mai_think.memory_block = memory_block - mai_think.relation_info_block = relation_info - mai_think.time_block = time_block - mai_think.chat_target = chat_target_1 - mai_think.chat_target_2 = chat_target_2 - mai_think.chat_info = chat_talking_prompt - mai_think.mood_state = mood_prompt - mai_think.identity = identity_block - mai_think.sender = sender - mai_think.target = target + self.build_mai_think_context( + chat_id=chat_id, + memory_block=memory_block, + relation_info=relation_info, + time_block=time_block, + chat_target_1=chat_target_1, + chat_target_2=chat_target_2, + mood_prompt=mood_prompt, + identity_block=identity_block, + sender=sender, + target=target, + chat_info=chat_talking_prompt + ) # 使用原有的模式 return await global_prompt_manager.format_prompt( @@ -806,7 +857,7 @@ class DefaultReplyer: chat_target=chat_target_1, chat_info=chat_talking_prompt, memory_block=memory_block, - tool_info_block=tool_info_block, + tool_info_block=tool_info, knowledge_prompt=prompt_info, extra_info_block=extra_info_block, relation_info_block=relation_info, @@ -836,6 +887,13 @@ class DefaultReplyer: reason = reply_data.get("reason", "") sender, target = self._parse_reply_target(reply_to) + # 添加情绪状态获取 + if global_config.mood.enable_mood: + chat_mood = mood_manager.get_mood_by_chat_id(chat_id) + mood_prompt = chat_mood.mood_state + else: + mood_prompt = "" + message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), @@ -916,6 +974,7 @@ class DefaultReplyer: reply_target_block=reply_target_block, raw_reply=raw_reply, reason=reason, + mood_state=mood_prompt, # 添加情绪状态参数 config_expression_style=global_config.expression.expression_style, keywords_reaction_prompt=keywords_reaction_prompt, moderation_prompt=moderation_prompt_block, @@ -1012,7 +1071,10 @@ async def get_prompt_info(message: str, threshold: float): related_info += found_knowledge_from_lpmm logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - return related_info + + # 格式化知识信息 + formatted_prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=related_info) + return formatted_prompt_info else: logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") return "" diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index bce8856e5..aa000df7a 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,8 +1,5 @@ import asyncio import concurrent.futures -import json -import os -import glob from collections import defaultdict from datetime import datetime, timedelta @@ -43,20 +40,6 @@ ONLINE_TIME = "online_time" TOTAL_MSG_CNT = "total_messages" MSG_CNT_BY_CHAT = "messages_by_chat" -# Focus统计数据的键 -FOCUS_TOTAL_CYCLES = "focus_total_cycles" -FOCUS_AVG_TIMES_BY_STAGE = "focus_avg_times_by_stage" -FOCUS_ACTION_RATIOS = "focus_action_ratios" -FOCUS_CYCLE_CNT_BY_CHAT = "focus_cycle_count_by_chat" -FOCUS_CYCLE_CNT_BY_ACTION = "focus_cycle_count_by_action" -FOCUS_AVG_TIMES_BY_CHAT_ACTION = "focus_avg_times_by_chat_action" -FOCUS_AVG_TIMES_BY_ACTION = "focus_avg_times_by_action" -FOCUS_TOTAL_TIME_BY_CHAT = "focus_total_time_by_chat" -FOCUS_TOTAL_TIME_BY_ACTION = "focus_total_time_by_action" -FOCUS_CYCLE_CNT_BY_VERSION = "focus_cycle_count_by_version" -FOCUS_ACTION_RATIOS_BY_VERSION = "focus_action_ratios_by_version" -FOCUS_AVG_TIMES_BY_VERSION = "focus_avg_times_by_version" - class OnlineTimeRecordTask(AsyncTask): """在线时间记录任务""" @@ -196,8 +179,6 @@ class StatisticOutputTask(AsyncTask): self._format_model_classified_stat(stats["last_hour"]), "", self._format_chat_stat(stats["last_hour"]), - "", - self._format_focus_stat(stats["last_hour"]), self.SEP_LINE, "", ] @@ -466,189 +447,7 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - """ - 收集指定时间段的Focus统计数据 - - :param collect_period: 统计时间段 - """ - if not collect_period: - return {} - - collect_period.sort(key=lambda x: x[1], reverse=True) - - stats = { - period_key: { - FOCUS_TOTAL_CYCLES: 0, - FOCUS_AVG_TIMES_BY_STAGE: defaultdict(list), - FOCUS_ACTION_RATIOS: defaultdict(int), - FOCUS_CYCLE_CNT_BY_CHAT: defaultdict(int), - FOCUS_CYCLE_CNT_BY_ACTION: defaultdict(int), - FOCUS_AVG_TIMES_BY_CHAT_ACTION: defaultdict(lambda: defaultdict(list)), - FOCUS_AVG_TIMES_BY_ACTION: defaultdict(lambda: defaultdict(list)), - "focus_exec_times_by_chat_action": defaultdict(lambda: defaultdict(list)), - FOCUS_TOTAL_TIME_BY_CHAT: defaultdict(float), - FOCUS_TOTAL_TIME_BY_ACTION: defaultdict(float), - FOCUS_CYCLE_CNT_BY_VERSION: defaultdict(int), - FOCUS_ACTION_RATIOS_BY_VERSION: defaultdict(lambda: defaultdict(int)), - FOCUS_AVG_TIMES_BY_VERSION: defaultdict(lambda: defaultdict(list)), - "focus_exec_times_by_version_action": defaultdict(lambda: defaultdict(list)), - "focus_action_ratios_by_chat": defaultdict(lambda: defaultdict(int)), - } - for period_key, _ in collect_period - } - - # 获取 log/hfc_loop 目录下的所有 json 文件 - log_dir = "log/hfc_loop" - if not os.path.exists(log_dir): - logger.warning(f"Focus log directory {log_dir} does not exist") - return stats - - json_files = glob.glob(os.path.join(log_dir, "*.json")) - query_start_time = collect_period[-1][1] - - for json_file in json_files: - try: - # 从文件名解析时间戳 (格式: hash_version_date_time.json) - filename = os.path.basename(json_file) - name_parts = filename.replace(".json", "").split("_") - if len(name_parts) >= 4: - date_str = name_parts[-2] # YYYYMMDD - time_str = name_parts[-1] # HHMMSS - file_time_str = f"{date_str}_{time_str}" - file_time = datetime.strptime(file_time_str, "%Y%m%d_%H%M%S") - - # 如果文件时间在查询范围内,则处理该文件 - if file_time >= query_start_time: - with open(json_file, "r", encoding="utf-8") as f: - cycles_data = json.load(f) - self._process_focus_file_data(cycles_data, stats, collect_period, file_time) - except Exception as e: - logger.warning(f"Failed to process focus file {json_file}: {e}") - continue - - # 计算平均值 - self._calculate_focus_averages(stats) - return stats - - def _process_focus_file_data( - self, - cycles_data: List[Dict], - stats: Dict[str, Any], - collect_period: List[Tuple[str, datetime]], - file_time: datetime, - ): - """ - 处理单个focus文件的数据 - """ - for cycle_data in cycles_data: - try: - # 解析时间戳 - timestamp_str = cycle_data.get("timestamp", "") - if timestamp_str: - cycle_time = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) - else: - cycle_time = file_time # 使用文件时间作为后备 - - chat_id = cycle_data.get("chat_id", "unknown") - action_type = cycle_data.get("action_type", "unknown") - total_time = cycle_data.get("total_time", 0.0) - step_times = cycle_data.get("step_times", {}) - version = cycle_data.get("version", "unknown") - - # 更新聊天ID名称映射 - if chat_id not in self.name_mapping: - # 尝试获取实际的聊天名称 - display_name = self._get_chat_display_name_from_id(chat_id) - self.name_mapping[chat_id] = (display_name, cycle_time.timestamp()) - - # 对每个时间段进行统计 - for idx, (_, period_start) in enumerate(collect_period): - if cycle_time >= period_start: - for period_key, _ in collect_period[idx:]: - stat = stats[period_key] - - # 基础统计 - stat[FOCUS_TOTAL_CYCLES] += 1 - stat[FOCUS_ACTION_RATIOS][action_type] += 1 - stat[FOCUS_CYCLE_CNT_BY_CHAT][chat_id] += 1 - stat[FOCUS_CYCLE_CNT_BY_ACTION][action_type] += 1 - stat["focus_action_ratios_by_chat"][chat_id][action_type] += 1 - stat[FOCUS_TOTAL_TIME_BY_CHAT][chat_id] += total_time - stat[FOCUS_TOTAL_TIME_BY_ACTION][action_type] += total_time - - # 版本统计 - stat[FOCUS_CYCLE_CNT_BY_VERSION][version] += 1 - stat[FOCUS_ACTION_RATIOS_BY_VERSION][version][action_type] += 1 - - # 阶段时间统计 - for stage, time_val in step_times.items(): - stat[FOCUS_AVG_TIMES_BY_STAGE][stage].append(time_val) - stat[FOCUS_AVG_TIMES_BY_CHAT_ACTION][chat_id][stage].append(time_val) - stat[FOCUS_AVG_TIMES_BY_ACTION][action_type][stage].append(time_val) - stat[FOCUS_AVG_TIMES_BY_VERSION][version][stage].append(time_val) - - # 专门收集执行动作阶段的时间,按聊天流和action类型分组 - if stage == "执行动作": - stat["focus_exec_times_by_chat_action"][chat_id][action_type].append(time_val) - # 按版本和action类型收集执行时间 - stat["focus_exec_times_by_version_action"][version][action_type].append(time_val) - break - except Exception as e: - logger.warning(f"Failed to process cycle data: {e}") - continue - - def _calculate_focus_averages(self, stats: Dict[str, Any]): - """ - 计算Focus统计的平均值 - """ - for _period_key, stat in stats.items(): - # 计算全局阶段平均时间 - for stage, times in stat[FOCUS_AVG_TIMES_BY_STAGE].items(): - if times: - stat[FOCUS_AVG_TIMES_BY_STAGE][stage] = sum(times) / len(times) - else: - stat[FOCUS_AVG_TIMES_BY_STAGE][stage] = 0.0 - - # 计算按chat_id和action_type的阶段平均时间 - for chat_id, stage_times in stat[FOCUS_AVG_TIMES_BY_CHAT_ACTION].items(): - for stage, times in stage_times.items(): - if times: - stat[FOCUS_AVG_TIMES_BY_CHAT_ACTION][chat_id][stage] = sum(times) / len(times) - else: - stat[FOCUS_AVG_TIMES_BY_CHAT_ACTION][chat_id][stage] = 0.0 - - # 计算按action_type的阶段平均时间 - for action_type, stage_times in stat[FOCUS_AVG_TIMES_BY_ACTION].items(): - for stage, times in stage_times.items(): - if times: - stat[FOCUS_AVG_TIMES_BY_ACTION][action_type][stage] = sum(times) / len(times) - else: - stat[FOCUS_AVG_TIMES_BY_ACTION][action_type][stage] = 0.0 - - # 计算按聊天流和action类型的执行时间平均值 - for chat_id, action_times in stat["focus_exec_times_by_chat_action"].items(): - for action_type, times in action_times.items(): - if times: - stat["focus_exec_times_by_chat_action"][chat_id][action_type] = sum(times) / len(times) - else: - stat["focus_exec_times_by_chat_action"][chat_id][action_type] = 0.0 - - # 计算按版本的阶段平均时间 - for version, stage_times in stat[FOCUS_AVG_TIMES_BY_VERSION].items(): - for stage, times in stage_times.items(): - if times: - stat[FOCUS_AVG_TIMES_BY_VERSION][version][stage] = sum(times) / len(times) - else: - stat[FOCUS_AVG_TIMES_BY_VERSION][version][stage] = 0.0 - - # 计算按版本和action类型的执行时间平均值 - for version, action_times in stat["focus_exec_times_by_version_action"].items(): - for action_type, times in action_times.items(): - if times: - stat["focus_exec_times_by_version_action"][version][action_type] = sum(times) / len(times) - else: - stat["focus_exec_times_by_version_action"][version][action_type] = 0.0 + def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: """ @@ -675,15 +474,13 @@ class StatisticOutputTask(AsyncTask): model_req_stat = self._collect_model_request_for_period(stat_start_timestamp) online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now) message_count_stat = self._collect_message_count_for_period(stat_start_timestamp) - focus_stat = self._collect_focus_statistics_for_period(stat_start_timestamp) # 统计数据合并 - # 合并四类统计数据 + # 合并三类统计数据 for period_key, _ in stat_start_timestamp: stat[period_key].update(model_req_stat[period_key]) stat[period_key].update(online_time_stat[period_key]) stat[period_key].update(message_count_stat[period_key]) - stat[period_key].update(focus_stat[period_key]) if last_all_time_stat: # 若存在上次完整统计数据,则将其与当前统计数据合并 @@ -800,41 +597,6 @@ class StatisticOutputTask(AsyncTask): output.append("") return "\n".join(output) - def _format_focus_stat(self, stats: Dict[str, Any]) -> str: - """ - 格式化Focus统计数据 - """ - if stats[FOCUS_TOTAL_CYCLES] <= 0: - return "" - - output = ["Focus系统统计:", f"总循环数: {stats[FOCUS_TOTAL_CYCLES]}", ""] - - # 全局阶段平均时间 - if stats[FOCUS_AVG_TIMES_BY_STAGE]: - output.append("全局阶段平均时间:") - output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items()) - output.append("") - - # Action类型比例 - if stats[FOCUS_ACTION_RATIOS]: - total_actions = sum(stats[FOCUS_ACTION_RATIOS].values()) - output.append("Action类型分布:") - for action_type, count in sorted(stats[FOCUS_ACTION_RATIOS].items()): - ratio = (count / total_actions) * 100 if total_actions > 0 else 0 - output.append(f" {action_type}: {count} ({ratio:.1f}%)") - output.append("") - - # 按Chat统计(仅显示前10个) - if stats[FOCUS_CYCLE_CNT_BY_CHAT]: - output.append("按聊天流统计 (前10):") - sorted_chats = sorted(stats[FOCUS_CYCLE_CNT_BY_CHAT].items(), key=lambda x: x[1], reverse=True)[:10] - for chat_id, count in sorted_chats: - chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] - output.append(f" {chat_name[:30]}: {count} 循环") - output.append("") - - return "\n".join(output) - def _get_chat_display_name_from_id(self, chat_id: str) -> str: """从chat_id获取显示名称""" try: @@ -865,6 +627,8 @@ class StatisticOutputTask(AsyncTask): logger.warning(f"获取聊天显示名称失败: {e}") return chat_id + # 移除_generate_versions_tab方法 + def _generate_html_report(self, stat: dict[str, Any], now: datetime): """ 生成HTML格式的统计报告 @@ -873,13 +637,11 @@ class StatisticOutputTask(AsyncTask): :return: HTML格式的统计报告 """ + # 移除版本对比内容相关tab和内容 tab_list = [ f'' for period in self.stat_period ] - # 添加Focus统计、版本对比和图表选项卡 - tab_list.append('') - tab_list.append('') tab_list.append('') def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str: @@ -941,53 +703,6 @@ class StatisticOutputTask(AsyncTask): for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) ] ) - - # Focus统计数据 - # focus_action_rows = "" - # focus_chat_rows = "" - # focus_stage_rows = "" - # focus_action_stage_rows = "" - - if stat_data.get(FOCUS_TOTAL_CYCLES, 0) > 0: - # Action类型统计 - total_actions = sum(stat_data[FOCUS_ACTION_RATIOS].values()) if stat_data[FOCUS_ACTION_RATIOS] else 0 - _focus_action_rows = "\n".join( - [ - f"{action_type}{count}{(count / total_actions * 100):.1f}%" - for action_type, count in sorted(stat_data[FOCUS_ACTION_RATIOS].items()) - ] - ) - - # 按聊天流统计 - _focus_chat_rows = "\n".join( - [ - f"{self.name_mapping.get(chat_id, (chat_id, 0))[0]}{count}{stat_data[FOCUS_TOTAL_TIME_BY_CHAT].get(chat_id, 0):.2f}秒" - for chat_id, count in sorted( - stat_data[FOCUS_CYCLE_CNT_BY_CHAT].items(), key=lambda x: x[1], reverse=True - ) - ] - ) - - # 全局阶段时间统计 - _focus_stage_rows = "\n".join( - [ - f"{stage}{avg_time:.3f}秒" - for stage, avg_time in sorted(stat_data[FOCUS_AVG_TIMES_BY_STAGE].items()) - ] - ) - - # 按Action类型的阶段时间统计 - focus_action_stage_items = [] - for action_type, stage_times in stat_data[FOCUS_AVG_TIMES_BY_ACTION].items(): - for stage, avg_time in stage_times.items(): - focus_action_stage_items.append((action_type, stage, avg_time)) - - _focus_action_stage_rows = "\n".join( - [ - f"{action_type}{stage}{avg_time:.3f}秒" - for action_type, stage, avg_time in sorted(focus_action_stage_items) - ] - ) # 生成HTML return f"""
@@ -1052,14 +767,7 @@ class StatisticOutputTask(AsyncTask): _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore ) - # 添加Focus统计内容 - focus_tab = self._generate_focus_tab(stat) - tab_content_list.append(focus_tab) - - # 添加版本对比内容 - versions_tab = self._generate_versions_tab(stat) - tab_content_list.append(versions_tab) - + # 不再添加版本对比内容 # 添加图表内容 chart_data = self._generate_chart_data(stat) tab_content_list.append(self._generate_chart_tab(chart_data)) @@ -1210,609 +918,6 @@ class StatisticOutputTask(AsyncTask): with open(self.record_file_path, "w", encoding="utf-8") as f: f.write(html_template) - def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - # sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next - """生成Focus统计独立分页的HTML内容""" - - # 为每个时间段准备Focus数据 - focus_sections = [] - - for period_name, period_delta, period_desc in self.stat_period: - stat_data = stat.get(period_name, {}) - - if stat_data.get(FOCUS_TOTAL_CYCLES, 0) <= 0: - continue - - # 生成Focus统计数据行 - focus_action_rows = "" - focus_chat_rows = "" - focus_stage_rows = "" - focus_action_stage_rows = "" - - # Action类型统计 - total_actions = sum(stat_data[FOCUS_ACTION_RATIOS].values()) if stat_data[FOCUS_ACTION_RATIOS] else 0 - if total_actions > 0: - focus_action_rows = "\n".join( - [ - f"{action_type}{count}{(count / total_actions * 100):.1f}%" - for action_type, count in sorted(stat_data[FOCUS_ACTION_RATIOS].items()) - ] - ) - - # 按聊天流统计(横向表格,显示各阶段时间差异和不同action的平均时间) - focus_chat_rows = "" - if stat_data[FOCUS_AVG_TIMES_BY_CHAT_ACTION]: - # 获取前三个阶段(不包括执行动作) - basic_stages = ["观察", "规划器"] - existing_basic_stages = [] - for stage in basic_stages: - # 检查是否有任何聊天流在这个阶段有数据 - stage_exists = False - for _chat_id, stage_times in stat_data[FOCUS_AVG_TIMES_BY_CHAT_ACTION].items(): - if stage in stage_times: - stage_exists = True - break - if stage_exists: - existing_basic_stages.append(stage) - - # 获取所有action类型(按出现频率排序) - all_action_types = sorted( - stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True - ) - - # 为每个聊天流生成一行 - chat_rows = [] - for chat_id in sorted( - stat_data[FOCUS_CYCLE_CNT_BY_CHAT].keys(), - key=lambda x: stat_data[FOCUS_CYCLE_CNT_BY_CHAT][x], - reverse=True, - ): - chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] - cycle_count = stat_data[FOCUS_CYCLE_CNT_BY_CHAT][chat_id] - - # 获取该聊天流的各阶段平均时间 - stage_times = stat_data[FOCUS_AVG_TIMES_BY_CHAT_ACTION].get(chat_id, {}) - - row_cells = [f"{chat_name}
({cycle_count}次循环)"] - - # 添加基础阶段时间 - for stage in existing_basic_stages: - time_val = stage_times.get(stage, 0.0) - row_cells.append(f"{time_val:.3f}秒") - - # 添加每个action类型的平均执行时间 - for action_type in all_action_types: - # 使用真实的按聊天流+action类型分组的执行时间数据 - exec_times_by_chat_action = stat_data.get("focus_exec_times_by_chat_action", {}) - chat_action_times = exec_times_by_chat_action.get(chat_id, {}) - avg_exec_time = chat_action_times.get(action_type, 0.0) - - if avg_exec_time > 0: - row_cells.append(f"{avg_exec_time:.3f}秒") - else: - row_cells.append("-") - - chat_rows.append(f"{''.join(row_cells)}") - - # 生成表头 - stage_headers = "".join([f"{stage}" for stage in existing_basic_stages]) - action_headers = "".join( - [f"{action_type}
(执行)" for action_type in all_action_types] - ) - focus_chat_table_header = f"聊天流{stage_headers}{action_headers}" - focus_chat_rows = focus_chat_table_header + "\n" + "\n".join(chat_rows) - - # 全局阶段时间统计 - focus_stage_rows = "\n".join( - [ - f"{stage}{avg_time:.3f}秒" - for stage, avg_time in sorted(stat_data[FOCUS_AVG_TIMES_BY_STAGE].items()) - ] - ) - - # 聊天流Action选择比例对比表(横向表格) - focus_chat_action_ratios_rows = "" - if stat_data.get("focus_action_ratios_by_chat"): - if all_action_types_for_ratio := sorted( - stat_data[FOCUS_ACTION_RATIOS].keys(), - key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], - reverse=True, - ): - # 为每个聊天流生成数据行(按循环数排序) - chat_ratio_rows = [] - for chat_id in sorted( - stat_data[FOCUS_CYCLE_CNT_BY_CHAT].keys(), - key=lambda x: stat_data[FOCUS_CYCLE_CNT_BY_CHAT][x], - reverse=True, - ): - chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] - total_cycles = stat_data[FOCUS_CYCLE_CNT_BY_CHAT][chat_id] - chat_action_counts = stat_data["focus_action_ratios_by_chat"].get(chat_id, {}) - - row_cells = [f"{chat_name}
({total_cycles}次循环)"] - - # 添加每个action类型的数量和百分比 - for action_type in all_action_types_for_ratio: - count = chat_action_counts.get(action_type, 0) - ratio = (count / total_cycles * 100) if total_cycles > 0 else 0 - if count > 0: - row_cells.append(f"{count}
({ratio:.1f}%)") - else: - row_cells.append("-
(0%)") - - chat_ratio_rows.append(f"{''.join(row_cells)}") - - # 生成表头 - action_headers = "".join([f"{action_type}" for action_type in all_action_types_for_ratio]) - chat_action_ratio_table_header = f"聊天流{action_headers}" - focus_chat_action_ratios_rows = chat_action_ratio_table_header + "\n" + "\n".join(chat_ratio_rows) - - # 按Action类型的阶段时间统计(横向表格) - focus_action_stage_rows = "" - if stat_data[FOCUS_AVG_TIMES_BY_ACTION]: - # 获取所有阶段(按固定顺序) - stage_order = ["观察", "规划器", "执行动作"] - all_stages = [] - for stage in stage_order: - if any(stage in stage_times for stage_times in stat_data[FOCUS_AVG_TIMES_BY_ACTION].values()): - all_stages.append(stage) - - # 为每个Action类型生成一行 - action_rows = [] - for action_type in sorted(stat_data[FOCUS_AVG_TIMES_BY_ACTION].keys()): - stage_times = stat_data[FOCUS_AVG_TIMES_BY_ACTION][action_type] - row_cells = [f"{action_type}"] - - for stage in all_stages: - time_val = stage_times.get(stage, 0.0) - row_cells.append(f"{time_val:.3f}秒") - - action_rows.append(f"{''.join(row_cells)}") - - # 生成表头 - stage_headers = "".join([f"{stage}" for stage in all_stages]) - focus_action_stage_table_header = f"Action类型{stage_headers}" - focus_action_stage_rows = focus_action_stage_table_header + "\n" + "\n".join(action_rows) - - # 计算时间范围 - if period_name == "all_time": - from src.manager.local_store_manager import local_storage - - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore - else: - start_time = datetime.now() - period_delta - - time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - # 生成该时间段的Focus统计HTML - section_html = f""" -
-

{period_desc}Focus统计

-

统计时段: {time_range}

-

总循环数: {stat_data.get(FOCUS_TOTAL_CYCLES, 0)}

- -
-
-

全局阶段平均时间

- - - {focus_stage_rows} -
阶段平均时间
-
- -
-

Action类型分布

- - - {focus_action_rows} -
Action类型次数占比
-
-
- -
-

按聊天流各阶段时间统计

- - - {focus_chat_rows} -
-
- -
-

聊天流Action选择比例对比

- - - {focus_chat_action_ratios_rows} -
-
- -
-

Action类型阶段时间详情

- - - {focus_action_stage_rows} -
-
-
- """ - - focus_sections.append(section_html) - - # 如果没有任何Focus数据 - if not focus_sections: - focus_sections.append(""" -
-

暂无Focus统计数据

-

在指定时间段内未找到任何Focus循环数据。

-

请确保 log/hfc_loop/ 目录下存在相应的JSON文件。

-
- """) - - return f""" -
-

Focus系统详细统计

-

- 数据来源: log/hfc_loop/ 目录下的JSON文件
- 统计内容: 各时间段的Focus循环性能分析 -

- - {"".join(focus_sections)} - - -
- """ - - def _generate_versions_tab(self, stat: dict[str, Any]) -> str: - # sourcery skip: use-named-expression, use-next - """生成版本对比独立分页的HTML内容""" - - # 为每个时间段准备版本对比数据 - version_sections = [] - - for period_name, period_delta, period_desc in self.stat_period: - stat_data = stat.get(period_name, {}) - - if not stat_data.get(FOCUS_CYCLE_CNT_BY_VERSION): - continue - - # 获取所有版本(按循环数排序) - all_versions = sorted( - stat_data[FOCUS_CYCLE_CNT_BY_VERSION].keys(), - key=lambda x: stat_data[FOCUS_CYCLE_CNT_BY_VERSION][x], - reverse=True, - ) - - # 生成版本Action分布表 - focus_version_action_rows = "" - if stat_data[FOCUS_ACTION_RATIOS_BY_VERSION]: - # 获取所有action类型 - all_action_types_for_version = set() - for version_actions in stat_data[FOCUS_ACTION_RATIOS_BY_VERSION].values(): - all_action_types_for_version.update(version_actions.keys()) - all_action_types_for_version = sorted(all_action_types_for_version) - - if all_action_types_for_version: - version_action_rows = [] - for version in all_versions: - version_actions = stat_data[FOCUS_ACTION_RATIOS_BY_VERSION].get(version, {}) - total_cycles = stat_data[FOCUS_CYCLE_CNT_BY_VERSION][version] - - row_cells = [f"{version}
({total_cycles}次循环)"] - - for action_type in all_action_types_for_version: - count = version_actions.get(action_type, 0) - ratio = (count / total_cycles * 100) if total_cycles > 0 else 0 - row_cells.append(f"{count}
({ratio:.1f}%)") - - version_action_rows.append(f"{''.join(row_cells)}") - - # 生成表头 - action_headers = "".join( - [f"{action_type}" for action_type in all_action_types_for_version] - ) - version_action_table_header = f"版本{action_headers}" - focus_version_action_rows = version_action_table_header + "\n" + "\n".join(version_action_rows) - - # 生成版本阶段时间表(按action类型分解执行时间) - focus_version_stage_rows = "" - if stat_data[FOCUS_AVG_TIMES_BY_VERSION]: - # 基础三个阶段 - basic_stages = ["观察", "规划器"] - - # 获取所有action类型用于执行时间列 - all_action_types_for_exec = set() - if stat_data.get("focus_exec_times_by_version_action"): - for version_actions in stat_data["focus_exec_times_by_version_action"].values(): - all_action_types_for_exec.update(version_actions.keys()) - all_action_types_for_exec = sorted(all_action_types_for_exec) - - # 检查哪些基础阶段存在数据 - existing_basic_stages = [] - for stage in basic_stages: - stage_exists = False - for version_stages in stat_data[FOCUS_AVG_TIMES_BY_VERSION].values(): - if stage in version_stages: - stage_exists = True - break - if stage_exists: - existing_basic_stages.append(stage) - - # 构建表格 - if existing_basic_stages or all_action_types_for_exec: - version_stage_rows = [] - - # 为每个版本生成数据行 - for version in all_versions: - version_stages = stat_data[FOCUS_AVG_TIMES_BY_VERSION].get(version, {}) - total_cycles = stat_data[FOCUS_CYCLE_CNT_BY_VERSION][version] - - row_cells = [f"{version}
({total_cycles}次循环)"] - - # 添加基础阶段时间 - for stage in existing_basic_stages: - time_val = version_stages.get(stage, 0.0) - row_cells.append(f"{time_val:.3f}秒") - - # 添加不同action类型的执行时间 - for action_type in all_action_types_for_exec: - # 获取该版本该action类型的平均执行时间 - version_exec_times = stat_data.get("focus_exec_times_by_version_action", {}) - if version in version_exec_times and action_type in version_exec_times[version]: - exec_time = version_exec_times[version][action_type] - row_cells.append(f"{exec_time:.3f}秒") - else: - row_cells.append("-") - - version_stage_rows.append(f"{''.join(row_cells)}") - - # 生成表头 - basic_headers = "".join([f"{stage}" for stage in existing_basic_stages]) - action_headers = "".join( - [ - f"执行时间
[{action_type}]" - for action_type in all_action_types_for_exec - ] - ) - version_stage_table_header = f"版本{basic_headers}{action_headers}" - focus_version_stage_rows = version_stage_table_header + "\n" + "\n".join(version_stage_rows) - - # 计算时间范围 - if period_name == "all_time": - from src.manager.local_store_manager import local_storage - - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore - else: - start_time = datetime.now() - period_delta - time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - # 生成该时间段的版本对比HTML - section_html = f""" -
-

{period_desc}版本对比

-

统计时段: {time_range}

-

包含版本: {len(all_versions)} 个版本

- -
-
-

版本Action类型分布对比

- - - {focus_version_action_rows} -
-
- -
-

版本阶段时间对比

- - - {focus_version_stage_rows} -
-
-
-
- """ - - version_sections.append(section_html) - - # 如果没有任何版本数据 - if not version_sections: - version_sections.append(""" -
-

暂无版本对比数据

-

在指定时间段内未找到任何版本信息。

-

请确保 log/hfc_loop/ 目录下的JSON文件包含版本信息。

-
- """) - - return f""" -
-

Focus HFC版本对比分析

-

- 对比内容: 不同版本的Action类型分布和各阶段性能表现
- 数据来源: log/hfc_loop/ 目录下JSON文件中的version字段 -

- - {"".join(version_sections)} - - -
- """ - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: """生成图表数据""" now = datetime.now() @@ -1906,68 +1011,12 @@ class StatisticOutputTask(AsyncTask): message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name][interval_index] += 1 - # 查询Focus循环记录 - focus_cycles_by_action = {} - focus_time_by_stage = {} - - log_dir = "log/hfc_loop" - if os.path.exists(log_dir): - json_files = glob.glob(os.path.join(log_dir, "*.json")) - for json_file in json_files: - try: - # 解析文件时间 - filename = os.path.basename(json_file) - name_parts = filename.replace(".json", "").split("_") - if len(name_parts) >= 4: - date_str = name_parts[-2] - time_str = name_parts[-1] - file_time_str = f"{date_str}_{time_str}" - file_time = datetime.strptime(file_time_str, "%Y%m%d_%H%M%S") - - if file_time >= start_time: - with open(json_file, "r", encoding="utf-8") as f: - cycles_data = json.load(f) - - for cycle in cycles_data: - try: - timestamp_str = cycle.get("timestamp", "") - if timestamp_str: - cycle_time = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) - else: - cycle_time = file_time - - if cycle_time >= start_time: - # 计算时间间隔索引 - time_diff = (cycle_time - start_time).total_seconds() - interval_index = int(time_diff // interval_seconds) - - if 0 <= interval_index < len(time_points): - action_type = cycle.get("action_type", "unknown") - step_times = cycle.get("step_times", {}) - - # 累计action类型数据 - if action_type not in focus_cycles_by_action: - focus_cycles_by_action[action_type] = [0] * len(time_points) - focus_cycles_by_action[action_type][interval_index] += 1 - - # 累计阶段时间数据 - for stage, time_val in step_times.items(): - if stage not in focus_time_by_stage: - focus_time_by_stage[stage] = [0] * len(time_points) - focus_time_by_stage[stage][interval_index] += time_val - except Exception: - continue - except Exception: - continue - return { "time_labels": time_labels, "total_cost_data": total_cost_data, "cost_by_model": cost_by_model, "cost_by_module": cost_by_module, "message_by_chat": message_by_chat, - "focus_cycles_by_action": focus_cycles_by_action, - "focus_time_by_stage": focus_time_by_stage, } def _generate_chart_tab(self, chart_data: dict) -> str: @@ -2059,14 +1108,8 @@ class StatisticOutputTask(AsyncTask):
-
- -
-
- -
- +
@@ -2169,8 +1212,6 @@ class StatisticOutputTask(AsyncTask): createChart('costByModule', data, timeRange); createChart('costByModel', data, timeRange); createChart('messageByChat', data, timeRange); - createChart('focusCyclesByAction', data, timeRange); - createChart('focusTimeByStage', data, timeRange); }} function createChart(chartType, data, timeRange) {{ @@ -2327,21 +1368,6 @@ class AsyncStatisticOutputTask(AsyncTask): def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore - def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore - - def _process_focus_file_data( - self, - cycles_data: List[Dict], - stats: Dict[str, Any], - collect_period: List[Tuple[str, datetime]], - file_time: datetime, - ): - return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore - - def _calculate_focus_averages(self, stats: Dict[str, Any]): - return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore - @staticmethod def _format_total_stat(stats: Dict[str, Any]) -> str: return StatisticOutputTask._format_total_stat(stats) @@ -2353,9 +1379,6 @@ class AsyncStatisticOutputTask(AsyncTask): def _format_chat_stat(self, stats: Dict[str, Any]) -> str: return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore - def _format_focus_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore @@ -2368,11 +1391,5 @@ class AsyncStatisticOutputTask(AsyncTask): def _get_chat_display_name_from_id(self, chat_id: str) -> str: return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore - def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore - - def _generate_versions_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore - def _convert_defaultdict_to_dict(self, data): return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 071f1886c..3ee4ae7b1 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -78,7 +78,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: # print(f"is_mentioned: {is_mentioned}") # print(f"is_at: {is_at}") - if is_at and global_config.normal_chat.at_bot_inevitable_reply: + if is_at and global_config.chat.at_bot_inevitable_reply: reply_probability = 1.0 logger.debug("被@,回复概率设置为100%") else: @@ -103,7 +103,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: for nickname in nicknames: if nickname in message_content: is_mentioned = True - if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply: + if is_mentioned and global_config.chat.mentioned_bot_inevitable_reply: reply_probability = 1.0 logger.debug("被提及,回复概率设置为100%") return is_mentioned, reply_probability @@ -619,9 +619,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: chat_target_info = None try: - chat_stream = get_chat_manager().get_stream(chat_id) - - if chat_stream: + if chat_stream := get_chat_manager().get_stream(chat_id): if chat_stream.group_info: is_group_chat = True chat_target_info = None # Explicitly None for group chat @@ -660,8 +658,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: chat_target_info = target_info else: logger.warning(f"无法获取 chat_stream for {chat_id} in utils") - # Keep defaults: is_group_chat=False, chat_target_info=None - except Exception as e: logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True) # Keep defaults on error diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 0ab5559cb..858d95aa3 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -94,7 +94,7 @@ class ImageManager: logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") async def get_emoji_description(self, image_base64: str) -> str: - """获取表情包描述,带查重和保存功能""" + """获取表情包描述,使用二步走识别并带缓存优化""" try: # 计算图片哈希 # 确保base64字符串只包含ASCII字符 @@ -107,33 +107,66 @@ class ImageManager: # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: - return f"[表情包,含义看起来是:{cached_description}]" + return f"[表情包:{cached_description}]" - # 调用AI获取描述 + # === 二步走识别流程 === + + # 第一步:VLM视觉分析 - 生成详细描述 if image_format in ["gif", "GIF"]: image_base64_processed = self.transform_gif(image_base64) if image_base64_processed is None: logger.warning("GIF转换失败,无法获取描述") return "[表情包(GIF处理失败)]" - prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg") + vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg") else: - prompt = "图片是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64, image_format) - if description is None: - logger.warning("AI未能生成表情包描述") - return "[表情包(描述生成失败)]" + if detailed_description is None: + logger.warning("VLM未能生成表情包详细描述") + return "[表情包(VLM描述生成失败)]" + + # 第二步:LLM情感分析 - 基于详细描述生成简短的情感标签 + emotion_prompt = f""" + 请你基于这个表情包的详细描述,提取出最核心的情感含义,用1-2个词概括。 + 详细描述:'{detailed_description}' + + 要求: + 1. 只输出1-2个最核心的情感词汇 + 2. 从互联网梗、meme的角度理解 + 3. 输出简短精准,不要解释 + 4. 如果有多个词用逗号分隔 + """ + + # 使用较低温度确保输出稳定 + emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji") + emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt) + + if emotion_result is None: + logger.warning("LLM未能生成情感标签,使用详细描述的前几个词") + # 降级处理:从详细描述中提取关键词 + import jieba + words = list(jieba.cut(detailed_description)) + emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") + + # 处理情感结果,取前1-2个最重要的标签 + emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] + final_emotion = emotions[0] if emotions else "表情" + + # 如果有第二个情感且不重复,也包含进来 + if len(emotions) > 1 and emotions[1] != emotions[0]: + final_emotion = f"{emotions[0]},{emotions[1]}" + + logger.info(f"[二步走识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") # 再次检查缓存,防止并发写入时重复生成 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") - return f"[表情包,含义看起来是:{cached_description}]" + return f"[表情包:{cached_description}]" - # 根据配置决定是否保存图片 - # if global_config.emoji.save_emoji: - # 生成文件名和路径 + # 保存表情包文件和元数据(用于可能的后续分析) logger.debug(f"保存表情包: {image_hash}") current_timestamp = time.time() filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" @@ -146,11 +179,11 @@ class ImageManager: with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 (Images表) + # 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程 try: img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) img_obj.path = file_path - img_obj.description = description + img_obj.description = detailed_description # 保存详细描述 img_obj.timestamp = current_timestamp img_obj.save() except Images.DoesNotExist: # type: ignore @@ -158,17 +191,17 @@ class ImageManager: emoji_hash=image_hash, path=file_path, type="emoji", - description=description, + description=detailed_description, # 保存详细描述 timestamp=current_timestamp, ) - # logger.debug(f"保存表情包元数据: {file_path}") except Exception as e: logger.error(f"保存表情包文件或元数据失败: {str(e)}") - # 保存描述到数据库 (ImageDescriptions表) - self._save_description_to_db(image_hash, description, "emoji") + # 保存最终的情感标签到缓存 (ImageDescriptions表) + self._save_description_to_db(image_hash, final_emotion, "emoji") - return f"[表情包:{description}]" + return f"[表情包:{final_emotion}]" + except Exception as e: logger.error(f"获取表情包描述失败: {str(e)}") return "[表情包]" diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index 1bc3e7dda..cf71dc56f 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -11,7 +11,7 @@ logger = get_logger("chat_voice") async def get_voice_text(voice_base64: str) -> str: """获取音频文件描述""" - if not global_config.chat.enable_asr: + if not global_config.voice.enable_asr: logger.warning("语音识别未启用,无法处理语音消息") return "[语音]" try: diff --git a/src/chat/willing/mode_classical.py b/src/chat/willing/mode_classical.py index e15272332..57400c44d 100644 --- a/src/chat/willing/mode_classical.py +++ b/src/chat/willing/mode_classical.py @@ -35,7 +35,7 @@ class ClassicalWillingManager(BaseWillingManager): if interested_rate > 0.2: current_willing += interested_rate - 0.2 - if willing_info.is_mentioned_bot and global_config.normal_chat.mentioned_bot_inevitable_reply and current_willing < 2: + if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2: current_willing += 1 if current_willing < 1.0 else 0.05 self.chat_reply_willing[chat_id] = min(current_willing, 1.0) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 645b0a5d6..1d0b8a397 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -306,6 +306,7 @@ class Expression(BaseModel): last_active_time = FloatField() chat_id = TextField(index=True) type = TextField() + create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据 class Meta: table_name = "expression" @@ -449,9 +450,12 @@ def initialize_database(): alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" alter_sql += " NULL" if field_obj.null else " NOT NULL" if hasattr(field_obj, "default") and field_obj.default is not None: - # 正确处理不同类型的默认值 + # 正确处理不同类型的默认值,跳过lambda函数 default_value = field_obj.default - if isinstance(default_value, str): + if callable(default_value): + # 跳过lambda函数或其他可调用对象,这些无法在SQL中表示 + pass + elif isinstance(default_value, str): alter_sql += f" DEFAULT '{default_value}'" elif isinstance(default_value, bool): alter_sql += f" DEFAULT {int(default_value)}" diff --git a/src/common/logger.py b/src/common/logger.py index aa80af551..a6bfc2634 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -321,7 +321,7 @@ MODULE_COLORS = { # 核心模块 "main": "\033[1;97m", # 亮白色+粗体 (主程序) "api": "\033[92m", # 亮绿色 - "emoji": "\033[33m", # 亮绿色 + "emoji": "\033[38;5;214m", # 橙黄色,偏向橙色但与replyer和action_manager不同 "chat": "\033[92m", # 亮蓝色 "config": "\033[93m", # 亮黄色 "common": "\033[95m", # 亮紫色 @@ -329,35 +329,33 @@ MODULE_COLORS = { "lpmm": "\033[96m", "plugin_system": "\033[91m", # 亮红色 "person_info": "\033[32m", # 绿色 - "individuality": "\033[34m", # 蓝色 + "individuality": "\033[94m", # 显眼的亮蓝色 "manager": "\033[35m", # 紫色 "llm_models": "\033[36m", # 青色 - "plugins": "\033[31m", # 红色 - "plugin_api": "\033[33m", # 黄色 - "remote": "\033[38;5;93m", # 紫蓝色 + "remote": "\033[38;5;242m", # 深灰色,更不显眼 "planner": "\033[36m", "memory": "\033[34m", - "hfc": "\033[96m", - "action_manager": "\033[38;5;166m", + "hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读 + "action_manager": "\033[38;5;208m", # 橙色,不与replyer重复 # 关系系统 - "relation": "\033[38;5;201m", # 深粉色 + "relation": "\033[38;5;139m", # 柔和的紫色,不刺眼 # 聊天相关模块 "normal_chat": "\033[38;5;81m", # 亮蓝绿色 - "normal_chat_response": "\033[38;5;123m", # 青绿色 - "heartflow": "\033[38;5;213m", # 粉色 + "heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系 "sub_heartflow": "\033[38;5;207m", # 粉紫色 "subheartflow_manager": "\033[38;5;201m", # 深粉色 "background_tasks": "\033[38;5;240m", # 灰色 "chat_message": "\033[38;5;45m", # 青色 "chat_stream": "\033[38;5;51m", # 亮青色 - "sender": "\033[38;5;39m", # 蓝色 + "sender": "\033[38;5;67m", # 稍微暗一些的蓝色,不显眼 "message_storage": "\033[38;5;33m", # 深蓝色 + "expressor": "\033[38;5;166m", # 橙色 # 专注聊天模块 "replyer": "\033[38;5;166m", # 橙色 - "base_processor": "\033[38;5;190m", # 绿黄色 - "working_memory": "\033[38;5;22m", # 深绿色 "memory_activator": "\033[34m", # 绿色 # 插件系统 + "plugins": "\033[31m", # 红色 + "plugin_api": "\033[33m", # 黄色 "plugin_manager": "\033[38;5;208m", # 红色 "base_plugin": "\033[38;5;202m", # 橙红色 "send_api": "\033[38;5;208m", # 橙色 @@ -378,9 +376,9 @@ MODULE_COLORS = { "local_storage": "\033[38;5;141m", # 紫色 "willing": "\033[38;5;147m", # 浅紫色 # 工具模块 - "tool_use": "\033[38;5;64m", # 深绿色 - "tool_executor": "\033[38;5;64m", # 深绿色 - "base_tool": "\033[38;5;70m", # 绿色 + "tool_use": "\033[38;5;172m", # 橙褐色 + "tool_executor": "\033[38;5;172m", # 橙褐色 + "base_tool": "\033[38;5;178m", # 金黄色 # 工具和实用模块 "prompt_build": "\033[38;5;105m", # 紫色 "chat_utils": "\033[38;5;111m", # 蓝色 @@ -388,14 +386,16 @@ MODULE_COLORS = { "maibot_statistic": "\033[38;5;129m", # 紫色 # 特殊功能插件 "mute_plugin": "\033[38;5;240m", # 灰色 - "example_comprehensive": "\033[38;5;246m", # 浅灰色 "core_actions": "\033[38;5;117m", # 深红色 "tts_action": "\033[38;5;58m", # 深黄色 "doubao_pic_plugin": "\033[38;5;64m", # 深绿色 - "vtb_action": "\033[38;5;70m", # 绿色 + # Action组件 + "no_reply_action": "\033[38;5;196m", # 亮红色,更显眼 + "reply_action": "\033[38;5;46m", # 亮绿色 + "base_action": "\033[38;5;250m", # 浅灰色 # 数据库和消息 "database_model": "\033[38;5;94m", # 橙褐色 - "maim_message": "\033[38;5;100m", # 绿褐色 + "maim_message": "\033[38;5;140m", # 紫褐色 # 日志系统 "logger": "\033[38;5;8m", # 深灰色 "confirm": "\033[1;93m", # 黄色+粗体 @@ -409,6 +409,34 @@ MODULE_COLORS = { "S4U_chat": "\033[92m", # 深灰色 } +# 定义模块别名映射 - 将真实的logger名称映射到显示的别名 +MODULE_ALIASES = { + # 示例映射 + "individuality": "人格特质", + "emoji": "表情包", + "no_reply_action": "摸鱼", + "reply_action": "回复", + "action_manager": "动作", + "memory_activator": "记忆", + "tool_use": "工具", + "expressor": "表达方式", + "database_model": "数据库", + "mood": "情绪", + "memory": "记忆", + "tool_executor": "工具", + "hfc": "聊天节奏", + "chat": "所见", + "plugin_manager": "插件", + "relationship_builder": "关系", + "llm_models": "模型", + "person_info": "人物", + "chat_stream": "聊天流", + "planner": "规划器", + "replyer": "言语", + "config": "配置", + "main": "主程序", +} + RESET_COLOR = "\033[0m" @@ -497,15 +525,18 @@ class ModuleColoredConsoleRenderer: if self._colors and self._enable_module_colors and logger_name: module_color = MODULE_COLORS.get(logger_name, "") - # 模块名称(带颜色) + # 模块名称(带颜色和别名支持) if logger_name: + # 获取别名,如果没有别名则使用原名称 + display_name = MODULE_ALIASES.get(logger_name, logger_name) + if self._colors and self._enable_module_colors: if module_color: - module_part = f"{module_color}[{logger_name}]{RESET_COLOR}" + module_part = f"{module_color}[{display_name}]{RESET_COLOR}" else: - module_part = f"[{logger_name}]" + module_part = f"[{display_name}]" else: - module_part = f"[{logger_name}]" + module_part = f"[{display_name}]" parts.append(module_part) # 消息内容(确保转换为字符串) @@ -715,19 +746,7 @@ def configure_logging( root_logger.setLevel(getattr(logging, level.upper())) -def set_module_color(module_name: str, color_code: str): - """为指定模块设置颜色 - Args: - module_name: 模块名称 - color_code: ANSI颜色代码,例如 '\033[92m' 表示亮绿色 - """ - MODULE_COLORS[module_name] = color_code - - -def get_module_colors(): - """获取当前模块颜色配置""" - return MODULE_COLORS.copy() def reload_log_config(): @@ -918,9 +937,20 @@ def show_module_colors(): for module_name, _color_code in MODULE_COLORS.items(): # 临时创建一个该模块的logger来展示颜色 demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name) - demo_logger.info(f"这是 {module_name} 模块的颜色效果") + alias = MODULE_ALIASES.get(module_name, module_name) + if alias != module_name: + demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})") + else: + demo_logger.info(f"这是 {module_name} 模块的颜色效果") print("=== 颜色展示结束 ===\n") + + # 显示别名映射表 + if MODULE_ALIASES: + print("=== 当前别名映射 ===") + for module_name, alias in MODULE_ALIASES.items(): + print(f" {module_name} -> {alias}") + print("=== 别名映射结束 ===\n") def format_json_for_logging(data, indent=2, ensure_ascii=False): diff --git a/src/config/config.py b/src/config/config.py index d14b89583..bd2d58f04 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -36,6 +36,7 @@ from src.config.official_configs import ( LPMMKnowledgeConfig, RelationshipConfig, ToolConfig, + VoiceConfig, DebugConfig, CustomPromptConfig, ) @@ -64,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.9.0-snapshot.2" +MMC_VERSION = "0.9.1" @@ -616,7 +617,7 @@ class Config(ConfigBase): tool: ToolConfig debug: DebugConfig custom_prompt: CustomPromptConfig - + voice: VoiceConfig def load_config(config_path: str) -> Config: """ diff --git a/src/config/official_configs.py b/src/config/official_configs.py index af561bec3..335a2be59 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -18,6 +18,9 @@ from packaging.version import Version @dataclass class BotConfig(ConfigBase): """QQ机器人配置类""" + + platform: str + """平台""" qq_account: str """QQ账号""" @@ -82,6 +85,12 @@ class ChatConfig(ConfigBase): use_s4u_prompt_mode: bool = False """是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建""" + mentioned_bot_inevitable_reply: bool = False + """提及 bot 必然回复""" + + at_bot_inevitable_reply: bool = False + """@bot 必然回复""" + # 修改:基于时段的回复频率配置,改为数组格式 time_based_talk_frequency: list[str] = field(default_factory=lambda: []) """ @@ -107,9 +116,6 @@ class ChatConfig(ConfigBase): focus_value: float = 1.0 """麦麦的专注思考能力,越低越容易专注,消耗token也越多""" - enable_asr: bool = False - """是否启用语音识别""" - def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -271,11 +277,7 @@ class NormalChatConfig(ConfigBase): response_interested_rate_amplifier: float = 1.0 """回复兴趣度放大系数""" - mentioned_bot_inevitable_reply: bool = False - """提及 bot 必然回复""" - at_bot_inevitable_reply: bool = False - """@bot 必然回复""" @dataclass @@ -310,6 +312,13 @@ class ToolConfig(ConfigBase): enable_in_focus_chat: bool = True """是否在专注聊天中启用工具""" + +@dataclass +class VoiceConfig(ConfigBase): + """语音识别配置类""" + + enable_asr: bool = False + """是否启用语音识别""" @dataclass @@ -400,15 +409,9 @@ class MoodConfig(ConfigBase): enable_mood: bool = False """是否启用情绪系统""" - - mood_update_interval: int = 1 - """情绪更新间隔(秒)""" - - mood_decay_rate: float = 0.95 - """情绪衰减率""" - - mood_intensity_factor: float = 0.7 - """情绪强度因子""" + + mood_update_threshold: float = 1.0 + """情绪更新阈值,越高,更新越慢""" @dataclass diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index ac2281d39..fc7156e14 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -1,7 +1,7 @@ -import ast import json import os import hashlib +import time from src.common.logger import get_logger from src.config.config import global_config @@ -9,8 +9,6 @@ from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager from rich.traceback import install -from .personality import Personality - install(extra_lines=3) logger = get_logger("individuality") @@ -20,12 +18,10 @@ class Individuality: """个体特征管理类""" def __init__(self): - # 正常初始化实例属性 - self.personality: Personality = None # type: ignore - self.name = "" self.bot_person_id = "" self.meta_info_file_path = "data/personality/meta.json" + self.personality_data_file_path = "data/personality/personality_data.json" self.model = LLMRequest( model=global_config.model.utils, @@ -33,20 +29,13 @@ class Individuality: ) async def initialize(self) -> None: - """初始化个体特征 - - Args: - bot_nickname: 机器人昵称 - personality_core: 人格核心特点 - personality_side: 人格侧面描述 - identity: 身份细节描述 - """ + """初始化个体特征""" bot_nickname = global_config.bot.nickname personality_core = global_config.personality.personality_core personality_side = global_config.personality.personality_side identity = global_config.personality.identity - logger.info("正在初始化个体特征") + person_info_manager = get_person_info_manager() self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") self.name = bot_nickname @@ -56,129 +45,61 @@ class Individuality: bot_nickname, personality_core, personality_side, identity ) - # 初始化人格(现在包含身份) - self.personality = Personality.initialize( - bot_nickname=bot_nickname, - personality_core=personality_core, - personality_side=personality_side, - identity=identity, - compress_personality=global_config.personality.compress_personality, - compress_identity=global_config.personality.compress_identity, - ) + logger.info("正在构建人设信息") - logger.info("正在将所有人设写入impression") - # 将所有人设写入impression - impression_parts = [] - if personality_core: - impression_parts.append(f"核心人格: {personality_core}") - if personality_side: - impression_parts.append(f"人格侧面: {personality_side}") - if identity: - impression_parts.append(f"身份: {identity}") - logger.info(f"impression_parts: {impression_parts}") + # 如果配置有变化,重新生成压缩版本 + if personality_changed or identity_changed: + logger.info("检测到配置变化,重新生成压缩版本") + personality_result = await self._create_personality(personality_core, personality_side) + identity_result = await self._create_identity(identity) + else: + logger.info("配置未变化,使用缓存版本") + # 从文件中获取已有的结果 + personality_result, identity_result = self._get_personality_from_file() + if not personality_result or not identity_result: + logger.info("未找到有效缓存,重新生成") + personality_result = await self._create_personality(personality_core, personality_side) + identity_result = await self._create_identity(identity) - impression_text = "。".join(impression_parts) - if impression_text: - impression_text += "。" + # 保存到文件 + if personality_result and identity_result: + self._save_personality_to_file(personality_result, identity_result) + logger.info("已将人设构建并保存到文件") + else: + logger.error("人设构建失败") - if impression_text: + # 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设) + if personality_changed or identity_changed: + logger.info("将清空数据库中原有的关键词缓存") update_data = { "platform": "system", "user_id": "bot_id", "person_name": self.name, "nickname": self.name, } - - await person_info_manager.update_one_field( - self.bot_person_id, "impression", impression_text, data=update_data - ) - logger.debug("已将完整人设更新到bot的impression中") - - # 根据变化情况决定是否重新创建 - personality_result = None - identity_result = None - - if personality_changed: - logger.info("检测到人格配置变化,重新生成压缩版本") - personality_result = await self._create_personality(personality_core, personality_side) - else: - logger.info("人格配置未变化,使用缓存版本") - # 从缓存中获取已有的personality结果 - existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") - if existing_short_impression: - try: - existing_data = ast.literal_eval(existing_short_impression) # type: ignore - if isinstance(existing_data, list) and len(existing_data) >= 1: - personality_result = existing_data[0] - except (json.JSONDecodeError, TypeError, IndexError): - logger.warning("无法解析现有的short_impression,将重新生成人格部分") - personality_result = await self._create_personality(personality_core, personality_side) - else: - logger.info("未找到现有的人格缓存,重新生成") - personality_result = await self._create_personality(personality_core, personality_side) - - if identity_changed: - logger.info("检测到身份配置变化,重新生成压缩版本") - identity_result = await self._create_identity(identity) - else: - logger.info("身份配置未变化,使用缓存版本") - # 从缓存中获取已有的identity结果 - existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") - if existing_short_impression: - try: - existing_data = ast.literal_eval(existing_short_impression) # type: ignore - if isinstance(existing_data, list) and len(existing_data) >= 2: - identity_result = existing_data[1] - except (json.JSONDecodeError, TypeError, IndexError): - logger.warning("无法解析现有的short_impression,将重新生成身份部分") - identity_result = await self._create_identity(identity) - else: - logger.info("未找到现有的身份缓存,重新生成") - identity_result = await self._create_identity(identity) - - result = [personality_result, identity_result] - - # 更新short_impression字段 - if personality_result and identity_result: - person_info_manager = get_person_info_manager() - await person_info_manager.update_one_field(self.bot_person_id, "short_impression", result) - logger.info("已将人设构建") - else: - logger.error("人设构建失败") + await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data) async def get_personality_block(self) -> str: - person_info_manager = get_person_info_manager() - bot_person_id = person_info_manager.get_person_id("system", "bot_id") - bot_name = global_config.bot.nickname if global_config.bot.alias_names: bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" else: bot_nickname = "" - short_impression = await person_info_manager.get_value(bot_person_id, "short_impression") - # 解析字符串形式的Python列表 - try: - if isinstance(short_impression, str) and short_impression.strip(): - short_impression = ast.literal_eval(short_impression) - elif not short_impression: - logger.warning("short_impression为空,使用默认值") - short_impression = ["友好活泼", "人类"] - except (ValueError, SyntaxError) as e: - logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}") - short_impression = ["友好活泼", "人类"] + + # 从文件获取 short_impression + personality, identity = self._get_personality_from_file() + # 确保short_impression是列表格式且有足够的元素 - if not isinstance(short_impression, list) or len(short_impression) < 2: - logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值") - short_impression = ["友好活泼", "人类"] - personality = short_impression[0] - identity = short_impression[1] - prompt_personality = f"{personality},{identity}" - identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - - return identity_block + if not personality or not identity: + logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值") + personality = "友好活泼" + identity = "人类" + + prompt_personality = f"{personality}\n{identity}" + return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" def _get_config_hash( - self, bot_nickname: str, personality_core: str, personality_side: str, identity: list + self, bot_nickname: str, personality_core: str, personality_side: str, identity: str ) -> tuple[str, str]: """获取personality和identity配置的哈希值 @@ -190,15 +111,15 @@ class Individuality: "nickname": bot_nickname, "personality_core": personality_core, "personality_side": personality_side, - "compress_personality": self.personality.compress_personality if self.personality else True, + "compress_personality": global_config.personality.compress_personality, } personality_str = json.dumps(personality_config, sort_keys=True) personality_hash = hashlib.md5(personality_str.encode("utf-8")).hexdigest() # 身份配置哈希 identity_config = { - "identity": sorted(identity), - "compress_identity": self.personality.compress_identity if self.personality else True, + "identity": identity, + "compress_identity": global_config.personality.compress_identity, } identity_str = json.dumps(identity_config, sort_keys=True) identity_hash = hashlib.md5(identity_str.encode("utf-8")).hexdigest() @@ -206,7 +127,7 @@ class Individuality: return personality_hash, identity_hash async def _check_config_and_clear_if_changed( - self, bot_nickname: str, personality_core: str, personality_side: str, identity: list + self, bot_nickname: str, personality_core: str, personality_side: str, identity: str ) -> tuple[bool, bool]: """检查配置是否发生变化,如果变化则清空相应缓存 @@ -271,6 +192,53 @@ class Individuality: except IOError as e: logger.error(f"保存meta_info文件失败: {e}") + def _load_personality_data(self) -> dict: + """从JSON文件中加载personality数据""" + if os.path.exists(self.personality_data_file_path): + try: + with open(self.personality_data_file_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.error(f"读取personality_data文件失败: {e}, 将创建新文件。") + return {} + return {} + + def _save_personality_data(self, personality_data: dict): + """将personality数据保存到JSON文件""" + try: + os.makedirs(os.path.dirname(self.personality_data_file_path), exist_ok=True) + with open(self.personality_data_file_path, "w", encoding="utf-8") as f: + json.dump(personality_data, f, ensure_ascii=False, indent=2) + logger.debug(f"已保存personality数据到文件: {self.personality_data_file_path}") + except IOError as e: + logger.error(f"保存personality_data文件失败: {e}") + + def _get_personality_from_file(self) -> tuple[str, str]: + """从文件获取personality数据 + + Returns: + tuple: (personality, identity) + """ + personality_data = self._load_personality_data() + personality = personality_data.get("personality", "友好活泼") + identity = personality_data.get("identity", "人类") + return personality, identity + + def _save_personality_to_file(self, personality: str, identity: str): + """保存personality数据到文件 + + Args: + personality: 压缩后的人格描述 + identity: 压缩后的身份描述 + """ + personality_data = { + "personality": personality, + "identity": identity, + "bot_nickname": self.name, + "last_updated": int(time.time()) + } + self._save_personality_data(personality_data) + async def _create_personality(self, personality_core: str, personality_side: str) -> str: # sourcery skip: merge-list-append, move-assign """使用LLM创建压缩版本的impression @@ -290,7 +258,7 @@ class Individuality: personality_parts.append(f"{personality_core}") # 准备需要压缩的内容 - if self.personality.compress_personality: + if global_config.personality.compress_personality: personality_to_compress = f"人格特质: {personality_side}" prompt = f"""请将以下人格信息进行简洁压缩,保留主要内容,用简练的中文表达: @@ -321,11 +289,11 @@ class Individuality: return personality_result - async def _create_identity(self, identity: list) -> str: + async def _create_identity(self, identity: str) -> str: """使用LLM创建压缩版本的impression""" logger.info("正在构建身份.........") - if self.personality.compress_identity: + if global_config.personality.compress_identity: identity_to_compress = f"身份背景: {identity}" prompt = f"""请将以下身份信息进行简洁压缩,保留主要内容,用简练的中文表达: diff --git a/src/individuality/template_scene.json b/src/individuality/not_using/template_scene.json similarity index 100% rename from src/individuality/template_scene.json rename to src/individuality/not_using/template_scene.json diff --git a/src/individuality/personality.py b/src/individuality/personality.py deleted file mode 100644 index 87907df76..000000000 --- a/src/individuality/personality.py +++ /dev/null @@ -1,91 +0,0 @@ - -from dataclasses import dataclass -from typing import Dict, List - - -@dataclass -class Personality: - """人格特质类""" - - bot_nickname: str # 机器人昵称 - personality_core: str # 人格核心特点 - personality_side: str # 人格侧面描述 - identity: List[str] # 身份细节描述 - compress_personality: bool # 是否压缩人格 - compress_identity: bool # 是否压缩身份 - - _instance = None - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self, personality_core: str = "", personality_side: str = "", identity: List[str] = None): - self.personality_core = personality_core - self.personality_side = personality_side - self.identity = identity - self.compress_personality = True - self.compress_identity = True - - @classmethod - def get_instance(cls) -> "Personality": - """获取Personality单例实例 - - Returns: - Personality: 单例实例 - """ - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def initialize( - cls, - bot_nickname: str, - personality_core: str, - personality_side: str, - identity: List[str] = None, - compress_personality: bool = True, - compress_identity: bool = True, - ) -> "Personality": - """初始化人格特质 - - Args: - bot_nickname: 机器人昵称 - personality_core: 人格核心特点 - personality_side: 人格侧面描述 - identity: 身份细节描述 - compress_personality: 是否压缩人格 - compress_identity: 是否压缩身份 - - Returns: - Personality: 初始化后的人格特质实例 - """ - instance = cls.get_instance() - instance.bot_nickname = bot_nickname - instance.personality_core = personality_core - instance.personality_side = personality_side - instance.identity = identity - instance.compress_personality = compress_personality - instance.compress_identity = compress_identity - return instance - - def to_dict(self) -> Dict: - """将人格特质转换为字典格式""" - return { - "bot_nickname": self.bot_nickname, - "personality_core": self.personality_core, - "personality_side": self.personality_side, - "identity": self.identity, - "compress_personality": self.compress_personality, - "compress_identity": self.compress_identity, - } - - @classmethod - def from_dict(cls, data: Dict) -> "Personality": - """从字典创建人格特质实例""" - instance = cls.get_instance() - for key, value in data.items(): - setattr(instance, key, value) - return instance diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 3621b4502..c994cd173 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -10,6 +10,7 @@ import base64 from PIL import Image import io import os +import copy # 添加copy模块用于深拷贝 from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from src.config.config import global_config @@ -69,23 +70,28 @@ error_code_mapping = { async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]): + """安全地记录请求体,用于调试日志,不会修改原始payload对象""" + # 创建payload的深拷贝,避免修改原始对象 + safe_payload = copy.deepcopy(payload) + image_base64: str = request_content.get("image_base64") image_format: str = request_content.get("image_format") if ( image_base64 - and payload - and isinstance(payload, dict) - and "messages" in payload - and len(payload["messages"]) > 0 + and safe_payload + and isinstance(safe_payload, dict) + and "messages" in safe_payload + and len(safe_payload["messages"]) > 0 ): - if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]: - content = payload["messages"][0]["content"] + if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]: + content = safe_payload["messages"][0]["content"] if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: - payload["messages"][0]["content"][1]["image_url"]["url"] = ( + # 只修改拷贝的对象,用于安全的日志记录 + safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) - return payload + return safe_payload class LLMRequest: @@ -109,10 +115,15 @@ class LLMRequest: def __init__(self, model: dict, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 + logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}") + logger.debug(f"🔍 [模型初始化] 模型配置: {model}") + logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") + try: # print(f"model['provider']: {model['provider']}") self.api_key = os.environ[f"{model['provider']}_KEY"] self.base_url = os.environ[f"{model['provider']}_BASE_URL"] + logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") except AttributeError as e: logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") @@ -124,6 +135,10 @@ class LLMRequest: self.model_name: str = model["name"] self.params = kwargs + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + self.enable_thinking = model.get("enable_thinking", False) self.temp = model.get("temp", 0.7) self.thinking_budget = model.get("thinking_budget", 4096) @@ -132,12 +147,24 @@ class LLMRequest: self.pri_out = model.get("pri_out", 0) self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) # print(f"max_tokens: {self.max_tokens}") + + logger.debug(f"🔍 [模型初始化] 模型参数设置完成:") + logger.debug(f" - model_name: {self.model_name}") + logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") + logger.debug(f" - enable_thinking: {self.enable_thinking}") + logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") + logger.debug(f" - thinking_budget: {self.thinking_budget}") + logger.debug(f" - temp: {self.temp}") + logger.debug(f" - stream: {self.stream}") + logger.debug(f" - max_tokens: {self.max_tokens}") + logger.debug(f" - base_url: {self.base_url}") # 获取数据库实例 self._init_database() # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" self.request_type = kwargs.pop("request_type", "default") + logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") @staticmethod def _init_database(): @@ -262,11 +289,12 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False + # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) + if self.has_enable_thinking: + payload["enable_thinking"] = self.enable_thinking - if self.thinking_budget != 4096: + # 添加thinking_budget参数(只有配置文件中声明了才添加) + if self.has_thinking_budget: payload["thinking_budget"] = self.thinking_budget if self.max_tokens: @@ -334,6 +362,19 @@ class LLMRequest: # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 if request_content["stream_mode"]: headers["Accept"] = "text/event-stream" + + # 添加请求发送前的调试信息 + logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求") + logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}") + logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}") + + if not file_bytes: + # 安全地记录请求体(隐藏敏感信息) + safe_payload = await _safely_record(request_content, request_content["payload"]) + logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") + else: + logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}") + async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: post_kwargs = {"headers": headers} # form-data数据上传方式不同 @@ -491,7 +532,36 @@ class LLMRequest: logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") raise RuntimeError("请求限制(429)") elif response.status in policy["abort_codes"]: - if response.status != 403: + # 特别处理400错误,添加详细调试信息 + if response.status == 400: + logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断") + logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}") + logger.error(f"🔍 [调试信息] API地址: {self.base_url}") + logger.error(f"🔍 [调试信息] 模型配置参数:") + logger.error(f" - enable_thinking: {self.enable_thinking}") + logger.error(f" - temp: {self.temp}") + logger.error(f" - thinking_budget: {self.thinking_budget}") + logger.error(f" - stream: {self.stream}") + logger.error(f" - max_tokens: {self.max_tokens}") + logger.error(f" - pri_in: {self.pri_in}") + logger.error(f" - pri_out: {self.pri_out}") + logger.error(f"🔍 [调试信息] 原始params: {self.params}") + + # 尝试获取服务器返回的详细错误信息 + try: + error_text = await response.text() + logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}") + + try: + error_json = json.loads(error_text) + logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}") + except json.JSONDecodeError: + logger.error(f"🔍 [调试信息] 错误响应不是有效的JSON格式") + except Exception as e: + logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}") + + raise RequestAbortException("参数错误,请检查调试信息", response) + elif response.status != 403: raise RequestAbortException("请求出现错误,中断处理", response) else: raise PermissionDeniedException("模型禁止访问") @@ -510,6 +580,19 @@ class LLMRequest: logger.error( f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" ) + + # 如果是400错误,额外输出请求体信息用于调试 + if response.status == 400: + logger.error(f"🔍 [异常调试] 400错误 - 请求体调试信息:") + try: + safe_payload = await _safely_record(request_content, payload) + logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") + except Exception as debug_error: + logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}") + logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}") + if isinstance(payload, dict): + logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}") + # print(request_content) # print(response) # 尝试获取并记录服务器返回的详细错误信息 @@ -654,14 +737,27 @@ class LLMRequest: """ # 复制一份参数,避免直接修改原始数据 new_params = dict(params) + + logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换") + logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}") + logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}") if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: + logger.debug(f"🔍 [参数转换] 检测到CoT模型,开始参数转换") # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度 if "temperature" in new_params and new_params["temperature"] == 0.7: - new_params.pop("temperature") + removed_temp = new_params.pop("temperature") + logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}") # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' if "max_tokens" in new_params: + old_value = new_params["max_tokens"] new_params["max_completion_tokens"] = new_params.pop("max_tokens") + logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})") + else: + logger.debug(f"🔍 [参数转换] 非CoT模型,无需参数转换") + + logger.debug(f"🔍 [参数转换] 转换前参数: {params}") + logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}") return new_params async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData: @@ -693,7 +789,12 @@ class LLMRequest: async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: """构建请求体""" # 复制一份参数,避免直接修改 self.params + logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体") + logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}") + params_copy = await self._transform_parameters(self.params) + logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}") + if image_base64: messages = [ { @@ -715,26 +816,37 @@ class LLMRequest: "messages": messages, **params_copy, } + + logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}") # 添加temp参数(如果不是默认值0.7) if self.temp != 0.7: payload["temperature"] = self.temp + logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}") - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False + # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) + if self.has_enable_thinking: + payload["enable_thinking"] = self.enable_thinking + logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}") - if self.thinking_budget != 4096: + # 添加thinking_budget参数(只有配置文件中声明了才添加) + if self.has_thinking_budget: payload["thinking_budget"] = self.thinking_budget + logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}") if self.max_tokens: payload["max_tokens"] = self.max_tokens + logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}") # if "max_tokens" not in payload and "max_completion_tokens" not in payload: # payload["max_tokens"] = global_config.model.model_max_output_length # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: + old_value = payload["max_tokens"] payload["max_completion_tokens"] = payload.pop("max_tokens") + logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})") + + logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}") return payload def _default_response_handler( diff --git a/src/main.py b/src/main.py index 3cd2107d1..aed9a2bf1 100644 --- a/src/main.py +++ b/src/main.py @@ -115,7 +115,6 @@ class MainSystem: # 初始化个体特征 await self.individuality.initialize() - logger.info("个体特征初始化成功") try: init_time = int(1000 * (time.time() - init_start_time)) diff --git a/src/mais4u/constant_s4u.py b/src/mais4u/constant_s4u.py new file mode 100644 index 000000000..8a7446405 --- /dev/null +++ b/src/mais4u/constant_s4u.py @@ -0,0 +1 @@ +ENABLE_S4U = False \ No newline at end of file diff --git a/src/chat/mai_thinking/mai_think.py b/src/mais4u/mai_think.py similarity index 98% rename from src/chat/mai_thinking/mai_think.py rename to src/mais4u/mai_think.py index c438b32eb..867ba8bef 100644 --- a/src/chat/mai_thinking/mai_think.py +++ b/src/mais4u/mai_think.py @@ -3,7 +3,7 @@ import time from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.llm_models.utils_model import LLMRequest from src.config.config import global_config -from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U +from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.common.logger import get_logger diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index e67cc7e38..e7380822d 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -1,6 +1,5 @@ import json import time -import random from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index 832c1f788..e447ae193 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -19,6 +19,7 @@ from src.mais4u.s4u_config import s4u_config from src.person_info.person_info import PersonInfoManager from .super_chat_manager import get_super_chat_manager from .yes_or_no import yes_or_no_head +from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("S4U_chat") @@ -165,7 +166,10 @@ class S4UChatManager: return self.s4u_chats[chat_stream.stream_id] -s4u_chat_manager = S4UChatManager() +if not ENABLE_S4U: + s4u_chat_manager = None +else: + s4u_chat_manager = S4UChatManager() def get_s4u_chat_manager() -> S4UChatManager: @@ -486,7 +490,7 @@ class S4UChat: logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'") if s4u_config.enable_streaming_output: - logger.info(f"[S4U] 开始流式输出") + logger.info("[S4U] 开始流式输出") # 流式输出,边生成边发送 gen = self.gpt.generate_response(message, "") async for chunk in gen: @@ -494,7 +498,7 @@ class S4UChat: await sender_container.add_message(chunk) total_chars_sent += len(chunk) else: - logger.info(f"[S4U] 开始一次性输出") + logger.info("[S4U] 开始一次性输出") # 一次性输出,先收集所有chunk all_chunks = [] gen = self.gpt.generate_response(message, "") diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index ffa0b3b01..c936cea17 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -10,6 +10,7 @@ from src.config.config import global_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api +from src.mais4u.constant_s4u import ENABLE_S4U """ 情绪管理系统使用说明: @@ -446,9 +447,10 @@ class MoodManager: # 发送初始情绪状态到ws端 asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values)) - -init_prompt() - -mood_manager = MoodManager() +if ENABLE_S4U: + init_prompt() + mood_manager = MoodManager() +else: + mood_manager = None """全局情绪管理器""" diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index 7e5d8e438..cbc7d3fac 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -4,7 +4,7 @@ from typing import Tuple from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from maim_message.message_base import GroupInfo,UserInfo +from maim_message.message_base import GroupInfo from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.timer_calculator import Timer diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 92a9ed277..d748c25e5 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -10,13 +10,13 @@ from datetime import datetime import asyncio from src.mais4u.s4u_config import s4u_config from src.chat.message_receive.message import MessageRecvS4U -from src.person_info.relationship_manager import get_relationship_manager +from src.person_info.relationship_fetcher import relationship_fetcher_manager +from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.chat.message_receive.chat_stream import ChatStream from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager from src.mais4u.mais4u_chat.screen_manager import screen_manager from src.chat.express.expression_selector import expression_selector from .s4u_mood_manager import mood_manager -from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.mais4u.mais4u_chat.internal_manager import internal_manager logger = get_logger("prompt") @@ -149,9 +149,17 @@ class PromptBuilder: relation_prompt = "" if global_config.relationship.enable_relationship and who_chat_in_group: - relationship_manager = get_relationship_manager() + relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id) + + # 将 (platform, user_id, nickname) 转换为 person_id + person_ids = [] + for person in who_chat_in_group: + person_id = PersonInfoManager.get_person_id(person[0], person[1]) + person_ids.append(person_id) + + # 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为 relation_info_list = await asyncio.gather( - *[relationship_manager.build_relationship_info(person) for person in who_chat_in_group] + *[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] ) relation_info = "".join(relation_info_list) if relation_info: diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index a7c96a254..339b46c33 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -5,7 +5,6 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder from src.common.logger import get_logger -from src.person_info.person_info import PersonInfoManager, get_person_info_manager import asyncio import re @@ -49,19 +48,19 @@ class S4UStreamGenerator: self.chat_stream =None async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""): - person_id = PersonInfoManager.get_person_id( - message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id - ) - person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") + # person_id = PersonInfoManager.get_person_id( + # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id + # ) + # person_info_manager = get_person_info_manager() + # person_name = await person_info_manager.get_value(person_id, "person_name") - if message.chat_stream.user_info.user_nickname: - if person_name: - sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" - else: - sender_name = f"[{message.chat_stream.user_info.user_nickname}]" - else: - sender_name = f"用户({message.chat_stream.user_info.user_id})" + # if message.chat_stream.user_info.user_nickname: + # if person_name: + # sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" + # else: + # sender_name = f"[{message.chat_stream.user_info.user_nickname}]" + # else: + # sender_name = f"用户({message.chat_stream.user_info.user_id})" # 构建prompt if previous_reply_context: diff --git a/src/mais4u/mais4u_chat/s4u_watching_manager.py b/src/mais4u/mais4u_chat/s4u_watching_manager.py index f02a1da3a..62ef6d86a 100644 --- a/src/mais4u/mais4u_chat/s4u_watching_manager.py +++ b/src/mais4u/mais4u_chat/s4u_watching_manager.py @@ -1,7 +1,3 @@ -import asyncio -import time -from enum import Enum -from typing import Optional from src.common.logger import get_logger from src.plugin_system.apis import send_api diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index b5706ca37..528eaecca 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import Dict, List, Optional from src.common.logger import get_logger from src.chat.message_receive.message import MessageRecvS4U +# 全局SuperChat管理器实例 +from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("super_chat_manager") @@ -296,10 +298,14 @@ class SuperChatManager: logger.info("SuperChat管理器已关闭") -# 全局SuperChat管理器实例 -super_chat_manager = SuperChatManager() +if ENABLE_S4U: + super_chat_manager = SuperChatManager() +else: + super_chat_manager = None + def get_super_chat_manager() -> SuperChatManager: """获取全局SuperChat管理器实例""" - return super_chat_manager \ No newline at end of file + + return super_chat_manager \ No newline at end of file diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py index 9e234082d..edc200f65 100644 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ b/src/mais4u/mais4u_chat/yes_or_no.py @@ -1,16 +1,6 @@ -import json -import time -import random -from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.config.config import global_config -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.plugin_system.apis import send_api -from json_repair import repair_json -from src.mais4u.s4u_config import s4u_config from src.plugin_system.apis import send_api logger = get_logger(__name__) diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py index 180513025..dbd7f3947 100644 --- a/src/mais4u/s4u_config.py +++ b/src/mais4u/s4u_config.py @@ -6,7 +6,7 @@ from tomlkit import TOMLDocument from tomlkit.items import Table from dataclasses import dataclass, fields, MISSING, field from typing import TypeVar, Type, Any, get_origin, get_args, Literal - +from src.mais4u.constant_s4u import ENABLE_S4U from src.common.logger import get_logger logger = get_logger("s4u_config") @@ -353,12 +353,16 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig: raise e -# 初始化S4U配置 -logger.info(f"S4U当前版本: {S4U_VERSION}") -update_s4u_config() +if not ENABLE_S4U: + s4u_config = None + s4u_config_main = None +else: + # 初始化S4U配置 + logger.info(f"S4U当前版本: {S4U_VERSION}") + update_s4u_config() -logger.info("正在加载S4U配置文件...") -s4u_config_main = load_s4u_config(config_path=CONFIG_PATH) -logger.info("S4U配置文件加载完成!") + logger.info("正在加载S4U配置文件...") + s4u_config_main = load_s4u_config(config_path=CONFIG_PATH) + logger.info("S4U配置文件加载完成!") -s4u_config: S4UConfig = s4u_config_main.s4u \ No newline at end of file + s4u_config: S4UConfig = s4u_config_main.s4u \ No newline at end of file diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 4134de9b9..38ed39bcc 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -83,12 +83,12 @@ class ChatMood: logger.debug( f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}" ) - update_probability = min(1.0, base_probability * time_multiplier * interest_multiplier) + update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier) if random.random() > update_probability: return - logger.info(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}") + logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}") message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( @@ -201,7 +201,7 @@ class MoodRegressionTask(AsyncTask): if mood.regression_count >= 3: continue - logger.info(f"chat {mood.chat_id} 开始情绪回归, 这是第 {mood.regression_count + 1} 次") + logger.info(f"{mood.log_prefix} 开始情绪回归, 这是第 {mood.regression_count + 1} 次") await mood.regress_mood() diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index eb463da35..6be0ad277 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -41,8 +41,6 @@ person_info_default = { "know_times": 0, "know_since": None, "last_know": None, - # "user_cardname": None, # This field is not in Peewee model PersonInfo - # "user_avatar": None, # This field is not in Peewee model PersonInfo "impression": None, # Corrected from person_impression "short_impression": None, "info_list": None, diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index deeb4c370..99f3be303 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -112,15 +112,6 @@ class RelationshipFetcher: current_points = await person_info_manager.get_value(person_id, "points") or [] - if isinstance(current_points, str): - try: - current_points = json.loads(current_points) - except json.JSONDecodeError: - logger.error(f"解析points JSON失败: {current_points}") - current_points = [] - elif not isinstance(current_points, list): - current_points = [] - # 按时间排序forgotten_points current_points.sort(key=lambda x: x[2]) # 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大 @@ -370,60 +361,6 @@ class RelationshipFetcher: logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}") logger.error(traceback.format_exc()) - def _organize_known_info(self) -> str: - """组织已知的用户信息为字符串 - - Returns: - str: 格式化的用户信息字符串 - """ - persons_infos_str = "" - - if self.info_fetched_cache: - persons_with_known_info = [] # 有已知信息的人员 - persons_with_unknown_info = [] # 有未知信息的人员 - - for person_id in self.info_fetched_cache: - person_known_infos = [] - person_unknown_infos = [] - person_name = "" - - for info_type in self.info_fetched_cache[person_id]: - person_name = self.info_fetched_cache[person_id][info_type]["person_name"] - if not self.info_fetched_cache[person_id][info_type]["unknown"]: - info_content = self.info_fetched_cache[person_id][info_type]["info"] - person_known_infos.append(f"[{info_type}]:{info_content}") - else: - person_unknown_infos.append(info_type) - - # 如果有已知信息,添加到已知信息列表 - if person_known_infos: - known_info_str = ";".join(person_known_infos) + ";" - persons_with_known_info.append((person_name, known_info_str)) - - # 如果有未知信息,添加到未知信息列表 - if person_unknown_infos: - persons_with_unknown_info.append((person_name, person_unknown_infos)) - - # 先输出有已知信息的人员 - for person_name, known_info_str in persons_with_known_info: - persons_infos_str += f"你对 {person_name} 的了解:{known_info_str}\n" - - # 统一处理未知信息,避免重复的警告文本 - if persons_with_unknown_info: - unknown_persons_details = [] - for person_name, unknown_types in persons_with_unknown_info: - unknown_types_str = "、".join(unknown_types) - unknown_persons_details.append(f"{person_name}的[{unknown_types_str}]") - - if len(unknown_persons_details) == 1: - persons_infos_str += ( - f"你不了解{unknown_persons_details[0]}信息,不要胡乱回答,可以直接说不知道或忘记了;\n" - ) - else: - unknown_all_str = "、".join(unknown_persons_details) - persons_infos_str += f"你不了解{unknown_all_str}等信息,不要胡乱回答,可以直接说不知道或忘记了;\n" - - return persons_infos_str async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): # sourcery skip: use-next diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index ecce06c65..01cc89e9a 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -55,60 +55,6 @@ class RelationshipManager: # person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar # ) - async def build_relationship_info(self, person, is_id: bool = False) -> str: - if is_id: - person_id = person - else: - person_id = PersonInfoManager.get_person_id(person[0], person[1]) - person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") - if not person_name or person_name == "none": - return "" - short_impression = await person_info_manager.get_value(person_id, "short_impression") - - current_points = await person_info_manager.get_value(person_id, "points") or [] - # print(f"current_points: {current_points}") - if isinstance(current_points, str): - try: - current_points = json.loads(current_points) - except json.JSONDecodeError: - logger.error(f"解析points JSON失败: {current_points}") - current_points = [] - elif not isinstance(current_points, list): - current_points = [] - - # 按时间排序forgotten_points - current_points.sort(key=lambda x: x[2]) - # 按权重加权随机抽取3个points,point[1]的值在1-10之间,权重越高被抽到概率越大 - if len(current_points) > 3: - # point[1] 取值范围1-10,直接作为权重 - weights = [max(1, min(10, int(point[1]))) for point in current_points] - points = random.choices(current_points, weights=weights, k=3) - else: - points = current_points - - # 构建points文本 - points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) - - nickname_str = await person_info_manager.get_value(person_id, "nickname") - platform = await person_info_manager.get_value(person_id, "platform") - - if person_name == nickname_str and not short_impression: - return "" - - if person_name == nickname_str: - relation_prompt = f"'{person_name}' :" - else: - relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。" - - if short_impression: - relation_prompt += f"你对ta的印象是:{short_impression}。\n" - - if points_text: - relation_prompt += f"你记得ta最近做的事:{points_text}" - - return relation_prompt - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): """更新用户印象 diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 491da7c1c..eb07dbc92 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -23,12 +23,6 @@ from .base import ( EventType, MaiMessages, ) -from .core import ( - plugin_manager, - component_registry, - dependency_manager, - events_manager, -) # 导入工具模块 from .utils import ( @@ -38,12 +32,42 @@ from .utils import ( # generate_plugin_manifest, ) -from .apis import register_plugin, get_logger +from .apis import ( + chat_api, + component_manage_api, + config_api, + database_api, + emoji_api, + generator_api, + llm_api, + message_api, + person_api, + plugin_manage_api, + send_api, + utils_api, + register_plugin, + get_logger, +) __version__ = "1.0.0" __all__ = [ + # API 模块 + "chat_api", + "component_manage_api", + "config_api", + "database_api", + "emoji_api", + "generator_api", + "llm_api", + "message_api", + "person_api", + "plugin_manage_api", + "send_api", + "utils_api", + "register_plugin", + "get_logger", # 基础类 "BasePlugin", "BaseAction", @@ -62,11 +86,6 @@ __all__ = [ "EventType", # 消息 "MaiMessages", - # 管理器 - "plugin_manager", - "component_registry", - "dependency_manager", - "events_manager", # 装饰器 "register_plugin", "ConfigField", diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 05cc62c72..0882fbdc6 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -7,6 +7,7 @@ # 导入所有API模块 from src.plugin_system.apis import ( chat_api, + component_manage_api, config_api, database_api, emoji_api, @@ -14,15 +15,17 @@ from src.plugin_system.apis import ( llm_api, message_api, person_api, + plugin_manage_api, send_api, utils_api, - plugin_register_api, ) from .logging_api import get_logger from .plugin_register_api import register_plugin + # 导出所有API模块,使它们可以通过 apis.xxx 方式访问 __all__ = [ "chat_api", + "component_manage_api", "config_api", "database_api", "emoji_api", @@ -30,9 +33,9 @@ __all__ = [ "llm_api", "message_api", "person_api", + "plugin_manage_api", "send_api", "utils_api", - "plugin_register_api", "get_logger", "register_plugin", ] diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py new file mode 100644 index 000000000..d9ea051d9 --- /dev/null +++ b/src/plugin_system/apis/component_manage_api.py @@ -0,0 +1,245 @@ +from typing import Optional, Union, Dict +from src.plugin_system.base.component_types import ( + CommandInfo, + ActionInfo, + EventHandlerInfo, + PluginInfo, + ComponentType, +) + + +# === 插件信息查询 === +def get_all_plugin_info() -> Dict[str, PluginInfo]: + """ + 获取所有插件的信息。 + + Returns: + dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_all_plugins() + + +def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: + """ + 获取指定插件的信息。 + + Args: + plugin_name (str): 插件名称。 + + Returns: + PluginInfo: 插件信息对象,如果插件不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_plugin_info(plugin_name) + + +# === 组件查询方法 === +def get_component_info( + component_name: str, component_type: ComponentType +) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]: + """ + 获取指定组件的信息。 + + Args: + component_name (str): 组件名称。 + component_type (ComponentType): 组件类型。 + Returns: + Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_component_info(component_name, component_type) # type: ignore + + +def get_components_info_by_type( + component_type: ComponentType, +) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: + """ + 获取指定类型的所有组件信息。 + + Args: + component_type (ComponentType): 组件类型。 + + Returns: + dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_components_by_type(component_type) # type: ignore + + +def get_enabled_components_info_by_type( + component_type: ComponentType, +) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: + """ + 获取指定类型的所有启用的组件信息。 + + Args: + component_type (ComponentType): 组件类型。 + + Returns: + dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_enabled_components_by_type(component_type) # type: ignore + + +# === Action 查询方法 === +def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: + """ + 获取指定 Action 的注册信息。 + + Args: + action_name (str): Action 名称。 + + Returns: + ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_registered_action_info(action_name) + + +def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: + """ + 获取指定 Command 的注册信息。 + + Args: + command_name (str): Command 名称。 + + Returns: + CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_registered_command_info(command_name) + + +# === EventHandler 特定查询方法 === +def get_registered_event_handler_info( + event_handler_name: str, +) -> Optional[EventHandlerInfo]: + """ + 获取指定 EventHandler 的注册信息。 + + Args: + event_handler_name (str): EventHandler 名称。 + + Returns: + EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_registered_event_handler_info(event_handler_name) + + +# === 组件管理方法 === +def globally_enable_component(component_name: str, component_type: ComponentType) -> bool: + """ + 全局启用指定组件。 + + Args: + component_name (str): 组件名称。 + component_type (ComponentType): 组件类型。 + + Returns: + bool: 启用成功返回 True,否则返回 False。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.enable_component(component_name, component_type) + + +async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool: + """ + 全局禁用指定组件。 + + **此函数是异步的,确保在异步环境中调用。** + + Args: + component_name (str): 组件名称。 + component_type (ComponentType): 组件类型。 + + Returns: + bool: 禁用成功返回 True,否则返回 False。 + """ + from src.plugin_system.core.component_registry import component_registry + + return await component_registry.disable_component(component_name, component_type) + + +def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: + """ + 局部启用指定组件。 + + Args: + component_name (str): 组件名称。 + component_type (ComponentType): 组件类型。 + stream_id (str): 消息流 ID。 + + Returns: + bool: 启用成功返回 True,否则返回 False。 + """ + from src.plugin_system.core.global_announcement_manager import global_announcement_manager + + match component_type: + case ComponentType.ACTION: + return global_announcement_manager.enable_specific_chat_action(stream_id, component_name) + case ComponentType.COMMAND: + return global_announcement_manager.enable_specific_chat_command(stream_id, component_name) + case ComponentType.EVENT_HANDLER: + return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name) + case _: + raise ValueError(f"未知 component type: {component_type}") + + +def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: + """ + 局部禁用指定组件。 + + Args: + component_name (str): 组件名称。 + component_type (ComponentType): 组件类型。 + stream_id (str): 消息流 ID。 + + Returns: + bool: 禁用成功返回 True,否则返回 False。 + """ + from src.plugin_system.core.global_announcement_manager import global_announcement_manager + + match component_type: + case ComponentType.ACTION: + return global_announcement_manager.disable_specific_chat_action(stream_id, component_name) + case ComponentType.COMMAND: + return global_announcement_manager.disable_specific_chat_command(stream_id, component_name) + case ComponentType.EVENT_HANDLER: + return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name) + case _: + raise ValueError(f"未知 component type: {component_type}") + +def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: + """ + 获取指定消息流中禁用的组件列表。 + + Args: + stream_id (str): 消息流 ID。 + component_type (ComponentType): 组件类型。 + + Returns: + list[str]: 禁用的组件名称列表。 + """ + from src.plugin_system.core.global_announcement_manager import global_announcement_manager + + match component_type: + case ComponentType.ACTION: + return global_announcement_manager.get_disabled_chat_actions(stream_id) + case ComponentType.COMMAND: + return global_announcement_manager.get_disabled_chat_commands(stream_id) + case ComponentType.EVENT_HANDLER: + return global_announcement_manager.get_disabled_chat_event_handlers(stream_id) + case _: + raise ValueError(f"未知 component type: {component_type}") \ No newline at end of file diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py new file mode 100644 index 000000000..1c01119b2 --- /dev/null +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -0,0 +1,95 @@ +from typing import Tuple, List +def list_loaded_plugins() -> List[str]: + """ + 列出所有当前加载的插件。 + + Returns: + list: 当前加载的插件名称列表。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return plugin_manager.list_loaded_plugins() + + +def list_registered_plugins() -> List[str]: + """ + 列出所有已注册的插件。 + + Returns: + list: 已注册的插件名称列表。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return plugin_manager.list_registered_plugins() + + +async def remove_plugin(plugin_name: str) -> bool: + """ + 卸载指定的插件。 + + **此函数是异步的,确保在异步环境中调用。** + + Args: + plugin_name (str): 要卸载的插件名称。 + + Returns: + bool: 卸载是否成功。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return await plugin_manager.remove_registered_plugin(plugin_name) + + +async def reload_plugin(plugin_name: str) -> bool: + """ + 重新加载指定的插件。 + + **此函数是异步的,确保在异步环境中调用。** + + Args: + plugin_name (str): 要重新加载的插件名称。 + + Returns: + bool: 重新加载是否成功。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return await plugin_manager.reload_registered_plugin(plugin_name) + + +def load_plugin(plugin_name: str) -> Tuple[bool, int]: + """ + 加载指定的插件。 + + Args: + plugin_name (str): 要加载的插件名称。 + + Returns: + Tuple[bool, int]: 加载是否成功,成功或失败个数。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return plugin_manager.load_registered_plugin_classes(plugin_name) + +def add_plugin_directory(plugin_directory: str) -> bool: + """ + 添加插件目录。 + + Args: + plugin_directory (str): 要添加的插件目录路径。 + Returns: + bool: 添加是否成功。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return plugin_manager.add_plugin_directory(plugin_directory) + +def rescan_plugin_directory() -> Tuple[int, int]: + """ + 重新扫描插件目录,加载新插件。 + Returns: + Tuple[int, int]: 成功加载的插件数量和失败的插件数量。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + return plugin_manager.rescan_plugin_directory() \ No newline at end of file diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index 879c09b32..e4ba2ee48 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -28,7 +28,6 @@ def register_plugin(cls): if "." in plugin_name: logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") - plugin_manager.plugin_classes[plugin_name] = cls splitted_name = cls.__module__.split(".") root_path = Path(__file__) @@ -40,6 +39,7 @@ def register_plugin(cls): logger.error(f"注册 {plugin_name} 无法找到项目根目录") return cls + plugin_manager.plugin_classes[plugin_name] = cls plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve()) logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}") diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index a61a0339d..c108c5d86 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -49,12 +49,10 @@ class BaseAction(ABC): reasoning: 执行该动作的理由 cycle_timers: 计时器字典 thinking_id: 思考ID - expressor: 表达器对象 - replyer: 回复器对象 chat_stream: 聊天流对象 log_prefix: 日志前缀 - shutting_down: 是否正在关闭 plugin_config: 插件配置字典 + action_message: 消息数据 **kwargs: 其他参数 """ if plugin_config is None: @@ -65,21 +63,30 @@ class BaseAction(ABC): self.thinking_id = thinking_id self.log_prefix = log_prefix - # 保存插件配置 self.plugin_config = plugin_config or {} + """对应的插件配置""" # 设置动作基本信息实例属性 self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", "")) + """Action的名字""" self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件") + """Action的描述""" self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy() self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() # 设置激活类型实例属性(从类属性复制,提供默认值) self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS) + """FOCUS模式下的激活类型""" self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS) + """NORMAL模式下的激活类型""" + self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type) + """激活类型""" self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) + """当激活类型为RANDOM时的概率""" self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") + """协助LLM进行判断的Prompt""" self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy() + """激活类型为KEYWORD时的KEYWORDS列表""" self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False) self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL) self.parallel_action: bool = getattr(self.__class__, "parallel_action", True) @@ -136,7 +143,7 @@ class BaseAction(ABC): self.target_id = self.user_id logger.debug(f"{self.log_prefix} Action组件初始化完成") - logger.info( + logger.debug( f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" ) @@ -405,23 +412,11 @@ class BaseAction(ABC): """ return await self.execute() - # def get_action_context(self, key: str, default=None): - # """获取action上下文信息 - - # Args: - # key: 上下文键名 - # default: 默认值 - - # Returns: - # Any: 上下文值或默认值 - # """ - # return self.api.get_action_context(key, default) - def get_config(self, key: str, default=None): - """获取插件配置值,支持嵌套键访问 + """获取插件配置值,使用嵌套键访问 Args: - key: 配置键名,支持嵌套访问如 "section.subsection.key" + key: 配置键名,使用嵌套访问如 "section.subsection.key" default: 默认值 Returns: diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 5387e01dd..b79f68845 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -17,17 +17,18 @@ class BaseCommand(ABC): - command_pattern: 命令匹配的正则表达式 - command_help: 命令帮助信息 - command_examples: 命令使用示例列表 - - intercept_message: 是否拦截消息处理(默认True拦截,False继续传递) """ command_name: str = "" + """Command组件的名称""" command_description: str = "" - - # 默认命令设置(子类可以覆盖) - command_pattern: str = "" + """Command组件的描述""" + # 默认命令设置 + command_pattern: str = r"" + """命令匹配的正则表达式""" command_help: str = "" + """命令帮助信息""" command_examples: List[str] = [] - intercept_message: bool = True # 默认拦截消息,不继续处理 def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): """初始化Command组件 @@ -53,11 +54,11 @@ class BaseCommand(ABC): self.matched_groups = groups @abstractmethod - async def execute(self) -> Tuple[bool, Optional[str]]: + async def execute(self) -> Tuple[bool, Optional[str], bool]: """执行Command的抽象方法,子类必须实现 Returns: - Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息) + Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理) """ pass @@ -229,5 +230,4 @@ class BaseCommand(ABC): command_pattern=cls.command_pattern, command_help=cls.command_help, command_examples=cls.command_examples.copy() if cls.command_examples else [], - intercept_message=cls.intercept_message, ) diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index b6c9e965d..5118885ff 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -13,16 +13,23 @@ class BaseEventHandler(ABC): 所有事件处理器都应该继承这个基类,提供事件处理的基本接口 """ - event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知 - handler_name: str = "" # 处理器名称 + event_type: EventType = EventType.UNKNOWN + """事件类型,默认为未知""" + handler_name: str = "" + """处理器名称""" handler_description: str = "" - weight: int = 0 # 权重,数值越大优先级越高 - intercept_message: bool = False # 是否拦截消息,默认为否 + """处理器描述""" + weight: int = 0 + """处理器权重,越大权重越高""" + intercept_message: bool = False + """是否拦截消息,默认为否""" def __init__(self): self.log_prefix = "[EventHandler]" - self.plugin_name = "" # 对应插件名 - self.plugin_config: Optional[Dict] = None # 插件配置字典 + self.plugin_name = "" + """对应插件名""" + self.plugin_config: Optional[Dict] = None + """插件配置字典""" if self.event_type == EventType.UNKNOWN: raise NotImplementedError("事件处理器必须指定 event_type") diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 1e6841eba..3cf82390e 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -3,7 +3,7 @@ from typing import List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo +from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 774daa598..74b01ddd7 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -142,7 +142,6 @@ class CommandInfo(ComponentInfo): command_pattern: str = "" # 命令匹配模式(正则表达式) command_help: str = "" # 命令帮助信息 command_examples: List[str] = field(default_factory=list) # 命令使用示例 - intercept_message: bool = True # 是否拦截消息处理(默认拦截) def __post_init__(self): super().__post_init__() diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index c6041ece7..3193828bf 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.events_manager import events_manager +from src.plugin_system.core.global_announcement_manager import global_announcement_manager __all__ = [ "plugin_manager", "component_registry", "dependency_manager", "events_manager", + "global_announcement_manager", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 7283cf9eb..2ea89b880 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -25,27 +25,35 @@ class ComponentRegistry: """ def __init__(self): - # 组件注册表 - self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息 - # 类型 -> 命名空间式名称 -> 组件信息 + # 命名空间式组件名构成法 f"{component_type}.{component_name}" + self._components: Dict[str, ComponentInfo] = {} + """组件注册表 命名空间式组件名 -> 组件信息""" self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} - # 命名空间式组件名 -> 组件类 + """类型 -> 组件原名称 -> 组件信息""" self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} + """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 + self._plugins: Dict[str, PluginInfo] = {} + """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类 - self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态 + self._action_registry: Dict[str, Type[BaseAction]] = {} + """Action注册表 action名 -> action类""" + self._default_actions: Dict[str, ActionInfo] = {} + """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类 - self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command名 + self._command_registry: Dict[str, Type[BaseCommand]] = {} + """Command类注册表 command名 -> command类""" + self._command_patterns: Dict[Pattern, str] = {} + """编译后的正则 -> command名""" # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类 - self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器 + self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} + """event_handler名 -> event_handler类""" + self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} + """启用的事件处理器 event_handler名 -> event_handler类""" logger.info("组件注册中心初始化完成") @@ -110,11 +118,17 @@ class ComponentRegistry: # 根据组件类型进行特定注册(使用原始名称) match component_type: case ComponentType.ACTION: - ret = self._register_action_component(component_info, component_class) # type: ignore + assert isinstance(component_info, ActionInfo) + assert issubclass(component_class, BaseAction) + ret = self._register_action_component(component_info, component_class) case ComponentType.COMMAND: - ret = self._register_command_component(component_info, component_class) # type: ignore + assert isinstance(component_info, CommandInfo) + assert issubclass(component_class, BaseCommand) + ret = self._register_command_component(component_info, component_class) case ComponentType.EVENT_HANDLER: - ret = self._register_event_handler_component(component_info, component_class) # type: ignore + assert isinstance(component_info, EventHandlerInfo) + assert issubclass(component_class, BaseEventHandler) + ret = self._register_event_handler_component(component_info, component_class) case _: logger.warning(f"未知组件类型: {component_type}") @@ -160,7 +174,9 @@ class ComponentRegistry: if pattern not in self._command_patterns: self._command_patterns[pattern] = command_name else: - logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") + logger.warning( + f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令" + ) return True @@ -176,6 +192,10 @@ class ComponentRegistry: self._event_handler_registry[handler_name] = handler_class + if not handler_info.enabled: + logger.warning(f"EventHandler组件 {handler_name} 未启用") + return True # 未启用,但是也是注册成功 + from .events_manager import events_manager # 延迟导入防止循环导入问题 if events_manager.register_event_subscriber(handler_info, handler_class): @@ -185,6 +205,124 @@ class ComponentRegistry: logger.error(f"注册事件处理器 {handler_name} 失败") return False + # === 组件移除相关 === + + async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: + target_component_class = self.get_component_class(component_name, component_type) + if not target_component_class: + logger.warning(f"组件 {component_name} 未注册,无法移除") + return False + try: + match component_type: + case ComponentType.ACTION: + self._action_registry.pop(component_name) + self._default_actions.pop(component_name) + case ComponentType.COMMAND: + self._command_registry.pop(component_name) + keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name] + for key in keys_to_remove: + self._command_patterns.pop(key) + case ComponentType.EVENT_HANDLER: + from .events_manager import events_manager # 延迟导入防止循环导入问题 + + self._event_handler_registry.pop(component_name) + self._enabled_event_handlers.pop(component_name) + await events_manager.unregister_event_subscriber(component_name) + namespaced_name = f"{component_type}.{component_name}" + self._components.pop(namespaced_name) + self._components_by_type[component_type].pop(component_name) + self._components_classes.pop(namespaced_name) + logger.info(f"组件 {component_name} 已移除") + return True + except KeyError: + logger.warning(f"移除组件时未找到组件: {component_name}") + return False + except Exception as e: + logger.error(f"移除组件 {component_name} 时发生错误: {e}") + return False + + def remove_plugin_registry(self, plugin_name: str) -> bool: + """移除插件注册信息 + + Args: + plugin_name: 插件名称 + + Returns: + bool: 是否成功移除 + """ + if plugin_name not in self._plugins: + logger.warning(f"插件 {plugin_name} 未注册,无法移除") + return False + del self._plugins[plugin_name] + logger.info(f"插件 {plugin_name} 已移除") + return True + + # === 组件全局启用/禁用方法 === + + def enable_component(self, component_name: str, component_type: ComponentType) -> bool: + """全局的启用某个组件 + Parameters: + component_name: 组件名称 + component_type: 组件类型 + Returns: + bool: 启用成功返回True,失败返回False + """ + target_component_class = self.get_component_class(component_name, component_type) + target_component_info = self.get_component_info(component_name, component_type) + if not target_component_class or not target_component_info: + logger.warning(f"组件 {component_name} 未注册,无法启用") + return False + target_component_info.enabled = True + match component_type: + case ComponentType.ACTION: + assert isinstance(target_component_info, ActionInfo) + self._default_actions[component_name] = target_component_info + case ComponentType.COMMAND: + assert isinstance(target_component_info, CommandInfo) + pattern = target_component_info.command_pattern + self._command_patterns[re.compile(pattern)] = component_name + case ComponentType.EVENT_HANDLER: + assert isinstance(target_component_info, EventHandlerInfo) + assert issubclass(target_component_class, BaseEventHandler) + self._enabled_event_handlers[component_name] = target_component_class + from .events_manager import events_manager # 延迟导入防止循环导入问题 + + events_manager.register_event_subscriber(target_component_info, target_component_class) + namespaced_name = f"{component_type}.{component_name}" + self._components[namespaced_name].enabled = True + self._components_by_type[component_type][component_name].enabled = True + logger.info(f"组件 {component_name} 已启用") + return True + + async def disable_component(self, component_name: str, component_type: ComponentType) -> bool: + """全局的禁用某个组件 + Parameters: + component_name: 组件名称 + component_type: 组件类型 + Returns: + bool: 禁用成功返回True,失败返回False + """ + target_component_class = self.get_component_class(component_name, component_type) + target_component_info = self.get_component_info(component_name, component_type) + if not target_component_class or not target_component_info: + logger.warning(f"组件 {component_name} 未注册,无法禁用") + return False + target_component_info.enabled = False + match component_type: + case ComponentType.ACTION: + self._default_actions.pop(component_name, None) + case ComponentType.COMMAND: + self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} + case ComponentType.EVENT_HANDLER: + self._enabled_event_handlers.pop(component_name, None) + from .events_manager import events_manager # 延迟导入防止循环导入问题 + + await events_manager.unregister_event_subscriber(component_name) + self._components[component_name].enabled = False + self._components_by_type[component_type][component_name].enabled = False + logger.info(f"组件 {component_name} 已禁用") + return True + # === 组件查询方法 === def get_component_info( self, component_name: str, component_type: Optional[ComponentType] = None @@ -287,7 +425,7 @@ class ComponentRegistry: # === Action特定查询方法 === def get_action_registry(self) -> Dict[str, Type[BaseAction]]: - """获取Action注册表(用于兼容现有系统)""" + """获取Action注册表""" return self._action_registry.copy() def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: @@ -314,7 +452,7 @@ class ComponentRegistry: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -335,11 +473,10 @@ class ComponentRegistry: return ( self._command_registry[command_name], candidates[0].match(text).groupdict(), # type: ignore - command_info.intercept_message, - command_info.plugin_name, + command_info, ) - # === 事件处理器特定查询方法 === + # === EventHandler 特定查询方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: """获取事件处理器注册表""" @@ -364,9 +501,9 @@ class ComponentRegistry: """获取所有插件""" return self._plugins.copy() - def get_enabled_plugins(self) -> Dict[str, PluginInfo]: - """获取所有启用的插件""" - return {name: info for name, info in self._plugins.items() if info.enabled} + # def get_enabled_plugins(self) -> Dict[str, PluginInfo]: + # """获取所有启用的插件""" + # return {name: info for name, info in self._plugins.items() if info.enabled} def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]: """获取插件的所有组件""" diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 6352c4a09..1f01b4ab4 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages from src.plugin_system.base.base_events_handler import BaseEventHandler +from .global_announcement_manager import global_announcement_manager logger = get_logger("events_manager") @@ -28,18 +29,16 @@ class EventsManager: bool: 是否注册成功 """ handler_name = handler_info.name - plugin_name = getattr(handler_info, "plugin_name", "unknown") - namespace_name = f"{plugin_name}.{handler_name}" - if namespace_name in self._handler_mapping: - logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册") + if handler_name in self._handler_mapping: + logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册") return False if not issubclass(handler_class, BaseEventHandler): logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类") return False - self._handler_mapping[namespace_name] = handler_class + self._handler_mapping[handler_name] = handler_class return self._insert_event_handler(handler_class, handler_info) async def handle_mai_events( @@ -55,6 +54,10 @@ class EventsManager: continue_flag = True transformed_message = self._transform_event_message(message, llm_prompt, llm_response) for handler in self._events_subscribers.get(event_type, []): + if message.chat_stream and message.chat_stream.stream_id: + stream_id = message.chat_stream.stream_id + if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id): + continue handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {}) if handler.intercept_message: try: @@ -71,7 +74,9 @@ class EventsManager: try: handler_task = asyncio.create_task(handler.execute(transformed_message)) handler_task.add_done_callback(self._task_done_callback) - handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}") + handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}") + if handler.handler_name not in self._handler_tasks: + self._handler_tasks[handler.handler_name] = [] self._handler_tasks[handler.handler_name].append(handler_task) except Exception as e: logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}") @@ -91,7 +96,7 @@ class EventsManager: return True - def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool: """从事件类型列表中移除事件处理器""" display_handler_name = handler_class.handler_name or handler_class.__name__ if handler_class.event_type == EventType.UNKNOWN: @@ -190,5 +195,20 @@ class EventsManager: finally: del self._handler_tasks[handler_name] + async def unregister_event_subscriber(self, handler_name: str) -> bool: + """取消注册事件处理器""" + if handler_name not in self._handler_mapping: + logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册") + return False + + await self.cancel_handler_tasks(handler_name) + + handler_class = self._handler_mapping.pop(handler_name) + if not self._remove_event_handler_instance(handler_class): + return False + + logger.info(f"事件处理器 {handler_name} 已成功取消注册") + return True + events_manager = EventsManager() diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py new file mode 100644 index 000000000..9f7052f5d --- /dev/null +++ b/src/plugin_system/core/global_announcement_manager.py @@ -0,0 +1,93 @@ +from typing import List, Dict + +from src.common.logger import get_logger + +logger = get_logger("global_announcement_manager") + + +class GlobalAnnouncementManager: + def __init__(self) -> None: + # 用户禁用的动作,chat_id -> [action_name] + self._user_disabled_actions: Dict[str, List[str]] = {} + # 用户禁用的命令,chat_id -> [command_name] + self._user_disabled_commands: Dict[str, List[str]] = {} + # 用户禁用的事件处理器,chat_id -> [handler_name] + self._user_disabled_event_handlers: Dict[str, List[str]] = {} + + def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: + """禁用特定聊天的某个动作""" + if chat_id not in self._user_disabled_actions: + self._user_disabled_actions[chat_id] = [] + if action_name in self._user_disabled_actions[chat_id]: + logger.warning(f"动作 {action_name} 已经被禁用") + return False + self._user_disabled_actions[chat_id].append(action_name) + return True + + def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: + """启用特定聊天的某个动作""" + if chat_id in self._user_disabled_actions: + try: + self._user_disabled_actions[chat_id].remove(action_name) + return True + except ValueError: + logger.warning(f"动作 {action_name} 不在禁用列表中") + return False + return False + + def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool: + """禁用特定聊天的某个命令""" + if chat_id not in self._user_disabled_commands: + self._user_disabled_commands[chat_id] = [] + if command_name in self._user_disabled_commands[chat_id]: + logger.warning(f"命令 {command_name} 已经被禁用") + return False + self._user_disabled_commands[chat_id].append(command_name) + return True + + def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool: + """启用特定聊天的某个命令""" + if chat_id in self._user_disabled_commands: + try: + self._user_disabled_commands[chat_id].remove(command_name) + return True + except ValueError: + logger.warning(f"命令 {command_name} 不在禁用列表中") + return False + return False + + def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool: + """禁用特定聊天的某个事件处理器""" + if chat_id not in self._user_disabled_event_handlers: + self._user_disabled_event_handlers[chat_id] = [] + if handler_name in self._user_disabled_event_handlers[chat_id]: + logger.warning(f"事件处理器 {handler_name} 已经被禁用") + return False + self._user_disabled_event_handlers[chat_id].append(handler_name) + return True + + def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool: + """启用特定聊天的某个事件处理器""" + if chat_id in self._user_disabled_event_handlers: + try: + self._user_disabled_event_handlers[chat_id].remove(handler_name) + return True + except ValueError: + logger.warning(f"事件处理器 {handler_name} 不在禁用列表中") + return False + return False + + def get_disabled_chat_actions(self, chat_id: str) -> List[str]: + """获取特定聊天禁用的所有动作""" + return self._user_disabled_actions.get(chat_id, []).copy() + + def get_disabled_chat_commands(self, chat_id: str) -> List[str]: + """获取特定聊天禁用的所有命令""" + return self._user_disabled_commands.get(chat_id, []).copy() + + def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: + """获取特定聊天禁用的所有事件处理器""" + return self._user_disabled_event_handlers.get(chat_id, []).copy() + + +global_announcement_manager = GlobalAnnouncementManager() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 3ce9c9e52..8bb005a94 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,5 +1,4 @@ import os -import inspect import traceback from typing import Dict, List, Optional, Tuple, Type, Any @@ -8,11 +7,11 @@ from pathlib import Path from src.common.logger import get_logger -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.base.plugin_base import PluginBase -from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency +from src.plugin_system.base.component_types import ComponentType, PythonDependency from src.plugin_system.utils.manifest_utils import VersionComparator +from .component_registry import component_registry +from .dependency_manager import dependency_manager logger = get_logger("plugin_manager") @@ -36,19 +35,7 @@ class PluginManager: self._ensure_plugin_directories() logger.info("插件管理器初始化完成") - def _ensure_plugin_directories(self) -> None: - """确保所有插件根目录存在,如果不存在则创建""" - default_directories = ["src/plugins/built_in", "plugins"] - - for directory in default_directories: - if not os.path.exists(directory): - os.makedirs(directory, exist_ok=True) - logger.info(f"创建插件根目录: {directory}") - if directory not in self.plugin_directories: - self.plugin_directories.append(directory) - logger.debug(f"已添加插件根目录: {directory}") - else: - logger.warning(f"根目录不可重复加载: {directory}") + # === 插件目录管理 === def add_plugin_directory(self, directory: str) -> bool: """添加插件目录""" @@ -63,6 +50,8 @@ class PluginManager: logger.warning(f"插件目录不存在: {directory}") return False + # === 插件加载管理 === + def load_all_plugins(self) -> Tuple[int, int]: """加载所有插件 @@ -162,62 +151,50 @@ class PluginManager: logger.debug("详细错误信息: ", exc_info=True) return False, 1 - def unload_registered_plugin_module(self, plugin_name: str) -> None: + async def remove_registered_plugin(self, plugin_name: str) -> bool: """ - 卸载插件模块 + 禁用插件模块 """ - pass + if not plugin_name: + raise ValueError("插件名称不能为空") + if plugin_name not in self.loaded_plugins: + logger.warning(f"插件 {plugin_name} 未加载") + return False + plugin_instance = self.loaded_plugins[plugin_name] + plugin_info = plugin_instance.plugin_info + success = True + for component in plugin_info.components: + success &= await component_registry.remove_component(component.name, component.component_type, plugin_name) + success &= component_registry.remove_plugin_registry(plugin_name) + del self.loaded_plugins[plugin_name] + return success - def reload_registered_plugin_module(self, plugin_name: str) -> None: + async def reload_registered_plugin(self, plugin_name: str) -> bool: """ 重载插件模块 """ - self.unload_registered_plugin_module(plugin_name) - self.load_registered_plugin_classes(plugin_name) + if not await self.remove_registered_plugin(plugin_name): + return False + if not self.load_registered_plugin_classes(plugin_name)[0]: + return False + logger.debug(f"插件 {plugin_name} 重载成功") + return True - def rescan_plugin_directory(self) -> None: + def rescan_plugin_directory(self) -> Tuple[int, int]: """ 重新扫描插件根目录 """ - # --------------------------------------- NEED REFACTORING --------------------------------------- + total_success = 0 + total_fail = 0 for directory in self.plugin_directories: if os.path.exists(directory): logger.debug(f"重新扫描插件根目录: {directory}") - self._load_plugin_modules_from_directory(directory) + success, fail = self._load_plugin_modules_from_directory(directory) + total_success += success + total_fail += fail else: logger.warning(f"插件根目录不存在: {directory}") - - def get_loaded_plugins(self) -> List[PluginInfo]: - """获取所有已加载的插件信息""" - return list(component_registry.get_all_plugins().values()) - - def get_enabled_plugins(self) -> List[PluginInfo]: - """获取所有启用的插件信息""" - return list(component_registry.get_enabled_plugins().values()) - - # def enable_plugin(self, plugin_name: str) -> bool: - # # -------------------------------- NEED REFACTORING -------------------------------- - # """启用插件""" - # if plugin_info := component_registry.get_plugin_info(plugin_name): - # plugin_info.enabled = True - # # 启用插件的所有组件 - # for component in plugin_info.components: - # component_registry.enable_component(component.name) - # logger.debug(f"已启用插件: {plugin_name}") - # return True - # return False - - # def disable_plugin(self, plugin_name: str) -> bool: - # # -------------------------------- NEED REFACTORING -------------------------------- - # """禁用插件""" - # if plugin_info := component_registry.get_plugin_info(plugin_name): - # plugin_info.enabled = False - # # 禁用插件的所有组件 - # for component in plugin_info.components: - # component_registry.disable_component(component.name) - # logger.debug(f"已禁用插件: {plugin_name}") - # return True - # return False + return total_success, total_fail def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: """获取插件实例 @@ -230,25 +207,6 @@ class PluginManager: """ return self.loaded_plugins.get(plugin_name) - def get_plugin_stats(self) -> Dict[str, Any]: - """获取插件统计信息""" - all_plugins = component_registry.get_all_plugins() - enabled_plugins = component_registry.get_enabled_plugins() - - action_components = component_registry.get_components_by_type(ComponentType.ACTION) - command_components = component_registry.get_components_by_type(ComponentType.COMMAND) - - return { - "total_plugins": len(all_plugins), - "enabled_plugins": len(enabled_plugins), - "failed_plugins": len(self.failed_plugins), - "total_components": len(action_components) + len(command_components), - "action_components": len(action_components), - "command_components": len(command_components), - "loaded_plugin_files": len(self.loaded_plugins), - "failed_plugin_details": self.failed_plugins.copy(), - } - def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]: """检查所有插件的Python依赖包 @@ -347,6 +305,43 @@ class PluginManager: return dependency_manager.generate_requirements_file(all_dependencies, output_path) + # === 查询方法 === + def list_loaded_plugins(self) -> List[str]: + """ + 列出所有当前加载的插件。 + + Returns: + list: 当前加载的插件名称列表。 + """ + return list(self.loaded_plugins.keys()) + + def list_registered_plugins(self) -> List[str]: + """ + 列出所有已注册的插件类。 + + Returns: + list: 已注册的插件类名称列表。 + """ + return list(self.plugin_classes.keys()) + + # === 私有方法 === + # == 目录管理 == + def _ensure_plugin_directories(self) -> None: + """确保所有插件根目录存在,如果不存在则创建""" + default_directories = ["src/plugins/built_in", "plugins"] + + for directory in default_directories: + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + logger.info(f"创建插件根目录: {directory}") + if directory not in self.plugin_directories: + self.plugin_directories.append(directory) + logger.debug(f"已添加插件根目录: {directory}") + else: + logger.warning(f"根目录不可重复加载: {directory}") + + # == 插件加载 == + def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: """从指定目录加载插件模块""" loaded_count = 0 @@ -372,18 +367,6 @@ class PluginManager: return loaded_count, failed_count - def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]: - """查找插件类对应的目录路径""" - try: - # module = getmodule(plugin_class) - # if module and hasattr(module, "__file__") and module.__file__: - # return os.path.dirname(module.__file__) - file_path = inspect.getfile(plugin_class) - return os.path.dirname(file_path) - except Exception as e: - logger.debug(f"通过inspect获取插件目录失败: {e}") - return None - def _load_plugin_module_file(self, plugin_file: str) -> bool: # sourcery skip: extract-method """加载单个插件模块文件 @@ -416,6 +399,8 @@ class PluginManager: self.failed_plugins[module_name] = error_msg return False + # == 兼容性检查 == + def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: """检查插件版本兼容性 @@ -451,6 +436,8 @@ class PluginManager: logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载 + # == 显示统计与插件信息 == + def _show_stats(self, total_registered: int, total_failed_registration: int): # sourcery skip: low-code-quality # 获取组件统计信息 @@ -493,9 +480,15 @@ class PluginManager: # 组件列表 if plugin_info.components: - action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION] - command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND] - event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER] + action_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.ACTION + ] + command_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.COMMAND + ] + event_handler_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER + ] if action_components: action_names = [c.name for c in action_components] @@ -504,7 +497,7 @@ class PluginManager: if command_components: command_names = [c.name for c in command_components] logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - + if event_handler_components: event_handler_names = [c.name for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 1f1727adf..d44183c83 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -10,6 +10,7 @@ from src.common.logger import get_logger # 导入API模块 - 标准Python包方式 from src.plugin_system.apis import emoji_api, llm_api, message_api from src.plugins.built_in.core_actions.no_reply import NoReplyAction +from src.config.config import global_config logger = get_logger("emoji") @@ -102,7 +103,11 @@ class EmojiAction(BaseAction): 这里是可用的情感标签:{available_emotions} 请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。 """ - logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") + + if global_config.debug.show_prompt: + logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") + else: + logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") # 5. 调用LLM models = llm_api.get_available_models() diff --git a/src/plugins/built_in/core_actions/no_reply.py b/src/plugins/built_in/core_actions/no_reply.py index f275bfc42..e9fad9107 100644 --- a/src/plugins/built_in/core_actions/no_reply.py +++ b/src/plugins/built_in/core_actions/no_reply.py @@ -13,7 +13,7 @@ from src.plugin_system.apis import message_api from src.config.config import global_config -logger = get_logger("core_actions") +logger = get_logger("no_reply_action") class NoReplyAction(BaseAction): diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 015189e2b..99bff18aa 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -5,15 +5,10 @@ 这是系统的内置插件,提供基础的聊天交互功能 """ -import random -import time from typing import List, Tuple, Type -import asyncio -import re -import traceback # 导入新插件系统 -from src.plugin_system import BasePlugin, register_plugin, BaseAction, ComponentInfo, ActionActivationType, ChatMode +from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ActionActivationType from src.plugin_system.base.config_types import ConfigField from src.config.config import global_config @@ -21,139 +16,12 @@ from src.config.config import global_config from src.common.logger import get_logger # 导入API模块 - 标准Python包方式 -from src.plugin_system.apis import generator_api, message_api from src.plugins.built_in.core_actions.no_reply import NoReplyAction from src.plugins.built_in.core_actions.emoji import EmojiAction -from src.person_info.person_info import get_person_info_manager -from src.chat.mai_thinking.mai_think import mai_thinking_manager +from src.plugins.built_in.core_actions.reply import ReplyAction logger = get_logger("core_actions") -# 常量定义 -WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒 - -ENABLE_THINKING = False - -class ReplyAction(BaseAction): - """回复动作 - 参与聊天回复""" - - # 激活设置 - focus_activation_type = ActionActivationType.NEVER - normal_activation_type = ActionActivationType.NEVER - mode_enable = ChatMode.FOCUS - parallel_action = False - - # 动作基本信息 - action_name = "reply" - action_description = "参与聊天回复,发送文本进行表达" - - # 动作参数定义 - action_parameters = {} - - # 动作使用场景 - action_require = ["你想要闲聊或者随便附和", "有人提到你", "如果你刚刚进行了回复,不要对同一个话题重复回应"] - - # 关联类型 - associated_types = ["text"] - - def _parse_reply_target(self, target_message: str) -> tuple: - sender = "" - target = "" - if ":" in target_message or ":" in target_message: - # 使用正则表达式匹配中文或英文冒号 - parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) - if len(parts) == 2: - sender = parts[0].strip() - target = parts[1].strip() - return sender, target - - async def execute(self) -> Tuple[bool, str]: - """执行回复动作""" - logger.info(f"{self.log_prefix} 决定进行回复") - start_time = self.action_data.get("loop_start_time", time.time()) - - user_id = self.user_id - platform = self.platform - # logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}") - person_id = get_person_info_manager().get_person_id(platform, user_id) - # logger.info(f"{self.log_prefix} 人物ID: {person_id}") - person_name = get_person_info_manager().get_value_sync(person_id, "person_name") - reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" - logger.info(f"{self.log_prefix} 回复目标: {reply_to}") - - try: - if prepared_reply := self.action_data.get("prepared_reply", ""): - reply_text = prepared_reply - else: - try: - success, reply_set, _ = await asyncio.wait_for( - generator_api.generate_reply( - extra_info="", - reply_to=reply_to, - chat_id=self.chat_id, - request_type="chat.replyer.focus", - enable_tool=global_config.tool.enable_in_focus_chat, - ), - timeout=global_config.chat.thinking_timeout, - ) - except asyncio.TimeoutError: - logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)") - return False, "timeout" - - # 检查从start_time以来的新消息数量 - # 获取动作触发时间或使用默认值 - current_time = time.time() - new_message_count = message_api.count_new_messages( - chat_id=self.chat_id, start_time=start_time, end_time=current_time - ) - - # 根据新消息数量决定是否使用reply_to - need_reply = new_message_count >= random.randint(2, 4) - logger.info( - f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复" - ) - # 构建回复文本 - reply_text = "" - first_replied = False - reply_to_platform_id = f"{platform}:{user_id}" - for reply_seg in reply_set: - data = reply_seg[1] - if not first_replied: - if need_reply: - await self.send_text( - content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False - ) - else: - await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False) - first_replied = True - else: - await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True) - reply_text += data - - # 存储动作记录 - reply_text = f"你对{person_name}进行了回复:{reply_text}" - - - if ENABLE_THINKING: - await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text) - - - await self.store_action_info( - action_build_into_prompt=False, - action_prompt_display=reply_text, - action_done=True, - ) - - # 重置NoReplyAction的连续计数器 - NoReplyAction.reset_consecutive_count() - - return success, reply_text - - except Exception as e: - logger.error(f"{self.log_prefix} 回复动作执行失败: {e}") - traceback.print_exc() - return False, f"回复失败: {str(e)}" - @register_plugin class CoreActionsPlugin(BasePlugin): @@ -168,11 +36,11 @@ class CoreActionsPlugin(BasePlugin): """ # 插件基本信息 - plugin_name = "core_actions" # 内部标识符 - enable_plugin = True - dependencies = [] # 插件依赖列表 - python_dependencies = [] # Python包依赖列表 - config_file_name = "config.toml" + plugin_name: str = "core_actions" # 内部标识符 + enable_plugin: bool = True + dependencies: list[str] = [] # 插件依赖列表 + python_dependencies: list[str] = [] # Python包依赖列表 + config_file_name: str = "config.toml" # 配置节描述 config_section_descriptions = { @@ -181,7 +49,7 @@ class CoreActionsPlugin(BasePlugin): } # 配置Schema定义 - config_schema = { + config_schema: dict = { "plugin": { "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), "config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"), diff --git a/src/plugins/built_in/core_actions/reply.py b/src/plugins/built_in/core_actions/reply.py new file mode 100644 index 000000000..d73337b29 --- /dev/null +++ b/src/plugins/built_in/core_actions/reply.py @@ -0,0 +1,149 @@ +# 导入新插件系统 +from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.config.config import global_config +import random +import time +from typing import Tuple +import asyncio +import re +import traceback + +# 导入依赖的系统组件 +from src.common.logger import get_logger + +# 导入API模块 - 标准Python包方式 +from src.plugin_system.apis import generator_api, message_api +from src.plugins.built_in.core_actions.no_reply import NoReplyAction +from src.person_info.person_info import get_person_info_manager +from src.mais4u.mai_think import mai_thinking_manager +from src.mais4u.constant_s4u import ENABLE_S4U + +logger = get_logger("reply_action") + + +class ReplyAction(BaseAction): + """回复动作 - 参与聊天回复""" + + # 激活设置 + focus_activation_type = ActionActivationType.NEVER + normal_activation_type = ActionActivationType.NEVER + mode_enable = ChatMode.FOCUS + parallel_action = False + + # 动作基本信息 + action_name = "reply" + action_description = "" + + # 动作参数定义 + action_parameters = {} + + # 动作使用场景 + action_require = [""] + + # 关联类型 + associated_types = ["text"] + + def _parse_reply_target(self, target_message: str) -> tuple: + sender = "" + target = "" + # 添加None检查,防止NoneType错误 + if target_message is None: + return sender, target + if ":" in target_message or ":" in target_message: + # 使用正则表达式匹配中文或英文冒号 + parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) + if len(parts) == 2: + sender = parts[0].strip() + target = parts[1].strip() + return sender, target + + async def execute(self) -> Tuple[bool, str]: + """执行回复动作""" + logger.debug(f"{self.log_prefix} 决定进行回复") + start_time = self.action_data.get("loop_start_time", time.time()) + + user_id = self.user_id + platform = self.platform + # logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}") + person_id = get_person_info_manager().get_person_id(platform, user_id) # type: ignore + # logger.info(f"{self.log_prefix} 人物ID: {person_id}") + person_name = get_person_info_manager().get_value_sync(person_id, "person_name") + reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" # type: ignore + logger.info(f"{self.log_prefix} 决定进行回复,目标: {reply_to}") + + try: + if prepared_reply := self.action_data.get("prepared_reply", ""): + reply_text = prepared_reply + else: + try: + success, reply_set, _ = await asyncio.wait_for( + generator_api.generate_reply( + extra_info="", + reply_to=reply_to, + chat_id=self.chat_id, + request_type="chat.replyer.focus", + enable_tool=global_config.tool.enable_in_focus_chat, + ), + timeout=global_config.chat.thinking_timeout, + ) + except asyncio.TimeoutError: + logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)") + return False, "timeout" + + # 检查从start_time以来的新消息数量 + # 获取动作触发时间或使用默认值 + current_time = time.time() + new_message_count = message_api.count_new_messages( + chat_id=self.chat_id, start_time=start_time, end_time=current_time + ) + + # 根据新消息数量决定是否使用reply_to + need_reply = new_message_count >= random.randint(2, 4) + if need_reply: + logger.info( + f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复" + ) + else: + logger.debug( + f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复" + ) + + # 构建回复文本 + reply_text = "" + first_replied = False + reply_to_platform_id = f"{platform}:{user_id}" + for reply_seg in reply_set: + data = reply_seg[1] + if not first_replied: + if need_reply: + await self.send_text( + content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False + ) + else: + await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False) + first_replied = True + else: + await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True) + reply_text += data + + # 存储动作记录 + reply_text = f"你对{person_name}进行了回复:{reply_text}" + + if ENABLE_S4U: + await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text) + + await self.store_action_info( + action_build_into_prompt=False, + action_prompt_display=reply_text, + action_done=True, + ) + + # 重置NoReplyAction的连续计数器 + NoReplyAction.reset_consecutive_count() + + return success, reply_text + + except Exception as e: + logger.error(f"{self.log_prefix} 回复动作执行失败: {e}") + traceback.print_exc() + return False, f"回复失败: {str(e)}" diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json new file mode 100644 index 000000000..41b3cd9ce --- /dev/null +++ b/src/plugins/built_in/plugin_management/_manifest.json @@ -0,0 +1,39 @@ +{ + "manifest_version": 1, + "name": "插件和组件管理 (Plugin and Component Management)", + "version": "1.0.0", + "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", + "author": { + "name": "MaiBot团队", + "url": "https://github.com/MaiM-with-u" + }, + "license": "GPL-v3.0-or-later", + "host_application": { + "min_version": "0.9.0" + }, + "homepage_url": "https://github.com/MaiM-with-u/maibot", + "repository_url": "https://github.com/MaiM-with-u/maibot", + "keywords": [ + "plugins", + "components", + "management", + "built-in" + ], + "categories": [ + "Core System", + "Plugin Management" + ], + "default_locale": "zh-CN", + "locales_path": "_locales", + "plugin_info": { + "is_built_in": true, + "plugin_type": "plugin_management", + "components": [ + { + "type": "command", + "name": "plugin_management", + "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。" + } + ] + } +} \ No newline at end of file diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py new file mode 100644 index 000000000..cbdf567ac --- /dev/null +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -0,0 +1,440 @@ +import asyncio + +from typing import List, Tuple, Type +from src.plugin_system import ( + BasePlugin, + BaseCommand, + CommandInfo, + ConfigField, + register_plugin, + plugin_manage_api, + component_manage_api, + ComponentInfo, + ComponentType, +) + + +class ManagementCommand(BaseCommand): + command_name: str = "management" + description: str = "管理命令" + command_pattern: str = r"(?P^/pm(\s[a-zA-Z0-9_]+)*\s*$)" + + async def execute(self) -> Tuple[bool, str, bool]: + # sourcery skip: merge-duplicate-blocks + if ( + not self.message + or not self.message.message_info + or not self.message.message_info.user_info + or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore + ): + await self.send_text("你没有权限使用插件管理命令") + return False, "没有权限", True + command_list = self.matched_groups["manage_command"].strip().split(" ") + if len(command_list) == 1: + await self.show_help("all") + return True, "帮助已发送", True + if len(command_list) == 2: + match command_list[1]: + case "plugin": + await self.show_help("plugin") + case "component": + await self.show_help("component") + case "help": + await self.show_help("all") + case _: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if len(command_list) == 3: + if command_list[1] == "plugin": + match command_list[2]: + case "help": + await self.show_help("plugin") + case "list": + await self._list_registered_plugins() + case "list_enabled": + await self._list_loaded_plugins() + case "rescan": + await self._rescan_plugin_dirs() + case _: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + elif command_list[1] == "component": + if command_list[2] == "list": + await self._list_all_registered_components() + elif command_list[2] == "help": + await self.show_help("component") + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if len(command_list) == 4: + if command_list[1] == "plugin": + match command_list[2]: + case "load": + await self._load_plugin(command_list[3]) + case "unload": + await self._unload_plugin(command_list[3]) + case "reload": + await self._reload_plugin(command_list[3]) + case "add_dir": + await self._add_dir(command_list[3]) + case _: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + elif command_list[1] == "component": + if command_list[2] != "list": + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if command_list[3] == "enabled": + await self._list_enabled_components() + elif command_list[3] == "disabled": + await self._list_disabled_components() + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if len(command_list) == 5: + if command_list[1] != "component": + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if command_list[2] != "list": + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if command_list[3] == "enabled": + await self._list_enabled_components(target_type=command_list[4]) + elif command_list[3] == "disabled": + await self._list_disabled_components(target_type=command_list[4]) + elif command_list[3] == "type": + await self._list_registered_components_by_type(command_list[4]) + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if len(command_list) == 6: + if command_list[1] != "component": + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + if command_list[2] == "enable": + if command_list[3] == "global": + await self._globally_enable_component(command_list[4], command_list[5]) + elif command_list[3] == "local": + await self._locally_enable_component(command_list[4], command_list[5]) + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + elif command_list[2] == "disable": + if command_list[3] == "global": + await self._globally_disable_component(command_list[4], command_list[5]) + elif command_list[3] == "local": + await self._locally_disable_component(command_list[4], command_list[5]) + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + else: + await self.send_text("插件管理命令不合法") + return False, "命令不合法", True + + return True, "命令执行完成", True + + async def show_help(self, target: str): + help_msg = "" + match target: + case "all": + help_msg = ( + "管理命令帮助\n" + "/pm help 管理命令提示\n" + "/pm plugin 插件管理命令\n" + "/pm component 组件管理命令\n" + "使用 /pm plugin help 或 /pm component help 获取具体帮助" + ) + case "plugin": + help_msg = ( + "插件管理命令帮助\n" + "/pm plugin help 插件管理命令提示\n" + "/pm plugin list 列出所有注册的插件\n" + "/pm plugin list_enabled 列出所有加载(启用)的插件\n" + "/pm plugin rescan 重新扫描所有目录\n" + "/pm plugin load 加载指定插件\n" + "/pm plugin unload 卸载指定插件\n" + "/pm plugin reload 重新加载指定插件\n" + "/pm plugin add_dir 添加插件目录\n" + ) + case "component": + help_msg = ( + "组件管理命令帮助\n" + "/pm component help 组件管理命令提示\n" + "/pm component list 列出所有注册的组件\n" + "/pm component list enabled <可选: type> 列出所有启用的组件\n" + "/pm component list disabled <可选: type> 列出所有禁用的组件\n" + " - 可选项: local,代表当前聊天中的;global,代表全局的\n" + " - 不填时为 global\n" + "/pm component list type 列出已经注册的指定类型的组件\n" + "/pm component enable global 全局启用组件\n" + "/pm component enable local 本聊天启用组件\n" + "/pm component disable global 全局禁用组件\n" + "/pm component disable local 本聊天禁用组件\n" + " - 可选项: action, command, event_handler\n" + ) + case _: + return + await self.send_text(help_msg) + + async def _list_loaded_plugins(self): + plugins = plugin_manage_api.list_loaded_plugins() + await self.send_text(f"已加载的插件: {', '.join(plugins)}") + + async def _list_registered_plugins(self): + plugins = plugin_manage_api.list_registered_plugins() + await self.send_text(f"已注册的插件: {', '.join(plugins)}") + + async def _rescan_plugin_dirs(self): + plugin_manage_api.rescan_plugin_directory() + await self.send_text("插件目录重新扫描执行中") + + async def _load_plugin(self, plugin_name: str): + success, count = plugin_manage_api.load_plugin(plugin_name) + if success: + await self.send_text(f"插件加载成功: {plugin_name}") + else: + if count == 0: + await self.send_text(f"插件{plugin_name}为禁用状态") + await self.send_text(f"插件加载失败: {plugin_name}") + + async def _unload_plugin(self, plugin_name: str): + success = await plugin_manage_api.remove_plugin(plugin_name) + if success: + await self.send_text(f"插件卸载成功: {plugin_name}") + else: + await self.send_text(f"插件卸载失败: {plugin_name}") + + async def _reload_plugin(self, plugin_name: str): + success = await plugin_manage_api.reload_plugin(plugin_name) + if success: + await self.send_text(f"插件重新加载成功: {plugin_name}") + else: + await self.send_text(f"插件重新加载失败: {plugin_name}") + + async def _add_dir(self, dir_path: str): + await self.send_text(f"正在添加插件目录: {dir_path}") + success = plugin_manage_api.add_plugin_directory(dir_path) + await asyncio.sleep(0.5) # 防止乱序发送 + if success: + await self.send_text(f"插件目录添加成功: {dir_path}") + else: + await self.send_text(f"插件目录添加失败: {dir_path}") + + def _fetch_all_registered_components(self) -> List[ComponentInfo]: + all_plugin_info = component_manage_api.get_all_plugin_info() + if not all_plugin_info: + return [] + + components_info: List[ComponentInfo] = [] + for plugin_info in all_plugin_info.values(): + components_info.extend(plugin_info.components) + return components_info + + def _fetch_locally_disabled_components(self) -> List[str]: + locally_disabled_components_actions = component_manage_api.get_locally_disabled_components( + self.message.chat_stream.stream_id, ComponentType.ACTION + ) + locally_disabled_components_commands = component_manage_api.get_locally_disabled_components( + self.message.chat_stream.stream_id, ComponentType.COMMAND + ) + locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components( + self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER + ) + return ( + locally_disabled_components_actions + + locally_disabled_components_commands + + locally_disabled_components_event_handlers + ) + + async def _list_all_registered_components(self): + components_info = self._fetch_all_registered_components() + if not components_info: + await self.send_text("没有注册的组件") + return + + all_components_str = ", ".join( + f"{component.name} ({component.component_type})" for component in components_info + ) + await self.send_text(f"已注册的组件: {all_components_str}") + + async def _list_enabled_components(self, target_type: str = "global"): + components_info = self._fetch_all_registered_components() + if not components_info: + await self.send_text("没有注册的组件") + return + + if target_type == "global": + enabled_components = [component for component in components_info if component.enabled] + if not enabled_components: + await self.send_text("没有满足条件的已启用全局组件") + return + enabled_components_str = ", ".join( + f"{component.name} ({component.component_type})" for component in enabled_components + ) + await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}") + elif target_type == "local": + locally_disabled_components = self._fetch_locally_disabled_components() + enabled_components = [ + component + for component in components_info + if (component.name not in locally_disabled_components and component.enabled) + ] + if not enabled_components: + await self.send_text("本聊天没有满足条件的已启用组件") + return + enabled_components_str = ", ".join( + f"{component.name} ({component.component_type})" for component in enabled_components + ) + await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}") + + async def _list_disabled_components(self, target_type: str = "global"): + components_info = self._fetch_all_registered_components() + if not components_info: + await self.send_text("没有注册的组件") + return + + if target_type == "global": + disabled_components = [component for component in components_info if not component.enabled] + if not disabled_components: + await self.send_text("没有满足条件的已禁用全局组件") + return + disabled_components_str = ", ".join( + f"{component.name} ({component.component_type})" for component in disabled_components + ) + await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}") + elif target_type == "local": + locally_disabled_components = self._fetch_locally_disabled_components() + disabled_components = [ + component + for component in components_info + if (component.name in locally_disabled_components or not component.enabled) + ] + if not disabled_components: + await self.send_text("本聊天没有满足条件的已禁用组件") + return + disabled_components_str = ", ".join( + f"{component.name} ({component.component_type})" for component in disabled_components + ) + await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}") + + async def _list_registered_components_by_type(self, target_type: str): + match target_type: + case "action": + component_type = ComponentType.ACTION + case "command": + component_type = ComponentType.COMMAND + case "event_handler": + component_type = ComponentType.EVENT_HANDLER + case _: + await self.send_text(f"未知组件类型: {target_type}") + return + + components_info = component_manage_api.get_components_info_by_type(component_type) + if not components_info: + await self.send_text(f"没有注册的 {target_type} 组件") + return + + components_str = ", ".join( + f"{name} ({component.component_type})" for name, component in components_info.items() + ) + await self.send_text(f"注册的 {target_type} 组件: {components_str}") + + async def _globally_enable_component(self, component_name: str, component_type: str): + match component_type: + case "action": + target_component_type = ComponentType.ACTION + case "command": + target_component_type = ComponentType.COMMAND + case "event_handler": + target_component_type = ComponentType.EVENT_HANDLER + case _: + await self.send_text(f"未知组件类型: {component_type}") + return + if component_manage_api.globally_enable_component(component_name, target_component_type): + await self.send_text(f"全局启用组件成功: {component_name}") + else: + await self.send_text(f"全局启用组件失败: {component_name}") + + async def _globally_disable_component(self, component_name: str, component_type: str): + match component_type: + case "action": + target_component_type = ComponentType.ACTION + case "command": + target_component_type = ComponentType.COMMAND + case "event_handler": + target_component_type = ComponentType.EVENT_HANDLER + case _: + await self.send_text(f"未知组件类型: {component_type}") + return + success = await component_manage_api.globally_disable_component(component_name, target_component_type) + if success: + await self.send_text(f"全局禁用组件成功: {component_name}") + else: + await self.send_text(f"全局禁用组件失败: {component_name}") + + async def _locally_enable_component(self, component_name: str, component_type: str): + match component_type: + case "action": + target_component_type = ComponentType.ACTION + case "command": + target_component_type = ComponentType.COMMAND + case "event_handler": + target_component_type = ComponentType.EVENT_HANDLER + case _: + await self.send_text(f"未知组件类型: {component_type}") + return + if component_manage_api.locally_enable_component( + component_name, + target_component_type, + self.message.chat_stream.stream_id, + ): + await self.send_text(f"本地启用组件成功: {component_name}") + else: + await self.send_text(f"本地启用组件失败: {component_name}") + + async def _locally_disable_component(self, component_name: str, component_type: str): + match component_type: + case "action": + target_component_type = ComponentType.ACTION + case "command": + target_component_type = ComponentType.COMMAND + case "event_handler": + target_component_type = ComponentType.EVENT_HANDLER + case _: + await self.send_text(f"未知组件类型: {component_type}") + return + if component_manage_api.locally_disable_component( + component_name, + target_component_type, + self.message.chat_stream.stream_id, + ): + await self.send_text(f"本地禁用组件成功: {component_name}") + else: + await self.send_text(f"本地禁用组件失败: {component_name}") + + +@register_plugin +class PluginManagementPlugin(BasePlugin): + plugin_name: str = "plugin_management_plugin" + enable_plugin: bool = True + dependencies: list[str] = [] + python_dependencies: list[str] = [] + config_file_name: str = "config.toml" + config_schema: dict = { + "plugin": { + "enable": ConfigField(bool, default=True, description="是否启用插件"), + "permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表"), + }, + } + + def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]: + components = [] + if self.get_config("plugin.enable", True): + components.append((ManagementCommand.get_command_info(), ManagementCommand)) + return components diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 7d45f4d30..6683735e4 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -92,7 +92,7 @@ class TTSAction(BaseAction): # 确保句子结尾有合适的标点 if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]): - processed_text = processed_text + "。" + processed_text = f"{processed_text}。" return processed_text @@ -107,11 +107,11 @@ class TTSPlugin(BasePlugin): """ # 插件基本信息 - plugin_name = "tts_plugin" # 内部标识符 - enable_plugin = True - dependencies = [] # 插件依赖列表 - python_dependencies = [] # Python包依赖列表 - config_file_name = "config.toml" + plugin_name: str = "tts_plugin" # 内部标识符 + enable_plugin: bool = True + dependencies: list[str] = [] # 插件依赖列表 + python_dependencies: list[str] = [] # Python包依赖列表 + config_file_name: str = "config.toml" # 配置节描述 config_section_descriptions = { @@ -121,7 +121,7 @@ class TTSPlugin(BasePlugin): } # 配置Schema定义 - config_schema = { + config_schema: dict = { "plugin": { "name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True), "version": ConfigField(type=str, default="0.1.0", description="插件版本号"), diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 04cf745da..ff8a79e73 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "4.4.4" +version = "4.4.8" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -13,6 +13,7 @@ version = "4.4.4" #----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- [bot] +platform = "qq" qq_account = 1145141919810 # 麦麦的QQ账号 nickname = "麦麦" # 麦麦的昵称 alias_names = ["麦叠", "牢麦"] # 麦麦的别名 @@ -33,7 +34,7 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息 # 表达方式 enable_expression = true # 是否启用表达方式 # 描述麦麦说话的表达风格,表达习惯,例如:(请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。) -expression_style = "请回复的平淡些,简短一些,说中文,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,不要刻意突出自身学科背景。" +expression_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" enable_expression_learning = false # 是否启用表达学习,麦麦会学习不同群里人类说话风格(群之间不互通) learning_interval = 350 # 学习间隔 单位秒 @@ -58,6 +59,9 @@ max_context_size = 25 # 上下文长度 thinking_timeout = 20 # 麦麦一次回复最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢) replyer_random_probability = 0.5 # 首要replyer模型被选择的概率 +mentioned_bot_inevitable_reply = true # 提及 bot 大概率回复 +at_bot_inevitable_reply = true # @bot 或 提及bot 大概率回复 + use_s4u_prompt_mode = true # 是否使用 s4u 对话构建模式,该模式会更好的把握当前对话对象的对话内容,但是对群聊整理理解能力较差(测试功能!!可能有未知问题!!) @@ -87,8 +91,6 @@ talk_frequency_adjust = [ # - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3 # - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配 -enable_asr = false # 是否启用语音识别,启用后麦麦可以通过语音输入进行对话,启用该功能需要配置语音识别模型[model.voice] - [message_receive] # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 ban_words = [ @@ -102,11 +104,8 @@ ban_msgs_regex = [ ] [normal_chat] #普通聊天 -#一般回复参数 willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现) response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数 -mentioned_bot_inevitable_reply = true # 提及 bot 必然回复 -at_bot_inevitable_reply = true # @bot 必然回复(包含提及) [tool] enable_in_normal_chat = false # 是否在普通聊天中启用工具 @@ -144,14 +143,15 @@ enable_instant_memory = false # 是否启用即时记忆,测试功能,可能 #不希望记忆的词,已经记忆的不会受到影响,需要手动清理 memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ] +[voice] +enable_asr = false # 是否启用语音识别,启用后麦麦可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]s + [mood] enable_mood = true # 是否启用情绪系统 -mood_update_interval = 1.0 # 情绪更新间隔 单位秒 -mood_decay_rate = 0.95 # 情绪衰减率 -mood_intensity_factor = 1.0 # 情绪强度因子 +mood_update_threshold = 1 # 情绪更新阈值,越高,更新越慢 [lpmm_knowledge] # lpmm知识库配置 -enable = true # 是否启用lpmm知识库 +enable = false # 是否启用lpmm知识库 rag_synonym_search_top_k = 10 # 同义词搜索TopK rag_synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词) info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5 @@ -229,7 +229,7 @@ show_prompt = false # 是否显示prompt [model] -model_max_output_length = 1000 # 模型单次返回的最大token数 +model_max_output_length = 1024 # 模型单次返回的最大token数 #------------必填:组件模型------------