diff --git a/.gitignore b/.gitignore
index df3ab670f..ac400b137 100644
--- a/.gitignore
+++ b/.gitignore
@@ -301,3 +301,5 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk
+src/chat/focus_chat/working_memory/test/test1.txt
+src/chat/focus_chat/working_memory/test/test4.txt
diff --git a/README.md b/README.md
index f07e7d57f..b1c271245 100644
--- a/README.md
+++ b/README.md
@@ -1,188 +1,125 @@
-# 麦麦!MaiCore-MaiMBot (编辑中)
-
-
+
+
+
+
- 
- 
- 
- 
- 
- 
- 
+# 麦麦!MaiCore-MaiBot (编辑中)
-
+
+
+
+
+
+
+
+[](https://deepwiki.com/DrSmoothl/MaiBot)
-
-
-
-
-
-
- 画师:略nd
-
+
+🌟 演示视频 |
+🚀 快速入门 |
+📃 教程 |
+💬 讨论 |
+🙋 贡献指南
+
-
MaiBot(麦麦)
-
- 一款专注于 群组聊天 的赛博网友
-
- 探索本项目的文档 »
-
-
-
- 报告Bug
- ·
- 提出新特性
-
-
+## 🎉 介绍
-## 新版0.6.x部署前先阅读:https://docs.mai-mai.org/faq/maibot/backup_update.html
+**🍔MaiCore 是一个基于大语言模型的可交互智能体**
-
-## 📝 项目简介
-
-**🍔MaiCore是一个基于大语言模型的可交互智能体**
-
-
-- 💭 **智能对话系统**:基于LLM的自然语言交互
-- 🤔 **实时思维系统**:模拟人类思考过程
-- 💝 **情感表达系统**:丰富的表情包和情绪表达
-- 🧠 **持久记忆系统**:基于MongoDB的长期记忆存储
-- 🔄 **动态人格系统**:自适应的性格特征
+- 💭 **智能对话系统**:基于 LLM 的自然语言交互。
+- 🤔 **实时思维系统**:模拟人类思考过程。
+- 💝 **情感表达系统**:丰富的表情包和情绪表达。
+- 🧠 **持久记忆系统**:基于 MongoDB 的长期记忆存储。
+- 🔄 **动态人格系统**:自适应的性格特征。
+## 🔥 更新和安装
-### 📢 版本信息
-
-**最新版本: v0.6.3** ([查看更新日志](changelogs/changelog.md))
-> [!WARNING]
-> 请阅读教程后更新!!!!!!!
-> 请阅读教程后更新!!!!!!!
-> 请阅读教程后更新!!!!!!!
-> 次版本MaiBot将基于MaiCore运行,不再依赖于nonebot相关组件运行。
-> MaiBot将通过nonebot的插件与nonebot建立联系,然后nonebot与QQ建立联系,实现MaiBot与QQ的交互
-
-**分支说明:**
-- `main`: 稳定发布版本
-- `dev`: 开发测试版本(不知道什么意思就别下)
-- `classical`: 0.6.0之前的版本
-
+**最新版本: v0.6.3** ([更新日志](changelogs/changelog.md))
+可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
+**GitHub 分支说明:**
+- `main`: 稳定发布版本(推荐)
+- `dev`: 开发测试版本(不稳定)
+- `classical`: 旧版本(停止维护)
+### 最新版本部署教程 (MaiCore 版本)
+- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
> [!WARNING]
-> - 项目处于活跃开发阶段,代码可能随时更改
-> - 文档未完善,有问题可以提交 Issue 或者 Discussion
-> - QQ机器人存在被限制风险,请自行了解,谨慎使用
-> - 由于持续迭代,可能存在一些已知或未知的bug
-> - 由于开发中,可能消耗较多token
-
-### ⚠️ 重要提示
-
-- 升级到v0.6.x版本前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html)
-- 本版本基于MaiCore重构,通过nonebot插件与QQ平台交互
-- 项目处于活跃开发阶段,功能和API可能随时调整
-
-### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
-- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517
-- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722
-- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
-- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475【已满】
-- [四群](https://qm.qq.com/q/wGePTl1UyY) 729957033【已满】
+> - 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html)
+> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
+> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
+> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
+> - 由于持续迭代,可能存在一些已知或未知的 bug。
+> - 由于程序处于开发中,可能消耗较多 token。
+## 💬 讨论
+- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
+ [二群](https://qm.qq.com/q/RzmCiRtHEW) |
+ [五群](https://qm.qq.com/q/JxvHZnxyec) |
+ [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)|
+ [四群](https://qm.qq.com/q/wGePTl1UyY)(已满)
## 📚 文档
+**部分内容可能更新不够及时,请注意版本对应**
-### (部分内容可能过时,请注意版本对应)
+- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
-### 核心文档
-- [📚 核心Wiki文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切
-
-### 最新版本部署教程(MaiCore版本)
-- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式(与旧版本不兼容)
-
-
-## 🎯 0.6.3 功能介绍
-
-| 模块 | 主要功能 | 特点 |
-|----------|------------------------------------------------------------------|-------|
-| 💬 聊天系统 | • **统一调控不同回复逻辑**
• 智能交互模式 (普通聊天/专注聊天)
• 关键词主动发言
• 多模型支持
• 动态prompt构建
• 私聊功能(PFC)增强 | 拟人化交互 |
-| 🧠 心流系统 | • 实时思考生成
• **智能状态管理**
• **概率回复机制**
• 自动启停机制
• 日程系统联动
• **上下文感知工具调用** | 智能化决策 |
-| 🧠 记忆系统 | • **记忆整合与提取**
• 海马体记忆机制
• 聊天记录概括 | 持久化记忆 |
-| 😊 表情系统 | • **全新表情包系统**
• **优化选择逻辑**
• 情绪匹配发送
• GIF支持
• 自动收集与审查 | 丰富表达 |
-| 📅 日程系统 | • 动态日程生成
• 自定义想象力
• 思维流联动 | 智能规划 |
-| 👥 关系系统 | • **工具调用动态更新**
• 关系管理优化
• 丰富接口支持
• 个性化交互 | 深度社交 |
-| 📊 统计系统 | • 使用数据统计
• LLM调用记录
• 实时控制台显示 | 数据可视 |
-| 🛠️ 工具系统 | • **LPMM知识库集成**
• **上下文感知调用**
• 知识获取工具
• 自动注册机制
• 多工具支持 | 扩展功能 |
-| 📚 **知识库(LPMM)** | • **全新LPMM系统**
• **强大的信息检索能力** | 知识增强 |
-| ✨ **昵称系统** | • **自动为群友取昵称**
• **降低认错人概率** (早期阶段) | 身份识别 |
-
-## 📐 项目架构
-
-```mermaid
-graph TD
- A[MaiCore] --> B[对话系统]
- A --> C[心流系统]
- A --> D[记忆系统]
- A --> E[情感系统]
- B --> F[多模型支持]
- B --> G[动态Prompt]
- C --> H[实时思考]
- C --> I[日程联动]
- D --> J[记忆存储]
- D --> K[记忆检索]
- E --> L[表情管理]
- E --> M[情绪识别]
-```
-
-## ✍️如何给本项目报告BUG/提交建议/做贡献
-
-MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](depends-data/CONTRIBUTE.md)(待补完)
-
-
-
-## 设计理念(原始时代的火花)
+### 设计理念(原始时代的火花)
> **千石可乐说:**
-> - 这个项目最初只是为了给牛牛bot添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
-> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"
-> - 如果人类真的需要一个AI来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
-> - 代码会保持开源和开放,但个人希望MaiMbot的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器.
-> - SengokuCola~~纯编程外行,面向cursor编程,很多代码写得不好多多包涵~~已得到大脑升级
+> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
+> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。
+> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
+> - 代码会保持开源和开放,但个人希望 MaiMbot 的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试。我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器。
+> - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级。
-
-## 📌 注意事项
-
-> [!WARNING]
-> 使用本项目前必须阅读和同意用户协议和隐私协议
-> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本人观点和立场。
-
-## 致谢
-
-- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现
-
-## 麦麦仓库状态
-
-
+## 🙋 贡献和致谢
+你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
+MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉
+但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
### 贡献者
-感谢各位大佬!
+感谢各位大佬!
-**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们**
+### 致谢
-## Stargazers over time
+- [略nd](https://space.bilibili.com/1344099355): 为麦麦绘制人设。
+- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现。
-[](https://starchart.cc/MaiM-with-u/MaiBot)
+**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们!**
+
+## 📌 注意事项
+
+> [!WARNING]
+> 使用本项目前必须阅读和同意[用户协议](EULA.md)和[隐私协议](PRIVACY.md)。
+> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI 生成内容不代表本项目团队的观点和立场。
+
+## 麦麦仓库状态
+
+
+
+### Star 趋势
+
+[](https://starchart.cc/MaiM-with-u/MaiBot)
+
+## License
+
+GPL-3.0
diff --git a/changelogs/changelog_dev.md b/changelogs/changelog_dev.md
deleted file mode 100644
index 663ad9629..000000000
--- a/changelogs/changelog_dev.md
+++ /dev/null
@@ -1,27 +0,0 @@
-这里放置了测试版本的细节更新
-
-## [test-0.6.1-snapshot-1] - 2025-4-5
-- 修复pfc回复出错bug
-- 修复表情包打字时间,不会卡表情包
-- 改进了知识库的提取
-- 提供了新的数据库连接方式
-- 修复了ban_user无效的问题
-
-## [test-0.6.0-snapshot-9] - 2025-4-4
-- 可以识别gif表情包
-
-## [test-0.6.0-snapshot-8] - 2025-4-3
-- 修复了表情包的注册,获取和发送逻辑
-- 表情包增加存储上限
-- 更改了回复引用的逻辑,从基于时间改为基于新消息
-- 增加了调试信息
-- 自动清理缓存图片
-- 修复并重启了关系系统
-
-## [test-0.6.0-snapshot-7] - 2025-4-2
-- 修改版本号命名:test-前缀为测试版,无前缀为正式版
-- 提供私聊的PFC模式,可以进行有目的,自由多轮对话
-
-## [0.6.0-mmc-4] - 2025-4-1
-- 提供两种聊天逻辑,思维流聊天(ThinkFlowChat 和 推理聊天(ReasoningChat)
-- 从结构上可支持多种回复消息逻辑
\ No newline at end of file
diff --git a/src/0.6Bing.md b/docs/0.6Bing.md
similarity index 100%
rename from src/0.6Bing.md
rename to docs/0.6Bing.md
diff --git a/depends-data/CONTRIBUTE.md b/docs/CONTRIBUTE.md
similarity index 100%
rename from depends-data/CONTRIBUTE.md
rename to docs/CONTRIBUTE.md
diff --git a/src/heartFC_chatting_logic.md b/docs/HeartFC_chatting_logic.md
similarity index 100%
rename from src/heartFC_chatting_logic.md
rename to docs/HeartFC_chatting_logic.md
diff --git a/src/heartFC_readme.md b/docs/HeartFC_readme.md
similarity index 100%
rename from src/heartFC_readme.md
rename to docs/HeartFC_readme.md
diff --git a/src/README.md b/docs/HeartFC_system.md
similarity index 99%
rename from src/README.md
rename to docs/HeartFC_system.md
index a55f1c973..e48a7b5d7 100644
--- a/src/README.md
+++ b/docs/HeartFC_system.md
@@ -149,7 +149,7 @@ c HeartFChatting工作方式
- **状态及含义**:
- `ChatState.ABSENT` (不参与/没在看): 初始或停用状态。子心流不观察新信息,不进行思考,也不回复。
- `ChatState.CHAT` (随便看看/水群): 普通聊天模式。激活 `NormalChatInstance`。
- * `ChatState.FOCUSED` (专注/认真水群): 专注聊天模式。激活 `HeartFlowChatInstance`。
+ * `ChatState.FOCUSED` (专注/认真聊天): 专注聊天模式。激活 `HeartFlowChatInstance`。
- **选择**: 子心流可以根据外部指令(来自 `SubHeartflowManager`)或内部逻辑(未来的扩展)选择进入 `ABSENT` 状态(不回复不观察),或进入 `CHAT` / `FOCUSED` 中的一种回复模式。
- **状态转换机制** (由 `SubHeartflowManager` 驱动,更细致的说明):
- **初始状态**: 新创建的 `SubHeartflow` 默认为 `ABSENT` 状态。
diff --git a/src/tools/tool_can_use/README.md b/docs/use_tool.md
similarity index 100%
rename from src/tools/tool_can_use/README.md
rename to docs/use_tool.md
diff --git a/requirements.txt b/requirements.txt
index 7abdffb48..0e60bc192 100644
Binary files a/requirements.txt and b/requirements.txt differ
diff --git a/src/api/config_api.py b/src/api/config_api.py
index 0b23fb993..8b99fb93e 100644
--- a/src/api/config_api.py
+++ b/src/api/config_api.py
@@ -41,7 +41,7 @@ class APIBotConfig:
allow_focus_mode: bool # 是否允许专注聊天状态
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
- observation_context_size: int # 观察到的最长上下文大小
+ chat.observation_context_size: int # 观察到的最长上下文大小
message_buffer: bool # 是否启用消息缓冲
ban_words: List[str] # 禁止词列表
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
@@ -128,7 +128,7 @@ class APIBotConfig:
llm_reasoning: Dict[str, Any] # 推理模型配置
llm_normal: Dict[str, Any] # 普通模型配置
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
- llm_summary: Dict[str, Any] # 总结模型配置
+ model.summary: Dict[str, Any] # 总结模型配置
vlm: Dict[str, Any] # VLM模型配置
llm_heartflow: Dict[str, Any] # 心流模型配置
llm_observation: Dict[str, Any] # 观察模型配置
@@ -203,7 +203,7 @@ class APIBotConfig:
"llm_reasoning",
"llm_normal",
"llm_topic_judge",
- "llm_summary",
+ "model.summary",
"vlm",
"llm_heartflow",
"llm_observation",
diff --git a/src/api/reload_config.py b/src/api/reload_config.py
index a5f36e3db..1772800b6 100644
--- a/src/api/reload_config.py
+++ b/src/api/reload_config.py
@@ -1,6 +1,6 @@
from fastapi import HTTPException
from rich.traceback import install
-from src.config.config import BotConfig
+from src.config.config import Config
from src.common.logger_manager import get_logger
import os
@@ -14,8 +14,8 @@ async def reload_config():
from src.config import config as config_module
logger.debug("正在重载配置文件...")
- bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
- config_module.global_config = BotConfig.load_config(config_path=bot_config_path)
+ bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml")
+ config_module.global_config = Config.load_config(config_path=bot_config_path)
logger.debug("配置文件重载成功")
return {"status": "reloaded"}
except FileNotFoundError as e:
diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py
index 5d800866f..fda0a63fd 100644
--- a/src/chat/emoji_system/emoji_manager.py
+++ b/src/chat/emoji_system/emoji_manager.py
@@ -5,12 +5,15 @@ import os
import random
import time
import traceback
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List, Any
from PIL import Image
import io
import re
-from ...common.database import db
+# from gradio_client import file
+
+from ...common.database.database_model import Emoji
+from ...common.database.database import db as peewee_db
from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest
@@ -51,7 +54,7 @@ class MaiEmoji:
self.is_deleted = False # 标记是否已被删除
self.format = ""
- async def initialize_hash_format(self):
+ async def initialize_hash_format(self) -> Optional[bool]:
"""从文件创建表情包实例, 计算哈希值和格式"""
try:
# 使用 full_path 检查文件是否存在
@@ -104,7 +107,7 @@ class MaiEmoji:
self.is_deleted = True
return None
- async def register_to_db(self):
+ async def register_to_db(self) -> bool:
"""
注册表情包
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
@@ -143,22 +146,22 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
- emoji_record = {
- "filename": self.filename,
- "path": self.path, # 存储目录路径
- "full_path": self.full_path, # 存储完整文件路径
- "embedding": self.embedding,
- "description": self.description,
- "emotion": self.emotion,
- "hash": self.hash,
- "format": self.format,
- "timestamp": int(self.register_time),
- "usage_count": self.usage_count,
- "last_used_time": self.last_used_time,
- }
+ emotion_str = ",".join(self.emotion) if self.emotion else ""
- # 使用upsert确保记录存在或被更新
- db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
+ Emoji.create(
+ hash=self.hash,
+ full_path=self.full_path,
+ format=self.format,
+ description=self.description,
+ emotion=emotion_str, # Store as comma-separated string
+ query_count=0, # Default value
+ is_registered=True,
+ is_banned=False, # Default value
+ record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
+ register_time=self.register_time,
+ usage_count=self.usage_count,
+ last_used_time=self.last_used_time,
+ )
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +169,6 @@ class MaiEmoji:
except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
- # 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
- # 可以考虑在这里尝试删除已移动的文件,避免残留
- try:
- if os.path.exists(self.full_path): # full_path 此时是目标路径
- os.remove(self.full_path)
- logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
- except Exception as remove_error:
- logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False
except Exception as e:
@@ -181,7 +176,7 @@ class MaiEmoji:
logger.error(traceback.format_exc())
return False
- async def delete(self):
+ async def delete(self) -> bool:
"""删除表情包
删除表情包的文件和数据库记录
@@ -201,10 +196,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录
- result = db.emoji.delete_one({"hash": self.hash})
- deleted_in_db = result.deleted_count > 0
+ try:
+ will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
+ result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
+ except Emoji.DoesNotExist:
+ logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
+ result = 0 # Indicate no DB record was deleted
- if deleted_in_db:
+ if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除
self.is_deleted = True
@@ -224,7 +223,7 @@ class MaiEmoji:
return False
-def _emoji_objects_to_readable_list(emoji_objects):
+def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
"""将表情包对象列表转换为可读的字符串列表
参数:
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
return emoji_info_list
-def _to_emoji_objects(data):
+def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
+ # data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
- for emoji_data in emoji_data_list:
- full_path = emoji_data.get("full_path")
+ for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
+ full_path = emoji_data.full_path
if not full_path:
- logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}")
+ logger.warning(
+ f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
+ )
load_errors += 1
- continue # 跳过缺少 full_path 的记录
+ continue
try:
- # 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path)
- # 设置从数据库加载的属性
- emoji.hash = emoji_data.get("hash", "")
- # 如果 hash 为空,也跳过?取决于业务逻辑
+ emoji.hash = emoji_data.emoji_hash
if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1
continue
- emoji.description = emoji_data.get("description", "")
- emoji.emotion = emoji_data.get("emotion", [])
- emoji.usage_count = emoji_data.get("usage_count", 0)
- # 优先使用 last_used_time,否则用 timestamp,最后用当前时间
- last_used = emoji_data.get("last_used_time")
- timestamp = emoji_data.get("timestamp")
- emoji.last_used_time = (
- last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
- )
- emoji.register_time = timestamp if timestamp is not None else time.time()
- emoji.format = emoji_data.get("format", "") # 加载格式
+ emoji.description = emoji_data.description
+ # Deserialize emotion string from DB to list
+ emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
+ emoji.usage_count = emoji_data.usage_count
- # 不需要再手动设置 path 和 filename,__init__ 会自动处理
+ db_last_used_time = emoji_data.last_used_time
+ db_register_time = emoji_data.register_time
+
+ # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
+ emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
+ # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
+ emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
+
+ emoji.format = emoji_data.format
emoji_objects.append(emoji)
- except ValueError as ve: # 捕获 __init__ 可能的错误
+ except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1
except Exception as e:
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
return emoji_objects, load_errors
-def _ensure_emoji_dir():
+def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
-async def clear_temp_emoji():
+async def clear_temp_emoji() -> None:
"""清理临时表情包
清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过100时,会全部删除
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
logger.success("[清理] 完成")
-async def clean_unused_emojis(emoji_dir, emoji_objects):
+async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
class EmojiManager:
_instance = None
- def __new__(cls):
+ def __new__(cls) -> "EmojiManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
- def __init__(self):
+ def __init__(self) -> None:
self._initialized = None
self._scan_task = None
- self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
+
+ self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
- model=global_config.llm_normal, max_tokens=600, request_type="emoji"
+ model=global_config.model.normal, max_tokens=600, request_type="emoji"
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
self.emoji_num = 0
- self.emoji_num_max = global_config.max_emoji_num
- self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
+ self.emoji_num_max = global_config.emoji.max_reg_num
+ self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
- def initialize(self):
+ def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
- if not self._initialized:
- try:
- self._ensure_emoji_collection()
- _ensure_emoji_dir()
- self._initialized = True
- # 更新表情包数量
- # 启动时执行一次完整性检查
- # await self.check_emoji_file_integrity()
- except Exception as e:
- logger.exception(f"初始化表情管理器失败: {e}")
+ peewee_db.connect(reuse_if_open=True)
+ if peewee_db.is_closed():
+ raise RuntimeError("数据库连接失败")
+ _ensure_emoji_dir()
+ Emoji.create_table(safe=True) # Ensures table exists
- def _ensure_db(self):
+ def _ensure_db(self) -> None:
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
- @staticmethod
- def _ensure_emoji_collection():
- """确保emoji集合存在并创建索引
-
- 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
-
- 索引的作用是加快数据库查询速度:
- - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- - tags字段的普通索引: 加快按标签搜索表情包的速度
- - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
-
- 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
- """
- if "emoji" not in db.list_collection_names():
- db.create_collection("emoji")
- db.emoji.create_index([("embedding", "2dsphere")])
- db.emoji.create_index([("filename", 1)], unique=True)
-
- def record_usage(self, emoji_hash: str):
+ def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
- db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}})
- for emoji in self.emoji_objects:
- if emoji.hash == emoji_hash:
- emoji.usage_count += 1
- break
-
+ emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
+ emoji_update.usage_count += 1
+ emoji_update.last_used_time = time.time() # Update last used time
+ emoji_update.save() # Persist changes to DB
+ except Emoji.DoesNotExist:
+ logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -447,7 +425,6 @@ class EmojiManager:
if not all_emojis:
logger.warning("内存中没有任何表情包对象")
- # 可以考虑再查一次数据库?或者依赖定期任务更新
return None
# 计算每个表情包与输入文本的最大情感相似度
@@ -463,40 +440,38 @@ class EmojiManager:
# 计算与每个emotion标签的相似度,取最大值
max_similarity = 0
- best_matching_emotion = "" # 记录最匹配的 emotion 喵~
+ best_matching_emotion = ""
for emotion in emotions:
# 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0)
- if similarity > max_similarity: # 如果找到更相似的喵~
+ if similarity > max_similarity:
max_similarity = similarity
- best_matching_emotion = emotion # 就记下这个 emotion 喵~
+ best_matching_emotion = emotion
- if best_matching_emotion: # 确保有匹配的情感才添加喵~
- emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~
+ if best_matching_emotion:
+ emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前10个最相似的表情包
- top_emojis = (
- emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
- ) # 改个名字,更清晰喵~
+ top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
if not top_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前几个中随机选择一个
- selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~
+ selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
# 更新使用次数
- self.record_usage(selected_emoji.hash)
+ self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time()
- logger.info( # 使用匹配到的 emotion 记录日志喵~
+ logger.info(
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
)
# 返回完整文件路径和描述
@@ -534,7 +509,7 @@ class EmojiManager:
return previous_row[-1]
- async def check_emoji_file_integrity(self):
+ async def check_emoji_file_integrity(self) -> None:
"""检查表情包文件完整性
遍历self.emoji_objects中的所有对象,检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除
@@ -599,7 +574,7 @@ class EmojiManager:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
- async def start_periodic_check_register(self):
+ async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db()
while True:
@@ -613,18 +588,18 @@ class EmojiManager:
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
- await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
+ await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查目录是否为空
files = os.listdir(EMOJI_DIR)
if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
- await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
+ await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
- if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
+ if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
@@ -651,15 +626,16 @@ class EmojiManager:
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
- await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
+ await asyncio.sleep(global_config.emoji.check_interval * 60)
- async def get_all_emoji_from_db(self):
+ async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
try:
self._ensure_db()
- logger.info("[数据库] 开始加载所有表情包记录...")
+ logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
- emoji_objects, load_errors = _to_emoji_objects(db.emoji.find())
+ emoji_peewee_instances = Emoji.select()
+ emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -674,7 +650,7 @@ class EmojiManager:
self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0
- async def get_emoji_from_db(self, emoji_hash=None):
+ async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数:
@@ -686,15 +662,16 @@ class EmojiManager:
try:
self._ensure_db()
- query = {}
if emoji_hash:
- query = {"hash": emoji_hash}
+ query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
else:
logger.warning(
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
)
+ query = Emoji.select()
- emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query))
+ emoji_peewee_instances = query
+ emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -705,7 +682,7 @@ class EmojiManager:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
return []
- async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]:
+ async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
"""从内存中的 emoji_objects 列表获取表情包
参数:
@@ -758,7 +735,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
- async def replace_a_emoji(self, new_emoji: MaiEmoji):
+ async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
"""替换一个表情包
Args:
@@ -788,7 +765,7 @@ class EmojiManager:
# 构建提示词
prompt = (
- f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
+ f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n"
@@ -819,7 +796,7 @@ class EmojiManager:
# 删除选定的表情包
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
- delete_success = await self.delete_emoji(emoji_to_delete.hash)
+ delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
if delete_success:
# 修复:等待异步注册完成
@@ -847,7 +824,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
- async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
+ async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
"""获取表情包描述和情感列表
Args:
@@ -871,10 +848,10 @@ class EmojiManager:
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包
- if global_config.EMOJI_CHECK:
+ if global_config.emoji.content_filtration:
prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下:
- 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
+ 1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字
diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py
index 37c50c0dc..d3d21e074 100644
--- a/src/chat/focus_chat/expressors/default_expressor.py
+++ b/src/chat/focus_chat/expressors/default_expressor.py
@@ -10,7 +10,6 @@ from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.emoji_system.emoji_manager import emoji_manager
-from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.utils.utils import process_llm_response
from src.chat.utils.info_catcher import info_catcher_manager
@@ -18,16 +17,69 @@ from src.manager.mood_manager import mood_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
+from src.individuality.individuality import Individuality
+from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
+import time
+from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
+import random
logger = get_logger("expressor")
+def init_prompt():
+ Prompt(
+ """
+你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
+{style_habbits}
+
+你现在正在群里聊天,以下是群里正在进行的聊天内容:
+{chat_info}
+
+以上是聊天内容,你需要了解聊天记录中的内容
+
+{chat_target}
+你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
+你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
+请你根据情景使用以下句法:
+{grammar_habbits}
+回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
+回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
+现在,你说:
+""",
+ "default_expressor_prompt",
+ )
+
+ Prompt(
+ """
+你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
+{style_habbits}
+
+你现在正在群里聊天,以下是群里正在进行的聊天内容:
+{chat_info}
+
+以上是聊天内容,你需要了解聊天记录中的内容
+
+{chat_target}
+你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
+你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
+请你根据情景使用以下句法:
+{grammar_habbits}
+回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
+回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
+现在,你说:
+""",
+ "default_expressor_private_prompt", # New template for private FOCUSED chat
+ )
+
+
class DefaultExpressor:
def __init__(self, chat_id: str):
self.log_prefix = "expressor"
+ # TODO: API-Adapter修改标记
self.express_model = LLMRequest(
- model=global_config.llm_normal,
- temperature=global_config.llm_normal["temp"],
+ model=global_config.model.normal,
+ temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_heartflow",
)
@@ -51,8 +103,8 @@ class DefaultExpressor:
messageinfo = anchor_message.message_info
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
bot_user_info = UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
# logger.debug(f"创建思考消息:{anchor_message}")
@@ -66,7 +118,7 @@ class DefaultExpressor:
reply=anchor_message, # 回复的是锚点消息
thinking_start_time=thinking_time_point,
)
- logger.debug(f"创建思考消息thinking_message:{thinking_message}")
+ # logger.debug(f"创建思考消息thinking_message:{thinking_message}")
await self.heart_fc_sender.register_thinking(thinking_message)
@@ -106,7 +158,7 @@ class DefaultExpressor:
if reply:
with Timer("发送消息", cycle_timers):
- sent_msg_list = await self._send_response_messages(
+ sent_msg_list = await self.send_response_messages(
anchor_message=anchor_message,
thinking_id=thinking_id,
response_set=reply,
@@ -141,7 +193,7 @@ class DefaultExpressor:
try:
# 1. 获取情绪影响因子并调整模型温度
arousal_multiplier = mood_manager.get_arousal_multiplier()
- current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier
+ current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
self.express_model.params["temperature"] = current_temp # 动态调整温度
# 2. 获取信息捕捉器
@@ -162,13 +214,10 @@ class DefaultExpressor:
# 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
- prompt = await prompt_builder.build_prompt(
- build_mode="focus",
+ prompt = await self.build_prompt_focus(
chat_stream=self.chat_stream, # Pass the stream object
in_mind_reply=in_mind_reply,
reason=reason,
- current_mind_info="",
- structured_info="",
sender_name=sender_name_for_prompt, # Pass determined name
target_message=target_message,
)
@@ -183,10 +232,11 @@ class DefaultExpressor:
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
+ # TODO: API-Adapter修改标记
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
- logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
+ # logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
logger.info(f"想要表达:{in_mind_reply}")
logger.info(f"理由:{reason}")
@@ -223,10 +273,108 @@ class DefaultExpressor:
traceback.print_exc()
return None
+ async def build_prompt_focus(
+ self,
+ reason,
+ chat_stream,
+ sender_name,
+ in_mind_reply,
+ target_message,
+ ) -> str:
+ individuality = Individuality.get_instance()
+ prompt_personality = individuality.get_prompt(x_person=0, level=2)
+
+ # Determine if it's a group chat
+ is_group_chat = bool(chat_stream.group_info)
+
+ # Use sender_name passed from caller for private chat, otherwise use a default for group
+ # Default sender_name for group chat isn't used in the group prompt template, but set for consistency
+ effective_sender_name = sender_name if not is_group_chat else "某人"
+
+ message_list_before_now = get_raw_msg_before_timestamp_with_chat(
+ chat_id=chat_stream.stream_id,
+ timestamp=time.time(),
+ limit=global_config.chat.observation_context_size,
+ )
+ chat_talking_prompt = await build_readable_messages(
+ message_list_before_now,
+ replace_bot_name=True,
+ merge_messages=True,
+ timestamp_mode="relative",
+ read_mark=0.0,
+ truncate=True,
+ )
+
+ (
+ learnt_style_expressions,
+ learnt_grammar_expressions,
+ personality_expressions,
+ ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
+
+ style_habbits = []
+ grammar_habbits = []
+ # 1. learnt_expressions加权随机选3条
+ if learnt_style_expressions:
+ weights = [expr["count"] for expr in learnt_style_expressions]
+ selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
+ for expr in selected_learnt:
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+ # 2. learnt_grammar_expressions加权随机选3条
+ if learnt_grammar_expressions:
+ weights = [expr["count"] for expr in learnt_grammar_expressions]
+ selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
+ for expr in selected_learnt:
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+ # 3. personality_expressions随机选1条
+ if personality_expressions:
+ expr = random.choice(personality_expressions)
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+
+ style_habbits_str = "\n".join(style_habbits)
+ grammar_habbits_str = "\n".join(grammar_habbits)
+
+ logger.debug("开始构建 focus prompt")
+
+ # --- Choose template based on chat type ---
+ if is_group_chat:
+ template_name = "default_expressor_prompt"
+ # Group specific formatting variables (already fetched or default)
+ chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
+ # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
+
+ prompt = await global_prompt_manager.format_prompt(
+ template_name,
+ style_habbits=style_habbits_str,
+ grammar_habbits=grammar_habbits_str,
+ chat_target=chat_target_1,
+ chat_info=chat_talking_prompt,
+ bot_name=global_config.bot.nickname,
+ prompt_personality="",
+ reason=reason,
+ in_mind_reply=in_mind_reply,
+ target_message=target_message,
+ )
+ else: # Private chat
+ template_name = "default_expressor_private_prompt"
+ prompt = await global_prompt_manager.format_prompt(
+ template_name,
+ sender_name=effective_sender_name, # Used in private template
+ chat_talking_prompt=chat_talking_prompt,
+ bot_name=global_config.bot.nickname,
+ prompt_personality=prompt_personality,
+ reason=reason,
+ moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
+ )
+
+ return prompt
+
# --- 发送器 (Sender) --- #
- async def _send_response_messages(
- self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str
+ async def send_response_messages(
+ self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = ""
) -> Optional[MessageSending]:
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
@@ -241,7 +389,11 @@ class DefaultExpressor:
stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志
# 检查思考过程是否仍在进行,并获取开始时间
- thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
+ if thinking_id:
+ thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
+ else:
+ thinking_id = "ds" + str(round(time.time(), 2))
+ thinking_start_time = time.time()
if thinking_start_time is None:
logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。")
@@ -274,6 +426,7 @@ class DefaultExpressor:
reply_to=reply_to,
is_emoji=is_emoji,
thinking_id=thinking_id,
+ thinking_start_time=thinking_start_time,
)
try:
@@ -295,6 +448,7 @@ class DefaultExpressor:
except Exception as e:
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
+ traceback.print_exc()
# 这里可以选择是继续发送下一个片段还是中止
# 在尝试发送完所有片段后,完成原始的 thinking_id 状态
@@ -325,13 +479,13 @@ class DefaultExpressor:
reply_to: bool,
is_emoji: bool,
thinking_id: str,
+ thinking_start_time: float,
) -> MessageSending:
"""构建单个发送消息"""
- thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id)
bot_user_info = UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)
@@ -348,3 +502,40 @@ class DefaultExpressor:
)
return bot_message
+
+
+def weighted_sample_no_replacement(items, weights, k) -> list:
+ """
+ 加权且不放回地随机抽取k个元素。
+
+ 参数:
+ items: 待抽取的元素列表
+ weights: 每个元素对应的权重(与items等长,且为正数)
+ k: 需要抽取的元素个数
+ 返回:
+ selected: 按权重加权且不重复抽取的k个元素组成的列表
+
+ 如果 items 中的元素不足 k 个,就只会返回所有可用的元素
+
+ 实现思路:
+ 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
+ 这样保证了:
+ 1. count越大被选中概率越高
+ 2. 不会重复选中同一个元素
+ """
+ selected = []
+ pool = list(zip(items, weights))
+ for _ in range(min(k, len(pool))):
+ total = sum(w for _, w in pool)
+ r = random.uniform(0, total)
+ upto = 0
+ for idx, (item, weight) in enumerate(pool):
+ upto += weight
+ if upto >= r:
+ selected.append(item)
+ pool.pop(idx)
+ break
+ return selected
+
+
+init_prompt()
diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py
index 942162bc8..7766fde56 100644
--- a/src/chat/focus_chat/expressors/exprssion_learner.py
+++ b/src/chat/focus_chat/expressors/exprssion_learner.py
@@ -77,8 +77,9 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
+ # TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
- model=global_config.llm_normal,
+ model=global_config.model.normal,
temperature=0.1,
max_tokens=256,
request_type="response_heartflow",
@@ -289,7 +290,7 @@ class ExpressionLearner:
# 构建prompt
prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt",
- personality=global_config.expression_style,
+ personality=global_config.personality.expression_style,
)
# logger.info(f"个性表达方式提取prompt: {prompt}")
diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py
index 4a28652d1..4f17f9bdf 100644
--- a/src/chat/focus_chat/heartFC_chat.py
+++ b/src/chat/focus_chat/heartFC_chat.py
@@ -14,15 +14,17 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor
from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
-from src.chat.heart_flow.observation.memory_observation import MemoryObservation
+from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
-from src.chat.heart_flow.observation.working_observation import WorkingObservation
+from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.focus_chat.memory_activator import MemoryActivator
from src.chat.focus_chat.info_processors.base_processor import BaseProcessor
+from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
from src.chat.focus_chat.planners.planner import ActionPlanner
-from src.chat.focus_chat.planners.action_factory import ActionManager
+from src.chat.focus_chat.planners.action_manager import ActionManager
+from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
install(extra_lines=3)
@@ -57,7 +59,7 @@ async def _handle_cycle_delay(action_taken_this_cycle: bool, cycle_start_time: f
class HeartFChatting:
"""
- 管理一个连续的Plan-Replier-Sender循环
+ 管理一个连续的Focus Chat循环
用于在特定聊天流中生成回复。
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
"""
@@ -66,7 +68,6 @@ class HeartFChatting:
self,
chat_id: str,
observations: list[Observation],
- on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
):
"""
HeartFChatting 初始化函数
@@ -74,24 +75,27 @@ class HeartFChatting:
参数:
chat_id: 聊天流唯一标识符(如stream_id)
observations: 关联的观察列表
- on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的异步回调函数
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: Optional[ChatStream] = None # 关联的聊天流
- self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态
- self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
self.log_prefix: str = str(chat_id) # Initial default, will be updated
-
- self.memory_observation = MemoryObservation(observe_id=self.stream_id)
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
- self.working_observation = WorkingObservation(observe_id=self.stream_id)
+ self.chatting_observation = observations[0]
+
self.memory_activator = MemoryActivator()
+ self.working_memory = WorkingMemory(chat_id=self.stream_id)
+ self.working_observation = WorkingMemoryObservation(
+ observe_id=self.stream_id, working_memory=self.working_memory
+ )
+
self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
+ self.hfcloop_observation.set_action_manager(self.action_manager)
+ self.all_observations = observations
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
self._register_default_processors()
@@ -108,9 +112,7 @@ class HeartFChatting:
self._cycle_counter = 0
self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息
self._current_cycle: Optional[CycleDetail] = None
- self.total_no_reply_count: int = 0 # 连续不回复计数器
self._shutting_down: bool = False # 关闭标志位
- self.total_waiting_time: float = 0.0 # 累计等待时间
async def _initialize(self) -> bool:
"""
@@ -151,6 +153,8 @@ class HeartFChatting:
self.processors.append(ChattingInfoProcessor())
self.processors.append(MindProcessor(subheartflow_id=self.stream_id))
self.processors.append(ToolProcessor(subheartflow_id=self.stream_id))
+ self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id))
+ self.processors.append(SelfProcessor(subheartflow_id=self.stream_id))
logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}")
async def start(self):
@@ -158,7 +162,7 @@ class HeartFChatting:
启动 HeartFChatting 的主循环。
注意:调用此方法前必须确保已经成功初始化。
"""
- logger.info(f"{self.log_prefix} 开始认真水群(HFC)...")
+ logger.info(f"{self.log_prefix} 开始认真聊天(HFC)...")
await self._start_loop_if_needed()
async def _start_loop_if_needed(self):
@@ -328,6 +332,7 @@ class HeartFChatting:
f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}",
exc_info=True,
)
+ traceback.print_exc()
# 即使出错,也认为该任务结束了,已从 pending_tasks 中移除
if pending_tasks:
@@ -349,13 +354,12 @@ class HeartFChatting:
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]:
try:
with Timer("观察", cycle_timers):
- await self.observations[0].observe()
- await self.memory_observation.observe()
+ # await self.observations[0].observe()
+ await self.chatting_observation.observe()
await self.working_observation.observe()
await self.hfcloop_observation.observe()
observations: List[Observation] = []
- observations.append(self.observations[0])
- observations.append(self.memory_observation)
+ observations.append(self.chatting_observation)
observations.append(self.working_observation)
observations.append(self.hfcloop_observation)
@@ -363,6 +367,8 @@ class HeartFChatting:
"observations": observations,
}
+ self.all_observations = observations
+
with Timer("回忆", cycle_timers):
running_memorys = await self.memory_activator.activate_memory(observations)
@@ -395,8 +401,7 @@ class HeartFChatting:
elif action_type == "no_reply":
action_str = "不回复"
else:
- action_type = "unknown"
- action_str = "未知动作"
+ action_str = action_type
logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'")
@@ -452,14 +457,10 @@ class HeartFChatting:
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
- observations=self.observations,
+ observations=self.all_observations,
expressor=self.expressor,
chat_stream=self.chat_stream,
- current_cycle=self._current_cycle,
log_prefix=self.log_prefix,
- on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback,
- total_no_reply_count=self.total_no_reply_count,
- total_waiting_time=self.total_waiting_time,
shutting_down=self._shutting_down,
)
@@ -470,14 +471,6 @@ class HeartFChatting:
# 处理动作并获取结果
success, reply_text = await action_handler.handle_action()
- # 更新状态计数器
- if action == "no_reply":
- self.total_no_reply_count = getattr(action_handler, "total_no_reply_count", self.total_no_reply_count)
- self.total_waiting_time = getattr(action_handler, "total_waiting_time", self.total_waiting_time)
- elif action == "reply":
- self.total_no_reply_count = 0
- self.total_waiting_time = 0.0
-
return success, reply_text
except Exception as e:
@@ -526,5 +519,3 @@ class HeartFChatting:
if last_n is not None:
history = history[-last_n:]
return [cycle.to_dict() for cycle in history]
-
-
diff --git a/src/chat/focus_chat/heartFC_sender.py b/src/chat/focus_chat/heartFC_sender.py
index 057668579..81d463b02 100644
--- a/src/chat/focus_chat/heartFC_sender.py
+++ b/src/chat/focus_chat/heartFC_sender.py
@@ -106,6 +106,7 @@ class HeartFCSender:
and not message.is_private_message()
and message.reply.processed_plain_text != "[System Trigger Context]"
):
+ message.set_reply(message.reply)
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
await message.process()
diff --git a/src/chat/focus_chat/heartflow_processor.py b/src/chat/focus_chat/heartflow_processor.py
index bbfa4ce46..a4cf360a5 100644
--- a/src/chat/focus_chat/heartflow_processor.py
+++ b/src/chat/focus_chat/heartflow_processor.py
@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否包含过滤词
"""
- for word in global_config.ban_words:
+ for word in global_config.chat.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
- for pattern in global_config.ban_msgs_regex:
+ for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py
index 55fb79b46..532ceccd1 100644
--- a/src/chat/focus_chat/heartflow_prompt_builder.py
+++ b/src/chat/focus_chat/heartflow_prompt_builder.py
@@ -6,43 +6,21 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
-from typing import Union, Optional, Dict, Any
-from src.common.database import db
+from typing import Union, Optional
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
-from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
-import traceback
import random
+import json
+import math
+from src.common.database.database_model import Knowledges
logger = get_logger("prompt")
def init_prompt():
- Prompt(
- """
-你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
-{style_habbits}
-
-你现在正在群里聊天,以下是群里正在进行的聊天内容:
-{chat_info}
-
-以上是聊天内容,你需要了解聊天记录中的内容
-
-{chat_target}
-你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
-你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
-请你根据情景使用以下句法:
-{grammar_habbits}
-回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
-回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
-现在,你说:
-""",
- "heart_flow_prompt",
- )
-
Prompt(
"""
你有以下信息可供参考:
@@ -69,7 +47,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1},
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
-请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
+请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
{moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
"reasoning_prompt_main",
@@ -82,29 +60,6 @@ def init_prompt():
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
- # --- Template for HeartFChatting (FOCUSED mode) ---
- Prompt(
- """
-{info_from_tools}
-你正在和 {sender_name} 私聊。
-聊天记录如下:
-{chat_talking_prompt}
-现在你想要回复。
-
-你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality}"。
-你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,然后给出日常且口语化的回复,平淡一些。
-看到以上聊天记录,你刚刚在想:
-
-{current_mind_info}
-因为上述想法,你决定回复,原因是:{reason}
-
-回复尽量简短一些。请注意把握聊天内容,{reply_style2}。{prompt_ger},不要复读自己说的话
-{reply_style1},说中文,不要刻意突出自身学科背景,注意只输出回复内容。
-{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
- "heart_flow_private_prompt", # New template for private FOCUSED chat
- )
-
- # --- Template for NormalChat (CHAT mode) ---
Prompt(
"""
{memory_prompt}
@@ -126,118 +81,6 @@ def init_prompt():
)
-async def _build_prompt_focus(
- reason, current_mind_info, structured_info, chat_stream, sender_name, in_mind_reply, target_message
-) -> str:
- individuality = Individuality.get_instance()
- prompt_personality = individuality.get_prompt(x_person=0, level=2)
-
- # Determine if it's a group chat
- is_group_chat = bool(chat_stream.group_info)
-
- # Use sender_name passed from caller for private chat, otherwise use a default for group
- # Default sender_name for group chat isn't used in the group prompt template, but set for consistency
- effective_sender_name = sender_name if not is_group_chat else "某人"
-
- message_list_before_now = get_raw_msg_before_timestamp_with_chat(
- chat_id=chat_stream.stream_id,
- timestamp=time.time(),
- limit=global_config.observation_context_size,
- )
- chat_talking_prompt = await build_readable_messages(
- message_list_before_now,
- replace_bot_name=True,
- merge_messages=True,
- timestamp_mode="relative",
- read_mark=0.0,
- truncate=True,
- )
-
- if structured_info:
- structured_info_prompt = await global_prompt_manager.format_prompt(
- "info_from_tools", structured_info=structured_info
- )
- else:
- structured_info_prompt = ""
-
- # 从/data/expression/对应chat_id/expressions.json中读取表达方式
- (
- learnt_style_expressions,
- learnt_grammar_expressions,
- personality_expressions,
- ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
-
- style_habbits = []
- grammar_habbits = []
- # 1. learnt_expressions加权随机选3条
- if learnt_style_expressions:
- weights = [expr["count"] for expr in learnt_style_expressions]
- selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
- for expr in selected_learnt:
- if isinstance(expr, dict) and "situation" in expr and "style" in expr:
- style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
- # 2. learnt_grammar_expressions加权随机选3条
- if learnt_grammar_expressions:
- weights = [expr["count"] for expr in learnt_grammar_expressions]
- selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
- for expr in selected_learnt:
- if isinstance(expr, dict) and "situation" in expr and "style" in expr:
- grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
- # 3. personality_expressions随机选1条
- if personality_expressions:
- expr = random.choice(personality_expressions)
- if isinstance(expr, dict) and "situation" in expr and "style" in expr:
- style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
-
- style_habbits_str = "\n".join(style_habbits)
- grammar_habbits_str = "\n".join(grammar_habbits)
-
- logger.debug("开始构建 focus prompt")
-
- # --- Choose template based on chat type ---
- if is_group_chat:
- template_name = "heart_flow_prompt"
- # Group specific formatting variables (already fetched or default)
- chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
- # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
-
- prompt = await global_prompt_manager.format_prompt(
- template_name,
- # info_from_tools=structured_info_prompt,
- style_habbits=style_habbits_str,
- grammar_habbits=grammar_habbits_str,
- chat_target=chat_target_1, # Used in group template
- # chat_talking_prompt=chat_talking_prompt,
- chat_info=chat_talking_prompt,
- bot_name=global_config.BOT_NICKNAME,
- # prompt_personality=prompt_personality,
- prompt_personality="",
- reason=reason,
- in_mind_reply=in_mind_reply,
- target_message=target_message,
- # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
- # sender_name is not used in the group template
- )
- else: # Private chat
- template_name = "heart_flow_private_prompt"
- prompt = await global_prompt_manager.format_prompt(
- template_name,
- info_from_tools=structured_info_prompt,
- sender_name=effective_sender_name, # Used in private template
- chat_talking_prompt=chat_talking_prompt,
- bot_name=global_config.BOT_NICKNAME,
- prompt_personality=prompt_personality,
- # chat_target and chat_target_2 are not used in private template
- current_mind_info=current_mind_info,
- reason=reason,
- moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
- )
- # --- End choosing template ---
-
- # logger.debug(f"focus_chat_prompt (is_group={is_group_chat}): \n{prompt}")
- return prompt
-
-
class PromptBuilder:
def __init__(self):
self.prompt_built = ""
@@ -257,17 +100,6 @@ class PromptBuilder:
) -> Optional[str]:
if build_mode == "normal":
return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name)
-
- elif build_mode == "focus":
- return await _build_prompt_focus(
- reason,
- current_mind_info,
- structured_info,
- chat_stream,
- sender_name,
- in_mind_reply,
- target_message,
- )
return None
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str:
@@ -280,7 +112,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
- limit=global_config.observation_context_size,
+ limit=global_config.chat.observation_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
@@ -328,7 +160,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
- limit=global_config.observation_context_size,
+ limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -340,18 +172,15 @@ class PromptBuilder:
# 关键词检测与反应
keywords_reaction_prompt = ""
- for rule in global_config.keywords_reaction_rules:
- if rule.get("enable", False):
- if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
- logger.info(
- f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
- )
- keywords_reaction_prompt += rule.get("reaction", "") + ","
+ for rule in global_config.keyword_reaction.rules:
+ if rule.enable:
+ if any(keyword in message_txt for keyword in rule.keywords):
+ logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
+ keywords_reaction_prompt += f"{rule.reaction},"
else:
- for pattern in rule.get("regex", []):
- result = pattern.search(message_txt)
- if result:
- reaction = rule.get("reaction", "")
+ for pattern in rule.regex:
+ if result := pattern.search(message_txt):
+ reaction = rule.reaction
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
@@ -397,15 +226,16 @@ class PromptBuilder:
chat_target_2=chat_target_2,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
- bot_name=global_config.BOT_NICKNAME,
- bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
+ bot_name=global_config.bot.nickname,
+ bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
- moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
+ # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
+ moderation_prompt="",
)
else:
template_name = "reasoning_prompt_private_main"
@@ -419,15 +249,16 @@ class PromptBuilder:
prompt_info=prompt_info,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
- bot_name=global_config.BOT_NICKNAME,
- bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
+ bot_name=global_config.bot.nickname,
+ bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
- moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
+ # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
+ moderation_prompt="",
)
# --- End choosing template ---
@@ -439,30 +270,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题,类似于记忆系统的做法
topics = []
- # try:
- # # 先尝试使用记忆系统的方法获取主题
- # hippocampus = HippocampusManager.get_instance()._hippocampus
- # topic_num = min(5, max(1, int(len(message) * 0.1)))
- # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
-
- # # 提取关键词
- # topics = re.findall(r"<([^>]+)>", topics_response[0])
- # if not topics:
- # topics = []
- # else:
- # topics = [
- # topic.strip()
- # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
- # if topic.strip()
- # ]
-
- # logger.info(f"从LLM提取的主题: {', '.join(topics)}")
- # except Exception as e:
- # logger.error(f"从LLM提取主题失败: {str(e)}")
- # # 如果LLM提取失败,使用jieba分词提取关键词作为备选
- # words = jieba.cut(message)
- # topics = [word for word in words if len(word) > 1][:5]
- # logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
@@ -572,8 +379,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
- # 调试:为内容添加序号和相似度信息
- # related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
@@ -602,14 +407,14 @@ class PromptBuilder:
return related_info
else:
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
- knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
+ knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
try:
- knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
+ knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -625,104 +430,70 @@ class PromptBuilder:
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
- # 使用余弦相似度计算
- pipeline = [
- {
- "$addFields": {
- "dotProduct": {
- "$reduce": {
- "input": {"$range": [0, {"$size": "$embedding"}]},
- "initialValue": 0,
- "in": {
- "$add": [
- "$$value",
- {
- "$multiply": [
- {"$arrayElemAt": ["$embedding", "$$this"]},
- {"$arrayElemAt": [query_embedding, "$$this"]},
- ]
- },
- ]
- },
- }
- },
- "magnitude1": {
- "$sqrt": {
- "$reduce": {
- "input": "$embedding",
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- "magnitude2": {
- "$sqrt": {
- "$reduce": {
- "input": query_embedding,
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- }
- },
- {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
- {
- "$match": {
- "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
- }
- },
- {"$sort": {"similarity": -1}},
- {"$limit": limit},
- {"$project": {"content": 1, "similarity": 1}},
- ]
- results = list(db.knowledges.aggregate(pipeline))
- logger.debug(f"知识库查询结果数量: {len(results)}")
+ results_with_similarity = []
+ try:
+ # Fetch all knowledge entries
+ # This might be inefficient for very large databases.
+ # Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
+ all_knowledges = Knowledges.select()
- if not results:
+ if not all_knowledges:
+ return [] if return_raw else ""
+
+ query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
+ if query_embedding_magnitude == 0: # Avoid division by zero
+ return "" if not return_raw else []
+
+ for knowledge_item in all_knowledges:
+ try:
+ db_embedding_str = knowledge_item.embedding
+ db_embedding = json.loads(db_embedding_str)
+
+ if len(db_embedding) != len(query_embedding):
+ logger.warning(
+ f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
+ )
+ continue
+
+ # Calculate Cosine Similarity
+ dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
+ db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
+
+ if db_embedding_magnitude == 0: # Avoid division by zero
+ similarity = 0.0
+ else:
+ similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
+
+ if similarity >= threshold:
+ results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
+ )
+ except Exception as e:
+ logger.error(f"Error processing knowledge item: {e}")
+
+ # Sort by similarity in descending order
+ results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
+
+ # Limit results
+ limited_results = results_with_similarity[:limit]
+
+ logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
+
+ if not limited_results:
+ return "" if not return_raw else []
+
+ if return_raw:
+ return limited_results
+ else:
+ return "\n".join(str(result["content"]) for result in limited_results)
+
+ except Exception as e:
+ logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
- if return_raw:
- return results
- else:
- # 返回所有找到的内容,用换行分隔
- return "\n".join(str(result["content"]) for result in results)
-
-
-def weighted_sample_no_replacement(items, weights, k) -> list:
- """
- 加权且不放回地随机抽取k个元素。
-
- 参数:
- items: 待抽取的元素列表
- weights: 每个元素对应的权重(与items等长,且为正数)
- k: 需要抽取的元素个数
- 返回:
- selected: 按权重加权且不重复抽取的k个元素组成的列表
-
- 如果 items 中的元素不足 k 个,就只会返回所有可用的元素
-
- 实现思路:
- 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
- 这样保证了:
- 1. count越大被选中概率越高
- 2. 不会重复选中同一个元素
- """
- selected = []
- pool = list(zip(items, weights))
- for _ in range(min(k, len(pool))):
- total = sum(w for _, w in pool)
- r = random.uniform(0, total)
- upto = 0
- for idx, (item, weight) in enumerate(pool):
- upto += weight
- if upto >= r:
- selected.append(item)
- pool.pop(idx)
- break
- return selected
-
init_prompt()
prompt_builder = PromptBuilder()
diff --git a/src/chat/focus_chat/info/action_info.py b/src/chat/focus_chat/info/action_info.py
new file mode 100644
index 000000000..1bb6b96a6
--- /dev/null
+++ b/src/chat/focus_chat/info/action_info.py
@@ -0,0 +1,83 @@
+from typing import Dict, Optional, Any, List
+from dataclasses import dataclass
+from .info_base import InfoBase
+
+
+@dataclass
+class ActionInfo(InfoBase):
+ """动作信息类
+
+ 用于管理和记录动作的变更信息,包括需要添加或移除的动作。
+ 继承自 InfoBase 类,使用字典存储具体数据。
+
+ Attributes:
+ type (str): 信息类型标识符,固定为 "action"
+
+ Data Fields:
+ add_actions (List[str]): 需要添加的动作列表
+ remove_actions (List[str]): 需要移除的动作列表
+ reason (str): 变更原因说明
+ """
+
+ type: str = "action"
+
+ def get_type(self) -> str:
+ """获取信息类型"""
+ return self.type
+
+ def get_data(self) -> Dict[str, Any]:
+ """获取信息数据"""
+ return self.data
+
+ def set_action_changes(self, action_changes: Dict[str, List[str]]) -> None:
+ """设置动作变更信息
+
+ Args:
+ action_changes (Dict[str, List[str]]): 包含要增加和删除的动作列表
+ {
+ "add": ["action1", "action2"],
+ "remove": ["action3"]
+ }
+ """
+ self.data["add_actions"] = action_changes.get("add", [])
+ self.data["remove_actions"] = action_changes.get("remove", [])
+
+ def set_reason(self, reason: str) -> None:
+ """设置变更原因
+
+ Args:
+ reason (str): 动作变更的原因说明
+ """
+ self.data["reason"] = reason
+
+ def get_add_actions(self) -> List[str]:
+ """获取需要添加的动作列表
+
+ Returns:
+ List[str]: 需要添加的动作列表
+ """
+ return self.data.get("add_actions", [])
+
+ def get_remove_actions(self) -> List[str]:
+ """获取需要移除的动作列表
+
+ Returns:
+ List[str]: 需要移除的动作列表
+ """
+ return self.data.get("remove_actions", [])
+
+ def get_reason(self) -> Optional[str]:
+ """获取变更原因
+
+ Returns:
+ Optional[str]: 动作变更的原因说明,如果未设置则返回 None
+ """
+ return self.data.get("reason")
+
+ def has_changes(self) -> bool:
+ """检查是否有动作变更
+
+ Returns:
+ bool: 如果有任何动作需要添加或移除则返回True
+ """
+ return bool(self.get_add_actions() or self.get_remove_actions())
\ No newline at end of file
diff --git a/src/chat/focus_chat/info/info_base.py b/src/chat/focus_chat/info/info_base.py
index 7779d913a..53ad30230 100644
--- a/src/chat/focus_chat/info/info_base.py
+++ b/src/chat/focus_chat/info/info_base.py
@@ -17,6 +17,7 @@ class InfoBase:
type: str = "base"
data: Dict[str, Any] = field(default_factory=dict)
+ processed_info: str = ""
def get_type(self) -> str:
"""获取信息类型
@@ -58,3 +59,11 @@ class InfoBase:
if isinstance(value, list):
return value
return []
+
+ def get_processed_info(self) -> str:
+ """获取处理后的信息
+
+ Returns:
+ str: 处理后的信息字符串
+ """
+ return self.processed_info
diff --git a/src/chat/focus_chat/info/self_info.py b/src/chat/focus_chat/info/self_info.py
new file mode 100644
index 000000000..866457956
--- /dev/null
+++ b/src/chat/focus_chat/info/self_info.py
@@ -0,0 +1,40 @@
+from dataclasses import dataclass
+from .info_base import InfoBase
+
+
+@dataclass
+class SelfInfo(InfoBase):
+ """思维信息类
+
+ 用于存储和管理当前思维状态的信息。
+
+ Attributes:
+ type (str): 信息类型标识符,默认为 "mind"
+ data (Dict[str, Any]): 包含 current_mind 的数据字典
+ """
+
+ type: str = "self"
+
+ def get_self_info(self) -> str:
+ """获取当前思维状态
+
+ Returns:
+ str: 当前思维状态
+ """
+ return self.get_info("self_info") or ""
+
+ def set_self_info(self, self_info: str) -> None:
+ """设置当前思维状态
+
+ Args:
+ self_info: 要设置的思维状态
+ """
+ self.data["self_info"] = self_info
+
+ def get_processed_info(self) -> str:
+ """获取处理后的信息
+
+ Returns:
+ str: 处理后的信息
+ """
+ return self.get_self_info()
diff --git a/src/chat/focus_chat/info/workingmemory_info.py b/src/chat/focus_chat/info/workingmemory_info.py
new file mode 100644
index 000000000..0edce8944
--- /dev/null
+++ b/src/chat/focus_chat/info/workingmemory_info.py
@@ -0,0 +1,89 @@
+from typing import Dict, Optional, List
+from dataclasses import dataclass
+from .info_base import InfoBase
+
+
+@dataclass
+class WorkingMemoryInfo(InfoBase):
+ type: str = "workingmemory"
+
+ processed_info: str = ""
+
+ def set_talking_message(self, message: str) -> None:
+ """设置说话消息
+
+ Args:
+ message (str): 说话消息内容
+ """
+ self.data["talking_message"] = message
+
+ def set_working_memory(self, working_memory: List[str]) -> None:
+ """设置工作记忆
+
+ Args:
+ working_memory (str): 工作记忆内容
+ """
+ self.data["working_memory"] = working_memory
+
+ def add_working_memory(self, working_memory: str) -> None:
+ """添加工作记忆
+
+ Args:
+ working_memory (str): 工作记忆内容
+ """
+ working_memory_list = self.data.get("working_memory", [])
+ # print(f"working_memory_list: {working_memory_list}")
+ working_memory_list.append(working_memory)
+ # print(f"working_memory_list: {working_memory_list}")
+ self.data["working_memory"] = working_memory_list
+
+ def get_working_memory(self) -> List[str]:
+ """获取工作记忆
+
+ Returns:
+ List[str]: 工作记忆内容
+ """
+ return self.data.get("working_memory", [])
+
+ def get_type(self) -> str:
+ """获取信息类型
+
+ Returns:
+ str: 当前信息对象的类型标识符
+ """
+ return self.type
+
+ def get_data(self) -> Dict[str, str]:
+ """获取所有信息数据
+
+ Returns:
+ Dict[str, str]: 包含所有信息数据的字典
+ """
+ return self.data
+
+ def get_info(self, key: str) -> Optional[str]:
+ """获取特定属性的信息
+
+ Args:
+ key: 要获取的属性键名
+
+ Returns:
+ Optional[str]: 属性值,如果键不存在则返回 None
+ """
+ return self.data.get(key)
+
+ def get_processed_info(self) -> Dict[str, str]:
+ """获取处理后的信息
+
+ Returns:
+ Dict[str, str]: 处理后的信息数据
+ """
+ all_memory = self.get_working_memory()
+ # print(f"all_memory: {all_memory}")
+ memory_str = ""
+ for memory in all_memory:
+ memory_str += f"{memory}\n"
+
+ self.processed_info = memory_str
+
+ return self.processed_info
diff --git a/src/chat/focus_chat/info_processors/action_processor.py b/src/chat/focus_chat/info_processors/action_processor.py
new file mode 100644
index 000000000..a952b38c8
--- /dev/null
+++ b/src/chat/focus_chat/info_processors/action_processor.py
@@ -0,0 +1,126 @@
+from typing import List, Optional, Any
+from src.chat.focus_chat.info.obs_info import ObsInfo
+from src.chat.heart_flow.observation.observation import Observation
+from src.chat.focus_chat.info.info_base import InfoBase
+from src.chat.focus_chat.info.action_info import ActionInfo
+from .base_processor import BaseProcessor
+from src.common.logger_manager import get_logger
+from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
+from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
+from src.chat.focus_chat.info.cycle_info import CycleInfo
+from datetime import datetime
+from typing import Dict
+from src.chat.models.utils_model import LLMRequest
+from src.config.config import global_config
+import random
+
+logger = get_logger("processor")
+
+
+class ActionProcessor(BaseProcessor):
+ """动作处理器
+
+ 用于处理Observation对象,将其转换为ObsInfo对象。
+ """
+
+ log_prefix = "聊天信息处理"
+
+ def __init__(self):
+ """初始化观察处理器"""
+ super().__init__()
+ # TODO: API-Adapter修改标记
+ self.model_summary = LLMRequest(
+ model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
+ )
+
+ async def process_info(
+ self,
+ observations: Optional[List[Observation]] = None,
+ running_memorys: Optional[List[Dict]] = None,
+ **kwargs: Any,
+ ) -> List[InfoBase]:
+ """处理Observation对象
+
+ Args:
+ infos: InfoBase对象列表
+ observations: 可选的Observation对象列表
+ **kwargs: 其他可选参数
+
+ Returns:
+ List[InfoBase]: 处理后的ObsInfo实例列表
+ """
+ # print(f"observations: {observations}")
+ processed_infos = []
+
+ # 处理Observation对象
+ if observations:
+ for obs in observations:
+
+ if isinstance(obs, HFCloopObservation):
+
+
+ # 创建动作信息
+ action_info = ActionInfo()
+ action_changes = await self.analyze_loop_actions(obs)
+ if action_changes["add"] or action_changes["remove"]:
+ action_info.set_action_changes(action_changes)
+ # 设置变更原因
+ reasons = []
+ if action_changes["add"]:
+ reasons.append(f"添加动作{action_changes['add']}因为检测到大量无回复")
+ if action_changes["remove"]:
+ reasons.append(f"移除动作{action_changes['remove']}因为检测到连续回复")
+ action_info.set_reason(" | ".join(reasons))
+ processed_infos.append(action_info)
+
+ return processed_infos
+
+
+ async def analyze_loop_actions(self, obs: HFCloopObservation) -> Dict[str, List[str]]:
+ """分析最近的循环内容并决定动作的增减
+
+ Returns:
+ Dict[str, List[str]]: 包含要增加和删除的动作
+ {
+ "add": ["action1", "action2"],
+ "remove": ["action3"]
+ }
+ """
+ result = {"add": [], "remove": []}
+
+ # 获取最近10次循环
+ recent_cycles = obs.history_loop[-10:] if len(obs.history_loop) > 10 else obs.history_loop
+ if not recent_cycles:
+ return result
+
+ # 统计no_reply的数量
+ no_reply_count = 0
+ reply_sequence = [] # 记录最近的动作序列
+
+ for cycle in recent_cycles:
+ action_type = cycle.loop_plan_info["action_result"]["action_type"]
+ if action_type == "no_reply":
+ no_reply_count += 1
+ reply_sequence.append(action_type == "reply")
+
+ # 检查no_reply比例
+ if len(recent_cycles) >= 5 and (no_reply_count / len(recent_cycles)) >= 0.8:
+ result["add"].append("exit_focus_chat")
+
+ # 获取最近三次的reply状态
+ last_three = reply_sequence[-3:] if len(reply_sequence) >= 3 else reply_sequence
+
+ # 根据最近的reply情况决定是否移除reply动作
+ if len(last_three) >= 3 and all(last_three):
+ # 如果最近三次都是reply,直接移除
+ result["remove"].append("reply")
+ elif len(last_three) >= 2 and all(last_three[-2:]):
+ # 如果最近两次都是reply,40%概率移除
+ if random.random() < 0.4:
+ result["remove"].append("reply")
+ elif last_three and last_three[-1]:
+ # 如果最近一次是reply,20%概率移除
+ if random.random() < 0.2:
+ result["remove"].append("reply")
+
+ return result
diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py
index 12bc8560a..5b46d16bb 100644
--- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py
+++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py
@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
def __init__(self):
"""初始化观察处理器"""
super().__init__()
- self.llm_summary = LLMRequest(
- model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
+ # TODO: API-Adapter修改标记
+ self.model_summary = LLMRequest(
+ model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def process_info(
@@ -54,19 +55,24 @@ class ChattingInfoProcessor(BaseProcessor):
for obs in observations:
# print(f"obs: {obs}")
if isinstance(obs, ChattingObservation):
+ # print("1111111111111111111111读取111111111111111")
+
obs_info = ObsInfo()
await self.chat_compress(obs)
# 设置说话消息
if hasattr(obs, "talking_message_str"):
+ # print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}")
obs_info.set_talking_message(obs.talking_message_str)
# 设置截断后的说话消息
if hasattr(obs, "talking_message_str_truncate"):
+ # print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}")
obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate)
if hasattr(obs, "mid_memory_info"):
+ # print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}")
obs_info.set_previous_chat_info(obs.mid_memory_info)
# 设置聊天类型
@@ -91,7 +97,7 @@ class ChattingInfoProcessor(BaseProcessor):
async def chat_compress(self, obs: ChattingObservation):
if obs.compressor_prompt:
try:
- summary_result, _, _ = await self.llm_summary.generate_response(obs.compressor_prompt)
+ summary_result, _, _ = await self.model_summary.generate_response(obs.compressor_prompt)
summary = "没有主题的闲聊" # 默认值
if summary_result: # 确保结果不为空
summary = summary_result
@@ -108,12 +114,12 @@ class ChattingInfoProcessor(BaseProcessor):
"created_at": datetime.now().timestamp(),
}
- obs.mid_memorys.append(mid_memory)
- if len(obs.mid_memorys) > obs.max_mid_memory_len:
- obs.mid_memorys.pop(0) # 移除最旧的
+ obs.mid_memories.append(mid_memory)
+ if len(obs.mid_memories) > obs.max_mid_memory_len:
+ obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n"
- for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分
+ for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}):{mid_memory_item['theme']}\n"
diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py
index 1a104e123..afd7921d4 100644
--- a/src/chat/focus_chat/info_processors/mind_processor.py
+++ b/src/chat/focus_chat/info_processors/mind_processor.py
@@ -6,21 +6,14 @@ import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
-import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
-import difflib
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
-from src.chat.focus_chat.info_processors.processor_utils import (
- calculate_similarity,
- calculate_replacement_probability,
- get_spark,
-)
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
@@ -28,7 +21,6 @@ logger = get_logger("processor")
def init_prompt():
- # --- Group Chat Prompt ---
group_prompt = """
你的名字是{bot_name}
{memory_str}
@@ -44,31 +36,29 @@ def init_prompt():
现在请你继续输出观察和规划,输出要求:
1. 先关注未读新消息的内容和近期回复历史
2. 根据新信息,修改和删除之前的观察和规划
-3. 根据聊天内容继续输出观察和规划,{hf_do_next}
+3. 根据聊天内容继续输出观察和规划
4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
Prompt(group_prompt, "sub_heartflow_prompt_before")
- # --- Private Chat Prompt ---
private_prompt = """
+你的名字是{bot_name}
{memory_str}
{extra_info}
{relation_prompt}
-你的名字是{bot_name},{prompt_personality},你现在{mood_info}
{cycle_info_block}
-现在是{time_now},你正在上网,和 {chat_target_name} 私聊,以下是你们的聊天内容:
+现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
{chat_observe_info}
-以下是你之前对聊天的观察和规划:
+
+以下是你之前对聊天的观察和规划,你的名字是{bot_name}:
{last_mind}
-请仔细阅读聊天内容,想想你和 {chat_target_name} 的关系,回顾你们刚刚的交流,你刚刚发言和对方的反应,思考聊天的主题。
-请思考你要不要回复以及如何回复对方。
-思考并输出你的内心想法
-输出要求:
-1. 根据聊天内容生成你的想法,{hf_do_next}
-2. 不要分点、不要使用表情符号
-3. 避免多余符号(冒号、引号、括号等)
-4. 语言简洁自然,不要浮夸
-5. 如果你刚发言,对方没有回复你,请谨慎回复"""
+
+现在请你继续输出观察和规划,输出要求:
+1. 先关注未读新消息的内容和近期回复历史
+2. 根据新信息,修改和删除之前的观察和规划
+3. 根据聊天内容继续输出观察和规划
+4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
+6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
Prompt(private_prompt, "sub_heartflow_prompt_private_before")
@@ -81,8 +71,8 @@ class MindProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
- model=global_config.llm_sub_heartflow,
- temperature=global_config.llm_sub_heartflow["temp"],
+ model=global_config.model.sub_heartflow,
+ temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="sub_heart_flow",
)
@@ -210,45 +200,26 @@ class MindProcessor(BaseProcessor):
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
- # 构建个性部分
- # prompt_personality = individuality.get_prompt(x_person=2, level=2)
-
- # 获取当前时间
- time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
-
- spark_prompt = get_spark()
-
- # ---------- 5. 构建最终提示词 ----------
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
prompt = (await global_prompt_manager.get_prompt_async(template_name)).format(
+ bot_name=individuality.name,
memory_str=memory_str,
extra_info=self.structured_info_str,
- # prompt_personality=prompt_personality,
relation_prompt=relation_prompt,
- bot_name=individuality.name,
- time_now=time_now,
+ time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_observe_info,
- # mood_info="mood_info",
- hf_do_next=spark_prompt,
last_mind=previous_mind,
cycle_info_block=hfcloop_observe_info,
chat_target_name=chat_target_name,
)
- # 在构建完提示词后,生成最终的prompt字符串
- final_prompt = prompt
-
- content = "" # 初始化内容变量
-
+ content = "(不知道该想些什么...)"
try:
- # 调用LLM生成响应
- response, _ = await self.llm_model.generate_response_async(prompt=final_prompt)
-
- # 直接使用LLM返回的文本响应作为 content
- content = response if response else ""
-
+ content, _ = await self.llm_model.generate_response_async(prompt=prompt)
+ if not content:
+ logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。")
except Exception as e:
# 处理总体异常
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
@@ -256,16 +227,8 @@ class MindProcessor(BaseProcessor):
content = "思考过程中出现错误"
# 记录初步思考结果
- logger.debug(f"{self.log_prefix} 思考prompt: \n{final_prompt}\n")
-
- # 处理空响应情况
- if not content:
- content = "(不知道该想些什么...)"
- logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。")
-
- # ---------- 8. 更新思考状态并返回结果 ----------
+ logger.debug(f"{self.log_prefix} 思考prompt: \n{prompt}\n")
logger.info(f"{self.log_prefix} 思考结果: {content}")
- # 更新当前思考内容
self.update_current_mind(content)
return content
@@ -275,138 +238,5 @@ class MindProcessor(BaseProcessor):
self.past_mind.append(self.current_mind)
self.current_mind = response
- def de_similar(self, previous_mind, new_content):
- try:
- similarity = calculate_similarity(previous_mind, new_content)
- replacement_prob = calculate_replacement_probability(similarity)
- logger.debug(f"{self.log_prefix} 新旧想法相似度: {similarity:.2f}, 替换概率: {replacement_prob:.2f}")
-
- # 定义词语列表 (移到判断之前)
- yu_qi_ci_liebiao = ["嗯", "哦", "啊", "唉", "哈", "唔"]
- zhuan_zhe_liebiao = ["但是", "不过", "然而", "可是", "只是"]
- cheng_jie_liebiao = ["然后", "接着", "此外", "而且", "另外"]
- zhuan_jie_ci_liebiao = zhuan_zhe_liebiao + cheng_jie_liebiao
-
- if random.random() < replacement_prob:
- # 相似度非常高时,尝试去重或特殊处理
- if similarity == 1.0:
- logger.debug(f"{self.log_prefix} 想法完全重复 (相似度 1.0),执行特殊处理...")
- # 随机截取大约一半内容
- if len(new_content) > 1: # 避免内容过短无法截取
- split_point = max(
- 1, len(new_content) // 2 + random.randint(-len(new_content) // 4, len(new_content) // 4)
- )
- truncated_content = new_content[:split_point]
- else:
- truncated_content = new_content # 如果只有一个字符或者为空,就不截取了
-
- # 添加语气词和转折/承接词
- yu_qi_ci = random.choice(yu_qi_ci_liebiao)
- zhuan_jie_ci = random.choice(zhuan_jie_ci_liebiao)
- content = f"{yu_qi_ci}{zhuan_jie_ci},{truncated_content}"
- logger.debug(f"{self.log_prefix} 想法重复,特殊处理后: {content}")
-
- else:
- # 相似度较高但非100%,执行标准去重逻辑
- logger.debug(f"{self.log_prefix} 执行概率性去重 (概率: {replacement_prob:.2f})...")
- logger.debug(
- f"{self.log_prefix} previous_mind类型: {type(previous_mind)}, new_content类型: {type(new_content)}"
- )
-
- matcher = difflib.SequenceMatcher(None, previous_mind, new_content)
- logger.debug(f"{self.log_prefix} matcher类型: {type(matcher)}")
-
- deduplicated_parts = []
- last_match_end_in_b = 0
-
- # 获取并记录所有匹配块
- matching_blocks = matcher.get_matching_blocks()
- logger.debug(f"{self.log_prefix} 匹配块数量: {len(matching_blocks)}")
- logger.debug(
- f"{self.log_prefix} 匹配块示例(前3个): {matching_blocks[:3] if len(matching_blocks) > 3 else matching_blocks}"
- )
-
- # get_matching_blocks()返回形如[(i, j, n), ...]的列表,其中i是a中的索引,j是b中的索引,n是匹配的长度
- for idx, match in enumerate(matching_blocks):
- if not isinstance(match, tuple):
- logger.error(f"{self.log_prefix} 匹配块 {idx} 不是元组类型,而是 {type(match)}: {match}")
- continue
-
- try:
- _i, j, n = match # 解包元组为三个变量
- logger.debug(f"{self.log_prefix} 匹配块 {idx}: i={_i}, j={j}, n={n}")
-
- if last_match_end_in_b < j:
- # 确保添加的是字符串,而不是元组
- try:
- non_matching_part = new_content[last_match_end_in_b:j]
- logger.debug(
- f"{self.log_prefix} 添加非匹配部分: '{non_matching_part}', 类型: {type(non_matching_part)}"
- )
- if not isinstance(non_matching_part, str):
- logger.warning(
- f"{self.log_prefix} 非匹配部分不是字符串类型: {type(non_matching_part)}"
- )
- non_matching_part = str(non_matching_part)
- deduplicated_parts.append(non_matching_part)
- except Exception as e:
- logger.error(f"{self.log_prefix} 处理非匹配部分时出错: {e}")
- logger.error(traceback.format_exc())
- last_match_end_in_b = j + n
- except Exception as e:
- logger.error(f"{self.log_prefix} 处理匹配块时出错: {e}")
- logger.error(traceback.format_exc())
-
- logger.debug(f"{self.log_prefix} 去重前部分列表: {deduplicated_parts}")
- logger.debug(f"{self.log_prefix} 列表元素类型: {[type(part) for part in deduplicated_parts]}")
-
- # 确保所有元素都是字符串
- deduplicated_parts = [str(part) for part in deduplicated_parts]
-
- # 防止列表为空
- if not deduplicated_parts:
- logger.warning(f"{self.log_prefix} 去重后列表为空,添加空字符串")
- deduplicated_parts = [""]
-
- logger.debug(f"{self.log_prefix} 处理后的部分列表: {deduplicated_parts}")
-
- try:
- deduplicated_content = "".join(deduplicated_parts).strip()
- logger.debug(f"{self.log_prefix} 拼接后的去重内容: '{deduplicated_content}'")
- except Exception as e:
- logger.error(f"{self.log_prefix} 拼接去重内容时出错: {e}")
- logger.error(traceback.format_exc())
- deduplicated_content = ""
-
- if deduplicated_content:
- # 根据概率决定是否添加词语
- prefix_str = ""
- if random.random() < 0.3: # 30% 概率添加语气词
- prefix_str += random.choice(yu_qi_ci_liebiao)
- if random.random() < 0.7: # 70% 概率添加转折/承接词
- prefix_str += random.choice(zhuan_jie_ci_liebiao)
-
- # 组合最终结果
- if prefix_str:
- content = f"{prefix_str},{deduplicated_content}" # 更新 content
- logger.debug(f"{self.log_prefix} 去重并添加引导词后: {content}")
- else:
- content = deduplicated_content # 更新 content
- logger.debug(f"{self.log_prefix} 去重后 (未添加引导词): {content}")
- else:
- logger.warning(f"{self.log_prefix} 去重后内容为空,保留原始LLM输出: {new_content}")
- content = new_content # 保留原始 content
- else:
- logger.debug(f"{self.log_prefix} 未执行概率性去重 (概率: {replacement_prob:.2f})")
- # content 保持 new_content 不变
-
- except Exception as e:
- logger.error(f"{self.log_prefix} 应用概率性去重或特殊处理时出错: {e}")
- logger.error(traceback.format_exc())
- # 出错时保留原始 content
- content = new_content
-
- return content
-
init_prompt()
diff --git a/src/chat/focus_chat/info_processors/processor_utils.py b/src/chat/focus_chat/info_processors/processor_utils.py
deleted file mode 100644
index 77cdc7a6b..000000000
--- a/src/chat/focus_chat/info_processors/processor_utils.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import difflib
-import random
-import time
-
-
-def calculate_similarity(text_a: str, text_b: str) -> float:
- """
- 计算两个文本字符串的相似度。
- """
- if not text_a or not text_b:
- return 0.0
- matcher = difflib.SequenceMatcher(None, text_a, text_b)
- return matcher.ratio()
-
-
-def calculate_replacement_probability(similarity: float) -> float:
- """
- 根据相似度计算替换的概率。
- 规则:
- - 相似度 <= 0.4: 概率 = 0
- - 相似度 >= 0.9: 概率 = 1
- - 相似度 == 0.6: 概率 = 0.7
- - 0.4 < 相似度 <= 0.6: 线性插值 (0.4, 0) 到 (0.6, 0.7)
- - 0.6 < 相似度 < 0.9: 线性插值 (0.6, 0.7) 到 (0.9, 1.0)
- """
- if similarity <= 0.4:
- return 0.0
- elif similarity >= 0.9:
- return 1.0
- elif 0.4 < similarity <= 0.6:
- # p = 3.5 * s - 1.4
- probability = 3.5 * similarity - 1.4
- return max(0.0, probability)
- else: # 0.6 < similarity < 0.9
- # p = s + 0.1
- probability = similarity + 0.1
- return min(1.0, max(0.0, probability))
-
-
-def get_spark():
- local_random = random.Random()
- current_minute = int(time.strftime("%M"))
- local_random.seed(current_minute)
-
- hf_options = [
- ("可以参考之前的想法,在原来想法的基础上继续思考", 0.2),
- ("可以参考之前的想法,在原来的想法上尝试新的话题", 0.4),
- ("不要太深入", 0.2),
- ("进行深入思考", 0.2),
- ]
- # 加权随机选择思考指导
- hf_do_next = local_random.choices(
- [option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1
- )[0]
-
- return hf_do_next
diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py
new file mode 100644
index 000000000..4e7e8544b
--- /dev/null
+++ b/src/chat/focus_chat/info_processors/self_processor.py
@@ -0,0 +1,164 @@
+from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
+from src.chat.heart_flow.observation.observation import Observation
+from src.chat.models.utils_model import LLMRequest
+from src.config.config import global_config
+import time
+import traceback
+from src.common.logger_manager import get_logger
+from src.individuality.individuality import Individuality
+from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.message_receive.chat_stream import chat_manager
+from src.chat.person_info.relationship_manager import relationship_manager
+from .base_processor import BaseProcessor
+from typing import List, Optional
+from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
+from typing import Dict
+from src.chat.focus_chat.info.info_base import InfoBase
+from src.chat.focus_chat.info.self_info import SelfInfo
+
+logger = get_logger("processor")
+
+
+def init_prompt():
+ indentify_prompt = """
+你的名字是{bot_name},你是一个AI,但是你伪装成了一个人类,你的人格是,{prompt_personality}。
+你的头像形象是一只橙色的鱼,头上有绿色的树叶。
+
+{relation_prompt}
+{memory_str}
+
+现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
+{chat_observe_info}
+
+现在请你根据现有的信息,思考自我认同
+1. 你是一个什么样的人,你和群里的人关系如何
+2. 思考有没有人提到你,或者图片与你有关
+3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
+4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
+
+请回复的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
+请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出内容。
+
+"""
+ Prompt(indentify_prompt, "indentify_prompt")
+
+
+class SelfProcessor(BaseProcessor):
+ log_prefix = "自我认同"
+
+ def __init__(self, subheartflow_id: str):
+ super().__init__()
+
+ self.subheartflow_id = subheartflow_id
+
+ self.llm_model = LLMRequest(
+ model=global_config.model.sub_heartflow,
+ temperature=global_config.model.sub_heartflow["temp"],
+ max_tokens=800,
+ request_type="self_identify",
+ )
+
+ name = chat_manager.get_stream_name(self.subheartflow_id)
+ self.log_prefix = f"[{name}] "
+
+ async def process_info(
+ self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
+ ) -> List[InfoBase]:
+ """处理信息对象
+
+ Args:
+ *infos: 可变数量的InfoBase类型的信息对象
+
+ Returns:
+ List[InfoBase]: 处理后的结构化信息列表
+ """
+ self_info_str = await self.self_indentify(observations, running_memorys)
+
+ if self_info_str:
+ self_info = SelfInfo()
+ self_info.set_self_info(self_info_str)
+ else:
+ self_info = None
+ return None
+
+ return [self_info]
+
+ async def self_indentify(
+ self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None
+ ):
+ """
+ 在回复前进行思考,生成内心想法并收集工具调用结果
+
+ 参数:
+ observations: 观察信息
+
+ 返回:
+ 如果return_prompt为False:
+ tuple: (current_mind, past_mind) 当前想法和过去的想法列表
+ 如果return_prompt为True:
+ tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
+ """
+
+ memory_str = ""
+ if running_memorys:
+ memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
+ for running_memory in running_memorys:
+ memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
+
+ if observations is None:
+ observations = []
+ for observation in observations:
+ if isinstance(observation, ChattingObservation):
+ # 获取聊天元信息
+ is_group_chat = observation.is_group_chat
+ chat_target_info = observation.chat_target_info
+ chat_target_name = "对方" # 私聊默认名称
+ if not is_group_chat and chat_target_info:
+ # 优先使用person_name,其次user_nickname,最后回退到默认值
+ chat_target_name = (
+ chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name
+ )
+ # 获取聊天内容
+ chat_observe_info = observation.get_observe_info()
+ person_list = observation.person_list
+ if isinstance(observation, HFCloopObservation):
+ # hfcloop_observe_info = observation.get_observe_info()
+ pass
+
+ individuality = Individuality.get_instance()
+ personality_block = individuality.get_prompt(x_person=2, level=2)
+
+ relation_prompt = ""
+ for person in person_list:
+ relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
+
+ prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
+ bot_name=individuality.name,
+ prompt_personality=personality_block,
+ memory_str=memory_str,
+ relation_prompt=relation_prompt,
+ time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
+ chat_observe_info=chat_observe_info,
+ )
+
+ content = ""
+ try:
+ content, _ = await self.llm_model.generate_response_async(prompt=prompt)
+ if not content:
+ logger.warning(f"{self.log_prefix} LLM返回空结果,自我识别失败。")
+ except Exception as e:
+ # 处理总体异常
+ logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
+ logger.error(traceback.format_exc())
+ content = "自我识别过程中出现错误"
+
+ if content == "None":
+ content = ""
+ # 记录初步思考结果
+ logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
+ logger.info(f"{self.log_prefix} 自我识别结果: {content}")
+
+ return content
+
+
+init_prompt()
diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py
index 8840c1ae4..de9a9a216 100644
--- a/src/chat/focus_chat/info_processors/tool_processor.py
+++ b/src/chat/focus_chat/info_processors/tool_processor.py
@@ -11,8 +11,8 @@ from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from typing import List, Optional, Dict
from src.chat.heart_flow.observation.observation import Observation
-from src.chat.heart_flow.observation.working_observation import WorkingObservation
from src.chat.focus_chat.info.structured_info import StructuredInfo
+from src.chat.heart_flow.observation.structure_observation import StructureObservation
logger = get_logger("processor")
@@ -24,9 +24,6 @@ def init_prompt():
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
-你要在群聊中扮演以下角色:
-{prompt_personality}
-
你当前的额外信息:
{memory_str}
@@ -52,7 +49,7 @@ class ToolProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
self.llm_model = LLMRequest(
- model=global_config.llm_tool_use,
+ model=global_config.model.tool_use,
max_tokens=500,
request_type="tool_execution",
)
@@ -70,6 +67,8 @@ class ToolProcessor(BaseProcessor):
list: 处理后的结构化信息列表
"""
+ working_infos = []
+
if observations:
for observation in observations:
if isinstance(observation, ChattingObservation):
@@ -77,7 +76,7 @@ class ToolProcessor(BaseProcessor):
# 更新WorkingObservation中的结构化信息
for observation in observations:
- if isinstance(observation, WorkingObservation):
+ if isinstance(observation, StructureObservation):
for structured_info in result:
logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
observation.add_structured_info(structured_info)
@@ -86,8 +85,9 @@ class ToolProcessor(BaseProcessor):
logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}")
structured_info = StructuredInfo()
- for working_info in working_infos:
- structured_info.set_info(working_info.get("type"), working_info.get("content"))
+ if working_infos:
+ for working_info in working_infos:
+ structured_info.set_info(working_info.get("type"), working_info.get("content"))
return [structured_info]
@@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor):
# 获取个性信息
individuality = Individuality.get_instance()
- prompt_personality = individuality.get_prompt(x_person=2, level=2)
+ # prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取时间信息
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -148,14 +148,14 @@ class ToolProcessor(BaseProcessor):
# chat_target_name=chat_target_name,
is_group_chat=is_group_chat,
# relation_prompt=relation_prompt,
- prompt_personality=prompt_personality,
+ # prompt_personality=prompt_personality,
# mood_info=mood_info,
bot_name=individuality.name,
time_now=time_now,
)
# 调用LLM,专注于工具使用
- logger.debug(f"开始执行工具调用{prompt}")
+ # logger.debug(f"开始执行工具调用{prompt}")
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
logger.debug(f"获取到工具原始输出:\n{tool_calls}")
diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py
new file mode 100644
index 000000000..c79c8363d
--- /dev/null
+++ b/src/chat/focus_chat/info_processors/working_memory_processor.py
@@ -0,0 +1,236 @@
+from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
+from src.chat.heart_flow.observation.observation import Observation
+from src.chat.models.utils_model import LLMRequest
+from src.config.config import global_config
+import time
+import traceback
+from src.common.logger_manager import get_logger
+from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.message_receive.chat_stream import chat_manager
+from .base_processor import BaseProcessor
+from src.chat.focus_chat.info.mind_info import MindInfo
+from typing import List, Optional
+from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
+from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
+from typing import Dict
+from src.chat.focus_chat.info.info_base import InfoBase
+from json_repair import repair_json
+from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
+import asyncio
+import json
+
+logger = get_logger("processor")
+
+
+def init_prompt():
+ memory_proces_prompt = """
+你的名字是{bot_name}
+
+现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
+{chat_observe_info}
+
+以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
+{memory_str}
+
+观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true,否则输出 false
+如果当前聊天记录的内容已经被总结,千万不要总结新记忆,输出false
+如果已经总结的记忆包含了当前聊天记录的内容,千万不要总结新记忆,输出false
+如果已经总结的记忆摘要,包含了当前聊天记录的内容,千万不要总结新记忆,输出false
+
+如果有相近的记忆,请合并记忆,输出merge_memory,格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
+
+请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
+```json
+{{
+ "selected_memory_ids": ["id1", "id2", ...],
+ "new_memory": "true" or "false",
+ "merge_memory": [["id1", "id2"], ["id3", "id4"],...]
+
+}}
+```
+"""
+ Prompt(memory_proces_prompt, "prompt_memory_proces")
+
+
+class WorkingMemoryProcessor(BaseProcessor):
+ log_prefix = "工作记忆"
+
+ def __init__(self, subheartflow_id: str):
+ super().__init__()
+
+ self.subheartflow_id = subheartflow_id
+
+ self.llm_model = LLMRequest(
+ model=global_config.model.sub_heartflow,
+ temperature=global_config.model.sub_heartflow["temp"],
+ max_tokens=800,
+ request_type="working_memory",
+ )
+
+ name = chat_manager.get_stream_name(self.subheartflow_id)
+ self.log_prefix = f"[{name}] "
+
+ async def process_info(
+ self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
+ ) -> List[InfoBase]:
+ """处理信息对象
+
+ Args:
+ *infos: 可变数量的InfoBase类型的信息对象
+
+ Returns:
+ List[InfoBase]: 处理后的结构化信息列表
+ """
+ working_memory = None
+ chat_info = ""
+ try:
+ for observation in observations:
+ if isinstance(observation, WorkingMemoryObservation):
+ working_memory = observation.get_observe_info()
+ # working_memory_obs = observation
+ if isinstance(observation, ChattingObservation):
+ chat_info = observation.get_observe_info()
+ # chat_info_truncate = observation.talking_message_str_truncate
+
+ if not working_memory:
+ logger.warning(f"{self.log_prefix} 没有找到工作记忆对象")
+ mind_info = MindInfo()
+ return [mind_info]
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
+ logger.error(traceback.format_exc())
+ return []
+
+ all_memory = working_memory.get_all_memories()
+ memory_prompts = []
+ for memory in all_memory:
+ # memory_content = memory.data
+ memory_summary = memory.summary
+ memory_id = memory.id
+ memory_brief = memory_summary.get("brief")
+ # memory_detailed = memory_summary.get("detailed")
+ memory_keypoints = memory_summary.get("keypoints")
+ memory_events = memory_summary.get("events")
+ memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
+ memory_prompts.append(memory_single_prompt)
+
+ memory_choose_str = "".join(memory_prompts)
+
+ # 使用提示模板进行处理
+ prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
+ bot_name=global_config.bot.nickname,
+ time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
+ chat_observe_info=chat_info,
+ memory_str=memory_choose_str,
+ )
+
+ # 调用LLM处理记忆
+ content = ""
+ try:
+ logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
+
+ content, _ = await self.llm_model.generate_response_async(prompt=prompt)
+ if not content:
+ logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
+ logger.error(traceback.format_exc())
+
+ # 解析LLM返回的JSON
+ try:
+ result = repair_json(content)
+ if isinstance(result, str):
+ result = json.loads(result)
+ if not isinstance(result, dict):
+ logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}")
+ return []
+
+ selected_memory_ids = result.get("selected_memory_ids", [])
+ new_memory = result.get("new_memory", "")
+ merge_memory = result.get("merge_memory", [])
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
+ logger.error(traceback.format_exc())
+ return []
+
+ logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
+
+ # 根据selected_memory_ids,调取记忆
+ memory_str = ""
+ if selected_memory_ids:
+ for memory_id in selected_memory_ids:
+ memory = await working_memory.retrieve_memory(memory_id)
+ if memory:
+ # memory_content = memory.data
+ memory_summary = memory.summary
+ memory_id = memory.id
+ memory_brief = memory_summary.get("brief")
+ # memory_detailed = memory_summary.get("detailed")
+ memory_keypoints = memory_summary.get("keypoints")
+ memory_events = memory_summary.get("events")
+ for keypoint in memory_keypoints:
+ memory_str += f"记忆要点:{keypoint}\n"
+ for event in memory_events:
+ memory_str += f"记忆事件:{event}\n"
+ # memory_str += f"记忆摘要:{memory_detailed}\n"
+ # memory_str += f"记忆主题:{memory_brief}\n"
+
+ working_memory_info = WorkingMemoryInfo()
+ if memory_str:
+ working_memory_info.add_working_memory(memory_str)
+ logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
+ else:
+ logger.warning(f"{self.log_prefix} 没有找到工作记忆")
+
+ # 根据聊天内容添加新记忆
+ if new_memory:
+ # 使用异步方式添加新记忆,不阻塞主流程
+ logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
+ asyncio.create_task(self.add_memory_async(working_memory, chat_info))
+
+ if merge_memory:
+ for merge_pairs in merge_memory:
+ memory1 = await working_memory.retrieve_memory(merge_pairs[0])
+ memory2 = await working_memory.retrieve_memory(merge_pairs[1])
+ if memory1 and memory2:
+ memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
+ memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
+ asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
+
+ return [working_memory_info]
+
+ async def add_memory_async(self, working_memory: WorkingMemory, content: str):
+ """异步添加记忆,不阻塞主流程
+
+ Args:
+ working_memory: 工作记忆对象
+ content: 记忆内容
+ """
+ try:
+ await working_memory.add_memory(content=content, from_source="chat_text")
+ logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
+ logger.error(traceback.format_exc())
+
+ async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
+ """异步合并记忆,不阻塞主流程
+
+ Args:
+ working_memory: 工作记忆对象
+ memory_str: 记忆内容
+ """
+ try:
+ merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
+ logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1} 和 {memory_id2}...")
+ logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}")
+ logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}")
+ logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}")
+ logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}")
+
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
+ logger.error(traceback.format_exc())
+
+
+init_prompt()
diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py
index 2d7fea034..4fcd37302 100644
--- a/src/chat/focus_chat/memory_activator.py
+++ b/src/chat/focus_chat/memory_activator.py
@@ -1,5 +1,5 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
-from src.chat.heart_flow.observation.working_observation import WorkingObservation
+from src.chat.heart_flow.observation.structure_observation import StructureObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
@@ -34,8 +34,9 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
+ # TODO: API-Adapter修改标记
self.summary_model = LLMRequest(
- model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
+ model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
)
self.running_memory = []
@@ -53,7 +54,7 @@ class MemoryActivator:
for observation in observations:
if isinstance(observation, ChattingObservation):
obs_info_text += observation.get_observe_info()
- elif isinstance(observation, WorkingObservation):
+ elif isinstance(observation, StructureObservation):
working_info = observation.get_observe_info()
for working_info_item in working_info:
obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n"
diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_manager.py
similarity index 75%
rename from src/chat/focus_chat/planners/action_factory.py
rename to src/chat/focus_chat/planners/action_manager.py
index 257156a25..60ab0babf 100644
--- a/src/chat/focus_chat/planners/action_factory.py
+++ b/src/chat/focus_chat/planners/action_manager.py
@@ -1,18 +1,18 @@
-from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union
-import os
-import importlib
-from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS
+from typing import Dict, List, Optional, Callable, Coroutine, Type, Any
+from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.common.logger_manager import get_logger
+import importlib
+import pkgutil
+import os
# 导入动作类,确保装饰器被执行
-from src.chat.focus_chat.planners.actions.reply_action import ReplyAction
-from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction
+import src.chat.focus_chat.planners.actions # noqa
-logger = get_logger("action_factory")
+logger = get_logger("action_manager")
# 定义动作信息类型
ActionInfo = Dict[str, Any]
@@ -31,20 +31,18 @@ class ActionManager:
self._using_actions: Dict[str, ActionInfo] = {}
# 临时备份原始使用中的动作
self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None
-
+
# 默认动作集,仅作为快照,用于恢复默认
self._default_actions: Dict[str, ActionInfo] = {}
-
+
# 加载所有已注册动作
self._load_registered_actions()
-
+
+ # 加载插件动作
+ self._load_plugin_actions()
+
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
-
- # logger.info(f"当前可用动作: {list(self._using_actions.keys())}")
- # for action_name, action_info in self._using_actions.items():
- # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
-
def _load_registered_actions(self) -> None:
"""
@@ -54,38 +52,81 @@ class ActionManager:
# 从_ACTION_REGISTRY获取所有已注册动作
for action_name, action_class in _ACTION_REGISTRY.items():
# 获取动作相关信息
- action_description:str = getattr(action_class, "action_description", "")
- action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {})
- action_require:list[str] = getattr(action_class, "action_require", [])
- is_default:bool = getattr(action_class, "default", False)
-
+
+ # 不读取插件动作和基类
+ if action_name == "base_action" or action_name == "plugin_action":
+ continue
+
+ action_description: str = getattr(action_class, "action_description", "")
+ action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
+ action_require: list[str] = getattr(action_class, "action_require", [])
+ is_default: bool = getattr(action_class, "default", False)
+
if action_name and action_description:
# 创建动作信息字典
action_info = {
"description": action_description,
"parameters": action_parameters,
- "require": action_require
+ "require": action_require,
}
-
- # 注册2
- print("注册2")
- print(action_info)
-
+
# 添加到所有已注册的动作
self._registered_actions[action_name] = action_info
-
+
# 添加到默认动作(如果是默认动作)
if is_default:
self._default_actions[action_name] = action_info
-
- logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
- logger.info(f"默认动作: {list(self._default_actions.keys())}")
+
+ # logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
+ # logger.info(f"默认动作: {list(self._default_actions.keys())}")
# for action_name, action_info in self._default_actions.items():
- # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
-
+ # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
+
except Exception as e:
logger.error(f"加载已注册动作失败: {e}")
+ def _load_plugin_actions(self) -> None:
+ """
+ 加载所有插件目录中的动作
+ """
+ try:
+ # 检查插件目录是否存在
+ plugin_path = "src.plugins"
+ plugin_dir = plugin_path.replace(".", os.path.sep)
+ if not os.path.exists(plugin_dir):
+ logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
+ return
+
+ # 导入插件包
+ try:
+ plugins_package = importlib.import_module(plugin_path)
+ except ImportError as e:
+ logger.error(f"导入插件包失败: {e}")
+ return
+
+ # 遍历插件包中的所有子包
+ for _, plugin_name, is_pkg in pkgutil.iter_modules(
+ plugins_package.__path__, plugins_package.__name__ + "."
+ ):
+ if not is_pkg:
+ continue
+
+ # 检查插件是否有actions子包
+ plugin_actions_path = f"{plugin_name}.actions"
+ try:
+ # 尝试导入插件的actions包
+ importlib.import_module(plugin_actions_path)
+ logger.info(f"成功加载插件动作模块: {plugin_actions_path}")
+ except ImportError as e:
+ logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
+ continue
+
+ # 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的)
+ self._load_registered_actions()
+
+ except Exception as e:
+ logger.error(f"加载插件动作失败: {e}")
+
def create_action(
self,
action_name: str,
@@ -96,11 +137,7 @@ class ActionManager:
observations: List[Observation],
expressor: DefaultExpressor,
chat_stream: ChatStream,
- current_cycle: CycleDetail,
log_prefix: str,
- on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
- total_no_reply_count: int = 0,
- total_waiting_time: float = 0.0,
shutting_down: bool = False,
) -> Optional[BaseAction]:
"""
@@ -115,11 +152,7 @@ class ActionManager:
observations: 观察列表
expressor: 表达器
chat_stream: 聊天流
- current_cycle: 当前循环信息
log_prefix: 日志前缀
- on_consecutive_no_reply_callback: 连续不回复回调
- total_no_reply_count: 连续不回复计数
- total_waiting_time: 累计等待时间
shutting_down: 是否正在关闭
Returns:
@@ -129,31 +162,26 @@ class ActionManager:
if action_name not in self._using_actions:
logger.warning(f"当前不可用的动作类型: {action_name}")
return None
-
+
handler_class = _ACTION_REGISTRY.get(action_name)
if not handler_class:
logger.warning(f"未注册的动作类型: {action_name}")
return None
try:
- # 创建动作实例并传递所有必要参数
+ # 创建动作实例
instance = handler_class(
- action_name=action_name,
action_data=action_data,
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
observations=observations,
- on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
- current_cycle=current_cycle,
- log_prefix=log_prefix,
- total_no_reply_count=total_no_reply_count,
- total_waiting_time=total_waiting_time,
- shutting_down=shutting_down,
expressor=expressor,
chat_stream=chat_stream,
+ log_prefix=log_prefix,
+ shutting_down=shutting_down,
)
-
+
return instance
except Exception as e:
@@ -167,7 +195,7 @@ class ActionManager:
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
-
+
def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集"""
return self._using_actions.copy()
@@ -175,21 +203,21 @@ class ActionManager:
def add_action_to_using(self, action_name: str) -> bool:
"""
添加已注册的动作到当前使用的动作集
-
+
Args:
action_name: 动作名称
-
+
Returns:
bool: 添加是否成功
"""
if action_name not in self._registered_actions:
logger.warning(f"添加失败: 动作 {action_name} 未注册")
return False
-
+
if action_name in self._using_actions:
logger.info(f"动作 {action_name} 已经在使用中")
return True
-
+
self._using_actions[action_name] = self._registered_actions[action_name]
logger.info(f"添加动作 {action_name} 到使用集")
return True
@@ -197,17 +225,17 @@ class ActionManager:
def remove_action_from_using(self, action_name: str) -> bool:
"""
从当前使用的动作集中移除指定动作
-
+
Args:
action_name: 动作名称
-
+
Returns:
bool: 移除是否成功
"""
if action_name not in self._using_actions:
logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中")
return False
-
+
del self._using_actions[action_name]
logger.info(f"已从使用集中移除动作 {action_name}")
return True
@@ -215,30 +243,26 @@ class ActionManager:
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
"""
添加新的动作到注册集
-
+
Args:
action_name: 动作名称
description: 动作描述
parameters: 动作参数定义,默认为空字典
require: 动作依赖项,默认为空列表
-
+
Returns:
bool: 添加是否成功
"""
if action_name in self._registered_actions:
return False
-
+
if parameters is None:
parameters = {}
if require is None:
require = []
-
- action_info = {
- "description": description,
- "parameters": parameters,
- "require": require
- }
-
+
+ action_info = {"description": description, "parameters": parameters, "require": require}
+
self._registered_actions[action_name] = action_info
return True
@@ -264,7 +288,7 @@ class ActionManager:
if self._original_actions_backup is not None:
self._using_actions = self._original_actions_backup.copy()
self._original_actions_backup = None
-
+
def restore_default_actions(self) -> None:
"""恢复默认动作集到使用集"""
self._using_actions = self._default_actions.copy()
@@ -273,15 +297,11 @@ class ActionManager:
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
"""
获取指定动作的处理器类
-
+
Args:
action_name: 动作名称
-
+
Returns:
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
"""
return _ACTION_REGISTRY.get(action_name)
-
-
-# 创建全局实例
-ActionFactory = ActionManager()
diff --git a/src/chat/focus_chat/planners/actions/__init__.py b/src/chat/focus_chat/planners/actions/__init__.py
new file mode 100644
index 000000000..3f2baf665
--- /dev/null
+++ b/src/chat/focus_chat/planners/actions/__init__.py
@@ -0,0 +1,5 @@
+# 导入所有动作模块以确保装饰器被执行
+from . import reply_action # noqa
+from . import no_reply_action # noqa
+
+# 在此处添加更多动作模块导入
diff --git a/src/chat/focus_chat/planners/actions/base_action.py b/src/chat/focus_chat/planners/actions/base_action.py
index 7c77c300c..82d259677 100644
--- a/src/chat/focus_chat/planners/actions/base_action.py
+++ b/src/chat/focus_chat/planners/actions/base_action.py
@@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {}
def register_action(cls):
"""
动作注册装饰器
-
+
用法:
@register_action
class MyAction(BaseAction):
@@ -24,22 +24,22 @@ def register_action(cls):
if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"):
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
return cls
-
- action_name = getattr(cls, "action_name")
- action_description = getattr(cls, "action_description")
+
+ action_name = cls.action_name
+ action_description = cls.action_description
is_default = getattr(cls, "default", False)
-
+
if not action_name or not action_description:
logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空")
return cls
-
+
# 将动作类注册到全局注册表
_ACTION_REGISTRY[action_name] = cls
-
+
# 如果是默认动作,添加到默认动作集
if is_default:
_DEFAULT_ACTIONS[action_name] = action_description
-
+
logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}")
return cls
@@ -60,15 +60,14 @@ class BaseAction(ABC):
cycle_timers: 计时器字典
thinking_id: 思考ID
"""
- #每个动作必须实现
- self.action_name:str = "base_action"
- self.action_description:str = "基础动作"
- self.action_parameters:dict = {}
- self.action_require:list[str] = []
-
- self.default:bool = False
-
-
+ # 每个动作必须实现
+ self.action_name: str = "base_action"
+ self.action_description: str = "基础动作"
+ self.action_parameters: dict = {}
+ self.action_require: list[str] = []
+
+ self.default: bool = False
+
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
diff --git a/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py b/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py
new file mode 100644
index 000000000..6aeb68ccd
--- /dev/null
+++ b/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py
@@ -0,0 +1,108 @@
+import asyncio
+import traceback
+from src.common.logger_manager import get_logger
+from src.chat.utils.timer_calculator import Timer
+from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
+from typing import Tuple, List, Callable, Coroutine
+from src.chat.heart_flow.observation.observation import Observation
+from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
+from src.chat.heart_flow.sub_heartflow import SubHeartFlow
+from src.chat.message_receive.chat_stream import ChatStream
+from src.chat.heart_flow.heartflow import heartflow
+from src.chat.heart_flow.sub_heartflow import ChatState
+
+logger = get_logger("action_taken")
+
+
+@register_action
+class ExitFocusChatAction(BaseAction):
+ """退出专注聊天动作处理类
+
+ 处理决定退出专注聊天的动作。
+ 执行后会将所属的sub heartflow转变为normal_chat状态。
+ """
+
+ action_name = "exit_focus_chat"
+ action_description = "退出专注聊天,转为普通聊天模式"
+ action_parameters = {}
+ action_require = [
+ "很长时间没有回复,你决定退出专注聊天",
+ "当前内容不需要持续专注关注,你决定退出专注聊天",
+ "聊天内容已经完成,你决定退出专注聊天",
+ ]
+ default = True
+
+ def __init__(
+ self,
+ action_data: dict,
+ reasoning: str,
+ cycle_timers: dict,
+ thinking_id: str,
+ observations: List[Observation],
+ log_prefix: str,
+ chat_stream: ChatStream,
+ shutting_down: bool = False,
+ **kwargs,
+ ):
+ """初始化退出专注聊天动作处理器
+
+ Args:
+ action_data: 动作数据
+ reasoning: 执行该动作的理由
+ cycle_timers: 计时器字典
+ thinking_id: 思考ID
+ observations: 观察列表
+ log_prefix: 日志前缀
+ shutting_down: 是否正在关闭
+ """
+ super().__init__(action_data, reasoning, cycle_timers, thinking_id)
+ self.observations = observations
+ self.log_prefix = log_prefix
+ self._shutting_down = shutting_down
+ self.chat_id = chat_stream.stream_id
+
+
+
+ async def handle_action(self) -> Tuple[bool, str]:
+ """
+ 处理退出专注聊天的情况
+
+ 工作流程:
+ 1. 将sub heartflow转换为normal_chat状态
+ 2. 等待新消息、超时或关闭信号
+ 3. 根据等待结果更新连续不回复计数
+ 4. 如果达到阈值,触发回调
+
+ Returns:
+ Tuple[bool, str]: (是否执行成功, 状态转换消息)
+ """
+ try:
+ # 转换状态
+ status_message = ""
+ self.sub_heartflow = await heartflow.get_or_create_subheartflow(self.chat_id)
+ if self.sub_heartflow:
+ try:
+ # 转换为normal_chat状态
+ await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT)
+ status_message = "已成功切换到普通聊天模式"
+ logger.info(f"{self.log_prefix} {status_message}")
+ except Exception as e:
+ error_msg = f"切换到普通聊天模式失败: {str(e)}"
+ logger.error(f"{self.log_prefix} {error_msg}")
+ return False, error_msg
+ else:
+ warning_msg = "未找到有效的sub heartflow实例,无法切换状态"
+ logger.warning(f"{self.log_prefix} {warning_msg}")
+ return False, warning_msg
+
+
+ return True, status_message
+
+ except asyncio.CancelledError:
+ logger.info(f"{self.log_prefix} 处理 'exit_focus_chat' 时等待被中断 (CancelledError)")
+ raise
+ except Exception as e:
+ error_msg = f"处理 'exit_focus_chat' 时发生错误: {str(e)}"
+ logger.error(f"{self.log_prefix} {error_msg}")
+ logger.error(traceback.format_exc())
+ return False, error_msg
\ No newline at end of file
diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py
index a29812c7a..6e31d5abb 100644
--- a/src/chat/focus_chat/planners/actions/no_reply_action.py
+++ b/src/chat/focus_chat/planners/actions/no_reply_action.py
@@ -6,14 +6,12 @@ from src.chat.focus_chat.planners.actions.base_action import BaseAction, registe
from typing import Tuple, List, Callable, Coroutine
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
-from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
logger = get_logger("action_taken")
# 常量定义
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
-CONSECUTIVE_NO_REPLY_THRESHOLD = 3 # 连续不回复的阈值
@register_action
@@ -29,7 +27,7 @@ class NoReplyAction(BaseAction):
action_require = [
"话题无关/无聊/不感兴趣/不懂",
"最后一条消息是你自己发的且无人回应你",
- "你发送了太多消息,且无人回复"
+ "你发送了太多消息,且无人回复",
]
default = True
@@ -40,13 +38,9 @@ class NoReplyAction(BaseAction):
cycle_timers: dict,
thinking_id: str,
observations: List[Observation],
- on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
- current_cycle: CycleDetail,
log_prefix: str,
- total_no_reply_count: int = 0,
- total_waiting_time: float = 0.0,
shutting_down: bool = False,
- **kwargs
+ **kwargs,
):
"""初始化不回复动作处理器
@@ -57,20 +51,12 @@ class NoReplyAction(BaseAction):
cycle_timers: 计时器字典
thinking_id: 思考ID
observations: 观察列表
- on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的回调函数
- current_cycle: 当前循环信息
log_prefix: 日志前缀
- total_no_reply_count: 连续不回复计数
- total_waiting_time: 累计等待时间
shutting_down: 是否正在关闭
"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
self.observations = observations
- self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
- self._current_cycle = current_cycle
self.log_prefix = log_prefix
- self.total_no_reply_count = total_no_reply_count
- self.total_waiting_time = total_waiting_time
self._shutting_down = shutting_down
async def handle_action(self) -> Tuple[bool, str]:
@@ -93,37 +79,6 @@ class NoReplyAction(BaseAction):
with Timer("等待新消息", self.cycle_timers):
# 等待新消息、超时或关闭信号,并获取结果
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
- # 从计时器获取实际等待时间
- current_waiting = self.cycle_timers.get("等待新消息", 0.0)
-
- if not self._shutting_down:
- self.total_no_reply_count += 1
- self.total_waiting_time += current_waiting # 累加等待时间
- logger.debug(
- f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, "
- f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}秒"
- )
-
- # 检查是否同时达到次数和时间阈值
- time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD
- if (
- self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD
- and self.total_waiting_time >= time_threshold
- ):
- logger.info(
- f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) "
- f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒),"
- f"调用回调请求状态转换"
- )
- # 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。
- await self.on_consecutive_no_reply_callback()
- elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD:
- # 仅次数达到阈值,但时间未达到
- logger.debug(
- f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) "
- f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调"
- )
- # else: 次数和时间都未达到阈值,不做处理
return True, "" # 不回复动作没有回复文本
diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py
new file mode 100644
index 000000000..94754d021
--- /dev/null
+++ b/src/chat/focus_chat/planners/actions/plugin_action.py
@@ -0,0 +1,203 @@
+import traceback
+from typing import Tuple, Dict, List, Any, Optional
+from src.chat.focus_chat.planners.actions.base_action import BaseAction
+from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
+from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
+from src.common.logger_manager import get_logger
+from src.chat.person_info.person_info import person_info_manager
+from abc import abstractmethod
+
+logger = get_logger("plugin_action")
+
+
+class PluginAction(BaseAction):
+ """插件动作基类
+
+ 封装了主程序内部依赖,提供简化的API接口给插件开发者
+ """
+
+ def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
+ """初始化插件动作基类"""
+ super().__init__(action_data, reasoning, cycle_timers, thinking_id)
+
+ # 存储内部服务和对象引用
+ self._services = {}
+
+ # 从kwargs提取必要的内部服务
+ if "observations" in kwargs:
+ self._services["observations"] = kwargs["observations"]
+ if "expressor" in kwargs:
+ self._services["expressor"] = kwargs["expressor"]
+ if "chat_stream" in kwargs:
+ self._services["chat_stream"] = kwargs["chat_stream"]
+
+ self.log_prefix = kwargs.get("log_prefix", "")
+
+ async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
+ """根据用户名获取用户ID"""
+ person_id = person_info_manager.get_person_id_by_person_name(person_name)
+ user_id = await person_info_manager.get_value(person_id, "user_id")
+ platform = await person_info_manager.get_value(person_id, "platform")
+ return platform, user_id
+
+ # 提供简化的API方法
+ async def send_message(self, text: str, target: Optional[str] = None) -> bool:
+ """发送消息的简化方法
+
+ Args:
+ text: 要发送的消息文本
+ target: 目标消息(可选)
+
+ Returns:
+ bool: 是否发送成功
+ """
+ try:
+ expressor = self._services.get("expressor")
+ chat_stream = self._services.get("chat_stream")
+
+ if not expressor or not chat_stream:
+ logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
+ return False
+
+ # 构造简化的动作数据
+ reply_data = {"text": text, "target": target or "", "emojis": []}
+
+ # 获取锚定消息(如果有)
+ observations = self._services.get("observations", [])
+
+ chatting_observation: ChattingObservation = next(
+ obs for obs in observations if isinstance(obs, ChattingObservation)
+ )
+ anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
+
+ # 如果没有找到锚点消息,创建一个占位符
+ if not anchor_message:
+ logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
+ anchor_message = await create_empty_anchor_message(
+ chat_stream.platform, chat_stream.group_info, chat_stream
+ )
+ else:
+ anchor_message.update_chat_stream(chat_stream)
+
+ response_set = [
+ ("text", text),
+ ]
+
+ # 调用内部方法发送消息
+ success = await expressor.send_response_messages(
+ anchor_message=anchor_message,
+ response_set=response_set,
+ )
+
+ return success
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
+ traceback.print_exc()
+ return False
+
+ async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
+ """发送消息的简化方法
+
+ Args:
+ text: 要发送的消息文本
+ target: 目标消息(可选)
+
+ Returns:
+ bool: 是否发送成功
+ """
+ try:
+ expressor = self._services.get("expressor")
+ chat_stream = self._services.get("chat_stream")
+
+ if not expressor or not chat_stream:
+ logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
+ return False
+
+ # 构造简化的动作数据
+ reply_data = {"text": text, "target": target or "", "emojis": []}
+
+ # 获取锚定消息(如果有)
+ observations = self._services.get("observations", [])
+
+ chatting_observation: ChattingObservation = next(
+ obs for obs in observations if isinstance(obs, ChattingObservation)
+ )
+ anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
+
+ # 如果没有找到锚点消息,创建一个占位符
+ if not anchor_message:
+ logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
+ anchor_message = await create_empty_anchor_message(
+ chat_stream.platform, chat_stream.group_info, chat_stream
+ )
+ else:
+ anchor_message.update_chat_stream(chat_stream)
+
+ # 调用内部方法发送消息
+ success, _ = await expressor.deal_reply(
+ cycle_timers=self.cycle_timers,
+ action_data=reply_data,
+ anchor_message=anchor_message,
+ reasoning=self.reasoning,
+ thinking_id=self.thinking_id,
+ )
+
+ return success
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
+ return False
+
+ def get_chat_type(self) -> str:
+ """获取当前聊天类型
+
+ Returns:
+ str: 聊天类型 ("group" 或 "private")
+ """
+ chat_stream = self._services.get("chat_stream")
+ if chat_stream and hasattr(chat_stream, "group_info"):
+ return "group" if chat_stream.group_info else "private"
+ return "unknown"
+
+ def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
+ """获取最近的消息
+
+ Args:
+ count: 要获取的消息数量
+
+ Returns:
+ List[Dict]: 消息列表,每个消息包含发送者、内容等信息
+ """
+ messages = []
+ observations = self._services.get("observations", [])
+
+ if observations and len(observations) > 0:
+ obs = observations[0]
+ if hasattr(obs, "get_talking_message"):
+ raw_messages = obs.get_talking_message()
+ # 转换为简化格式
+ for msg in raw_messages[-count:]:
+ simple_msg = {
+ "sender": msg.get("sender", "未知"),
+ "content": msg.get("content", ""),
+ "timestamp": msg.get("timestamp", 0),
+ }
+ messages.append(simple_msg)
+
+ return messages
+
+ @abstractmethod
+ async def process(self) -> Tuple[bool, str]:
+ """插件处理逻辑,子类必须实现此方法
+
+ Returns:
+ Tuple[bool, str]: (是否执行成功, 回复文本)
+ """
+ pass
+
+ async def handle_action(self) -> Tuple[bool, str]:
+ """实现BaseAction的抽象方法,调用子类的process方法
+
+ Returns:
+ Tuple[bool, str]: (是否执行成功, 回复文本)
+ """
+ return await self.process()
diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py
index 7b2e88fa0..45a4340d5 100644
--- a/src/chat/focus_chat/planners/actions/reply_action.py
+++ b/src/chat/focus_chat/planners/actions/reply_action.py
@@ -1,14 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-
from src.common.logger_manager import get_logger
-from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
-from typing import Tuple, List, Optional
+from typing import Tuple, List
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
-from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
@@ -22,29 +19,27 @@ class ReplyAction(BaseAction):
处理构建和发送消息回复的动作。
"""
- action_name:str = "reply"
- action_description:str = "表达想法,可以只包含文本、表情或两者都有"
- action_parameters:dict[str:str] = {
+ action_name: str = "reply"
+ action_description: str = "表达想法,可以只包含文本、表情或两者都有"
+ action_parameters: dict[str:str] = {
"text": "你想要表达的内容(可选)",
- "emojis": "描述当前使用表情包的场景(可选)",
+ "emojis": "描述当前使用表情包的场景,一段话描述(可选)",
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
}
- action_require:list[str] = [
+ action_require: list[str] = [
"有实质性内容需要表达",
"有人提到你,但你还没有回应他",
- "在合适的时候添加表情(不要总是添加)",
- "如果你要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本",
- "除非有明确的回复目标,如果选择了target,不用特别提到某个人的人名",
+ "在合适的时候添加表情(不要总是添加),表情描述要详细,描述当前场景,一段话描述",
+ "如果你有明确的,要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本",
"一次只回复一个人,一次只回复一个话题,突出重点",
"如果是自己发的消息想继续,需自然衔接",
"避免重复或评价自己的发言,不要和自己聊天",
- "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
+ "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短",
]
default = True
def __init__(
self,
- action_name: str,
action_data: dict,
reasoning: str,
cycle_timers: dict,
@@ -52,9 +47,8 @@ class ReplyAction(BaseAction):
observations: List[Observation],
expressor: DefaultExpressor,
chat_stream: ChatStream,
- current_cycle: CycleDetail,
log_prefix: str,
- **kwargs
+ **kwargs,
):
"""初始化回复动作处理器
@@ -67,14 +61,12 @@ class ReplyAction(BaseAction):
observations: 观察列表
expressor: 表达器
chat_stream: 聊天流
- current_cycle: 当前循环信息
log_prefix: 日志前缀
"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
self.observations = observations
self.expressor = expressor
self.chat_stream = chat_stream
- self._current_cycle = current_cycle
self.log_prefix = log_prefix
async def handle_action(self) -> Tuple[bool, str]:
@@ -89,9 +81,9 @@ class ReplyAction(BaseAction):
reasoning=self.reasoning,
reply_data=self.action_data,
cycle_timers=self.cycle_timers,
- thinking_id=self.thinking_id
+ thinking_id=self.thinking_id,
)
-
+
async def _handle_reply(
self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str
) -> tuple[bool, str]:
@@ -105,13 +97,15 @@ class ReplyAction(BaseAction):
"emojis": "微笑" # 表情关键词列表(可选)
}
"""
- # 重置连续不回复计数器
- self.total_no_reply_count = 0
- self.total_waiting_time = 0.0
# 从聊天观察获取锚定消息
- observations: ChattingObservation = self.observations[0]
- anchor_message = observations.serch_message_by_text(reply_data["target"])
+ chatting_observation: ChattingObservation = next(
+ obs for obs in self.observations if isinstance(obs, ChattingObservation)
+ )
+ if reply_data.get("target"):
+ anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
+ else:
+ anchor_message = None
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py
index bb87e1da7..ca35d3096 100644
--- a/src/chat/focus_chat/planners/planner.py
+++ b/src/chat/focus_chat/planners/planner.py
@@ -4,25 +4,30 @@ from typing import List, Dict, Any, Optional
from rich.traceback import install
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
-from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.obs_info import ObsInfo
from src.chat.focus_chat.info.cycle_info import CycleInfo
from src.chat.focus_chat.info.mind_info import MindInfo
+from src.chat.focus_chat.info.action_info import ActionInfo
from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import Individuality
-from src.chat.focus_chat.planners.action_factory import ActionManager
-from src.chat.focus_chat.planners.action_factory import ActionInfo
+from src.chat.focus_chat.planners.action_manager import ActionManager
+
logger = get_logger("planner")
install(extra_lines=3)
+
def init_prompt():
Prompt(
- """你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话:
+ """{extra_info_block}
+
+你需要基于以下信息决定如何参与对话
+这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action:
{chat_content_block}
+
{mind_info_block}
{cycle_info_block}
@@ -44,31 +49,31 @@ def init_prompt():
}}
请输出你的决策 JSON:""",
-"planner_prompt",)
-
+ "planner_prompt",
+ )
+
Prompt(
"""
action_name: {action_name}
描述:{action_description}
参数:
- {action_parameters}
+{action_parameters}
动作要求:
- {action_require}
- """,
+{action_require}""",
"action_prompt",
)
-
+
class ActionPlanner:
def __init__(self, log_prefix: str, action_manager: ActionManager):
self.log_prefix = log_prefix
# LLM规划器配置
self.planner_llm = LLMRequest(
- model=global_config.llm_plan,
+ model=global_config.model.plan,
max_tokens=1000,
request_type="action_planning", # 用于动作规划
)
-
+
self.action_manager = action_manager
async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]:
@@ -82,31 +87,69 @@ class ActionPlanner:
action = "no_reply" # 默认动作
reasoning = "规划器初始化默认"
+ action_data = {}
try:
# 获取观察信息
+ extra_info: list[str] = []
+
+ # 首先处理动作变更
+ for info in all_plan_info:
+ if isinstance(info, ActionInfo) and info.has_changes():
+ add_actions = info.get_add_actions()
+ remove_actions = info.get_remove_actions()
+ reason = info.get_reason()
+
+ # 处理动作的增加
+ for action_name in add_actions:
+ if action_name in self.action_manager.get_registered_actions():
+ self.action_manager.add_action_to_using(action_name)
+ logger.debug(f"{self.log_prefix}添加动作: {action_name}, 原因: {reason}")
+
+ # 处理动作的移除
+ for action_name in remove_actions:
+ self.action_manager.remove_action_from_using(action_name)
+ logger.debug(f"{self.log_prefix}移除动作: {action_name}, 原因: {reason}")
+
+ # 如果当前选择的动作被移除了,更新为no_reply
+ if action in remove_actions:
+ action = "no_reply"
+ reasoning = f"之前选择的动作{action}已被移除,原因: {reason}"
+
+ # 继续处理其他信息
for info in all_plan_info:
if isinstance(info, ObsInfo):
- logger.debug(f"{self.log_prefix} 观察信息: {info}")
observed_messages = info.get_talking_message()
observed_messages_str = info.get_talking_message_str_truncate()
chat_type = info.get_chat_type()
- if chat_type == "group":
- is_group_chat = True
- else:
- is_group_chat = False
+ is_group_chat = (chat_type == "group")
elif isinstance(info, MindInfo):
- logger.debug(f"{self.log_prefix} 思维信息: {info}")
current_mind = info.get_current_mind()
elif isinstance(info, CycleInfo):
- logger.debug(f"{self.log_prefix} 循环信息: {info}")
cycle_info = info.get_observe_info()
elif isinstance(info, StructuredInfo):
- logger.debug(f"{self.log_prefix} 结构化信息: {info}")
- structured_info = info.get_data()
+ _structured_info = info.get_data()
+ elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
+ extra_info.append(info.get_processed_info())
+ # 获取当前可用的动作
current_available_actions = self.action_manager.get_using_actions()
+ # 如果没有可用动作,直接返回no_reply
+ if not current_available_actions:
+ logger.warning(f"{self.log_prefix}没有可用的动作,将使用no_reply")
+ action = "no_reply"
+ reasoning = "没有可用的动作"
+ return {
+ "action_result": {
+ "action_type": action,
+ "action_data": action_data,
+ "reasoning": reasoning
+ },
+ "current_mind": current_mind,
+ "observed_messages": observed_messages
+ }
+
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt = await self.build_planner_prompt(
is_group_chat=is_group_chat, # <-- Pass HFC state
@@ -116,6 +159,7 @@ class ActionPlanner:
# structured_info=structured_info, # <-- Pass SubMind info
current_available_actions=current_available_actions, # <-- Pass determined actions
cycle_info=cycle_info, # <-- Pass cycle info
+ extra_info=extra_info,
)
# --- 调用 LLM (普通文本生成) ---
@@ -142,15 +186,13 @@ class ActionPlanner:
extracted_action = parsed_json.get("action", "no_reply")
extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由")
- # 新的reply格式
- if extracted_action == "reply":
- action_data = {
- "text": parsed_json.get("text", []),
- "emojis": parsed_json.get("emojis", []),
- "target": parsed_json.get("target", ""),
- }
- else:
- action_data = {} # 其他动作可能不需要额外数据
+ # 将所有其他属性添加到action_data
+ action_data = {}
+ for key, value in parsed_json.items():
+ if key not in ["action", "reasoning"]:
+ action_data[key] = value
+
+ # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
if extracted_action not in current_available_actions:
logger.warning(
@@ -173,7 +215,7 @@ class ActionPlanner:
except Exception as outer_e:
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}")
traceback.print_exc()
- action = "no_reply" # 发生未知错误,标记为 error 动作
+ action = "no_reply"
reasoning = f"Planner 内部处理错误: {outer_e}"
logger.debug(
@@ -194,10 +236,8 @@ class ActionPlanner:
"observed_messages": observed_messages,
}
- # 返回结果字典
return plan_result
-
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
@@ -206,6 +246,7 @@ class ActionPlanner:
current_mind: Optional[str],
current_available_actions: Dict[str, ActionInfo],
cycle_info: Optional[str],
+ extra_info: list[str],
) -> str:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
@@ -218,7 +259,6 @@ class ActionPlanner:
)
chat_context_description = f"你正在和 {chat_target_name} 私聊"
-
chat_content_block = ""
if observed_messages_str:
chat_content_block = f"聊天记录:\n{observed_messages_str}"
@@ -234,7 +274,6 @@ class ActionPlanner:
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
-
action_options_block = ""
for using_actions_name, using_actions_info in current_available_actions.items():
# print(using_actions_name)
@@ -242,38 +281,39 @@ class ActionPlanner:
# print(using_actions_info["parameters"])
# print(using_actions_info["require"])
# print(using_actions_info["description"])
-
+
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
-
+
param_text = ""
for param_name, param_description in using_actions_info["parameters"].items():
- param_text += f"{param_name}: {param_description}\n"
-
+ param_text += f" {param_name}: {param_description}\n"
+
require_text = ""
for require_item in using_actions_info["require"]:
- require_text += f"- {require_item}\n"
-
+ require_text += f" - {require_item}\n"
+
using_action_prompt = using_action_prompt.format(
action_name=using_actions_name,
action_description=using_actions_info["description"],
action_parameters=param_text,
action_require=require_text,
)
-
+
action_options_block += using_action_prompt
-
+ extra_info_block = "\n".join(extra_info)
+ extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
-
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
- bot_name=global_config.BOT_NICKNAME,
+ bot_name=global_config.bot.nickname,
prompt_personality=personality_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
mind_info_block=mind_info_block,
cycle_info_block=cycle_info,
action_options_text=action_options_block,
+ extra_info_block=extra_info_block,
)
return prompt
diff --git a/src/chat/focus_chat/working_memory/memory_item.py b/src/chat/focus_chat/working_memory/memory_item.py
new file mode 100644
index 000000000..15724a387
--- /dev/null
+++ b/src/chat/focus_chat/working_memory/memory_item.py
@@ -0,0 +1,112 @@
+from typing import Dict, Any, List, Optional, Set, Tuple
+import time
+import random
+import string
+
+
+class MemoryItem:
+ """记忆项类,用于存储单个记忆的所有相关信息"""
+
+ def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None):
+ """
+ 初始化记忆项
+
+ Args:
+ data: 记忆数据
+ from_source: 数据来源
+ tags: 数据标签列表
+ """
+ # 生成可读ID:时间戳_随机字符串
+ timestamp = int(time.time())
+ random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
+ self.id = f"{timestamp}_{random_str}"
+ self.data = data
+ self.data_type = type(data)
+ self.from_source = from_source
+ self.tags = set(tags) if tags else set()
+ self.timestamp = time.time()
+ # 修改summary的结构说明,用于存储可能的总结信息
+ # summary结构:{
+ # "brief": "记忆内容主题",
+ # "detailed": "记忆内容概括",
+ # "keypoints": ["关键概念1", "关键概念2"],
+ # "events": ["事件1", "事件2"]
+ # }
+ self.summary = None
+
+ # 记忆精简次数
+ self.compress_count = 0
+
+ # 记忆提取次数
+ self.retrieval_count = 0
+
+ # 记忆强度 (初始为10)
+ self.memory_strength = 10.0
+
+ # 记忆操作历史记录
+ # 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
+ self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
+
+ def add_tag(self, tag: str) -> None:
+ """添加标签"""
+ self.tags.add(tag)
+
+ def remove_tag(self, tag: str) -> None:
+ """移除标签"""
+ if tag in self.tags:
+ self.tags.remove(tag)
+
+ def has_tag(self, tag: str) -> bool:
+ """检查是否有特定标签"""
+ return tag in self.tags
+
+ def has_all_tags(self, tags: List[str]) -> bool:
+ """检查是否有所有指定的标签"""
+ return all(tag in self.tags for tag in tags)
+
+ def matches_source(self, source: str) -> bool:
+ """检查来源是否匹配"""
+ return self.from_source == source
+
+ def set_summary(self, summary: Dict[str, Any]) -> None:
+ """设置总结信息"""
+ self.summary = summary
+
+ def increase_strength(self, amount: float) -> None:
+ """增加记忆强度"""
+ self.memory_strength = min(10.0, self.memory_strength + amount)
+ # 记录操作历史
+ self.record_operation("strengthen")
+
+ def decrease_strength(self, amount: float) -> None:
+ """减少记忆强度"""
+ self.memory_strength = max(0.1, self.memory_strength - amount)
+ # 记录操作历史
+ self.record_operation("weaken")
+
+ def increase_compress_count(self) -> None:
+ """增加精简次数并减弱记忆强度"""
+ self.compress_count += 1
+ # 记录操作历史
+ self.record_operation("compress")
+
+ def record_retrieval(self) -> None:
+ """记录记忆被提取的情况"""
+ self.retrieval_count += 1
+ # 提取后强度翻倍
+ self.memory_strength = min(10.0, self.memory_strength * 2)
+ # 记录操作历史
+ self.record_operation("retrieval")
+
+ def record_operation(self, operation_type: str) -> None:
+ """记录操作历史"""
+ current_time = time.time()
+ self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
+
+ def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]:
+ """转换为元组格式(为了兼容性)"""
+ return (self.data, self.from_source, self.tags, self.timestamp, self.id)
+
+ def is_memory_valid(self) -> bool:
+ """检查记忆是否有效(强度是否大于等于1)"""
+ return self.memory_strength >= 1.0
diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py
new file mode 100644
index 000000000..7fda40239
--- /dev/null
+++ b/src/chat/focus_chat/working_memory/memory_manager.py
@@ -0,0 +1,781 @@
+from typing import Dict, Any, Type, TypeVar, List, Optional
+import traceback
+from json_repair import repair_json
+from rich.traceback import install
+from src.common.logger_manager import get_logger
+from src.chat.models.utils_model import LLMRequest
+from src.config.config import global_config
+from src.chat.focus_chat.working_memory.memory_item import MemoryItem
+import json # 添加json模块导入
+
+
+install(extra_lines=3)
+logger = get_logger("working_memory")
+
+T = TypeVar("T")
+
+
+class MemoryManager:
+ def __init__(self, chat_id: str):
+ """
+ 初始化工作记忆
+
+ Args:
+ chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
+ """
+ # 关联的聊天ID
+ self._chat_id = chat_id
+
+ # 主存储: 数据类型 -> 记忆项列表
+ self._memory: Dict[Type, List[MemoryItem]] = {}
+
+ # ID到记忆项的映射
+ self._id_map: Dict[str, MemoryItem] = {}
+
+ self.llm_summarizer = LLMRequest(
+ model=global_config.model.summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
+ )
+
+ @property
+ def chat_id(self) -> str:
+ """获取关联的聊天ID"""
+ return self._chat_id
+
+ @chat_id.setter
+ def chat_id(self, value: str):
+ """设置关联的聊天ID"""
+ self._chat_id = value
+
+ def push_item(self, memory_item: MemoryItem) -> str:
+ """
+ 推送一个已创建的记忆项到工作记忆中
+
+ Args:
+ memory_item: 要存储的记忆项
+
+ Returns:
+ 记忆项的ID
+ """
+ data_type = memory_item.data_type
+
+ # 确保存在该类型的存储列表
+ if data_type not in self._memory:
+ self._memory[data_type] = []
+
+ # 添加到内存和ID映射
+ self._memory[data_type].append(memory_item)
+ self._id_map[memory_item.id] = memory_item
+
+ return memory_item.id
+
+ async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem:
+ """
+ 推送一段有类型的信息到工作记忆中,并自动生成总结
+
+ Args:
+ data: 要存储的数据
+ from_source: 数据来源
+ tags: 数据标签列表
+
+ Returns:
+ 包含原始数据和总结信息的字典
+ """
+ # 如果数据是字符串类型,则先进行总结
+ if isinstance(data, str):
+ # 先生成总结
+ summary = await self.summarize_memory_item(data)
+
+ # 准备标签
+ memory_tags = list(tags) if tags else []
+
+ # 创建记忆项
+ memory_item = MemoryItem(data, from_source, memory_tags)
+
+ # 将总结信息保存到记忆项中
+ memory_item.set_summary(summary)
+
+ # 推送记忆项
+ self.push_item(memory_item)
+
+ return memory_item
+ else:
+ # 非字符串类型,直接创建并推送记忆项
+ memory_item = MemoryItem(data, from_source, tags)
+ self.push_item(memory_item)
+
+ return memory_item
+
+ def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
+ """
+ 通过ID获取记忆项
+
+ Args:
+ memory_id: 记忆项ID
+
+ Returns:
+ 找到的记忆项,如果不存在则返回None
+ """
+ memory_item = self._id_map.get(memory_id)
+ if memory_item:
+ # 检查记忆强度,如果小于1则删除
+ if not memory_item.is_memory_valid():
+ print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
+ self.delete(memory_id)
+ return None
+
+ return memory_item
+
+ def get_all_items(self) -> List[MemoryItem]:
+ """获取所有记忆项"""
+ return list(self._id_map.values())
+
+ def find_items(
+ self,
+ data_type: Optional[Type] = None,
+ source: Optional[str] = None,
+ tags: Optional[List[str]] = None,
+ start_time: Optional[float] = None,
+ end_time: Optional[float] = None,
+ memory_id: Optional[str] = None,
+ limit: Optional[int] = None,
+ newest_first: bool = False,
+ min_strength: float = 0.0,
+ ) -> List[MemoryItem]:
+ """
+ 按条件查找记忆项
+
+ Args:
+ data_type: 要查找的数据类型
+ source: 数据来源
+ tags: 必须包含的标签列表
+ start_time: 开始时间戳
+ end_time: 结束时间戳
+ memory_id: 特定记忆项ID
+ limit: 返回结果的最大数量
+ newest_first: 是否按最新优先排序
+ min_strength: 最小记忆强度
+
+ Returns:
+ 符合条件的记忆项列表
+ """
+ # 如果提供了特定ID,直接查找
+ if memory_id:
+ item = self.get_by_id(memory_id)
+ return [item] if item else []
+
+ results = []
+
+ # 确定要搜索的类型列表
+ types_to_search = [data_type] if data_type else list(self._memory.keys())
+
+ # 对每个类型进行搜索
+ for typ in types_to_search:
+ if typ not in self._memory:
+ continue
+
+ # 获取该类型的所有项目
+ items = self._memory[typ]
+
+ # 如果需要最新优先,则反转遍历顺序
+ if newest_first:
+ items_to_check = list(reversed(items))
+ else:
+ items_to_check = items
+
+ # 遍历项目
+ for item in items_to_check:
+ # 检查来源是否匹配
+ if source is not None and not item.matches_source(source):
+ continue
+
+ # 检查标签是否匹配
+ if tags is not None and not item.has_all_tags(tags):
+ continue
+
+ # 检查时间范围
+ if start_time is not None and item.timestamp < start_time:
+ continue
+ if end_time is not None and item.timestamp > end_time:
+ continue
+
+ # 检查记忆强度
+ if min_strength > 0 and item.memory_strength < min_strength:
+ continue
+
+ # 所有条件都满足,添加到结果中
+ results.append(item)
+
+ # 如果达到限制数量,提前返回
+ if limit is not None and len(results) >= limit:
+ return results
+
+ return results
+
+ async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
+ """
+ 使用LLM总结记忆项
+
+ Args:
+ content: 需要总结的内容
+
+ Returns:
+ 包含总结、概括、关键概念和事件的字典
+ """
+ prompt = f"""请对以下内容进行总结,总结成记忆,输出四部分:
+1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
+2. 记忆内容概括(200字以内):让用户可以了解记忆内容的大致内容
+3. 关键概念和知识(keypoints):多条,提取关键的概念、知识点和关键词,要包含对概念的解释
+4. 事件描述(events):多条,描述谁(人物)在什么时候(时间)做了什么(事件)
+
+内容:
+{content}
+
+请按以下JSON格式输出:
+```json
+{{
+ "brief": "记忆内容主题(20字以内)",
+ "detailed": "记忆内容概括(200字以内)",
+ "keypoints": [
+ "概念1:解释",
+ "概念2:解释",
+ ...
+ ],
+ "events": [
+ "事件1:谁在什么时候做了什么",
+ "事件2:谁在什么时候做了什么",
+ ...
+ ]
+}}
+```
+请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
+"""
+ default_summary = {
+ "brief": "主题未知的记忆",
+ "detailed": "大致内容未知的记忆",
+ "keypoints": ["未知的概念"],
+ "events": ["未知的事件"],
+ }
+
+ try:
+ # 调用LLM生成总结
+ response, _ = await self.llm_summarizer.generate_response_async(prompt)
+
+ # 使用repair_json解析响应
+ try:
+ # 使用repair_json修复JSON格式
+ fixed_json_string = repair_json(response)
+
+ # 如果repair_json返回的是字符串,需要解析为Python对象
+ if isinstance(fixed_json_string, str):
+ try:
+ json_result = json.loads(fixed_json_string)
+ except json.JSONDecodeError as decode_error:
+ logger.error(f"JSON解析错误: {str(decode_error)}")
+ return default_summary
+ else:
+ # 如果repair_json直接返回了字典对象,直接使用
+ json_result = fixed_json_string
+
+ # 进行额外的类型检查
+ if not isinstance(json_result, dict):
+ logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
+ return default_summary
+
+ # 确保所有必要字段都存在且类型正确
+ if "brief" not in json_result or not isinstance(json_result["brief"], str):
+ json_result["brief"] = "主题未知的记忆"
+
+ if "detailed" not in json_result or not isinstance(json_result["detailed"], str):
+ json_result["detailed"] = "大致内容未知的记忆"
+
+ # 处理关键概念
+ if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list):
+ json_result["keypoints"] = ["未知的概念"]
+ else:
+ # 确保keypoints中的每个项目都是字符串
+ json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
+ if not json_result["keypoints"]:
+ json_result["keypoints"] = ["未知的概念"]
+
+ # 处理事件
+ if "events" not in json_result or not isinstance(json_result["events"], list):
+ json_result["events"] = ["未知的事件"]
+ else:
+ # 确保events中的每个项目都是字符串
+ json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
+ if not json_result["events"]:
+ json_result["events"] = ["未知的事件"]
+
+ # 兼容旧版,将keypoints和events合并到key_points中
+ json_result["key_points"] = json_result["keypoints"] + json_result["events"]
+
+ return json_result
+
+ except Exception as json_error:
+ logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
+ # 返回默认结构
+ return default_summary
+
+ except Exception as e:
+ # 出错时返回简单的结构
+ logger.error(f"生成总结时出错: {str(e)}")
+ return default_summary
+
+ async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
+ """
+ 对记忆进行精简操作,根据要求修改要点、总结和概括
+
+ Args:
+ memory_id: 记忆ID
+ requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
+
+ Returns:
+ 修改后的记忆总结字典
+ """
+ # 获取指定ID的记忆项
+ logger.info(f"精简记忆: {memory_id}")
+ memory_item = self.get_by_id(memory_id)
+ if not memory_item:
+ raise ValueError(f"未找到ID为{memory_id}的记忆项")
+
+ # 增加精简次数
+ memory_item.increase_compress_count()
+
+ summary = memory_item.summary
+
+ # 使用LLM根据要求对总结、概括和要点进行精简修改
+ prompt = f"""
+请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程:
+要求:{requirements}
+你可以随机对关键概念和事件进行压缩,模糊或者丢弃,修改后,同样修改主题和概括
+
+目前主题:{summary["brief"]}
+
+目前概括:{summary["detailed"]}
+
+目前关键概念:
+{chr(10).join([f"- {point}" for point in summary.get("keypoints", [])])}
+
+目前事件:
+{chr(10).join([f"- {point}" for point in summary.get("events", [])])}
+
+请生成修改后的主题、概括、关键概念和事件,遵循以下格式:
+```json
+{{
+ "brief": "修改后的主题(20字以内)",
+ "detailed": "修改后的概括(200字以内)",
+ "keypoints": [
+ "修改后的概念1:解释",
+ "修改后的概念2:解释"
+ ],
+ "events": [
+ "修改后的事件1:谁在什么时候做了什么",
+ "修改后的事件2:谁在什么时候做了什么"
+ ]
+}}
+```
+请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
+"""
+ # 检查summary中是否有旧版结构,转换为新版结构
+ if "keypoints" not in summary and "events" not in summary and "key_points" in summary:
+ # 尝试区分key_points中的keypoints和events
+ # 简单地将前半部分视为keypoints,后半部分视为events
+ key_points = summary.get("key_points", [])
+ halfway = len(key_points) // 2
+ summary["keypoints"] = key_points[:halfway] or ["未知的概念"]
+ summary["events"] = key_points[halfway:] or ["未知的事件"]
+
+ # 定义默认的精简结果
+ default_refined = {
+ "brief": summary["brief"],
+ "detailed": summary["detailed"],
+ "keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
+ "events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
+ }
+
+ try:
+ # 调用LLM修改总结、概括和要点
+ response, _ = await self.llm_summarizer.generate_response_async(prompt)
+ logger.info(f"精简记忆响应: {response}")
+ # 使用repair_json处理响应
+ try:
+ # 修复JSON格式
+ fixed_json_string = repair_json(response)
+
+ # 将修复后的字符串解析为Python对象
+ if isinstance(fixed_json_string, str):
+ try:
+ refined_data = json.loads(fixed_json_string)
+ except json.JSONDecodeError as decode_error:
+ logger.error(f"JSON解析错误: {str(decode_error)}")
+ refined_data = default_refined
+ else:
+ # 如果repair_json直接返回了字典对象,直接使用
+ refined_data = fixed_json_string
+
+ # 确保是字典类型
+ if not isinstance(refined_data, dict):
+ logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
+ refined_data = default_refined
+
+ # 更新总结、概括
+ summary["brief"] = refined_data.get("brief", "主题未知的记忆")
+ summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆")
+
+ # 更新关键概念
+ keypoints = refined_data.get("keypoints", [])
+ if isinstance(keypoints, list) and keypoints:
+ # 确保所有关键概念都是字符串
+ summary["keypoints"] = [str(point) for point in keypoints if point is not None]
+ else:
+ # 如果keypoints不是列表或为空,使用默认值
+ summary["keypoints"] = ["主要概念已遗忘"]
+
+ # 更新事件
+ events = refined_data.get("events", [])
+ if isinstance(events, list) and events:
+ # 确保所有事件都是字符串
+ summary["events"] = [str(event) for event in events if event is not None]
+ else:
+ # 如果events不是列表或为空,使用默认值
+ summary["events"] = ["事件细节已遗忘"]
+
+ # 兼容旧版,维护key_points
+ summary["key_points"] = summary["keypoints"] + summary["events"]
+
+ except Exception as e:
+ logger.error(f"精简记忆出错: {str(e)}")
+ traceback.print_exc()
+
+ # 出错时使用简化的默认精简
+ summary["brief"] = summary["brief"] + " (已简化)"
+ summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1]
+ summary["events"] = summary.get("events", ["未知的事件"])[:1]
+ summary["key_points"] = summary["keypoints"] + summary["events"]
+
+ except Exception as e:
+ logger.error(f"精简记忆调用LLM出错: {str(e)}")
+ traceback.print_exc()
+
+ # 更新原记忆项的总结
+ memory_item.set_summary(summary)
+
+ return memory_item
+
+ def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
+ """
+ 使单个记忆衰减
+
+ Args:
+ memory_id: 记忆ID
+ decay_factor: 衰减因子(0-1之间)
+
+ Returns:
+ 是否成功衰减
+ """
+ memory_item = self.get_by_id(memory_id)
+ if not memory_item:
+ return False
+
+ # 计算衰减量(当前强度 * (1-衰减因子))
+ old_strength = memory_item.memory_strength
+ decay_amount = old_strength * (1 - decay_factor)
+
+ # 更新强度
+ memory_item.memory_strength = decay_amount
+
+ return True
+
+ def delete(self, memory_id: str) -> bool:
+ """
+ 删除指定ID的记忆项
+
+ Args:
+ memory_id: 要删除的记忆项ID
+
+ Returns:
+ 是否成功删除
+ """
+ if memory_id not in self._id_map:
+ return False
+
+ # 获取要删除的项
+ item = self._id_map[memory_id]
+
+ # 从内存中删除
+ data_type = item.data_type
+ if data_type in self._memory:
+ self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
+
+ # 从ID映射中删除
+ del self._id_map[memory_id]
+
+ return True
+
+ def clear(self, data_type: Optional[Type] = None) -> None:
+ """
+ 清除记忆中的数据
+
+ Args:
+ data_type: 要清除的数据类型,如果为None则清除所有数据
+ """
+ if data_type is None:
+ # 清除所有数据
+ self._memory.clear()
+ self._id_map.clear()
+ elif data_type in self._memory:
+ # 清除指定类型的数据
+ for item in self._memory[data_type]:
+ if item.id in self._id_map:
+ del self._id_map[item.id]
+ del self._memory[data_type]
+
+ async def merge_memories(
+ self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
+ ) -> MemoryItem:
+ """
+ 合并两个记忆项
+
+ Args:
+ memory_id1: 第一个记忆项ID
+ memory_id2: 第二个记忆项ID
+ reason: 合并原因
+ delete_originals: 是否删除原始记忆,默认为True
+
+ Returns:
+ 包含合并后的记忆信息的字典
+ """
+ # 获取两个记忆项
+ memory_item1 = self.get_by_id(memory_id1)
+ memory_item2 = self.get_by_id(memory_id2)
+
+ if not memory_item1 or not memory_item2:
+ raise ValueError("无法找到指定的记忆项")
+
+ content1 = memory_item1.data
+ content2 = memory_item2.data
+
+ # 获取记忆的摘要信息(如果有)
+ summary1 = memory_item1.summary
+ summary2 = memory_item2.summary
+
+ # 构建合并提示
+ prompt = f"""
+请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
+合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
+
+合并原因:{reason}
+"""
+
+ # 如果有摘要信息,添加到提示中
+ if summary1:
+ prompt += f"记忆1主题:{summary1['brief']}\n"
+ prompt += f"记忆1概括:{summary1['detailed']}\n"
+
+ if "keypoints" in summary1:
+ prompt += "记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
+
+ if "events" in summary1:
+ prompt += "记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
+ elif "key_points" in summary1:
+ prompt += "记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
+
+ if summary2:
+ prompt += f"记忆2主题:{summary2['brief']}\n"
+ prompt += f"记忆2概括:{summary2['detailed']}\n"
+
+ if "keypoints" in summary2:
+ prompt += "记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
+
+ if "events" in summary2:
+ prompt += "记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
+ elif "key_points" in summary2:
+ prompt += "记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
+
+ # 添加记忆原始内容
+ prompt += f"""
+记忆1原始内容:
+{content1}
+
+记忆2原始内容:
+{content2}
+
+请按以下JSON格式输出合并结果:
+```json
+{{
+ "content": "合并后的记忆内容文本(尽可能保留原信息,但去除重复)",
+ "brief": "合并后的主题(20字以内)",
+ "detailed": "合并后的概括(200字以内)",
+ "keypoints": [
+ "合并后的概念1:解释",
+ "合并后的概念2:解释",
+ "合并后的概念3:解释"
+ ],
+ "events": [
+ "合并后的事件1:谁在什么时候做了什么",
+ "合并后的事件2:谁在什么时候做了什么"
+ ]
+}}
+```
+请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
+"""
+
+ # 默认合并结果
+ default_merged = {
+ "content": f"{content1}\n\n{content2}",
+ "brief": f"合并:{summary1['brief']} + {summary2['brief']}",
+ "detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
+ "keypoints": [],
+ "events": [],
+ }
+
+ # 合并旧版key_points
+ if "key_points" in summary1:
+ default_merged["keypoints"].extend(summary1.get("keypoints", []))
+ default_merged["events"].extend(summary1.get("events", []))
+ # 如果没有新的结构,尝试从旧结构分离
+ if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary1:
+ key_points = summary1["key_points"]
+ halfway = len(key_points) // 2
+ default_merged["keypoints"].extend(key_points[:halfway])
+ default_merged["events"].extend(key_points[halfway:])
+
+ if "key_points" in summary2:
+ default_merged["keypoints"].extend(summary2.get("keypoints", []))
+ default_merged["events"].extend(summary2.get("events", []))
+ # 如果没有新的结构,尝试从旧结构分离
+ if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary2:
+ key_points = summary2["key_points"]
+ halfway = len(key_points) // 2
+ default_merged["keypoints"].extend(key_points[:halfway])
+ default_merged["events"].extend(key_points[halfway:])
+
+ # 确保列表不为空
+ if not default_merged["keypoints"]:
+ default_merged["keypoints"] = ["合并的关键概念"]
+ if not default_merged["events"]:
+ default_merged["events"] = ["合并的事件"]
+
+ # 添加key_points兼容
+ default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"]
+
+ try:
+ # 调用LLM合并记忆
+ response, _ = await self.llm_summarizer.generate_response_async(prompt)
+
+ # 处理LLM返回的合并结果
+ try:
+ # 修复JSON格式
+ fixed_json_string = repair_json(response)
+
+ # 将修复后的字符串解析为Python对象
+ if isinstance(fixed_json_string, str):
+ try:
+ merged_data = json.loads(fixed_json_string)
+ except json.JSONDecodeError as decode_error:
+ logger.error(f"JSON解析错误: {str(decode_error)}")
+ merged_data = default_merged
+ else:
+ # 如果repair_json直接返回了字典对象,直接使用
+ merged_data = fixed_json_string
+
+ # 确保是字典类型
+ if not isinstance(merged_data, dict):
+ logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
+ merged_data = default_merged
+
+ # 确保所有必要字段都存在且类型正确
+ if "content" not in merged_data or not isinstance(merged_data["content"], str):
+ merged_data["content"] = default_merged["content"]
+
+ if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
+ merged_data["brief"] = default_merged["brief"]
+
+ if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str):
+ merged_data["detailed"] = default_merged["detailed"]
+
+ # 处理关键概念
+ if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list):
+ merged_data["keypoints"] = default_merged["keypoints"]
+ else:
+ # 确保keypoints中的每个项目都是字符串
+ merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
+ if not merged_data["keypoints"]:
+ merged_data["keypoints"] = ["合并的关键概念"]
+
+ # 处理事件
+ if "events" not in merged_data or not isinstance(merged_data["events"], list):
+ merged_data["events"] = default_merged["events"]
+ else:
+ # 确保events中的每个项目都是字符串
+ merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
+ if not merged_data["events"]:
+ merged_data["events"] = ["合并的事件"]
+
+ # 添加key_points兼容
+ merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"]
+
+ except Exception as e:
+ logger.error(f"合并记忆时处理JSON出错: {str(e)}")
+ traceback.print_exc()
+ merged_data = default_merged
+ except Exception as e:
+ logger.error(f"合并记忆调用LLM出错: {str(e)}")
+ traceback.print_exc()
+ merged_data = default_merged
+
+ # 创建新的记忆项
+ # 合并记忆项的标签
+ merged_tags = memory_item1.tags.union(memory_item2.tags)
+
+ # 取两个记忆项中更强的来源
+ merged_source = (
+ memory_item1.from_source
+ if memory_item1.memory_strength >= memory_item2.memory_strength
+ else memory_item2.from_source
+ )
+
+ # 创建新的记忆项
+ merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
+
+ # 设置合并后的摘要
+ summary = {
+ "brief": merged_data["brief"],
+ "detailed": merged_data["detailed"],
+ "keypoints": merged_data["keypoints"],
+ "events": merged_data["events"],
+ "key_points": merged_data["key_points"],
+ }
+ merged_memory.set_summary(summary)
+
+ # 记忆强度取两者最大值
+ merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
+
+ # 添加到存储中
+ self.push_item(merged_memory)
+
+ # 如果需要,删除原始记忆
+ if delete_originals:
+ self.delete(memory_id1)
+ self.delete(memory_id2)
+
+ return merged_memory
+
+ def delete_earliest_memory(self) -> bool:
+ """
+ 删除最早的记忆项
+
+ Returns:
+ 是否成功删除
+ """
+ # 获取所有记忆项
+ all_memories = self.get_all_items()
+
+ if not all_memories:
+ return False
+
+ # 按时间戳排序,找到最早的记忆项
+ earliest_memory = min(all_memories, key=lambda item: item.timestamp)
+
+ # 删除最早的记忆项
+ return self.delete(earliest_memory.id)
diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py
new file mode 100644
index 000000000..db9824150
--- /dev/null
+++ b/src/chat/focus_chat/working_memory/working_memory.py
@@ -0,0 +1,192 @@
+from typing import List, Any, Optional
+import asyncio
+import random
+from src.common.logger_manager import get_logger
+from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
+
+logger = get_logger(__name__)
+
+# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
+
+
+class WorkingMemory:
+ """
+ 工作记忆,负责协调和运作记忆
+ 从属于特定的流,用chat_id来标识
+ """
+
+ def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
+ """
+ 初始化工作记忆管理器
+
+ Args:
+ max_memories_per_chat: 每个聊天的最大记忆数量
+ auto_decay_interval: 自动衰减记忆的时间间隔(秒)
+ """
+ self.memory_manager = MemoryManager(chat_id)
+
+ # 记忆容量上限
+ self.max_memories_per_chat = max_memories_per_chat
+
+ # 自动衰减间隔
+ self.auto_decay_interval = auto_decay_interval
+
+ # 衰减任务
+ self.decay_task = None
+
+ # 启动自动衰减任务
+ self._start_auto_decay()
+
+ def _start_auto_decay(self):
+ """启动自动衰减任务"""
+ if self.decay_task is None:
+ self.decay_task = asyncio.create_task(self._auto_decay_loop())
+
+ async def _auto_decay_loop(self):
+ """自动衰减循环"""
+ while True:
+ await asyncio.sleep(self.auto_decay_interval)
+ try:
+ await self.decay_all_memories()
+ except Exception as e:
+ print(f"自动衰减记忆时出错: {str(e)}")
+
+ async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
+ """
+ 添加一段记忆到指定聊天
+
+ Args:
+ content: 记忆内容
+ from_source: 数据来源
+ tags: 数据标签列表
+
+ Returns:
+ 包含记忆信息的字典
+ """
+ memory = await self.memory_manager.push_with_summary(content, from_source, tags)
+ if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
+ self.remove_earliest_memory()
+
+ return memory
+
+ def remove_earliest_memory(self):
+ """
+ 删除最早的记忆
+ """
+ return self.memory_manager.delete_earliest_memory()
+
+ async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
+ """
+ 检索记忆
+
+ Args:
+ chat_id: 聊天ID
+ memory_id: 记忆ID
+
+ Returns:
+ 检索到的记忆项,如果不存在则返回None
+ """
+ memory_item = self.memory_manager.get_by_id(memory_id)
+ if memory_item:
+ memory_item.retrieval_count += 1
+ memory_item.increase_strength(5)
+ return memory_item
+ return None
+
+ async def decay_all_memories(self, decay_factor: float = 0.5):
+ """
+ 对所有聊天的所有记忆进行衰减
+ 衰减:对记忆进行refine压缩,强度会变为原先的0.5
+
+ Args:
+ decay_factor: 衰减因子(0-1之间)
+ """
+ logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
+
+ all_memories = self.memory_manager.get_all_items()
+
+ for memory_item in all_memories:
+ # 如果压缩完小于1会被删除
+ memory_id = memory_item.id
+ self.memory_manager.decay_memory(memory_id, decay_factor)
+ if memory_item.memory_strength < 1:
+ self.memory_manager.delete(memory_id)
+ continue
+ # 计算衰减量
+ if memory_item.memory_strength < 5:
+ await self.memory_manager.refine_memory(
+ memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
+ )
+
+ async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
+ """合并记忆
+
+ Args:
+ memory_str: 记忆内容
+ """
+ return await self.memory_manager.merge_memories(
+ memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
+ )
+
+ # 暂时没用,先留着
+ async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
+ """
+ 模拟记忆模糊过程,随机选择一部分记忆进行精简
+
+ Args:
+ chat_id: 聊天ID
+ blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简
+ """
+ memory = self.get_memory(chat_id)
+
+ # 获取所有字符串类型且有总结的记忆
+ all_summarized_memories = []
+ for type_items in memory._memory.values():
+ for item in type_items:
+ if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
+ all_summarized_memories.append(item)
+
+ if not all_summarized_memories:
+ return
+
+ # 计算要模糊的记忆数量
+ blur_count = max(1, int(len(all_summarized_memories) * blur_rate))
+
+ # 随机选择要模糊的记忆
+ memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories)))
+
+ # 对选中的记忆进行精简
+ for memory_item in memories_to_blur:
+ try:
+ # 根据记忆强度决定模糊程度
+ if memory_item.memory_strength > 7:
+ requirement = "保留所有重要信息,仅略微精简"
+ elif memory_item.memory_strength > 4:
+ requirement = "保留核心要点,适度精简细节"
+ else:
+ requirement = "只保留最关键的1-2个要点,大幅精简内容"
+
+ # 进行精简
+ await memory.refine_memory(memory_item.id, requirement)
+ print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}")
+
+ except Exception as e:
+ print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
+
+ async def shutdown(self) -> None:
+ """关闭管理器,停止所有任务"""
+ if self.decay_task and not self.decay_task.done():
+ self.decay_task.cancel()
+ try:
+ await self.decay_task
+ except asyncio.CancelledError:
+ pass
+
+ def get_all_memories(self) -> List[MemoryItem]:
+ """
+ 获取所有记忆项目
+
+ Returns:
+ List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
+ """
+ return self.memory_manager.get_all_items()
diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py
index d9fa1c9d3..28b248bdc 100644
--- a/src/chat/heart_flow/background_tasks.py
+++ b/src/chat/heart_flow/background_tasks.py
@@ -1,13 +1,9 @@
import asyncio
import traceback
from typing import Optional, Coroutine, Callable, Any, List
-
from src.common.logger_manager import get_logger
-
-# Need manager types for dependency injection
from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
-from src.chat.heart_flow.interest_logger import InterestLogger
logger = get_logger("background_tasks")
@@ -62,23 +58,18 @@ class BackgroundTaskManager:
mai_state_info: MaiStateInfo, # Needs current state info
mai_state_manager: MaiStateManager,
subheartflow_manager: SubHeartflowManager,
- interest_logger: InterestLogger,
):
self.mai_state_info = mai_state_info
self.mai_state_manager = mai_state_manager
self.subheartflow_manager = subheartflow_manager
- self.interest_logger = interest_logger
# Task references
self._state_update_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
- self._logging_task: Optional[asyncio.Task] = None
- self._normal_chat_timeout_check_task: Optional[asyncio.Task] = None
self._hf_judge_state_update_task: Optional[asyncio.Task] = None
self._into_focus_task: Optional[asyncio.Task] = None
self._private_chat_activation_task: Optional[asyncio.Task] = None # 新增私聊激活任务引用
self._tasks: List[Optional[asyncio.Task]] = [] # Keep track of all tasks
- self._detect_command_from_gui_task: Optional[asyncio.Task] = None # 新增GUI命令检测任务引用
async def start_tasks(self):
"""启动所有后台任务
@@ -97,30 +88,12 @@ class BackgroundTaskManager:
f"聊天状态更新任务已启动 间隔:{STATE_UPDATE_INTERVAL_SECONDS}s",
"_state_update_task",
),
- (
- lambda: self._run_normal_chat_timeout_check_cycle(NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS),
- "debug",
- f"聊天超时检查任务已启动 间隔:{NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS}s",
- "_normal_chat_timeout_check_task",
- ),
- (
- lambda: self._run_absent_into_chat(HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS),
- "debug",
- f"状态评估任务已启动 间隔:{HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS}s",
- "_hf_judge_state_update_task",
- ),
(
self._run_cleanup_cycle,
"info",
f"清理任务已启动 间隔:{CLEANUP_INTERVAL_SECONDS}s",
"_cleanup_task",
),
- (
- self._run_logging_cycle,
- "info",
- f"日志任务已启动 间隔:{LOG_INTERVAL_SECONDS}s",
- "_logging_task",
- ),
# 新增兴趣评估任务配置
(
self._run_into_focus_cycle,
@@ -136,13 +109,6 @@ class BackgroundTaskManager:
f"私聊激活检查任务已启动 间隔:{PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS}s",
"_private_chat_activation_task",
),
- # 新增GUI命令检测任务配置
- # (
- # lambda: self._run_detect_command_from_gui_cycle(3),
- # "debug",
- # f"GUI命令检测任务已启动 间隔:{3}s",
- # "_detect_command_from_gui_task",
- # ),
]
# 统一启动所有任务
@@ -207,7 +173,6 @@ class BackgroundTaskManager:
if state_changed:
current_state = self.mai_state_info.get_current_state()
- await self.subheartflow_manager.enforce_subheartflow_limits()
# 状态转换处理
@@ -218,15 +183,6 @@ class BackgroundTaskManager:
logger.info("检测到离线,停用所有子心流")
await self.subheartflow_manager.deactivate_all_subflows()
- async def _perform_absent_into_chat(self):
- """调用llm检测是否转换ABSENT-CHAT状态"""
- logger.debug("[状态评估任务] 开始基于LLM评估子心流状态...")
- await self.subheartflow_manager.sbhf_absent_into_chat()
-
- async def _normal_chat_timeout_check_work(self):
- """检查处于CHAT状态的子心流是否因长时间未发言而超时,并将其转为ABSENT"""
- logger.debug("[聊天超时检查] 开始检查处于CHAT状态的子心流...")
- await self.subheartflow_manager.sbhf_chat_into_absent()
async def _perform_cleanup_work(self):
"""执行子心流清理任务
@@ -253,42 +209,23 @@ class BackgroundTaskManager:
# 记录最终清理结果
logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流")
- async def _perform_logging_work(self):
- """执行一轮状态日志记录。"""
- await self.interest_logger.log_all_states()
# --- 新增兴趣评估工作函数 ---
async def _perform_into_focus_work(self):
"""执行一轮子心流兴趣评估与提升检查。"""
# 直接调用 subheartflow_manager 的方法,并传递当前状态信息
await self.subheartflow_manager.sbhf_absent_into_focus()
-
- # --- 结束新增 ---
-
- # --- 结束新增 ---
-
- # --- Specific Task Runners --- #
+
async def _run_state_update_cycle(self, interval: int):
await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work)
- async def _run_absent_into_chat(self, interval: int):
- await _run_periodic_loop(task_name="Into Chat", interval=interval, task_func=self._perform_absent_into_chat)
- async def _run_normal_chat_timeout_check_cycle(self, interval: int):
- await _run_periodic_loop(
- task_name="Normal Chat Timeout Check", interval=interval, task_func=self._normal_chat_timeout_check_work
- )
async def _run_cleanup_cycle(self):
await _run_periodic_loop(
task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work
)
- async def _run_logging_cycle(self):
- await _run_periodic_loop(
- task_name="State Logging", interval=LOG_INTERVAL_SECONDS, task_func=self._perform_logging_work
- )
-
# --- 新增兴趣评估任务运行器 ---
async def _run_into_focus_cycle(self):
await _run_periodic_loop(
@@ -304,11 +241,3 @@ class BackgroundTaskManager:
interval=interval,
task_func=self.subheartflow_manager.sbhf_absent_private_into_focus,
)
-
- # # 有api之后删除
- # async def _run_detect_command_from_gui_cycle(self, interval: int):
- # await _run_periodic_loop(
- # task_name="Detect Command from GUI",
- # interval=interval,
- # task_func=self.subheartflow_manager.detect_command_from_gui,
- # )
diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py
index ad876bcf0..bad0683ce 100644
--- a/src/chat/heart_flow/heartflow.py
+++ b/src/chat/heart_flow/heartflow.py
@@ -4,10 +4,8 @@ from src.config.config import global_config
from src.common.logger_manager import get_logger
from typing import Any, Optional
from src.tools.tool_use import ToolUser
-from src.chat.person_info.relationship_manager import relationship_manager # Module instance
from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
-from src.chat.heart_flow.interest_logger import InterestLogger # Import InterestLogger
from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager
logger = get_logger("heartflow")
@@ -17,16 +15,10 @@ class Heartflow:
"""主心流协调器,负责初始化并协调各个子系统:
- 状态管理 (MaiState)
- 子心流管理 (SubHeartflow)
- - 思考过程 (Mind)
- - 日志记录 (InterestLogger)
- 后台任务 (BackgroundTaskManager)
"""
def __init__(self):
- # 核心状态
- self.current_mind = "什么也没想" # 当前主心流想法
- self.past_mind = [] # 历史想法记录
-
# 状态管理相关
self.current_state: MaiStateInfo = MaiStateInfo() # 当前状态信息
self.mai_state_manager: MaiStateManager = MaiStateManager() # 状态决策管理器
@@ -34,23 +26,11 @@ class Heartflow:
# 子心流管理 (在初始化时传入 current_state)
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
- # LLM模型配置
- self.llm_model = LLMRequest(
- model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
- )
-
- # 外部依赖模块
- self.tool_user_instance = ToolUser() # 工具使用模块
- self.relationship_manager_instance = relationship_manager # 关系管理模块
-
- self.interest_logger: InterestLogger = InterestLogger(self.subheartflow_manager, self) # 兴趣日志记录器
-
# 后台任务管理器 (整合所有定时任务)
self.background_task_manager: BackgroundTaskManager = BackgroundTaskManager(
mai_state_info=self.current_state,
mai_state_manager=self.mai_state_manager,
subheartflow_manager=self.subheartflow_manager,
- interest_logger=self.interest_logger,
)
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
diff --git a/src/chat/heart_flow/interest_chatting.py b/src/chat/heart_flow/interest_chatting.py
index 45f7fe952..bce372b5c 100644
--- a/src/chat/heart_flow/interest_chatting.py
+++ b/src/chat/heart_flow/interest_chatting.py
@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
class InterestChatting:
def __init__(
self,
- decay_rate=global_config.default_decay_rate_per_second,
+ decay_rate=global_config.focus_chat.default_decay_rate_per_second,
max_interest=MAX_INTEREST,
- trigger_threshold=global_config.reply_trigger_threshold,
+ trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
max_probability=MAX_REPLY_PROBABILITY,
):
# 基础属性初始化
diff --git a/src/chat/heart_flow/interest_logger.py b/src/chat/heart_flow/interest_logger.py
deleted file mode 100644
index b33f449db..000000000
--- a/src/chat/heart_flow/interest_logger.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import asyncio
-import time
-import json
-import os
-import traceback
-from typing import TYPE_CHECKING, Dict, List
-
-from src.common.logger_manager import get_logger
-
-# Need chat_manager to get stream names
-from src.chat.message_receive.chat_stream import chat_manager
-
-if TYPE_CHECKING:
- from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
- from src.chat.heart_flow.sub_heartflow import SubHeartflow
- from src.chat.heart_flow.heartflow import Heartflow # 导入 Heartflow 类型
-
-
-logger = get_logger("interest")
-
-# Consider moving log directory/filename constants here
-LOG_DIRECTORY = "logs/interest"
-HISTORY_LOG_FILENAME = "interest_history.log"
-
-
-def _ensure_log_directory():
- """确保日志目录存在。"""
- os.makedirs(LOG_DIRECTORY, exist_ok=True)
- logger.info(f"已确保日志目录 '{LOG_DIRECTORY}' 存在")
-
-
-def _clear_and_create_log_file():
- """清除日志文件并创建新的日志文件。"""
- if os.path.exists(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)):
- os.remove(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME))
- with open(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME), "w", encoding="utf-8") as f:
- f.write("")
-
-
-class InterestLogger:
- """负责定期记录主心流和所有子心流的状态到日志文件。"""
-
- def __init__(self, subheartflow_manager: "SubHeartflowManager", heartflow: "Heartflow"):
- """
- 初始化 InterestLogger。
-
- Args:
- subheartflow_manager: 子心流管理器实例。
- heartflow: 主心流实例,用于获取主心流状态。
- """
- self.subheartflow_manager = subheartflow_manager
- self.heartflow = heartflow # 存储 Heartflow 实例
- self._history_log_file_path = os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)
- _ensure_log_directory()
- _clear_and_create_log_file()
-
- async def get_all_subflow_states(self) -> Dict[str, Dict]:
- """并发获取所有活跃子心流的当前完整状态。"""
- all_flows: List["SubHeartflow"] = self.subheartflow_manager.get_all_subheartflows()
- tasks = []
- results = {}
-
- if not all_flows:
- # logger.debug("未找到任何子心流状态")
- return results
-
- for subheartflow in all_flows:
- if await self.subheartflow_manager.get_or_create_subheartflow(subheartflow.subheartflow_id):
- tasks.append(
- asyncio.create_task(subheartflow.get_full_state(), name=f"get_state_{subheartflow.subheartflow_id}")
- )
- else:
- logger.warning(f"子心流 {subheartflow.subheartflow_id} 在创建任务前已消失")
-
- if tasks:
- done, pending = await asyncio.wait(tasks, timeout=5.0)
-
- if pending:
- logger.warning(f"获取子心流状态超时,有 {len(pending)} 个任务未完成")
- for task in pending:
- task.cancel()
-
- for task in done:
- stream_id_str = task.get_name().split("get_state_")[-1]
- stream_id = stream_id_str
-
- if task.cancelled():
- logger.warning(f"获取子心流 {stream_id} 状态的任务已取消(超时)", exc_info=False)
- elif task.exception():
- exc = task.exception()
- logger.warning(f"获取子心流 {stream_id} 状态出错: {exc}")
- else:
- result = task.result()
- results[stream_id] = result
-
- logger.trace(f"成功获取 {len(results)} 个子心流的完整状态")
- return results
-
- async def log_all_states(self):
- """获取主心流状态和所有子心流的完整状态并写入日志文件。"""
- try:
- current_timestamp = time.time()
-
- # main_mind = self.heartflow.current_mind
- # 获取 Mai 状态名称
- mai_state_name = self.heartflow.current_state.get_current_state().name
-
- all_subflow_states = await self.get_all_subflow_states()
-
- log_entry_base = {
- "timestamp": round(current_timestamp, 2),
- # "main_mind": main_mind,
- "mai_state": mai_state_name,
- "subflow_count": len(all_subflow_states),
- "subflows": [],
- }
-
- if not all_subflow_states:
- # logger.debug("没有获取到任何子心流状态,仅记录主心流状态")
- with open(self._history_log_file_path, "a", encoding="utf-8") as f:
- f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n")
- return
-
- subflow_details = []
- items_snapshot = list(all_subflow_states.items())
- for stream_id, state in items_snapshot:
- group_name = stream_id
- try:
- chat_stream = chat_manager.get_stream(stream_id)
- if chat_stream:
- if chat_stream.group_info:
- group_name = chat_stream.group_info.group_name
- elif chat_stream.user_info:
- group_name = f"私聊_{chat_stream.user_info.user_nickname}"
- except Exception as e:
- logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
-
- interest_state = state.get("interest_state", {})
-
- subflow_entry = {
- "stream_id": stream_id,
- "group_name": group_name,
- "sub_mind": state.get("current_mind", "未知"),
- "sub_chat_state": state.get("chat_state", "未知"),
- "interest_level": interest_state.get("interest_level", 0.0),
- "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
- # "is_above_threshold": interest_state.get("is_above_threshold", False),
- }
- subflow_details.append(subflow_entry)
-
- log_entry_base["subflows"] = subflow_details
-
- with open(self._history_log_file_path, "a", encoding="utf-8") as f:
- f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n")
-
- except IOError as e:
- logger.error(f"写入状态日志到 {self._history_log_file_path} 出错: {e}")
- except Exception as e:
- logger.error(f"记录状态时发生意外错误: {e}")
- logger.error(traceback.format_exc())
-
- async def api_get_all_states(self):
- """获取主心流和所有子心流的状态。"""
- try:
- current_timestamp = time.time()
-
- # main_mind = self.heartflow.current_mind
- # 获取 Mai 状态名称
- mai_state_name = self.heartflow.current_state.get_current_state().name
-
- all_subflow_states = await self.get_all_subflow_states()
-
- log_entry_base = {
- "timestamp": round(current_timestamp, 2),
- # "main_mind": main_mind,
- "mai_state": mai_state_name,
- "subflow_count": len(all_subflow_states),
- "subflows": [],
- }
-
- subflow_details = []
- items_snapshot = list(all_subflow_states.items())
- for stream_id, state in items_snapshot:
- group_name = stream_id
- try:
- chat_stream = chat_manager.get_stream(stream_id)
- if chat_stream:
- if chat_stream.group_info:
- group_name = chat_stream.group_info.group_name
- elif chat_stream.user_info:
- group_name = f"私聊_{chat_stream.user_info.user_nickname}"
- except Exception as e:
- logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
-
- interest_state = state.get("interest_state", {})
-
- subflow_entry = {
- "stream_id": stream_id,
- "group_name": group_name,
- "sub_mind": state.get("current_mind", "未知"),
- "sub_chat_state": state.get("chat_state", "未知"),
- "interest_level": interest_state.get("interest_level", 0.0),
- "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
- # "is_above_threshold": interest_state.get("is_above_threshold", False),
- }
- subflow_details.append(subflow_entry)
-
- log_entry_base["subflows"] = subflow_details
- return subflow_details
- except Exception as e:
- logger.error(f"记录状态时发生意外错误: {e}")
- logger.error(traceback.format_exc())
diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py
index 7dea910e9..c5e272796 100644
--- a/src/chat/heart_flow/mai_state_manager.py
+++ b/src/chat/heart_flow/mai_state_manager.py
@@ -13,72 +13,24 @@ logger = get_logger("mai_state")
# The line `enable_unlimited_hfc_chat = False` is setting a configuration parameter that controls
# whether a specific debugging feature is enabled or not. When `enable_unlimited_hfc_chat` is set to
# `False`, it means that the debugging feature for unlimited focused chatting is disabled.
-enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
-# enable_unlimited_hfc_chat = False
+# enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
+enable_unlimited_hfc_chat = False
prevent_offline_state = True
-# 目前默认不启用OFFLINE状态
-
-# 不同状态下普通聊天的最大消息数
-base_normal_chat_num = global_config.base_normal_chat_num
-base_focused_chat_num = global_config.base_focused_chat_num
-
-
-MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
-MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
-MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
-
-# 不同状态下专注聊天的最大消息数
-MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2)
-MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num
-MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2
-
-# -- 状态定义 --
+# 目前默认不启用OFFLINE状
class MaiState(enum.Enum):
"""
聊天状态:
OFFLINE: 不在线:回复概率极低,不会进行任何聊天
- PEEKING: 看一眼手机:回复概率较低,会进行一些普通聊天
NORMAL_CHAT: 正常看手机:回复概率较高,会进行一些普通聊天和少量的专注聊天
FOCUSED_CHAT: 专注聊天:回复概率极高,会进行专注聊天和少量的普通聊天
"""
OFFLINE = "不在线"
- PEEKING = "看一眼手机"
NORMAL_CHAT = "正常看手机"
FOCUSED_CHAT = "专心看手机"
- def get_normal_chat_max_num(self):
- # 调试用
- if enable_unlimited_hfc_chat:
- return 1000
-
- if self == MaiState.OFFLINE:
- return 0
- elif self == MaiState.PEEKING:
- return MAX_NORMAL_CHAT_NUM_PEEKING
- elif self == MaiState.NORMAL_CHAT:
- return MAX_NORMAL_CHAT_NUM_NORMAL
- elif self == MaiState.FOCUSED_CHAT:
- return MAX_NORMAL_CHAT_NUM_FOCUSED
- return None
-
- def get_focused_chat_max_num(self):
- # 调试用
- if enable_unlimited_hfc_chat:
- return 1000
-
- if self == MaiState.OFFLINE:
- return 0
- elif self == MaiState.PEEKING:
- return MAX_FOCUSED_CHAT_NUM_PEEKING
- elif self == MaiState.NORMAL_CHAT:
- return MAX_FOCUSED_CHAT_NUM_NORMAL
- elif self == MaiState.FOCUSED_CHAT:
- return MAX_FOCUSED_CHAT_NUM_FOCUSED
- return None
-
class MaiStateInfo:
def __init__(self):
@@ -148,34 +100,18 @@ class MaiStateManager:
_time_since_last_min_check = current_time - current_state_info.last_min_check_time
next_state: Optional[MaiState] = None
- # 辅助函数:根据 prevent_offline_state 标志调整目标状态
def _resolve_offline(candidate_state: MaiState) -> MaiState:
- # 现在不再切换到OFFLINE,直接返回当前状态
if candidate_state == MaiState.OFFLINE:
return current_status
return candidate_state
if current_status == MaiState.OFFLINE:
logger.info("当前[离线],没看手机,思考要不要上线看看......")
- elif current_status == MaiState.PEEKING:
- logger.info("当前[看一眼手机],思考要不要继续聊下去......")
elif current_status == MaiState.NORMAL_CHAT:
logger.info("当前在[正常看手机]思考要不要继续聊下去......")
elif current_status == MaiState.FOCUSED_CHAT:
logger.info("当前在[专心看手机]思考要不要继续聊下去......")
- # 1. 移除每分钟概率切换到OFFLINE的逻辑
- # if time_since_last_min_check >= 60:
- # if current_status != MaiState.OFFLINE:
- # if random.random() < 0.03: # 3% 概率切换到 OFFLINE
- # potential_next = MaiState.OFFLINE
- # resolved_next = _resolve_offline(potential_next)
- # logger.debug(f"概率触发下线,resolve 为 {resolved_next.value}")
- # # 只有当解析后的状态与当前状态不同时才设置 next_state
- # if resolved_next != current_status:
- # next_state = resolved_next
-
- # 2. 状态持续时间规则 (只有在规则1没有触发状态改变时才检查)
if next_state is None:
time_limit_exceeded = False
choices_list = []
@@ -183,44 +119,33 @@ class MaiStateManager:
rule_id = ""
if current_status == MaiState.OFFLINE:
- # OFFLINE 状态不再自动切换,直接返回 None
return None
- elif current_status == MaiState.PEEKING:
- if time_in_current_status >= 600: # PEEKING 最多持续 600 秒
- time_limit_exceeded = True
- rule_id = "2.2 (From PEEKING)"
- weights = [50, 50]
- choices_list = [MaiState.NORMAL_CHAT, MaiState.FOCUSED_CHAT]
elif current_status == MaiState.NORMAL_CHAT:
if time_in_current_status >= 300: # NORMAL_CHAT 最多持续 300 秒
time_limit_exceeded = True
rule_id = "2.3 (From NORMAL_CHAT)"
- weights = [50, 50]
- choices_list = [MaiState.PEEKING, MaiState.FOCUSED_CHAT]
+ weights = [100]
+ choices_list = [MaiState.FOCUSED_CHAT]
elif current_status == MaiState.FOCUSED_CHAT:
if time_in_current_status >= 600: # FOCUSED_CHAT 最多持续 600 秒
time_limit_exceeded = True
rule_id = "2.4 (From FOCUSED_CHAT)"
- weights = [50, 50]
- choices_list = [MaiState.NORMAL_CHAT, MaiState.PEEKING]
+ weights = [100]
+ choices_list = [MaiState.NORMAL_CHAT]
if time_limit_exceeded:
next_state_candidate = random.choices(choices_list, weights=weights, k=1)[0]
resolved_candidate = _resolve_offline(next_state_candidate)
logger.debug(
- f"规则{rule_id}:时间到,随机选择 {next_state_candidate.value},resolve 为 {resolved_candidate.value}"
+ f"规则{rule_id}:时间到,切换到 {next_state_candidate.value},resolve 为 {resolved_candidate.value}"
)
- next_state = resolved_candidate # 直接使用解析后的状态
+ next_state = resolved_candidate
- # 注意:enable_unlimited_hfc_chat 优先级高于 prevent_offline_state
- # 如果触发了这个,它会覆盖上面规则2设置的 next_state
if enable_unlimited_hfc_chat:
logger.debug("调试用:开挂了,强制切换到专注聊天")
next_state = MaiState.FOCUSED_CHAT
- # --- 最终决策 --- #
- # 如果决定了下一个状态,且这个状态与当前状态不同,则返回下一个状态
if next_state is not None and next_state != current_status:
return next_state
else:
- return None # 没有状态转换发生或无需重置计时器
+ return None
diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py
index a51eba5e2..9bd10e511 100644
--- a/src/chat/heart_flow/observation/chatting_observation.py
+++ b/src/chat/heart_flow/observation/chatting_observation.py
@@ -14,6 +14,7 @@ from typing import Optional
import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.heart_flow.observation.observation import Observation
+
from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt
@@ -43,6 +44,7 @@ class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
self.chat_id = chat_id
+ self.platform = "qq"
# --- Initialize attributes (defaults) ---
self.is_group_chat: bool = False
@@ -53,19 +55,20 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
- self.name = global_config.BOT_NICKNAME
- self.nick_name = global_config.BOT_ALIAS_NAMES
- self.max_now_obs_len = global_config.observation_context_size
- self.overlap_len = global_config.compressed_length
- self.mid_memorys = []
- self.max_mid_memory_len = global_config.compress_length_limit
+ self.name = global_config.bot.nickname
+ self.nick_name = global_config.bot.alias_names
+ self.max_now_obs_len = global_config.chat.observation_context_size
+ self.overlap_len = global_config.focus_chat.compressed_length
+ self.mid_memories = []
+ self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
- self.llm_summary = LLMRequest(
- model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
+ # TODO: API-Adapter修改标记
+ self.model_summary = LLMRequest(
+ model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def initialize(self):
@@ -83,7 +86,7 @@ class ChattingObservation(Observation):
for id in ids:
print(f"id:{id}")
try:
- for mid_memory in self.mid_memorys:
+ for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
@@ -101,11 +104,11 @@ class ChattingObservation(Observation):
else:
mid_memory_str = "之前的聊天内容:\n"
- for mid_memory in self.mid_memorys:
+ for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
- def serch_message_by_text(self, text: str) -> Optional[MessageRecv]:
+ def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
"""
根据回复的纯文本
1. 在talking_message中查找最新的,最匹配的消息
@@ -118,12 +121,12 @@ class ChattingObservation(Observation):
for message in reverse_talking_message:
if message["processed_plain_text"] == text:
find_msg = message
- logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
+ # logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
break
else:
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
msg_list.append({"message": message, "similarity": similarity})
- logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}")
+ # logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg:
if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
@@ -137,8 +140,23 @@ class ChattingObservation(Observation):
return None
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
- group_info = find_msg.get("chat_info", {}).get("group_info")
- user_info = find_msg.get("chat_info", {}).get("user_info")
+
+ # 创建所需的user_info字段
+ user_info = {
+ "platform": find_msg.get("user_platform", ""),
+ "user_id": find_msg.get("user_id", ""),
+ "user_nickname": find_msg.get("user_nickname", ""),
+ "user_cardname": find_msg.get("user_cardname", ""),
+ }
+
+ # 创建所需的group_info字段,如果是群聊的话
+ group_info = {}
+ if find_msg.get("chat_info_group_id"):
+ group_info = {
+ "platform": find_msg.get("chat_info_group_platform", ""),
+ "group_id": find_msg.get("chat_info_group_id", ""),
+ "group_name": find_msg.get("chat_info_group_name", ""),
+ }
content_format = ""
accept_format = ""
@@ -150,7 +168,7 @@ class ChattingObservation(Observation):
}
message_info = {
- "platform": find_msg.get("platform"),
+ "platform": self.platform,
"message_id": find_msg.get("message_id"),
"time": find_msg.get("time"),
"group_info": group_info,
@@ -179,6 +197,8 @@ class ChattingObservation(Observation):
limit_mode="latest",
)
+ # print(f"new_messages_list: {new_messages_list}")
+
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
@@ -190,6 +210,7 @@ class ChattingObservation(Observation):
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
+ # print(f"压缩中:oldest_messages: {oldest_messages}")
oldest_messages_str = await build_readable_messages(
messages=oldest_messages, timestamp_mode="normal", read_mark=0
)
@@ -232,21 +253,24 @@ class ChattingObservation(Observation):
self.oldest_messages = oldest_messages
self.oldest_messages_str = oldest_messages_str
+ # 构建中
+ # print(f"构建中:self.talking_message: {self.talking_message}")
self.talking_message_str = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
)
+ # print(f"构建中:self.talking_message_str: {self.talking_message_str}")
self.talking_message_str_truncate = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal",
read_mark=last_obs_time_mark,
truncate=True,
)
+ # print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
self.person_list = await get_person_id_list(self.talking_message)
-
- # print(f"self.11111person_list: {self.person_list}")
+ # print(f"构建中:self.person_list: {self.person_list}")
logger.trace(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
diff --git a/src/chat/heart_flow/observation/hfcloop_observation.py b/src/chat/heart_flow/observation/hfcloop_observation.py
index 470671e28..d712b83be 100644
--- a/src/chat/heart_flow/observation/hfcloop_observation.py
+++ b/src/chat/heart_flow/observation/hfcloop_observation.py
@@ -3,6 +3,7 @@
from datetime import datetime
from src.common.logger_manager import get_logger
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
+from src.chat.focus_chat.planners.action_manager import ActionManager
from typing import List
# Import the new utility function
@@ -16,15 +17,20 @@ class HFCloopObservation:
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
+ self.action_manager: ActionManager = None
+
+ self.all_actions = {}
def get_observe_info(self):
return self.observe_info
def add_loop_info(self, loop_info: CycleDetail):
- # logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
- # print(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
self.history_loop.append(loop_info)
+ def set_action_manager(self, action_manager: ActionManager):
+ self.action_manager = action_manager
+ self.all_actions = self.action_manager.get_registered_actions()
+
async def observe(self):
recent_active_cycles: List[CycleDetail] = []
for cycle in reversed(self.history_loop):
@@ -62,7 +68,6 @@ class HFCloopObservation:
if cycle_info_block:
cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
else:
- # 如果最近的活动循环不是文本回复,或者没有活动循环
cycle_info_block = "\n"
# 获取history_loop中最新添加的
@@ -72,8 +77,16 @@ class HFCloopObservation:
end_time = last_loop.end_time
if start_time is not None and end_time is not None:
time_diff = int(end_time - start_time)
- cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n"
+ if time_diff > 60:
+ cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n"
+ else:
+ cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n"
else:
- cycle_info_block += "\n无法获取上一次阅读消息的时间\n"
+ cycle_info_block += "\n你还没看过消息\n"
+
+ using_actions = self.action_manager.get_using_actions()
+ for action_name, action_info in using_actions.items():
+ action_description = action_info["description"]
+ cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
self.observe_info = cycle_info_block
diff --git a/src/chat/heart_flow/observation/memory_observation.py b/src/chat/heart_flow/observation/memory_observation.py
deleted file mode 100644
index 1938a47d3..000000000
--- a/src/chat/heart_flow/observation/memory_observation.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from src.chat.heart_flow.observation.observation import Observation
-from datetime import datetime
-from src.common.logger_manager import get_logger
-import traceback
-
-# Import the new utility function
-from src.chat.memory_system.Hippocampus import HippocampusManager
-import jieba
-from typing import List
-
-logger = get_logger("memory")
-
-
-class MemoryObservation(Observation):
- def __init__(self, observe_id):
- super().__init__(observe_id)
- self.observe_info: str = ""
- self.context: str = ""
- self.running_memory: List[dict] = []
-
- def get_observe_info(self):
- for memory in self.running_memory:
- self.observe_info += f"{memory['topic']}:{memory['content']}\n"
- return self.observe_info
-
- async def observe(self):
- # ---------- 2. 获取记忆 ----------
- try:
- # 从聊天内容中提取关键词
- chat_words = set(jieba.cut(self.context))
- # 过滤掉停用词和单字词
- keywords = [word for word in chat_words if len(word) > 1]
- # 去重并限制数量
- keywords = list(set(keywords))[:5]
-
- logger.debug(f"取的关键词: {keywords}")
-
- # 调用记忆系统获取相关记忆
- related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
- valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
- )
-
- logger.debug(f"获取到的记忆: {related_memory}")
-
- if related_memory:
- for topic, memory in related_memory:
- # 将记忆添加到 running_memory
- self.running_memory.append(
- {"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()}
- )
- logger.debug(f"添加新记忆: {topic} - {memory}")
-
- except Exception as e:
- logger.error(f"观察 记忆时出错: {e}")
- logger.error(traceback.format_exc())
diff --git a/src/chat/heart_flow/observation/structure_observation.py b/src/chat/heart_flow/observation/structure_observation.py
new file mode 100644
index 000000000..2732ef0b1
--- /dev/null
+++ b/src/chat/heart_flow/observation/structure_observation.py
@@ -0,0 +1,32 @@
+from datetime import datetime
+from src.common.logger_manager import get_logger
+
+# Import the new utility function
+
+logger = get_logger("observation")
+
+
+# 所有观察的基类
+class StructureObservation:
+ def __init__(self, observe_id):
+ self.observe_info = ""
+ self.observe_id = observe_id
+ self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
+ self.history_loop = []
+ self.structured_info = []
+
+ def get_observe_info(self):
+ return self.structured_info
+
+ def add_structured_info(self, structured_info: dict):
+ self.structured_info.append(structured_info)
+
+ async def observe(self):
+ observed_structured_infos = []
+ for structured_info in self.structured_info:
+ if structured_info.get("ttl") > 0:
+ structured_info["ttl"] -= 1
+ observed_structured_infos.append(structured_info)
+ logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
+
+ self.structured_info = observed_structured_infos
diff --git a/src/chat/heart_flow/observation/working_observation.py b/src/chat/heart_flow/observation/working_observation.py
index 27b6ab92d..7013c3a2b 100644
--- a/src/chat/heart_flow/observation/working_observation.py
+++ b/src/chat/heart_flow/observation/working_observation.py
@@ -2,33 +2,33 @@
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger_manager import get_logger
-
+from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
+from src.chat.focus_chat.working_memory.memory_item import MemoryItem
+from typing import List
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
-class WorkingObservation:
- def __init__(self, observe_id):
+class WorkingMemoryObservation:
+ def __init__(self, observe_id, working_memory: WorkingMemory):
self.observe_info = ""
self.observe_id = observe_id
- self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
- self.history_loop = []
- self.structured_info = []
+ self.last_observe_time = datetime.now().timestamp()
+
+ self.working_memory = working_memory
+
+ self.retrieved_working_memory = []
def get_observe_info(self):
- return self.structured_info
+ return self.working_memory
- def add_structured_info(self, structured_info: dict):
- self.structured_info.append(structured_info)
+ def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
+ self.retrieved_working_memory.append(retrieved_working_memory)
+
+ def get_retrieved_working_memory(self):
+ return self.retrieved_working_memory
async def observe(self):
- observed_structured_infos = []
- for structured_info in self.structured_info:
- if structured_info.get("ttl") > 0:
- structured_info["ttl"] -= 1
- observed_structured_infos.append(structured_info)
- logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
-
- self.structured_info = observed_structured_infos
+ pass
diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py
index 157c1c957..c440f8cfd 100644
--- a/src/chat/heart_flow/sub_heartflow.py
+++ b/src/chat/heart_flow/sub_heartflow.py
@@ -89,6 +89,14 @@ class SubHeartflow:
await self.interest_chatting.initialize()
logger.debug(f"{self.log_prefix} InterestChatting 实例已初始化。")
+ # 创建并初始化 normal_chat_instance
+ chat_stream = chat_manager.get_stream(self.chat_id)
+ if chat_stream:
+ self.normal_chat_instance = NormalChat(chat_stream=chat_stream,interest_dict=self.get_interest_dict())
+ await self.normal_chat_instance.initialize()
+ await self.normal_chat_instance.start_chat()
+ logger.info(f"{self.log_prefix} NormalChat 实例已创建并启动。")
+
def update_last_chat_state_time(self):
self.chat_state_last_time = time.time() - self.chat_state_changed_time
@@ -181,8 +189,7 @@ class SubHeartflow:
# 创建 HeartFChatting 实例,并传递 从构造函数传入的 回调函数
self.heart_fc_instance = HeartFChatting(
chat_id=self.subheartflow_id,
- observations=self.observations, # 传递所有观察者
- on_consecutive_no_reply_callback=self.hfc_no_reply_callback, # <-- Use stored callback
+ observations=self.observations,
)
# 初始化并启动 HeartFChatting
@@ -200,55 +207,41 @@ class SubHeartflow:
self.heart_fc_instance = None # 创建或初始化异常,清理实例
return False
- async def change_chat_state(self, new_state: "ChatState"):
- """更新sub_heartflow的聊天状态,并管理 HeartFChatting 和 NormalChat 实例及任务"""
+ async def change_chat_state(self, new_state: ChatState) -> None:
+ """
+ 改变聊天状态。
+ 如果转换到CHAT或FOCUSED状态时超过限制,会保持当前状态。
+ """
current_state = self.chat_state.chat_status
+ state_changed = False
+ log_prefix = f"[{self.log_prefix}]"
- if current_state == new_state:
- return
-
- log_prefix = self.log_prefix
- state_changed = False # 标记状态是否实际发生改变
-
- # --- 状态转换逻辑 ---
if new_state == ChatState.CHAT:
- # 移除限额检查逻辑
- logger.debug(f"{log_prefix} 准备进入或保持 聊天 状态")
- if current_state == ChatState.FOCUSED:
- if await self._start_normal_chat(rewind=False):
- # logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
- state_changed = True
- else:
- logger.error(f"{log_prefix} 从FOCUSED状态启动 NormalChat 失败,无法进入 CHAT 状态。")
- # 考虑是否需要回滚状态或采取其他措施
- return # 启动失败,不改变状态
+ logger.debug(f"{log_prefix} 准备进入或保持 普通聊天 状态")
+ if await self._start_normal_chat():
+ logger.debug(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
+ state_changed = True
else:
- if await self._start_normal_chat(rewind=True):
- # logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
- state_changed = True
- else:
- logger.error(f"{log_prefix} 从ABSENT状态启动 NormalChat 失败,无法进入 CHAT 状态。")
- # 考虑是否需要回滚状态或采取其他措施
- return # 启动失败,不改变状态
+ logger.error(f"{log_prefix} 启动 NormalChat 失败,无法进入 CHAT 状态。")
+ # 启动失败时,保持当前状态
+ return
elif new_state == ChatState.FOCUSED:
- # 移除限额检查逻辑
logger.debug(f"{log_prefix} 准备进入或保持 专注聊天 状态")
if await self._start_heart_fc_chat():
logger.debug(f"{log_prefix} 成功进入或保持 HeartFChatting 状态。")
state_changed = True
else:
logger.error(f"{log_prefix} 启动 HeartFChatting 失败,无法进入 FOCUSED 状态。")
- # 启动失败,状态回滚到之前的状态或ABSENT?这里保持不改变
- return # 启动失败,不改变状态
+ # 启动失败时,保持当前状态
+ return
elif new_state == ChatState.ABSENT:
logger.info(f"{log_prefix} 进入 ABSENT 状态,停止所有聊天活动...")
self.clear_interest_dict()
-
await self._stop_normal_chat()
await self._stop_heart_fc_chat()
- state_changed = True # 总是可以成功转换到 ABSENT
+ state_changed = True
# --- 更新状态和最后活动时间 ---
if state_changed:
@@ -263,7 +256,6 @@ class SubHeartflow:
self.chat_state_last_time = 0
self.chat_state_changed_time = time.time()
else:
- # 如果因为某些原因(如启动失败)没有成功改变状态,记录一下
logger.debug(
f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。"
)
diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py
index a4bff8338..22bab6a40 100644
--- a/src/chat/heart_flow/subheartflow_manager.py
+++ b/src/chat/heart_flow/subheartflow_manager.py
@@ -1,26 +1,14 @@
import asyncio
import time
import random
-from typing import Dict, Any, Optional, List, Tuple
-import json # 导入 json 模块
-import functools # <-- 新增导入
-
-# 导入日志模块
+from typing import Dict, Any, Optional, List
+import functools
from src.common.logger_manager import get_logger
-
-# 导入聊天流管理模块
from src.chat.message_receive.chat_stream import chat_manager
-
-# 导入心流相关类
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
-
-# 导入LLM请求工具
-from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
-from src.individuality.individuality import Individuality
-import traceback
# 初始化日志记录器
@@ -74,14 +62,6 @@ class SubHeartflowManager:
self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问
self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例
- # 为 LLM 状态评估创建一个 LLMRequest 实例
- # 使用与 Heartflow 相同的模型和参数
- self.llm_state_evaluator = LLMRequest(
- model=global_config.llm_heartflow, # 与 Heartflow 一致
- temperature=0.6, # 与 Heartflow 一致
- max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
- request_type="subheartflow_state_eval", # 保留特定的请求类型
- )
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
"""强制改变指定子心流的状态"""
@@ -155,10 +135,6 @@ class SubHeartflowManager:
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
return None
- # --- 新增:内部方法,用于尝试将单个子心流设置为 ABSENT ---
-
- # --- 结束新增 ---
-
async def sleep_subheartflow(self, subheartflow_id: Any, reason: str) -> bool:
"""停止指定的子心流并将其状态设置为 ABSENT"""
log_prefix = "[子心流管理]"
@@ -189,54 +165,6 @@ class SubHeartflowManager:
return flows_to_stop
- async def enforce_subheartflow_limits(self):
- """根据主状态限制停止超额子心流(优先停不活跃的)"""
- # 使用 self.mai_state_info 获取当前状态和限制
- current_mai_state = self.mai_state_info.get_current_state()
- normal_limit = current_mai_state.get_normal_chat_max_num()
- focused_limit = current_mai_state.get_focused_chat_max_num()
- logger.debug(f"[限制] 状态:{current_mai_state.value}, 普通限:{normal_limit}, 专注限:{focused_limit}")
-
- # 分类统计当前子心流
- normal_flows = []
- focused_flows = []
- for flow_id, flow in list(self.subheartflows.items()):
- if flow.chat_state.chat_status == ChatState.CHAT:
- normal_flows.append((flow_id, getattr(flow, "last_active_time", 0)))
- elif flow.chat_state.chat_status == ChatState.FOCUSED:
- focused_flows.append((flow_id, getattr(flow, "last_active_time", 0)))
-
- logger.debug(f"[限制] 当前数量 - 普通:{len(normal_flows)}, 专注:{len(focused_flows)}")
- stopped = 0
-
- # 处理普通聊天超额
- if len(normal_flows) > normal_limit:
- excess = len(normal_flows) - normal_limit
- logger.info(f"[限制] 普通聊天超额({len(normal_flows)}>{normal_limit}), 停止{excess}个")
- normal_flows.sort(key=lambda x: x[1])
- for flow_id, _ in normal_flows[:excess]:
- if await self.sleep_subheartflow(flow_id, f"普通聊天超额(限{normal_limit})"):
- stopped += 1
-
- # 处理专注聊天超额(需重新统计)
- focused_flows = [
- (fid, t)
- for fid, f in list(self.subheartflows.items())
- if (t := getattr(f, "last_active_time", 0)) and f.chat_state.chat_status == ChatState.FOCUSED
- ]
- if len(focused_flows) > focused_limit:
- excess = len(focused_flows) - focused_limit
- logger.info(f"[限制] 专注聊天超额({len(focused_flows)}>{focused_limit}), 停止{excess}个")
- focused_flows.sort(key=lambda x: x[1])
- for flow_id, _ in focused_flows[:excess]:
- if await self.sleep_subheartflow(flow_id, f"专注聊天超额(限{focused_limit})"):
- stopped += 1
-
- if stopped:
- logger.info(f"[限制] 已停止{stopped}个子心流, 剩余:{len(self.subheartflows)}")
- else:
- logger.debug(f"[限制] 无需停止, 当前总数:{len(self.subheartflows)}")
-
async def deactivate_all_subflows(self):
"""将所有子心流的状态更改为 ABSENT (例如主状态变为OFFLINE时调用)"""
log_prefix = "[停用]"
@@ -272,27 +200,14 @@ class SubHeartflowManager:
)
async def sbhf_absent_into_focus(self):
- """评估子心流兴趣度,满足条件且未达上限则提升到FOCUSED状态(基于start_hfc_probability)"""
+ """评估子心流兴趣度,满足条件则提升到FOCUSED状态(基于start_hfc_probability)"""
try:
current_state = self.mai_state_info.get_current_state()
- focused_limit = current_state.get_focused_chat_max_num()
- # --- 新增:检查是否允许进入 FOCUS 模式 --- #
- if not global_config.allow_focus_mode:
+ # 检查是否允许进入 FOCUS 模式
+ if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
- return # 如果不允许,直接返回
- # --- 结束新增 ---
-
- logger.info(f"当前状态 ({current_state.value}) 可以在{focused_limit}个群 专注聊天")
-
- if focused_limit <= 0:
- # logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 不允许 FOCUSED 子心流")
- return
-
- current_focused_count = self.count_subflows_by_state(ChatState.FOCUSED)
- if current_focused_count >= focused_limit:
- logger.debug(f"已达专注上限 ({current_focused_count}/{focused_limit})")
return
for sub_hf in list(self.subheartflows.values()):
@@ -320,11 +235,6 @@ class SubHeartflowManager:
if random.random() >= sub_hf.interest_chatting.start_hfc_probability:
continue
- # 再次检查是否达到上限
- if current_focused_count >= focused_limit:
- logger.debug(f"{stream_name} 已达专注上限")
- break
-
# 获取最新状态并执行提升
current_subflow = self.subheartflows.get(flow_id)
if not current_subflow:
@@ -337,283 +247,57 @@ class SubHeartflowManager:
# 执行状态提升
await current_subflow.change_chat_state(ChatState.FOCUSED)
- # 验证提升结果
- if (
- final_subflow := self.subheartflows.get(flow_id)
- ) and final_subflow.chat_state.chat_status == ChatState.FOCUSED:
- current_focused_count += 1
except Exception as e:
logger.error(f"启动HFC 兴趣评估失败: {e}", exc_info=True)
- async def sbhf_absent_into_chat(self):
+
+ async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any):
"""
- 随机选一个 ABSENT 状态的 *群聊* 子心流,评估是否应转换为 CHAT 状态。
- 每次调用最多转换一个。
- 私聊会被忽略。
- """
- current_mai_state = self.mai_state_info.get_current_state()
- chat_limit = current_mai_state.get_normal_chat_max_num()
-
- async with self._lock:
- # 1. 筛选出所有 ABSENT 状态的 *群聊* 子心流
- absent_group_subflows = [
- hf
- for hf in self.subheartflows.values()
- if hf.chat_state.chat_status == ChatState.ABSENT and hf.is_group_chat
- ]
-
- if not absent_group_subflows:
- # logger.debug("没有摸鱼的群聊子心流可以评估。") # 日志太频繁
- return # 没有目标,直接返回
-
- # 2. 随机选一个幸运儿
- sub_hf_to_evaluate = random.choice(absent_group_subflows)
- flow_id = sub_hf_to_evaluate.subheartflow_id
- stream_name = chat_manager.get_stream_name(flow_id) or flow_id
- log_prefix = f"[{stream_name}]"
-
- # 3. 检查 CHAT 上限
- current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT)
- if current_chat_count >= chat_limit:
- logger.info(f"{log_prefix} 想看看能不能聊,但是聊天太多了, ({current_chat_count}/{chat_limit}) 满了。")
- return # 满了,这次就算了
-
- # --- 获取 FOCUSED 计数 ---
- current_focused_count = self.count_subflows_by_state_nolock(ChatState.FOCUSED)
- focused_limit = current_mai_state.get_focused_chat_max_num()
-
- # --- 新增:获取聊天和专注群名 ---
- chatting_group_names = []
- focused_group_names = []
- for flow_id, hf in self.subheartflows.items():
- stream_name = chat_manager.get_stream_name(flow_id) or str(flow_id) # 保证有名字
- if hf.chat_state.chat_status == ChatState.CHAT:
- chatting_group_names.append(stream_name)
- elif hf.chat_state.chat_status == ChatState.FOCUSED:
- focused_group_names.append(stream_name)
- # --- 结束新增 ---
-
- # --- 获取观察信息和构建 Prompt ---
- first_observation = sub_hf_to_evaluate.observations[0] # 喵~第一个观察者肯定存在的说
- await first_observation.observe()
- current_chat_log = first_observation.talking_message_str or "当前没啥聊天内容。"
- _observation_summary = f"在[{stream_name}]这个群中,你最近看群友聊了这些:\n{current_chat_log}"
-
- _mai_state_description = f"你当前状态: {current_mai_state.value}。"
- individuality = Individuality.get_instance()
- personality_prompt = individuality.get_prompt(x_person=2, level=2)
- prompt_personality = f"你正在扮演名为{individuality.name}的人类,{personality_prompt}"
-
- # --- 修改:在 prompt 中加入当前聊天计数和群名信息 (条件显示) ---
- chat_status_lines = []
- if chatting_group_names:
- chat_status_lines.append(
- f"正在这些群闲聊 ({current_chat_count}/{chat_limit}): {', '.join(chatting_group_names)}"
- )
- if focused_group_names:
- chat_status_lines.append(
- f"正在这些群专注的聊天 ({current_focused_count}/{focused_limit}): {', '.join(focused_group_names)}"
- )
-
- chat_status_prompt = "当前没有在任何群聊中。" # 默认消息喵~
- if chat_status_lines:
- chat_status_prompt = "当前聊天情况,你已经参与了下面这几个群的聊天:\n" + "\n".join(
- chat_status_lines
- ) # 拼接状态信息
-
- prompt = (
- f"{prompt_personality}\n"
- f"{chat_status_prompt}\n" # <-- 喵!用了新的状态信息~
- f"你当前尚未加入 [{stream_name}] 群聊天。\n"
- f"{_observation_summary}\n---\n"
- f"基于以上信息,你想不想开始在这个群闲聊?\n"
- f"请说明理由,并以 JSON 格式回答,包含 'decision' (布尔值) 和 'reason' (字符串)。\n"
- f'例如:{{"decision": true, "reason": "看起来挺热闹的,插个话"}}\n'
- f'例如:{{"decision": false, "reason": "已经聊了好多,休息一下"}}\n'
- f"请只输出有效的 JSON 对象。"
- )
- # --- 结束修改 ---
-
- # --- 4. LLM 评估是否想聊 ---
- yao_kai_shi_liao_ma, reason = await self._llm_evaluate_state_transition(prompt)
-
- if reason:
- if yao_kai_shi_liao_ma:
- logger.info(f"{log_prefix} 打算开始聊,原因是: {reason}")
- else:
- logger.info(f"{log_prefix} 不打算聊,原因是: {reason}")
- else:
- logger.info(f"{log_prefix} 结果: {yao_kai_shi_liao_ma}")
-
- if yao_kai_shi_liao_ma is None:
- logger.debug(f"{log_prefix} 问AI想不想聊失败了,这次算了。")
- return # 评估失败,结束
-
- if not yao_kai_shi_liao_ma:
- # logger.info(f"{log_prefix} 现在不想聊这个群。")
- return # 不想聊,结束
-
- # --- 5. AI想聊,再次检查额度并尝试转换 ---
- # 再次检查以防万一
- current_chat_count_before_change = self.count_subflows_by_state_nolock(ChatState.CHAT)
- if current_chat_count_before_change < chat_limit:
- logger.info(
- f"{log_prefix} 想聊,而且还有精力 ({current_chat_count_before_change}/{chat_limit}),这就去聊!"
- )
- await sub_hf_to_evaluate.change_chat_state(ChatState.CHAT)
- # 确认转换成功
- if sub_hf_to_evaluate.chat_state.chat_status == ChatState.CHAT:
- logger.debug(f"{log_prefix} 成功进入聊天状态!本次评估圆满结束。")
- else:
- logger.warning(
- f"{log_prefix} 奇怪,尝试进入聊天状态失败了。当前状态: {sub_hf_to_evaluate.chat_state.chat_status.value}"
- )
- else:
- logger.warning(
- f"{log_prefix} AI说想聊,但是刚问完就没空位了 ({current_chat_count_before_change}/{chat_limit})。真不巧,下次再说吧。"
- )
- # 无论转换成功与否,本次评估都结束了
-
- # 锁在这里自动释放
-
- # --- 新增:单独检查 CHAT 状态超时的任务 ---
- async def sbhf_chat_into_absent(self):
- """定期检查处于 CHAT 状态的子心流是否因长时间未发言而超时,并将其转为 ABSENT。"""
- log_prefix_task = "[聊天超时检查]"
- transitioned_to_absent = 0
- checked_count = 0
-
- async with self._lock:
- subflows_snapshot = list(self.subheartflows.values())
- checked_count = len(subflows_snapshot)
-
- if not subflows_snapshot:
- return
-
- for sub_hf in subflows_snapshot:
- # 只检查 CHAT 状态的子心流
- if sub_hf.chat_state.chat_status != ChatState.CHAT:
- continue
-
- flow_id = sub_hf.subheartflow_id
- stream_name = chat_manager.get_stream_name(flow_id) or flow_id
- log_prefix = f"[{stream_name}]({log_prefix_task})"
-
- should_deactivate = False
- reason = ""
-
- try:
- last_bot_dong_zuo_time = sub_hf.get_normal_chat_last_speak_time()
-
- if last_bot_dong_zuo_time > 0:
- current_time = time.time()
- time_since_last_bb = current_time - last_bot_dong_zuo_time
- minutes_since_last_bb = time_since_last_bb / 60
-
- # 60分钟强制退出
- if minutes_since_last_bb >= 60:
- should_deactivate = True
- reason = "超过60分钟未发言,强制退出"
- else:
- # 根据时间区间确定退出概率
- exit_probability = 0
- if minutes_since_last_bb < 5:
- exit_probability = 0.01 # 1%
- elif minutes_since_last_bb < 15:
- exit_probability = 0.02 # 2%
- elif minutes_since_last_bb < 30:
- exit_probability = 0.04 # 4%
- else:
- exit_probability = 0.08 # 8%
-
- # 随机判断是否退出
- if random.random() < exit_probability:
- should_deactivate = True
- reason = f"已{minutes_since_last_bb:.1f}分钟未发言,触发{exit_probability * 100:.0f}%退出概率"
-
- except AttributeError:
- logger.error(
- f"{log_prefix} 无法获取 Bot 最后 BB 时间,请确保 SubHeartflow 相关实现正确。跳过超时检查。"
- )
- except Exception as e:
- logger.error(f"{log_prefix} 检查 Bot 超时状态时出错: {e}", exc_info=True)
-
- # 执行状态转换(如果超时)
- if should_deactivate:
- logger.debug(f"{log_prefix} 因超时 ({reason}),尝试转换为 ABSENT 状态。")
- await sub_hf.change_chat_state(ChatState.ABSENT)
- # 再次检查确保状态已改变
- if sub_hf.chat_state.chat_status == ChatState.ABSENT:
- transitioned_to_absent += 1
- logger.info(f"{log_prefix} 不看了。")
- else:
- logger.warning(f"{log_prefix} 尝试因超时转换为 ABSENT 失败。")
-
- if transitioned_to_absent > 0:
- logger.debug(
- f"{log_prefix_task} 完成,共检查 {checked_count} 个子心流,{transitioned_to_absent} 个因超时转为 ABSENT。"
- )
-
- # --- 结束新增 ---
-
- async def _llm_evaluate_state_transition(self, prompt: str) -> Tuple[Optional[bool], Optional[str]]:
- """
- 使用 LLM 评估是否应进行状态转换,期望 LLM 返回 JSON 格式。
+ 接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 CHAT。
+ 通常在连续多次 "no_reply" 后被调用。
+ 对于私聊和群聊,都转换为 CHAT。
Args:
- prompt: 提供给 LLM 的提示信息,要求返回 {"decision": true/false}。
-
- Returns:
- Optional[bool]: 如果成功解析 LLM 的 JSON 响应并提取了 'decision' 键的值,则返回该布尔值。
- 如果 LLM 调用失败、返回无效 JSON 或 JSON 中缺少 'decision' 键或其值不是布尔型,则返回 None。
+ subflow_id: 需要转换状态的子心流 ID。
"""
- log_prefix = "[LLM状态评估]"
- try:
- # --- 真实的 LLM 调用 ---
- response_text, _ = await self.llm_state_evaluator.generate_response_async(prompt)
- # logger.debug(f"{log_prefix} 使用模型 {self.llm_state_evaluator.model_name} 评估")
- logger.debug(f"{log_prefix} 原始输入: {prompt}")
- logger.debug(f"{log_prefix} 原始评估结果: {response_text}")
+ async with self._lock:
+ subflow = self.subheartflows.get(subflow_id)
+ if not subflow:
+ logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 CHAT")
+ return
- # --- 解析 JSON 响应 ---
- try:
- # 尝试去除可能的Markdown代码块标记
- cleaned_response = response_text.strip().strip("`").strip()
- if cleaned_response.startswith("json"):
- cleaned_response = cleaned_response[4:].strip()
+ stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
+ current_state = subflow.chat_state.chat_status
- data = json.loads(cleaned_response)
- decision = data.get("decision") # 使用 .get() 避免 KeyError
- reason = data.get("reason")
+ if current_state == ChatState.FOCUSED:
+ target_state = ChatState.CHAT
+ log_reason = "转为CHAT"
- if isinstance(decision, bool):
- logger.debug(f"{log_prefix} LLM评估结果 (来自JSON): {'建议转换' if decision else '建议不转换'}")
-
- return decision, reason
- else:
- logger.warning(
- f"{log_prefix} LLM 返回的 JSON 中 'decision' 键的值不是布尔型: {decision}。响应: {response_text}"
+ logger.info(
+ f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})"
+ )
+ try:
+ # 从HFC到CHAT时,清空兴趣字典
+ subflow.clear_interest_dict()
+ await subflow.change_chat_state(target_state)
+ final_state = subflow.chat_state.chat_status
+ if final_state == target_state:
+ logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}")
+ else:
+ logger.warning(
+ f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}"
+ )
+ except Exception as e:
+ logger.error(
+ f"[状态转换请求] 转换 {stream_name} 到 {target_state.value} 时出错: {e}", exc_info=True
)
- return None, None # 值类型不正确
-
- except json.JSONDecodeError as json_err:
- logger.warning(f"{log_prefix} LLM 返回的响应不是有效的 JSON: {json_err}。响应: {response_text}")
- # 尝试在非JSON响应中查找关键词作为后备方案 (可选)
- if "true" in response_text.lower():
- logger.debug(f"{log_prefix} 在非JSON响应中找到 'true',解释为建议转换")
- return True, None
- if "false" in response_text.lower():
- logger.debug(f"{log_prefix} 在非JSON响应中找到 'false',解释为建议不转换")
- return False, None
- return None, None # JSON 解析失败,也未找到关键词
- except Exception as parse_err: # 捕获其他可能的解析错误
- logger.warning(f"{log_prefix} 解析 LLM JSON 响应时发生意外错误: {parse_err}。响应: {response_text}")
- return None, None
-
- except Exception as e:
- logger.error(f"{log_prefix} 调用 LLM 或处理其响应时出错: {e}", exc_info=True)
- traceback.print_exc()
- return None, None # LLM 调用或处理失败
+ elif current_state == ChatState.ABSENT:
+ logger.debug(f"[状态转换请求] {stream_name} 处于 ABSENT 状态,尝试转为 CHAT")
+ await subflow.change_chat_state(ChatState.CHAT)
+ else:
+ logger.debug(
+ f"[状态转换请求] {stream_name} 当前状态为 {current_state.value},无需转换"
+ )
def count_subflows_by_state(self, state: ChatState) -> int:
"""统计指定状态的子心流数量"""
@@ -636,23 +320,6 @@ class SubHeartflowManager:
count += 1
return count
- def get_active_subflow_minds(self) -> List[str]:
- """获取所有活跃(非ABSENT)子心流的当前想法"""
- minds = []
- for subheartflow in self.subheartflows.values():
- # 检查子心流是否活跃(非ABSENT状态)
- if subheartflow.chat_state.chat_status != ChatState.ABSENT:
- minds.append(subheartflow.sub_mind.current_mind)
- return minds
-
- def update_main_mind_in_subflows(self, main_mind: str):
- """更新所有子心流的主心流想法"""
- updated_count = sum(
- 1
- for _, subheartflow in list(self.subheartflows.items())
- if subheartflow.subheartflow_id in self.subheartflows
- )
- logger.debug(f"[子心流管理器] 更新了{updated_count}个子心流的主想法")
async def delete_subflow(self, subheartflow_id: Any):
"""删除指定的子心流。"""
@@ -669,91 +336,13 @@ class SubHeartflowManager:
else:
logger.warning(f"尝试删除不存在的 SubHeartflow: {subheartflow_id}")
- # --- 新增:处理 HFC 无回复回调的专用方法 --- #
+
async def _handle_hfc_no_reply(self, subheartflow_id: Any):
"""处理来自 HeartFChatting 的连续无回复信号 (通过 partial 绑定 ID)"""
- # 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent 内部会处理锁
+ # 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent_or_chat 内部会处理锁
logger.debug(f"[管理器 HFC 处理器] 接收到来自 {subheartflow_id} 的 HFC 无回复信号")
await self.sbhf_focus_into_absent_or_chat(subheartflow_id)
- # --- 结束新增 --- #
-
- # --- 新增:处理来自 HeartFChatting 的状态转换请求 --- #
- async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any):
- """
- 接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 ABSENT 或 CHAT。
- 通常在连续多次 "no_reply" 后被调用。
- 对于私聊,总是转换为 ABSENT。
- 对于群聊,随机决定转换为 ABSENT 或 CHAT (如果 CHAT 未达上限)。
-
- Args:
- subflow_id: 需要转换状态的子心流 ID。
- """
- async with self._lock:
- subflow = self.subheartflows.get(subflow_id)
- if not subflow:
- logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 ABSENT/CHAT")
- return
-
- stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
- current_state = subflow.chat_state.chat_status
-
- if current_state == ChatState.FOCUSED:
- target_state = ChatState.ABSENT # Default target
- log_reason = "默认转换 (私聊或群聊)"
-
- # --- Modify logic based on chat type --- #
- if subflow.is_group_chat:
- # Group chat: Decide between ABSENT or CHAT
- if random.random() < 0.5: # 50% chance to try CHAT
- current_mai_state = self.mai_state_info.get_current_state()
- chat_limit = current_mai_state.get_normal_chat_max_num()
- current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT)
-
- if current_chat_count < chat_limit:
- target_state = ChatState.CHAT
- log_reason = f"群聊随机选择 CHAT (当前 {current_chat_count}/{chat_limit})"
- else:
- target_state = ChatState.ABSENT # Fallback to ABSENT if CHAT limit reached
- log_reason = (
- f"群聊随机选择 CHAT 但已达上限 ({current_chat_count}/{chat_limit}),转为 ABSENT"
- )
- else: # 50% chance to go directly to ABSENT
- target_state = ChatState.ABSENT
- log_reason = "群聊随机选择 ABSENT"
- else:
- # Private chat: Always go to ABSENT
- target_state = ChatState.ABSENT
- log_reason = "私聊退出 FOCUSED,转为 ABSENT"
- # --- End modification --- #
-
- logger.info(
- f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})"
- )
- try:
- # 从HFC到CHAT时,清空兴趣字典
- subflow.clear_interest_dict()
- await subflow.change_chat_state(target_state)
- final_state = subflow.chat_state.chat_status
- if final_state == target_state:
- logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}")
- else:
- logger.warning(
- f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}"
- )
- except Exception as e:
- logger.error(
- f"[状态转换请求] 转换 {stream_name} 到 {target_state.value} 时出错: {e}", exc_info=True
- )
- elif current_state == ChatState.ABSENT:
- logger.debug(f"[状态转换请求] {stream_name} 已处于 ABSENT 状态,无需转换")
- else:
- logger.warning(
- f"[状态转换请求] 收到对 {stream_name} 的请求,但其状态为 {current_state.value} (非 FOCUSED),不执行转换"
- )
-
- # --- 结束新增 --- #
-
# --- 新增:处理私聊从 ABSENT 直接到 FOCUSED 的逻辑 --- #
async def sbhf_absent_private_into_focus(self):
"""检查 ABSENT 状态的私聊子心流是否有新活动,若有且未达 FOCUSED 上限,则直接转换为 FOCUSED。"""
@@ -761,19 +350,8 @@ class SubHeartflowManager:
transitioned_count = 0
checked_count = 0
- # --- 获取当前状态和 FOCUSED 上限 --- #
- current_mai_state = self.mai_state_info.get_current_state()
- focused_limit = current_mai_state.get_focused_chat_max_num()
-
# --- 检查是否允许 FOCUS 模式 --- #
- if not global_config.allow_focus_mode:
- # Log less frequently to avoid spam
- # if int(time.time()) % 60 == 0:
- # logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")
- return
-
- if focused_limit <= 0:
- # logger.debug(f"{log_prefix_task} 当前状态 ({current_mai_state.value}) 不允许 FOCUSED 子心流")
+ if not global_config.chat.allow_focus_mode:
return
async with self._lock:
@@ -794,12 +372,6 @@ class SubHeartflowManager:
# --- 遍历评估每个符合条件的私聊 --- #
for sub_hf in eligible_subflows:
- # --- 再次检查 FOCUSED 上限,因为可能有多个同时激活 --- #
- if current_focused_count >= focused_limit:
- logger.debug(
- f"{log_prefix_task} 已达专注上限 ({current_focused_count}/{focused_limit}),停止检查后续私聊。"
- )
- break # 已满,无需再检查其他私聊
flow_id = sub_hf.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
@@ -823,9 +395,6 @@ class SubHeartflowManager:
# --- 如果活跃且未达上限,则尝试转换 --- #
if is_active:
- logger.info(
- f"{log_prefix} 检测到活跃且未达专注上限 ({current_focused_count}/{focused_limit}),尝试转换为 FOCUSED。"
- )
await sub_hf.change_chat_state(ChatState.FOCUSED)
# 确认转换成功
if sub_hf.chat_state.chat_status == ChatState.FOCUSED:
diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py
index 70eb679c9..1695a3948 100644
--- a/src/chat/memory_system/Hippocampus.py
+++ b/src/chat/memory_system/Hippocampus.py
@@ -4,13 +4,14 @@ import math
import random
import time
import re
+import json
from itertools import combinations
import jieba
import networkx as nx
import numpy as np
from collections import Counter
-from ...common.database import db
+from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
@@ -19,9 +20,11 @@ from ..utils.chat_message_builder import (
build_readable_messages,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
-from .memory_config import MemoryConfig
from rich.traceback import install
+from ...config.config import global_config
+from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
+
install(extra_lines=3)
@@ -192,21 +195,19 @@ class Hippocampus:
def __init__(self):
self.memory_graph = MemoryGraph()
self.llm_topic_judge = None
- self.llm_summary = None
+ self.model_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
- self.config = None
- def initialize(self, global_config):
- # 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
- self.config = MemoryConfig.from_global_config(global_config)
+ def initialize(self):
# 初始化子组件
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
- self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
- self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory")
+ # TODO: API-Adapter修改标记
+ self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
+ self.model_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -792,7 +793,6 @@ class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
- self.config = hippocampus.config
def get_memory_sample(self):
"""从数据库获取记忆样本"""
@@ -801,13 +801,13 @@ class EntorhinalCortex:
# 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler(
- n_hours1=self.config.memory_build_distribution[0],
- std_hours1=self.config.memory_build_distribution[1],
- weight1=self.config.memory_build_distribution[2],
- n_hours2=self.config.memory_build_distribution[3],
- std_hours2=self.config.memory_build_distribution[4],
- weight2=self.config.memory_build_distribution[5],
- total_samples=self.config.build_memory_sample_num,
+ n_hours1=global_config.memory.memory_build_distribution[0],
+ std_hours1=global_config.memory.memory_build_distribution[1],
+ weight1=global_config.memory.memory_build_distribution[2],
+ n_hours2=global_config.memory.memory_build_distribution[3],
+ std_hours2=global_config.memory.memory_build_distribution[4],
+ weight2=global_config.memory.memory_build_distribution[5],
+ total_samples=global_config.memory.memory_build_sample_num,
)
timestamps = sample_scheduler.get_timestamp_array()
@@ -818,7 +818,7 @@ class EntorhinalCortex:
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
- timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
+ timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -858,11 +858,12 @@ class EntorhinalCortex:
if all_valid:
# 更新数据库中的记忆次数
for message in messages:
- # 确保在更新前获取最新的 memorized_times,以防万一
+ # 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
- db.messages.update_one(
- {"_id": message["_id"]}, {"$set": {"memorized_times": current_memorized_times + 1}}
- )
+ # 使用 Peewee 更新记录
+ Messages.update(memorized_times=current_memorized_times + 1).where(
+ Messages.message_id == message["message_id"]
+ ).execute()
return messages # 直接返回原始的消息列表
# 如果获取失败或消息无效,增加尝试次数
@@ -875,12 +876,9 @@ class EntorhinalCortex:
async def sync_memory_to_db(self):
"""将记忆图同步到数据库"""
# 获取数据库中所有节点和内存中所有节点
- db_nodes = list(db.graph_data.nodes.find())
+ db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
- # 转换数据库节点为字典格式,方便查找
- db_nodes_dict = {node["concept"]: node for node in db_nodes}
-
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
@@ -894,44 +892,39 @@ class EntorhinalCortex:
created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
- if concept not in db_nodes_dict:
+ # 将memory_items转换为JSON字符串
+ memory_items_json = json.dumps(memory_items, ensure_ascii=False)
+
+ if concept not in db_nodes:
# 数据库中缺少的节点,添加
- node_data = {
- "concept": concept,
- "memory_items": memory_items,
- "hash": memory_hash,
- "created_time": created_time,
- "last_modified": last_modified,
- }
- db.graph_data.nodes.insert_one(node_data)
+ GraphNodes.create(
+ concept=concept,
+ memory_items=memory_items_json,
+ hash=memory_hash,
+ created_time=created_time,
+ last_modified=last_modified,
+ )
else:
# 获取数据库中节点的特征值
- db_node = db_nodes_dict[concept]
- db_hash = db_node.get("hash", None)
+ db_node = db_nodes[concept]
+ db_hash = db_node.hash
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
- db.graph_data.nodes.update_one(
- {"concept": concept},
- {
- "$set": {
- "memory_items": memory_items,
- "hash": memory_hash,
- "created_time": created_time,
- "last_modified": last_modified,
- }
- },
- )
+ db_node.memory_items = memory_items_json
+ db_node.hash = memory_hash
+ db_node.last_modified = last_modified
+ db_node.save()
# 处理边的信息
- db_edges = list(db.graph_data.edges.find())
+ db_edges = list(GraphEdges.select())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
- edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"])
- db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
+ edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
+ db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 检查并更新边
for source, target, data in memory_edges:
@@ -945,29 +938,22 @@ class EntorhinalCortex:
if edge_key not in db_edge_dict:
# 添加新边
- edge_data = {
- "source": source,
- "target": target,
- "strength": strength,
- "hash": edge_hash,
- "created_time": created_time,
- "last_modified": last_modified,
- }
- db.graph_data.edges.insert_one(edge_data)
+ GraphEdges.create(
+ source=source,
+ target=target,
+ strength=strength,
+ hash=edge_hash,
+ created_time=created_time,
+ last_modified=last_modified,
+ )
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
- db.graph_data.edges.update_one(
- {"source": source, "target": target},
- {
- "$set": {
- "hash": edge_hash,
- "strength": strength,
- "created_time": created_time,
- "last_modified": last_modified,
- }
- },
- )
+ edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
+ edge.hash = edge_hash
+ edge.strength = strength
+ edge.last_modified = last_modified
+ edge.save()
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
@@ -978,29 +964,29 @@ class EntorhinalCortex:
self.memory_graph.G.clear()
# 从数据库加载所有节点
- nodes = list(db.graph_data.nodes.find())
+ nodes = list(GraphNodes.select())
for node in nodes:
- concept = node["concept"]
- memory_items = node.get("memory_items", [])
+ concept = node.concept
+ memory_items = json.loads(node.memory_items)
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在
- if "created_time" not in node or "last_modified" not in node:
+ if not node.created_time or not node.last_modified:
need_update = True
# 更新数据库中的节点
update_data = {}
- if "created_time" not in node:
+ if not node.created_time:
update_data["created_time"] = current_time
- if "last_modified" not in node:
+ if not node.last_modified:
update_data["last_modified"] = current_time
- db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data})
+ GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
- created_time = node.get("created_time", current_time)
- last_modified = node.get("last_modified", current_time)
+ created_time = node.created_time or current_time
+ last_modified = node.last_modified or current_time
# 添加节点到图中
self.memory_graph.G.add_node(
@@ -1008,28 +994,30 @@ class EntorhinalCortex:
)
# 从数据库加载所有边
- edges = list(db.graph_data.edges.find())
+ edges = list(GraphEdges.select())
for edge in edges:
- source = edge["source"]
- target = edge["target"]
- strength = edge.get("strength", 1)
+ source = edge.source
+ target = edge.target
+ strength = edge.strength
# 检查时间字段是否存在
- if "created_time" not in edge or "last_modified" not in edge:
+ if not edge.created_time or not edge.last_modified:
need_update = True
# 更新数据库中的边
update_data = {}
- if "created_time" not in edge:
+ if not edge.created_time:
update_data["created_time"] = current_time
- if "last_modified" not in edge:
+ if not edge.last_modified:
update_data["last_modified"] = current_time
- db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data})
+ GraphEdges.update(**update_data).where(
+ (GraphEdges.source == source) & (GraphEdges.target == target)
+ ).execute()
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
- created_time = edge.get("created_time", current_time)
- last_modified = edge.get("last_modified", current_time)
+ created_time = edge.created_time or current_time
+ last_modified = edge.last_modified or current_time
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
@@ -1047,8 +1035,8 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
- db.graph_data.nodes.delete_many({})
- db.graph_data.edges.delete_many({})
+ GraphNodes.delete().execute()
+ GraphEdges.delete().execute()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
@@ -1063,29 +1051,27 @@ class EntorhinalCortex:
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
- node_data = {
- "concept": concept,
- "memory_items": memory_items,
- "hash": self.hippocampus.calculate_node_hash(concept, memory_items),
- "created_time": data.get("created_time", datetime.datetime.now().timestamp()),
- "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
- }
- db.graph_data.nodes.insert_one(node_data)
+ GraphNodes.create(
+ concept=concept,
+ memory_items=json.dumps(memory_items),
+ hash=self.hippocampus.calculate_node_hash(concept, memory_items),
+ created_time=data.get("created_time", datetime.datetime.now().timestamp()),
+ last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
+ )
node_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
# 重新写入边
edge_start = time.time()
for source, target, data in memory_edges:
- edge_data = {
- "source": source,
- "target": target,
- "strength": data.get("strength", 1),
- "hash": self.hippocampus.calculate_edge_hash(source, target),
- "created_time": data.get("created_time", datetime.datetime.now().timestamp()),
- "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
- }
- db.graph_data.edges.insert_one(edge_data)
+ GraphEdges.create(
+ source=source,
+ target=target,
+ strength=data.get("strength", 1),
+ hash=self.hippocampus.calculate_edge_hash(source, target),
+ created_time=data.get("created_time", datetime.datetime.now().timestamp()),
+ last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
+ )
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
@@ -1099,7 +1085,6 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
- self.config = hippocampus.config
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
@@ -1159,7 +1144,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
- topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
+ topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1170,7 +1155,7 @@ class ParahippocampalGyrus:
# 调用修改后的 topic_what,不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
try:
- task = self.hippocampus.llm_summary.generate_response_async(topic_what_prompt)
+ task = self.hippocampus.model_summary.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
except Exception as e:
logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}")
@@ -1222,7 +1207,7 @@ class ParahippocampalGyrus:
bar = "█" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
- compress_rate = self.config.memory_compress_rate
+ compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
@@ -1322,7 +1307,7 @@ class ParahippocampalGyrus:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified")
- if current_time - last_modified > 3600 * self.config.memory_forget_time:
+ if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1
@@ -1430,8 +1415,8 @@ class ParahippocampalGyrus:
async def operation_consolidate_memory(self):
"""整合记忆:合并节点内相似的记忆项"""
start_time = time.time()
- percentage = self.config.consolidate_memory_percentage
- similarity_threshold = self.config.consolidation_similarity_threshold
+ percentage = global_config.memory.consolidate_memory_percentage
+ similarity_threshold = global_config.memory.consolidation_similarity_threshold
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
# 获取所有至少有2条记忆项的节点
@@ -1544,7 +1529,6 @@ class ParahippocampalGyrus:
class HippocampusManager:
_instance = None
_hippocampus = None
- _global_config = None
_initialized = False
@classmethod
@@ -1559,19 +1543,15 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return cls._hippocampus
- def initialize(self, global_config):
+ def initialize(self):
"""初始化海马体实例"""
if self._initialized:
return self._hippocampus
- self._global_config = global_config
self._hippocampus = Hippocampus()
- self._hippocampus.initialize(global_config)
+ self._hippocampus.initialize()
self._initialized = True
- # 输出记忆系统参数信息
- config = self._hippocampus.config
-
# 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes())
@@ -1579,9 +1559,9 @@ class HippocampusManager:
logger.success(f"""--------------------------------
记忆系统参数配置:
- 构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
- 记忆构建分布: {config.memory_build_distribution}
- 遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
+ 构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
+ 记忆构建分布: {global_config.memory.memory_build_distribution}
+ 遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501
diff --git a/src/chat/memory_system/debug_memory.py b/src/chat/memory_system/debug_memory.py
index baf745409..b09e703a1 100644
--- a/src/chat/memory_system/debug_memory.py
+++ b/src/chat/memory_system/debug_memory.py
@@ -7,7 +7,6 @@ import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.chat.memory_system.Hippocampus import HippocampusManager
-from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
@@ -19,7 +18,7 @@ async def test_memory_system():
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
- hippocampus_manager.initialize(global_config=global_config)
+ hippocampus_manager.initialize()
print("记忆系统初始化完成")
# 测试记忆构建
diff --git a/src/chat/memory_system/manually_alter_memory.py b/src/chat/memory_system/manually_alter_memory.py
index ce5abbba7..9bbf59f5b 100644
--- a/src/chat/memory_system/manually_alter_memory.py
+++ b/src/chat/memory_system/manually_alter_memory.py
@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.logger import get_module_logger # noqa E402
-from src.common.database import db # noqa E402
+from common.database.database import db # noqa E402
logger = get_module_logger("mem_alter")
console = Console()
diff --git a/src/chat/memory_system/memory_config.py b/src/chat/memory_system/memory_config.py
deleted file mode 100644
index b82e54ec1..000000000
--- a/src/chat/memory_system/memory_config.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from dataclasses import dataclass
-from typing import List
-
-
-@dataclass
-class MemoryConfig:
- """记忆系统配置类"""
-
- # 记忆构建相关配置
- memory_build_distribution: List[float] # 记忆构建的时间分布参数
- build_memory_sample_num: int # 每次构建记忆的样本数量
- build_memory_sample_length: int # 每个样本的消息长度
- memory_compress_rate: float # 记忆压缩率
-
- # 记忆遗忘相关配置
- memory_forget_time: int # 记忆遗忘时间(小时)
-
- # 记忆过滤相关配置
- memory_ban_words: List[str] # 记忆过滤词列表
-
- # 新增:记忆整合相关配置
- consolidation_similarity_threshold: float # 相似度阈值
- consolidate_memory_percentage: float # 检查节点比例
- consolidate_memory_interval: int # 记忆整合间隔
-
- llm_topic_judge: str # 话题判断模型
- llm_summary: str # 话题总结模型
-
- @classmethod
- def from_global_config(cls, global_config):
- """从全局配置创建记忆系统配置"""
- # 使用 getattr 提供默认值,防止全局配置缺少这些项
- return cls(
- memory_build_distribution=getattr(
- global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
- ), # 添加默认值
- build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
- build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
- memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
- memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
- memory_ban_words=getattr(global_config, "memory_ban_words", []),
- # 新增加载整合配置,并提供默认值
- consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
- consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
- consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
- llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
- llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
- )
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index 3c9e4420c..3b9a6f929 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -38,10 +38,10 @@ class ChatBot:
async def _create_pfc_chat(self, message: MessageRecv):
try:
- chat_id = str(message.chat_stream.stream_id)
- private_name = str(message.message_info.user_info.user_nickname)
+ if global_config.experimental.pfc_chatting:
+ chat_id = str(message.chat_stream.stream_id)
+ private_name = str(message.message_info.user_info.user_nickname)
- if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
@@ -72,27 +72,11 @@ class ChatBot:
message_data["message_info"]["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"]
)
+ # print(message_data)
logger.trace(f"处理消息:{str(message_data)[:120]}...")
message = MessageRecv(message_data)
- groupinfo = message.message_info.group_info
- userinfo = message.message_info.user_info
-
- # 用户黑名单拦截
- if userinfo.user_id in global_config.ban_user_id:
- logger.debug(f"用户{userinfo.user_id}被禁止回复")
- return
-
- if groupinfo is None:
- logger.trace("检测到私聊消息,检查")
- # 好友黑名单拦截
- if userinfo.user_id not in global_config.talk_allowed_private:
- logger.debug(f"用户{userinfo.user_id}没有私聊权限")
- return
-
- # 群聊黑名单拦截
- if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
- logger.trace(f"群{groupinfo.group_id}被禁止回复")
- return
+ group_info = message.message_info.group_info
+ user_info = message.message_info.user_info
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -109,33 +93,27 @@ class ChatBot:
async def preprocess():
logger.trace("开始预处理消息...")
# 如果在私聊中
- if groupinfo is None:
+ if group_info is None:
logger.trace("检测到私聊消息")
- # 是否在配置信息中开启私聊模式
- if global_config.enable_friend_chat:
- logger.trace("私聊模式已启用")
- # 是否进入PFC
- if global_config.enable_pfc_chatting:
- logger.trace("进入PFC私聊处理流程")
- userinfo = message.message_info.user_info
- messageinfo = message.message_info
- # 创建聊天流
- logger.trace(f"为{userinfo.user_id}创建/获取聊天流")
- chat = await chat_manager.get_or_create_stream(
- platform=messageinfo.platform,
- user_info=userinfo,
- group_info=groupinfo,
- )
- message.update_chat_stream(chat)
- await self.only_process_chat.process_message(message)
- await self._create_pfc_chat(message)
- # 禁止PFC,进入普通的心流消息处理逻辑
- else:
- logger.trace("进入普通心流私聊处理")
- await self.heartflow_processor.process_message(message_data)
+ if global_config.experimental.pfc_chatting:
+ logger.trace("进入PFC私聊处理流程")
+ # 创建聊天流
+ logger.trace(f"为{user_info.user_id}创建/获取聊天流")
+ chat = await chat_manager.get_or_create_stream(
+ platform=message.message_info.platform,
+ user_info=user_info,
+ group_info=group_info,
+ )
+ message.update_chat_stream(chat)
+ await self.only_process_chat.process_message(message)
+ await self._create_pfc_chat(message)
+ # 禁止PFC,进入普通的心流消息处理逻辑
+ else:
+ logger.trace("进入普通心流私聊处理")
+ await self.heartflow_processor.process_message(message_data)
# 群聊默认进入心流消息处理逻辑
else:
- logger.trace(f"检测到群聊消息,群ID: {groupinfo.group_id}")
+ logger.trace(f"检测到群聊消息,群ID: {group_info.group_id}")
await self.heartflow_processor.process_message(message_data)
if template_group_name:
diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py
index 53ebd5026..e00fc7370 100644
--- a/src/chat/message_receive/chat_stream.py
+++ b/src/chat/message_receive/chat_stream.py
@@ -5,7 +5,8 @@ import copy
from typing import Dict, Optional
-from ...common.database import db
+from ...common.database.database import db
+from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
@@ -38,7 +39,7 @@ class ChatStream:
def to_dict(self) -> dict:
"""转换为字典格式"""
- result = {
+ return {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
@@ -46,7 +47,6 @@ class ChatStream:
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
- return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
@@ -82,7 +82,13 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
- self._ensure_collection()
+ try:
+ db.connect(reuse_if_open=True)
+ # 确保 ChatStreams 表存在
+ db.create_tables([ChatStreams], safe=True)
+ except Exception as e:
+ logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
+
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
@@ -107,15 +113,6 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
- @staticmethod
- def _ensure_collection():
- """确保数据库集合存在并创建索引"""
- if "chat_streams" not in db.list_collection_names():
- db.create_collection("chat_streams")
- # 创建索引
- db.chat_streams.create_index([("stream_id", 1)], unique=True)
- db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
-
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
@@ -151,16 +148,43 @@ class ChatManager:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
- stream = copy.deepcopy(stream)
+ stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
- data = db.chat_streams.find_one({"stream_id": stream_id})
- if data:
- stream = ChatStream.from_dict(data)
+ def _db_find_stream_sync(s_id: str):
+ return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
+
+ model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
+
+ if model_instance:
+ # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
+ user_info_data = {
+ "platform": model_instance.user_platform,
+ "user_id": model_instance.user_id,
+ "user_nickname": model_instance.user_nickname,
+ "user_cardname": model_instance.user_cardname or "",
+ }
+ group_info_data = None
+ if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
+ group_info_data = {
+ "platform": model_instance.group_platform,
+ "group_id": model_instance.group_id,
+ "group_name": model_instance.group_name,
+ }
+
+ data_for_from_dict = {
+ "stream_id": model_instance.stream_id,
+ "platform": model_instance.platform,
+ "user_info": user_info_data,
+ "group_info": group_info_data,
+ "create_time": model_instance.create_time,
+ "last_active_time": model_instance.last_active_time,
+ }
+ stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
@@ -175,7 +199,7 @@ class ChatManager:
group_info=group_info,
)
except Exception as e:
- logger.error(f"创建聊天流失败: {e}")
+ logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
# 保存到内存和数据库
@@ -205,15 +229,39 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊"
else:
- # 如果没有群名或用户昵称,返回 None 或其他默认值
return None
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
- if not stream.saved:
- db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
+ if stream.saved:
+ return
+ stream_data_dict = stream.to_dict()
+
+ def _db_save_stream_sync(s_data_dict: dict):
+ user_info_d = s_data_dict.get("user_info")
+ group_info_d = s_data_dict.get("group_info")
+
+ fields_to_save = {
+ "platform": s_data_dict["platform"],
+ "create_time": s_data_dict["create_time"],
+ "last_active_time": s_data_dict["last_active_time"],
+ "user_platform": user_info_d["platform"] if user_info_d else "",
+ "user_id": user_info_d["user_id"] if user_info_d else "",
+ "user_nickname": user_info_d["user_nickname"] if user_info_d else "",
+ "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
+ "group_platform": group_info_d["platform"] if group_info_d else "",
+ "group_id": group_info_d["group_id"] if group_info_d else "",
+ "group_name": group_info_d["group_name"] if group_info_d else "",
+ }
+
+ ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
+
+ try:
+ await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True
+ except Exception as e:
+ logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self):
"""保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
- all_streams = db.chat_streams.find({})
- for data in all_streams:
- stream = ChatStream.from_dict(data)
- self.streams[stream.stream_id] = stream
+
+ def _db_load_all_streams_sync():
+ loaded_streams_data = []
+ for model_instance in ChatStreams.select():
+ user_info_data = {
+ "platform": model_instance.user_platform,
+ "user_id": model_instance.user_id,
+ "user_nickname": model_instance.user_nickname,
+ "user_cardname": model_instance.user_cardname or "",
+ }
+ group_info_data = None
+ if model_instance.group_id:
+ group_info_data = {
+ "platform": model_instance.group_platform,
+ "group_id": model_instance.group_id,
+ "group_name": model_instance.group_name,
+ }
+
+ data_for_from_dict = {
+ "stream_id": model_instance.stream_id,
+ "platform": model_instance.platform,
+ "user_info": user_info_data,
+ "group_info": group_info_data,
+ "create_time": model_instance.create_time,
+ "last_active_time": model_instance.last_active_time,
+ }
+ loaded_streams_data.append(data_for_from_dict)
+ return loaded_streams_data
+
+ try:
+ all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
+ self.streams.clear()
+ for data in all_streams_data_list:
+ stream = ChatStream.from_dict(data)
+ stream.saved = True
+ self.streams[stream.stream_id] = stream
+ except Exception as e:
+ logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例
diff --git a/src/chat/message_receive/message_buffer.py b/src/chat/message_receive/message_buffer.py
index f3cf63d0a..2df256ce5 100644
--- a/src/chat/message_receive/message_buffer.py
+++ b/src/chat/message_receive/message_buffer.py
@@ -38,7 +38,7 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
- if not global_config.message_buffer:
+ if not global_config.chat.message_buffer:
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
@@ -107,7 +107,7 @@ class MessageBuffer:
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
- if not global_config.message_buffer:
+ if not global_config.chat.message_buffer:
return True
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
diff --git a/src/chat/message_receive/message_sender.py b/src/chat/message_receive/message_sender.py
index 5db34fdea..cf5877989 100644
--- a/src/chat/message_receive/message_sender.py
+++ b/src/chat/message_receive/message_sender.py
@@ -279,7 +279,7 @@ class MessageManager:
)
# 检查是否超时
- if thinking_time > global_config.thinking_timeout:
+ if thinking_time > global_config.normal_chat.thinking_timeout:
logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
)
diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py
index cae029a11..d0041cd51 100644
--- a/src/chat/message_receive/storage.py
+++ b/src/chat/message_receive/storage.py
@@ -1,9 +1,10 @@
import re
from typing import Union
-from ...common.database import db
+# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
+from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
from src.common.logger import get_module_logger
logger = get_module_logger("message_storage")
@@ -29,42 +30,66 @@ class MessageStorage:
else:
filtered_detailed_plain_text = ""
- message_data = {
- "message_id": message.message_info.message_id,
- "time": message.message_info.time,
- "chat_id": chat_stream.stream_id,
- "chat_info": chat_stream.to_dict(),
- "user_info": message.message_info.user_info.to_dict(),
- # 使用过滤后的文本
- "processed_plain_text": filtered_processed_plain_text,
- "detailed_plain_text": filtered_detailed_plain_text,
- "memorized_times": message.memorized_times,
- }
- db.messages.insert_one(message_data)
+ chat_info_dict = chat_stream.to_dict()
+ user_info_dict = message.message_info.user_info.to_dict()
+
+ # message_id 现在是 TextField,直接使用字符串值
+ msg_id = message.message_info.message_id
+
+ # 安全地获取 group_info, 如果为 None 则视为空字典
+ group_info_from_chat = chat_info_dict.get("group_info") or {}
+ # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
+ user_info_from_chat = chat_info_dict.get("user_info") or {}
+
+ Messages.create(
+ message_id=msg_id,
+ time=float(message.message_info.time),
+ chat_id=chat_stream.stream_id,
+ # Flattened chat_info
+ chat_info_stream_id=chat_info_dict.get("stream_id"),
+ chat_info_platform=chat_info_dict.get("platform"),
+ chat_info_user_platform=user_info_from_chat.get("platform"),
+ chat_info_user_id=user_info_from_chat.get("user_id"),
+ chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
+ chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
+ chat_info_group_platform=group_info_from_chat.get("platform"),
+ chat_info_group_id=group_info_from_chat.get("group_id"),
+ chat_info_group_name=group_info_from_chat.get("group_name"),
+ chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
+ chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
+ # Flattened user_info (message sender)
+ user_platform=user_info_dict.get("platform"),
+ user_id=user_info_dict.get("user_id"),
+ user_nickname=user_info_dict.get("user_nickname"),
+ user_cardname=user_info_dict.get("user_cardname"),
+ # Text content
+ processed_plain_text=filtered_processed_plain_text,
+ detailed_plain_text=filtered_detailed_plain_text,
+ memorized_times=message.memorized_times,
+ )
except Exception:
logger.exception("存储消息失败")
@staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库"""
- if "recalled_messages" not in db.list_collection_names():
- db.create_collection("recalled_messages")
- else:
- try:
- message_data = {
- "message_id": message_id,
- "time": time,
- "stream_id": chat_stream.stream_id,
- }
- db.recalled_messages.insert_one(message_data)
- except Exception:
- logger.exception("存储撤回消息失败")
+ # Table creation is handled by initialize_database in database_model.py
+ try:
+ RecalledMessages.create(
+ message_id=message_id,
+ time=float(time), # Assuming time is a string representing a float timestamp
+ stream_id=chat_stream.stream_id,
+ )
+ except Exception:
+ logger.exception("存储撤回消息失败")
@staticmethod
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息"""
try:
- db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
+ # Assuming input 'time' is a string timestamp that can be converted to float
+ current_time_float = float(time)
+ RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
except Exception:
logger.exception("删除撤回消息失败")
diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py
index e662a8e33..f6528856d 100644
--- a/src/chat/models/utils_model.py
+++ b/src/chat/models/utils_model.py
@@ -12,7 +12,8 @@ import base64
from PIL import Image
import io
import os
-from ...common.database import db
+from src.common.database.database import db # 确保 db 被导入用于 create_tables
+from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config
from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
- # if isinstance(content, str) and len(content) > 100:
- # payload["messages"][0]["content"] = content[:100]
return payload
@@ -111,8 +110,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
- self.api_key = os.environ[model["key"]]
- self.base_url = os.environ[model["base_url"]]
+ self.api_key = os.environ[f"{model['provider']}_KEY"]
+ self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database():
"""初始化数据库集合"""
try:
- # 创建llm_usage集合的索引
- db.llm_usage.create_index([("timestamp", 1)])
- db.llm_usage.create_index([("model_name", 1)])
- db.llm_usage.create_index([("user_id", 1)])
- db.llm_usage.create_index([("request_type", 1)])
+ # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
+ db.create_tables([LLMUsage], safe=True)
+ logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
- logger.error(f"创建数据库索引失败: {str(e)}")
+ logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage(
self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type
try:
- usage_data = {
- "model_name": self.model_name,
- "user_id": user_id,
- "request_type": request_type,
- "endpoint": endpoint,
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": total_tokens,
- "cost": self._calculate_cost(prompt_tokens, completion_tokens),
- "status": "success",
- "timestamp": datetime.now(),
- }
- db.llm_usage.insert_one(usage_data)
+ # 使用 Peewee 模型创建记录
+ LLMUsage.create(
+ model_name=self.model_name,
+ user_id=user_id,
+ request_type=request_type,
+ endpoint=endpoint,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ cost=self._calculate_cost(prompt_tokens, completion_tokens),
+ status="success",
+ timestamp=datetime.now(), # Peewee 会处理 DateTimeField
+ )
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
@@ -500,11 +497,11 @@ class LLMRequest:
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
- if global_config.llm_normal.get("name") == old_model_name:
- global_config.llm_normal["name"] = self.model_name
+ if global_config.model.normal.get("name") == old_model_name:
+ global_config.model.normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
- if global_config.llm_reasoning.get("name") == old_model_name:
- global_config.llm_reasoning["name"] = self.model_name
+ if global_config.model.reasoning.get("name") == old_model_name:
+ global_config.model.reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
@@ -636,7 +633,7 @@ class LLMRequest:
**params_copy,
}
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- payload["max_tokens"] = global_config.model_max_output_length
+ payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py
index 9dc2454ff..bd5322137 100644
--- a/src/chat/normal_chat/normal_chat.py
+++ b/src/chat/normal_chat/normal_chat.py
@@ -22,11 +22,11 @@ from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.normal_chat.willing.willing_manager import willing_manager
from src.config.config import global_config
-logger = get_logger("chat")
+logger = get_logger("normal_chat")
class NormalChat:
- def __init__(self, chat_stream: ChatStream, interest_dict: dict = None):
+ def __init__(self, chat_stream: ChatStream, interest_dict: dict = {}):
"""初始化 NormalChat 实例。只进行同步操作。"""
# Basic info from chat_stream (sync)
@@ -73,8 +73,8 @@ class NormalChat:
messageinfo = message.message_info
bot_user_info = UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
@@ -121,8 +121,8 @@ class NormalChat:
message_id=thinking_id,
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -147,7 +147,7 @@ class NormalChat:
# 改为实例方法
async def _handle_emoji(self, message: MessageRecv, response: str):
"""处理表情包"""
- if random() < global_config.emoji_chance:
+ if random() < global_config.normal_chat.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
if emoji_raw:
emoji_path, description = emoji_raw
@@ -160,8 +160,8 @@ class NormalChat:
message_id="mt" + str(thinking_time_point),
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -186,7 +186,7 @@ class NormalChat:
label=emotion,
stance=stance, # 使用 self.chat_stream
)
- self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
+ self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
async def _reply_interested_message(self) -> None:
"""
@@ -200,7 +200,7 @@ class NormalChat:
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
break
- # 获取待处理消息列表
+
items_to_process = list(self.interest_dict.items())
if not items_to_process:
continue
@@ -430,7 +430,7 @@ class NormalChat:
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息中是否包含过滤词"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
- for word in global_config.ban_words:
+ for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -445,7 +445,7 @@ class NormalChat:
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
- for pattern in global_config.ban_msgs_regex:
+ for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -481,7 +481,7 @@ class NormalChat:
try:
if exc := task.exception():
logger.error(f"[{self.stream_name}] 任务异常: {exc}")
- logger.error(traceback.format_exc())
+ traceback.print_exc()
except asyncio.CancelledError:
logger.debug(f"[{self.stream_name}] 任务已取消")
except Exception as e:
@@ -522,4 +522,4 @@ class NormalChat:
logger.info(f"[{self.stream_name}] 清理了 {len(thinking_messages)} 条未处理的思考消息。")
except Exception as e:
logger.error(f"[{self.stream_name}] 清理思考消息时出错: {e}")
- logger.error(traceback.format_exc())
+ traceback.print_exc()
diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py
index aec65ed1d..631f7baa5 100644
--- a/src/chat/normal_chat/normal_chat_generator.py
+++ b/src/chat/normal_chat/normal_chat_generator.py
@@ -15,21 +15,22 @@ logger = get_logger("llm")
class NormalChatGenerator:
def __init__(self):
+ # TODO: API-Adapter修改标记
self.model_reasoning = LLMRequest(
- model=global_config.llm_reasoning,
+ model=global_config.model.reasoning,
temperature=0.7,
max_tokens=3000,
request_type="response_reasoning",
)
self.model_normal = LLMRequest(
- model=global_config.llm_normal,
- temperature=global_config.llm_normal["temp"],
+ model=global_config.model.normal,
+ temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_reasoning",
)
self.model_sum = LLMRequest(
- model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation"
+ model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
)
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"
@@ -37,7 +38,7 @@ class NormalChatGenerator:
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
- if random.random() < global_config.model_reasoning_probability:
+ if random.random() < global_config.normal_chat.reasoning_model_probability:
self.current_model_type = "深深地"
current_model = self.model_reasoning
else:
@@ -51,7 +52,7 @@ class NormalChatGenerator:
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
if model_response:
- logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
+ logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
model_response = await self._process_response(model_response)
return model_response
@@ -113,7 +114,7 @@ class NormalChatGenerator:
- "中立":不表达明确立场或无关回应
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
- 4. 考虑回复者的人格设定为{global_config.personality_core}
+ 4. 考虑回复者的人格设定为{global_config.personality.personality_core}
对话示例:
被回复:「A就是笨」
diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py
index e96aa77a7..a9f04273a 100644
--- a/src/chat/normal_chat/willing/mode_classical.py
+++ b/src/chat/normal_chat/willing/mode_classical.py
@@ -1,18 +1,20 @@
import asyncio
+
+from src.config.config import global_config
from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager):
def __init__(self):
super().__init__()
- self._decay_task: asyncio.Task = None
+ self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(1)
for chat_id in self.chat_reply_willing:
- self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9)
+ self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self):
if self._decay_task is None:
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
- interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier
+ interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
if interested_rate > 0.4:
current_willing += interested_rate - 0.3
- if willing_info.is_mentioned_bot and current_willing < 1.0:
- current_willing += 1
- elif willing_info.is_mentioned_bot:
- current_willing += 0.05
+ if willing_info.is_mentioned_bot:
+ current_willing += 1 if current_willing < 1.0 else 0.05
is_emoji_not_reply = False
if willing_info.is_emoji:
- if self.global_config.emoji_response_penalty != 0:
- current_willing *= self.global_config.emoji_response_penalty
+ if global_config.normal_chat.emoji_response_penalty != 0:
+ current_willing *= global_config.normal_chat.emoji_response_penalty
else:
is_emoji_not_reply = True
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(
- max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
+ max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊)
if (
willing_info.group_info
- and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
+ and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
):
- reply_probability = reply_probability / self.global_config.down_frequency_rate
+ reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
if is_emoji_not_reply:
reply_probability = 0
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
- self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8)
+ self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
- self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4)
+ self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id)
diff --git a/src/chat/normal_chat/willing/mode_mxp.py b/src/chat/normal_chat/willing/mode_mxp.py
index 78120ac53..1e7d5856d 100644
--- a/src/chat/normal_chat/willing/mode_mxp.py
+++ b/src/chat/normal_chat/willing/mode_mxp.py
@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
+from src.config.config import global_config
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益
- self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
- self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
probability = self._willing_to_probability(current_willing)
if w_info.is_emoji:
- probability *= self.emoji_response_penalty
+ probability *= global_config.normal_chat.emoji_response_penalty
- if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
- probability /= self.down_frequency_rate
+ if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
+ probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing
diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py
index 37e623d11..bbc5dcc0a 100644
--- a/src/chat/normal_chat/willing/willing_manager.py
+++ b/src/chat/normal_chat/willing/willing_manager.py
@@ -1,6 +1,6 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass
-from src.config.config import global_config, BotConfig
+from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock()
- self.global_config: BotConfig = global_config
self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
Returns:
对应mode的WillingManager实例
"""
- mode = global_config.willing_mode.lower()
+ mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode)
diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py
index 605b86b23..562cdc235 100644
--- a/src/chat/person_info/person_info.py
+++ b/src/chat/person_info/person_info.py
@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger
-from ...common.database import db
+from ...common.database.database import db
+from ...common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
-import json
+import json # 新增导入
import re
@@ -38,47 +39,49 @@ logger = get_logger("person_info")
person_info_default = {
"person_id": None,
- "person_name": None,
+ "person_name": None, # 模型中已设为 null=True,此默认值OK
"name_reason": None,
- "platform": None,
- "user_id": None,
- "nickname": None,
- # "age" : 0,
+ "platform": "unknown", # 提供非None的默认值
+ "user_id": "unknown", # 提供非None的默认值
+ "nickname": "Unknown", # 提供非None的默认值
"relationship_value": 0,
- # "saved" : True,
- # "impression" : None,
- # "gender" : Unkown,
- "konw_time": 0,
+ "know_time": 0, # 修正拼写:konw_time -> know_time
"msg_interval": 2000,
- "msg_interval_list": [],
- "user_cardname": None, # 添加群名片
- "user_avatar": None, # 添加头像信息(例如URL或标识符)
-} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
+ "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
+ "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
+ "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
+}
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
+ # TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest(
- model=global_config.llm_normal,
+ model=global_config.model.normal,
max_tokens=256,
request_type="qv_name",
)
- if "person_info" not in db.list_collection_names():
- db.create_collection("person_info")
- db.person_info.create_index("person_id", unique=True)
+ try:
+ db.connect(reuse_if_open=True)
+ db.create_tables([PersonInfo], safe=True)
+ except Exception as e:
+ logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
- cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
- for doc in cursor:
- if doc.get("person_name"):
- self.person_name_list[doc["person_id"]] = doc["person_name"]
- logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
+ try:
+ for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
+ PersonInfo.person_name.is_null(False)
+ ):
+ if record.person_name:
+ self.person_name_list[record.person_id] = record.person_name
+ logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
+ except Exception as e:
+ logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: int):
"""获取唯一id"""
- # 如果platform中存在-,就截取-后面的部分
if "-" in platform:
platform = platform.split("-")[1]
@@ -86,15 +89,27 @@ class PersonInfoManager:
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
- def is_person_known(self, platform: str, user_id: int):
+ async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
- document = db.person_info.find_one({"person_id": person_id})
- if document:
- return True
- else:
+
+ def _db_check_known_sync(p_id: str):
+ return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
+
+ try:
+ return await asyncio.to_thread(_db_check_known_sync, person_id)
+ except Exception as e:
+ logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
+ def get_person_id_by_person_name(self, person_name: str):
+ """根据用户名获取用户ID"""
+ document = db.person_info.find_one({"person_name": person_name})
+ if document:
+ return document["person_id"]
+ else:
+ return ""
+
@staticmethod
async def create_person_info(person_id: str, data: dict = None):
"""创建一个项"""
@@ -103,73 +118,111 @@ class PersonInfoManager:
return
_person_info_default = copy.deepcopy(person_info_default)
- _person_info_default["person_id"] = person_id
+ model_fields = PersonInfo._meta.fields.keys()
+
+ final_data = {"person_id": person_id}
if data:
- for key in _person_info_default:
- if key != "person_id" and key in data:
- _person_info_default[key] = data[key]
+ for key, value in data.items():
+ if key in model_fields:
+ final_data[key] = value
- db.person_info.insert_one(_person_info_default)
+ for key, default_value in _person_info_default.items():
+ if key in model_fields and key not in final_data:
+ final_data[key] = default_value
+
+ if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
+ final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
+ elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
+ final_data["msg_interval_list"] = json.dumps([])
+
+ def _db_create_sync(p_data: dict):
+ try:
+ PersonInfo.create(**p_data)
+ return True
+ except Exception as e:
+ logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
+ return False
+
+ await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全"""
- if field_name not in person_info_default.keys():
- logger.debug(f"更新'{field_name}'失败,未定义的字段")
+ if field_name not in PersonInfo._meta.fields:
+ if field_name in person_info_default:
+ logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
+ return
+ logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
- document = db.person_info.find_one({"person_id": person_id})
+ def _db_update_sync(p_id: str, f_name: str, val):
+ record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
+ if record:
+ if f_name == "msg_interval_list" and isinstance(val, list):
+ setattr(record, f_name, json.dumps(val))
+ else:
+ setattr(record, f_name, val)
+ record.save()
+ return True, False
+ return False, True
- if document:
- db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
- else:
- data[field_name] = value
- logger.debug(f"更新时{person_id}不存在,已新建")
- await self.create_person_info(person_id, data)
+ found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
+
+ if needs_creation:
+ logger.debug(f"更新时 {person_id} 不存在,将新建。")
+ creation_data = data if data is not None else {}
+ creation_data[field_name] = value
+ if "platform" not in creation_data or "user_id" not in creation_data:
+ logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。")
+
+ await self.create_person_info(person_id, creation_data)
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
- document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
- if document:
- return True
- else:
+ if field_name not in PersonInfo._meta.fields:
+ logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
+ return False
+
+ def _db_has_field_sync(p_id: str, f_name: str):
+ record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
+ if record:
+ return True
+ return False
+
+ try:
+ return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
+ except Exception as e:
+ logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False
@staticmethod
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
- # 尝试直接解析
parsed_json = json.loads(text)
- # 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list):
- if parsed_json: # 检查列表是否为空
+ if parsed_json:
parsed_json = parsed_json[0]
- else: # 如果列表为空,重置为 None,走后续逻辑
+ else:
parsed_json = None
- # 确保解析结果是字典
if isinstance(parsed_json, dict):
return parsed_json
except json.JSONDecodeError:
- # 解析失败,继续尝试其他方法
pass
except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
- pass # 继续尝试其他方法
+ pass
- # 如果直接解析失败或结果不是字典
try:
- # 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text)
if matches:
parsed_obj = json.loads(matches[0])
- if isinstance(parsed_obj, dict): # 确保是字典
+ if isinstance(parsed_obj, dict):
return parsed_obj
- # 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -184,7 +237,6 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}")
- # 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""}
@@ -199,9 +251,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason")
- max_retries = 5 # 最大重试次数
+ max_retries = 5
current_try = 0
- existing_names = ""
+ existing_names_str = ""
+ current_name_set = set(self.person_name_list.values())
+
while current_try < max_retries:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -216,45 +270,58 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
-
qv_name_prompt += (
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
)
- if existing_names:
- qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n"
+
+ if existing_names_str:
+ qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
+
+ if len(current_name_set) < 50 and current_name_set:
+ qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
+
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
qv_name_prompt += """{
"nickname": "昵称",
"reason": "理由"
}"""
- # logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0])
- if not result["nickname"]:
- logger.error("生成的昵称为空,重试中...")
+ if not result or not result.get("nickname"):
+ logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1
continue
- # 检查生成的昵称是否已存在
- if result["nickname"] not in self.person_name_list.values():
- # 更新数据库和内存中的列表
- await self.update_one_field(person_id, "person_name", result["nickname"])
- # await self.update_one_field(person_id, "nickname", user_nickname)
- # await self.update_one_field(person_id, "avatar", user_avatar)
- await self.update_one_field(person_id, "name_reason", result["reason"])
+ generated_nickname = result["nickname"]
- self.person_name_list[person_id] = result["nickname"]
- # logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
+ is_duplicate = False
+ if generated_nickname in current_name_set:
+ is_duplicate = True
+ else:
+
+ def _db_check_name_exists_sync(name_to_check):
+ return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
+
+ if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
+ is_duplicate = True
+ current_name_set.add(generated_nickname)
+
+ if not is_duplicate:
+ await self.update_one_field(person_id, "person_name", generated_nickname)
+ await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
+
+ self.person_name_list[person_id] = generated_nickname
return result
else:
- existing_names += f"{result['nickname']}、"
+ if existing_names_str:
+ existing_names_str += "、"
+ existing_names_str += generated_nickname
+ logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
+ current_try += 1
- logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
- current_try += 1
-
- logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
+ logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
return None
@staticmethod
@@ -264,30 +331,56 @@ class PersonInfoManager:
logger.debug("删除失败:person_id 不能为空")
return
- result = db.person_info.delete_one({"person_id": person_id})
- if result.deleted_count > 0:
- logger.debug(f"删除成功:person_id={person_id}")
+ def _db_delete_sync(p_id: str):
+ try:
+ query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
+ deleted_count = query.execute()
+ return deleted_count
+ except Exception as e:
+ logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
+ return 0
+
+ deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
+
+ if deleted_count > 0:
+ logger.debug(f"删除成功:person_id={person_id} (Peewee)")
else:
- logger.debug(f"删除失败:未找到 person_id={person_id}")
+ logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_value获取失败:person_id不能为空")
+ return person_info_default.get(field_name)
+
+ if field_name not in PersonInfo._meta.fields:
+ if field_name in person_info_default:
+ logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。")
+ return copy.deepcopy(person_info_default[field_name])
+ logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None
- if field_name not in person_info_default:
- logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
+ def _db_get_value_sync(p_id: str, f_name: str):
+ record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
+ if record:
+ val = getattr(record, f_name)
+ if f_name == "msg_interval_list" and isinstance(val, str):
+ try:
+ return json.loads(val)
+ except json.JSONDecodeError:
+ logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
+ return copy.deepcopy(person_info_default.get(f_name, []))
+ return val
return None
- document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
+ value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
- if document and field_name in document:
- return document[field_name]
+ if value is not None:
+ return value
else:
- default_value = copy.deepcopy(person_info_default[field_name])
- logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
+ default_value = copy.deepcopy(person_info_default.get(field_name))
+ logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value
@staticmethod
@@ -297,93 +390,84 @@ class PersonInfoManager:
logger.debug("get_values获取失败:person_id不能为空")
return {}
- # 检查所有字段是否有效
- for field in field_names:
- if field not in person_info_default:
- logger.debug(f"get_values获取失败:字段'{field}'未定义")
- return {}
-
- # 构建查询投影(所有字段都有效才会执行到这里)
- projection = {field: 1 for field in field_names}
-
- document = db.person_info.find_one({"person_id": person_id}, projection)
-
result = {}
- for field in field_names:
- result[field] = copy.deepcopy(
- document.get(field, person_info_default[field]) if document else person_info_default[field]
- )
+
+ def _db_get_record_sync(p_id: str):
+ return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
+
+ record = await asyncio.to_thread(_db_get_record_sync, person_id)
+
+ for field_name in field_names:
+ if field_name not in PersonInfo._meta.fields:
+ if field_name in person_info_default:
+ result[field_name] = copy.deepcopy(person_info_default[field_name])
+ logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
+ else:
+ logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
+ result[field_name] = None
+ continue
+
+ if record:
+ value = getattr(record, field_name)
+ if field_name == "msg_interval_list" and isinstance(value, str):
+ try:
+ result[field_name] = json.loads(value)
+ except json.JSONDecodeError:
+ logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
+ result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
+ elif value is not None:
+ result[field_name] = value
+ else:
+ result[field_name] = copy.deepcopy(person_info_default.get(field_name))
+ else:
+ result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result
@staticmethod
async def del_all_undefined_field():
- """删除所有项里的未定义字段"""
- # 获取所有已定义的字段名
- defined_fields = set(person_info_default.keys())
-
- try:
- # 遍历集合中的所有文档
- for document in db.person_info.find({}):
- # 找出文档中未定义的字段
- undefined_fields = set(document.keys()) - defined_fields - {"_id"}
-
- if undefined_fields:
- # 构建更新操作,使用$unset删除未定义字段
- update_result = db.person_info.update_one(
- {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
- )
-
- if update_result.modified_count > 0:
- logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
-
- return
-
- except Exception as e:
- logger.error(f"清理未定义字段时出错: {e}")
- return
+ """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
+ logger.info(
+ "del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。"
+ )
+ return
@staticmethod
async def get_specific_value_list(
field_name: str,
- way: Callable[[Any], bool], # 接受任意类型值
+ way: Callable[[Any], bool],
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
-
- Args:
- field_name: 目标字段名
- way: 判断函数 (value: Any) -> bool
-
- Returns:
- {person_id: value} | {}
-
- Example:
- # 查找所有nickname包含"admin"的用户
- result = manager.specific_value_list(
- "nickname",
- lambda x: "admin" in x.lower()
- )
"""
- if field_name not in person_info_default:
- logger.error(f"字段检查失败:'{field_name}'未定义")
+ if field_name not in PersonInfo._meta.fields:
+ logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
return {}
+ def _db_get_specific_sync(f_name: str):
+ found_results = {}
+ try:
+ for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
+ value = getattr(record, f_name)
+ if f_name == "msg_interval_list" and isinstance(value, str):
+ try:
+ processed_value = json.loads(value)
+ except json.JSONDecodeError:
+ logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
+ continue
+ else:
+ processed_value = value
+
+ if way(processed_value):
+ found_results[record.person_id] = processed_value
+ except Exception as e_query:
+ logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
+ return found_results
+
try:
- result = {}
- for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
- try:
- value = doc[field_name]
- if way(value):
- result[doc["person_id"]] = value
- except (KeyError, TypeError, ValueError) as e:
- logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
- continue
-
- return result
-
+ return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e:
- logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
+ logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {}
async def personal_habit_deduction(self):
@@ -391,35 +475,31 @@ class PersonInfoManager:
try:
while 1:
await asyncio.sleep(600)
- current_time = datetime.datetime.now()
- logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
+ current_time_dt = datetime.datetime.now()
+ logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
- # "msg_interval"推断
- msg_interval_map = False
- msg_interval_lists = await self.get_specific_value_list(
+ msg_interval_map_generated = False
+ msg_interval_lists_map = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
)
- for person_id, msg_interval_list_ in msg_interval_lists.items():
+
+ for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3)
try:
time_interval = []
- for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
+ for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1
if delta > 0:
time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000]
- # --- 修改后的逻辑 ---
- # 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
- if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
- time_interval.sort()
- # 画图(log) - 这部分保留
- msg_interval_map = True
+ if len(time_interval) >= 30 + 10:
+ time_interval.sort()
+ msg_interval_map_generated = True
log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
- # 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval)
plt.hist(
time_series_original,
@@ -441,34 +521,29 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
- # 画图结束
- # 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5]
-
- # 计算截断后数据的 37% 分位数
- if trimmed_interval: # 确保截断后列表不为空
- msg_interval = int(round(np.percentile(trimmed_interval, 37)))
- # 更新数据库
- await self.update_one_field(person_id, "msg_interval", msg_interval)
- logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
+ if trimmed_interval:
+ msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
+ await self.update_one_field(person_id, "msg_interval", msg_interval_val)
+ logger.trace(
+ f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
+ )
else:
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
else:
logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
)
- # --- 修改结束 ---
- except Exception as e:
- logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
+ except Exception as e_inner:
+ logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
continue
- # 其他...
-
- if msg_interval_map:
+ if msg_interval_map_generated:
logger.trace("已保存分布图到: logs/person_info")
- current_time = datetime.datetime.now()
- logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
+
+ current_time_dt_end = datetime.datetime.now()
+ logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400)
except Exception as e:
@@ -481,41 +556,27 @@ class PersonInfoManager:
"""
根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。
-
- Args:
- platform: 平台标识
- user_id: 用户在该平台上的ID
- nickname: 用户的昵称 (可选,用于创建新用户)
- user_cardname: 用户的群名片 (可选,用于创建新用户)
- user_avatar: 用户的头像信息 (可选,用于创建新用户)
-
- Returns:
- 对应的 person_id。
"""
person_id = self.get_person_id(platform, user_id)
- # 检查用户是否已存在
- # 使用静态方法 get_person_id,因此可以直接调用 db
- document = db.person_info.find_one({"person_id": person_id})
+ def _db_check_exists_sync(p_id: str):
+ return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
- if document is None:
- logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
+ record = await asyncio.to_thread(_db_check_exists_sync, person_id)
+
+ if record is None:
+ logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = {
"platform": platform,
- "user_id": user_id,
+ "user_id": str(user_id),
"nickname": nickname,
- "konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间
- # 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中
- # 如果需要存储它们,需要先在 person_info_default 中定义
+ "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time
}
- # 过滤掉值为 None 的初始数据
- initial_data = {k: v for k, v in initial_data.items() if v is not None}
+ model_fields = PersonInfo._meta.fields.keys()
+ filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
- # 注意:create_person_info 是静态方法
- await PersonInfoManager.create_person_info(person_id, data=initial_data)
- # 创建后,可以考虑立即为其取名,但这可能会增加延迟
- # await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
- logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
+ await self.create_person_info(person_id, data=filtered_initial_data)
+ logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
return person_id
@@ -525,35 +586,55 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
return None
- # 优先从内存缓存查找 person_id
found_person_id = None
- for pid, name in self.person_name_list.items():
- if name == person_name:
+ for pid, name_in_cache in self.person_name_list.items():
+ if name_in_cache == person_name:
found_person_id = pid
- break # 找到第一个匹配就停止
+ break
if not found_person_id:
- # 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
- document = db.person_info.find_one({"person_name": person_name})
- if document:
- found_person_id = document.get("person_id")
- else:
- logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
- return None # 数据库也找不到
- # 根据找到的 person_id 获取所需信息
- if found_person_id:
- required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
- person_data = await self.get_values(found_person_id, required_fields)
- if person_data: # 确保 get_values 成功返回
- return person_data
+ def _db_find_by_name_sync(p_name_to_find: str):
+ return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
+
+ record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
+ if record:
+ found_person_id = record.person_id
+ if (
+ found_person_id not in self.person_name_list
+ or self.person_name_list[found_person_id] != person_name
+ ):
+ self.person_name_list[found_person_id] = person_name
else:
- logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败")
+ logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None
- else:
- # 这理论上不应该发生,因为上面已经处理了找不到的情况
- logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id")
- return None
+
+ if found_person_id:
+ required_fields = [
+ "person_id",
+ "platform",
+ "user_id",
+ "nickname",
+ "user_cardname",
+ "user_avatar",
+ "person_name",
+ "name_reason",
+ ]
+ valid_fields_to_get = [
+ f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
+ ]
+
+ person_data = await self.get_values(found_person_id, valid_fields_to_get)
+
+ if person_data:
+ final_result = {key: person_data.get(key) for key in required_fields}
+ return final_result
+ else:
+ logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
+ return None
+
+ logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
+ return None
person_info_manager = PersonInfoManager()
diff --git a/src/chat/person_info/relationship_manager.py b/src/chat/person_info/relationship_manager.py
index c8a443857..a23780c0e 100644
--- a/src/chat/person_info/relationship_manager.py
+++ b/src/chat/person_info/relationship_manager.py
@@ -77,7 +77,7 @@ class RelationshipManager:
@staticmethod
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
- is_known = person_info_manager.is_person_known(platform, user_id)
+ is_known = await person_info_manager.is_person_known(platform, user_id)
return is_known
@staticmethod
diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py
index 15b1e4fc6..d662d8c0a 100644
--- a/src/chat/utils/chat_message_builder.py
+++ b/src/chat/utils/chat_message_builder.py
@@ -174,6 +174,16 @@ async def _build_readable_messages_internal(
# 1 & 2: 获取发送者信息并提取消息组件
for msg in messages:
+ # 检查并修复缺少的user_info字段
+ if "user_info" not in msg:
+ # 创建user_info字段
+ msg["user_info"] = {
+ "platform": msg.get("user_platform", ""),
+ "user_id": msg.get("user_id", ""),
+ "user_nickname": msg.get("user_nickname", ""),
+ "user_cardname": msg.get("user_cardname", ""),
+ }
+
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
@@ -190,8 +200,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
- if replace_bot_name and user_id == global_config.BOT_QQ:
- person_name = f"{global_config.BOT_NICKNAME}(你)"
+ if replace_bot_name and user_id == global_config.bot.qq_account:
+ person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +437,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = []
def get_anon_name(platform, user_id):
- if user_id == global_config.BOT_QQ:
+ if user_id == global_config.bot.qq_account:
return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map:
@@ -451,10 +461,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 回复
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
- def reply_replacer(match):
+ def reply_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
- anon_reply = get_anon_name(platform, bbb)
+ anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -462,10 +472,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 @
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
- def at_replacer(match):
+ def at_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
- anon_at = get_anon_name(platform, bbb)
+ anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content)
@@ -501,7 +511,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
- if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
+ if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)
diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py
index 174bb5b49..93cda5113 100644
--- a/src/chat/utils/info_catcher.py
+++ b/src/chat/utils/info_catcher.py
@@ -1,15 +1,15 @@
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
-from src.common.database import db
+from src.common.database.database_model import Messages, ThinkingLog
import time
import traceback
from typing import List
+import json
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
- self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -60,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration
- # def catch_shf
-
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1:
@@ -72,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
- # if self.response_mode == "heart_flow": # 条件判断不需要了喵~
- # self.heartflow_data["prompt"] = prompt
- # self.heartflow_data["response"] = response
- # self.heartflow_data["model"] = model_name
- # elif self.response_mode == "reasoning": # 条件判断不需要了喵~
- # self.reasoning_data["thinking_log"] = reasoning_content
- # self.reasoning_data["prompt"] = prompt
- # self.reasoning_data["response"] = response
- # self.reasoning_data["model"] = model_name
-
- # 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name
- # 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
- # self.heartflow_data["prompt"] = prompt
- # self.heartflow_data["response"] = response
- # self.heartflow_data["model"] = model_name
self.response_text = response
@@ -102,6 +85,7 @@ class InfoCatcher:
):
self.timing_results["make_response_time"] = response_duration
self.response_time = time.time()
+ self.response_messages = []
for msg in response_message:
self.response_messages.append(msg)
@@ -112,107 +96,112 @@ class InfoCatcher:
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
- # 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
- # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
- messages_between = db.messages.find(
- {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
- ).sort("time", -1)
+ messages_between_query = (
+ Messages.select()
+ .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
+ .order_by(Messages.time.desc())
+ )
- result = list(messages_between)
+ result = list(messages_between_query)
print(f"查询结果数量: {len(result)}")
if result:
- print(f"第一条消息时间: {result[0]['time']}")
- print(f"最后一条消息时间: {result[-1]['time']}")
+ print(f"第一条消息时间: {result[0].time}")
+ print(f"最后一条消息时间: {result[-1].time}")
return result
except Exception as e:
print(f"获取消息时出错: {str(e)}")
+ print(traceback.format_exc())
return []
def get_message_from_db_before_msg(self, message: MessageRecv):
- # 从数据库中获取消息
- message_id = message.message_info.message_id
- chat_id = message.chat_stream.stream_id
+ message_id_val = message.message_info.message_id
+ chat_id_val = message.chat_stream.stream_id
- # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
- messages_before = (
- db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
- .sort("time", -1)
- .limit(self.context_length * 3)
- ) # 获取更多历史信息
+ messages_before_query = (
+ Messages.select()
+ .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
+ .order_by(Messages.time.desc())
+ .limit(global_config.chat.observation_context_size * 3)
+ )
- return list(messages_before)
+ return list(messages_before_query)
def message_list_to_dict(self, message_list):
- # 存储简化的聊天记录
result = []
- for message in message_list:
- if not isinstance(message, dict):
- message = self.message_to_dict(message)
- # print(message)
+ for msg_item in message_list:
+ processed_msg_item = msg_item
+ if not isinstance(msg_item, dict):
+ processed_msg_item = self.message_to_dict(msg_item)
+
+ if not processed_msg_item:
+ continue
lite_message = {
- "time": message["time"],
- "user_nickname": message["user_info"]["user_nickname"],
- "processed_plain_text": message["processed_plain_text"],
+ "time": processed_msg_item.get("time"),
+ "user_nickname": processed_msg_item.get("user_nickname"),
+ "processed_plain_text": processed_msg_item.get("processed_plain_text"),
}
result.append(lite_message)
-
return result
@staticmethod
- def message_to_dict(message):
- if not message:
+ def message_to_dict(msg_obj):
+ if not msg_obj:
return None
- if isinstance(message, dict):
- return message
- return {
- # "message_id": message.message_info.message_id,
- "time": message.message_info.time,
- "user_id": message.message_info.user_info.user_id,
- "user_nickname": message.message_info.user_info.user_nickname,
- "processed_plain_text": message.processed_plain_text,
- # "detailed_plain_text": message.detailed_plain_text
- }
+ if isinstance(msg_obj, dict):
+ return msg_obj
- def done_catch(self):
- """将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
- try:
- # 将消息对象转换为可序列化的字典喵~
-
- thinking_log_data = {
- "chat_id": self.chat_id,
- "trigger_text": self.trigger_response_text,
- "response_text": self.response_text,
- "trigger_info": {
- "time": self.trigger_response_time,
- "message": self.message_to_dict(self.trigger_response_message),
- },
- "response_info": {
- "time": self.response_time,
- "message": self.response_messages,
- },
- "timing_results": self.timing_results,
- "chat_history": self.message_list_to_dict(self.chat_history),
- "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
- "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
- "heartflow_data": self.heartflow_data,
- "reasoning_data": self.reasoning_data,
+ if isinstance(msg_obj, Messages):
+ return {
+ "time": msg_obj.time,
+ "user_id": msg_obj.user_id,
+ "user_nickname": msg_obj.user_nickname,
+ "processed_plain_text": msg_obj.processed_plain_text,
}
- # 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
- # if self.response_mode == "heart_flow":
- # thinking_log_data["mode_specific_data"] = self.heartflow_data
- # elif self.response_mode == "reasoning":
- # thinking_log_data["mode_specific_data"] = self.reasoning_data
+ if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
+ return {
+ "time": msg_obj.message_info.time,
+ "user_id": msg_obj.message_info.user_info.user_id,
+ "user_nickname": msg_obj.message_info.user_info.user_nickname,
+ "processed_plain_text": msg_obj.processed_plain_text,
+ }
- # 将数据插入到 thinking_log 集合中喵~
- db.thinking_log.insert_one(thinking_log_data)
+ print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
+ return {}
+
+ def done_catch(self):
+ """将收集到的信息存储到数据库的 thinking_log 表中喵~"""
+ try:
+ trigger_info_dict = self.message_to_dict(self.trigger_response_message)
+ response_info_dict = {
+ "time": self.response_time,
+ "message": self.response_messages,
+ }
+ chat_history_list = self.message_list_to_dict(self.chat_history)
+ chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
+ chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
+
+ log_entry = ThinkingLog(
+ chat_id=self.chat_id,
+ trigger_text=self.trigger_response_text,
+ response_text=self.response_text,
+ trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
+ response_info_json=json.dumps(response_info_dict),
+ timing_results_json=json.dumps(self.timing_results),
+ chat_history_json=json.dumps(chat_history_list),
+ chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
+ chat_history_after_response_json=json.dumps(chat_history_after_response_list),
+ heartflow_data_json=json.dumps(self.heartflow_data),
+ reasoning_data_json=json.dumps(self.reasoning_data),
+ )
+ log_entry.save()
return True
except Exception as e:
diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index 3f9832926..a657ae85b 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
+
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
-from ...common.database import db
+from ...common.database.database import db # This db is the Peewee database instance
+from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
- self.record_id: str | None = None
+ self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID"""
self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod
def _init_database():
"""初始化数据库"""
- if "online_time" not in db.list_collection_names():
- # 初始化数据库(在线时长)
- db.create_collection("online_time")
- # 创建索引
- if ("end_timestamp", 1) not in db.online_time.list_indexes():
- db.online_time.create_index([("end_timestamp", 1)])
+ with db.atomic(): # Use atomic operations for schema changes
+ OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
try:
+ current_time = datetime.now()
+ extended_end_time = current_time + timedelta(minutes=1)
+
if self.record_id:
# 如果有记录,则更新结束时间
- db.online_time.update_one(
- {"_id": self.record_id},
- {
- "$set": {
- "end_timestamp": datetime.now() + timedelta(minutes=1),
- }
- },
- )
- else:
+ query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
+ updated_rows = query.execute()
+ if updated_rows == 0:
+ # Record might have been deleted or ID is stale, try to find/create
+ self.record_id = None # Reset record_id to trigger find/create logic below
+
+ if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
- current_time = datetime.now()
- if recent_record := db.online_time.find_one(
- {"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
- ):
+ # Look for a record whose end_timestamp is recent enough to be considered ongoing
+ recent_record = (
+ OnlineTime.select()
+ .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
+ .order_by(OnlineTime.end_timestamp.desc())
+ .first()
+ )
+
+ if recent_record:
# 如果有记录,则更新结束时间
- self.record_id = recent_record["_id"]
- db.online_time.update_one(
- {"_id": self.record_id},
- {
- "$set": {
- "end_timestamp": current_time + timedelta(minutes=1),
- }
- },
- )
+ self.record_id = recent_record.id
+ recent_record.end_timestamp = extended_end_time
+ recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
- self.record_id = db.online_time.insert_one(
- {
- "start_timestamp": current_time,
- "end_timestamp": current_time + timedelta(minutes=1),
- }
- ).inserted_id
+ new_record = OnlineTime.create(
+ timestamp=current_time.timestamp(), # 添加此行
+ start_timestamp=current_time,
+ end_timestamp=extended_end_time,
+ duration=5, # 初始时长为5分钟
+ )
+ self.record_id = new_record.id
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
- if len(collect_period) <= 0:
+ if not collect_period:
return {}
- else:
- # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
- collect_period.sort(key=lambda x: x[1], reverse=True)
+
+ # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
- # 总LLM请求数
TOTAL_REQ_CNT: 0,
- # 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
- # 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
- # 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
- # 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
- # 总开销
TOTAL_COST: 0.0,
- # 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
- for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
- record_timestamp = record.get("timestamp")
+ # Assuming LLMUsage.timestamp is a DateTimeField
+ query_start_time = collect_period[-1][1]
+ for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
+ record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
- # 如果记录时间在当前时间段内,则它一定在更早的时间段内
- # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
- request_type = record.get("request_type", "unknown") # 请求类型
- user_id = str(record.get("user_id", "unknown")) # 用户ID
- model_name = record.get("model_name", "unknown") # 模型名称
+ request_type = record.request_type or "unknown"
+ user_id = record.user_id or "unknown" # user_id is TextField, already string
+ model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
- prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
- completion_tokens = record.get("completion_tokens", 0) # 输出Token数
- total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
+ prompt_tokens = record.prompt_tokens or 0
+ completion_tokens = record.completion_tokens or 0
+ total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
- cost = record.get("cost", 0.0)
+ cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
- break # 取消更早时间段的判断
-
+ break
return stats
@staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
- if len(collect_period) <= 0:
+ if not collect_period:
return {}
- else:
- # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
- collect_period.sort(key=lambda x: x[1], reverse=True)
+
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
- # 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
- # 统计在线时间
- for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
- end_timestamp: datetime = record.get("end_timestamp")
- for idx, (_, period_start) in enumerate(collect_period):
- if end_timestamp >= period_start:
- # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间
- end_timestamp = min(end_timestamp, now)
- # 如果记录时间在当前时间段内,则它一定在更早的时间段内
- # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
- for period_key, _period_start in collect_period[idx:]:
- start_timestamp: datetime = record.get("start_timestamp")
- if start_timestamp < _period_start:
- # 如果开始时间在查询边界之前,则使用开始时间
- stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
- else:
- # 否则,使用开始时间
- stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
- break # 取消更早时间段的判断
+ query_start_time = collect_period[-1][1]
+ # Assuming OnlineTime.end_timestamp is a DateTimeField
+ for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
+ # record.end_timestamp and record.start_timestamp are datetime objects
+ record_end_timestamp = record.end_timestamp
+ record_start_timestamp = record.start_timestamp
+ for idx, (_, period_boundary_start) in enumerate(collect_period):
+ if record_end_timestamp >= period_boundary_start:
+ # Calculate effective end time for this record in relation to 'now'
+ effective_end_time = min(record_end_timestamp, now)
+
+ for period_key, current_period_start_time in collect_period[idx:]:
+ # Determine the portion of the record that falls within this specific statistical period
+ overlap_start = max(record_start_timestamp, current_period_start_time)
+ overlap_end = effective_end_time # Already capped by 'now' and record's own end
+
+ if overlap_end > overlap_start:
+ stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
+ break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
- if len(collect_period) <= 0:
+ if not collect_period:
return {}
- else:
- # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
- collect_period.sort(key=lambda x: x[1], reverse=True)
+
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
- # 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
- # 统计消息量
- for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
- chat_info = message.get("chat_info", None) # 聊天信息
- user_info = message.get("user_info", None) # 用户信息(消息发送人)
- message_time = message.get("time", 0) # 消息时间
+ query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
+ for message in Messages.select().where(Messages.time >= query_start_timestamp):
+ message_time_ts = message.time # This is a float timestamp
- group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
- if group_info is not None:
- # 若有群聊信息
- chat_id = f"g{group_info.get('group_id')}"
- chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
- elif user_info:
- # 若没有群聊信息,则尝试获取用户信息
- chat_id = f"u{user_info['user_id']}"
- chat_name = user_info["user_nickname"]
+ chat_id = None
+ chat_name = None
+
+ # Logic based on Peewee model structure, aiming to replicate original intent
+ if message.chat_info_group_id:
+ chat_id = f"g{message.chat_info_group_id}"
+ chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
+ elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
+ # This uses the message SENDER's ID as per original logic's fallback
+ chat_id = f"u{message.user_id}" # SENDER's user_id
+ chat_name = message.user_nickname # SENDER's nickname
else:
- continue # 如果没有群组信息也没有用户信息,则跳过
+ # If neither group_id nor sender_id is available for chat identification
+ logger.warning(
+ f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
+ )
+ continue
+ if not chat_id: # Should not happen if above logic is correct
+ continue
+
+ # Update name_mapping
if chat_id in self.name_mapping:
- if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
- # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
- self.name_mapping[chat_id] = (chat_name, message_time)
+ if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
+ self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
- self.name_mapping[chat_id] = (chat_name, message_time)
+ self.name_mapping[chat_id] = (chat_name, message_time_ts)
- for idx, (_, period_start) in enumerate(collect_period):
- if message_time >= period_start.timestamp():
- # 如果记录时间在当前时间段内,则它一定在更早的时间段内
- # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
+ for idx, (_, period_start_dt) in enumerate(collect_period):
+ if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
-
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py
index 8fe8334b8..6d9ce0719 100644
--- a/src/chat/utils/utils.py
+++ b/src/chat/utils/utils.py
@@ -13,8 +13,10 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
-from ...common.database import db
+from ...common.database.database import db
from ...config.config import global_config
+from ...common.database.database_model import Messages
+from ...common.message_repository import find_messages, count_messages
logger = get_module_logger("chat_utils")
@@ -43,8 +45,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
- keywords = [global_config.BOT_NICKNAME]
- nicknames = global_config.BOT_ALIAS_NAMES
+ keywords = [global_config.bot.nickname]
+ nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
@@ -64,18 +66,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
)
# 判断是否被@
- if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text):
+ if re.search(f"@[\s\S]*?(id:{global_config.bot.qq_account})", message.processed_plain_text):
is_at = True
is_mentioned = True
- if is_at and global_config.at_bot_inevitable_reply:
+ if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@,回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
- f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text
+ f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\):[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
@@ -88,7 +90,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
- if is_mentioned and global_config.mentioned_bot_inevitable_reply:
+ if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及,回复概率设置为100%")
return is_mentioned, reply_probability
@@ -96,7 +98,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
- llm = LLMRequest(model=global_config.embedding, request_type=request_type)
+ # TODO: API-Adapter修改标记
+ llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
@@ -107,20 +110,12 @@ async def get_embedding(text, request_type="embedding"):
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
- recent_messages = list(
- db.messages.find(
- {"chat_id": chat_stream_id},
- {
- "time": 1, # 返回时间字段
- "chat_id": 1,
- "chat_info": 1,
- "user_info": 1,
- "message_id": 1, # 返回消息ID字段
- "detailed_plain_text": 1, # 返回处理后的文本字段
- },
- )
- .sort("time", -1)
- .limit(limit)
+ filter_query = {"chat_id": chat_stream_id}
+ sort_order = [("time", -1)]
+ recent_messages = find_messages(
+ message_filter=filter_query,
+ sort=sort_order,
+ limit=limit
)
if not recent_messages:
@@ -142,17 +137,14 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
return message_detailed_plain_text_list
-def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> list:
+def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
- recent_messages = list(
- db.messages.find(
- {"chat_id": chat_stream_id},
- {
- "user_info": 1,
- },
- )
- .sort("time", -1)
- .limit(limit)
+ filter_query = {"chat_id": chat_stream_id}
+ sort_order = [("time", -1)]
+ recent_messages = find_messages(
+ message_filter=filter_query,
+ sort=sort_order,
+ limit=limit
)
if not recent_messages:
@@ -160,10 +152,15 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
who_chat_in_group = []
for msg_db_data in recent_messages:
- user_info = UserInfo.from_dict(msg_db_data["user_info"])
+ user_info = UserInfo.from_dict({
+ "platform": msg_db_data["user_platform"],
+ "user_id": msg_db_data["user_id"],
+ "user_nickname": msg_db_data["user_nickname"],
+ "user_cardname": msg_db_data.get("user_cardname", "")
+ })
if (
(user_info.platform, user_info.user_id) != sender
- and user_info.user_id != global_config.BOT_QQ
+ and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
@@ -321,7 +318,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
- if global_config.enable_kaomoji_protection:
+ if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
@@ -340,8 +337,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
- max_length = global_config.response_max_length * 2
- max_sentence_num = global_config.response_max_sentence_num
+ max_length = global_config.response_splitter.max_length * 2
+ max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
@@ -349,20 +346,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
- error_rate=global_config.chinese_typo_error_rate,
- min_freq=global_config.chinese_typo_min_freq,
- tone_error_rate=global_config.chinese_typo_tone_error_rate,
- word_replace_rate=global_config.chinese_typo_word_replace_rate,
+ error_rate=global_config.chinese_typo.error_rate,
+ min_freq=global_config.chinese_typo.min_freq,
+ tone_error_rate=global_config.chinese_typo.tone_error_rate,
+ word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
- if global_config.enable_response_splitter:
+ if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
- if global_config.chinese_typo_enable:
+ if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
@@ -372,14 +369,14 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
- return [f"{global_config.BOT_NICKNAME}不知道哦"]
+ return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:
# sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
- if global_config.enable_kaomoji_protection:
+ if global_config.response_splitter.enable_kaomoji_protection:
sentences = recover_kaomoji(sentences, kaomoji_mapping)
return sentences
@@ -580,26 +577,23 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
logger.error("stream_id 不能为空")
return 0, 0
- # 直接查询时间范围内的消息
- # time > start_time AND time <= end_time
- query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
+ # 使用message_repository中的count_messages和find_messages函数
+
+
+ # 构建查询条件
+ filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
- # 执行查询
- messages_cursor = db.messages.find(query)
+ # 先获取消息数量
+ count = count_messages(filter_query)
+
+ # 获取消息内容计算总长度
+ messages = find_messages(message_filter=filter_query)
+ total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
- # 遍历结果计算数量和长度
- for msg in messages_cursor:
- count += 1
- total_length += len(msg.get("processed_plain_text", ""))
-
- # logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}")
return count, total_length
- except PyMongoError as e:
- logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}")
- return 0, 0
- except Exception as e: # 保留一个通用异常捕获以防万一
+ except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py
index 455038246..c317fbbd6 100644
--- a/src/chat/utils/utils_image.py
+++ b/src/chat/utils/utils_image.py
@@ -8,7 +8,8 @@ import io
import numpy as np
-from ...common.database import db
+from ...common.database.database import db
+from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config
from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self):
if not self._initialized:
- self._ensure_image_collection()
- self._ensure_description_collection()
self._ensure_image_dir()
+
+ self._initialized = True
+ self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
+
+ try:
+ db.connect(reuse_if_open=True)
+ db.create_tables([Images, ImageDescriptions], safe=True)
+ except Exception as e:
+ logger.error(f"数据库连接或表创建失败: {e}")
+
self._initialized = True
- self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
- @staticmethod
- def _ensure_image_collection():
- """确保images集合存在并创建索引"""
- if "images" not in db.list_collection_names():
- db.create_collection("images")
-
- # 删除旧索引
- db.images.drop_indexes()
- # 创建新的复合索引
- db.images.create_index([("hash", 1), ("type", 1)], unique=True)
- db.images.create_index([("url", 1)])
- db.images.create_index([("path", 1)])
-
- @staticmethod
- def _ensure_description_collection():
- """确保image_descriptions集合存在并创建索引"""
- if "image_descriptions" not in db.list_collection_names():
- db.create_collection("image_descriptions")
-
- # 删除旧索引
- db.image_descriptions.drop_indexes()
- # 创建新的复合索引
- db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
-
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns:
Optional[str]: 描述文本,如果不存在则返回None
"""
- result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
- return result["description"] if result else None
+ try:
+ record = ImageDescriptions.get_or_none(
+ (ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
+ )
+ return record.description if record else None
+ except Exception as e:
+ logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
+ return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji' 或 'image')
"""
try:
- db.image_descriptions.update_one(
- {"hash": image_hash, "type": description_type},
- {
- "$set": {
- "description": description,
- "timestamp": int(time.time()),
- "hash": image_hash, # 确保hash字段存在
- "type": description_type, # 确保type字段存在
- }
- },
- upsert=True,
+ current_timestamp = time.time()
+ defaults = {"description": description, "timestamp": current_timestamp}
+ desc_obj, created = ImageDescriptions.get_or_create(
+ hash=image_hash, type=description_type, defaults=defaults
)
+ if not created: # 如果记录已存在,则更新
+ desc_obj.description = description
+ desc_obj.timestamp = current_timestamp
+ desc_obj.save()
except Exception as e:
- logger.error(f"保存描述到数据库失败: {str(e)}")
+ logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
@@ -116,51 +103,64 @@ class ImageManager:
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
- # logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
- image_base64 = self.transform_gif(image_base64)
+ image_base64_processed = self.transform_gif(image_base64)
+ if image_base64_processed is None:
+ logger.warning("GIF转换失败,无法获取描述")
+ return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
- description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
+ description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
+ if description is None:
+ logger.warning("AI未能生成表情包描述")
+ return "[表情包(描述生成失败)]"
+
+ # 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
- if global_config.save_emoji:
+ if global_config.emoji.save_emoji:
# 生成文件名和路径
- timestamp = int(time.time())
- filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
- if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
- os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
- file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
+ current_timestamp = time.time()
+ filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
+ emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
+ os.makedirs(emoji_dir, exist_ok=True)
+ file_path = os.path.join(emoji_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
- # 保存到数据库
- image_doc = {
- "hash": image_hash,
- "path": file_path,
- "type": "emoji",
- "description": description,
- "timestamp": timestamp,
- }
- db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
- logger.trace(f"保存表情包: {file_path}")
+ # 保存到数据库 (Images表)
+ try:
+ img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
+ img_obj.path = file_path
+ img_obj.description = description
+ img_obj.timestamp = current_timestamp
+ img_obj.save()
+ except Images.DoesNotExist:
+ Images.create(
+ hash=image_hash,
+ path=file_path,
+ type="emoji",
+ description=description,
+ timestamp=current_timestamp,
+ )
+ logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e:
- logger.error(f"保存表情包文件失败: {str(e)}")
+ logger.error(f"保存表情包文件或元数据失败: {str(e)}")
- # 保存描述到数据库
+ # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
+ if description is None:
+ logger.warning("AI未能生成图片描述")
+ return "[图片(描述生成失败)]"
+
+ # 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}")
- if description is None:
- logger.warning("AI未能生成图片描述")
- return "[图片]"
-
# 根据配置决定是否保存图片
- if global_config.save_pic:
+ if global_config.emoji.save_pic:
# 生成文件名和路径
- timestamp = int(time.time())
- filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
- if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
- os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
- file_path = os.path.join(self.IMAGE_DIR, "image", filename)
+ current_timestamp = time.time()
+ filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
+ image_dir = os.path.join(self.IMAGE_DIR, "image")
+ os.makedirs(image_dir, exist_ok=True)
+ file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
- # 保存到数据库
- image_doc = {
- "hash": image_hash,
- "path": file_path,
- "type": "image",
- "description": description,
- "timestamp": timestamp,
- }
- db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
- logger.trace(f"保存图片: {file_path}")
+ # 保存到数据库 (Images表)
+ try:
+ img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
+ img_obj.path = file_path
+ img_obj.description = description
+ img_obj.timestamp = current_timestamp
+ img_obj.save()
+ except Images.DoesNotExist:
+ Images.create(
+ hash=image_hash,
+ path=file_path,
+ type="image",
+ description=description,
+ timestamp=current_timestamp,
+ )
+ logger.trace(f"保存图片元数据: {file_path}")
except Exception as e:
- logger.error(f"保存图片文件失败: {str(e)}")
+ logger.error(f"保存图片文件或元数据失败: {str(e)}")
- # 保存描述到数据库
+ # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]"
diff --git a/src/chat/zhishi/knowledge_library.py b/src/chat/zhishi/knowledge_library.py
index 6fa1d3e1a..0068a153c 100644
--- a/src/chat/zhishi/knowledge_library.py
+++ b/src/chat/zhishi/knowledge_library.py
@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
-from src.common.database import db # noqa E402
+from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件
diff --git a/src/common/database.py b/src/common/database/database.py
similarity index 81%
rename from src/common/database.py
rename to src/common/database/database.py
index 752f746db..a2dab739d 100644
--- a/src/common/database.py
+++ b/src/common/database/database.py
@@ -1,5 +1,6 @@
import os
from pymongo import MongoClient
+from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点
-db: Database = DBWrapper()
+memory_db: Database = DBWrapper()
+
+# 定义数据库文件路径
+ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
+_DB_DIR = os.path.join(ROOT_PATH, "data")
+_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
+
+# 确保数据库目录存在
+os.makedirs(_DB_DIR, exist_ok=True)
+
+# 全局 Peewee SQLite 数据库访问点
+db = SqliteDatabase(_DB_FILE)
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
new file mode 100644
index 000000000..3544a8be0
--- /dev/null
+++ b/src/common/database/database_model.py
@@ -0,0 +1,393 @@
+from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
+from .database import db
+import datetime
+from ..logger_manager import get_logger
+
+logger = get_logger("database_model")
+# 请在此处定义您的数据库实例。
+# 您需要取消注释并配置适合您的数据库的部分。
+# 例如,对于 SQLite:
+# db = SqliteDatabase('MaiBot.db')
+#
+# 对于 PostgreSQL:
+# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
+# host='localhost', port=5432)
+#
+# 对于 MySQL:
+# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
+# host='localhost', port=3306)
+
+
+# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
+# 这允许您在一个地方为所有模型指定数据库。
+class BaseModel(Model):
+ class Meta:
+ # 将下面的 'db' 替换为您实际的数据库实例变量名。
+ database = db # 例如: database = my_actual_db_instance
+ pass # 在用户定义数据库实例之前,此处为占位符
+
+
+class ChatStreams(BaseModel):
+ """
+ 用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
+ """
+
+ # stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
+ # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
+ stream_id = TextField(unique=True, index=True)
+
+ # create_time: 1746096761.4490178 (时间戳,精确到小数点后7位)
+ # DoubleField 用于存储浮点数,适合此类时间戳。
+ create_time = DoubleField()
+
+ # group_info 字段:
+ # platform: "qq"
+ # group_id: "941657197"
+ # group_name: "测试"
+ group_platform = TextField()
+ group_id = TextField()
+ group_name = TextField()
+
+ # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位)
+ last_active_time = DoubleField()
+
+ # platform: "qq" (顶层平台字段)
+ platform = TextField()
+
+ # user_info 字段:
+ # platform: "qq"
+ # user_id: "1787882683"
+ # user_nickname: "墨梓柒(IceSakurary)"
+ # user_cardname: ""
+ user_platform = TextField()
+ user_id = TextField()
+ user_nickname = TextField()
+ # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
+ user_cardname = TextField(null=True)
+
+ class Meta:
+ # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
+ # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
+ # 请取消注释并在下面设置数据库实例:
+ # database = db
+ table_name = "chat_streams" # 可选:明确指定数据库中的表名
+
+
+class LLMUsage(BaseModel):
+ """
+ 用于存储 API 使用日志数据的模型。
+ """
+
+ model_name = TextField(index=True) # 添加索引
+ user_id = TextField(index=True) # 添加索引
+ request_type = TextField(index=True) # 添加索引
+ endpoint = TextField()
+ prompt_tokens = IntegerField()
+ completion_tokens = IntegerField()
+ total_tokens = IntegerField()
+ cost = DoubleField()
+ status = TextField()
+ timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
+
+ class Meta:
+ # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
+ # database = db
+ table_name = "llm_usage"
+
+
+class Emoji(BaseModel):
+ """表情包"""
+
+ full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
+ format = TextField() # 图片格式
+ emoji_hash = TextField(index=True) # 表情包的哈希值
+ description = TextField() # 表情包的描述
+ query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
+ is_registered = BooleanField(default=False) # 是否已注册
+ is_banned = BooleanField(default=False) # 是否被禁止注册
+ # emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
+ emotion = TextField(null=True)
+ record_time = FloatField() # 记录时间(被创建的时间)
+ register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
+ usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
+ last_used_time = FloatField(null=True) # 上次使用时间
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "emoji"
+
+
+class Messages(BaseModel):
+ """
+ 用于存储消息数据的模型。
+ """
+
+ message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
+ time = DoubleField() # 消息时间戳
+
+ chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
+
+ # 从 chat_info 扁平化而来的字段
+ chat_info_stream_id = TextField()
+ chat_info_platform = TextField()
+ chat_info_user_platform = TextField()
+ chat_info_user_id = TextField()
+ chat_info_user_nickname = TextField()
+ chat_info_user_cardname = TextField(null=True)
+ chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
+ chat_info_group_id = TextField(null=True)
+ chat_info_group_name = TextField(null=True)
+ chat_info_create_time = DoubleField()
+ chat_info_last_active_time = DoubleField()
+
+ # 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
+ user_platform = TextField()
+ user_id = TextField()
+ user_nickname = TextField()
+ user_cardname = TextField(null=True)
+
+ processed_plain_text = TextField(null=True) # 处理后的纯文本消息
+ detailed_plain_text = TextField(null=True) # 详细的纯文本消息
+ memorized_times = IntegerField(default=0) # 被记忆的次数
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "messages"
+
+
+class Images(BaseModel):
+ """
+ 用于存储图像信息的模型。
+ """
+
+ emoji_hash = TextField(index=True) # 图像的哈希值
+ description = TextField(null=True) # 图像的描述
+ path = TextField(unique=True) # 图像文件的路径
+ timestamp = FloatField() # 时间戳
+ type = TextField() # 图像类型,例如 "emoji"
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "images"
+
+
+class ImageDescriptions(BaseModel):
+ """
+ 用于存储图像描述信息的模型。
+ """
+
+ type = TextField() # 类型,例如 "emoji"
+ image_description_hash = TextField(index=True) # 图像的哈希值
+ description = TextField() # 图像的描述
+ timestamp = FloatField() # 时间戳
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "image_descriptions"
+
+
+class OnlineTime(BaseModel):
+ """
+ 用于存储在线时长记录的模型。
+ """
+
+ # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
+ timestamp = TextField(default=datetime.datetime.now) # 时间戳
+ duration = IntegerField() # 时长,单位分钟
+ start_timestamp = DateTimeField(default=datetime.datetime.now)
+ end_timestamp = DateTimeField(index=True)
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "online_time"
+
+
+class PersonInfo(BaseModel):
+ """
+ 用于存储个人信息数据的模型。
+ """
+
+ person_id = TextField(unique=True, index=True) # 个人唯一ID
+ person_name = TextField(null=True) # 个人名称 (允许为空)
+ name_reason = TextField(null=True) # 名称设定的原因
+ platform = TextField() # 平台
+ user_id = TextField(index=True) # 用户ID
+ nickname = TextField() # 用户昵称
+ relationship_value = IntegerField(default=0) # 关系值
+ know_time = FloatField() # 认识时间 (时间戳)
+ msg_interval = IntegerField() # 消息间隔
+ # msg_interval_list: 存储为 JSON 字符串的列表
+ msg_interval_list = TextField(null=True)
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "person_info"
+
+
+class Knowledges(BaseModel):
+ """
+ 用于存储知识库条目的模型。
+ """
+
+ content = TextField() # 知识内容的文本
+ embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
+ # 可以添加其他元数据字段,如 source, create_time 等
+
+ class Meta:
+ # database = db # 继承自 BaseModel
+ table_name = "knowledges"
+
+
+class ThinkingLog(BaseModel):
+ chat_id = TextField(index=True)
+ trigger_text = TextField(null=True)
+ response_text = TextField(null=True)
+
+ # Store complex dicts/lists as JSON strings
+ trigger_info_json = TextField(null=True)
+ response_info_json = TextField(null=True)
+ timing_results_json = TextField(null=True)
+ chat_history_json = TextField(null=True)
+ chat_history_in_thinking_json = TextField(null=True)
+ chat_history_after_response_json = TextField(null=True)
+ heartflow_data_json = TextField(null=True)
+ reasoning_data_json = TextField(null=True)
+
+ # Add a timestamp for the log entry itself
+ # Ensure you have: from peewee import DateTimeField
+ # And: import datetime
+ created_at = DateTimeField(default=datetime.datetime.now)
+
+ class Meta:
+ table_name = "thinking_logs"
+
+
+class RecalledMessages(BaseModel):
+ """
+ 用于存储撤回消息记录的模型。
+ """
+
+ message_id = TextField(index=True) # 被撤回的消息 ID
+ time = DoubleField() # 撤回操作发生的时间戳
+ stream_id = TextField() # 对应的 ChatStreams stream_id
+
+ class Meta:
+ table_name = "recalled_messages"
+
+
+class GraphNodes(BaseModel):
+ """
+ 用于存储记忆图节点的模型
+ """
+
+ concept = TextField(unique=True, index=True) # 节点概念
+ memory_items = TextField() # JSON格式存储的记忆列表
+ hash = TextField() # 节点哈希值
+ created_time = FloatField() # 创建时间戳
+ last_modified = FloatField() # 最后修改时间戳
+
+ class Meta:
+ table_name = "graph_nodes"
+
+
+class GraphEdges(BaseModel):
+ """
+ 用于存储记忆图边的模型
+ """
+
+ source = TextField(index=True) # 源节点
+ target = TextField(index=True) # 目标节点
+ strength = IntegerField() # 连接强度
+ hash = TextField() # 边哈希值
+ created_time = FloatField() # 创建时间戳
+ last_modified = FloatField() # 最后修改时间戳
+
+ class Meta:
+ table_name = "graph_edges"
+
+
+def create_tables():
+ """
+ 创建所有在模型中定义的数据库表。
+ """
+ with db:
+ db.create_tables(
+ [
+ ChatStreams,
+ LLMUsage,
+ Emoji,
+ Messages,
+ Images,
+ ImageDescriptions,
+ OnlineTime,
+ PersonInfo,
+ Knowledges,
+ ThinkingLog,
+ RecalledMessages, # 添加新模型
+ GraphNodes, # 添加图节点表
+ GraphEdges, # 添加图边表
+ ]
+ )
+
+
+def initialize_database():
+ """
+ 检查所有定义的表是否存在,如果不存在则创建它们。
+ 检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
+ """
+ import sys
+
+ models = [
+ ChatStreams,
+ LLMUsage,
+ Emoji,
+ Messages,
+ Images,
+ ImageDescriptions,
+ OnlineTime,
+ PersonInfo,
+ Knowledges,
+ ThinkingLog,
+ RecalledMessages,
+ GraphNodes, # 添加图节点表
+ GraphEdges, # 添加图边表
+ ]
+
+ needs_creation = False
+ try:
+ with db: # 管理 table_exists 检查的连接
+ for model in models:
+ table_name = model._meta.table_name
+ if not db.table_exists(model):
+ logger.warning(f"表 '{table_name}' 未找到。")
+ needs_creation = True
+ break # 一个表丢失,无需进一步检查。
+ if not needs_creation:
+ # 检查字段
+ for model in models:
+ table_name = model._meta.table_name
+ cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
+ existing_columns = {row[1] for row in cursor.fetchall()}
+ model_fields = model._meta.fields
+ for field_name in model_fields:
+ if field_name not in existing_columns:
+ logger.error(f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
+ sys.exit(1)
+ except Exception as e:
+ logger.exception(f"检查表或字段是否存在时出错: {e}")
+ # 如果检查失败(例如数据库不可用),则退出
+ return
+
+ if needs_creation:
+ logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
+ try:
+ create_tables() # 此函数有其自己的 'with db:' 上下文管理。
+ logger.info("数据库表创建过程完成。")
+ except Exception as e:
+ logger.exception(f"创建表期间出错: {e}")
+ else:
+ logger.info("所有数据库表及字段均已存在。")
+
+
+# 模块加载时调用初始化函数
+initialize_database()
diff --git a/src/common/logger.py b/src/common/logger.py
index 61d4d3559..6c11b09d8 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -276,6 +276,40 @@ CHAT_STYLE_CONFIG = {
},
}
+# Topic日志样式配置
+NORMAL_CHAT_STYLE_CONFIG = {
+ "advanced": {
+ "console_format": (
+ "{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | "
+ "一般水群 | "
+ "{message}"
+ ),
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
+ },
+ "simple": {
+ "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
+ },
+}
+
+# Topic日志样式配置
+FOCUS_CHAT_STYLE_CONFIG = {
+ "advanced": {
+ "console_format": (
+ "{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | "
+ "专注水群 | "
+ "{message}"
+ ),
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}",
+ },
+ "simple": {
+ "console_format": "{time:HH:mm:ss} | 专注水群 | {message}", # noqa: E501
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}",
+ },
+}
+
REMOTE_STYLE_CONFIG = {
"advanced": {
"console_format": (
@@ -629,22 +663,22 @@ PROCESSOR_STYLE_CONFIG = {
PLANNER_STYLE_CONFIG = {
"advanced": {
- "console_format": "{time:HH:mm:ss} | 规划器 | {message}",
+ "console_format": "{time:HH:mm:ss} | 规划器 | {message}",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
},
"simple": {
- "console_format": "{time:HH:mm:ss} | 规划器 | {message}",
+ "console_format": "{time:HH:mm:ss} | 规划器 | {message}",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
},
}
ACTION_TAKEN_STYLE_CONFIG = {
"advanced": {
- "console_format": "{time:HH:mm:ss} | 动作 | {message}",
+ "console_format": "{time:HH:mm:ss} | 动作 | {message}",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
},
"simple": {
- "console_format": "{time:HH:mm:ss} | 动作 | {message}",
+ "console_format": "{time:HH:mm:ss} | 动作 | {message}",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
},
}
@@ -935,6 +969,8 @@ MAIM_MESSAGE_STYLE_CONFIG = (
INTEREST_CHAT_STYLE_CONFIG = (
INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"]
)
+NORMAL_CHAT_STYLE_CONFIG = NORMAL_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else NORMAL_CHAT_STYLE_CONFIG["advanced"]
+FOCUS_CHAT_STYLE_CONFIG = FOCUS_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else FOCUS_CHAT_STYLE_CONFIG["advanced"]
def is_registered_module(record: dict) -> bool:
diff --git a/src/common/logger_manager.py b/src/common/logger_manager.py
index 2b6f01ca0..de25b5ba7 100644
--- a/src/common/logger_manager.py
+++ b/src/common/logger_manager.py
@@ -21,6 +21,8 @@ from src.common.logger import (
WILLING_STYLE_CONFIG,
PFC_ACTION_PLANNER_STYLE_CONFIG,
MAI_STATE_CONFIG,
+ NORMAL_CHAT_STYLE_CONFIG,
+ FOCUS_CHAT_STYLE_CONFIG,
LPMM_STYLE_CONFIG,
HFC_STYLE_CONFIG,
OBSERVATION_STYLE_CONFIG,
@@ -95,7 +97,8 @@ MODULE_LOGGER_CONFIGS = {
"init": INIT_STYLE_CONFIG, # 初始化
"interest_chat": INTEREST_CHAT_STYLE_CONFIG, # 兴趣
"api": API_SERVER_STYLE_CONFIG, # API服务器
- "maim_message": MAIM_MESSAGE_STYLE_CONFIG, # 消息服务
+ "normal_chat": NORMAL_CHAT_STYLE_CONFIG, # 一般水群
+ "focus_chat": FOCUS_CHAT_STYLE_CONFIG, # 专注水群
# ...如有更多模块,继续添加...
}
diff --git a/src/common/message_repository.py b/src/common/message_repository.py
index 03f192cea..ee69b22b0 100644
--- a/src/common/message_repository.py
+++ b/src/common/message_repository.py
@@ -1,11 +1,19 @@
-from src.common.database import db
+from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger
import traceback
from typing import List, Any, Optional
+from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__)
+def _model_to_dict(model_instance: Model) -> dict[str, Any]:
+ """
+ 将 Peewee 模型实例转换为字典。
+ """
+ return model_instance.__data__
+
+
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,84 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。
Args:
- message_filter: MongoDB 查询过滤器。
- sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。
+ message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
+ sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数,0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
Returns:
- 消息文档列表,如果出错则返回空列表。
+ 消息字典列表,如果出错则返回空列表。
"""
try:
- query = db.messages.find(message_filter)
+ query = Messages.select()
+
+ # 应用过滤器
+ if message_filter:
+ conditions = []
+ for key, value in message_filter.items():
+ if hasattr(Messages, key):
+ field = getattr(Messages, key)
+ if isinstance(value, dict):
+ # 处理 MongoDB 风格的操作符
+ for op, op_value in value.items():
+ if op == "$gt":
+ conditions.append(field > op_value)
+ elif op == "$lt":
+ conditions.append(field < op_value)
+ elif op == "$gte":
+ conditions.append(field >= op_value)
+ elif op == "$lte":
+ conditions.append(field <= op_value)
+ elif op == "$ne":
+ conditions.append(field != op_value)
+ elif op == "$in":
+ conditions.append(field.in_(op_value))
+ elif op == "$nin":
+ conditions.append(field.not_in(op_value))
+ else:
+ logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
+ else:
+ # 直接相等比较
+ conditions.append(field == value)
+ else:
+ logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
+ if conditions:
+ query = query.where(*conditions)
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
- query = query.sort([("time", 1)]).limit(limit)
- results = list(query)
+ query = query.order_by(Messages.time.asc()).limit(limit)
+ peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
- query = query.sort([("time", -1)]).limit(limit)
- latest_results = list(query)
+ query = query.order_by(Messages.time.desc()).limit(limit)
+ latest_results_peewee = list(query)
# 将结果按时间正序排列
- # 假设消息文档中总是有 'time' 字段且可排序
- results = sorted(latest_results, key=lambda msg: msg.get("time"))
+ peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
- query = query.sort(sort)
- results = list(query)
+ peewee_sort_terms = []
+ for field_name, direction in sort:
+ if hasattr(Messages, field_name):
+ field = getattr(Messages, field_name)
+ if direction == 1: # ASC
+ peewee_sort_terms.append(field.asc())
+ elif direction == -1: # DESC
+ peewee_sort_terms.append(field.desc())
+ else:
+ logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
+ else:
+ logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
+ if peewee_sort_terms:
+ query = query.order_by(*peewee_sort_terms)
+ peewee_results = list(query)
- return results
+ return [_model_to_dict(msg) for msg in peewee_results]
except Exception as e:
log_message = (
- f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。
Args:
- message_filter: MongoDB 查询过滤器。
+ message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns:
符合条件的消息数量,如果出错则返回 0。
"""
try:
- count = db.messages.count_documents(message_filter)
+ query = Messages.select()
+
+ # 应用过滤器
+ if message_filter:
+ conditions = []
+ for key, value in message_filter.items():
+ if hasattr(Messages, key):
+ field = getattr(Messages, key)
+ if isinstance(value, dict):
+ # 处理 MongoDB 风格的操作符
+ for op, op_value in value.items():
+ if op == "$gt":
+ conditions.append(field > op_value)
+ elif op == "$lt":
+ conditions.append(field < op_value)
+ elif op == "$gte":
+ conditions.append(field >= op_value)
+ elif op == "$lte":
+ conditions.append(field <= op_value)
+ elif op == "$ne":
+ conditions.append(field != op_value)
+ elif op == "$in":
+ conditions.append(field.in_(op_value))
+ elif op == "$nin":
+ conditions.append(field.not_in(op_value))
+ else:
+ logger.warning(
+ f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
+ )
+ else:
+ # 直接相等比较
+ conditions.append(field == value)
+ else:
+ logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
+ if conditions:
+ query = query.where(*conditions)
+
+ count = query.count()
return count
except Exception as e:
- log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
+ log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
+# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
+# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
diff --git a/src/common/remote.py b/src/common/remote.py
index 1d26df01b..b1108be9c 100644
--- a/src/common/remote.py
+++ b/src/common/remote.py
@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),
- "mmc_version": global_config.MAI_VERSION,
+ "mmc_version": global_config.MMC_VERSION,
}
match platform.system():
@@ -133,10 +133,9 @@ class TelemetryHeartBeatTask(AsyncTask):
async def run(self):
# 发送心跳
- if global_config.remote_enable:
- if self.client_uuid is None:
- if not await self._req_uuid():
- logger.error("获取UUID失败,跳过此次心跳")
- return
+ if global_config.telemetry.enable:
+ if self.client_uuid is None and not await self._req_uuid():
+ logger.error("获取UUID失败,跳过此次心跳")
+ return
await self._send_heartbeat()
diff --git a/src/config/config.py b/src/config/config.py
index b186f3b83..e6b7c5326 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -1,64 +1,68 @@
import os
-import re
-from dataclasses import dataclass, field
-from typing import Dict, List, Optional
+from dataclasses import field, dataclass
-import tomli
import tomlkit
import shutil
from datetime import datetime
-from pathlib import Path
-from packaging import version
-from packaging.version import Version, InvalidVersion
-from packaging.specifiers import SpecifierSet, InvalidSpecifier
+
+from tomlkit import TOMLDocument
+from tomlkit.items import Table
from src.common.logger_manager import get_logger
from rich.traceback import install
+from src.config.config_base import ConfigBase
+from src.config.official_configs import (
+ BotConfig,
+ ChatTargetConfig,
+ PersonalityConfig,
+ IdentityConfig,
+ PlatformsConfig,
+ ChatConfig,
+ NormalChatConfig,
+ FocusChatConfig,
+ EmojiConfig,
+ MemoryConfig,
+ MoodConfig,
+ KeywordReactionConfig,
+ ChineseTypoConfig,
+ ResponseSplitterConfig,
+ TelemetryConfig,
+ ExperimentalConfig,
+ ModelConfig,
+)
+
install(extra_lines=3)
# 配置主程序日志格式
logger = get_logger("config")
-# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
-is_test = True
-mai_version_main = "0.6.4"
-mai_version_fix = "snapshot-1"
+CONFIG_DIR = "config"
+TEMPLATE_DIR = "template"
-if mai_version_fix:
- if is_test:
- mai_version = f"test-{mai_version_main}-{mai_version_fix}"
- else:
- mai_version = f"{mai_version_main}-{mai_version_fix}"
-else:
- if is_test:
- mai_version = f"test-{mai_version_main}"
- else:
- mai_version = mai_version_main
+# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
+# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
+MMC_VERSION = "0.7.0-snapshot.1"
def update_config():
# 获取根目录路径
- root_dir = Path(__file__).parent.parent.parent
- template_dir = root_dir / "template"
- config_dir = root_dir / "config"
- old_config_dir = config_dir / "old"
+ old_config_dir = f"{CONFIG_DIR}/old"
# 定义文件路径
- template_path = template_dir / "bot_config_template.toml"
- old_config_path = config_dir / "bot_config.toml"
- new_config_path = config_dir / "bot_config.toml"
+ template_path = f"{TEMPLATE_DIR}/bot_config_template.toml"
+ old_config_path = f"{CONFIG_DIR}/bot_config.toml"
+ new_config_path = f"{CONFIG_DIR}/bot_config.toml"
# 检查配置文件是否存在
- if not old_config_path.exists():
+ if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
- # 创建文件夹
- old_config_dir.mkdir(parents=True, exist_ok=True)
- shutil.copy2(template_path, old_config_path)
+ os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
+ shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
- return quit()
+ quit()
# 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f:
@@ -75,13 +79,15 @@ def update_config():
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
+ else:
+ logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录(如果不存在)
- old_config_dir.mkdir(exist_ok=True)
+ os.makedirs(old_config_dir, exist_ok=True)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
+ old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
@@ -91,24 +97,23 @@ def update_config():
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
- # 递归更新配置
- def update_dict(target, source):
+ def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
+ """
+ 将source字典的值更新到target字典中(如果target中存在相同的键)
+ """
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
- if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
+ if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
- if not value:
- target[key] = tomlkit.array()
- else:
- target[key] = tomlkit.array(value)
+ target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
@@ -123,619 +128,57 @@ def update_config():
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
- logger.info("配置文件更新完成")
+ logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
+ quit()
@dataclass
-class BotConfig:
- """机器人配置类"""
-
- INNER_VERSION: Version = None
- MAI_VERSION: str = mai_version # 硬编码的版本信息
-
- # bot
- BOT_QQ: Optional[str] = "114514"
- BOT_NICKNAME: Optional[str] = None
- BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
-
- # group
- talk_allowed_groups = set()
- talk_frequency_down_groups = set()
- ban_user_id = set()
-
- # personality
- personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
- personality_sides: List[str] = field(
- default_factory=lambda: [
- "用一句话或几句话描述人格的一些侧面",
- "用一句话或几句话描述人格的一些侧面",
- "用一句话或几句话描述人格的一些侧面",
- ]
- )
- expression_style = "描述麦麦说话的表达风格,表达习惯"
- # identity
- identity_detail: List[str] = field(
- default_factory=lambda: [
- "身份特点",
- "身份特点",
- ]
- )
- height: int = 170 # 身高 单位厘米
- weight: int = 50 # 体重 单位千克
- age: int = 20 # 年龄 单位岁
- gender: str = "男" # 性别
- appearance: str = "用几句话描述外貌特征" # 外貌特征
-
- # chat
- allow_focus_mode: bool = True # 是否允许专注聊天状态
-
- base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
- base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
-
- observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
-
- message_buffer: bool = True # 消息缓冲器
-
- ban_words = set()
- ban_msgs_regex = set()
-
- # focus_chat
- reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发
- default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢
- consecutive_no_reply_threshold = 3
-
- compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
- compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
-
- # normal_chat
- model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率
- model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率
-
- emoji_chance: float = 0.2 # 发送表情包的基础概率
- thinking_timeout: int = 120 # 思考时间
-
- willing_mode: str = "classical" # 意愿模式
- response_willing_amplifier: float = 1.0 # 回复意愿放大系数
- response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
- down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
- emoji_response_penalty: float = 0.0 # 表情包回复惩罚
- mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
- at_bot_inevitable_reply: bool = False # @bot 必然回复
-
- # emoji
- max_emoji_num: int = 200 # 表情包最大数量
- max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
- EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
-
- save_pic: bool = False # 是否保存图片
- save_emoji: bool = False # 是否保存表情包
- steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
-
- EMOJI_CHECK: bool = False # 是否开启过滤
- EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
-
- # memory
- build_memory_interval: int = 600 # 记忆构建间隔(秒)
- memory_build_distribution: list = field(
- default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
- ) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
- build_memory_sample_num: int = 10 # 记忆构建采样数量
- build_memory_sample_length: int = 20 # 记忆构建采样长度
- memory_compress_rate: float = 0.1 # 记忆压缩率
-
- forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
- memory_forget_time: int = 24 # 记忆遗忘时间(小时)
- memory_forget_percentage: float = 0.01 # 记忆遗忘比例
-
- consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒)
- consolidation_similarity_threshold: float = 0.7 # 相似度阈值
- consolidate_memory_percentage: float = 0.01 # 检查节点比例
-
- memory_ban_words: list = field(
- default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
- ) # 添加新的配置项默认值
-
- # mood
- mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
- mood_decay_rate: float = 0.95 # 情绪衰减率
- mood_intensity_factor: float = 0.7 # 情绪强度因子
-
- # keywords
- keywords_reaction_rules = [] # 关键词回复规则
-
- # chinese_typo
- chinese_typo_enable = True # 是否启用中文错别字生成器
- chinese_typo_error_rate = 0.03 # 单字替换概率
- chinese_typo_min_freq = 7 # 最小字频阈值
- chinese_typo_tone_error_rate = 0.2 # 声调错误概率
- chinese_typo_word_replace_rate = 0.02 # 整词替换概率
-
- # response_splitter
- enable_kaomoji_protection = False # 是否启用颜文字保护
- enable_response_splitter = True # 是否启用回复分割器
- response_max_length = 100 # 回复允许的最大长度
- response_max_sentence_num = 3 # 回复允许的最大句子数
-
- model_max_output_length: int = 800 # 最大回复长度
-
- # remote
- remote_enable: bool = True # 是否启用远程控制
-
- # experimental
- enable_friend_chat: bool = False # 是否启用好友聊天
- # enable_think_flow: bool = False # 是否启用思考流程
- talk_allowed_private = set()
- enable_pfc_chatting: bool = False # 是否启用PFC聊天
-
- # 模型配置
- llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
- # llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
- llm_normal: Dict[str, str] = field(default_factory=lambda: {})
- llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
- llm_summary: Dict[str, str] = field(default_factory=lambda: {})
- embedding: Dict[str, str] = field(default_factory=lambda: {})
- vlm: Dict[str, str] = field(default_factory=lambda: {})
- moderation: Dict[str, str] = field(default_factory=lambda: {})
-
- llm_observation: Dict[str, str] = field(default_factory=lambda: {})
- llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
- llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
- llm_tool_use: Dict[str, str] = field(default_factory=lambda: {})
- llm_plan: Dict[str, str] = field(default_factory=lambda: {})
-
- api_urls: Dict[str, str] = field(default_factory=lambda: {})
-
- @staticmethod
- def get_config_dir() -> str:
- """获取配置文件目录"""
- current_dir = os.path.dirname(os.path.abspath(__file__))
- root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
- config_dir = os.path.join(root_dir, "config")
- if not os.path.exists(config_dir):
- os.makedirs(config_dir)
- return config_dir
-
- @classmethod
- def convert_to_specifierset(cls, value: str) -> SpecifierSet:
- """将 字符串 版本表达式转换成 SpecifierSet
- Args:
- value[str]: 版本表达式(字符串)
- Returns:
- SpecifierSet
- """
-
- try:
- converted = SpecifierSet(value)
- except InvalidSpecifier:
- logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
- exit(1)
-
- return converted
-
- @classmethod
- def get_config_version(cls, toml: dict) -> Version:
- """提取配置文件的 SpecifierSet 版本数据
- Args:
- toml[dict]: 输入的配置文件字典
- Returns:
- Version
- """
-
- if "inner" in toml:
- try:
- config_version: str = toml["inner"]["version"]
- except KeyError as e:
- logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
- raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
- else:
- toml["inner"] = {"version": "0.0.0"}
- config_version = toml["inner"]["version"]
-
- try:
- ver = version.parse(config_version)
- except InvalidVersion as e:
- logger.error(
- "配置文件中 inner段 的 version 键是错误的版本描述\n"
- "请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
- "本项目在不同的版本下有不同的模板,请注意识别"
- )
- raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
-
- return ver
-
- @classmethod
- def load_config(cls, config_path: str = None) -> "BotConfig":
- """从TOML配置文件加载配置"""
- config = cls()
-
- def personality(parent: dict):
- personality_config = parent["personality"]
- if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
- config.personality_core = personality_config.get("personality_core", config.personality_core)
- config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
- if config.INNER_VERSION in SpecifierSet(">=1.7.0"):
- config.expression_style = personality_config.get("expression_style", config.expression_style)
-
- def identity(parent: dict):
- identity_config = parent["identity"]
- if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
- config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
- config.height = identity_config.get("height", config.height)
- config.weight = identity_config.get("weight", config.weight)
- config.age = identity_config.get("age", config.age)
- config.gender = identity_config.get("gender", config.gender)
- config.appearance = identity_config.get("appearance", config.appearance)
-
- def emoji(parent: dict):
- emoji_config = parent["emoji"]
- config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
- config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
- config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
- if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
- config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
- config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
- if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
- config.save_pic = emoji_config.get("save_pic", config.save_pic)
- config.save_emoji = emoji_config.get("save_emoji", config.save_emoji)
- config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji)
-
- def bot(parent: dict):
- # 机器人基础配置
- bot_config = parent["bot"]
- bot_qq = bot_config.get("qq")
- config.BOT_QQ = str(bot_qq)
- config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
- config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
-
- def chat(parent: dict):
- chat_config = parent["chat"]
- config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode)
- config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num)
- config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num)
- config.observation_context_size = chat_config.get(
- "observation_context_size", config.observation_context_size
- )
- config.message_buffer = chat_config.get("message_buffer", config.message_buffer)
- config.ban_words = chat_config.get("ban_words", config.ban_words)
- for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
- config.ban_msgs_regex.add(re.compile(r))
-
- def normal_chat(parent: dict):
- normal_chat_config = parent["normal_chat"]
- config.model_reasoning_probability = normal_chat_config.get(
- "model_reasoning_probability", config.model_reasoning_probability
- )
- config.model_normal_probability = normal_chat_config.get(
- "model_normal_probability", config.model_normal_probability
- )
- config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance)
- config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout)
-
- config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode)
- config.response_willing_amplifier = normal_chat_config.get(
- "response_willing_amplifier", config.response_willing_amplifier
- )
- config.response_interested_rate_amplifier = normal_chat_config.get(
- "response_interested_rate_amplifier", config.response_interested_rate_amplifier
- )
- config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate)
- config.emoji_response_penalty = normal_chat_config.get(
- "emoji_response_penalty", config.emoji_response_penalty
- )
-
- config.mentioned_bot_inevitable_reply = normal_chat_config.get(
- "mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
- )
- config.at_bot_inevitable_reply = normal_chat_config.get(
- "at_bot_inevitable_reply", config.at_bot_inevitable_reply
- )
-
- def focus_chat(parent: dict):
- focus_chat_config = parent["focus_chat"]
- config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length)
- config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit)
- config.reply_trigger_threshold = focus_chat_config.get(
- "reply_trigger_threshold", config.reply_trigger_threshold
- )
- config.default_decay_rate_per_second = focus_chat_config.get(
- "default_decay_rate_per_second", config.default_decay_rate_per_second
- )
- config.consecutive_no_reply_threshold = focus_chat_config.get(
- "consecutive_no_reply_threshold", config.consecutive_no_reply_threshold
- )
-
- def model(parent: dict):
- # 加载模型配置
- model_config: dict = parent["model"]
-
- config_list = [
- "llm_reasoning",
- # "llm_reasoning_minor",
- "llm_normal",
- "llm_topic_judge",
- "llm_summary",
- "vlm",
- "embedding",
- "llm_tool_use",
- "llm_observation",
- "llm_sub_heartflow",
- "llm_plan",
- "llm_heartflow",
- "llm_PFC_action_planner",
- "llm_PFC_chat",
- "llm_PFC_reply_checker",
- ]
-
- for item in config_list:
- if item in model_config:
- cfg_item: dict = model_config[item]
-
- # base_url 的例子: SILICONFLOW_BASE_URL
- # key 的例子: SILICONFLOW_KEY
- cfg_target = {
- "name": "",
- "base_url": "",
- "key": "",
- "stream": False,
- "pri_in": 0,
- "pri_out": 0,
- "temp": 0.7,
- }
-
- if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
- cfg_target = cfg_item
-
- elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
- stable_item = ["name", "pri_in", "pri_out"]
-
- stream_item = ["stream"]
- if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
- stable_item.append("stream")
-
- pricing_item = ["pri_in", "pri_out"]
-
- # 从配置中原始拷贝稳定字段
- for i in stable_item:
- # 如果 字段 属于计费项 且获取不到,那默认值是 0
- if i in pricing_item and i not in cfg_item:
- cfg_target[i] = 0
-
- if i in stream_item and i not in cfg_item:
- cfg_target[i] = False
-
- else:
- # 没有特殊情况则原样复制
- try:
- cfg_target[i] = cfg_item[i]
- except KeyError as e:
- logger.error(f"{item} 中的必要字段不存在,请检查")
- raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
-
- # 如果配置中有temp参数,就使用配置中的值
- if "temp" in cfg_item:
- cfg_target["temp"] = cfg_item["temp"]
- else:
- # 如果没有temp参数,就删除默认值
- cfg_target.pop("temp", None)
-
- provider = cfg_item.get("provider")
- if provider is None:
- logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
- raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
-
- cfg_target["base_url"] = f"{provider}_BASE_URL"
- cfg_target["key"] = f"{provider}_KEY"
-
- # 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
- setattr(config, item, cfg_target)
- else:
- logger.error(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
- raise KeyError(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
-
- def memory(parent: dict):
- memory_config = parent["memory"]
- config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
- config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
- config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
- config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
- config.memory_forget_percentage = memory_config.get(
- "memory_forget_percentage", config.memory_forget_percentage
- )
- config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
- if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
- config.memory_build_distribution = memory_config.get(
- "memory_build_distribution", config.memory_build_distribution
- )
- config.build_memory_sample_num = memory_config.get(
- "build_memory_sample_num", config.build_memory_sample_num
- )
- config.build_memory_sample_length = memory_config.get(
- "build_memory_sample_length", config.build_memory_sample_length
- )
- if config.INNER_VERSION in SpecifierSet(">=1.5.1"):
- config.consolidate_memory_interval = memory_config.get(
- "consolidate_memory_interval", config.consolidate_memory_interval
- )
- config.consolidation_similarity_threshold = memory_config.get(
- "consolidation_similarity_threshold", config.consolidation_similarity_threshold
- )
- config.consolidate_memory_percentage = memory_config.get(
- "consolidate_memory_percentage", config.consolidate_memory_percentage
- )
-
- def remote(parent: dict):
- remote_config = parent["remote"]
- config.remote_enable = remote_config.get("enable", config.remote_enable)
-
- def mood(parent: dict):
- mood_config = parent["mood"]
- config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
- config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
- config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
-
- def keywords_reaction(parent: dict):
- keywords_reaction_config = parent["keywords_reaction"]
- if keywords_reaction_config.get("enable", False):
- config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
- for rule in config.keywords_reaction_rules:
- if rule.get("enable", False) and "regex" in rule:
- rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
-
- def chinese_typo(parent: dict):
- chinese_typo_config = parent["chinese_typo"]
- config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
- config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
- config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
- config.chinese_typo_tone_error_rate = chinese_typo_config.get(
- "tone_error_rate", config.chinese_typo_tone_error_rate
- )
- config.chinese_typo_word_replace_rate = chinese_typo_config.get(
- "word_replace_rate", config.chinese_typo_word_replace_rate
- )
-
- def response_splitter(parent: dict):
- response_splitter_config = parent["response_splitter"]
- config.enable_response_splitter = response_splitter_config.get(
- "enable_response_splitter", config.enable_response_splitter
- )
- config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
- config.response_max_sentence_num = response_splitter_config.get(
- "response_max_sentence_num", config.response_max_sentence_num
- )
- if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
- config.enable_kaomoji_protection = response_splitter_config.get(
- "enable_kaomoji_protection", config.enable_kaomoji_protection
- )
- if config.INNER_VERSION in SpecifierSet(">=1.6.0"):
- config.model_max_output_length = response_splitter_config.get(
- "model_max_output_length", config.model_max_output_length
- )
-
- def groups(parent: dict):
- groups_config = parent["groups"]
- # config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
- config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", []))
- # config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
- config.talk_frequency_down_groups = set(
- str(group) for group in groups_config.get("talk_frequency_down", [])
- )
- # config.ban_user_id = set(groups_config.get("ban_user_id", []))
- config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", []))
-
- def experimental(parent: dict):
- experimental_config = parent["experimental"]
- config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
- # config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
- config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", []))
- if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
- config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
-
- # 版本表达式:>=1.0.0,<2.0.0
- # 允许字段:func: method, support: str, notice: str, necessary: bool
- # 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
- # 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
- # 正常执行程序,但是会看到这条自定义提示
-
- # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
- # 主版本号:当你做了不兼容的 API 修改,
- # 次版本号:当你做了向下兼容的功能性新增,
- # 修订号:当你做了向下兼容的问题修正。
- # 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
-
- # 如果你做了break的修改,就应该改动主版本号
- # 如果做了一个兼容修改,就不应该要求这个选项是必须的!
- include_configs = {
- "bot": {"func": bot, "support": ">=0.0.0"},
- "groups": {"func": groups, "support": ">=0.0.0"},
- "personality": {"func": personality, "support": ">=0.0.0"},
- "identity": {"func": identity, "support": ">=1.2.4"},
- "emoji": {"func": emoji, "support": ">=0.0.0"},
- "model": {"func": model, "support": ">=0.0.0"},
- "memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
- "mood": {"func": mood, "support": ">=0.0.0"},
- "remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
- "keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
- "chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
- "response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
- "experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
- "chat": {"func": chat, "support": ">=1.6.0", "necessary": False},
- "normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False},
- "focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False},
- }
-
- # 原地修改,将 字符串版本表达式 转换成 版本对象
- for key in include_configs:
- item_support = include_configs[key]["support"]
- include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
-
- if os.path.exists(config_path):
- with open(config_path, "rb") as f:
- try:
- toml_dict = tomli.load(f)
- except tomli.TOMLDecodeError as e:
- logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}")
- exit(1)
-
- # 获取配置文件版本
- config.INNER_VERSION = cls.get_config_version(toml_dict)
-
- # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
- for key in include_configs:
- if key in toml_dict:
- group_specifierset: SpecifierSet = include_configs[key]["support"]
-
- # 检查配置文件版本是否在支持范围内
- if config.INNER_VERSION in group_specifierset:
- # 如果版本在支持范围内,检查是否存在通知
- if "notice" in include_configs[key]:
- logger.warning(include_configs[key]["notice"])
-
- include_configs[key]["func"](toml_dict)
-
- else:
- # 如果版本不在支持范围内,崩溃并提示用户
- logger.error(
- f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
- f"当前程序仅支持以下版本范围: {group_specifierset}"
- )
- raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
-
- # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
- elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
- # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
- if key == "keywords_reaction":
- pass
-
- else:
- # 如果用户根本没有需要的配置项,提示缺少配置
- logger.error(f"配置文件中缺少必需的字段: '{key}'")
- raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
-
- # identity_detail字段非空检查
- if not config.identity_detail:
- logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
- raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
-
- logger.success(f"成功加载配置文件: {config_path}")
-
- return config
+class Config(ConfigBase):
+ """总配置类"""
+
+ MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
+
+ bot: BotConfig
+ chat_target: ChatTargetConfig
+ personality: PersonalityConfig
+ identity: IdentityConfig
+ platforms: PlatformsConfig
+ chat: ChatConfig
+ normal_chat: NormalChatConfig
+ focus_chat: FocusChatConfig
+ emoji: EmojiConfig
+ memory: MemoryConfig
+ mood: MoodConfig
+ keyword_reaction: KeywordReactionConfig
+ chinese_typo: ChineseTypoConfig
+ response_splitter: ResponseSplitterConfig
+ telemetry: TelemetryConfig
+ experimental: ExperimentalConfig
+ model: ModelConfig
+
+
+def load_config(config_path: str) -> Config:
+ """
+ 加载配置文件
+ :param config_path: 配置文件路径
+ :return: Config对象
+ """
+ # 读取配置文件
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_data = tomlkit.load(f)
+
+ # 创建Config对象
+ try:
+ return Config.from_dict(config_data)
+ except Exception as e:
+ logger.critical("配置文件解析失败")
+ raise e
# 获取配置文件路径
-logger.info(f"MaiCore当前版本: {mai_version}")
+logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
-bot_config_floder_path = BotConfig.get_config_dir()
-logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}")
-
-bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
-
-if os.path.exists(bot_config_path):
- # 如果开发环境配置文件不存在,则使用默认配置文件
- logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
-else:
- # 配置文件不存在
- logger.error("配置文件不存在,请检查路径: {bot_config_path}")
- raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
-
-global_config = BotConfig.load_config(config_path=bot_config_path)
+logger.info("正在品鉴配置文件...")
+global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml")
+logger.info("非常的新鲜,非常的美味!")
diff --git a/src/config/config_base.py b/src/config/config_base.py
new file mode 100644
index 000000000..92f6cf9d4
--- /dev/null
+++ b/src/config/config_base.py
@@ -0,0 +1,116 @@
+from dataclasses import dataclass, fields, MISSING
+from typing import TypeVar, Type, Any, get_origin, get_args
+
+T = TypeVar("T", bound="ConfigBase")
+
+TOML_DICT_TYPE = {
+ int,
+ float,
+ str,
+ bool,
+ list,
+ dict,
+}
+
+
+@dataclass
+class ConfigBase:
+ """配置类的基类"""
+
+ @classmethod
+ def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
+ """从字典加载配置字段"""
+ if not isinstance(data, dict):
+ raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
+
+ init_args: dict[str, Any] = {}
+
+ for f in fields(cls):
+ field_name = f.name
+
+ if field_name.startswith("_"):
+ # 跳过以 _ 开头的字段
+ continue
+
+ if field_name not in data:
+ if f.default is not MISSING or f.default_factory is not MISSING:
+ # 跳过未提供且有默认值/默认构造方法的字段
+ continue
+ else:
+ raise ValueError(f"Missing required field: '{field_name}'")
+
+ value = data[field_name]
+ field_type = f.type
+
+ try:
+ init_args[field_name] = cls._convert_field(value, field_type)
+ except TypeError as e:
+ raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
+ except Exception as e:
+ raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
+
+ return cls(**init_args)
+
+ @classmethod
+ def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
+ """
+ 转换字段值为指定类型
+
+ 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
+ 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
+ 3. 对于基础类型(int, str, float, bool),直接转换
+ 4. 对于其他类型,尝试直接转换,如果失败则抛出异常
+ """
+
+ # 如果是嵌套的 dataclass,递归调用 from_dict 方法
+ if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
+ if not isinstance(value, dict):
+ raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
+ return field_type.from_dict(value)
+
+ # 处理泛型集合类型(list, set, tuple)
+ field_origin_type = get_origin(field_type)
+ field_type_args = get_args(field_type)
+
+ if field_origin_type in {list, set, tuple}:
+ # 检查提供的value是否为list
+ if not isinstance(value, list):
+ raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
+
+ if field_origin_type is list:
+ return [cls._convert_field(item, field_type_args[0]) for item in value]
+ elif field_origin_type is set:
+ return {cls._convert_field(item, field_type_args[0]) for item in value}
+ elif field_origin_type is tuple:
+ # 检查提供的value长度是否与类型参数一致
+ if len(value) != len(field_type_args):
+ raise TypeError(
+ f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
+ )
+ return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
+
+ if field_origin_type is dict:
+ # 检查提供的value是否为dict
+ if not isinstance(value, dict):
+ raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
+
+ # 检查字典的键值类型
+ if len(field_type_args) != 2:
+ raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
+ key_type, value_type = field_type_args
+
+ return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
+
+ # 处理基础类型,例如 int, str 等
+ if field_type is Any or isinstance(value, field_type):
+ return value
+
+ # 其他类型,尝试直接转换
+ try:
+ return field_type(value)
+ except (ValueError, TypeError) as e:
+ raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
+
+ def __str__(self):
+ """返回配置类的字符串表示"""
+ return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
new file mode 100644
index 000000000..6ad4648ba
--- /dev/null
+++ b/src/config/official_configs.py
@@ -0,0 +1,399 @@
+from dataclasses import dataclass, field
+from typing import Any
+
+from src.config.config_base import ConfigBase
+
+"""
+须知:
+1. 本文件中记录了所有的配置项
+2. 所有新增的class都需要继承自ConfigBase
+3. 所有新增的class都应在config.py中的Config类中添加字段
+4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
+"""
+
+
+@dataclass
+class BotConfig(ConfigBase):
+ """QQ机器人配置类"""
+
+ qq_account: str
+ """QQ账号"""
+
+ nickname: str
+ """昵称"""
+
+ alias_names: list[str] = field(default_factory=lambda: [])
+ """别名列表"""
+
+
+@dataclass
+class ChatTargetConfig(ConfigBase):
+ """
+ 聊天目标配置类
+ 此类中有聊天的群组和用户配置
+ """
+
+ talk_allowed_groups: set[str] = field(default_factory=lambda: set())
+ """允许聊天的群组列表"""
+
+ talk_frequency_down_groups: set[str] = field(default_factory=lambda: set())
+ """降低聊天频率的群组列表"""
+
+ ban_user_id: set[str] = field(default_factory=lambda: set())
+ """禁止聊天的用户列表"""
+
+
+@dataclass
+class PersonalityConfig(ConfigBase):
+ """人格配置类"""
+
+ personality_core: str
+ """核心人格"""
+
+ expression_style: str
+ """表达风格"""
+
+ personality_sides: list[str] = field(default_factory=lambda: [])
+ """人格侧写"""
+
+
+@dataclass
+class IdentityConfig(ConfigBase):
+ """个体特征配置类"""
+
+ height: int = 170
+ """身高(单位:厘米)"""
+
+ weight: float = 50
+ """体重(单位:千克)"""
+
+ age: int = 18
+ """年龄(单位:岁)"""
+
+ gender: str = "女"
+ """性别(男/女)"""
+
+ appearance: str = "可爱"
+ """外貌描述"""
+
+ identity_detail: list[str] = field(default_factory=lambda: [])
+ """身份特征"""
+
+
+@dataclass
+class PlatformsConfig(ConfigBase):
+ """平台配置类"""
+
+ qq: str
+ """QQ适配器连接URL配置"""
+
+
+@dataclass
+class ChatConfig(ConfigBase):
+ """聊天配置类"""
+
+ allow_focus_mode: bool = True
+ """是否允许专注聊天状态"""
+
+ base_normal_chat_num: int = 3
+ """最多允许多少个群进行普通聊天"""
+
+ base_focused_chat_num: int = 2
+ """最多允许多少个群进行专注聊天"""
+
+ observation_context_size: int = 12
+ """可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
+
+ message_buffer: bool = True
+ """消息缓冲器"""
+
+ ban_words: set[str] = field(default_factory=lambda: set())
+ """过滤词列表"""
+
+ ban_msgs_regex: set[str] = field(default_factory=lambda: set())
+ """过滤正则表达式列表"""
+
+
+@dataclass
+class NormalChatConfig(ConfigBase):
+ """普通聊天配置类"""
+
+ reasoning_model_probability: float = 0.3
+ """
+ 发言时选择推理模型的概率(0-1之间)
+ 选择普通模型的概率为 1 - reasoning_normal_model_probability
+ """
+
+ emoji_chance: float = 0.2
+ """发送表情包的基础概率"""
+
+ thinking_timeout: int = 120
+ """最长思考时间"""
+
+ willing_mode: str = "classical"
+ """意愿模式"""
+
+ response_willing_amplifier: float = 1.0
+ """回复意愿放大系数"""
+
+ response_interested_rate_amplifier: float = 1.0
+ """回复兴趣度放大系数"""
+
+ down_frequency_rate: float = 3.0
+ """降低回复频率的群组回复意愿降低系数"""
+
+ emoji_response_penalty: float = 0.0
+ """表情包回复惩罚系数"""
+
+ mentioned_bot_inevitable_reply: bool = False
+ """提及 bot 必然回复"""
+
+ at_bot_inevitable_reply: bool = False
+ """@bot 必然回复"""
+
+
+@dataclass
+class FocusChatConfig(ConfigBase):
+ """专注聊天配置类"""
+
+ reply_trigger_threshold: float = 3.0
+ """心流聊天触发阈值,越低越容易触发"""
+
+ default_decay_rate_per_second: float = 0.98
+ """默认衰减率,越大衰减越快"""
+
+ consecutive_no_reply_threshold: int = 3
+ """连续不回复的次数阈值"""
+
+ compressed_length: int = 5
+ """心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5"""
+
+ compress_length_limit: int = 5
+ """最多压缩份数,超过该数值的压缩上下文会被删除"""
+
+
+@dataclass
+class EmojiConfig(ConfigBase):
+ """表情包配置类"""
+
+ max_reg_num: int = 200
+ """表情包最大注册数量"""
+
+ do_replace: bool = True
+ """达到最大注册数量时替换旧表情包"""
+
+ check_interval: int = 120
+ """表情包检查间隔(分钟)"""
+
+ save_pic: bool = False
+ """是否保存图片"""
+
+ cache_emoji: bool = True
+ """是否缓存表情包"""
+
+ steal_emoji: bool = True
+ """是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
+
+ content_filtration: bool = False
+ """是否开启表情包过滤"""
+
+ filtration_prompt: str = "符合公序良俗"
+ """表情包过滤要求"""
+
+
+@dataclass
+class MemoryConfig(ConfigBase):
+ """记忆配置类"""
+
+ memory_build_interval: int = 600
+ """记忆构建间隔(秒)"""
+
+ memory_build_distribution: tuple[
+ float,
+ float,
+ float,
+ float,
+ float,
+ float,
+ ] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
+ """记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重"""
+
+ memory_build_sample_num: int = 8
+ """记忆构建采样数量"""
+
+ memory_build_sample_length: int = 40
+ """记忆构建采样长度"""
+
+ memory_compress_rate: float = 0.1
+ """记忆压缩率"""
+
+ forget_memory_interval: int = 1000
+ """记忆遗忘间隔(秒)"""
+
+ memory_forget_time: int = 24
+ """记忆遗忘时间(小时)"""
+
+ memory_forget_percentage: float = 0.01
+ """记忆遗忘比例"""
+
+ consolidate_memory_interval: int = 1000
+ """记忆整合间隔(秒)"""
+
+ consolidation_similarity_threshold: float = 0.7
+ """整合相似度阈值"""
+
+ consolidate_memory_percentage: float = 0.01
+ """整合检查节点比例"""
+
+ memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
+ """不允许记忆的词列表"""
+
+
+@dataclass
+class MoodConfig(ConfigBase):
+ """情绪配置类"""
+
+ mood_update_interval: int = 1
+ """情绪更新间隔(秒)"""
+
+ mood_decay_rate: float = 0.95
+ """情绪衰减率"""
+
+ mood_intensity_factor: float = 0.7
+ """情绪强度因子"""
+
+
+@dataclass
+class KeywordRuleConfig(ConfigBase):
+ """关键词规则配置类"""
+
+ enable: bool = True
+ """是否启用关键词规则"""
+
+ keywords: list[str] = field(default_factory=lambda: [])
+ """关键词列表"""
+
+ regex: list[str] = field(default_factory=lambda: [])
+ """正则表达式列表"""
+
+ reaction: str = ""
+ """关键词触发的反应"""
+
+
+@dataclass
+class KeywordReactionConfig(ConfigBase):
+ """关键词配置类"""
+
+ enable: bool = True
+ """是否启用关键词反应"""
+
+ rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
+ """关键词反应规则列表"""
+
+
+@dataclass
+class ChineseTypoConfig(ConfigBase):
+ """中文错别字配置类"""
+
+ enable: bool = True
+ """是否启用中文错别字生成器"""
+
+ error_rate: float = 0.01
+ """单字替换概率"""
+
+ min_freq: int = 9
+ """最小字频阈值"""
+
+ tone_error_rate: float = 0.1
+ """声调错误概率"""
+
+ word_replace_rate: float = 0.006
+ """整词替换概率"""
+
+
+@dataclass
+class ResponseSplitterConfig(ConfigBase):
+ """回复分割器配置类"""
+
+ enable: bool = True
+ """是否启用回复分割器"""
+
+ max_length: int = 256
+ """回复允许的最大长度"""
+
+ max_sentence_num: int = 3
+ """回复允许的最大句子数"""
+
+ enable_kaomoji_protection: bool = False
+ """是否启用颜文字保护"""
+
+
+@dataclass
+class TelemetryConfig(ConfigBase):
+ """遥测配置类"""
+
+ enable: bool = True
+ """是否启用遥测"""
+
+
+@dataclass
+class ExperimentalConfig(ConfigBase):
+ """实验功能配置类"""
+
+ # enable_friend_chat: bool = False
+ # """是否启用好友聊天"""
+
+ # talk_allowed_private: set[str] = field(default_factory=lambda: set())
+ # """允许聊天的私聊列表"""
+
+ pfc_chatting: bool = False
+ """是否启用PFC"""
+
+
+@dataclass
+class ModelConfig(ConfigBase):
+ """模型配置类"""
+
+ model_max_output_length: int = 800 # 最大回复长度
+
+ reasoning: dict[str, Any] = field(default_factory=lambda: {})
+ """推理模型配置"""
+
+ normal: dict[str, Any] = field(default_factory=lambda: {})
+ """普通模型配置"""
+
+ topic_judge: dict[str, Any] = field(default_factory=lambda: {})
+ """主题判断模型配置"""
+
+ summary: dict[str, Any] = field(default_factory=lambda: {})
+ """摘要模型配置"""
+
+ vlm: dict[str, Any] = field(default_factory=lambda: {})
+ """视觉语言模型配置"""
+
+ heartflow: dict[str, Any] = field(default_factory=lambda: {})
+ """心流模型配置"""
+
+ observation: dict[str, Any] = field(default_factory=lambda: {})
+ """观察模型配置"""
+
+ sub_heartflow: dict[str, Any] = field(default_factory=lambda: {})
+ """子心流模型配置"""
+
+ plan: dict[str, Any] = field(default_factory=lambda: {})
+ """计划模型配置"""
+
+ embedding: dict[str, Any] = field(default_factory=lambda: {})
+ """嵌入模型配置"""
+
+ pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {})
+ """PFC动作规划模型配置"""
+
+ pfc_chat: dict[str, Any] = field(default_factory=lambda: {})
+ """PFC聊天模型配置"""
+
+ pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {})
+ """PFC回复检查模型配置"""
+
+ tool_use: dict[str, Any] = field(default_factory=lambda: {})
+ """工具使用模型配置"""
diff --git a/src/experimental/PFC/action_planner.py b/src/experimental/PFC/action_planner.py
index b4182c9aa..c0bff5887 100644
--- a/src/experimental/PFC/action_planner.py
+++ b/src/experimental/PFC/action_planner.py
@@ -114,7 +114,7 @@ class ActionPlanner:
request_type="action_planning",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
- self.name = global_config.BOT_NICKNAME
+ self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
@@ -140,7 +140,7 @@ class ActionPlanner:
# (这部分逻辑不变)
time_since_last_bot_message_info = ""
try:
- bot_id = str(global_config.BOT_QQ)
+ bot_id = str(global_config.bot.qq_account)
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
for i in range(len(observation_info.chat_history) - 1, -1, -1):
msg = observation_info.chat_history[i]
diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py
index 704eeb330..55914d800 100644
--- a/src/experimental/PFC/chat_observer.py
+++ b/src/experimental/PFC/chat_observer.py
@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
create_new_message_notification,
create_cold_chat_notification,
)
-from src.experimental.PFC.message_storage import MongoDBMessageStorage
+from src.experimental.PFC.message_storage import PeeweeMessageStorage
from rich.traceback import install
install(extra_lines=3)
@@ -53,7 +53,7 @@ class ChatObserver:
self.stream_id = stream_id
self.private_name = private_name
- self.message_storage = MongoDBMessageStorage()
+ self.message_storage = PeeweeMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
@@ -323,7 +323,7 @@ class ChatObserver:
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
- if user_info.user_id == global_config.BOT_QQ:
+ if user_info.user_id == global_config.bot.qq_account:
self.update_bot_speak_time(msg["time"])
else:
self.update_user_speak_time(msg["time"])
diff --git a/src/experimental/PFC/message_sender.py b/src/experimental/PFC/message_sender.py
index 181bf171b..4b193a41d 100644
--- a/src/experimental/PFC/message_sender.py
+++ b/src/experimental/PFC/message_sender.py
@@ -42,8 +42,8 @@ class DirectMessageSender:
# 获取麦麦的信息
bot_user_info = UserInfo(
- user_id=global_config.BOT_QQ,
- user_nickname=global_config.BOT_NICKNAME,
+ user_id=global_config.bot.qq_account,
+ user_nickname=global_config.bot.nickname,
platform=chat_stream.platform,
)
diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py
index cd6a01e34..e2e1dd052 100644
--- a/src/experimental/PFC/message_storage.py
+++ b/src/experimental/PFC/message_storage.py
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
-from src.common.database import db
+
+# from src.common.database.database import db # Peewee db 导入
+from src.common.database.database_model import Messages # Peewee Messages 模型导入
+from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
class MessageStorage(ABC):
@@ -47,28 +50,35 @@ class MessageStorage(ABC):
pass
-class MongoDBMessageStorage(MessageStorage):
- """MongoDB消息存储实现"""
+class PeeweeMessageStorage(MessageStorage):
+ """Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
- query = {"chat_id": chat_id, "time": {"$gt": message_time}}
- # print(f"storage_check_message: {message_time}")
+ query = (
+ Messages.select()
+ .where((Messages.chat_id == chat_id) & (Messages.time > message_time))
+ .order_by(Messages.time.asc())
+ )
- return list(db.messages.find(query).sort("time", 1))
+ # print(f"storage_check_message: {message_time}")
+ messages_models = list(query)
+ return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
- query = {"chat_id": chat_id, "time": {"$lt": time_point}}
-
- messages = list(db.messages.find(query).sort("time", -1).limit(limit))
+ query = (
+ Messages.select()
+ .where((Messages.chat_id == chat_id) & (Messages.time < time_point))
+ .order_by(Messages.time.desc())
+ .limit(limit)
+ )
+ messages_models = list(query)
# 将消息按时间正序排列
- messages.reverse()
- return messages
+ messages_models.reverse()
+ return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
- query = {"chat_id": chat_id, "time": {"$gt": after_time}}
-
- return db.messages.find_one(query) is not None
+ return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
# # 创建一个内存消息存储实现,用于测试
diff --git a/src/experimental/PFC/pfc.py b/src/experimental/PFC/pfc.py
index 84fb9f8dc..80e75c5bf 100644
--- a/src/experimental/PFC/pfc.py
+++ b/src/experimental/PFC/pfc.py
@@ -42,13 +42,14 @@ class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
+ # TODO: API-Adapter修改标记
self.llm = LLMRequest(
- model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
+ model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
- self.name = global_config.BOT_NICKNAME
- self.nick_name = global_config.BOT_ALIAS_NAMES
+ self.name = global_config.bot.nickname
+ self.nick_name = global_config.bot.alias_names
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
@@ -315,7 +316,7 @@ class GoalAnalyzer:
# message_segment = Seg(type="text", data=content)
# bot_user_info = UserInfo(
# user_id=global_config.BOT_QQ,
-# user_nickname=global_config.BOT_NICKNAME,
+# user_nickname=global_config.bot.nickname,
# platform=chat_stream.platform,
# )
diff --git a/src/experimental/PFC/pfc_KnowledgeFetcher.py b/src/experimental/PFC/pfc_KnowledgeFetcher.py
index 8ebc307e2..4c1d8c759 100644
--- a/src/experimental/PFC/pfc_KnowledgeFetcher.py
+++ b/src/experimental/PFC/pfc_KnowledgeFetcher.py
@@ -14,9 +14,10 @@ class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
+ # TODO: API-Adapter修改标记
self.llm = LLMRequest(
- model=global_config.llm_normal,
- temperature=global_config.llm_normal["temp"],
+ model=global_config.model.normal,
+ temperature=global_config.model.normal["temp"],
max_tokens=1000,
request_type="knowledge_fetch",
)
diff --git a/src/experimental/PFC/reply_checker.py b/src/experimental/PFC/reply_checker.py
index a76e8a0da..5bca9d601 100644
--- a/src/experimental/PFC/reply_checker.py
+++ b/src/experimental/PFC/reply_checker.py
@@ -16,7 +16,7 @@ class ReplyChecker:
self.llm = LLMRequest(
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
)
- self.name = global_config.BOT_NICKNAME
+ self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.max_retries = 3 # 最大重试次数
@@ -43,7 +43,7 @@ class ReplyChecker:
bot_messages = []
for msg in reversed(chat_history):
user_info = UserInfo.from_dict(msg.get("user_info", {}))
- if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
+ if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串
bot_messages.append(msg.get("processed_plain_text", ""))
if len(bot_messages) >= 2: # 只和最近的两条比较
break
diff --git a/src/experimental/PFC/reply_generator.py b/src/experimental/PFC/reply_generator.py
index 6dcda69af..bac8a769f 100644
--- a/src/experimental/PFC/reply_generator.py
+++ b/src/experimental/PFC/reply_generator.py
@@ -93,7 +93,7 @@ class ReplyGenerator:
request_type="reply_generation",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
- self.name = global_config.BOT_NICKNAME
+ self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.reply_checker = ReplyChecker(stream_id, private_name)
diff --git a/src/experimental/PFC/waiter.py b/src/experimental/PFC/waiter.py
index af5cf7ad0..452446589 100644
--- a/src/experimental/PFC/waiter.py
+++ b/src/experimental/PFC/waiter.py
@@ -19,7 +19,7 @@ class Waiter:
def __init__(self, stream_id: str, private_name: str):
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
- self.name = global_config.BOT_NICKNAME
+ self.name = global_config.bot.nickname
self.private_name = private_name
# self.wait_accumulated_time = 0 # 不再需要累加计时
diff --git a/src/experimental/only_message_process.py b/src/experimental/only_message_process.py
index 3d1432703..62f73c700 100644
--- a/src/experimental/only_message_process.py
+++ b/src/experimental/only_message_process.py
@@ -16,7 +16,7 @@ class MessageProcessor:
@staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
- for word in global_config.ban_words:
+ for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
@@ -28,7 +28,7 @@ class MessageProcessor:
@staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
- for pattern in global_config.ban_msgs_regex:
+ for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
diff --git a/src/main.py b/src/main.py
index 34b7eda3d..4f8af28ef 100644
--- a/src/main.py
+++ b/src/main.py
@@ -40,7 +40,7 @@ class MainSystem:
async def initialize(self):
"""初始化系统组件"""
- logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
+ logger.debug(f"正在唤醒{global_config.bot.nickname}......")
# 其他初始化任务
await asyncio.gather(self._init_components())
@@ -84,7 +84,7 @@ class MainSystem:
asyncio.create_task(chat_manager._auto_save_task())
# 使用HippocampusManager初始化海马体
- self.hippocampus_manager.initialize(global_config=global_config)
+ self.hippocampus_manager.initialize()
# await asyncio.sleep(0.5) #防止logger输出飞了
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
@@ -92,15 +92,15 @@ class MainSystem:
# 初始化个体特征
self.individuality.initialize(
- bot_nickname=global_config.BOT_NICKNAME,
- personality_core=global_config.personality_core,
- personality_sides=global_config.personality_sides,
- identity_detail=global_config.identity_detail,
- height=global_config.height,
- weight=global_config.weight,
- age=global_config.age,
- gender=global_config.gender,
- appearance=global_config.appearance,
+ bot_nickname=global_config.bot.nickname,
+ personality_core=global_config.personality.personality_core,
+ personality_sides=global_config.personality.personality_sides,
+ identity_detail=global_config.identity.identity_detail,
+ height=global_config.identity.height,
+ weight=global_config.identity.weight,
+ age=global_config.identity.age,
+ gender=global_config.identity.gender,
+ appearance=global_config.identity.appearance,
)
logger.success("个体特征初始化成功")
@@ -141,7 +141,7 @@ class MainSystem:
async def build_memory_task():
"""记忆构建任务"""
while True:
- await asyncio.sleep(global_config.build_memory_interval)
+ await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory()
@@ -149,16 +149,18 @@ class MainSystem:
async def forget_memory_task():
"""记忆遗忘任务"""
while True:
- await asyncio.sleep(global_config.forget_memory_interval)
+ await asyncio.sleep(global_config.memory.forget_memory_interval)
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
- await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
+ await HippocampusManager.get_instance().forget_memory(
+ percentage=global_config.memory.memory_forget_percentage
+ )
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@staticmethod
async def consolidate_memory_task():
"""记忆整合任务"""
while True:
- await asyncio.sleep(global_config.consolidate_memory_interval)
+ await asyncio.sleep(global_config.memory.consolidate_memory_interval)
print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...")
await HippocampusManager.get_instance().consolidate_memory()
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
diff --git a/src/manager/mood_manager.py b/src/manager/mood_manager.py
index 42677d4e1..c83fbeb7c 100644
--- a/src/manager/mood_manager.py
+++ b/src/manager/mood_manager.py
@@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask):
def __init__(self):
super().__init__(
task_name="Mood Update Task",
- wait_before_start=global_config.mood_update_interval,
- run_interval=global_config.mood_update_interval,
+ wait_before_start=global_config.mood.mood_update_interval,
+ run_interval=global_config.mood.mood_update_interval,
)
# 从配置文件获取衰减率
- self.decay_rate_valence: float = 1 - global_config.mood_decay_rate
+ self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate
"""愉悦度衰减率"""
- self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate
+ self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate
"""唤醒度衰减率"""
self.last_update = time.time()
diff --git a/src/plugins.md b/src/plugins.md
new file mode 100644
index 000000000..71ca741a6
--- /dev/null
+++ b/src/plugins.md
@@ -0,0 +1,101 @@
+# 如何编写MaiBot插件
+
+## 基本步骤
+
+1. 在`src/plugins/你的插件名/actions/`目录下创建插件文件
+2. 继承`PluginAction`基类
+3. 实现`process`方法
+
+## 插件结构示例
+
+```python
+from src.common.logger_manager import get_logger
+from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
+from typing import Tuple
+
+logger = get_logger("your_action_name")
+
+@register_action
+class YourAction(PluginAction):
+ """你的动作描述"""
+
+ action_name = "your_action_name" # 动作名称,必须唯一
+ action_description = "这个动作的详细描述,会展示给用户"
+ action_parameters = {
+ "param1": "参数1的说明(可选)",
+ "param2": "参数2的说明(可选)"
+ }
+ action_require = [
+ "使用场景1",
+ "使用场景2"
+ ]
+ default = False # 是否默认启用
+
+ async def process(self) -> Tuple[bool, str]:
+ """插件核心逻辑"""
+ # 你的代码逻辑...
+ return True, "执行结果"
+```
+
+## 可用的API方法
+
+插件可以使用`PluginAction`基类提供的以下API:
+
+### 1. 发送消息
+
+```python
+await self.send_message("要发送的文本", target="可选的回复目标")
+```
+
+### 2. 获取聊天类型
+
+```python
+chat_type = self.get_chat_type() # 返回 "group" 或 "private" 或 "unknown"
+```
+
+### 3. 获取最近消息
+
+```python
+messages = self.get_recent_messages(count=5) # 获取最近5条消息
+# 返回格式: [{"sender": "发送者", "content": "内容", "timestamp": 时间戳}, ...]
+```
+
+### 4. 获取动作参数
+
+```python
+param_value = self.action_data.get("param_name", "默认值")
+```
+
+### 5. 日志记录
+
+```python
+logger.info(f"{self.log_prefix} 你的日志信息")
+logger.warning("警告信息")
+logger.error("错误信息")
+```
+
+## 返回值说明
+
+`process`方法必须返回一个元组,包含两个元素:
+- 第一个元素(bool): 表示动作是否执行成功
+- 第二个元素(str): 执行结果的文本描述
+
+```python
+return True, "执行成功的消息"
+# 或
+return False, "执行失败的原因"
+```
+
+## 最佳实践
+
+1. 使用`action_parameters`清晰定义你的动作需要的参数
+2. 使用`action_require`描述何时应该使用你的动作
+3. 使用`action_description`准确描述你的动作功能
+4. 使用`logger`记录重要信息,方便调试
+5. 避免操作底层系统,尽量使用`PluginAction`提供的API
+
+## 注册与加载
+
+插件会在系统启动时自动加载,只要放在正确的目录并添加了`@register_action`装饰器。
+
+若设置`default = True`,插件会自动添加到默认动作集;否则需要在系统中手动启用。
diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py
new file mode 100644
index 000000000..0b0692d42
--- /dev/null
+++ b/src/plugins/__init__.py
@@ -0,0 +1 @@
+"""插件系统包"""
diff --git a/src/plugins/test_plugin/__init__.py b/src/plugins/test_plugin/__init__.py
new file mode 100644
index 000000000..b5fefb97e
--- /dev/null
+++ b/src/plugins/test_plugin/__init__.py
@@ -0,0 +1,5 @@
+"""测试插件包"""
+
+"""
+这是一个测试插件
+"""
diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py
new file mode 100644
index 000000000..a87c0b523
--- /dev/null
+++ b/src/plugins/test_plugin/actions/__init__.py
@@ -0,0 +1,7 @@
+"""测试插件动作模块"""
+
+# 导入所有动作模块以确保装饰器被执行
+# from . import test_action # noqa
+
+# from . import online_action # noqa
+# from . import mute_action # noqa
diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py
new file mode 100644
index 000000000..c96204172
--- /dev/null
+++ b/src/plugins/test_plugin/actions/mute_action.py
@@ -0,0 +1,51 @@
+from src.common.logger_manager import get_logger
+from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
+from typing import Tuple
+
+logger = get_logger("mute_action")
+
+
+@register_action
+class MuteAction(PluginAction):
+ """测试动作处理类"""
+
+ action_name = "mute_action"
+ action_description = (
+ "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
+ )
+ action_parameters = {
+ "target": "禁言对象,输入你要禁言的对象的名字,必填,",
+ "duration": "禁言时长,输入你要禁言的时长,单位为秒,必填",
+ }
+ action_require = [
+ "当有人违反了公序良俗时使用",
+ "当有人刷屏时使用",
+ "当有人要求禁言自己时使用",
+ "当有人戳你两次以上时,防止刷屏,禁言他,必须牢记",
+ "当千石可乐或可乐酱要求你禁言时使用",
+ "当你想回避某个话题时使用",
+ ]
+ default = True # 不是默认动作,需要手动添加到使用集
+
+ async def process(self) -> Tuple[bool, str]:
+ """处理测试动作"""
+ logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
+
+ # 发送测试消息
+ target = self.action_data.get("target")
+ duration = self.action_data.get("duration")
+ reason = self.action_data.get("reason")
+ platform, user_id = await self.get_user_id_by_person_name(target)
+
+ await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪")
+
+ try:
+ await self.send_message(f"[command]mute,{user_id},{duration}")
+
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")
+ await self.send_message_by_expressor(f"执行mute动作时出错: {e}")
+
+ return False, "执行mute动作时出错"
+
+ return True, "测试动作执行成功"
diff --git a/src/plugins/test_plugin/actions/online_action.py b/src/plugins/test_plugin/actions/online_action.py
new file mode 100644
index 000000000..4f49045f2
--- /dev/null
+++ b/src/plugins/test_plugin/actions/online_action.py
@@ -0,0 +1,43 @@
+from src.common.logger_manager import get_logger
+from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
+from typing import Tuple
+
+logger = get_logger("check_online_action")
+
+
+@register_action
+class CheckOnlineAction(PluginAction):
+ """测试动作处理类"""
+
+ action_name = "check_online_action"
+ action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用"
+ action_parameters = {"mode": "查看模式"}
+ action_require = [
+ "当有人要求你检查Maibot(麦麦 机器人)在线状态时使用",
+ "mode参数为version时查看在线版本状态,默认用这种",
+ "mode参数为type时查看在线系统类型分布",
+ ]
+ default = True # 不是默认动作,需要手动添加到使用集
+
+ async def process(self) -> Tuple[bool, str]:
+ """处理测试动作"""
+ logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
+
+ # 发送测试消息
+ mode = self.action_data.get("mode", "type")
+
+ await self.send_message_by_expressor("我看看")
+
+ try:
+ if mode == "type":
+ await self.send_message("#online detail")
+ elif mode == "version":
+ await self.send_message("#online")
+
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 执行online动作时出错: {e}")
+ await self.send_message_by_expressor("执行online动作时出错: {e}")
+
+ return False, "执行online动作时出错"
+
+ return True, "测试动作执行成功"
diff --git a/src/plugins/test_plugin/actions/test_action.py b/src/plugins/test_plugin/actions/test_action.py
new file mode 100644
index 000000000..995dd918a
--- /dev/null
+++ b/src/plugins/test_plugin/actions/test_action.py
@@ -0,0 +1,37 @@
+from src.common.logger_manager import get_logger
+from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
+from typing import Tuple
+
+logger = get_logger("test_action")
+
+
+@register_action
+class TestAction(PluginAction):
+ """测试动作处理类"""
+
+ action_name = "test_action"
+ action_description = "这是一个测试动作,当有人要求你测试插件系统时使用"
+ action_parameters = {"test_param": "测试参数(可选)"}
+ action_require = [
+ "测试情况下使用",
+ "想测试插件动作加载时使用",
+ ]
+ default = False # 不是默认动作,需要手动添加到使用集
+
+ async def process(self) -> Tuple[bool, str]:
+ """处理测试动作"""
+ logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}")
+
+ # 获取聊天类型
+ chat_type = self.get_chat_type()
+ logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}")
+
+ # 获取最近消息
+ recent_messages = self.get_recent_messages(3)
+ logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}")
+
+ # 发送测试消息
+ test_param = self.action_data.get("test_param", "默认参数")
+ await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}")
+
+ return True, "测试动作执行成功"
diff --git a/src/tools/not_used/change_mood.py b/src/tools/not_used/change_mood.py
index c34bebb93..69fc3bb78 100644
--- a/src/tools/not_used/change_mood.py
+++ b/src/tools/not_used/change_mood.py
@@ -44,7 +44,7 @@ class ChangeMoodTool(BaseTool):
_ori_response = ",".join(response_set)
# _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text)
emotion = "平静"
- mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
+ mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"}
except Exception as e:
logger.error(f"心情改变工具执行失败: {str(e)}")
diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py
index 65acd55c0..fd37f11e7 100644
--- a/src/tools/tool_can_use/get_knowledge.py
+++ b/src/tools/tool_can_use/get_knowledge.py
@@ -1,8 +1,10 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding
-from src.common.database import db
+from src.common.database.database_model import Knowledges # Updated import
from src.common.logger_manager import get_logger
-from typing import Any, Union
+from typing import Any, Union, List # Added List
+import json # Added for parsing embedding
+import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool")
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
Returns:
dict: 工具执行结果
"""
+ query = "" # Initialize query to ensure it's defined in except block
try:
query = function_args.get("query")
threshold = function_args.get("threshold", 0.4)
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
+ @staticmethod
+ def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
+ """计算两个向量之间的余弦相似度"""
+ dot_product = sum(p * q for p, q in zip(vec1, vec2))
+ magnitude1 = math.sqrt(sum(p * p for p in vec1))
+ magnitude2 = math.sqrt(sum(q * q for q in vec2))
+ if magnitude1 == 0 or magnitude2 == 0:
+ return 0.0
+ return dot_product / (magnitude1 * magnitude2)
+
@staticmethod
def get_info_from_db(
- query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
+ query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
"""从数据库中获取相关信息
@@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool):
if not query_embedding:
return "" if not return_raw else []
- # 使用余弦相似度计算
- pipeline = [
- {
- "$addFields": {
- "dotProduct": {
- "$reduce": {
- "input": {"$range": [0, {"$size": "$embedding"}]},
- "initialValue": 0,
- "in": {
- "$add": [
- "$$value",
- {
- "$multiply": [
- {"$arrayElemAt": ["$embedding", "$$this"]},
- {"$arrayElemAt": [query_embedding, "$$this"]},
- ]
- },
- ]
- },
- }
- },
- "magnitude1": {
- "$sqrt": {
- "$reduce": {
- "input": "$embedding",
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- "magnitude2": {
- "$sqrt": {
- "$reduce": {
- "input": query_embedding,
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- }
- },
- {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
- {
- "$match": {
- "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
- }
- },
- {"$sort": {"similarity": -1}},
- {"$limit": limit},
- {"$project": {"content": 1, "similarity": 1}},
- ]
+ similar_items = []
+ try:
+ all_knowledges = Knowledges.select()
+ for item in all_knowledges:
+ try:
+ item_embedding_str = item.embedding
+ if not item_embedding_str:
+ logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
+ continue
+ item_embedding = json.loads(item_embedding_str)
+ if not isinstance(item_embedding, list) or not all(
+ isinstance(x, (int, float)) for x in item_embedding
+ ):
+ logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
+ continue
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
+ continue
+ except AttributeError:
+ logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
+ continue
- results = list(db.knowledges.aggregate(pipeline))
- logger.debug(f"知识库查询结果数量: {len(results)}")
+ similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
+
+ if similarity >= threshold:
+ similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
+
+ # 按相似度降序排序
+ similar_items.sort(key=lambda x: x["similarity"], reverse=True)
+
+ # 应用限制
+ results = similar_items[:limit]
+ logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
+
+ except Exception as e:
+ logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
+ return "" if not return_raw else []
if not results:
return "" if not return_raw else []
if return_raw:
- return results
+ # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
+ # 这里返回包含内容和相似度的字典列表
+ return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py
index c55170b88..ff36085d6 100644
--- a/src/tools/tool_use.py
+++ b/src/tools/tool_use.py
@@ -15,7 +15,7 @@ logger = get_logger("tool_use")
class ToolUser:
def __init__(self):
self.llm_model_tool = LLMRequest(
- model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
+ model=global_config.model.tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
)
@staticmethod
@@ -37,7 +37,7 @@ class ToolUser:
# print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}")
# 这些信息应该从调用者传入,而不是从self获取
- bot_name = global_config.BOT_NICKNAME
+ bot_name = global_config.bot.nickname
prompt = ""
prompt += mid_memory_info
prompt += "你正在思考如何回复群里的消息。\n"
diff --git a/template/bot_config_meta.toml b/template/bot_config_meta.toml
deleted file mode 100644
index c3541baad..000000000
--- a/template/bot_config_meta.toml
+++ /dev/null
@@ -1,104 +0,0 @@
-[inner.version]
-describe = "版本号"
-important = true
-can_edit = false
-
-[bot.qq]
-describe = "机器人的QQ号"
-important = true
-can_edit = true
-
-[bot.nickname]
-describe = "机器人的昵称"
-important = true
-can_edit = true
-
-[bot.alias_names]
-describe = "机器人的别名列表,该选项还在调试中,暂时未生效"
-important = false
-can_edit = true
-
-[groups.talk_allowed]
-describe = "可以回复消息的群号码列表"
-important = true
-can_edit = true
-
-[groups.talk_frequency_down]
-describe = "降低回复频率的群号码列表"
-important = false
-can_edit = true
-
-[groups.ban_user_id]
-describe = "禁止回复和读取消息的QQ号列表"
-important = false
-can_edit = true
-
-[personality.personality_core]
-describe = "用一句话或几句话描述人格的核心特点,建议20字以内"
-important = true
-can_edit = true
-
-[personality.personality_sides]
-describe = "用一句话或几句话描述人格的一些细节,条数任意,不能为0,该选项还在调试中"
-important = false
-can_edit = true
-
-[identity.identity_detail]
-describe = "身份特点列表,条数任意,不能为0,该选项还在调试中"
-important = false
-can_edit = true
-
-[identity.age]
-describe = "年龄,单位岁"
-important = false
-can_edit = true
-
-[identity.gender]
-describe = "性别"
-important = false
-can_edit = true
-
-[identity.appearance]
-describe = "外貌特征描述,该选项还在调试中,暂时未生效"
-important = false
-can_edit = true
-
-[platforms.nonebot-qq]
-describe = "nonebot-qq适配器提供的链接"
-important = true
-can_edit = true
-
-[chat.allow_focus_mode]
-describe = "是否允许专注聊天状态"
-important = false
-can_edit = true
-
-[chat.base_normal_chat_num]
-describe = "最多允许多少个群进行普通聊天"
-important = false
-can_edit = true
-
-[chat.base_focused_chat_num]
-describe = "最多允许多少个群进行专注聊天"
-important = false
-can_edit = true
-
-[chat.observation_context_size]
-describe = "观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖"
-important = false
-can_edit = true
-
-[chat.message_buffer]
-describe = "启用消息缓冲器,启用此项以解决消息的拆分问题,但会使麦麦的回复延迟"
-important = false
-can_edit = true
-
-[chat.ban_words]
-describe = "需要过滤的消息列表"
-important = false
-can_edit = true
-
-[chat.ban_msgs_regex]
-describe = "需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤(支持CQ码)"
-important = false
-can_edit = true
\ No newline at end of file
diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml
index 931afe2ed..943422029 100644
--- a/template/bot_config_template.toml
+++ b/template/bot_config_template.toml
@@ -1,18 +1,10 @@
[inner]
-version = "1.7.0"
+version = "2.2.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件,请在修改后将version的值进行变更
-#如果新增项目,请在BotConfig类下新增相应的变量
-#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{
-#"func":memory,
-#"support":">=0.0.0", #新的版本号
-#"necessary":False #是否必须
-#}
-#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断:
- # if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
- # config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
-
+#如果新增项目,请阅读src/config/official_configs.py中的说明
+#
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号:当你做了不兼容的 API 修改,
# 次版本号:当你做了向下兼容的功能性新增,
@@ -21,17 +13,12 @@ version = "1.7.0"
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
[bot]
-qq = 1145141919810
+qq_account = 1145141919810
nickname = "麦麦"
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
-[groups]
-talk_allowed = [
- 123,
- 123,
-] #可以回复消息的群号码
-talk_frequency_down = [] #降低回复频率的群号码
-ban_user_id = [] #禁止回复和读取消息的QQ号
+[chat_target]
+talk_frequency_down_groups = [] #降低回复频率的群号码
[personality] #未完善
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
@@ -53,10 +40,13 @@ identity_detail = [
"身份特点",
"身份特点",
]# 条数任意,不能为0, 该选项还在调试中
+
#外貌特征
-age = 20 # 年龄 单位岁
-gender = "男" # 性别
-appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效
+age = 18 # 年龄 单位岁
+gender = "女" # 性别
+height = "170" # 身高(单位cm)
+weight = "50" # 体重(单位kg)
+appearance = "用一句或几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效
[platforms] # 必填项目,填写每个平台适配器提供的链接
qq="http://127.0.0.1:18002/api/message"
@@ -65,10 +55,8 @@ qq="http://127.0.0.1:18002/api/message"
allow_focus_mode = false # 是否允许专注聊天状态
# 是否启用heart_flowC(HFC)模式
# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token
-base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天
-base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天
-observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖
+chat.observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖
message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
@@ -85,11 +73,10 @@ ban_msgs_regex = [
[normal_chat] #普通聊天
#一般回复参数
-model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率
-model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率
+reasoning_model_probability = 0.3 # 麦麦回答时选择推理模型的概率(与之相对的,普通模型的概率为1 - reasoning_model_probability)
emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发
-thinking_timeout = 100 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢)
+thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢)
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1
@@ -100,34 +87,34 @@ mentioned_bot_inevitable_reply = false # 提及 bot 必然回复
at_bot_inevitable_reply = false # @bot 必然回复
[focus_chat] #专注聊天
-reply_trigger_threshold = 3.6 # 专注聊天触发阈值,越低越容易进入专注聊天
-default_decay_rate_per_second = 0.95 # 默认衰减率,越大衰减越快,越高越难进入专注聊天
+reply_trigger_threshold = 3.0 # 专注聊天触发阈值,越低越容易进入专注聊天
+default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入专注聊天
consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天
# 以下选项暂时无效
-compressed_length = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
+compressed_length = 5 # 不能大于chat.observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下文会被删除
[emoji]
-max_emoji_num = 40 # 表情包最大数量
-max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包
-check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
+max_reg_num = 40 # 表情包最大注册数量
+do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
+check_interval = 120 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
save_pic = false # 是否保存图片
-save_emoji = false # 是否保存表情包
+cache_emoji = true # 是否缓存表情包
steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
-enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
-check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
+content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
+filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
[memory]
-build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
-build_memory_distribution = [6.0,3.0,0.6,32.0,12.0,0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
-build_memory_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
-build_memory_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富
+memory_build_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
+memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
+memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
+memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
-memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时
+memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时
memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,麦麦整合越频繁,记忆更精简
@@ -135,54 +122,48 @@ consolidation_similarity_threshold = 0.7 # 相似度阈值
consolidation_check_percentage = 0.01 # 检查节点比例
#不希望记忆的词,已经记忆的不会受到影响
-memory_ban_words = [
- # "403","张三"
-]
+memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
[mood]
mood_update_interval = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate = 0.95 # 情绪衰减率
mood_intensity_factor = 1.0 # 情绪强度因子
-[keywords_reaction] # 针对某个关键词作出反应
+[keyword_reaction] # 针对某个关键词作出反应
enable = true # 关键词反应功能的总开关
-[[keywords_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可
+[[keyword_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可
enable = true # 是否启用此条(为了人类在未来AI战争能更好地识别AI(bushi),默认开启)
keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词
reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词
-[[keywords_reaction.rules]] # 就像这样复制
+[[keyword_reaction.rules]] # 就像这样复制
enable = false # 仅作示例,不会触发
keywords = ["测试关键词回复","test",""]
reaction = "回答“测试成功”" # 修复错误的引号
-[[keywords_reaction.rules]] # 使用正则表达式匹配句式
+[[keyword_reaction.rules]] # 使用正则表达式匹配句式
enable = false # 仅作示例,不会触发
regex = ["^(?P\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写
reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)"
[chinese_typo]
enable = true # 是否启用中文错别字生成器
-error_rate=0.001 # 单字替换概率
+error_rate=0.01 # 单字替换概率
min_freq=9 # 最小字频阈值
tone_error_rate=0.1 # 声调错误概率
word_replace_rate=0.006 # 整词替换概率
[response_splitter]
-enable_response_splitter = true # 是否启用回复分割器
-response_max_length = 256 # 回复允许的最大长度
-response_max_sentence_num = 4 # 回复允许的最大句子数
+enable = true # 是否启用回复分割器
+max_length = 256 # 回复允许的最大长度
+max_sentence_num = 4 # 回复允许的最大句子数
enable_kaomoji_protection = false # 是否启用颜文字保护
-model_max_output_length = 256 # 模型单次返回的最大token数
-
-[remote] #发送统计信息,主要是看全球有多少只麦麦
+[telemetry] #发送统计信息,主要是看全球有多少只麦麦
enable = true
[experimental] #实验性功能
-enable_friend_chat = false # 是否启用好友聊天
-talk_allowed_private = [] # 可以回复消息的QQ号
pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立
#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写
@@ -194,14 +175,17 @@ pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与
# stream = : 用于指定模型是否是使用流式输出
# 如果不指定,则该项是 False
+[model]
+model_max_output_length = 800 # 模型单次返回的最大token数
+
#这个模型必须是推理模型
-[model.llm_reasoning] # 一般聊天模式的推理回复模型
+[model.reasoning] # 一般聊天模式的推理回复模型
name = "Pro/deepseek-ai/DeepSeek-R1"
provider = "SILICONFLOW"
pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗)
pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗)
-[model.llm_normal] #V3 回复模型 专注和一般聊天模式共用的回复模型
+[model.normal] #V3 回复模型 专注和一般聊天模式共用的回复模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
@@ -209,13 +193,13 @@ pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
temp = 0.2 #模型的温度,新V3建议0.1-0.3
-[model.llm_topic_judge] #主题判断模型:建议使用qwen2.5 7b
+[model.topic_judge] #主题判断模型:建议使用qwen2.5 7b
name = "Pro/Qwen/Qwen2.5-7B-Instruct"
provider = "SILICONFLOW"
pri_in = 0.35
pri_out = 0.35
-[model.llm_summary] #概括模型,建议使用qwen2.5 32b 及以上
+[model.summary] #概括模型,建议使用qwen2.5 32b 及以上
name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW"
pri_in = 1.26
@@ -227,27 +211,27 @@ provider = "SILICONFLOW"
pri_in = 0.35
pri_out = 0.35
-[model.llm_heartflow] # 用于控制麦麦是否参与聊天的模型
+[model.heartflow] # 用于控制麦麦是否参与聊天的模型
name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW"
pri_in = 1.26
pri_out = 1.26
-[model.llm_observation] #观察模型,压缩聊天内容,建议用免费的
+[model.observation] #观察模型,压缩聊天内容,建议用免费的
# name = "Pro/Qwen/Qwen2.5-7B-Instruct"
name = "Qwen/Qwen2.5-7B-Instruct"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
-[model.llm_sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型
+[model.sub_heartflow] #心流:认真聊天时,生成麦麦的内心想法,必须使用具有工具调用能力的模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
pri_out = 8
temp = 0.3 #模型的温度,新V3建议0.1-0.3
-[model.llm_plan] #决策:认真水群时,负责决定麦麦该做什么
+[model.plan] #决策:认真聊天时,负责决定麦麦该做什么
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
@@ -265,7 +249,7 @@ pri_out = 0
#私聊PFC:需要开启PFC功能,默认三个模型均为硅基流动v3,如果需要支持多人同时私聊或频繁调用,建议把其中的一个或两个换成官方v3或其它模型,以免撞到429
#PFC决策模型
-[model.llm_PFC_action_planner]
+[model.pfc_action_planner]
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
temp = 0.3
@@ -273,7 +257,7 @@ pri_in = 2
pri_out = 8
#PFC聊天模型
-[model.llm_PFC_chat]
+[model.pfc_chat]
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
temp = 0.3
@@ -281,7 +265,7 @@ pri_in = 2
pri_out = 8
#PFC检查模型
-[model.llm_PFC_reply_checker]
+[model.pfc_reply_checker]
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
@@ -294,7 +278,7 @@ pri_out = 8
#以下模型暂时没有使用!!
#以下模型暂时没有使用!!
-[model.llm_tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b
+[model.tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b
name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW"
pri_in = 1.26
diff --git a/tests/common/test_message_repository.py b/tests/common/test_message_repository.py
new file mode 100644
index 000000000..798fa16b1
--- /dev/null
+++ b/tests/common/test_message_repository.py
@@ -0,0 +1,172 @@
+import unittest
+from unittest.mock import patch, MagicMock
+import datetime
+import sys
+import os
+
+# 添加项目根目录到Python路径
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
+
+from peewee import SqliteDatabase
+from src.common.database.database_model import Messages, BaseModel
+from src.common.message_repository import find_messages
+
+
+class TestMessageRepository(unittest.TestCase):
+ def setUp(self):
+ # 创建内存中的SQLite数据库用于测试
+ self.test_db = SqliteDatabase(":memory:")
+
+ # 覆盖原有数据库连接
+ BaseModel._meta.database = self.test_db
+ Messages._meta.database = self.test_db
+
+ # 创建表
+ self.test_db.create_tables([Messages])
+
+ # 添加测试数据
+ current_time = datetime.datetime.now().timestamp()
+ self.test_messages = [
+ {
+ "message_id": "msg1",
+ "time": current_time - 3600, # 1小时前
+ "chat_id": "5ed68437e28644da51f314f37df68d18",
+ "chat_info_stream_id": "stream1",
+ "chat_info_platform": "qq",
+ "chat_info_user_platform": "qq",
+ "chat_info_user_id": "user1",
+ "chat_info_user_nickname": "用户1",
+ "chat_info_user_cardname": "卡片名1",
+ "chat_info_group_platform": "qq",
+ "chat_info_group_id": "group1",
+ "chat_info_group_name": "群组1",
+ "chat_info_create_time": current_time - 7200, # 2小时前
+ "chat_info_last_active_time": current_time - 1800, # 30分钟前
+ "user_platform": "qq",
+ "user_id": "user1",
+ "user_nickname": "用户1",
+ "user_cardname": "卡片名1",
+ "processed_plain_text": "你好",
+ "detailed_plain_text": "你好",
+ "memorized_times": 1,
+ },
+ {
+ "message_id": "msg2",
+ "time": current_time - 1800, # 30分钟前
+ "chat_id": "chat1",
+ "chat_info_stream_id": "stream1",
+ "chat_info_platform": "qq",
+ "chat_info_user_platform": "qq",
+ "chat_info_user_id": "user1",
+ "chat_info_user_nickname": "用户1",
+ "chat_info_user_cardname": "卡片名1",
+ "chat_info_group_platform": "qq",
+ "chat_info_group_id": "group1",
+ "chat_info_group_name": "群组1",
+ "chat_info_create_time": current_time - 7200,
+ "chat_info_last_active_time": current_time - 900, # 15分钟前
+ "user_platform": "qq",
+ "user_id": "user1",
+ "user_nickname": "用户1",
+ "user_cardname": "卡片名1",
+ "processed_plain_text": "世界",
+ "detailed_plain_text": "世界",
+ "memorized_times": 2,
+ },
+ {
+ "message_id": "msg3",
+ "time": current_time - 900, # 15分钟前
+ "chat_id": "chat2",
+ "chat_info_stream_id": "stream2",
+ "chat_info_platform": "wechat",
+ "chat_info_user_platform": "wechat",
+ "chat_info_user_id": "user2",
+ "chat_info_user_nickname": "用户2",
+ "chat_info_user_cardname": "卡片名2",
+ "chat_info_group_platform": "wechat",
+ "chat_info_group_id": "group2",
+ "chat_info_group_name": "群组2",
+ "chat_info_create_time": current_time - 3600,
+ "chat_info_last_active_time": current_time - 600, # 10分钟前
+ "user_platform": "wechat",
+ "user_id": "user2",
+ "user_nickname": "用户2",
+ "user_cardname": "卡片名2",
+ "processed_plain_text": "测试",
+ "detailed_plain_text": "测试",
+ "memorized_times": 0,
+ },
+ ]
+
+ for msg_data in self.test_messages:
+ Messages.create(**msg_data)
+
+ def tearDown(self):
+ # 关闭测试数据库连接
+ self.test_db.close()
+
+ def test_find_messages_no_filter(self):
+ """测试不带过滤器的查询"""
+ results = find_messages({})
+ self.assertEqual(len(results), 3)
+ # 验证结果是否按时间升序排列
+ self.assertEqual(results[0]["message_id"], "msg1")
+ self.assertEqual(results[1]["message_id"], "msg2")
+ self.assertEqual(results[2]["message_id"], "msg3")
+
+ def test_find_messages_with_filter(self):
+ """测试带过滤器的查询"""
+ results = find_messages({"chat_id": "chat1"})
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0]["message_id"], "msg1")
+ self.assertEqual(results[1]["message_id"], "msg2")
+
+ results = find_messages({"user_id": "user2"})
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0]["message_id"], "msg3")
+
+ def test_find_messages_with_operators(self):
+ """测试带操作符的查询"""
+ results = find_messages({"memorized_times": {"$gt": 0}})
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0]["message_id"], "msg1")
+ self.assertEqual(results[1]["message_id"], "msg2")
+
+ results = find_messages({"memorized_times": {"$gte": 2}})
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0]["message_id"], "msg2")
+
+ def test_find_messages_with_sort(self):
+ """测试带排序的查询"""
+ results = find_messages({}, sort=[("memorized_times", -1)])
+ self.assertEqual(len(results), 3)
+ # 验证结果是否按memorized_times降序排列
+ self.assertEqual(results[0]["message_id"], "msg2") # memorized_times = 2
+ self.assertEqual(results[1]["message_id"], "msg1") # memorized_times = 1
+ self.assertEqual(results[2]["message_id"], "msg3") # memorized_times = 0
+
+ def test_find_messages_with_limit(self):
+ """测试带限制的查询"""
+ # 默认limit_mode为latest,应返回最新的2条记录
+ results = find_messages({}, limit=2)
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0]["message_id"], "msg2")
+ self.assertEqual(results[1]["message_id"], "msg3")
+
+ # 使用earliest模式,应返回最早的2条记录
+ results = find_messages({}, limit=2, limit_mode="earliest")
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0]["message_id"], "msg1")
+ self.assertEqual(results[1]["message_id"], "msg2")
+
+ def test_find_messages_with_combined_criteria(self):
+ """测试组合查询条件"""
+ results = find_messages(
+ {"chat_info_platform": "qq", "memorized_times": {"$gt": 0}}, sort=[("time", 1)], limit=1
+ )
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0]["message_id"], "msg2")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_build_readable_messages.py b/tests/test_build_readable_messages.py
new file mode 100644
index 000000000..71d91a46d
--- /dev/null
+++ b/tests/test_build_readable_messages.py
@@ -0,0 +1,173 @@
+import unittest
+import sys
+import os
+import datetime
+import time
+import asyncio
+import traceback
+import json
+import copy
+
+# 添加项目根目录到Python路径
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages
+from src.common.logger import get_module_logger
+
+# 创建测试日志记录器
+logger = get_module_logger("test_readable_msg")
+
+
+class TestBuildReadableMessages(unittest.TestCase):
+ def setUp(self):
+ # 准备测试数据:从真实数据库获取消息
+ self.chat_id = "5ed68437e28644da51f314f37df68d18"
+ self.current_time = time.time()
+ self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
+
+ # 获取最新的10条消息
+ try:
+ self.messages = get_raw_msg_by_timestamp_with_chat(
+ chat_id=self.chat_id,
+ timestamp_start=self.thirty_days_ago,
+ timestamp_end=self.current_time,
+ limit=10,
+ limit_mode="latest",
+ )
+ logger.info(f"已获取 {len(self.messages)} 条测试消息")
+
+ # 打印消息样例
+ if self.messages:
+ sample_msg = self.messages[0]
+ logger.info(f"消息样例: {list(sample_msg.keys())}")
+ logger.info(f"消息内容: {sample_msg.get('processed_plain_text', '无文本内容')[:50]}...")
+ except Exception as e:
+ logger.error(f"获取消息失败: {e}")
+ logger.error(traceback.format_exc())
+ self.messages = []
+
+ def test_manual_fix_messages(self):
+ """创建一个手动修复版本的消息进行测试"""
+ if not self.messages:
+ self.skipTest("没有测试消息,跳过测试")
+ return
+
+ logger.info("开始手动修复消息...")
+
+ # 创建修复版本的消息列表
+ fixed_messages = []
+
+ for msg in self.messages:
+ # 深拷贝以避免修改原始数据
+ fixed_msg = copy.deepcopy(msg)
+
+ # 构建 user_info 对象
+ if "user_info" not in fixed_msg:
+ user_info = {
+ "platform": fixed_msg.get("user_platform", "qq"),
+ "user_id": fixed_msg.get("user_id", "10000"),
+ "user_nickname": fixed_msg.get("user_nickname", "测试用户"),
+ "user_cardname": fixed_msg.get("user_cardname", ""),
+ }
+ fixed_msg["user_info"] = user_info
+ logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info")
+
+ fixed_messages.append(fixed_msg)
+
+ logger.info(f"已修复 {len(fixed_messages)} 条消息")
+
+ try:
+ # 使用修复后的消息尝试格式化
+ formatted_text = asyncio.run(
+ build_readable_messages(
+ messages=fixed_messages,
+ replace_bot_name=True,
+ merge_messages=False,
+ timestamp_mode="absolute",
+ read_mark=0.0,
+ truncate=False,
+ )
+ )
+
+ logger.info("使用修复后的消息格式化完成")
+ logger.info(f"格式化结果长度: {len(formatted_text)}")
+ if formatted_text:
+ logger.info(f"格式化结果预览: {formatted_text[:200]}...")
+ else:
+ logger.warning("格式化结果为空")
+
+ # 断言
+ self.assertNotEqual(formatted_text, "", "有消息时不应返回空字符串")
+ except Exception as e:
+ logger.error(f"使用修复后的消息格式化失败: {e}")
+ logger.error(traceback.format_exc())
+ raise
+
+ def test_debug_build_messages_internal(self):
+ """调试_build_readable_messages_internal函数"""
+ if not self.messages:
+ self.skipTest("没有测试消息,跳过测试")
+ return
+
+ logger.info("开始调试内部构建函数...")
+
+ try:
+ # 直接导入内部函数进行测试
+ from src.chat.utils.chat_message_builder import _build_readable_messages_internal
+
+ # 手动创建一个简单的测试消息列表
+ test_msg = self.messages[0].copy() # 使用第一条消息作为模板
+
+ # 检查消息结构
+ logger.info(f"测试消息keys: {list(test_msg.keys())}")
+ logger.info(f"user_info存在: {'user_info' in test_msg}")
+
+ # 修复缺少的user_info字段
+ if "user_info" not in test_msg:
+ logger.warning("消息中缺少user_info字段,添加模拟数据")
+ test_msg["user_info"] = {
+ "platform": test_msg.get("user_platform", "qq"),
+ "user_id": test_msg.get("user_id", "10000"),
+ "user_nickname": test_msg.get("user_nickname", "测试用户"),
+ "user_cardname": test_msg.get("user_cardname", ""),
+ }
+ logger.info(f"添加的user_info: {test_msg['user_info']}")
+
+ simple_msgs = [test_msg]
+
+ # 运行内部函数
+ result_text, result_details = asyncio.run(
+ _build_readable_messages_internal(
+ simple_msgs, replace_bot_name=True, merge_messages=False, timestamp_mode="absolute", truncate=False
+ )
+ )
+
+ logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}")
+ logger.info(f"详情列表长度: {len(result_details)}")
+
+ # 显示处理过程中的变量
+ if not result_text and len(simple_msgs) > 0:
+ logger.warning("消息处理可能有问题,检查关键步骤")
+ msg = simple_msgs[0]
+
+ # 打印关键变量的值
+ user_info = msg.get("user_info", {})
+ platform = user_info.get("platform")
+ user_id = user_info.get("user_id")
+ timestamp = msg.get("time")
+ content = msg.get("processed_plain_text", "")
+
+ logger.warning(f"平台: {platform}, 用户ID: {user_id}, 时间戳: {timestamp}")
+ logger.warning(f"内容: {content[:50]}...")
+
+ # 检查必要信息是否完整
+ logger.warning(f"必要信息完整性检查: {all([platform, user_id, timestamp is not None])}")
+
+ except Exception as e:
+ logger.error(f"调试内部函数失败: {e}")
+ logger.error(traceback.format_exc())
+ raise
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 000000000..1a1239601
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,7 @@
+from src.config.config import global_config
+
+
+class TestConfig:
+ def test_load(self):
+ config = global_config
+ print(config)
diff --git a/tests/test_extract_messages.py b/tests/test_extract_messages.py
new file mode 100644
index 000000000..95ddb523f
--- /dev/null
+++ b/tests/test_extract_messages.py
@@ -0,0 +1,83 @@
+import unittest
+import sys
+import os
+import datetime
+import time
+
+# 添加项目根目录到Python路径
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from src.common.message_repository import find_messages
+from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
+from peewee import SqliteDatabase
+from src.common.database.database import db # 导入实际的数据库连接
+
+
+class TestExtractMessages(unittest.TestCase):
+ def setUp(self):
+ # 这个测试使用真实的数据库,所以不需要创建测试数据
+ pass
+
+ def test_extract_latest_messages_direct(self):
+ """测试直接使用message_repository.find_messages函数"""
+ chat_id = "5ed68437e28644da51f314f37df68d18"
+
+ # 提取最新的10条消息
+ results = find_messages({"chat_id": chat_id}, limit=10)
+
+ # 打印结果数量
+ print(f"\n直接使用find_messages,找到 {len(results)} 条消息")
+
+ # 如果有结果,打印一些信息
+ if results:
+ print("\n消息时间顺序:")
+ for idx, msg in enumerate(results):
+ msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
+ print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
+ print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
+
+ # 验证结果按时间排序
+ times = [msg["time"] for msg in results]
+ self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
+ else:
+ print(f"未找到chat_id为 {chat_id} 的消息")
+
+ # 最基本的断言,确保测试有效
+ self.assertIsInstance(results, list, "结果应该是一个列表")
+
+ def test_extract_latest_messages_via_builder(self):
+ """使用chat_message_builder中的函数测试从真实数据库提取消息"""
+ chat_id = "5ed68437e28644da51f314f37df68d18"
+
+ # 设置时间范围为过去30天到现在
+ current_time = time.time()
+ thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
+
+ # 使用chat_message_builder中的函数
+ results = get_raw_msg_by_timestamp_with_chat(
+ chat_id=chat_id, timestamp_start=thirty_days_ago, timestamp_end=current_time, limit=10, limit_mode="latest"
+ )
+
+ # 打印结果数量
+ print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息")
+
+ # 如果有结果,打印一些信息
+ if results:
+ print("\n消息时间顺序:")
+ for idx, msg in enumerate(results):
+ msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
+ print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
+ print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
+
+ # 验证结果按时间排序
+ times = [msg["time"] for msg in results]
+ self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
+ else:
+ print(f"未找到chat_id为 {chat_id} 的消息")
+
+ # 最基本的断言,确保测试有效
+ self.assertIsInstance(results, list, "结果应该是一个列表")
+
+
+if __name__ == "__main__":
+ unittest.main()