This commit is contained in:
tcmofashi
2025-05-18 21:17:22 +08:00
114 changed files with 6207 additions and 4079 deletions

2
.gitignore vendored
View File

@@ -301,3 +301,5 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk
src/chat/focus_chat/working_memory/test/test1.txt
src/chat/focus_chat/working_memory/test/test4.txt

237
README.md
View File

@@ -1,177 +1,95 @@
# 麦麦MaiCore-MaiMBot (编辑中)
<br />
<div style="text-align: center">
<picture>
<source media="(max-width: 600px)" srcset="depends-data/maimai.png" width="100%">
<img alt="MaiBot" src="depends-data/maimai.png" title="作者:略nd" align="right" width="30%">
</picture>
![Python Version](https://img.shields.io/badge/Python-3.10+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者)
![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数)
![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数)
![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot)
# 麦麦MaiCore-MaiBot (编辑中)
</div>
![Python Version](https://img.shields.io/badge/Python-3.10+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者)
![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数)
![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数)
![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot)
<p style="text-align: center">
<a href="https://github.com/MaiM-with-u/MaiBot/">
<img src="depends-data/maimai.png" alt="Logo" style="width: 200px">
</a>
<br />
<a href="https://space.bilibili.com/1344099355">
画师略nd
</a>
<strong>
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 演示视频</a> |
<a href="#-更新和安装">🚀 快速入门</a> |
<a href="#-文档">📃 教程</a> |
<a href="#-讨论">💬 讨论</a> |
<a href="#-贡献和致谢">🙋 贡献指南</a>
</strong>
<h3 style="text-align: center">MaiBot(麦麦)</h3>
<p style="text-align: center">
一款专注于<strong> 群组聊天 </strong>的赛博网友
<br />
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
<br />
<br />
<!-- <a href="https://github.com/shaojintian/Best_README_template">查看Demo</a>
· -->
<a href="https://github.com/MaiM-with-u/MaiBot/issues">报告Bug</a>
·
<a href="https://github.com/MaiM-with-u/MaiBot/issues">提出新特性</a>
</p>
</p>
## 🎉 介绍
## 新版0.6.x部署前先阅读https://docs.mai-mai.org/faq/maibot/backup_update.html
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
## 📝 项目简介
**🍔MaiCore是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**基于LLM的自然语言交互
- 🤔 **实时思维系统**:模拟人类思考过程
- 💝 **情感表达系统**:丰富的表情包和情绪表达
- 🧠 **持久记忆系统**基于MongoDB的长期记忆存储
- 🔄 **动态人格系统**:自适应的性格特征
- 💭 **智能对话系统**:基于 LLM 的自然语言交互。
- 🤔 **实时思维系统**:模拟人类思考过程。
- 💝 **情感表达系统**:丰富的表情包和情绪表达。
- 🧠 **持久记忆系统**:基于 MongoDB 的长期记忆存储。
- 🔄 **动态人格系统**:自适应的性格特征。
<div style="text-align: center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="depends-data/video.png" style="max-width: 200px" alt="麦麦演示视频">
<br>
<picture>
<source media="(max-width: 600px)" srcset="depends-data/video.png" width="100%">
<img src="depends-data/video.png" width="30%" alt="麦麦演示视频">
</picture>
<br />
👆 点击观看麦麦演示视频 👆
</a>
</div>
## 🔥 更新和安装
### 📢 版本信息
**最新版本: v0.6.3** ([查看更新日志](changelogs/changelog.md))
> [!WARNING]
> 请阅读教程后更新!!!!!!!
> 请阅读教程后更新!!!!!!!
> 请阅读教程后更新!!!!!!!
> 次版本MaiBot将基于MaiCore运行不再依赖于nonebot相关组件运行。
> MaiBot将通过nonebot的插件与nonebot建立联系然后nonebot与QQ建立联系实现MaiBot与QQ的交互
**分支说明:**
- `main`: 稳定发布版本
- `dev`: 开发测试版本(不知道什么意思就别下)
- `classical`: 0.6.0之前的版本
**最新版本: v0.6.3** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
**GitHub 分支说明:**
- `main`: 稳定发布版本(推荐)
- `dev`: 开发测试版本(不稳定)
- `classical`: 旧版本(停止维护)
### 最新版本部署教程 (MaiCore 版本)
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
> [!WARNING]
> - 项目处于活跃开发阶段,代码可能随时更改
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
> - 由于开发中可能消耗较多token
### ⚠️ 重要提示
- 升级到v0.6.x版本前请务必阅读[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html)
- 本版本基于MaiCore重构通过nonebot插件与QQ平台交互
- 项目处于活跃开发阶段功能和API可能随时调整
### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517
- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475【已满】
- [四群](https://qm.qq.com/q/wGePTl1UyY) 729957033【已满】
> - 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html)
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于持续迭代,可能存在一些已知或未知的 bug。
> - 由于程序处于开发中,可能消耗较多 token。
## 💬 讨论
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
[五群](https://qm.qq.com/q/JxvHZnxyec) |
[三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)|
[四群](https://qm.qq.com/q/wGePTl1UyY)(已满)
## 📚 文档
**部分内容可能更新不够及时,请注意版本对应**
### (部分内容可能过时,请注意版本对应)
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
### 核心文档
- [📚 核心Wiki文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切
### 最新版本部署教程(MaiCore版本)
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式与旧版本不兼容
## 🎯 0.6.3 功能介绍
| 模块 | 主要功能 | 特点 |
|----------|------------------------------------------------------------------|-------|
| 💬 聊天系统 | • **统一调控不同回复逻辑**<br>• 智能交互模式 (普通聊天/专注聊天)<br>• 关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC)增强 | 拟人化交互 |
| 🧠 心流系统 | • 实时思考生成<br>**智能状态管理**<br>**概率回复机制**<br>• 自动启停机制<br>• 日程系统联动<br>**上下文感知工具调用** | 智能化决策 |
| 🧠 记忆系统 | • **记忆整合与提取**<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 |
| 😊 表情系统 | • **全新表情包系统**<br>**优化选择逻辑**<br>• 情绪匹配发送<br>• GIF支持<br>• 自动收集与审查 | 丰富表达 |
| 📅 日程系统 | • 动态日程生成<br>• 自定义想象力<br>• 思维流联动 | 智能规划 |
| 👥 关系系统 | • **工具调用动态更新**<br>• 关系管理优化<br>• 丰富接口支持<br>• 个性化交互 | 深度社交 |
| 📊 统计系统 | • 使用数据统计<br>• LLM调用记录<br>• 实时控制台显示 | 数据可视 |
| 🛠️ 工具系统 | • **LPMM知识库集成**<br>**上下文感知调用**<br>• 知识获取工具<br>• 自动注册机制<br>• 多工具支持 | 扩展功能 |
| 📚 **知识库(LPMM)** | • **全新LPMM系统**<br>**强大的信息检索能力** | 知识增强 |
| ✨ **昵称系统** | • **自动为群友取昵称**<br>**降低认错人概率** (早期阶段) | 身份识别 |
## 📐 项目架构
```mermaid
graph TD
A[MaiCore] --> B[对话系统]
A --> C[心流系统]
A --> D[记忆系统]
A --> E[情感系统]
B --> F[多模型支持]
B --> G[动态Prompt]
C --> H[实时思考]
C --> I[日程联动]
D --> J[记忆存储]
D --> K[记忆检索]
E --> L[表情管理]
E --> M[情绪识别]
```
## ✍如何给本项目报告BUG/提交建议/做贡献
MaiCore是一个开源项目我们非常欢迎你的参与。你的贡献无论是提交bug报告、功能需求还是代码pr都对项目非常宝贵。我们非常感谢你的支持🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](depends-data/CONTRIBUTE.md)(待补完)
## 设计理念(原始时代的火花)
### 设计理念(原始时代的火花)
> **千石可乐说:**
> - 这个项目最初只是为了给牛牛bot添加一点额外的功能但是功能越写越多最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"
> - 如果人类真的需要一个AI来陪伴自己并不是所有人都需要一个完美的能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
> - 代码会保持开源和开放但个人希望MaiMbot的运行时数据保持封闭尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器.
> - SengokuCola~~纯编程外行面向cursor编程很多代码写得不好多多包涵~~已得到大脑升级
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"
> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
> - 代码会保持开源和开放,但个人希望 MaiMbot 的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器
> - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级
## 📌 注意事项
> [!WARNING]
> 使用本项目前必须阅读和同意用户协议和隐私协议
> 本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI生成内容不代表本人观点和立场。
## 致谢
- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现
## 麦麦仓库状态
![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "Repobeats analytics image")
## 🙋 贡献和致谢
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr都对项目非常宝贵。我们非常感谢你的支持🎉
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
### 贡献者
@@ -181,8 +99,27 @@ MaiCore是一个开源项目我们非常欢迎你的参与。你的贡献
<img alt="contributors" src="https://contrib.rocks/image?repo=MaiM-with-u/MaiBot" />
</a>
**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们**
### 致谢
## Stargazers over time
- [略nd](https://space.bilibili.com/1344099355): 为麦麦绘制人设。
- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现。
[![Stargazers over time](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot)
**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们!**
## 📌 注意事项
> [!WARNING]
> 使用本项目前必须阅读和同意[用户协议](EULA.md)和[隐私协议](PRIVACY.md)。
> 本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI 生成内容不代表本项目团队的观点和立场。
## 麦麦仓库状态
![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态")
### Star 趋势
[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot)
## License
GPL-3.0

View File

@@ -1,27 +0,0 @@
这里放置了测试版本的细节更新
## [test-0.6.1-snapshot-1] - 2025-4-5
- 修复pfc回复出错bug
- 修复表情包打字时间,不会卡表情包
- 改进了知识库的提取
- 提供了新的数据库连接方式
- 修复了ban_user无效的问题
## [test-0.6.0-snapshot-9] - 2025-4-4
- 可以识别gif表情包
## [test-0.6.0-snapshot-8] - 2025-4-3
- 修复了表情包的注册,获取和发送逻辑
- 表情包增加存储上限
- 更改了回复引用的逻辑,从基于时间改为基于新消息
- 增加了调试信息
- 自动清理缓存图片
- 修复并重启了关系系统
## [test-0.6.0-snapshot-7] - 2025-4-2
- 修改版本号命名test-前缀为测试版,无前缀为正式版
- 提供私聊的PFC模式可以进行有目的自由多轮对话
## [0.6.0-mmc-4] - 2025-4-1
- 提供两种聊天逻辑思维流聊天ThinkFlowChat 和 推理聊天ReasoningChat
- 从结构上可支持多种回复消息逻辑

View File

@@ -149,7 +149,7 @@ c HeartFChatting工作方式
- **状态及含义**:
- `ChatState.ABSENT` (不参与/没在看): 初始或停用状态。子心流不观察新信息,不进行思考,也不回复。
- `ChatState.CHAT` (随便看看/水群): 普通聊天模式。激活 `NormalChatInstance`
* `ChatState.FOCUSED` (专注/认真水群): 专注聊天模式。激活 `HeartFlowChatInstance`
* `ChatState.FOCUSED` (专注/认真聊天): 专注聊天模式。激活 `HeartFlowChatInstance`
- **选择**: 子心流可以根据外部指令(来自 `SubHeartflowManager`)或内部逻辑(未来的扩展)选择进入 `ABSENT` 状态(不回复不观察),或进入 `CHAT` / `FOCUSED` 中的一种回复模式。
- **状态转换机制** (由 `SubHeartflowManager` 驱动,更细致的说明):
- **初始状态**: 新创建的 `SubHeartflow` 默认为 `ABSENT` 状态。

Binary file not shown.

View File

@@ -41,7 +41,7 @@ class APIBotConfig:
allow_focus_mode: bool # 是否允许专注聊天状态
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
observation_context_size: int # 观察到的最长上下文大小
chat.observation_context_size: int # 观察到的最长上下文大小
message_buffer: bool # 是否启用消息缓冲
ban_words: List[str] # 禁止词列表
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
@@ -128,7 +128,7 @@ class APIBotConfig:
llm_reasoning: Dict[str, Any] # 推理模型配置
llm_normal: Dict[str, Any] # 普通模型配置
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
llm_summary: Dict[str, Any] # 总结模型配置
model.summary: Dict[str, Any] # 总结模型配置
vlm: Dict[str, Any] # VLM模型配置
llm_heartflow: Dict[str, Any] # 心流模型配置
llm_observation: Dict[str, Any] # 观察模型配置
@@ -203,7 +203,7 @@ class APIBotConfig:
"llm_reasoning",
"llm_normal",
"llm_topic_judge",
"llm_summary",
"model.summary",
"vlm",
"llm_heartflow",
"llm_observation",

View File

@@ -1,6 +1,6 @@
from fastapi import HTTPException
from rich.traceback import install
from src.config.config import BotConfig
from src.config.config import Config
from src.common.logger_manager import get_logger
import os
@@ -14,8 +14,8 @@ async def reload_config():
from src.config import config as config_module
logger.debug("正在重载配置文件...")
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
config_module.global_config = BotConfig.load_config(config_path=bot_config_path)
bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml")
config_module.global_config = Config.load_config(config_path=bot_config_path)
logger.debug("配置文件重载成功")
return {"status": "reloaded"}
except FileNotFoundError as e:

View File

@@ -5,12 +5,15 @@ import os
import random
import time
import traceback
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any
from PIL import Image
import io
import re
from ...common.database import db
# from gradio_client import file
from ...common.database.database_model import Emoji
from ...common.database.database import db as peewee_db
from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest
@@ -51,7 +54,7 @@ class MaiEmoji:
self.is_deleted = False # 标记是否已被删除
self.format = ""
async def initialize_hash_format(self):
async def initialize_hash_format(self) -> Optional[bool]:
"""从文件创建表情包实例, 计算哈希值和格式"""
try:
# 使用 full_path 检查文件是否存在
@@ -104,7 +107,7 @@ class MaiEmoji:
self.is_deleted = True
return None
async def register_to_db(self):
async def register_to_db(self) -> bool:
"""
注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
@@ -143,22 +146,22 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
emoji_record = {
"filename": self.filename,
"path": self.path, # 存储目录路径
"full_path": self.full_path, # 存储完整文件路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion,
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time),
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
emotion_str = ",".join(self.emotion) if self.emotion else ""
# 使用upsert确保记录存在或被更新
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
Emoji.create(
hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +169,6 @@ class MaiEmoji:
except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
# 可以考虑在这里尝试删除已移动的文件,避免残留
try:
if os.path.exists(self.full_path): # full_path 此时是目标路径
os.remove(self.full_path)
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
except Exception as remove_error:
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False
except Exception as e:
@@ -181,7 +176,7 @@ class MaiEmoji:
logger.error(traceback.format_exc())
return False
async def delete(self):
async def delete(self) -> bool:
"""删除表情包
删除表情包的文件和数据库记录
@@ -201,10 +196,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash})
deleted_in_db = result.deleted_count > 0
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
if deleted_in_db:
if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除
self.is_deleted = True
@@ -224,7 +223,7 @@ class MaiEmoji:
return False
def _emoji_objects_to_readable_list(emoji_objects):
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
"""将表情包对象列表转换为可读的字符串列表
参数:
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
return emoji_info_list
def _to_emoji_objects(data):
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list:
full_path = emoji_data.get("full_path")
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.full_path
if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}")
logger.warning(
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
)
load_errors += 1
continue # 跳过缺少 full_path 的记录
continue
try:
# 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path)
# 设置从数据库加载的属性
emoji.hash = emoji_data.get("hash", "")
# 如果 hash 为空,也跳过?取决于业务逻辑
emoji.hash = emoji_data.emoji_hash
if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1
continue
emoji.description = emoji_data.get("description", "")
emoji.emotion = emoji_data.get("emotion", [])
emoji.usage_count = emoji_data.get("usage_count", 0)
# 优先使用 last_used_time否则用 timestamp最后用当前时间
last_used = emoji_data.get("last_used_time")
timestamp = emoji_data.get("timestamp")
emoji.last_used_time = (
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
)
emoji.register_time = timestamp if timestamp is not None else time.time()
emoji.format = emoji_data.get("format", "") # 加载格式
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
# 不需要再手动设置 path 和 filename__init__ 会自动处理
db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji)
except ValueError as ve: # 捕获 __init__ 可能的错误
except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1
except Exception as e:
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
return emoji_objects, load_errors
def _ensure_emoji_dir():
def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
async def clear_temp_emoji():
async def clear_temp_emoji() -> None:
"""清理临时表情包
清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过100时会全部删除
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
logger.success("[清理] 完成")
async def clean_unused_emojis(emoji_dir, emoji_objects):
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
class EmojiManager:
_instance = None
def __new__(cls):
def __new__(cls) -> "EmojiManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
def __init__(self) -> None:
self._initialized = None
self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.llm_normal, max_tokens=600, request_type="emoji"
model=global_config.model.normal, max_tokens=600, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
def initialize(self):
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self._ensure_emoji_collection()
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
self._initialized = True
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception as e:
logger.exception(f"初始化表情管理器失败: {e}")
Emoji.create_table(safe=True) # Ensures table exists
def _ensure_db(self):
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_hash: str):
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}})
for emoji in self.emoji_objects:
if emoji.hash == emoji_hash:
emoji.usage_count += 1
break
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -447,7 +425,6 @@ class EmojiManager:
if not all_emojis:
logger.warning("内存中没有任何表情包对象")
# 可以考虑再查一次数据库?或者依赖定期任务更新
return None
# 计算每个表情包与输入文本的最大情感相似度
@@ -463,40 +440,38 @@ class EmojiManager:
# 计算与每个emotion标签的相似度取最大值
max_similarity = 0
best_matching_emotion = "" # 记录最匹配的 emotion 喵~
best_matching_emotion = ""
for emotion in emotions:
# 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0)
if similarity > max_similarity: # 如果找到更相似的喵~
if similarity > max_similarity:
max_similarity = similarity
best_matching_emotion = emotion # 就记下这个 emotion 喵~
best_matching_emotion = emotion
if best_matching_emotion: # 确保有匹配的情感才添加喵~
emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~
if best_matching_emotion:
emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前10个最相似的表情包
top_emojis = (
emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
) # 改个名字,更清晰喵~
top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
if not top_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前几个中随机选择一个
selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~
selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
# 更新使用次数
self.record_usage(selected_emoji.hash)
self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time()
logger.info( # 使用匹配到的 emotion 记录日志喵~
logger.info(
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
)
# 返回完整文件路径和描述
@@ -534,7 +509,7 @@ class EmojiManager:
return previous_row[-1]
async def check_emoji_file_integrity(self):
async def check_emoji_file_integrity(self) -> None:
"""检查表情包文件完整性
遍历self.emoji_objects中的所有对象检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除
@@ -599,7 +574,7 @@ class EmojiManager:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
async def start_periodic_check_register(self):
async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db()
while True:
@@ -613,18 +588,18 @@ class EmojiManager:
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查目录是否为空
files = os.listdir(EMOJI_DIR)
if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
@@ -651,15 +626,16 @@ class EmojiManager:
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
async def get_all_emoji_from_db(self):
async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.info("[数据库] 开始加载所有表情包记录...")
logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find())
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -674,7 +650,7 @@ class EmojiManager:
self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0
async def get_emoji_from_db(self, emoji_hash=None):
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数:
@@ -686,15 +662,16 @@ class EmojiManager:
try:
self._ensure_db()
query = {}
if emoji_hash:
query = {"hash": emoji_hash}
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query))
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -705,7 +682,7 @@ class EmojiManager:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
return []
async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]:
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
"""从内存中的 emoji_objects 列表获取表情包
参数:
@@ -758,7 +735,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
async def replace_a_emoji(self, new_emoji: MaiEmoji):
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
"""替换一个表情包
Args:
@@ -788,7 +765,7 @@ class EmojiManager:
# 构建提示词
prompt = (
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n"
@@ -819,7 +796,7 @@ class EmojiManager:
# 删除选定的表情包
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
delete_success = await self.delete_emoji(emoji_to_delete.hash)
delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
if delete_success:
# 修复:等待异步注册完成
@@ -847,7 +824,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
"""获取表情包描述和情感列表
Args:
@@ -871,10 +848,10 @@ class EmojiManager:
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包
if global_config.EMOJI_CHECK:
if global_config.emoji.content_filtration:
prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下:
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字

View File

@@ -10,7 +10,6 @@ from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.utils.utils import process_llm_response
from src.chat.utils.info_catcher import info_catcher_manager
@@ -18,16 +17,69 @@ from src.manager.mood_manager import mood_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random
logger = get_logger("expressor")
def init_prompt():
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"default_expressor_prompt",
)
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"default_expressor_private_prompt", # New template for private FOCUSED chat
)
class DefaultExpressor:
def __init__(self, chat_id: str):
self.log_prefix = "expressor"
# TODO: API-Adapter修改标记
self.express_model = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_heartflow",
)
@@ -51,8 +103,8 @@ class DefaultExpressor:
messageinfo = anchor_message.message_info
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
# logger.debug(f"创建思考消息:{anchor_message}")
@@ -66,7 +118,7 @@ class DefaultExpressor:
reply=anchor_message, # 回复的是锚点消息
thinking_start_time=thinking_time_point,
)
logger.debug(f"创建思考消息thinking_message{thinking_message}")
# logger.debug(f"创建思考消息thinking_message{thinking_message}")
await self.heart_fc_sender.register_thinking(thinking_message)
@@ -106,7 +158,7 @@ class DefaultExpressor:
if reply:
with Timer("发送消息", cycle_timers):
sent_msg_list = await self._send_response_messages(
sent_msg_list = await self.send_response_messages(
anchor_message=anchor_message,
thinking_id=thinking_id,
response_set=reply,
@@ -141,7 +193,7 @@ class DefaultExpressor:
try:
# 1. 获取情绪影响因子并调整模型温度
arousal_multiplier = mood_manager.get_arousal_multiplier()
current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier
current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
self.express_model.params["temperature"] = current_temp # 动态调整温度
# 2. 获取信息捕捉器
@@ -162,13 +214,10 @@ class DefaultExpressor:
# 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await prompt_builder.build_prompt(
build_mode="focus",
prompt = await self.build_prompt_focus(
chat_stream=self.chat_stream, # Pass the stream object
in_mind_reply=in_mind_reply,
reason=reason,
current_mind_info="",
structured_info="",
sender_name=sender_name_for_prompt, # Pass determined name
target_message=target_message,
)
@@ -183,10 +232,11 @@ class DefaultExpressor:
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# TODO: API-Adapter修改标记
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
# logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
logger.info(f"想要表达:{in_mind_reply}")
logger.info(f"理由:{reason}")
@@ -223,10 +273,108 @@ class DefaultExpressor:
traceback.print_exc()
return None
async def build_prompt_focus(
self,
reason,
chat_stream,
sender_name,
in_mind_reply,
target_message,
) -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
# Use sender_name passed from caller for private chat, otherwise use a default for group
# Default sender_name for group chat isn't used in the group prompt template, but set for consistency
effective_sender_name = sender_name if not is_group_chat else "某人"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=True,
)
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
style_habbits = []
grammar_habbits = []
# 1. learnt_expressions加权随机选3条
if learnt_style_expressions:
weights = [expr["count"] for expr in learnt_style_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 2. learnt_grammar_expressions加权随机选3条
if learnt_grammar_expressions:
weights = [expr["count"] for expr in learnt_grammar_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
grammar_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 3. personality_expressions随机选1条
if personality_expressions:
expr = random.choice(personality_expressions)
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
logger.debug("开始构建 focus prompt")
# --- Choose template based on chat type ---
if is_group_chat:
template_name = "default_expressor_prompt"
# Group specific formatting variables (already fetched or default)
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
prompt = await global_prompt_manager.format_prompt(
template_name,
style_habbits=style_habbits_str,
grammar_habbits=grammar_habbits_str,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
bot_name=global_config.bot.nickname,
prompt_personality="",
reason=reason,
in_mind_reply=in_mind_reply,
target_message=target_message,
)
else: # Private chat
template_name = "default_expressor_private_prompt"
prompt = await global_prompt_manager.format_prompt(
template_name,
sender_name=effective_sender_name, # Used in private template
chat_talking_prompt=chat_talking_prompt,
bot_name=global_config.bot.nickname,
prompt_personality=prompt_personality,
reason=reason,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
)
return prompt
# --- 发送器 (Sender) --- #
async def _send_response_messages(
self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str
async def send_response_messages(
self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = ""
) -> Optional[MessageSending]:
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
@@ -241,7 +389,11 @@ class DefaultExpressor:
stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志
# 检查思考过程是否仍在进行,并获取开始时间
if thinking_id:
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
else:
thinking_id = "ds" + str(round(time.time(), 2))
thinking_start_time = time.time()
if thinking_start_time is None:
logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。")
@@ -274,6 +426,7 @@ class DefaultExpressor:
reply_to=reply_to,
is_emoji=is_emoji,
thinking_id=thinking_id,
thinking_start_time=thinking_start_time,
)
try:
@@ -295,6 +448,7 @@ class DefaultExpressor:
except Exception as e:
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
traceback.print_exc()
# 这里可以选择是继续发送下一个片段还是中止
# 在尝试发送完所有片段后,完成原始的 thinking_id 状态
@@ -325,13 +479,13 @@ class DefaultExpressor:
reply_to: bool,
is_emoji: bool,
thinking_id: str,
thinking_start_time: float,
) -> MessageSending:
"""构建单个发送消息"""
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)
@@ -348,3 +502,40 @@ class DefaultExpressor:
)
return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt()

View File

@@ -77,8 +77,9 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
temperature=0.1,
max_tokens=256,
request_type="response_heartflow",
@@ -289,7 +290,7 @@ class ExpressionLearner:
# 构建prompt
prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt",
personality=global_config.expression_style,
personality=global_config.personality.expression_style,
)
# logger.info(f"个性表达方式提取prompt: {prompt}")

View File

@@ -14,15 +14,17 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor
from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
from src.chat.heart_flow.observation.memory_observation import MemoryObservation
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.heart_flow.observation.working_observation import WorkingObservation
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.focus_chat.memory_activator import MemoryActivator
from src.chat.focus_chat.info_processors.base_processor import BaseProcessor
from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
from src.chat.focus_chat.planners.planner import ActionPlanner
from src.chat.focus_chat.planners.action_factory import ActionManager
from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
install(extra_lines=3)
@@ -57,7 +59,7 @@ async def _handle_cycle_delay(action_taken_this_cycle: bool, cycle_start_time: f
class HeartFChatting:
"""
管理一个连续的Plan-Replier-Sender循环
管理一个连续的Focus Chat循环
用于在特定聊天流中生成回复。
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
"""
@@ -66,7 +68,6 @@ class HeartFChatting:
self,
chat_id: str,
observations: list[Observation],
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
):
"""
HeartFChatting 初始化函数
@@ -74,24 +75,27 @@ class HeartFChatting:
参数:
chat_id: 聊天流唯一标识符(如stream_id)
observations: 关联的观察列表
on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的异步回调函数
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: Optional[ChatStream] = None # 关联的聊天流
self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态
self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
self.log_prefix: str = str(chat_id) # Initial default, will be updated
self.memory_observation = MemoryObservation(observe_id=self.stream_id)
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
self.working_observation = WorkingObservation(observe_id=self.stream_id)
self.chatting_observation = observations[0]
self.memory_activator = MemoryActivator()
self.working_memory = WorkingMemory(chat_id=self.stream_id)
self.working_observation = WorkingMemoryObservation(
observe_id=self.stream_id, working_memory=self.working_memory
)
self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
self.hfcloop_observation.set_action_manager(self.action_manager)
self.all_observations = observations
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
self._register_default_processors()
@@ -108,9 +112,7 @@ class HeartFChatting:
self._cycle_counter = 0
self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息
self._current_cycle: Optional[CycleDetail] = None
self.total_no_reply_count: int = 0 # 连续不回复计数器
self._shutting_down: bool = False # 关闭标志位
self.total_waiting_time: float = 0.0 # 累计等待时间
async def _initialize(self) -> bool:
"""
@@ -151,6 +153,8 @@ class HeartFChatting:
self.processors.append(ChattingInfoProcessor())
self.processors.append(MindProcessor(subheartflow_id=self.stream_id))
self.processors.append(ToolProcessor(subheartflow_id=self.stream_id))
self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id))
self.processors.append(SelfProcessor(subheartflow_id=self.stream_id))
logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}")
async def start(self):
@@ -158,7 +162,7 @@ class HeartFChatting:
启动 HeartFChatting 的主循环。
注意:调用此方法前必须确保已经成功初始化。
"""
logger.info(f"{self.log_prefix} 开始认真水群(HFC)...")
logger.info(f"{self.log_prefix} 开始认真聊天(HFC)...")
await self._start_loop_if_needed()
async def _start_loop_if_needed(self):
@@ -328,6 +332,7 @@ class HeartFChatting:
f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}",
exc_info=True,
)
traceback.print_exc()
# 即使出错,也认为该任务结束了,已从 pending_tasks 中移除
if pending_tasks:
@@ -349,13 +354,12 @@ class HeartFChatting:
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]:
try:
with Timer("观察", cycle_timers):
await self.observations[0].observe()
await self.memory_observation.observe()
# await self.observations[0].observe()
await self.chatting_observation.observe()
await self.working_observation.observe()
await self.hfcloop_observation.observe()
observations: List[Observation] = []
observations.append(self.observations[0])
observations.append(self.memory_observation)
observations.append(self.chatting_observation)
observations.append(self.working_observation)
observations.append(self.hfcloop_observation)
@@ -363,6 +367,8 @@ class HeartFChatting:
"observations": observations,
}
self.all_observations = observations
with Timer("回忆", cycle_timers):
running_memorys = await self.memory_activator.activate_memory(observations)
@@ -395,8 +401,7 @@ class HeartFChatting:
elif action_type == "no_reply":
action_str = "不回复"
else:
action_type = "unknown"
action_str = "未知动作"
action_str = action_type
logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'")
@@ -452,14 +457,10 @@ class HeartFChatting:
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
observations=self.observations,
observations=self.all_observations,
expressor=self.expressor,
chat_stream=self.chat_stream,
current_cycle=self._current_cycle,
log_prefix=self.log_prefix,
on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback,
total_no_reply_count=self.total_no_reply_count,
total_waiting_time=self.total_waiting_time,
shutting_down=self._shutting_down,
)
@@ -470,14 +471,6 @@ class HeartFChatting:
# 处理动作并获取结果
success, reply_text = await action_handler.handle_action()
# 更新状态计数器
if action == "no_reply":
self.total_no_reply_count = getattr(action_handler, "total_no_reply_count", self.total_no_reply_count)
self.total_waiting_time = getattr(action_handler, "total_waiting_time", self.total_waiting_time)
elif action == "reply":
self.total_no_reply_count = 0
self.total_waiting_time = 0.0
return success, reply_text
except Exception as e:
@@ -526,5 +519,3 @@ class HeartFChatting:
if last_n is not None:
history = history[-last_n:]
return [cycle.to_dict() for cycle in history]

View File

@@ -106,6 +106,7 @@ class HeartFCSender:
and not message.is_private_message()
and message.reply.processed_plain_text != "[System Trigger Context]"
):
message.set_reply(message.reply)
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
await message.process()

View File

@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否包含过滤词
"""
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")

View File

@@ -6,43 +6,21 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
from typing import Union, Optional, Dict, Any
from src.common.database import db
from typing import Union, Optional
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import traceback
import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt")
def init_prompt():
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"heart_flow_prompt",
)
Prompt(
"""
你有以下信息可供参考:
@@ -69,7 +47,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只输出回复内容。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情at或 @等 )。只输出回复内容。
{moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"reasoning_prompt_main",
@@ -82,29 +60,6 @@ def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
# --- Template for HeartFChatting (FOCUSED mode) ---
Prompt(
"""
{info_from_tools}
你正在和 {sender_name} 私聊。
聊天记录如下:
{chat_talking_prompt}
现在你想要回复。
你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality}"
你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,然后给出日常且口语化的回复,平淡一些。
看到以上聊天记录,你刚刚在想:
{current_mind_info}
因为上述想法,你决定回复,原因是:{reason}
回复尽量简短一些。请注意把握聊天内容,{reply_style2}{prompt_ger},不要复读自己说的话
{reply_style1},说中文,不要刻意突出自身学科背景,注意只输出回复内容。
{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"heart_flow_private_prompt", # New template for private FOCUSED chat
)
# --- Template for NormalChat (CHAT mode) ---
Prompt(
"""
{memory_prompt}
@@ -126,118 +81,6 @@ def init_prompt():
)
async def _build_prompt_focus(
reason, current_mind_info, structured_info, chat_stream, sender_name, in_mind_reply, target_message
) -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
# Use sender_name passed from caller for private chat, otherwise use a default for group
# Default sender_name for group chat isn't used in the group prompt template, but set for consistency
effective_sender_name = sender_name if not is_group_chat else "某人"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=True,
)
if structured_info:
structured_info_prompt = await global_prompt_manager.format_prompt(
"info_from_tools", structured_info=structured_info
)
else:
structured_info_prompt = ""
# 从/data/expression/对应chat_id/expressions.json中读取表达方式
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
style_habbits = []
grammar_habbits = []
# 1. learnt_expressions加权随机选3条
if learnt_style_expressions:
weights = [expr["count"] for expr in learnt_style_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 2. learnt_grammar_expressions加权随机选3条
if learnt_grammar_expressions:
weights = [expr["count"] for expr in learnt_grammar_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
grammar_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 3. personality_expressions随机选1条
if personality_expressions:
expr = random.choice(personality_expressions)
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
logger.debug("开始构建 focus prompt")
# --- Choose template based on chat type ---
if is_group_chat:
template_name = "heart_flow_prompt"
# Group specific formatting variables (already fetched or default)
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
prompt = await global_prompt_manager.format_prompt(
template_name,
# info_from_tools=structured_info_prompt,
style_habbits=style_habbits_str,
grammar_habbits=grammar_habbits_str,
chat_target=chat_target_1, # Used in group template
# chat_talking_prompt=chat_talking_prompt,
chat_info=chat_talking_prompt,
bot_name=global_config.BOT_NICKNAME,
# prompt_personality=prompt_personality,
prompt_personality="",
reason=reason,
in_mind_reply=in_mind_reply,
target_message=target_message,
# moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
# sender_name is not used in the group template
)
else: # Private chat
template_name = "heart_flow_private_prompt"
prompt = await global_prompt_manager.format_prompt(
template_name,
info_from_tools=structured_info_prompt,
sender_name=effective_sender_name, # Used in private template
chat_talking_prompt=chat_talking_prompt,
bot_name=global_config.BOT_NICKNAME,
prompt_personality=prompt_personality,
# chat_target and chat_target_2 are not used in private template
current_mind_info=current_mind_info,
reason=reason,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
)
# --- End choosing template ---
# logger.debug(f"focus_chat_prompt (is_group={is_group_chat}): \n{prompt}")
return prompt
class PromptBuilder:
def __init__(self):
self.prompt_built = ""
@@ -257,17 +100,6 @@ class PromptBuilder:
) -> Optional[str]:
if build_mode == "normal":
return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name)
elif build_mode == "focus":
return await _build_prompt_focus(
reason,
current_mind_info,
structured_info,
chat_stream,
sender_name,
in_mind_reply,
target_message,
)
return None
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str:
@@ -280,7 +112,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
@@ -328,7 +160,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -340,18 +172,15 @@ class PromptBuilder:
# 关键词检测与反应
keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
logger.info(
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
for rule in global_config.keyword_reaction.rules:
if rule.enable:
if any(keyword in message_txt for keyword in rule.keywords):
logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
keywords_reaction_prompt += f"{rule.reaction}"
else:
for pattern in rule.get("regex", []):
result = pattern.search(message_txt)
if result:
reaction = rule.get("reaction", "")
for pattern in rule.regex:
if result := pattern.search(message_txt):
reaction = rule.reaction
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
@@ -397,15 +226,16 @@ class PromptBuilder:
chat_target_2=chat_target_2,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
# moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
moderation_prompt="",
)
else:
template_name = "reasoning_prompt_private_main"
@@ -419,15 +249,16 @@ class PromptBuilder:
prompt_info=prompt_info,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
# moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
moderation_prompt="",
)
# --- End choosing template ---
@@ -439,30 +270,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
@@ -572,8 +379,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
@@ -602,14 +407,14 @@ class PromptBuilder:
return related_info
else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -625,103 +430,69 @@ class PromptBuilder:
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
results_with_similarity = []
try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not results:
if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else []
if return_raw:
return results
return limited_results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
return "\n".join(str(result["content"]) for result in limited_results)
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
init_prompt()

View File

@@ -0,0 +1,83 @@
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class ActionInfo(InfoBase):
"""动作信息类
用于管理和记录动作的变更信息,包括需要添加或移除的动作。
继承自 InfoBase 类,使用字典存储具体数据。
Attributes:
type (str): 信息类型标识符,固定为 "action"
Data Fields:
add_actions (List[str]): 需要添加的动作列表
remove_actions (List[str]): 需要移除的动作列表
reason (str): 变更原因说明
"""
type: str = "action"
def get_type(self) -> str:
"""获取信息类型"""
return self.type
def get_data(self) -> Dict[str, Any]:
"""获取信息数据"""
return self.data
def set_action_changes(self, action_changes: Dict[str, List[str]]) -> None:
"""设置动作变更信息
Args:
action_changes (Dict[str, List[str]]): 包含要增加和删除的动作列表
{
"add": ["action1", "action2"],
"remove": ["action3"]
}
"""
self.data["add_actions"] = action_changes.get("add", [])
self.data["remove_actions"] = action_changes.get("remove", [])
def set_reason(self, reason: str) -> None:
"""设置变更原因
Args:
reason (str): 动作变更的原因说明
"""
self.data["reason"] = reason
def get_add_actions(self) -> List[str]:
"""获取需要添加的动作列表
Returns:
List[str]: 需要添加的动作列表
"""
return self.data.get("add_actions", [])
def get_remove_actions(self) -> List[str]:
"""获取需要移除的动作列表
Returns:
List[str]: 需要移除的动作列表
"""
return self.data.get("remove_actions", [])
def get_reason(self) -> Optional[str]:
"""获取变更原因
Returns:
Optional[str]: 动作变更的原因说明,如果未设置则返回 None
"""
return self.data.get("reason")
def has_changes(self) -> bool:
"""检查是否有动作变更
Returns:
bool: 如果有任何动作需要添加或移除则返回True
"""
return bool(self.get_add_actions() or self.get_remove_actions())

View File

@@ -17,6 +17,7 @@ class InfoBase:
type: str = "base"
data: Dict[str, Any] = field(default_factory=dict)
processed_info: str = ""
def get_type(self) -> str:
"""获取信息类型
@@ -58,3 +59,11 @@ class InfoBase:
if isinstance(value, list):
return value
return []
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息字符串
"""
return self.processed_info

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class SelfInfo(InfoBase):
"""思维信息类
用于存储和管理当前思维状态的信息。
Attributes:
type (str): 信息类型标识符,默认为 "mind"
data (Dict[str, Any]): 包含 current_mind 的数据字典
"""
type: str = "self"
def get_self_info(self) -> str:
"""获取当前思维状态
Returns:
str: 当前思维状态
"""
return self.get_info("self_info") or ""
def set_self_info(self, self_info: str) -> None:
"""设置当前思维状态
Args:
self_info: 要设置的思维状态
"""
self.data["self_info"] = self_info
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息
"""
return self.get_self_info()

View File

@@ -0,0 +1,89 @@
from typing import Dict, Optional, List
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class WorkingMemoryInfo(InfoBase):
type: str = "workingmemory"
processed_info: str = ""
def set_talking_message(self, message: str) -> None:
"""设置说话消息
Args:
message (str): 说话消息内容
"""
self.data["talking_message"] = message
def set_working_memory(self, working_memory: List[str]) -> None:
"""设置工作记忆
Args:
working_memory (str): 工作记忆内容
"""
self.data["working_memory"] = working_memory
def add_working_memory(self, working_memory: str) -> None:
"""添加工作记忆
Args:
working_memory (str): 工作记忆内容
"""
working_memory_list = self.data.get("working_memory", [])
# print(f"working_memory_list: {working_memory_list}")
working_memory_list.append(working_memory)
# print(f"working_memory_list: {working_memory_list}")
self.data["working_memory"] = working_memory_list
def get_working_memory(self) -> List[str]:
"""获取工作记忆
Returns:
List[str]: 工作记忆内容
"""
return self.data.get("working_memory", [])
def get_type(self) -> str:
"""获取信息类型
Returns:
str: 当前信息对象的类型标识符
"""
return self.type
def get_data(self) -> Dict[str, str]:
"""获取所有信息数据
Returns:
Dict[str, str]: 包含所有信息数据的字典
"""
return self.data
def get_info(self, key: str) -> Optional[str]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
Optional[str]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def get_processed_info(self) -> Dict[str, str]:
"""获取处理后的信息
Returns:
Dict[str, str]: 处理后的信息数据
"""
all_memory = self.get_working_memory()
# print(f"all_memory: {all_memory}")
memory_str = ""
for memory in all_memory:
memory_str += f"{memory}\n"
self.processed_info = memory_str
return self.processed_info

View File

@@ -0,0 +1,126 @@
from typing import List, Optional, Any
from src.chat.focus_chat.info.obs_info import ObsInfo
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.action_info import ActionInfo
from .base_processor import BaseProcessor
from src.common.logger_manager import get_logger
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.focus_chat.info.cycle_info import CycleInfo
from datetime import datetime
from typing import Dict
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import random
logger = get_logger("processor")
class ActionProcessor(BaseProcessor):
"""动作处理器
用于处理Observation对象将其转换为ObsInfo对象。
"""
log_prefix = "聊天信息处理"
def __init__(self):
"""初始化观察处理器"""
super().__init__()
# TODO: API-Adapter修改标记
self.model_summary = LLMRequest(
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def process_info(
self,
observations: Optional[List[Observation]] = None,
running_memorys: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[InfoBase]:
"""处理Observation对象
Args:
infos: InfoBase对象列表
observations: 可选的Observation对象列表
**kwargs: 其他可选参数
Returns:
List[InfoBase]: 处理后的ObsInfo实例列表
"""
# print(f"observations: {observations}")
processed_infos = []
# 处理Observation对象
if observations:
for obs in observations:
if isinstance(obs, HFCloopObservation):
# 创建动作信息
action_info = ActionInfo()
action_changes = await self.analyze_loop_actions(obs)
if action_changes["add"] or action_changes["remove"]:
action_info.set_action_changes(action_changes)
# 设置变更原因
reasons = []
if action_changes["add"]:
reasons.append(f"添加动作{action_changes['add']}因为检测到大量无回复")
if action_changes["remove"]:
reasons.append(f"移除动作{action_changes['remove']}因为检测到连续回复")
action_info.set_reason(" | ".join(reasons))
processed_infos.append(action_info)
return processed_infos
async def analyze_loop_actions(self, obs: HFCloopObservation) -> Dict[str, List[str]]:
"""分析最近的循环内容并决定动作的增减
Returns:
Dict[str, List[str]]: 包含要增加和删除的动作
{
"add": ["action1", "action2"],
"remove": ["action3"]
}
"""
result = {"add": [], "remove": []}
# 获取最近10次循环
recent_cycles = obs.history_loop[-10:] if len(obs.history_loop) > 10 else obs.history_loop
if not recent_cycles:
return result
# 统计no_reply的数量
no_reply_count = 0
reply_sequence = [] # 记录最近的动作序列
for cycle in recent_cycles:
action_type = cycle.loop_plan_info["action_result"]["action_type"]
if action_type == "no_reply":
no_reply_count += 1
reply_sequence.append(action_type == "reply")
# 检查no_reply比例
if len(recent_cycles) >= 5 and (no_reply_count / len(recent_cycles)) >= 0.8:
result["add"].append("exit_focus_chat")
# 获取最近三次的reply状态
last_three = reply_sequence[-3:] if len(reply_sequence) >= 3 else reply_sequence
# 根据最近的reply情况决定是否移除reply动作
if len(last_three) >= 3 and all(last_three):
# 如果最近三次都是reply直接移除
result["remove"].append("reply")
elif len(last_three) >= 2 and all(last_three[-2:]):
# 如果最近两次都是reply40%概率移除
if random.random() < 0.4:
result["remove"].append("reply")
elif last_three and last_three[-1]:
# 如果最近一次是reply20%概率移除
if random.random() < 0.2:
result["remove"].append("reply")
return result

View File

@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
def __init__(self):
"""初始化观察处理器"""
super().__init__()
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
# TODO: API-Adapter修改标记
self.model_summary = LLMRequest(
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def process_info(
@@ -54,19 +55,24 @@ class ChattingInfoProcessor(BaseProcessor):
for obs in observations:
# print(f"obs: {obs}")
if isinstance(obs, ChattingObservation):
# print("1111111111111111111111读取111111111111111")
obs_info = ObsInfo()
await self.chat_compress(obs)
# 设置说话消息
if hasattr(obs, "talking_message_str"):
# print(f"设置说话消息obs.talking_message_str: {obs.talking_message_str}")
obs_info.set_talking_message(obs.talking_message_str)
# 设置截断后的说话消息
if hasattr(obs, "talking_message_str_truncate"):
# print(f"设置截断后的说话消息obs.talking_message_str_truncate: {obs.talking_message_str_truncate}")
obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate)
if hasattr(obs, "mid_memory_info"):
# print(f"设置之前聊天信息obs.mid_memory_info: {obs.mid_memory_info}")
obs_info.set_previous_chat_info(obs.mid_memory_info)
# 设置聊天类型
@@ -91,7 +97,7 @@ class ChattingInfoProcessor(BaseProcessor):
async def chat_compress(self, obs: ChattingObservation):
if obs.compressor_prompt:
try:
summary_result, _, _ = await self.llm_summary.generate_response(obs.compressor_prompt)
summary_result, _, _ = await self.model_summary.generate_response(obs.compressor_prompt)
summary = "没有主题的闲聊" # 默认值
if summary_result: # 确保结果不为空
summary = summary_result
@@ -108,12 +114,12 @@ class ChattingInfoProcessor(BaseProcessor):
"created_at": datetime.now().timestamp(),
}
obs.mid_memorys.append(mid_memory)
if len(obs.mid_memorys) > obs.max_mid_memory_len:
obs.mid_memorys.pop(0) # 移除最旧的
obs.mid_memories.append(mid_memory)
if len(obs.mid_memories) > obs.max_mid_memory_len:
obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n"
for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分
for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n"

View File

@@ -6,21 +6,14 @@ import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
import difflib
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.focus_chat.info_processors.processor_utils import (
calculate_similarity,
calculate_replacement_probability,
get_spark,
)
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
@@ -28,7 +21,6 @@ logger = get_logger("processor")
def init_prompt():
# --- Group Chat Prompt ---
group_prompt = """
你的名字是{bot_name}
{memory_str}
@@ -44,31 +36,29 @@ def init_prompt():
现在请你继续输出观察和规划,输出要求:
1. 先关注未读新消息的内容和近期回复历史
2. 根据新信息,修改和删除之前的观察和规划
3. 根据聊天内容继续输出观察和规划{hf_do_next}
3. 根据聊天内容继续输出观察和规划
4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
Prompt(group_prompt, "sub_heartflow_prompt_before")
# --- Private Chat Prompt ---
private_prompt = """
你的名字是{bot_name}
{memory_str}
{extra_info}
{relation_prompt}
你的名字是{bot_name},{prompt_personality},你现在{mood_info}
{cycle_info_block}
现在是{time_now},你正在上网,和 {chat_target_name} 私聊,以下是你们的聊天内容:
现在是{time_now},你正在上网,和qq群里的网友们聊天以下是正在进行的聊天内容:
{chat_observe_info}
以下是你之前对聊天的观察和规划:
以下是你之前对聊天的观察和规划,你的名字是{bot_name}
{last_mind}
请仔细阅读聊天内容,想想你和 {chat_target_name} 的关系,回顾你们刚刚的交流,你刚刚发言和对方的反应,思考聊天的主题。
请思考你要不要回复以及如何回复对方。
思考并输出你的内心想法
输出要求:
1. 根据聊天内容生成你的想法,{hf_do_next}
2. 不要分点、不要使用表情符号
3. 避免多余符号(冒号、引号、括号等)
4. 语言简洁自然,不要浮夸
5. 如果你刚发言,对方没有回复你,请谨慎回复"""
现在请你继续输出观察和规划,输出要求:
1. 先关注未读新消息的内容和近期回复历史
2. 根据新信息,修改和删除之前的观察和规划
3. 根据聊天内容继续输出观察和规划
4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
Prompt(private_prompt, "sub_heartflow_prompt_private_before")
@@ -81,8 +71,8 @@ class MindProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.llm_sub_heartflow,
temperature=global_config.llm_sub_heartflow["temp"],
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="sub_heart_flow",
)
@@ -210,45 +200,26 @@ class MindProcessor(BaseProcessor):
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
# 构建个性部分
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
spark_prompt = get_spark()
# ---------- 5. 构建最终提示词 ----------
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
prompt = (await global_prompt_manager.get_prompt_async(template_name)).format(
bot_name=individuality.name,
memory_str=memory_str,
extra_info=self.structured_info_str,
# prompt_personality=prompt_personality,
relation_prompt=relation_prompt,
bot_name=individuality.name,
time_now=time_now,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_observe_info,
# mood_info="mood_info",
hf_do_next=spark_prompt,
last_mind=previous_mind,
cycle_info_block=hfcloop_observe_info,
chat_target_name=chat_target_name,
)
# 在构建完提示词后生成最终的prompt字符串
final_prompt = prompt
content = "" # 初始化内容变量
content = "(不知道该想些什么...)"
try:
# 调用LLM生成响应
response, _ = await self.llm_model.generate_response_async(prompt=final_prompt)
# 直接使用LLM返回的文本响应作为 content
content = response if response else ""
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果思考失败。")
except Exception as e:
# 处理总体异常
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
@@ -256,16 +227,8 @@ class MindProcessor(BaseProcessor):
content = "思考过程中出现错误"
# 记录初步思考结果
logger.debug(f"{self.log_prefix} 思考prompt: \n{final_prompt}\n")
# 处理空响应情况
if not content:
content = "(不知道该想些什么...)"
logger.warning(f"{self.log_prefix} LLM返回空结果思考失败。")
# ---------- 8. 更新思考状态并返回结果 ----------
logger.debug(f"{self.log_prefix} 思考prompt: \n{prompt}\n")
logger.info(f"{self.log_prefix} 思考结果: {content}")
# 更新当前思考内容
self.update_current_mind(content)
return content
@@ -275,138 +238,5 @@ class MindProcessor(BaseProcessor):
self.past_mind.append(self.current_mind)
self.current_mind = response
def de_similar(self, previous_mind, new_content):
try:
similarity = calculate_similarity(previous_mind, new_content)
replacement_prob = calculate_replacement_probability(similarity)
logger.debug(f"{self.log_prefix} 新旧想法相似度: {similarity:.2f}, 替换概率: {replacement_prob:.2f}")
# 定义词语列表 (移到判断之前)
yu_qi_ci_liebiao = ["", "", "", "", "", ""]
zhuan_zhe_liebiao = ["但是", "不过", "然而", "可是", "只是"]
cheng_jie_liebiao = ["然后", "接着", "此外", "而且", "另外"]
zhuan_jie_ci_liebiao = zhuan_zhe_liebiao + cheng_jie_liebiao
if random.random() < replacement_prob:
# 相似度非常高时,尝试去重或特殊处理
if similarity == 1.0:
logger.debug(f"{self.log_prefix} 想法完全重复 (相似度 1.0),执行特殊处理...")
# 随机截取大约一半内容
if len(new_content) > 1: # 避免内容过短无法截取
split_point = max(
1, len(new_content) // 2 + random.randint(-len(new_content) // 4, len(new_content) // 4)
)
truncated_content = new_content[:split_point]
else:
truncated_content = new_content # 如果只有一个字符或者为空,就不截取了
# 添加语气词和转折/承接词
yu_qi_ci = random.choice(yu_qi_ci_liebiao)
zhuan_jie_ci = random.choice(zhuan_jie_ci_liebiao)
content = f"{yu_qi_ci}{zhuan_jie_ci}{truncated_content}"
logger.debug(f"{self.log_prefix} 想法重复,特殊处理后: {content}")
else:
# 相似度较高但非100%,执行标准去重逻辑
logger.debug(f"{self.log_prefix} 执行概率性去重 (概率: {replacement_prob:.2f})...")
logger.debug(
f"{self.log_prefix} previous_mind类型: {type(previous_mind)}, new_content类型: {type(new_content)}"
)
matcher = difflib.SequenceMatcher(None, previous_mind, new_content)
logger.debug(f"{self.log_prefix} matcher类型: {type(matcher)}")
deduplicated_parts = []
last_match_end_in_b = 0
# 获取并记录所有匹配块
matching_blocks = matcher.get_matching_blocks()
logger.debug(f"{self.log_prefix} 匹配块数量: {len(matching_blocks)}")
logger.debug(
f"{self.log_prefix} 匹配块示例(前3个): {matching_blocks[:3] if len(matching_blocks) > 3 else matching_blocks}"
)
# get_matching_blocks()返回形如[(i, j, n), ...]的列表其中i是a中的索引j是b中的索引n是匹配的长度
for idx, match in enumerate(matching_blocks):
if not isinstance(match, tuple):
logger.error(f"{self.log_prefix} 匹配块 {idx} 不是元组类型,而是 {type(match)}: {match}")
continue
try:
_i, j, n = match # 解包元组为三个变量
logger.debug(f"{self.log_prefix} 匹配块 {idx}: i={_i}, j={j}, n={n}")
if last_match_end_in_b < j:
# 确保添加的是字符串,而不是元组
try:
non_matching_part = new_content[last_match_end_in_b:j]
logger.debug(
f"{self.log_prefix} 添加非匹配部分: '{non_matching_part}', 类型: {type(non_matching_part)}"
)
if not isinstance(non_matching_part, str):
logger.warning(
f"{self.log_prefix} 非匹配部分不是字符串类型: {type(non_matching_part)}"
)
non_matching_part = str(non_matching_part)
deduplicated_parts.append(non_matching_part)
except Exception as e:
logger.error(f"{self.log_prefix} 处理非匹配部分时出错: {e}")
logger.error(traceback.format_exc())
last_match_end_in_b = j + n
except Exception as e:
logger.error(f"{self.log_prefix} 处理匹配块时出错: {e}")
logger.error(traceback.format_exc())
logger.debug(f"{self.log_prefix} 去重前部分列表: {deduplicated_parts}")
logger.debug(f"{self.log_prefix} 列表元素类型: {[type(part) for part in deduplicated_parts]}")
# 确保所有元素都是字符串
deduplicated_parts = [str(part) for part in deduplicated_parts]
# 防止列表为空
if not deduplicated_parts:
logger.warning(f"{self.log_prefix} 去重后列表为空,添加空字符串")
deduplicated_parts = [""]
logger.debug(f"{self.log_prefix} 处理后的部分列表: {deduplicated_parts}")
try:
deduplicated_content = "".join(deduplicated_parts).strip()
logger.debug(f"{self.log_prefix} 拼接后的去重内容: '{deduplicated_content}'")
except Exception as e:
logger.error(f"{self.log_prefix} 拼接去重内容时出错: {e}")
logger.error(traceback.format_exc())
deduplicated_content = ""
if deduplicated_content:
# 根据概率决定是否添加词语
prefix_str = ""
if random.random() < 0.3: # 30% 概率添加语气词
prefix_str += random.choice(yu_qi_ci_liebiao)
if random.random() < 0.7: # 70% 概率添加转折/承接词
prefix_str += random.choice(zhuan_jie_ci_liebiao)
# 组合最终结果
if prefix_str:
content = f"{prefix_str}{deduplicated_content}" # 更新 content
logger.debug(f"{self.log_prefix} 去重并添加引导词后: {content}")
else:
content = deduplicated_content # 更新 content
logger.debug(f"{self.log_prefix} 去重后 (未添加引导词): {content}")
else:
logger.warning(f"{self.log_prefix} 去重后内容为空保留原始LLM输出: {new_content}")
content = new_content # 保留原始 content
else:
logger.debug(f"{self.log_prefix} 未执行概率性去重 (概率: {replacement_prob:.2f})")
# content 保持 new_content 不变
except Exception as e:
logger.error(f"{self.log_prefix} 应用概率性去重或特殊处理时出错: {e}")
logger.error(traceback.format_exc())
# 出错时保留原始 content
content = new_content
return content
init_prompt()

View File

@@ -1,56 +0,0 @@
import difflib
import random
import time
def calculate_similarity(text_a: str, text_b: str) -> float:
"""
计算两个文本字符串的相似度。
"""
if not text_a or not text_b:
return 0.0
matcher = difflib.SequenceMatcher(None, text_a, text_b)
return matcher.ratio()
def calculate_replacement_probability(similarity: float) -> float:
"""
根据相似度计算替换的概率。
规则:
- 相似度 <= 0.4: 概率 = 0
- 相似度 >= 0.9: 概率 = 1
- 相似度 == 0.6: 概率 = 0.7
- 0.4 < 相似度 <= 0.6: 线性插值 (0.4, 0) 到 (0.6, 0.7)
- 0.6 < 相似度 < 0.9: 线性插值 (0.6, 0.7) 到 (0.9, 1.0)
"""
if similarity <= 0.4:
return 0.0
elif similarity >= 0.9:
return 1.0
elif 0.4 < similarity <= 0.6:
# p = 3.5 * s - 1.4
probability = 3.5 * similarity - 1.4
return max(0.0, probability)
else: # 0.6 < similarity < 0.9
# p = s + 0.1
probability = similarity + 0.1
return min(1.0, max(0.0, probability))
def get_spark():
local_random = random.Random()
current_minute = int(time.strftime("%M"))
local_random.seed(current_minute)
hf_options = [
("可以参考之前的想法,在原来想法的基础上继续思考", 0.2),
("可以参考之前的想法,在原来的想法上尝试新的话题", 0.4),
("不要太深入", 0.2),
("进行深入思考", 0.2),
]
# 加权随机选择思考指导
hf_do_next = local_random.choices(
[option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1
)[0]
return hf_do_next

View File

@@ -0,0 +1,164 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.self_info import SelfInfo
logger = get_logger("processor")
def init_prompt():
indentify_prompt = """
你的名字是{bot_name},你是一个AI但是你伪装成了一个人类你的人格是{prompt_personality}
你的头像形象是一只橙色的鱼,头上有绿色的树叶。
{relation_prompt}
{memory_str}
现在是{time_now}你正在上网和qq群里的网友们聊天以下是正在进行的聊天内容
{chat_observe_info}
现在请你根据现有的信息,思考自我认同
1. 你是一个什么样的人,你和群里的人关系如何
2. 思考有没有人提到你,或者图片与你有关
3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
请回复的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出内容。
"""
Prompt(indentify_prompt, "indentify_prompt")
class SelfProcessor(BaseProcessor):
log_prefix = "自我认同"
def __init__(self, subheartflow_id: str):
super().__init__()
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="self_identify",
)
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
"""处理信息对象
Args:
*infos: 可变数量的InfoBase类型的信息对象
Returns:
List[InfoBase]: 处理后的结构化信息列表
"""
self_info_str = await self.self_indentify(observations, running_memorys)
if self_info_str:
self_info = SelfInfo()
self_info.set_self_info(self_info_str)
else:
self_info = None
return None
return [self_info]
async def self_indentify(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None
):
"""
在回复前进行思考,生成内心想法并收集工具调用结果
参数:
observations: 观察信息
返回:
如果return_prompt为False:
tuple: (current_mind, past_mind) 当前想法和过去的想法列表
如果return_prompt为True:
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
"""
memory_str = ""
if running_memorys:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memorys:
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
if observations is None:
observations = []
for observation in observations:
if isinstance(observation, ChattingObservation):
# 获取聊天元信息
is_group_chat = observation.is_group_chat
chat_target_info = observation.chat_target_info
chat_target_name = "对方" # 私聊默认名称
if not is_group_chat and chat_target_info:
# 优先使用person_name其次user_nickname最后回退到默认值
chat_target_name = (
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name
)
# 获取聊天内容
chat_observe_info = observation.get_observe_info()
person_list = observation.person_list
if isinstance(observation, HFCloopObservation):
# hfcloop_observe_info = observation.get_observe_info()
pass
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
relation_prompt = ""
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
bot_name=individuality.name,
prompt_personality=personality_block,
memory_str=memory_str,
relation_prompt=relation_prompt,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_observe_info,
)
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果自我识别失败。")
except Exception as e:
# 处理总体异常
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
content = "自我识别过程中出现错误"
if content == "None":
content = ""
# 记录初步思考结果
logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
logger.info(f"{self.log_prefix} 自我识别结果: {content}")
return content
init_prompt()

View File

@@ -11,8 +11,8 @@ from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from typing import List, Optional, Dict
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.working_observation import WorkingObservation
from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.chat.heart_flow.observation.structure_observation import StructureObservation
logger = get_logger("processor")
@@ -24,9 +24,6 @@ def init_prompt():
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
你要在群聊中扮演以下角色:
{prompt_personality}
你当前的额外信息:
{memory_str}
@@ -52,7 +49,7 @@ class ToolProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
self.llm_model = LLMRequest(
model=global_config.llm_tool_use,
model=global_config.model.tool_use,
max_tokens=500,
request_type="tool_execution",
)
@@ -70,6 +67,8 @@ class ToolProcessor(BaseProcessor):
list: 处理后的结构化信息列表
"""
working_infos = []
if observations:
for observation in observations:
if isinstance(observation, ChattingObservation):
@@ -77,7 +76,7 @@ class ToolProcessor(BaseProcessor):
# 更新WorkingObservation中的结构化信息
for observation in observations:
if isinstance(observation, WorkingObservation):
if isinstance(observation, StructureObservation):
for structured_info in result:
logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
observation.add_structured_info(structured_info)
@@ -86,6 +85,7 @@ class ToolProcessor(BaseProcessor):
logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}")
structured_info = StructuredInfo()
if working_infos:
for working_info in working_infos:
structured_info.set_info(working_info.get("type"), working_info.get("content"))
@@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor):
# 获取个性信息
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=2)
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取时间信息
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -148,14 +148,14 @@ class ToolProcessor(BaseProcessor):
# chat_target_name=chat_target_name,
is_group_chat=is_group_chat,
# relation_prompt=relation_prompt,
prompt_personality=prompt_personality,
# prompt_personality=prompt_personality,
# mood_info=mood_info,
bot_name=individuality.name,
time_now=time_now,
)
# 调用LLM专注于工具使用
logger.debug(f"开始执行工具调用{prompt}")
# logger.debug(f"开始执行工具调用{prompt}")
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
logger.debug(f"获取到工具原始输出:\n{tool_calls}")

View File

@@ -0,0 +1,236 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
from json_repair import repair_json
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
import asyncio
import json
logger = get_logger("processor")
def init_prompt():
memory_proces_prompt = """
你的名字是{bot_name}
现在是{time_now}你正在上网和qq群里的网友们聊天以下是正在进行的聊天内容
{chat_observe_info}
以下是你已经总结的记忆摘要你可以调取这些记忆查看内容来帮助你聊天不要一次调取太多记忆最多调取3个左右记忆
{memory_str}
观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true否则输出 false
如果当前聊天记录的内容已经被总结千万不要总结新记忆输出false
如果已经总结的记忆包含了当前聊天记录的内容千万不要总结新记忆输出false
如果已经总结的记忆摘要,包含了当前聊天记录的内容千万不要总结新记忆输出false
如果有相近的记忆请合并记忆输出merge_memory格式为[["id1", "id2"], ["id3", "id4"],...]你可以进行多组合并但是每组合并只能有两个记忆id不要输出其他内容
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆以JSON格式输出格式如下
```json
{{
"selected_memory_ids": ["id1", "id2", ...],
"new_memory": "true" or "false",
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
}}
```
"""
Prompt(memory_proces_prompt, "prompt_memory_proces")
class WorkingMemoryProcessor(BaseProcessor):
log_prefix = "工作记忆"
def __init__(self, subheartflow_id: str):
super().__init__()
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="working_memory",
)
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
"""处理信息对象
Args:
*infos: 可变数量的InfoBase类型的信息对象
Returns:
List[InfoBase]: 处理后的结构化信息列表
"""
working_memory = None
chat_info = ""
try:
for observation in observations:
if isinstance(observation, WorkingMemoryObservation):
working_memory = observation.get_observe_info()
# working_memory_obs = observation
if isinstance(observation, ChattingObservation):
chat_info = observation.get_observe_info()
# chat_info_truncate = observation.talking_message_str_truncate
if not working_memory:
logger.warning(f"{self.log_prefix} 没有找到工作记忆对象")
mind_info = MindInfo()
return [mind_info]
except Exception as e:
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
logger.error(traceback.format_exc())
return []
all_memory = working_memory.get_all_memories()
memory_prompts = []
for memory in all_memory:
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.bot.nickname,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str,
)
# 调用LLM处理记忆
content = ""
try:
logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
# 解析LLM返回的JSON
try:
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
new_memory = result.get("new_memory", "")
merge_memory = result.get("merge_memory", [])
except Exception as e:
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
# 根据selected_memory_ids调取记忆
memory_str = ""
if selected_memory_ids:
for memory_id in selected_memory_ids:
memory = await working_memory.retrieve_memory(memory_id)
if memory:
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
for keypoint in memory_keypoints:
memory_str += f"记忆要点:{keypoint}\n"
for event in memory_events:
memory_str += f"记忆事件:{event}\n"
# memory_str += f"记忆摘要:{memory_detailed}\n"
# memory_str += f"记忆主题:{memory_brief}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.warning(f"{self.log_prefix} 没有找到工作记忆")
# 根据聊天内容添加新记忆
if new_memory:
# 使用异步方式添加新记忆,不阻塞主流程
logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
asyncio.create_task(self.add_memory_async(working_memory, chat_info))
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
if memory1 and memory2:
memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
async def add_memory_async(self, working_memory: WorkingMemory, content: str):
"""异步添加记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
content: 记忆内容
"""
try:
await working_memory.add_memory(content=content, from_source="chat_text")
logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...")
except Exception as e:
logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
logger.error(traceback.format_exc())
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
"""异步合并记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
memory_str: 记忆内容
"""
try:
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1}{memory_id2}...")
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}")
logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}")
logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}")
logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}")
except Exception as e:
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
logger.error(traceback.format_exc())
init_prompt()

View File

@@ -1,5 +1,5 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.working_observation import WorkingObservation
from src.chat.heart_flow.observation.structure_observation import StructureObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
@@ -34,8 +34,9 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
# TODO: API-Adapter修改标记
self.summary_model = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
)
self.running_memory = []
@@ -53,7 +54,7 @@ class MemoryActivator:
for observation in observations:
if isinstance(observation, ChattingObservation):
obs_info_text += observation.get_observe_info()
elif isinstance(observation, WorkingObservation):
elif isinstance(observation, StructureObservation):
working_info = observation.get_observe_info()
for working_info_item in working_info:
obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n"

View File

@@ -1,18 +1,18 @@
from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union
import os
import importlib
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS
from typing import Dict, List, Optional, Callable, Coroutine, Type, Any
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.common.logger_manager import get_logger
import importlib
import pkgutil
import os
# 导入动作类,确保装饰器被执行
from src.chat.focus_chat.planners.actions.reply_action import ReplyAction
from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction
import src.chat.focus_chat.planners.actions # noqa
logger = get_logger("action_factory")
logger = get_logger("action_manager")
# 定义动作信息类型
ActionInfo = Dict[str, Any]
@@ -38,14 +38,12 @@ class ActionManager:
# 加载所有已注册动作
self._load_registered_actions()
# 加载插件动作
self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
# logger.info(f"当前可用动作: {list(self._using_actions.keys())}")
# for action_name, action_info in self._using_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
def _load_registered_actions(self) -> None:
"""
加载所有通过装饰器注册的动作
@@ -54,23 +52,24 @@ class ActionManager:
# 从_ACTION_REGISTRY获取所有已注册动作
for action_name, action_class in _ACTION_REGISTRY.items():
# 获取动作相关信息
action_description:str = getattr(action_class, "action_description", "")
action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {})
action_require:list[str] = getattr(action_class, "action_require", [])
is_default:bool = getattr(action_class, "default", False)
# 不读取插件动作和基类
if action_name == "base_action" or action_name == "plugin_action":
continue
action_description: str = getattr(action_class, "action_description", "")
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
action_require: list[str] = getattr(action_class, "action_require", [])
is_default: bool = getattr(action_class, "default", False)
if action_name and action_description:
# 创建动作信息字典
action_info = {
"description": action_description,
"parameters": action_parameters,
"require": action_require
"require": action_require,
}
# 注册2
print("注册2")
print(action_info)
# 添加到所有已注册的动作
self._registered_actions[action_name] = action_info
@@ -78,14 +77,56 @@ class ActionManager:
if is_default:
self._default_actions[action_name] = action_info
logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
logger.info(f"默认动作: {list(self._default_actions.keys())}")
# logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
# logger.info(f"默认动作: {list(self._default_actions.keys())}")
# for action_name, action_info in self._default_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
except Exception as e:
logger.error(f"加载已注册动作失败: {e}")
def _load_plugin_actions(self) -> None:
"""
加载所有插件目录中的动作
"""
try:
# 检查插件目录是否存在
plugin_path = "src.plugins"
plugin_dir = plugin_path.replace(".", os.path.sep)
if not os.path.exists(plugin_dir):
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
return
# 导入插件包
try:
plugins_package = importlib.import_module(plugin_path)
except ImportError as e:
logger.error(f"导入插件包失败: {e}")
return
# 遍历插件包中的所有子包
for _, plugin_name, is_pkg in pkgutil.iter_modules(
plugins_package.__path__, plugins_package.__name__ + "."
):
if not is_pkg:
continue
# 检查插件是否有actions子包
plugin_actions_path = f"{plugin_name}.actions"
try:
# 尝试导入插件的actions包
importlib.import_module(plugin_actions_path)
logger.info(f"成功加载插件动作模块: {plugin_actions_path}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
continue
# 再次从_ACTION_REGISTRY获取所有动作包括刚刚从插件加载的
self._load_registered_actions()
except Exception as e:
logger.error(f"加载插件动作失败: {e}")
def create_action(
self,
action_name: str,
@@ -96,11 +137,7 @@ class ActionManager:
observations: List[Observation],
expressor: DefaultExpressor,
chat_stream: ChatStream,
current_cycle: CycleDetail,
log_prefix: str,
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
total_no_reply_count: int = 0,
total_waiting_time: float = 0.0,
shutting_down: bool = False,
) -> Optional[BaseAction]:
"""
@@ -115,11 +152,7 @@ class ActionManager:
observations: 观察列表
expressor: 表达器
chat_stream: 聊天流
current_cycle: 当前循环信息
log_prefix: 日志前缀
on_consecutive_no_reply_callback: 连续不回复回调
total_no_reply_count: 连续不回复计数
total_waiting_time: 累计等待时间
shutting_down: 是否正在关闭
Returns:
@@ -136,22 +169,17 @@ class ActionManager:
return None
try:
# 创建动作实例并传递所有必要参数
# 创建动作实例
instance = handler_class(
action_name=action_name,
action_data=action_data,
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
observations=observations,
on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
current_cycle=current_cycle,
log_prefix=log_prefix,
total_no_reply_count=total_no_reply_count,
total_waiting_time=total_waiting_time,
shutting_down=shutting_down,
expressor=expressor,
chat_stream=chat_stream,
log_prefix=log_prefix,
shutting_down=shutting_down,
)
return instance
@@ -233,11 +261,7 @@ class ActionManager:
if require is None:
require = []
action_info = {
"description": description,
"parameters": parameters,
"require": require
}
action_info = {"description": description, "parameters": parameters, "require": require}
self._registered_actions[action_name] = action_info
return True
@@ -281,7 +305,3 @@ class ActionManager:
Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None
"""
return _ACTION_REGISTRY.get(action_name)
# 创建全局实例
ActionFactory = ActionManager()

View File

@@ -0,0 +1,5 @@
# 导入所有动作模块以确保装饰器被执行
from . import reply_action # noqa
from . import no_reply_action # noqa
# 在此处添加更多动作模块导入

View File

@@ -25,8 +25,8 @@ def register_action(cls):
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
return cls
action_name = getattr(cls, "action_name")
action_description = getattr(cls, "action_description")
action_name = cls.action_name
action_description = cls.action_description
is_default = getattr(cls, "default", False)
if not action_name or not action_description:
@@ -60,14 +60,13 @@ class BaseAction(ABC):
cycle_timers: 计时器字典
thinking_id: 思考ID
"""
#每个动作必须实现
self.action_name:str = "base_action"
self.action_description:str = "基础动作"
self.action_parameters:dict = {}
self.action_require:list[str] = []
self.default:bool = False
# 每个动作必须实现
self.action_name: str = "base_action"
self.action_description: str = "基础动作"
self.action_parameters: dict = {}
self.action_require: list[str] = []
self.default: bool = False
self.action_data = action_data
self.reasoning = reasoning

View File

@@ -0,0 +1,108 @@
import asyncio
import traceback
from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List, Callable, Coroutine
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.sub_heartflow import SubHeartFlow
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import ChatState
logger = get_logger("action_taken")
@register_action
class ExitFocusChatAction(BaseAction):
"""退出专注聊天动作处理类
处理决定退出专注聊天的动作。
执行后会将所属的sub heartflow转变为normal_chat状态。
"""
action_name = "exit_focus_chat"
action_description = "退出专注聊天,转为普通聊天模式"
action_parameters = {}
action_require = [
"很长时间没有回复,你决定退出专注聊天",
"当前内容不需要持续专注关注,你决定退出专注聊天",
"聊天内容已经完成,你决定退出专注聊天",
]
default = True
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
observations: List[Observation],
log_prefix: str,
chat_stream: ChatStream,
shutting_down: bool = False,
**kwargs,
):
"""初始化退出专注聊天动作处理器
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
observations: 观察列表
log_prefix: 日志前缀
shutting_down: 是否正在关闭
"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
self.observations = observations
self.log_prefix = log_prefix
self._shutting_down = shutting_down
self.chat_id = chat_stream.stream_id
async def handle_action(self) -> Tuple[bool, str]:
"""
处理退出专注聊天的情况
工作流程:
1. 将sub heartflow转换为normal_chat状态
2. 等待新消息、超时或关闭信号
3. 根据等待结果更新连续不回复计数
4. 如果达到阈值,触发回调
Returns:
Tuple[bool, str]: (是否执行成功, 状态转换消息)
"""
try:
# 转换状态
status_message = ""
self.sub_heartflow = await heartflow.get_or_create_subheartflow(self.chat_id)
if self.sub_heartflow:
try:
# 转换为normal_chat状态
await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT)
status_message = "已成功切换到普通聊天模式"
logger.info(f"{self.log_prefix} {status_message}")
except Exception as e:
error_msg = f"切换到普通聊天模式失败: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg
else:
warning_msg = "未找到有效的sub heartflow实例无法切换状态"
logger.warning(f"{self.log_prefix} {warning_msg}")
return False, warning_msg
return True, status_message
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 处理 'exit_focus_chat' 时等待被中断 (CancelledError)")
raise
except Exception as e:
error_msg = f"处理 'exit_focus_chat' 时发生错误: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
logger.error(traceback.format_exc())
return False, error_msg

View File

@@ -6,14 +6,12 @@ from src.chat.focus_chat.planners.actions.base_action import BaseAction, registe
from typing import Tuple, List, Callable, Coroutine
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
logger = get_logger("action_taken")
# 常量定义
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
CONSECUTIVE_NO_REPLY_THRESHOLD = 3 # 连续不回复的阈值
@register_action
@@ -29,7 +27,7 @@ class NoReplyAction(BaseAction):
action_require = [
"话题无关/无聊/不感兴趣/不懂",
"最后一条消息是你自己发的且无人回应你",
"你发送了太多消息,且无人回复"
"你发送了太多消息,且无人回复",
]
default = True
@@ -40,13 +38,9 @@ class NoReplyAction(BaseAction):
cycle_timers: dict,
thinking_id: str,
observations: List[Observation],
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
current_cycle: CycleDetail,
log_prefix: str,
total_no_reply_count: int = 0,
total_waiting_time: float = 0.0,
shutting_down: bool = False,
**kwargs
**kwargs,
):
"""初始化不回复动作处理器
@@ -57,20 +51,12 @@ class NoReplyAction(BaseAction):
cycle_timers: 计时器字典
thinking_id: 思考ID
observations: 观察列表
on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的回调函数
current_cycle: 当前循环信息
log_prefix: 日志前缀
total_no_reply_count: 连续不回复计数
total_waiting_time: 累计等待时间
shutting_down: 是否正在关闭
"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
self.observations = observations
self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
self._current_cycle = current_cycle
self.log_prefix = log_prefix
self.total_no_reply_count = total_no_reply_count
self.total_waiting_time = total_waiting_time
self._shutting_down = shutting_down
async def handle_action(self) -> Tuple[bool, str]:
@@ -93,37 +79,6 @@ class NoReplyAction(BaseAction):
with Timer("等待新消息", self.cycle_timers):
# 等待新消息、超时或关闭信号,并获取结果
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
# 从计时器获取实际等待时间
current_waiting = self.cycle_timers.get("等待新消息", 0.0)
if not self._shutting_down:
self.total_no_reply_count += 1
self.total_waiting_time += current_waiting # 累加等待时间
logger.debug(
f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, "
f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}"
)
# 检查是否同时达到次数和时间阈值
time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD
if (
self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD
and self.total_waiting_time >= time_threshold
):
logger.info(
f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) "
f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒)"
f"调用回调请求状态转换"
)
# 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。
await self.on_consecutive_no_reply_callback()
elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD:
# 仅次数达到阈值,但时间未达到
logger.debug(
f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) "
f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调"
)
# else: 次数和时间都未达到阈值,不做处理
return True, "" # 不回复动作没有回复文本

View File

@@ -0,0 +1,203 @@
import traceback
from typing import Tuple, Dict, List, Any, Optional
from src.chat.focus_chat.planners.actions.base_action import BaseAction
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger
from src.chat.person_info.person_info import person_info_manager
from abc import abstractmethod
logger = get_logger("plugin_action")
class PluginAction(BaseAction):
"""插件动作基类
封装了主程序内部依赖提供简化的API接口给插件开发者
"""
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
"""初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用
self._services = {}
# 从kwargs提取必要的内部服务
if "observations" in kwargs:
self._services["observations"] = kwargs["observations"]
if "expressor" in kwargs:
self._services["expressor"] = kwargs["expressor"]
if "chat_stream" in kwargs:
self._services["chat_stream"] = kwargs["chat_stream"]
self.log_prefix = kwargs.get("log_prefix", "")
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
"""根据用户名获取用户ID"""
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
# 提供简化的API方法
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
response_set = [
("text", text),
]
# 调用内部方法发送消息
success = await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
success, _ = await expressor.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
return False
def get_chat_type(self) -> str:
"""获取当前聊天类型
Returns:
str: 聊天类型 ("group""private")
"""
chat_stream = self._services.get("chat_stream")
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
observations = self._services.get("observations", [])
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
return messages
@abstractmethod
async def process(self) -> Tuple[bool, str]:
"""插件处理逻辑,子类必须实现此方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""实现BaseAction的抽象方法调用子类的process方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.process()

View File

@@ -1,14 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List, Optional
from typing import Tuple, List
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
@@ -22,29 +19,27 @@ class ReplyAction(BaseAction):
处理构建和发送消息回复的动作。
"""
action_name:str = "reply"
action_description:str = "表达想法,可以只包含文本、表情或两者都有"
action_parameters:dict[str:str] = {
action_name: str = "reply"
action_description: str = "表达想法,可以只包含文本、表情或两者都有"
action_parameters: dict[str:str] = {
"text": "你想要表达的内容(可选)",
"emojis": "描述当前使用表情包的场景(可选)",
"emojis": "描述当前使用表情包的场景,一段话描述(可选)",
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
}
action_require:list[str] = [
action_require: list[str] = [
"有实质性内容需要表达",
"有人提到你,但你还没有回应他",
"在合适的时候添加表情(不要总是添加)",
"如果你要回复特定某人的某句话或者你想回复较早的消息请在target中指定那句话的原始文本",
"除非有明确的回复目标如果选择了target不用特别提到某个人的人名",
"在合适的时候添加表情(不要总是添加),表情描述要详细,描述当前场景,一段话描述",
"如果你有明确的,要回复特定某人的某句话或者你想回复较早的消息请在target中指定那句话的原始文本",
"一次只回复一个人,一次只回复一个话题,突出重点",
"如果是自己发的消息想继续,需自然衔接",
"避免重复或评价自己的发言,不要和自己聊天",
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短",
]
default = True
def __init__(
self,
action_name: str,
action_data: dict,
reasoning: str,
cycle_timers: dict,
@@ -52,9 +47,8 @@ class ReplyAction(BaseAction):
observations: List[Observation],
expressor: DefaultExpressor,
chat_stream: ChatStream,
current_cycle: CycleDetail,
log_prefix: str,
**kwargs
**kwargs,
):
"""初始化回复动作处理器
@@ -67,14 +61,12 @@ class ReplyAction(BaseAction):
observations: 观察列表
expressor: 表达器
chat_stream: 聊天流
current_cycle: 当前循环信息
log_prefix: 日志前缀
"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
self.observations = observations
self.expressor = expressor
self.chat_stream = chat_stream
self._current_cycle = current_cycle
self.log_prefix = log_prefix
async def handle_action(self) -> Tuple[bool, str]:
@@ -89,7 +81,7 @@ class ReplyAction(BaseAction):
reasoning=self.reasoning,
reply_data=self.action_data,
cycle_timers=self.cycle_timers,
thinking_id=self.thinking_id
thinking_id=self.thinking_id,
)
async def _handle_reply(
@@ -105,13 +97,15 @@ class ReplyAction(BaseAction):
"emojis": "微笑" # 表情关键词列表(可选)
}
"""
# 重置连续不回复计数器
self.total_no_reply_count = 0
self.total_waiting_time = 0.0
# 从聊天观察获取锚定消息
observations: ChattingObservation = self.observations[0]
anchor_message = observations.serch_message_by_text(reply_data["target"])
chatting_observation: ChattingObservation = next(
obs for obs in self.observations if isinstance(obs, ChattingObservation)
)
if reply_data.get("target"):
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
else:
anchor_message = None
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:

View File

@@ -4,25 +4,30 @@ from typing import List, Dict, Any, Optional
from rich.traceback import install
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.obs_info import ObsInfo
from src.chat.focus_chat.info.cycle_info import CycleInfo
from src.chat.focus_chat.info.mind_info import MindInfo
from src.chat.focus_chat.info.action_info import ActionInfo
from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import Individuality
from src.chat.focus_chat.planners.action_factory import ActionManager
from src.chat.focus_chat.planners.action_factory import ActionInfo
from src.chat.focus_chat.planners.action_manager import ActionManager
logger = get_logger("planner")
install(extra_lines=3)
def init_prompt():
Prompt(
"""你的名字是{bot_name},{prompt_personality}{chat_context_description}。需要基于以下信息决定如何参与对话:
"""{extra_info_block}
你需要基于以下信息决定如何参与对话
这些信息可能会有冲突请你整合这些信息并选择一个最合适的action
{chat_content_block}
{mind_info_block}
{cycle_info_block}
@@ -44,17 +49,17 @@ def init_prompt():
}}
请输出你的决策 JSON""",
"planner_prompt",)
"planner_prompt",
)
Prompt(
"""
action_name: {action_name}
描述:{action_description}
参数:
{action_parameters}
{action_parameters}
动作要求:
{action_require}
""",
{action_require}""",
"action_prompt",
)
@@ -64,7 +69,7 @@ class ActionPlanner:
self.log_prefix = log_prefix
# LLM规划器配置
self.planner_llm = LLMRequest(
model=global_config.llm_plan,
model=global_config.model.plan,
max_tokens=1000,
request_type="action_planning", # 用于动作规划
)
@@ -82,31 +87,69 @@ class ActionPlanner:
action = "no_reply" # 默认动作
reasoning = "规划器初始化默认"
action_data = {}
try:
# 获取观察信息
extra_info: list[str] = []
# 首先处理动作变更
for info in all_plan_info:
if isinstance(info, ActionInfo) and info.has_changes():
add_actions = info.get_add_actions()
remove_actions = info.get_remove_actions()
reason = info.get_reason()
# 处理动作的增加
for action_name in add_actions:
if action_name in self.action_manager.get_registered_actions():
self.action_manager.add_action_to_using(action_name)
logger.debug(f"{self.log_prefix}添加动作: {action_name}, 原因: {reason}")
# 处理动作的移除
for action_name in remove_actions:
self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}移除动作: {action_name}, 原因: {reason}")
# 如果当前选择的动作被移除了更新为no_reply
if action in remove_actions:
action = "no_reply"
reasoning = f"之前选择的动作{action}已被移除,原因: {reason}"
# 继续处理其他信息
for info in all_plan_info:
if isinstance(info, ObsInfo):
logger.debug(f"{self.log_prefix} 观察信息: {info}")
observed_messages = info.get_talking_message()
observed_messages_str = info.get_talking_message_str_truncate()
chat_type = info.get_chat_type()
if chat_type == "group":
is_group_chat = True
else:
is_group_chat = False
is_group_chat = (chat_type == "group")
elif isinstance(info, MindInfo):
logger.debug(f"{self.log_prefix} 思维信息: {info}")
current_mind = info.get_current_mind()
elif isinstance(info, CycleInfo):
logger.debug(f"{self.log_prefix} 循环信息: {info}")
cycle_info = info.get_observe_info()
elif isinstance(info, StructuredInfo):
logger.debug(f"{self.log_prefix} 结构化信息: {info}")
structured_info = info.get_data()
_structured_info = info.get_data()
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
extra_info.append(info.get_processed_info())
# 获取当前可用的动作
current_available_actions = self.action_manager.get_using_actions()
# 如果没有可用动作直接返回no_reply
if not current_available_actions:
logger.warning(f"{self.log_prefix}没有可用的动作将使用no_reply")
action = "no_reply"
reasoning = "没有可用的动作"
return {
"action_result": {
"action_type": action,
"action_data": action_data,
"reasoning": reasoning
},
"current_mind": current_mind,
"observed_messages": observed_messages
}
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt = await self.build_planner_prompt(
is_group_chat=is_group_chat, # <-- Pass HFC state
@@ -116,6 +159,7 @@ class ActionPlanner:
# structured_info=structured_info, # <-- Pass SubMind info
current_available_actions=current_available_actions, # <-- Pass determined actions
cycle_info=cycle_info, # <-- Pass cycle info
extra_info=extra_info,
)
# --- 调用 LLM (普通文本生成) ---
@@ -142,15 +186,13 @@ class ActionPlanner:
extracted_action = parsed_json.get("action", "no_reply")
extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由")
# 新的reply格式
if extracted_action == "reply":
action_data = {
"text": parsed_json.get("text", []),
"emojis": parsed_json.get("emojis", []),
"target": parsed_json.get("target", ""),
}
else:
action_data = {} # 其他动作可能不需要额外数据
# 将所有其他属性添加到action_data
action_data = {}
for key, value in parsed_json.items():
if key not in ["action", "reasoning"]:
action_data[key] = value
# 对于reply动作不需要额外处理因为相关字段已经在上面的循环中添加到action_data
if extracted_action not in current_available_actions:
logger.warning(
@@ -173,7 +215,7 @@ class ActionPlanner:
except Exception as outer_e:
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}")
traceback.print_exc()
action = "no_reply" # 发生未知错误,标记为 error 动作
action = "no_reply"
reasoning = f"Planner 内部处理错误: {outer_e}"
logger.debug(
@@ -194,10 +236,8 @@ class ActionPlanner:
"observed_messages": observed_messages,
}
# 返回结果字典
return plan_result
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
@@ -206,6 +246,7 @@ class ActionPlanner:
current_mind: Optional[str],
current_available_actions: Dict[str, ActionInfo],
cycle_info: Optional[str],
extra_info: list[str],
) -> str:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
@@ -218,7 +259,6 @@ class ActionPlanner:
)
chat_context_description = f"你正在和 {chat_target_name} 私聊"
chat_content_block = ""
if observed_messages_str:
chat_content_block = f"聊天记录:\n{observed_messages_str}"
@@ -234,7 +274,6 @@ class ActionPlanner:
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
action_options_block = ""
for using_actions_name, using_actions_info in current_available_actions.items():
# print(using_actions_name)
@@ -247,11 +286,11 @@ class ActionPlanner:
param_text = ""
for param_name, param_description in using_actions_info["parameters"].items():
param_text += f"{param_name}: {param_description}\n"
param_text += f" {param_name}: {param_description}\n"
require_text = ""
for require_item in using_actions_info["require"]:
require_text += f"- {require_item}\n"
require_text += f" - {require_item}\n"
using_action_prompt = using_action_prompt.format(
action_name=using_actions_name,
@@ -262,18 +301,19 @@ class ActionPlanner:
action_options_block += using_action_prompt
extra_info_block = "\n".join(extra_info)
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
bot_name=global_config.BOT_NICKNAME,
bot_name=global_config.bot.nickname,
prompt_personality=personality_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
mind_info_block=mind_info_block,
cycle_info_block=cycle_info,
action_options_text=action_options_block,
extra_info_block=extra_info_block,
)
return prompt

View File

@@ -0,0 +1,112 @@
from typing import Dict, Any, List, Optional, Set, Tuple
import time
import random
import string
class MemoryItem:
"""记忆项类,用于存储单个记忆的所有相关信息"""
def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
初始化记忆项
Args:
data: 记忆数据
from_source: 数据来源
tags: 数据标签列表
"""
# 生成可读ID时间戳_随机字符串
timestamp = int(time.time())
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
self.id = f"{timestamp}_{random_str}"
self.data = data
self.data_type = type(data)
self.from_source = from_source
self.tags = set(tags) if tags else set()
self.timestamp = time.time()
# 修改summary的结构说明用于存储可能的总结信息
# summary结构{
# "brief": "记忆内容主题",
# "detailed": "记忆内容概括",
# "keypoints": ["关键概念1", "关键概念2"],
# "events": ["事件1", "事件2"]
# }
self.summary = None
# 记忆精简次数
self.compress_count = 0
# 记忆提取次数
self.retrieval_count = 0
# 记忆强度 (初始为10)
self.memory_strength = 10.0
# 记忆操作历史记录
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
def add_tag(self, tag: str) -> None:
"""添加标签"""
self.tags.add(tag)
def remove_tag(self, tag: str) -> None:
"""移除标签"""
if tag in self.tags:
self.tags.remove(tag)
def has_tag(self, tag: str) -> bool:
"""检查是否有特定标签"""
return tag in self.tags
def has_all_tags(self, tags: List[str]) -> bool:
"""检查是否有所有指定的标签"""
return all(tag in self.tags for tag in tags)
def matches_source(self, source: str) -> bool:
"""检查来源是否匹配"""
return self.from_source == source
def set_summary(self, summary: Dict[str, Any]) -> None:
"""设置总结信息"""
self.summary = summary
def increase_strength(self, amount: float) -> None:
"""增加记忆强度"""
self.memory_strength = min(10.0, self.memory_strength + amount)
# 记录操作历史
self.record_operation("strengthen")
def decrease_strength(self, amount: float) -> None:
"""减少记忆强度"""
self.memory_strength = max(0.1, self.memory_strength - amount)
# 记录操作历史
self.record_operation("weaken")
def increase_compress_count(self) -> None:
"""增加精简次数并减弱记忆强度"""
self.compress_count += 1
# 记录操作历史
self.record_operation("compress")
def record_retrieval(self) -> None:
"""记录记忆被提取的情况"""
self.retrieval_count += 1
# 提取后强度翻倍
self.memory_strength = min(10.0, self.memory_strength * 2)
# 记录操作历史
self.record_operation("retrieval")
def record_operation(self, operation_type: str) -> None:
"""记录操作历史"""
current_time = time.time()
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]:
"""转换为元组格式(为了兼容性)"""
return (self.data, self.from_source, self.tags, self.timestamp, self.id)
def is_memory_valid(self) -> bool:
"""检查记忆是否有效强度是否大于等于1"""
return self.memory_strength >= 1.0

View File

@@ -0,0 +1,781 @@
from typing import Dict, Any, Type, TypeVar, List, Optional
import traceback
from json_repair import repair_json
from rich.traceback import install
from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
import json # 添加json模块导入
install(extra_lines=3)
logger = get_logger("working_memory")
T = TypeVar("T")
class MemoryManager:
def __init__(self, chat_id: str):
"""
初始化工作记忆
Args:
chat_id: 关联的聊天ID用于标识该工作记忆属于哪个聊天
"""
# 关联的聊天ID
self._chat_id = chat_id
# 主存储: 数据类型 -> 记忆项列表
self._memory: Dict[Type, List[MemoryItem]] = {}
# ID到记忆项的映射
self._id_map: Dict[str, MemoryItem] = {}
self.llm_summarizer = LLMRequest(
model=global_config.model.summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
)
@property
def chat_id(self) -> str:
"""获取关联的聊天ID"""
return self._chat_id
@chat_id.setter
def chat_id(self, value: str):
"""设置关联的聊天ID"""
self._chat_id = value
def push_item(self, memory_item: MemoryItem) -> str:
"""
推送一个已创建的记忆项到工作记忆中
Args:
memory_item: 要存储的记忆项
Returns:
记忆项的ID
"""
data_type = memory_item.data_type
# 确保存在该类型的存储列表
if data_type not in self._memory:
self._memory[data_type] = []
# 添加到内存和ID映射
self._memory[data_type].append(memory_item)
self._id_map[memory_item.id] = memory_item
return memory_item.id
async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem:
"""
推送一段有类型的信息到工作记忆中,并自动生成总结
Args:
data: 要存储的数据
from_source: 数据来源
tags: 数据标签列表
Returns:
包含原始数据和总结信息的字典
"""
# 如果数据是字符串类型,则先进行总结
if isinstance(data, str):
# 先生成总结
summary = await self.summarize_memory_item(data)
# 准备标签
memory_tags = list(tags) if tags else []
# 创建记忆项
memory_item = MemoryItem(data, from_source, memory_tags)
# 将总结信息保存到记忆项中
memory_item.set_summary(summary)
# 推送记忆项
self.push_item(memory_item)
return memory_item
else:
# 非字符串类型,直接创建并推送记忆项
memory_item = MemoryItem(data, from_source, tags)
self.push_item(memory_item)
return memory_item
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
"""
通过ID获取记忆项
Args:
memory_id: 记忆项ID
Returns:
找到的记忆项如果不存在则返回None
"""
memory_item = self._id_map.get(memory_id)
if memory_item:
# 检查记忆强度如果小于1则删除
if not memory_item.is_memory_valid():
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
self.delete(memory_id)
return None
return memory_item
def get_all_items(self) -> List[MemoryItem]:
"""获取所有记忆项"""
return list(self._id_map.values())
def find_items(
self,
data_type: Optional[Type] = None,
source: Optional[str] = None,
tags: Optional[List[str]] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
memory_id: Optional[str] = None,
limit: Optional[int] = None,
newest_first: bool = False,
min_strength: float = 0.0,
) -> List[MemoryItem]:
"""
按条件查找记忆项
Args:
data_type: 要查找的数据类型
source: 数据来源
tags: 必须包含的标签列表
start_time: 开始时间戳
end_time: 结束时间戳
memory_id: 特定记忆项ID
limit: 返回结果的最大数量
newest_first: 是否按最新优先排序
min_strength: 最小记忆强度
Returns:
符合条件的记忆项列表
"""
# 如果提供了特定ID直接查找
if memory_id:
item = self.get_by_id(memory_id)
return [item] if item else []
results = []
# 确定要搜索的类型列表
types_to_search = [data_type] if data_type else list(self._memory.keys())
# 对每个类型进行搜索
for typ in types_to_search:
if typ not in self._memory:
continue
# 获取该类型的所有项目
items = self._memory[typ]
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 检查标签是否匹配
if tags is not None and not item.has_all_tags(tags):
continue
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
return results
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
"""
使用LLM总结记忆项
Args:
content: 需要总结的内容
Returns:
包含总结、概括、关键概念和事件的字典
"""
prompt = f"""请对以下内容进行总结,总结成记忆,输出四部分:
1. 记忆内容主题精简20字以内让用户可以一眼看出记忆内容是什么
2. 记忆内容概括200字以内让用户可以了解记忆内容的大致内容
3. 关键概念和知识keypoints多条提取关键的概念、知识点和关键词要包含对概念的解释
4. 事件描述events多条描述谁人物在什么时候时间做了什么事件
内容:
{content}
请按以下JSON格式输出
```json
{{
"brief": "记忆内容主题20字以内",
"detailed": "记忆内容概括200字以内",
"keypoints": [
"概念1解释",
"概念2解释",
...
],
"events": [
"事件1谁在什么时候做了什么",
"事件2谁在什么时候做了什么",
...
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
default_summary = {
"brief": "主题未知的记忆",
"detailed": "大致内容未知的记忆",
"keypoints": ["未知的概念"],
"events": ["未知的事件"],
}
try:
# 调用LLM生成总结
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 使用repair_json解析响应
try:
# 使用repair_json修复JSON格式
fixed_json_string = repair_json(response)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json_string, str):
try:
json_result = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
return default_summary
else:
# 如果repair_json直接返回了字典对象直接使用
json_result = fixed_json_string
# 进行额外的类型检查
if not isinstance(json_result, dict):
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
return default_summary
# 确保所有必要字段都存在且类型正确
if "brief" not in json_result or not isinstance(json_result["brief"], str):
json_result["brief"] = "主题未知的记忆"
if "detailed" not in json_result or not isinstance(json_result["detailed"], str):
json_result["detailed"] = "大致内容未知的记忆"
# 处理关键概念
if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list):
json_result["keypoints"] = ["未知的概念"]
else:
# 确保keypoints中的每个项目都是字符串
json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
if not json_result["keypoints"]:
json_result["keypoints"] = ["未知的概念"]
# 处理事件
if "events" not in json_result or not isinstance(json_result["events"], list):
json_result["events"] = ["未知的事件"]
else:
# 确保events中的每个项目都是字符串
json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
if not json_result["events"]:
json_result["events"] = ["未知的事件"]
# 兼容旧版将keypoints和events合并到key_points中
json_result["key_points"] = json_result["keypoints"] + json_result["events"]
return json_result
except Exception as json_error:
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
# 返回默认结构
return default_summary
except Exception as e:
# 出错时返回简单的结构
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
"""
对记忆进行精简操作,根据要求修改要点、总结和概括
Args:
memory_id: 记忆ID
requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
Returns:
修改后的记忆总结字典
"""
# 获取指定ID的记忆项
logger.info(f"精简记忆: {memory_id}")
memory_item = self.get_by_id(memory_id)
if not memory_item:
raise ValueError(f"未找到ID为{memory_id}的记忆项")
# 增加精简次数
memory_item.increase_compress_count()
summary = memory_item.summary
# 使用LLM根据要求对总结、概括和要点进行精简修改
prompt = f"""
请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程:
要求:{requirements}
你可以随机对关键概念和事件进行压缩,模糊或者丢弃,修改后,同样修改主题和概括
目前主题:{summary["brief"]}
目前概括:{summary["detailed"]}
目前关键概念:
{chr(10).join([f"- {point}" for point in summary.get("keypoints", [])])}
目前事件:
{chr(10).join([f"- {point}" for point in summary.get("events", [])])}
请生成修改后的主题、概括、关键概念和事件,遵循以下格式:
```json
{{
"brief": "修改后的主题20字以内",
"detailed": "修改后的概括200字以内",
"keypoints": [
"修改后的概念1解释",
"修改后的概念2解释"
],
"events": [
"修改后的事件1谁在什么时候做了什么",
"修改后的事件2谁在什么时候做了什么"
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 检查summary中是否有旧版结构转换为新版结构
if "keypoints" not in summary and "events" not in summary and "key_points" in summary:
# 尝试区分key_points中的keypoints和events
# 简单地将前半部分视为keypoints后半部分视为events
key_points = summary.get("key_points", [])
halfway = len(key_points) // 2
summary["keypoints"] = key_points[:halfway] or ["未知的概念"]
summary["events"] = key_points[halfway:] or ["未知的事件"]
# 定义默认的精简结果
default_refined = {
"brief": summary["brief"],
"detailed": summary["detailed"],
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
"events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
}
try:
# 调用LLM修改总结、概括和要点
response, _ = await self.llm_summarizer.generate_response_async(prompt)
logger.info(f"精简记忆响应: {response}")
# 使用repair_json处理响应
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
refined_data = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
refined_data = default_refined
else:
# 如果repair_json直接返回了字典对象直接使用
refined_data = fixed_json_string
# 确保是字典类型
if not isinstance(refined_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
refined_data = default_refined
# 更新总结、概括
summary["brief"] = refined_data.get("brief", "主题未知的记忆")
summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆")
# 更新关键概念
keypoints = refined_data.get("keypoints", [])
if isinstance(keypoints, list) and keypoints:
# 确保所有关键概念都是字符串
summary["keypoints"] = [str(point) for point in keypoints if point is not None]
else:
# 如果keypoints不是列表或为空使用默认值
summary["keypoints"] = ["主要概念已遗忘"]
# 更新事件
events = refined_data.get("events", [])
if isinstance(events, list) and events:
# 确保所有事件都是字符串
summary["events"] = [str(event) for event in events if event is not None]
else:
# 如果events不是列表或为空使用默认值
summary["events"] = ["事件细节已遗忘"]
# 兼容旧版维护key_points
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆出错: {str(e)}")
traceback.print_exc()
# 出错时使用简化的默认精简
summary["brief"] = summary["brief"] + " (已简化)"
summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1]
summary["events"] = summary.get("events", ["未知的事件"])[:1]
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆调用LLM出错: {str(e)}")
traceback.print_exc()
# 更新原记忆项的总结
memory_item.set_summary(summary)
return memory_item
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""
使单个记忆衰减
Args:
memory_id: 记忆ID
decay_factor: 衰减因子(0-1之间)
Returns:
是否成功衰减
"""
memory_item = self.get_by_id(memory_id)
if not memory_item:
return False
# 计算衰减量(当前强度 * (1-衰减因子)
old_strength = memory_item.memory_strength
decay_amount = old_strength * (1 - decay_factor)
# 更新强度
memory_item.memory_strength = decay_amount
return True
def delete(self, memory_id: str) -> bool:
"""
删除指定ID的记忆项
Args:
memory_id: 要删除的记忆项ID
Returns:
是否成功删除
"""
if memory_id not in self._id_map:
return False
# 获取要删除的项
item = self._id_map[memory_id]
# 从内存中删除
data_type = item.data_type
if data_type in self._memory:
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
# 从ID映射中删除
del self._id_map[memory_id]
return True
def clear(self, data_type: Optional[Type] = None) -> None:
"""
清除记忆中的数据
Args:
data_type: 要清除的数据类型如果为None则清除所有数据
"""
if data_type is None:
# 清除所有数据
self._memory.clear()
self._id_map.clear()
elif data_type in self._memory:
# 清除指定类型的数据
for item in self._memory[data_type]:
if item.id in self._id_map:
del self._id_map[item.id]
del self._memory[data_type]
async def merge_memories(
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
) -> MemoryItem:
"""
合并两个记忆项
Args:
memory_id1: 第一个记忆项ID
memory_id2: 第二个记忆项ID
reason: 合并原因
delete_originals: 是否删除原始记忆默认为True
Returns:
包含合并后的记忆信息的字典
"""
# 获取两个记忆项
memory_item1 = self.get_by_id(memory_id1)
memory_item2 = self.get_by_id(memory_id2)
if not memory_item1 or not memory_item2:
raise ValueError("无法找到指定的记忆项")
content1 = memory_item1.data
content2 = memory_item2.data
# 获取记忆的摘要信息(如果有)
summary1 = memory_item1.summary
summary2 = memory_item2.summary
# 构建合并提示
prompt = f"""
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
合并原因:{reason}
"""
# 如果有摘要信息,添加到提示中
if summary1:
prompt += f"记忆1主题{summary1['brief']}\n"
prompt += f"记忆1概括{summary1['detailed']}\n"
if "keypoints" in summary1:
prompt += "记忆1关键概念\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
if "events" in summary1:
prompt += "记忆1事件\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
elif "key_points" in summary1:
prompt += "记忆1要点\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
if summary2:
prompt += f"记忆2主题{summary2['brief']}\n"
prompt += f"记忆2概括{summary2['detailed']}\n"
if "keypoints" in summary2:
prompt += "记忆2关键概念\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
if "events" in summary2:
prompt += "记忆2事件\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
elif "key_points" in summary2:
prompt += "记忆2要点\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
# 添加记忆原始内容
prompt += f"""
记忆1原始内容
{content1}
记忆2原始内容
{content2}
请按以下JSON格式输出合并结果
```json
{{
"content": "合并后的记忆内容文本(尽可能保留原信息,但去除重复)",
"brief": "合并后的主题20字以内",
"detailed": "合并后的概括200字以内",
"keypoints": [
"合并后的概念1解释",
"合并后的概念2解释",
"合并后的概念3解释"
],
"events": [
"合并后的事件1谁在什么时候做了什么",
"合并后的事件2谁在什么时候做了什么"
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 默认合并结果
default_merged = {
"content": f"{content1}\n\n{content2}",
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
"keypoints": [],
"events": [],
}
# 合并旧版key_points
if "key_points" in summary1:
default_merged["keypoints"].extend(summary1.get("keypoints", []))
default_merged["events"].extend(summary1.get("events", []))
# 如果没有新的结构,尝试从旧结构分离
if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary1:
key_points = summary1["key_points"]
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
if "key_points" in summary2:
default_merged["keypoints"].extend(summary2.get("keypoints", []))
default_merged["events"].extend(summary2.get("events", []))
# 如果没有新的结构,尝试从旧结构分离
if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary2:
key_points = summary2["key_points"]
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
# 确保列表不为空
if not default_merged["keypoints"]:
default_merged["keypoints"] = ["合并的关键概念"]
if not default_merged["events"]:
default_merged["events"] = ["合并的事件"]
# 添加key_points兼容
default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"]
try:
# 调用LLM合并记忆
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 处理LLM返回的合并结果
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
merged_data = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
merged_data = default_merged
else:
# 如果repair_json直接返回了字典对象直接使用
merged_data = fixed_json_string
# 确保是字典类型
if not isinstance(merged_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
merged_data = default_merged
# 确保所有必要字段都存在且类型正确
if "content" not in merged_data or not isinstance(merged_data["content"], str):
merged_data["content"] = default_merged["content"]
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
merged_data["brief"] = default_merged["brief"]
if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str):
merged_data["detailed"] = default_merged["detailed"]
# 处理关键概念
if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list):
merged_data["keypoints"] = default_merged["keypoints"]
else:
# 确保keypoints中的每个项目都是字符串
merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
if not merged_data["keypoints"]:
merged_data["keypoints"] = ["合并的关键概念"]
# 处理事件
if "events" not in merged_data or not isinstance(merged_data["events"], list):
merged_data["events"] = default_merged["events"]
else:
# 确保events中的每个项目都是字符串
merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
if not merged_data["events"]:
merged_data["events"] = ["合并的事件"]
# 添加key_points兼容
merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"]
except Exception as e:
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
except Exception as e:
logger.error(f"合并记忆调用LLM出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
# 创建新的记忆项
# 合并记忆项的标签
merged_tags = memory_item1.tags.union(memory_item2.tags)
# 取两个记忆项中更强的来源
merged_source = (
memory_item1.from_source
if memory_item1.memory_strength >= memory_item2.memory_strength
else memory_item2.from_source
)
# 创建新的记忆项
merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
# 设置合并后的摘要
summary = {
"brief": merged_data["brief"],
"detailed": merged_data["detailed"],
"keypoints": merged_data["keypoints"],
"events": merged_data["events"],
"key_points": merged_data["key_points"],
}
merged_memory.set_summary(summary)
# 记忆强度取两者最大值
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
# 添加到存储中
self.push_item(merged_memory)
# 如果需要,删除原始记忆
if delete_originals:
self.delete(memory_id1)
self.delete(memory_id2)
return merged_memory
def delete_earliest_memory(self) -> bool:
"""
删除最早的记忆项
Returns:
是否成功删除
"""
# 获取所有记忆项
all_memories = self.get_all_items()
if not all_memories:
return False
# 按时间戳排序,找到最早的记忆项
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
# 删除最早的记忆项
return self.delete(earliest_memory.id)

View File

@@ -0,0 +1,192 @@
from typing import List, Any, Optional
import asyncio
import random
from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
logger = get_logger(__name__)
# 问题是我不知道这个manager是不是需要和其他manager统一管理因为这个manager是从属于每一个聊天流都有自己的定时任务
class WorkingMemory:
"""
工作记忆,负责协调和运作记忆
从属于特定的流用chat_id来标识
"""
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
"""
初始化工作记忆管理器
Args:
max_memories_per_chat: 每个聊天的最大记忆数量
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
"""
self.memory_manager = MemoryManager(chat_id)
# 记忆容量上限
self.max_memories_per_chat = max_memories_per_chat
# 自动衰减间隔
self.auto_decay_interval = auto_decay_interval
# 衰减任务
self.decay_task = None
# 启动自动衰减任务
self._start_auto_decay()
def _start_auto_decay(self):
"""启动自动衰减任务"""
if self.decay_task is None:
self.decay_task = asyncio.create_task(self._auto_decay_loop())
async def _auto_decay_loop(self):
"""自动衰减循环"""
while True:
await asyncio.sleep(self.auto_decay_interval)
try:
await self.decay_all_memories()
except Exception as e:
print(f"自动衰减记忆时出错: {str(e)}")
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
添加一段记忆到指定聊天
Args:
content: 记忆内容
from_source: 数据来源
tags: 数据标签列表
Returns:
包含记忆信息的字典
"""
memory = await self.memory_manager.push_with_summary(content, from_source, tags)
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
self.remove_earliest_memory()
return memory
def remove_earliest_memory(self):
"""
删除最早的记忆
"""
return self.memory_manager.delete_earliest_memory()
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
"""
检索记忆
Args:
chat_id: 聊天ID
memory_id: 记忆ID
Returns:
检索到的记忆项如果不存在则返回None
"""
memory_item = self.memory_manager.get_by_id(memory_id)
if memory_item:
memory_item.retrieval_count += 1
memory_item.increase_strength(5)
return memory_item
return None
async def decay_all_memories(self, decay_factor: float = 0.5):
"""
对所有聊天的所有记忆进行衰减
衰减对记忆进行refine压缩强度会变为原先的0.5
Args:
decay_factor: 衰减因子(0-1之间)
"""
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
all_memories = self.memory_manager.get_all_items()
for memory_item in all_memories:
# 如果压缩完小于1会被删除
memory_id = memory_item.id
self.memory_manager.decay_memory(memory_id, decay_factor)
if memory_item.memory_strength < 1:
self.memory_manager.delete(memory_id)
continue
# 计算衰减量
if memory_item.memory_strength < 5:
await self.memory_manager.refine_memory(
memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
)
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
"""合并记忆
Args:
memory_str: 记忆内容
"""
return await self.memory_manager.merge_memories(
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
)
# 暂时没用,先留着
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
"""
模拟记忆模糊过程,随机选择一部分记忆进行精简
Args:
chat_id: 聊天ID
blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简
"""
memory = self.get_memory(chat_id)
# 获取所有字符串类型且有总结的记忆
all_summarized_memories = []
for type_items in memory._memory.values():
for item in type_items:
if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
all_summarized_memories.append(item)
if not all_summarized_memories:
return
# 计算要模糊的记忆数量
blur_count = max(1, int(len(all_summarized_memories) * blur_rate))
# 随机选择要模糊的记忆
memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories)))
# 对选中的记忆进行精简
for memory_item in memories_to_blur:
try:
# 根据记忆强度决定模糊程度
if memory_item.memory_strength > 7:
requirement = "保留所有重要信息,仅略微精简"
elif memory_item.memory_strength > 4:
requirement = "保留核心要点,适度精简细节"
else:
requirement = "只保留最关键的1-2个要点大幅精简内容"
# 进行精简
await memory.refine_memory(memory_item.id, requirement)
print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}")
except Exception as e:
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
async def shutdown(self) -> None:
"""关闭管理器,停止所有任务"""
if self.decay_task and not self.decay_task.done():
self.decay_task.cancel()
try:
await self.decay_task
except asyncio.CancelledError:
pass
def get_all_memories(self) -> List[MemoryItem]:
"""
获取所有记忆项目
Returns:
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
"""
return self.memory_manager.get_all_items()

View File

@@ -1,13 +1,9 @@
import asyncio
import traceback
from typing import Optional, Coroutine, Callable, Any, List
from src.common.logger_manager import get_logger
# Need manager types for dependency injection
from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.chat.heart_flow.interest_logger import InterestLogger
logger = get_logger("background_tasks")
@@ -62,23 +58,18 @@ class BackgroundTaskManager:
mai_state_info: MaiStateInfo, # Needs current state info
mai_state_manager: MaiStateManager,
subheartflow_manager: SubHeartflowManager,
interest_logger: InterestLogger,
):
self.mai_state_info = mai_state_info
self.mai_state_manager = mai_state_manager
self.subheartflow_manager = subheartflow_manager
self.interest_logger = interest_logger
# Task references
self._state_update_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
self._logging_task: Optional[asyncio.Task] = None
self._normal_chat_timeout_check_task: Optional[asyncio.Task] = None
self._hf_judge_state_update_task: Optional[asyncio.Task] = None
self._into_focus_task: Optional[asyncio.Task] = None
self._private_chat_activation_task: Optional[asyncio.Task] = None # 新增私聊激活任务引用
self._tasks: List[Optional[asyncio.Task]] = [] # Keep track of all tasks
self._detect_command_from_gui_task: Optional[asyncio.Task] = None # 新增GUI命令检测任务引用
async def start_tasks(self):
"""启动所有后台任务
@@ -97,30 +88,12 @@ class BackgroundTaskManager:
f"聊天状态更新任务已启动 间隔:{STATE_UPDATE_INTERVAL_SECONDS}s",
"_state_update_task",
),
(
lambda: self._run_normal_chat_timeout_check_cycle(NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS),
"debug",
f"聊天超时检查任务已启动 间隔:{NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS}s",
"_normal_chat_timeout_check_task",
),
(
lambda: self._run_absent_into_chat(HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS),
"debug",
f"状态评估任务已启动 间隔:{HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS}s",
"_hf_judge_state_update_task",
),
(
self._run_cleanup_cycle,
"info",
f"清理任务已启动 间隔:{CLEANUP_INTERVAL_SECONDS}s",
"_cleanup_task",
),
(
self._run_logging_cycle,
"info",
f"日志任务已启动 间隔:{LOG_INTERVAL_SECONDS}s",
"_logging_task",
),
# 新增兴趣评估任务配置
(
self._run_into_focus_cycle,
@@ -136,13 +109,6 @@ class BackgroundTaskManager:
f"私聊激活检查任务已启动 间隔:{PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS}s",
"_private_chat_activation_task",
),
# 新增GUI命令检测任务配置
# (
# lambda: self._run_detect_command_from_gui_cycle(3),
# "debug",
# f"GUI命令检测任务已启动 间隔:{3}s",
# "_detect_command_from_gui_task",
# ),
]
# 统一启动所有任务
@@ -207,7 +173,6 @@ class BackgroundTaskManager:
if state_changed:
current_state = self.mai_state_info.get_current_state()
await self.subheartflow_manager.enforce_subheartflow_limits()
# 状态转换处理
@@ -218,15 +183,6 @@ class BackgroundTaskManager:
logger.info("检测到离线,停用所有子心流")
await self.subheartflow_manager.deactivate_all_subflows()
async def _perform_absent_into_chat(self):
"""调用llm检测是否转换ABSENT-CHAT状态"""
logger.debug("[状态评估任务] 开始基于LLM评估子心流状态...")
await self.subheartflow_manager.sbhf_absent_into_chat()
async def _normal_chat_timeout_check_work(self):
"""检查处于CHAT状态的子心流是否因长时间未发言而超时并将其转为ABSENT"""
logger.debug("[聊天超时检查] 开始检查处于CHAT状态的子心流...")
await self.subheartflow_manager.sbhf_chat_into_absent()
async def _perform_cleanup_work(self):
"""执行子心流清理任务
@@ -253,9 +209,6 @@ class BackgroundTaskManager:
# 记录最终清理结果
logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流")
async def _perform_logging_work(self):
"""执行一轮状态日志记录。"""
await self.interest_logger.log_all_states()
# --- 新增兴趣评估工作函数 ---
async def _perform_into_focus_work(self):
@@ -263,32 +216,16 @@ class BackgroundTaskManager:
# 直接调用 subheartflow_manager 的方法,并传递当前状态信息
await self.subheartflow_manager.sbhf_absent_into_focus()
# --- 结束新增 ---
# --- 结束新增 ---
# --- Specific Task Runners --- #
async def _run_state_update_cycle(self, interval: int):
await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work)
async def _run_absent_into_chat(self, interval: int):
await _run_periodic_loop(task_name="Into Chat", interval=interval, task_func=self._perform_absent_into_chat)
async def _run_normal_chat_timeout_check_cycle(self, interval: int):
await _run_periodic_loop(
task_name="Normal Chat Timeout Check", interval=interval, task_func=self._normal_chat_timeout_check_work
)
async def _run_cleanup_cycle(self):
await _run_periodic_loop(
task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work
)
async def _run_logging_cycle(self):
await _run_periodic_loop(
task_name="State Logging", interval=LOG_INTERVAL_SECONDS, task_func=self._perform_logging_work
)
# --- 新增兴趣评估任务运行器 ---
async def _run_into_focus_cycle(self):
await _run_periodic_loop(
@@ -304,11 +241,3 @@ class BackgroundTaskManager:
interval=interval,
task_func=self.subheartflow_manager.sbhf_absent_private_into_focus,
)
# # 有api之后删除
# async def _run_detect_command_from_gui_cycle(self, interval: int):
# await _run_periodic_loop(
# task_name="Detect Command from GUI",
# interval=interval,
# task_func=self.subheartflow_manager.detect_command_from_gui,
# )

View File

@@ -4,10 +4,8 @@ from src.config.config import global_config
from src.common.logger_manager import get_logger
from typing import Any, Optional
from src.tools.tool_use import ToolUser
from src.chat.person_info.relationship_manager import relationship_manager # Module instance
from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.chat.heart_flow.interest_logger import InterestLogger # Import InterestLogger
from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager
logger = get_logger("heartflow")
@@ -17,16 +15,10 @@ class Heartflow:
"""主心流协调器,负责初始化并协调各个子系统:
- 状态管理 (MaiState)
- 子心流管理 (SubHeartflow)
- 思考过程 (Mind)
- 日志记录 (InterestLogger)
- 后台任务 (BackgroundTaskManager)
"""
def __init__(self):
# 核心状态
self.current_mind = "什么也没想" # 当前主心流想法
self.past_mind = [] # 历史想法记录
# 状态管理相关
self.current_state: MaiStateInfo = MaiStateInfo() # 当前状态信息
self.mai_state_manager: MaiStateManager = MaiStateManager() # 状态决策管理器
@@ -34,23 +26,11 @@ class Heartflow:
# 子心流管理 (在初始化时传入 current_state)
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
# LLM模型配置
self.llm_model = LLMRequest(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
)
# 外部依赖模块
self.tool_user_instance = ToolUser() # 工具使用模块
self.relationship_manager_instance = relationship_manager # 关系管理模块
self.interest_logger: InterestLogger = InterestLogger(self.subheartflow_manager, self) # 兴趣日志记录器
# 后台任务管理器 (整合所有定时任务)
self.background_task_manager: BackgroundTaskManager = BackgroundTaskManager(
mai_state_info=self.current_state,
mai_state_manager=self.mai_state_manager,
subheartflow_manager=self.subheartflow_manager,
interest_logger=self.interest_logger,
)
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:

View File

@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
class InterestChatting:
def __init__(
self,
decay_rate=global_config.default_decay_rate_per_second,
decay_rate=global_config.focus_chat.default_decay_rate_per_second,
max_interest=MAX_INTEREST,
trigger_threshold=global_config.reply_trigger_threshold,
trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
max_probability=MAX_REPLY_PROBABILITY,
):
# 基础属性初始化

View File

@@ -1,212 +0,0 @@
import asyncio
import time
import json
import os
import traceback
from typing import TYPE_CHECKING, Dict, List
from src.common.logger_manager import get_logger
# Need chat_manager to get stream names
from src.chat.message_receive.chat_stream import chat_manager
if TYPE_CHECKING:
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.chat.heart_flow.sub_heartflow import SubHeartflow
from src.chat.heart_flow.heartflow import Heartflow # 导入 Heartflow 类型
logger = get_logger("interest")
# Consider moving log directory/filename constants here
LOG_DIRECTORY = "logs/interest"
HISTORY_LOG_FILENAME = "interest_history.log"
def _ensure_log_directory():
"""确保日志目录存在。"""
os.makedirs(LOG_DIRECTORY, exist_ok=True)
logger.info(f"已确保日志目录 '{LOG_DIRECTORY}' 存在")
def _clear_and_create_log_file():
"""清除日志文件并创建新的日志文件。"""
if os.path.exists(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)):
os.remove(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME))
with open(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME), "w", encoding="utf-8") as f:
f.write("")
class InterestLogger:
"""负责定期记录主心流和所有子心流的状态到日志文件。"""
def __init__(self, subheartflow_manager: "SubHeartflowManager", heartflow: "Heartflow"):
"""
初始化 InterestLogger。
Args:
subheartflow_manager: 子心流管理器实例。
heartflow: 主心流实例,用于获取主心流状态。
"""
self.subheartflow_manager = subheartflow_manager
self.heartflow = heartflow # 存储 Heartflow 实例
self._history_log_file_path = os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)
_ensure_log_directory()
_clear_and_create_log_file()
async def get_all_subflow_states(self) -> Dict[str, Dict]:
"""并发获取所有活跃子心流的当前完整状态。"""
all_flows: List["SubHeartflow"] = self.subheartflow_manager.get_all_subheartflows()
tasks = []
results = {}
if not all_flows:
# logger.debug("未找到任何子心流状态")
return results
for subheartflow in all_flows:
if await self.subheartflow_manager.get_or_create_subheartflow(subheartflow.subheartflow_id):
tasks.append(
asyncio.create_task(subheartflow.get_full_state(), name=f"get_state_{subheartflow.subheartflow_id}")
)
else:
logger.warning(f"子心流 {subheartflow.subheartflow_id} 在创建任务前已消失")
if tasks:
done, pending = await asyncio.wait(tasks, timeout=5.0)
if pending:
logger.warning(f"获取子心流状态超时,有 {len(pending)} 个任务未完成")
for task in pending:
task.cancel()
for task in done:
stream_id_str = task.get_name().split("get_state_")[-1]
stream_id = stream_id_str
if task.cancelled():
logger.warning(f"获取子心流 {stream_id} 状态的任务已取消(超时)", exc_info=False)
elif task.exception():
exc = task.exception()
logger.warning(f"获取子心流 {stream_id} 状态出错: {exc}")
else:
result = task.result()
results[stream_id] = result
logger.trace(f"成功获取 {len(results)} 个子心流的完整状态")
return results
async def log_all_states(self):
"""获取主心流状态和所有子心流的完整状态并写入日志文件。"""
try:
current_timestamp = time.time()
# main_mind = self.heartflow.current_mind
# 获取 Mai 状态名称
mai_state_name = self.heartflow.current_state.get_current_state().name
all_subflow_states = await self.get_all_subflow_states()
log_entry_base = {
"timestamp": round(current_timestamp, 2),
# "main_mind": main_mind,
"mai_state": mai_state_name,
"subflow_count": len(all_subflow_states),
"subflows": [],
}
if not all_subflow_states:
# logger.debug("没有获取到任何子心流状态,仅记录主心流状态")
with open(self._history_log_file_path, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n")
return
subflow_details = []
items_snapshot = list(all_subflow_states.items())
for stream_id, state in items_snapshot:
group_name = stream_id
try:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
if chat_stream.group_info:
group_name = chat_stream.group_info.group_name
elif chat_stream.user_info:
group_name = f"私聊_{chat_stream.user_info.user_nickname}"
except Exception as e:
logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
interest_state = state.get("interest_state", {})
subflow_entry = {
"stream_id": stream_id,
"group_name": group_name,
"sub_mind": state.get("current_mind", "未知"),
"sub_chat_state": state.get("chat_state", "未知"),
"interest_level": interest_state.get("interest_level", 0.0),
"start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
# "is_above_threshold": interest_state.get("is_above_threshold", False),
}
subflow_details.append(subflow_entry)
log_entry_base["subflows"] = subflow_details
with open(self._history_log_file_path, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n")
except IOError as e:
logger.error(f"写入状态日志到 {self._history_log_file_path} 出错: {e}")
except Exception as e:
logger.error(f"记录状态时发生意外错误: {e}")
logger.error(traceback.format_exc())
async def api_get_all_states(self):
"""获取主心流和所有子心流的状态。"""
try:
current_timestamp = time.time()
# main_mind = self.heartflow.current_mind
# 获取 Mai 状态名称
mai_state_name = self.heartflow.current_state.get_current_state().name
all_subflow_states = await self.get_all_subflow_states()
log_entry_base = {
"timestamp": round(current_timestamp, 2),
# "main_mind": main_mind,
"mai_state": mai_state_name,
"subflow_count": len(all_subflow_states),
"subflows": [],
}
subflow_details = []
items_snapshot = list(all_subflow_states.items())
for stream_id, state in items_snapshot:
group_name = stream_id
try:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
if chat_stream.group_info:
group_name = chat_stream.group_info.group_name
elif chat_stream.user_info:
group_name = f"私聊_{chat_stream.user_info.user_nickname}"
except Exception as e:
logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
interest_state = state.get("interest_state", {})
subflow_entry = {
"stream_id": stream_id,
"group_name": group_name,
"sub_mind": state.get("current_mind", "未知"),
"sub_chat_state": state.get("chat_state", "未知"),
"interest_level": interest_state.get("interest_level", 0.0),
"start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
# "is_above_threshold": interest_state.get("is_above_threshold", False),
}
subflow_details.append(subflow_entry)
log_entry_base["subflows"] = subflow_details
return subflow_details
except Exception as e:
logger.error(f"记录状态时发生意外错误: {e}")
logger.error(traceback.format_exc())

View File

@@ -13,72 +13,24 @@ logger = get_logger("mai_state")
# The line `enable_unlimited_hfc_chat = False` is setting a configuration parameter that controls
# whether a specific debugging feature is enabled or not. When `enable_unlimited_hfc_chat` is set to
# `False`, it means that the debugging feature for unlimited focused chatting is disabled.
enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
# enable_unlimited_hfc_chat = False
# enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
enable_unlimited_hfc_chat = False
prevent_offline_state = True
# 目前默认不启用OFFLINE状
# 不同状态下普通聊天的最大消息数
base_normal_chat_num = global_config.base_normal_chat_num
base_focused_chat_num = global_config.base_focused_chat_num
MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
# 不同状态下专注聊天的最大消息数
MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2
# -- 状态定义 --
# 目前默认不启用OFFLINE状
class MaiState(enum.Enum):
"""
聊天状态:
OFFLINE: 不在线:回复概率极低,不会进行任何聊天
PEEKING: 看一眼手机:回复概率较低,会进行一些普通聊天
NORMAL_CHAT: 正常看手机:回复概率较高,会进行一些普通聊天和少量的专注聊天
FOCUSED_CHAT: 专注聊天:回复概率极高,会进行专注聊天和少量的普通聊天
"""
OFFLINE = "不在线"
PEEKING = "看一眼手机"
NORMAL_CHAT = "正常看手机"
FOCUSED_CHAT = "专心看手机"
def get_normal_chat_max_num(self):
# 调试用
if enable_unlimited_hfc_chat:
return 1000
if self == MaiState.OFFLINE:
return 0
elif self == MaiState.PEEKING:
return MAX_NORMAL_CHAT_NUM_PEEKING
elif self == MaiState.NORMAL_CHAT:
return MAX_NORMAL_CHAT_NUM_NORMAL
elif self == MaiState.FOCUSED_CHAT:
return MAX_NORMAL_CHAT_NUM_FOCUSED
return None
def get_focused_chat_max_num(self):
# 调试用
if enable_unlimited_hfc_chat:
return 1000
if self == MaiState.OFFLINE:
return 0
elif self == MaiState.PEEKING:
return MAX_FOCUSED_CHAT_NUM_PEEKING
elif self == MaiState.NORMAL_CHAT:
return MAX_FOCUSED_CHAT_NUM_NORMAL
elif self == MaiState.FOCUSED_CHAT:
return MAX_FOCUSED_CHAT_NUM_FOCUSED
return None
class MaiStateInfo:
def __init__(self):
@@ -148,34 +100,18 @@ class MaiStateManager:
_time_since_last_min_check = current_time - current_state_info.last_min_check_time
next_state: Optional[MaiState] = None
# 辅助函数:根据 prevent_offline_state 标志调整目标状态
def _resolve_offline(candidate_state: MaiState) -> MaiState:
# 现在不再切换到OFFLINE直接返回当前状态
if candidate_state == MaiState.OFFLINE:
return current_status
return candidate_state
if current_status == MaiState.OFFLINE:
logger.info("当前[离线],没看手机,思考要不要上线看看......")
elif current_status == MaiState.PEEKING:
logger.info("当前[看一眼手机],思考要不要继续聊下去......")
elif current_status == MaiState.NORMAL_CHAT:
logger.info("当前在[正常看手机]思考要不要继续聊下去......")
elif current_status == MaiState.FOCUSED_CHAT:
logger.info("当前在[专心看手机]思考要不要继续聊下去......")
# 1. 移除每分钟概率切换到OFFLINE的逻辑
# if time_since_last_min_check >= 60:
# if current_status != MaiState.OFFLINE:
# if random.random() < 0.03: # 3% 概率切换到 OFFLINE
# potential_next = MaiState.OFFLINE
# resolved_next = _resolve_offline(potential_next)
# logger.debug(f"概率触发下线resolve 为 {resolved_next.value}")
# # 只有当解析后的状态与当前状态不同时才设置 next_state
# if resolved_next != current_status:
# next_state = resolved_next
# 2. 状态持续时间规则 (只有在规则1没有触发状态改变时才检查)
if next_state is None:
time_limit_exceeded = False
choices_list = []
@@ -183,44 +119,33 @@ class MaiStateManager:
rule_id = ""
if current_status == MaiState.OFFLINE:
# OFFLINE 状态不再自动切换,直接返回 None
return None
elif current_status == MaiState.PEEKING:
if time_in_current_status >= 600: # PEEKING 最多持续 600 秒
time_limit_exceeded = True
rule_id = "2.2 (From PEEKING)"
weights = [50, 50]
choices_list = [MaiState.NORMAL_CHAT, MaiState.FOCUSED_CHAT]
elif current_status == MaiState.NORMAL_CHAT:
if time_in_current_status >= 300: # NORMAL_CHAT 最多持续 300 秒
time_limit_exceeded = True
rule_id = "2.3 (From NORMAL_CHAT)"
weights = [50, 50]
choices_list = [MaiState.PEEKING, MaiState.FOCUSED_CHAT]
weights = [100]
choices_list = [MaiState.FOCUSED_CHAT]
elif current_status == MaiState.FOCUSED_CHAT:
if time_in_current_status >= 600: # FOCUSED_CHAT 最多持续 600 秒
time_limit_exceeded = True
rule_id = "2.4 (From FOCUSED_CHAT)"
weights = [50, 50]
choices_list = [MaiState.NORMAL_CHAT, MaiState.PEEKING]
weights = [100]
choices_list = [MaiState.NORMAL_CHAT]
if time_limit_exceeded:
next_state_candidate = random.choices(choices_list, weights=weights, k=1)[0]
resolved_candidate = _resolve_offline(next_state_candidate)
logger.debug(
f"规则{rule_id}:时间到,随机选择 {next_state_candidate.value}resolve 为 {resolved_candidate.value}"
f"规则{rule_id}:时间到,切换到 {next_state_candidate.value}resolve 为 {resolved_candidate.value}"
)
next_state = resolved_candidate # 直接使用解析后的状态
next_state = resolved_candidate
# 注意enable_unlimited_hfc_chat 优先级高于 prevent_offline_state
# 如果触发了这个它会覆盖上面规则2设置的 next_state
if enable_unlimited_hfc_chat:
logger.debug("调试用:开挂了,强制切换到专注聊天")
next_state = MaiState.FOCUSED_CHAT
# --- 最终决策 --- #
# 如果决定了下一个状态,且这个状态与当前状态不同,则返回下一个状态
if next_state is not None and next_state != current_status:
return next_state
else:
return None # 没有状态转换发生或无需重置计时器
return None

View File

@@ -14,6 +14,7 @@ from typing import Optional
import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.heart_flow.observation.observation import Observation
from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt
@@ -43,6 +44,7 @@ class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
self.chat_id = chat_id
self.platform = "qq"
# --- Initialize attributes (defaults) ---
self.is_group_chat: bool = False
@@ -53,19 +55,20 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.max_now_obs_len = global_config.observation_context_size
self.overlap_len = global_config.compressed_length
self.mid_memorys = []
self.max_mid_memory_len = global_config.compress_length_limit
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
# TODO: API-Adapter修改标记
self.model_summary = LLMRequest(
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def initialize(self):
@@ -83,7 +86,7 @@ class ChattingObservation(Observation):
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
@@ -101,11 +104,11 @@ class ChattingObservation(Observation):
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
def serch_message_by_text(self, text: str) -> Optional[MessageRecv]:
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
"""
根据回复的纯文本
1. 在talking_message中查找最新的最匹配的消息
@@ -118,12 +121,12 @@ class ChattingObservation(Observation):
for message in reverse_talking_message:
if message["processed_plain_text"] == text:
find_msg = message
logger.debug(f"找到的锚定消息find_msg: {find_msg}")
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
break
else:
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
msg_list.append({"message": message, "similarity": similarity})
logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
# logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg:
if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
@@ -137,8 +140,23 @@ class ChattingObservation(Observation):
return None
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
group_info = find_msg.get("chat_info", {}).get("group_info")
user_info = find_msg.get("chat_info", {}).get("user_info")
# 创建所需的user_info字段
user_info = {
"platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""),
"user_nickname": find_msg.get("user_nickname", ""),
"user_cardname": find_msg.get("user_cardname", ""),
}
# 创建所需的group_info字段如果是群聊的话
group_info = {}
if find_msg.get("chat_info_group_id"):
group_info = {
"platform": find_msg.get("chat_info_group_platform", ""),
"group_id": find_msg.get("chat_info_group_id", ""),
"group_name": find_msg.get("chat_info_group_name", ""),
}
content_format = ""
accept_format = ""
@@ -150,7 +168,7 @@ class ChattingObservation(Observation):
}
message_info = {
"platform": find_msg.get("platform"),
"platform": self.platform,
"message_id": find_msg.get("message_id"),
"time": find_msg.get("time"),
"group_info": group_info,
@@ -179,6 +197,8 @@ class ChattingObservation(Observation):
limit_mode="latest",
)
# print(f"new_messages_list: {new_messages_list}")
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
@@ -190,6 +210,7 @@ class ChattingObservation(Observation):
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
# print(f"压缩中oldest_messages: {oldest_messages}")
oldest_messages_str = await build_readable_messages(
messages=oldest_messages, timestamp_mode="normal", read_mark=0
)
@@ -232,21 +253,24 @@ class ChattingObservation(Observation):
self.oldest_messages = oldest_messages
self.oldest_messages_str = oldest_messages_str
# 构建中
# print(f"构建中self.talking_message: {self.talking_message}")
self.talking_message_str = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
)
# print(f"构建中self.talking_message_str: {self.talking_message_str}")
self.talking_message_str_truncate = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal",
read_mark=last_obs_time_mark,
truncate=True,
)
# print(f"构建中self.talking_message_str_truncate: {self.talking_message_str_truncate}")
self.person_list = await get_person_id_list(self.talking_message)
# print(f"self.11111person_list: {self.person_list}")
# print(f"构建中self.person_list: {self.person_list}")
logger.trace(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"

View File

@@ -3,6 +3,7 @@
from datetime import datetime
from src.common.logger_manager import get_logger
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.planners.action_manager import ActionManager
from typing import List
# Import the new utility function
@@ -16,15 +17,20 @@ class HFCloopObservation:
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
self.action_manager: ActionManager = None
self.all_actions = {}
def get_observe_info(self):
return self.observe_info
def add_loop_info(self, loop_info: CycleDetail):
# logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
# print(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
self.history_loop.append(loop_info)
def set_action_manager(self, action_manager: ActionManager):
self.action_manager = action_manager
self.all_actions = self.action_manager.get_registered_actions()
async def observe(self):
recent_active_cycles: List[CycleDetail] = []
for cycle in reversed(self.history_loop):
@@ -62,7 +68,6 @@ class HFCloopObservation:
if cycle_info_block:
cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
else:
# 如果最近的活动循环不是文本回复,或者没有活动循环
cycle_info_block = "\n"
# 获取history_loop中最新添加的
@@ -72,8 +77,16 @@ class HFCloopObservation:
end_time = last_loop.end_time
if start_time is not None and end_time is not None:
time_diff = int(end_time - start_time)
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n"
if time_diff > 60:
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n"
else:
cycle_info_block += "\n无法获取上一次阅读消息的时间\n"
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}\n"
else:
cycle_info_block += "\n你还没看过消息\n"
using_actions = self.action_manager.get_using_actions()
for action_name, action_info in using_actions.items():
action_description = action_info["description"]
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
self.observe_info = cycle_info_block

View File

@@ -1,55 +0,0 @@
from src.chat.heart_flow.observation.observation import Observation
from datetime import datetime
from src.common.logger_manager import get_logger
import traceback
# Import the new utility function
from src.chat.memory_system.Hippocampus import HippocampusManager
import jieba
from typing import List
logger = get_logger("memory")
class MemoryObservation(Observation):
def __init__(self, observe_id):
super().__init__(observe_id)
self.observe_info: str = ""
self.context: str = ""
self.running_memory: List[dict] = []
def get_observe_info(self):
for memory in self.running_memory:
self.observe_info += f"{memory['topic']}:{memory['content']}\n"
return self.observe_info
async def observe(self):
# ---------- 2. 获取记忆 ----------
try:
# 从聊天内容中提取关键词
chat_words = set(jieba.cut(self.context))
# 过滤掉停用词和单字词
keywords = [word for word in chat_words if len(word) > 1]
# 去重并限制数量
keywords = list(set(keywords))[:5]
logger.debug(f"取的关键词: {keywords}")
# 调用记忆系统获取相关记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
)
logger.debug(f"获取到的记忆: {related_memory}")
if related_memory:
for topic, memory in related_memory:
# 将记忆添加到 running_memory
self.running_memory.append(
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()}
)
logger.debug(f"添加新记忆: {topic} - {memory}")
except Exception as e:
logger.error(f"观察 记忆时出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -0,0 +1,32 @@
from datetime import datetime
from src.common.logger_manager import get_logger
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class StructureObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop = []
self.structured_info = []
def get_observe_info(self):
return self.structured_info
def add_structured_info(self, structured_info: dict):
self.structured_info.append(structured_info)
async def observe(self):
observed_structured_infos = []
for structured_info in self.structured_info:
if structured_info.get("ttl") > 0:
structured_info["ttl"] -= 1
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos

View File

@@ -2,33 +2,33 @@
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
from typing import List
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class WorkingObservation:
def __init__(self, observe_id):
class WorkingMemoryObservation:
def __init__(self, observe_id, working_memory: WorkingMemory):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop = []
self.structured_info = []
self.last_observe_time = datetime.now().timestamp()
self.working_memory = working_memory
self.retrieved_working_memory = []
def get_observe_info(self):
return self.structured_info
return self.working_memory
def add_structured_info(self, structured_info: dict):
self.structured_info.append(structured_info)
def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
self.retrieved_working_memory.append(retrieved_working_memory)
def get_retrieved_working_memory(self):
return self.retrieved_working_memory
async def observe(self):
observed_structured_infos = []
for structured_info in self.structured_info:
if structured_info.get("ttl") > 0:
structured_info["ttl"] -= 1
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos
pass

View File

@@ -89,6 +89,14 @@ class SubHeartflow:
await self.interest_chatting.initialize()
logger.debug(f"{self.log_prefix} InterestChatting 实例已初始化。")
# 创建并初始化 normal_chat_instance
chat_stream = chat_manager.get_stream(self.chat_id)
if chat_stream:
self.normal_chat_instance = NormalChat(chat_stream=chat_stream,interest_dict=self.get_interest_dict())
await self.normal_chat_instance.initialize()
await self.normal_chat_instance.start_chat()
logger.info(f"{self.log_prefix} NormalChat 实例已创建并启动。")
def update_last_chat_state_time(self):
self.chat_state_last_time = time.time() - self.chat_state_changed_time
@@ -181,8 +189,7 @@ class SubHeartflow:
# 创建 HeartFChatting 实例,并传递 从构造函数传入的 回调函数
self.heart_fc_instance = HeartFChatting(
chat_id=self.subheartflow_id,
observations=self.observations, # 传递所有观察者
on_consecutive_no_reply_callback=self.hfc_no_reply_callback, # <-- Use stored callback
observations=self.observations,
)
# 初始化并启动 HeartFChatting
@@ -200,55 +207,41 @@ class SubHeartflow:
self.heart_fc_instance = None # 创建或初始化异常,清理实例
return False
async def change_chat_state(self, new_state: "ChatState"):
"""更新sub_heartflow的聊天状态并管理 HeartFChatting 和 NormalChat 实例及任务"""
async def change_chat_state(self, new_state: ChatState) -> None:
"""
改变聊天状态。
如果转换到CHAT或FOCUSED状态时超过限制会保持当前状态。
"""
current_state = self.chat_state.chat_status
state_changed = False
log_prefix = f"[{self.log_prefix}]"
if current_state == new_state:
if new_state == ChatState.CHAT:
logger.debug(f"{log_prefix} 准备进入或保持 普通聊天 状态")
if await self._start_normal_chat():
logger.debug(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
state_changed = True
else:
logger.error(f"{log_prefix} 启动 NormalChat 失败,无法进入 CHAT 状态。")
# 启动失败时,保持当前状态
return
log_prefix = self.log_prefix
state_changed = False # 标记状态是否实际发生改变
# --- 状态转换逻辑 ---
if new_state == ChatState.CHAT:
# 移除限额检查逻辑
logger.debug(f"{log_prefix} 准备进入或保持 聊天 状态")
if current_state == ChatState.FOCUSED:
if await self._start_normal_chat(rewind=False):
# logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
state_changed = True
else:
logger.error(f"{log_prefix} 从FOCUSED状态启动 NormalChat 失败,无法进入 CHAT 状态。")
# 考虑是否需要回滚状态或采取其他措施
return # 启动失败,不改变状态
else:
if await self._start_normal_chat(rewind=True):
# logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
state_changed = True
else:
logger.error(f"{log_prefix} 从ABSENT状态启动 NormalChat 失败,无法进入 CHAT 状态。")
# 考虑是否需要回滚状态或采取其他措施
return # 启动失败,不改变状态
elif new_state == ChatState.FOCUSED:
# 移除限额检查逻辑
logger.debug(f"{log_prefix} 准备进入或保持 专注聊天 状态")
if await self._start_heart_fc_chat():
logger.debug(f"{log_prefix} 成功进入或保持 HeartFChatting 状态。")
state_changed = True
else:
logger.error(f"{log_prefix} 启动 HeartFChatting 失败,无法进入 FOCUSED 状态。")
# 启动失败状态回滚到之前的状态或ABSENT这里保持不改变
return # 启动失败,不改变状态
# 启动失败时,保持当前状态
return
elif new_state == ChatState.ABSENT:
logger.info(f"{log_prefix} 进入 ABSENT 状态,停止所有聊天活动...")
self.clear_interest_dict()
await self._stop_normal_chat()
await self._stop_heart_fc_chat()
state_changed = True # 总是可以成功转换到 ABSENT
state_changed = True
# --- 更新状态和最后活动时间 ---
if state_changed:
@@ -263,7 +256,6 @@ class SubHeartflow:
self.chat_state_last_time = 0
self.chat_state_changed_time = time.time()
else:
# 如果因为某些原因(如启动失败)没有成功改变状态,记录一下
logger.debug(
f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。"
)

View File

@@ -1,26 +1,14 @@
import asyncio
import time
import random
from typing import Dict, Any, Optional, List, Tuple
import json # 导入 json 模块
import functools # <-- 新增导入
# 导入日志模块
from typing import Dict, Any, Optional, List
import functools
from src.common.logger_manager import get_logger
# 导入聊天流管理模块
from src.chat.message_receive.chat_stream import chat_manager
# 导入心流相关类
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
# 导入LLM请求工具
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.individuality.individuality import Individuality
import traceback
# 初始化日志记录器
@@ -74,14 +62,6 @@ class SubHeartflowManager:
self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问
self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例
# 为 LLM 状态评估创建一个 LLMRequest 实例
# 使用与 Heartflow 相同的模型和参数
self.llm_state_evaluator = LLMRequest(
model=global_config.llm_heartflow, # 与 Heartflow 一致
temperature=0.6, # 与 Heartflow 一致
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
request_type="subheartflow_state_eval", # 保留特定的请求类型
)
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
"""强制改变指定子心流的状态"""
@@ -155,10 +135,6 @@ class SubHeartflowManager:
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
return None
# --- 新增:内部方法,用于尝试将单个子心流设置为 ABSENT ---
# --- 结束新增 ---
async def sleep_subheartflow(self, subheartflow_id: Any, reason: str) -> bool:
"""停止指定的子心流并将其状态设置为 ABSENT"""
log_prefix = "[子心流管理]"
@@ -189,54 +165,6 @@ class SubHeartflowManager:
return flows_to_stop
async def enforce_subheartflow_limits(self):
"""根据主状态限制停止超额子心流(优先停不活跃的)"""
# 使用 self.mai_state_info 获取当前状态和限制
current_mai_state = self.mai_state_info.get_current_state()
normal_limit = current_mai_state.get_normal_chat_max_num()
focused_limit = current_mai_state.get_focused_chat_max_num()
logger.debug(f"[限制] 状态:{current_mai_state.value}, 普通限:{normal_limit}, 专注限:{focused_limit}")
# 分类统计当前子心流
normal_flows = []
focused_flows = []
for flow_id, flow in list(self.subheartflows.items()):
if flow.chat_state.chat_status == ChatState.CHAT:
normal_flows.append((flow_id, getattr(flow, "last_active_time", 0)))
elif flow.chat_state.chat_status == ChatState.FOCUSED:
focused_flows.append((flow_id, getattr(flow, "last_active_time", 0)))
logger.debug(f"[限制] 当前数量 - 普通:{len(normal_flows)}, 专注:{len(focused_flows)}")
stopped = 0
# 处理普通聊天超额
if len(normal_flows) > normal_limit:
excess = len(normal_flows) - normal_limit
logger.info(f"[限制] 普通聊天超额({len(normal_flows)}>{normal_limit}), 停止{excess}")
normal_flows.sort(key=lambda x: x[1])
for flow_id, _ in normal_flows[:excess]:
if await self.sleep_subheartflow(flow_id, f"普通聊天超额(限{normal_limit})"):
stopped += 1
# 处理专注聊天超额(需重新统计)
focused_flows = [
(fid, t)
for fid, f in list(self.subheartflows.items())
if (t := getattr(f, "last_active_time", 0)) and f.chat_state.chat_status == ChatState.FOCUSED
]
if len(focused_flows) > focused_limit:
excess = len(focused_flows) - focused_limit
logger.info(f"[限制] 专注聊天超额({len(focused_flows)}>{focused_limit}), 停止{excess}")
focused_flows.sort(key=lambda x: x[1])
for flow_id, _ in focused_flows[:excess]:
if await self.sleep_subheartflow(flow_id, f"专注聊天超额(限{focused_limit})"):
stopped += 1
if stopped:
logger.info(f"[限制] 已停止{stopped}个子心流, 剩余:{len(self.subheartflows)}")
else:
logger.debug(f"[限制] 无需停止, 当前总数:{len(self.subheartflows)}")
async def deactivate_all_subflows(self):
"""将所有子心流的状态更改为 ABSENT (例如主状态变为OFFLINE时调用)"""
log_prefix = "[停用]"
@@ -272,27 +200,14 @@ class SubHeartflowManager:
)
async def sbhf_absent_into_focus(self):
"""评估子心流兴趣度,满足条件且未达上限则提升到FOCUSED状态基于start_hfc_probability"""
"""评估子心流兴趣度满足条件则提升到FOCUSED状态基于start_hfc_probability"""
try:
current_state = self.mai_state_info.get_current_state()
focused_limit = current_state.get_focused_chat_max_num()
# --- 新增:检查是否允许进入 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
# 检查是否允许进入 FOCUS 模式
if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
return # 如果不允许,直接返回
# --- 结束新增 ---
logger.info(f"当前状态 ({current_state.value}) 可以在{focused_limit}个群 专注聊天")
if focused_limit <= 0:
# logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 不允许 FOCUSED 子心流")
return
current_focused_count = self.count_subflows_by_state(ChatState.FOCUSED)
if current_focused_count >= focused_limit:
logger.debug(f"已达专注上限 ({current_focused_count}/{focused_limit})")
return
for sub_hf in list(self.subheartflows.values()):
@@ -320,11 +235,6 @@ class SubHeartflowManager:
if random.random() >= sub_hf.interest_chatting.start_hfc_probability:
continue
# 再次检查是否达到上限
if current_focused_count >= focused_limit:
logger.debug(f"{stream_name} 已达专注上限")
break
# 获取最新状态并执行提升
current_subflow = self.subheartflows.get(flow_id)
if not current_subflow:
@@ -337,283 +247,57 @@ class SubHeartflowManager:
# 执行状态提升
await current_subflow.change_chat_state(ChatState.FOCUSED)
# 验证提升结果
if (
final_subflow := self.subheartflows.get(flow_id)
) and final_subflow.chat_state.chat_status == ChatState.FOCUSED:
current_focused_count += 1
except Exception as e:
logger.error(f"启动HFC 兴趣评估失败: {e}", exc_info=True)
async def sbhf_absent_into_chat(self):
async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any):
"""
随机选一个 ABSENT 状态的 *群聊* 子心流,评估是否应转换为 CHAT 状态
每次调用最多转换一个
私聊会被忽略
"""
current_mai_state = self.mai_state_info.get_current_state()
chat_limit = current_mai_state.get_normal_chat_max_num()
async with self._lock:
# 1. 筛选出所有 ABSENT 状态的 *群聊* 子心流
absent_group_subflows = [
hf
for hf in self.subheartflows.values()
if hf.chat_state.chat_status == ChatState.ABSENT and hf.is_group_chat
]
if not absent_group_subflows:
# logger.debug("没有摸鱼的群聊子心流可以评估。") # 日志太频繁
return # 没有目标,直接返回
# 2. 随机选一个幸运儿
sub_hf_to_evaluate = random.choice(absent_group_subflows)
flow_id = sub_hf_to_evaluate.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
log_prefix = f"[{stream_name}]"
# 3. 检查 CHAT 上限
current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT)
if current_chat_count >= chat_limit:
logger.info(f"{log_prefix} 想看看能不能聊,但是聊天太多了, ({current_chat_count}/{chat_limit}) 满了。")
return # 满了,这次就算了
# --- 获取 FOCUSED 计数 ---
current_focused_count = self.count_subflows_by_state_nolock(ChatState.FOCUSED)
focused_limit = current_mai_state.get_focused_chat_max_num()
# --- 新增:获取聊天和专注群名 ---
chatting_group_names = []
focused_group_names = []
for flow_id, hf in self.subheartflows.items():
stream_name = chat_manager.get_stream_name(flow_id) or str(flow_id) # 保证有名字
if hf.chat_state.chat_status == ChatState.CHAT:
chatting_group_names.append(stream_name)
elif hf.chat_state.chat_status == ChatState.FOCUSED:
focused_group_names.append(stream_name)
# --- 结束新增 ---
# --- 获取观察信息和构建 Prompt ---
first_observation = sub_hf_to_evaluate.observations[0] # 喵~第一个观察者肯定存在的说
await first_observation.observe()
current_chat_log = first_observation.talking_message_str or "当前没啥聊天内容。"
_observation_summary = f"在[{stream_name}]这个群中,你最近看群友聊了这些:\n{current_chat_log}"
_mai_state_description = f"你当前状态: {current_mai_state.value}"
individuality = Individuality.get_instance()
personality_prompt = individuality.get_prompt(x_person=2, level=2)
prompt_personality = f"你正在扮演名为{individuality.name}的人类,{personality_prompt}"
# --- 修改:在 prompt 中加入当前聊天计数和群名信息 (条件显示) ---
chat_status_lines = []
if chatting_group_names:
chat_status_lines.append(
f"正在这些群闲聊 ({current_chat_count}/{chat_limit}): {', '.join(chatting_group_names)}"
)
if focused_group_names:
chat_status_lines.append(
f"正在这些群专注的聊天 ({current_focused_count}/{focused_limit}): {', '.join(focused_group_names)}"
)
chat_status_prompt = "当前没有在任何群聊中。" # 默认消息喵~
if chat_status_lines:
chat_status_prompt = "当前聊天情况,你已经参与了下面这几个群的聊天:\n" + "\n".join(
chat_status_lines
) # 拼接状态信息
prompt = (
f"{prompt_personality}\n"
f"{chat_status_prompt}\n" # <-- 喵!用了新的状态信息~
f"你当前尚未加入 [{stream_name}] 群聊天。\n"
f"{_observation_summary}\n---\n"
f"基于以上信息,你想不想开始在这个群闲聊?\n"
f"请说明理由,并以 JSON 格式回答,包含 'decision' (布尔值) 和 'reason' (字符串)。\n"
f'例如:{{"decision": true, "reason": "看起来挺热闹的,插个话"}}\n'
f'例如:{{"decision": false, "reason": "已经聊了好多,休息一下"}}\n'
f"请只输出有效的 JSON 对象。"
)
# --- 结束修改 ---
# --- 4. LLM 评估是否想聊 ---
yao_kai_shi_liao_ma, reason = await self._llm_evaluate_state_transition(prompt)
if reason:
if yao_kai_shi_liao_ma:
logger.info(f"{log_prefix} 打算开始聊,原因是: {reason}")
else:
logger.info(f"{log_prefix} 不打算聊,原因是: {reason}")
else:
logger.info(f"{log_prefix} 结果: {yao_kai_shi_liao_ma}")
if yao_kai_shi_liao_ma is None:
logger.debug(f"{log_prefix} 问AI想不想聊失败了这次算了。")
return # 评估失败,结束
if not yao_kai_shi_liao_ma:
# logger.info(f"{log_prefix} 现在不想聊这个群。")
return # 不想聊,结束
# --- 5. AI想聊再次检查额度并尝试转换 ---
# 再次检查以防万一
current_chat_count_before_change = self.count_subflows_by_state_nolock(ChatState.CHAT)
if current_chat_count_before_change < chat_limit:
logger.info(
f"{log_prefix} 想聊,而且还有精力 ({current_chat_count_before_change}/{chat_limit}),这就去聊!"
)
await sub_hf_to_evaluate.change_chat_state(ChatState.CHAT)
# 确认转换成功
if sub_hf_to_evaluate.chat_state.chat_status == ChatState.CHAT:
logger.debug(f"{log_prefix} 成功进入聊天状态!本次评估圆满结束。")
else:
logger.warning(
f"{log_prefix} 奇怪,尝试进入聊天状态失败了。当前状态: {sub_hf_to_evaluate.chat_state.chat_status.value}"
)
else:
logger.warning(
f"{log_prefix} AI说想聊但是刚问完就没空位了 ({current_chat_count_before_change}/{chat_limit})。真不巧,下次再说吧。"
)
# 无论转换成功与否,本次评估都结束了
# 锁在这里自动释放
# --- 新增:单独检查 CHAT 状态超时的任务 ---
async def sbhf_chat_into_absent(self):
"""定期检查处于 CHAT 状态的子心流是否因长时间未发言而超时,并将其转为 ABSENT。"""
log_prefix_task = "[聊天超时检查]"
transitioned_to_absent = 0
checked_count = 0
async with self._lock:
subflows_snapshot = list(self.subheartflows.values())
checked_count = len(subflows_snapshot)
if not subflows_snapshot:
return
for sub_hf in subflows_snapshot:
# 只检查 CHAT 状态的子心流
if sub_hf.chat_state.chat_status != ChatState.CHAT:
continue
flow_id = sub_hf.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
log_prefix = f"[{stream_name}]({log_prefix_task})"
should_deactivate = False
reason = ""
try:
last_bot_dong_zuo_time = sub_hf.get_normal_chat_last_speak_time()
if last_bot_dong_zuo_time > 0:
current_time = time.time()
time_since_last_bb = current_time - last_bot_dong_zuo_time
minutes_since_last_bb = time_since_last_bb / 60
# 60分钟强制退出
if minutes_since_last_bb >= 60:
should_deactivate = True
reason = "超过60分钟未发言强制退出"
else:
# 根据时间区间确定退出概率
exit_probability = 0
if minutes_since_last_bb < 5:
exit_probability = 0.01 # 1%
elif minutes_since_last_bb < 15:
exit_probability = 0.02 # 2%
elif minutes_since_last_bb < 30:
exit_probability = 0.04 # 4%
else:
exit_probability = 0.08 # 8%
# 随机判断是否退出
if random.random() < exit_probability:
should_deactivate = True
reason = f"{minutes_since_last_bb:.1f}分钟未发言,触发{exit_probability * 100:.0f}%退出概率"
except AttributeError:
logger.error(
f"{log_prefix} 无法获取 Bot 最后 BB 时间,请确保 SubHeartflow 相关实现正确。跳过超时检查。"
)
except Exception as e:
logger.error(f"{log_prefix} 检查 Bot 超时状态时出错: {e}", exc_info=True)
# 执行状态转换(如果超时)
if should_deactivate:
logger.debug(f"{log_prefix} 因超时 ({reason}),尝试转换为 ABSENT 状态。")
await sub_hf.change_chat_state(ChatState.ABSENT)
# 再次检查确保状态已改变
if sub_hf.chat_state.chat_status == ChatState.ABSENT:
transitioned_to_absent += 1
logger.info(f"{log_prefix} 不看了。")
else:
logger.warning(f"{log_prefix} 尝试因超时转换为 ABSENT 失败。")
if transitioned_to_absent > 0:
logger.debug(
f"{log_prefix_task} 完成,共检查 {checked_count} 个子心流,{transitioned_to_absent} 个因超时转为 ABSENT。"
)
# --- 结束新增 ---
async def _llm_evaluate_state_transition(self, prompt: str) -> Tuple[Optional[bool], Optional[str]]:
"""
使用 LLM 评估是否应进行状态转换,期望 LLM 返回 JSON 格式。
接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 CHAT。
通常在连续多次 "no_reply" 后被调用
对于私聊和群聊,都转换为 CHAT
Args:
prompt: 提供给 LLM 的提示信息,要求返回 {"decision": true/false}
Returns:
Optional[bool]: 如果成功解析 LLM 的 JSON 响应并提取了 'decision' 键的值,则返回该布尔值。
如果 LLM 调用失败、返回无效 JSON 或 JSON 中缺少 'decision' 键或其值不是布尔型,则返回 None。
subflow_id: 需要转换状态的子心流 ID
"""
log_prefix = "[LLM状态评估]"
async with self._lock:
subflow = self.subheartflows.get(subflow_id)
if not subflow:
logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 CHAT")
return
stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
current_state = subflow.chat_state.chat_status
if current_state == ChatState.FOCUSED:
target_state = ChatState.CHAT
log_reason = "转为CHAT"
logger.info(
f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})"
)
try:
# --- 真实的 LLM 调用 ---
response_text, _ = await self.llm_state_evaluator.generate_response_async(prompt)
# logger.debug(f"{log_prefix} 使用模型 {self.llm_state_evaluator.model_name} 评估")
logger.debug(f"{log_prefix} 原始输入: {prompt}")
logger.debug(f"{log_prefix} 原始评估结果: {response_text}")
# --- 解析 JSON 响应 ---
try:
# 尝试去除可能的Markdown代码块标记
cleaned_response = response_text.strip().strip("`").strip()
if cleaned_response.startswith("json"):
cleaned_response = cleaned_response[4:].strip()
data = json.loads(cleaned_response)
decision = data.get("decision") # 使用 .get() 避免 KeyError
reason = data.get("reason")
if isinstance(decision, bool):
logger.debug(f"{log_prefix} LLM评估结果 (来自JSON): {'建议转换' if decision else '建议不转换'}")
return decision, reason
# 从HFC到CHAT时清空兴趣字典
subflow.clear_interest_dict()
await subflow.change_chat_state(target_state)
final_state = subflow.chat_state.chat_status
if final_state == target_state:
logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}")
else:
logger.warning(
f"{log_prefix} LLM 返回的 JSON 中 'decision' 键的值不是布尔型: {decision}。响应: {response_text}"
f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}"
)
return None, None # 值类型不正确
except json.JSONDecodeError as json_err:
logger.warning(f"{log_prefix} LLM 返回的响应不是有效的 JSON: {json_err}。响应: {response_text}")
# 尝试在非JSON响应中查找关键词作为后备方案 (可选)
if "true" in response_text.lower():
logger.debug(f"{log_prefix} 在非JSON响应中找到 'true',解释为建议转换")
return True, None
if "false" in response_text.lower():
logger.debug(f"{log_prefix} 在非JSON响应中找到 'false',解释为建议不转换")
return False, None
return None, None # JSON 解析失败,也未找到关键词
except Exception as parse_err: # 捕获其他可能的解析错误
logger.warning(f"{log_prefix} 解析 LLM JSON 响应时发生意外错误: {parse_err}。响应: {response_text}")
return None, None
except Exception as e:
logger.error(f"{log_prefix} 调用 LLM 或处理其响应时出错: {e}", exc_info=True)
traceback.print_exc()
return None, None # LLM 调用或处理失败
logger.error(
f"[状态转换请求] 转换 {stream_name}{target_state.value} 时出错: {e}", exc_info=True
)
elif current_state == ChatState.ABSENT:
logger.debug(f"[状态转换请求] {stream_name} 处于 ABSENT 状态,尝试转为 CHAT")
await subflow.change_chat_state(ChatState.CHAT)
else:
logger.debug(
f"[状态转换请求] {stream_name} 当前状态为 {current_state.value},无需转换"
)
def count_subflows_by_state(self, state: ChatState) -> int:
"""统计指定状态的子心流数量"""
@@ -636,23 +320,6 @@ class SubHeartflowManager:
count += 1
return count
def get_active_subflow_minds(self) -> List[str]:
"""获取所有活跃(非ABSENT)子心流的当前想法"""
minds = []
for subheartflow in self.subheartflows.values():
# 检查子心流是否活跃(非ABSENT状态)
if subheartflow.chat_state.chat_status != ChatState.ABSENT:
minds.append(subheartflow.sub_mind.current_mind)
return minds
def update_main_mind_in_subflows(self, main_mind: str):
"""更新所有子心流的主心流想法"""
updated_count = sum(
1
for _, subheartflow in list(self.subheartflows.items())
if subheartflow.subheartflow_id in self.subheartflows
)
logger.debug(f"[子心流管理器] 更新了{updated_count}个子心流的主想法")
async def delete_subflow(self, subheartflow_id: Any):
"""删除指定的子心流。"""
@@ -669,91 +336,13 @@ class SubHeartflowManager:
else:
logger.warning(f"尝试删除不存在的 SubHeartflow: {subheartflow_id}")
# --- 新增:处理 HFC 无回复回调的专用方法 --- #
async def _handle_hfc_no_reply(self, subheartflow_id: Any):
"""处理来自 HeartFChatting 的连续无回复信号 (通过 partial 绑定 ID)"""
# 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent 内部会处理锁
# 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent_or_chat 内部会处理锁
logger.debug(f"[管理器 HFC 处理器] 接收到来自 {subheartflow_id} 的 HFC 无回复信号")
await self.sbhf_focus_into_absent_or_chat(subheartflow_id)
# --- 结束新增 --- #
# --- 新增:处理来自 HeartFChatting 的状态转换请求 --- #
async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any):
"""
接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 ABSENT 或 CHAT。
通常在连续多次 "no_reply" 后被调用。
对于私聊,总是转换为 ABSENT。
对于群聊,随机决定转换为 ABSENT 或 CHAT (如果 CHAT 未达上限)。
Args:
subflow_id: 需要转换状态的子心流 ID。
"""
async with self._lock:
subflow = self.subheartflows.get(subflow_id)
if not subflow:
logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 ABSENT/CHAT")
return
stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
current_state = subflow.chat_state.chat_status
if current_state == ChatState.FOCUSED:
target_state = ChatState.ABSENT # Default target
log_reason = "默认转换 (私聊或群聊)"
# --- Modify logic based on chat type --- #
if subflow.is_group_chat:
# Group chat: Decide between ABSENT or CHAT
if random.random() < 0.5: # 50% chance to try CHAT
current_mai_state = self.mai_state_info.get_current_state()
chat_limit = current_mai_state.get_normal_chat_max_num()
current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT)
if current_chat_count < chat_limit:
target_state = ChatState.CHAT
log_reason = f"群聊随机选择 CHAT (当前 {current_chat_count}/{chat_limit})"
else:
target_state = ChatState.ABSENT # Fallback to ABSENT if CHAT limit reached
log_reason = (
f"群聊随机选择 CHAT 但已达上限 ({current_chat_count}/{chat_limit}),转为 ABSENT"
)
else: # 50% chance to go directly to ABSENT
target_state = ChatState.ABSENT
log_reason = "群聊随机选择 ABSENT"
else:
# Private chat: Always go to ABSENT
target_state = ChatState.ABSENT
log_reason = "私聊退出 FOCUSED转为 ABSENT"
# --- End modification --- #
logger.info(
f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})"
)
try:
# 从HFC到CHAT时清空兴趣字典
subflow.clear_interest_dict()
await subflow.change_chat_state(target_state)
final_state = subflow.chat_state.chat_status
if final_state == target_state:
logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}")
else:
logger.warning(
f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}"
)
except Exception as e:
logger.error(
f"[状态转换请求] 转换 {stream_name}{target_state.value} 时出错: {e}", exc_info=True
)
elif current_state == ChatState.ABSENT:
logger.debug(f"[状态转换请求] {stream_name} 已处于 ABSENT 状态,无需转换")
else:
logger.warning(
f"[状态转换请求] 收到对 {stream_name} 的请求,但其状态为 {current_state.value} (非 FOCUSED),不执行转换"
)
# --- 结束新增 --- #
# --- 新增:处理私聊从 ABSENT 直接到 FOCUSED 的逻辑 --- #
async def sbhf_absent_private_into_focus(self):
"""检查 ABSENT 状态的私聊子心流是否有新活动,若有且未达 FOCUSED 上限,则直接转换为 FOCUSED。"""
@@ -761,19 +350,8 @@ class SubHeartflowManager:
transitioned_count = 0
checked_count = 0
# --- 获取当前状态和 FOCUSED 上限 --- #
current_mai_state = self.mai_state_info.get_current_state()
focused_limit = current_mai_state.get_focused_chat_max_num()
# --- 检查是否允许 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
# Log less frequently to avoid spam
# if int(time.time()) % 60 == 0:
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")
return
if focused_limit <= 0:
# logger.debug(f"{log_prefix_task} 当前状态 ({current_mai_state.value}) 不允许 FOCUSED 子心流")
if not global_config.chat.allow_focus_mode:
return
async with self._lock:
@@ -794,12 +372,6 @@ class SubHeartflowManager:
# --- 遍历评估每个符合条件的私聊 --- #
for sub_hf in eligible_subflows:
# --- 再次检查 FOCUSED 上限,因为可能有多个同时激活 --- #
if current_focused_count >= focused_limit:
logger.debug(
f"{log_prefix_task} 已达专注上限 ({current_focused_count}/{focused_limit}),停止检查后续私聊。"
)
break # 已满,无需再检查其他私聊
flow_id = sub_hf.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
@@ -823,9 +395,6 @@ class SubHeartflowManager:
# --- 如果活跃且未达上限,则尝试转换 --- #
if is_active:
logger.info(
f"{log_prefix} 检测到活跃且未达专注上限 ({current_focused_count}/{focused_limit}),尝试转换为 FOCUSED。"
)
await sub_hf.change_chat_state(ChatState.FOCUSED)
# 确认转换成功
if sub_hf.chat_state.chat_status == ChatState.FOCUSED:

View File

@@ -4,13 +4,14 @@ import math
import random
import time
import re
import json
from itertools import combinations
import jieba
import networkx as nx
import numpy as np
from collections import Counter
from ...common.database import db
from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
@@ -19,9 +20,11 @@ from ..utils.chat_message_builder import (
build_readable_messages,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install
from ...config.config import global_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3)
@@ -192,21 +195,19 @@ class Hippocampus:
def __init__(self):
self.memory_graph = MemoryGraph()
self.llm_topic_judge = None
self.llm_summary = None
self.model_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
self.config = None
def initialize(self, global_config):
# 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
self.config = MemoryConfig.from_global_config(global_config)
def initialize(self):
# 初始化子组件
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory")
# TODO: API-Adapter修改标记
self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
self.model_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -792,7 +793,6 @@ class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
def get_memory_sample(self):
"""从数据库获取记忆样本"""
@@ -801,13 +801,13 @@ class EntorhinalCortex:
# 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler(
n_hours1=self.config.memory_build_distribution[0],
std_hours1=self.config.memory_build_distribution[1],
weight1=self.config.memory_build_distribution[2],
n_hours2=self.config.memory_build_distribution[3],
std_hours2=self.config.memory_build_distribution[4],
weight2=self.config.memory_build_distribution[5],
total_samples=self.config.build_memory_sample_num,
n_hours1=global_config.memory.memory_build_distribution[0],
std_hours1=global_config.memory.memory_build_distribution[1],
weight1=global_config.memory.memory_build_distribution[2],
n_hours2=global_config.memory.memory_build_distribution[3],
std_hours2=global_config.memory.memory_build_distribution[4],
weight2=global_config.memory.memory_build_distribution[5],
total_samples=global_config.memory.memory_build_sample_num,
)
timestamps = sample_scheduler.get_timestamp_array()
@@ -818,7 +818,7 @@ class EntorhinalCortex:
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -858,11 +858,12 @@ class EntorhinalCortex:
if all_valid:
# 更新数据库中的记忆次数
for message in messages:
# 确保在更新前获取最新的 memorized_times,以防万一
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
db.messages.update_one(
{"_id": message["_id"]}, {"$set": {"memorized_times": current_memorized_times + 1}}
)
# 使用 Peewee 更新记录
Messages.update(memorized_times=current_memorized_times + 1).where(
Messages.message_id == message["message_id"]
).execute()
return messages # 直接返回原始的消息列表
# 如果获取失败或消息无效,增加尝试次数
@@ -875,12 +876,9 @@ class EntorhinalCortex:
async def sync_memory_to_db(self):
"""将记忆图同步到数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(db.graph_data.nodes.find())
db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node["concept"]: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
@@ -894,44 +892,39 @@ class EntorhinalCortex:
created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
if concept not in db_nodes_dict:
# 将memory_items转换为JSON字符串
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if concept not in db_nodes:
# 数据库中缺少的节点,添加
node_data = {
"concept": concept,
"memory_items": memory_items,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
db.graph_data.nodes.insert_one(node_data)
GraphNodes.create(
concept=concept,
memory_items=memory_items_json,
hash=memory_hash,
created_time=created_time,
last_modified=last_modified,
)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get("hash", None)
db_node = db_nodes[concept]
db_hash = db_node.hash
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
db.graph_data.nodes.update_one(
{"concept": concept},
{
"$set": {
"memory_items": memory_items,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
},
)
db_node.memory_items = memory_items_json
db_node.hash = memory_hash
db_node.last_modified = last_modified
db_node.save()
# 处理边的信息
db_edges = list(db.graph_data.edges.find())
db_edges = list(GraphEdges.select())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"])
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 检查并更新边
for source, target, data in memory_edges:
@@ -945,29 +938,22 @@ class EntorhinalCortex:
if edge_key not in db_edge_dict:
# 添加新边
edge_data = {
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
db.graph_data.edges.insert_one(edge_data)
GraphEdges.create(
source=source,
target=target,
strength=strength,
hash=edge_hash,
created_time=created_time,
last_modified=last_modified,
)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
db.graph_data.edges.update_one(
{"source": source, "target": target},
{
"$set": {
"hash": edge_hash,
"strength": strength,
"created_time": created_time,
"last_modified": last_modified,
}
},
)
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
edge.hash = edge_hash
edge.strength = strength
edge.last_modified = last_modified
edge.save()
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
@@ -978,29 +964,29 @@ class EntorhinalCortex:
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = list(db.graph_data.nodes.find())
nodes = list(GraphNodes.select())
for node in nodes:
concept = node["concept"]
memory_items = node.get("memory_items", [])
concept = node.concept
memory_items = json.loads(node.memory_items)
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在
if "created_time" not in node or "last_modified" not in node:
if not node.created_time or not node.last_modified:
need_update = True
# 更新数据库中的节点
update_data = {}
if "created_time" not in node:
if not node.created_time:
update_data["created_time"] = current_time
if "last_modified" not in node:
if not node.last_modified:
update_data["last_modified"] = current_time
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data})
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.get("created_time", current_time)
last_modified = node.get("last_modified", current_time)
created_time = node.created_time or current_time
last_modified = node.last_modified or current_time
# 添加节点到图中
self.memory_graph.G.add_node(
@@ -1008,28 +994,30 @@ class EntorhinalCortex:
)
# 从数据库加载所有边
edges = list(db.graph_data.edges.find())
edges = list(GraphEdges.select())
for edge in edges:
source = edge["source"]
target = edge["target"]
strength = edge.get("strength", 1)
source = edge.source
target = edge.target
strength = edge.strength
# 检查时间字段是否存在
if "created_time" not in edge or "last_modified" not in edge:
if not edge.created_time or not edge.last_modified:
need_update = True
# 更新数据库中的边
update_data = {}
if "created_time" not in edge:
if not edge.created_time:
update_data["created_time"] = current_time
if "last_modified" not in edge:
if not edge.last_modified:
update_data["last_modified"] = current_time
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data})
GraphEdges.update(**update_data).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get("created_time", current_time)
last_modified = edge.get("last_modified", current_time)
created_time = edge.created_time or current_time
last_modified = edge.last_modified or current_time
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
@@ -1047,8 +1035,8 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
db.graph_data.nodes.delete_many({})
db.graph_data.edges.delete_many({})
GraphNodes.delete().execute()
GraphEdges.delete().execute()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1063,29 +1051,27 @@ class EntorhinalCortex:
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
node_data = {
"concept": concept,
"memory_items": memory_items,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.nodes.insert_one(node_data)
GraphNodes.create(
concept=concept,
memory_items=json.dumps(memory_items),
hash=self.hippocampus.calculate_node_hash(concept, memory_items),
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}")
# 重新写入边
edge_start = time.time()
for source, target, data in memory_edges:
edge_data = {
"source": source,
"target": target,
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.edges.insert_one(edge_data)
GraphEdges.create(
source=source,
target=target,
strength=data.get("strength", 1),
hash=self.hippocampus.calculate_edge_hash(source, target),
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
)
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1099,7 +1085,6 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
@@ -1159,7 +1144,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1170,7 +1155,7 @@ class ParahippocampalGyrus:
# 调用修改后的 topic_what不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
try:
task = self.hippocampus.llm_summary.generate_response_async(topic_what_prompt)
task = self.hippocampus.model_summary.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
except Exception as e:
logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}")
@@ -1222,7 +1207,7 @@ class ParahippocampalGyrus:
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = self.config.memory_compress_rate
compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
@@ -1322,7 +1307,7 @@ class ParahippocampalGyrus:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified")
if current_time - last_modified > 3600 * self.config.memory_forget_time:
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1
@@ -1430,8 +1415,8 @@ class ParahippocampalGyrus:
async def operation_consolidate_memory(self):
"""整合记忆:合并节点内相似的记忆项"""
start_time = time.time()
percentage = self.config.consolidate_memory_percentage
similarity_threshold = self.config.consolidation_similarity_threshold
percentage = global_config.memory.consolidate_memory_percentage
similarity_threshold = global_config.memory.consolidation_similarity_threshold
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
# 获取所有至少有2条记忆项的节点
@@ -1544,7 +1529,6 @@ class ParahippocampalGyrus:
class HippocampusManager:
_instance = None
_hippocampus = None
_global_config = None
_initialized = False
@classmethod
@@ -1559,19 +1543,15 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return cls._hippocampus
def initialize(self, global_config):
def initialize(self):
"""初始化海马体实例"""
if self._initialized:
return self._hippocampus
self._global_config = global_config
self._hippocampus = Hippocampus()
self._hippocampus.initialize(global_config)
self._hippocampus.initialize()
self._initialized = True
# 输出记忆系统参数信息
config = self._hippocampus.config
# 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes())
@@ -1579,9 +1559,9 @@ class HippocampusManager:
logger.success(f"""--------------------------------
记忆系统参数配置:
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
记忆构建分布: {config.memory_build_distribution}
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
记忆构建分布: {global_config.memory.memory_build_distribution}
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501

View File

@@ -7,7 +7,6 @@ import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
@@ -19,7 +18,7 @@ async def test_memory_system():
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
hippocampus_manager.initialize()
print("记忆系统初始化完成")
# 测试记忆构建

View File

@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.logger import get_module_logger # noqa E402
from src.common.database import db # noqa E402
from common.database.database import db # noqa E402
logger = get_module_logger("mem_alter")
console = Console()

View File

@@ -1,48 +0,0 @@
from dataclasses import dataclass
from typing import List
@dataclass
class MemoryConfig:
"""记忆系统配置类"""
# 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量
build_memory_sample_length: int # 每个样本的消息长度
memory_compress_rate: float # 记忆压缩率
# 记忆遗忘相关配置
memory_forget_time: int # 记忆遗忘时间(小时)
# 记忆过滤相关配置
memory_ban_words: List[str] # 记忆过滤词列表
# 新增:记忆整合相关配置
consolidation_similarity_threshold: float # 相似度阈值
consolidate_memory_percentage: float # 检查节点比例
consolidate_memory_interval: int # 记忆整合间隔
llm_topic_judge: str # 话题判断模型
llm_summary: str # 话题总结模型
@classmethod
def from_global_config(cls, global_config):
"""从全局配置创建记忆系统配置"""
# 使用 getattr 提供默认值,防止全局配置缺少这些项
return cls(
memory_build_distribution=getattr(
global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
), # 添加默认值
build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
memory_ban_words=getattr(global_config, "memory_ban_words", []),
# 新增加载整合配置,并提供默认值
consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
)

View File

@@ -38,10 +38,10 @@ class ChatBot:
async def _create_pfc_chat(self, message: MessageRecv):
try:
if global_config.experimental.pfc_chatting:
chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname)
if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
@@ -72,27 +72,11 @@ class ChatBot:
message_data["message_info"]["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"]
)
# print(message_data)
logger.trace(f"处理消息:{str(message_data)[:120]}...")
message = MessageRecv(message_data)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
# 用户黑名单拦截
if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if groupinfo is None:
logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截
if userinfo.user_id not in global_config.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return
# 群聊黑名单拦截
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -109,22 +93,16 @@ class ChatBot:
async def preprocess():
logger.trace("开始预处理消息...")
# 如果在私聊中
if groupinfo is None:
if group_info is None:
logger.trace("检测到私聊消息")
# 是否在配置信息中开启私聊模式
if global_config.enable_friend_chat:
logger.trace("私聊模式已启用")
# 是否进入PFC
if global_config.enable_pfc_chatting:
if global_config.experimental.pfc_chatting:
logger.trace("进入PFC私聊处理流程")
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 创建聊天流
logger.trace(f"{userinfo.user_id}创建/获取聊天流")
logger.trace(f"{user_info.user_id}创建/获取聊天流")
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
platform=message.message_info.platform,
user_info=user_info,
group_info=group_info,
)
message.update_chat_stream(chat)
await self.only_process_chat.process_message(message)
@@ -135,7 +113,7 @@ class ChatBot:
await self.heartflow_processor.process_message(message_data)
# 群聊默认进入心流消息处理逻辑
else:
logger.trace(f"检测到群聊消息群ID: {groupinfo.group_id}")
logger.trace(f"检测到群聊消息群ID: {group_info.group_id}")
await self.heartflow_processor.process_message(message_data)
if template_group_name:

View File

@@ -5,7 +5,8 @@ import copy
from typing import Dict, Optional
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
@@ -38,7 +39,7 @@ class ChatStream:
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
return {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
@@ -46,7 +47,6 @@ class ChatStream:
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
@@ -82,7 +82,13 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection()
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
@@ -107,15 +113,6 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
@@ -151,16 +148,43 @@ class ChatManager:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream)
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
@@ -175,7 +199,7 @@ class ChatManager:
group_info=group_info,
)
except Exception as e:
logger.error(f"创建聊天流失败: {e}")
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
# 保存到内存和数据库
@@ -205,15 +229,39 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊"
else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
if stream.saved:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self):
"""保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例

View File

@@ -38,7 +38,7 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
@@ -107,7 +107,7 @@ class MessageBuffer:
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
return True
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info

View File

@@ -279,7 +279,7 @@ class MessageManager:
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
if thinking_time > global_config.normal_chat.thinking_timeout:
logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
)

View File

@@ -1,9 +1,10 @@
import re
from typing import Union
from ...common.database import db
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
from src.common.logger import get_module_logger
logger = get_module_logger("message_storage")
@@ -29,34 +30,56 @@ class MessageStorage:
else:
filtered_detailed_plain_text = ""
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id": chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
# 使用过滤后的文本
"processed_plain_text": filtered_processed_plain_text,
"detailed_plain_text": filtered_detailed_plain_text,
"memorized_times": message.memorized_times,
}
db.messages.insert_one(message_data)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
detailed_plain_text=filtered_detailed_plain_text,
memorized_times=message.memorized_times,
)
except Exception:
logger.exception("存储消息失败")
@staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names():
db.create_collection("recalled_messages")
else:
# Table creation is handled by initialize_database in database_model.py
try:
message_data = {
"message_id": message_id,
"time": time,
"stream_id": chat_stream.stream_id,
}
db.recalled_messages.insert_one(message_data)
RecalledMessages.create(
message_id=message_id,
time=float(time), # Assuming time is a string representing a float timestamp
stream_id=chat_stream.stream_id,
)
except Exception:
logger.exception("存储撤回消息失败")
@@ -64,7 +87,9 @@ class MessageStorage:
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息"""
try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
# Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
except Exception:
logger.exception("删除撤回消息失败")

View File

@@ -12,7 +12,8 @@ import base64
from PIL import Image
import io
import os
from ...common.database import db
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config
from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
@@ -111,8 +110,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = os.environ[model["key"]]
self.base_url = os.environ[model["base_url"]]
self.api_key = os.environ[f"{model['provider']}_KEY"]
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database():
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}")
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage(
self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
"request_type": request_type,
"endpoint": endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
"status": "success",
"timestamp": datetime.now(),
}
db.llm_usage.insert_one(usage_data)
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=self.model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
@@ -500,11 +497,11 @@ class LLMRequest:
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
if global_config.model.normal.get("name") == old_model_name:
global_config.model.normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
if global_config.model.reasoning.get("name") == old_model_name:
global_config.model.reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
@@ -636,7 +633,7 @@ class LLMRequest:
**params_copy,
}
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length
payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")

View File

@@ -22,11 +22,11 @@ from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.normal_chat.willing.willing_manager import willing_manager
from src.config.config import global_config
logger = get_logger("chat")
logger = get_logger("normal_chat")
class NormalChat:
def __init__(self, chat_stream: ChatStream, interest_dict: dict = None):
def __init__(self, chat_stream: ChatStream, interest_dict: dict = {}):
"""初始化 NormalChat 实例。只进行同步操作。"""
# Basic info from chat_stream (sync)
@@ -73,8 +73,8 @@ class NormalChat:
messageinfo = message.message_info
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
@@ -121,8 +121,8 @@ class NormalChat:
message_id=thinking_id,
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -147,7 +147,7 @@ class NormalChat:
# 改为实例方法
async def _handle_emoji(self, message: MessageRecv, response: str):
"""处理表情包"""
if random() < global_config.emoji_chance:
if random() < global_config.normal_chat.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
if emoji_raw:
emoji_path, description = emoji_raw
@@ -160,8 +160,8 @@ class NormalChat:
message_id="mt" + str(thinking_time_point),
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -186,7 +186,7 @@ class NormalChat:
label=emotion,
stance=stance, # 使用 self.chat_stream
)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
async def _reply_interested_message(self) -> None:
"""
@@ -200,7 +200,7 @@ class NormalChat:
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
break
# 获取待处理消息列表
items_to_process = list(self.interest_dict.items())
if not items_to_process:
continue
@@ -430,7 +430,7 @@ class NormalChat:
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息中是否包含过滤词"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -445,7 +445,7 @@ class NormalChat:
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -481,7 +481,7 @@ class NormalChat:
try:
if exc := task.exception():
logger.error(f"[{self.stream_name}] 任务异常: {exc}")
logger.error(traceback.format_exc())
traceback.print_exc()
except asyncio.CancelledError:
logger.debug(f"[{self.stream_name}] 任务已取消")
except Exception as e:
@@ -522,4 +522,4 @@ class NormalChat:
logger.info(f"[{self.stream_name}] 清理了 {len(thinking_messages)} 条未处理的思考消息。")
except Exception as e:
logger.error(f"[{self.stream_name}] 清理思考消息时出错: {e}")
logger.error(traceback.format_exc())
traceback.print_exc()

View File

@@ -15,21 +15,22 @@ logger = get_logger("llm")
class NormalChatGenerator:
def __init__(self):
# TODO: API-Adapter修改标记
self.model_reasoning = LLMRequest(
model=global_config.llm_reasoning,
model=global_config.model.reasoning,
temperature=0.7,
max_tokens=3000,
request_type="response_reasoning",
)
self.model_normal = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_reasoning",
)
self.model_sum = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation"
model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
)
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"
@@ -37,7 +38,7 @@ class NormalChatGenerator:
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
if random.random() < global_config.model_reasoning_probability:
if random.random() < global_config.normal_chat.reasoning_model_probability:
self.current_model_type = "深深地"
current_model = self.model_reasoning
else:
@@ -51,7 +52,7 @@ class NormalChatGenerator:
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
if model_response:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
model_response = await self._process_response(model_response)
return model_response
@@ -113,7 +114,7 @@ class NormalChatGenerator:
- "中立":不表达明确立场或无关回应
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
4. 考虑回复者的人格设定为{global_config.personality_core}
4. 考虑回复者的人格设定为{global_config.personality.personality_core}
对话示例:
被回复「A就是笨」

View File

@@ -1,18 +1,20 @@
import asyncio
from src.config.config import global_config
from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager):
def __init__(self):
super().__init__()
self._decay_task: asyncio.Task = None
self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(1)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9)
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self):
if self._decay_task is None:
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
if interested_rate > 0.4:
current_willing += interested_rate - 0.3
if willing_info.is_mentioned_bot and current_willing < 1.0:
current_willing += 1
elif willing_info.is_mentioned_bot:
current_willing += 0.05
if willing_info.is_mentioned_bot:
current_willing += 1 if current_willing < 1.0 else 0.05
is_emoji_not_reply = False
if willing_info.is_emoji:
if self.global_config.emoji_response_penalty != 0:
current_willing *= self.global_config.emoji_response_penalty
if global_config.normal_chat.emoji_response_penalty != 0:
current_willing *= global_config.normal_chat.emoji_response_penalty
else:
is_emoji_not_reply = True
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊)
if (
willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
):
reply_probability = reply_probability / self.global_config.down_frequency_rate
reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
if is_emoji_not_reply:
reply_probability = 0
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8)
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4)
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id)

View File

@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
from src.config.config import global_config
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
probability = self._willing_to_probability(current_willing)
if w_info.is_emoji:
probability *= self.emoji_response_penalty
probability *= global_config.normal_chat.emoji_response_penalty
if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
probability /= self.down_frequency_rate
if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing

View File

@@ -1,6 +1,6 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass
from src.config.config import global_config, BotConfig
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock()
self.global_config: BotConfig = global_config
self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
Returns:
对应mode的WillingManager实例
"""
mode = global_config.willing_mode.lower()
mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode)

View File

@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import json
import json # 新增导入
import re
@@ -38,47 +39,49 @@ logger = get_logger("person_info")
person_info_default = {
"person_id": None,
"person_name": None,
"person_name": None, # 模型中已设为 null=True此默认值OK
"name_reason": None,
"platform": None,
"user_id": None,
"nickname": None,
# "age" : 0,
"platform": "unknown", # 提供非None的默认值
"user_id": "unknown", # 提供非None的默认值
"nickname": "Unknown", # 提供非None的默认值
"relationship_value": 0,
# "saved" : True,
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0,
"know_time": 0, # 修正拼写konw_time -> know_time
"msg_interval": 2000,
"msg_interval_list": [],
"user_cardname": None, # 添加群名片
"user_avatar": None, # 添加头像信息例如URL或标识符
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
"msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
"user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
"user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
}
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
# TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
max_tokens=256,
request_type="qv_name",
)
if "person_info" not in db.list_collection_names():
db.create_collection("person_info")
db.person_info.create_index("person_id", unique=True)
try:
db.connect(reuse_if_open=True)
db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
for doc in cursor:
if doc.get("person_name"):
self.person_name_list[doc["person_id"]] = doc["person_name"]
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: int):
"""获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform:
platform = platform.split("-")[1]
@@ -86,15 +89,27 @@ class PersonInfoManager:
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
def is_person_known(self, platform: str, user_id: int):
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
document = db.person_info.find_one({"person_id": person_id})
if document:
return True
else:
def _db_check_known_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
def get_person_id_by_person_name(self, person_name: str):
"""根据用户名获取用户ID"""
document = db.person_info.find_one({"person_name": person_name})
if document:
return document["person_id"]
else:
return ""
@staticmethod
async def create_person_info(person_id: str, data: dict = None):
"""创建一个项"""
@@ -103,73 +118,111 @@ class PersonInfoManager:
return
_person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id
model_fields = PersonInfo._meta.fields.keys()
final_data = {"person_id": person_id}
if data:
for key in _person_info_default:
if key != "person_id" and key in data:
_person_info_default[key] = data[key]
for key, value in data.items():
if key in model_fields:
final_data[key] = value
db.person_info.insert_one(_person_info_default)
for key, default_value in _person_info_default.items():
if key in model_fields and key not in final_data:
final_data[key] = default_value
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
final_data["msg_interval_list"] = json.dumps([])
def _db_create_sync(p_data: dict):
try:
PersonInfo.create(**p_data)
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全"""
if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段")
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
return
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
document = db.person_info.find_one({"person_id": person_id})
if document:
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
def _db_update_sync(p_id: str, f_name: str, val):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
if f_name == "msg_interval_list" and isinstance(val, list):
setattr(record, f_name, json.dumps(val))
else:
data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建")
await self.create_person_info(person_id, data)
setattr(record, f_name, val)
record.save()
return True, False
return False, True
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
if needs_creation:
logger.debug(f"更新时 {person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
creation_data[field_name] = value
if "platform" not in creation_data or "user_id" not in creation_data:
logger.warning(f"{person_id} 创建记录时platform/user_id 可能缺失。")
await self.create_person_info(person_id, creation_data)
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
if document:
if field_name not in PersonInfo._meta.fields:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True
else:
return False
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False
@staticmethod
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
# 尝试直接解析
parsed_json = json.loads(text)
# 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list):
if parsed_json: # 检查列表是否为空
if parsed_json:
parsed_json = parsed_json[0]
else: # 如果列表为空,重置为 None走后续逻辑
else:
parsed_json = None
# 确保解析结果是字典
if isinstance(parsed_json, dict):
return parsed_json
except json.JSONDecodeError:
# 解析失败,继续尝试其他方法
pass
except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
pass # 继续尝试其他方法
pass
# 如果直接解析失败或结果不是字典
try:
# 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text)
if matches:
parsed_obj = json.loads(matches[0])
if isinstance(parsed_obj, dict): # 确保是字典
if isinstance(parsed_obj, dict):
return parsed_obj
# 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -184,7 +237,6 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}")
# 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""}
@@ -199,9 +251,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason")
max_retries = 5 # 最大重试次数
max_retries = 5
current_try = 0
existing_names = ""
existing_names_str = ""
current_name_set = set(self.person_name_list.values())
while current_try < max_retries:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -216,45 +270,58 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
)
if existing_names:
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}\n"
if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
if len(current_name_set) < 50 and current_name_set:
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
qv_name_prompt += "请用json给出你的想法并给出理由示例如下"
qv_name_prompt += """{
"nickname": "昵称",
"reason": "理由"
}"""
# logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0])
if not result["nickname"]:
logger.error("生成的昵称为空,重试中...")
if not result or not result.get("nickname"):
logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1
continue
# 检查生成的昵称是否已存在
if result["nickname"] not in self.person_name_list.values():
# 更新数据库和内存中的列表
await self.update_one_field(person_id, "person_name", result["nickname"])
# await self.update_one_field(person_id, "nickname", user_nickname)
# await self.update_one_field(person_id, "avatar", user_avatar)
await self.update_one_field(person_id, "name_reason", result["reason"])
generated_nickname = result["nickname"]
self.person_name_list[person_id] = result["nickname"]
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
is_duplicate = False
if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
self.person_name_list[person_id] = generated_nickname
return result
else:
existing_names += f"{result['nickname']}"
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
if existing_names_str:
existing_names_str += ""
existing_names_str += generated_nickname
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
current_try += 1
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称")
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
return None
@staticmethod
@@ -264,30 +331,56 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空")
return
result = db.person_info.delete_one({"person_id": person_id})
if result.deleted_count > 0:
logger.debug(f"删除成功person_id={person_id}")
def _db_delete_sync(p_id: str):
try:
query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 person_id={person_id}")
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
return person_info_default.get(field_name)
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.trace(f"字段'{field_name}'不在Peewee模型中但存在于默认配置中。返回配置默认值。")
return copy.deepcopy(person_info_default[field_name])
logger.debug(f"get_value获取失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None
if field_name not in person_info_default:
logger.debug(f"get_value获取失败字段'{field_name}'未定义")
def _db_get_value_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
val = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
return copy.deepcopy(person_info_default.get(f_name, []))
return val
return None
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if document and field_name in document:
return document[field_name]
if value is not None:
return value
else:
default_value = copy.deepcopy(person_info_default[field_name])
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
default_value = copy.deepcopy(person_info_default.get(field_name))
logger.trace(f"获取{person_id}{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value
@staticmethod
@@ -297,93 +390,84 @@ class PersonInfoManager:
logger.debug("get_values获取失败person_id不能为空")
return {}
# 检查所有字段是否有效
for field in field_names:
if field not in person_info_default:
logger.debug(f"get_values获取失败字段'{field}'未定义")
return {}
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {}
for field in field_names:
result[field] = copy.deepcopy(
document.get(field, person_info_default[field]) if document else person_info_default[field]
)
def _db_get_record_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.trace(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if field_name == "msg_interval_list" and isinstance(value, str):
try:
result[field_name] = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
elif value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result
@staticmethod
async def del_all_undefined_field():
"""删除所有项里的未定义字段"""
# 获取所有已定义的字段名
defined_fields = set(person_info_default.keys())
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
logger.info(
"del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。"
)
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return
@staticmethod
async def get_specific_value_list(
field_name: str,
way: Callable[[Any], bool], # 接受任意类型值
way: Callable[[Any], bool],
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
"nickname",
lambda x: "admin" in x.lower()
)
"""
if field_name not in person_info_default:
logger.error(f"字段检查失败:'{field_name}'未定义")
if field_name not in PersonInfo._meta.fields:
logger.error(f"字段检查失败:'{field_name}'在 PersonInfo Peewee 模型中定义")
return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
result = {}
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
value = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(value, str):
try:
value = doc[field_name]
if way(value):
result[doc["person_id"]] = value
except (KeyError, TypeError, ValueError) as e:
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
processed_value = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
continue
else:
processed_value = value
return result
if way(processed_value):
found_results[record.person_id] = processed_value
except Exception as e_query:
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {}
async def personal_habit_deduction(self):
@@ -391,35 +475,31 @@ class PersonInfoManager:
try:
while 1:
await asyncio.sleep(600)
current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
current_time_dt = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
# "msg_interval"推断
msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list(
msg_interval_map_generated = False
msg_interval_lists_map = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
)
for person_id, msg_interval_list_ in msg_interval_lists.items():
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3)
try:
time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1
if delta > 0:
time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log) - 这部分保留
msg_interval_map = True
if len(time_interval) >= 30 + 10:
time_interval.sort()
msg_interval_map_generated = True
log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval)
plt.hist(
time_series_original,
@@ -441,34 +521,29 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
# 画图结束
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5]
# 计算截断后数据的 37% 分位数
if trimmed_interval: # 确保截断后列表不为空
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
# 更新数据库
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
if trimmed_interval:
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
logger.trace(
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
)
else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else:
logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
)
# --- 修改结束 ---
except Exception as e:
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
except Exception as e_inner:
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
continue
# 其他...
if msg_interval_map:
if msg_interval_map_generated:
logger.trace("已保存分布图到: logs/person_info")
current_time = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
current_time_dt_end = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400)
except Exception as e:
@@ -481,41 +556,27 @@ class PersonInfoManager:
"""
根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。
Args:
platform: 平台标识
user_id: 用户在该平台上的ID
nickname: 用户的昵称 (可选,用于创建新用户)
user_cardname: 用户的群名片 (可选,用于创建新用户)
user_avatar: 用户的头像信息 (可选,用于创建新用户)
Returns:
对应的 person_id。
"""
person_id = self.get_person_id(platform, user_id)
# 检查用户是否已存在
# 使用静态方法 get_person_id因此可以直接调用 db
document = db.person_info.find_one({"person_id": person_id})
def _db_check_exists_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if document is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = {
"platform": platform,
"user_id": user_id,
"user_id": str(user_id),
"nickname": nickname,
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间
# 注意:这里没有添加 user_cardname 和 user_avatar因为它们不在 person_info_default 中
# 如果需要存储它们,需要先在 person_info_default 中定义
"know_time": int(datetime.datetime.now().timestamp()), # 修正拼写konw_time -> know_time
}
# 过滤掉值为 None 的初始数据
initial_data = {k: v for k, v in initial_data.items() if v is not None}
model_fields = PersonInfo._meta.fields.keys()
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
# 注意create_person_info 是静态方法
await PersonInfoManager.create_person_info(person_id, data=initial_data)
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
await self.create_person_info(person_id, data=filtered_initial_data)
logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
return person_id
@@ -525,34 +586,54 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None
# 优先从内存缓存查找 person_id
found_person_id = None
for pid, name in self.person_name_list.items():
if name == person_name:
for pid, name_in_cache in self.person_name_list.items():
if name_in_cache == person_name:
found_person_id = pid
break # 找到第一个匹配就停止
break
if not found_person_id:
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
document = db.person_info.find_one({"person_name": person_name})
if document:
found_person_id = document.get("person_id")
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
return None # 数据库也找不到
# 根据找到的 person_id 获取所需信息
if found_person_id:
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
person_data = await self.get_values(found_person_id, required_fields)
if person_data: # 确保 get_values 成功返回
return person_data
def _db_find_by_name_sync(p_name_to_find: str):
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record:
found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败")
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else:
# 这理论上不应该发生,因为上面已经处理了找不到的情况
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id")
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None

View File

@@ -77,7 +77,7 @@ class RelationshipManager:
@staticmethod
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
is_known = person_info_manager.is_person_known(platform, user_id)
is_known = await person_info_manager.is_person_known(platform, user_id)
return is_known
@staticmethod

View File

@@ -174,6 +174,16 @@ async def _build_readable_messages_internal(
# 1 & 2: 获取发送者信息并提取消息组件
for msg in messages:
# 检查并修复缺少的user_info字段
if "user_info" not in msg:
# 创建user_info字段
msg["user_info"] = {
"platform": msg.get("user_platform", ""),
"user_id": msg.get("user_id", ""),
"user_nickname": msg.get("user_nickname", ""),
"user_cardname": msg.get("user_cardname", ""),
}
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
@@ -190,8 +200,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ:
person_name = f"{global_config.BOT_NICKNAME}(你)"
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +437,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = []
def get_anon_name(platform, user_id):
if user_id == global_config.BOT_QQ:
if user_id == global_config.bot.qq_account:
return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map:
@@ -451,10 +461,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 回复<aaa:bbb>
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
def reply_replacer(match):
def reply_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb)
anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -462,10 +472,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 @<aaa:bbb>
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
def at_replacer(match):
def at_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_at = get_anon_name(platform, bbb)
anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content)
@@ -501,7 +511,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)

View File

@@ -1,15 +1,15 @@
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db
from src.common.database.database_model import Messages, ThinkingLog
import time
import traceback
from typing import List
import json
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -60,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1:
@@ -72,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response
@@ -102,6 +85,7 @@ class InfoCatcher:
):
self.timing_results["make_response_time"] = response_duration
self.response_time = time.time()
self.response_messages = []
for msg in response_message:
self.response_messages.append(msg)
@@ -112,107 +96,112 @@ class InfoCatcher:
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find(
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
).sort("time", -1)
messages_between_query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
.order_by(Messages.time.desc())
)
result = list(messages_between)
result = list(messages_between_query)
print(f"查询结果数量: {len(result)}")
if result:
print(f"第一条消息时间: {result[0]['time']}")
print(f"最后一条消息时间: {result[-1]['time']}")
print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1].time}")
return result
except Exception as e:
print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return []
def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息
message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id
message_id_val = message.message_info.message_id
chat_id_val = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = (
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
.sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
messages_before_query = (
Messages.select()
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.order_by(Messages.time.desc())
.limit(global_config.chat.observation_context_size * 3)
)
return list(messages_before)
return list(messages_before_query)
def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = []
for message in message_list:
if not isinstance(message, dict):
message = self.message_to_dict(message)
# print(message)
for msg_item in message_list:
processed_msg_item = msg_item
if not isinstance(msg_item, dict):
processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = {
"time": message["time"],
"user_nickname": message["user_info"]["user_nickname"],
"processed_plain_text": message["processed_plain_text"],
"time": processed_msg_item.get("time"),
"user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
}
result.append(lite_message)
return result
@staticmethod
def message_to_dict(message):
if not message:
def message_to_dict(msg_obj):
if not msg_obj:
return None
if isinstance(message, dict):
return message
if isinstance(msg_obj, dict):
return msg_obj
if isinstance(msg_obj, Messages):
return {
# "message_id": message.message_info.message_id,
"time": message.message_info.time,
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
"time": msg_obj.time,
"user_id": msg_obj.user_id,
"user_nickname": msg_obj.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
"""将收集到的信息存储到数据库的 thinking_log 中喵~"""
try:
# 将消息对象转换为可序列化的字典喵~
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
"time": self.response_time,
"message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
}
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
# if self.response_mode == "heart_flow":
# thinking_log_data["mode_specific_data"] = self.heartflow_data
# elif self.response_mode == "reasoning":
# thinking_log_data["mode_specific_data"] = self.reasoning_data
# 将数据插入到 thinking_log 集合中喵~
db.thinking_log.insert_one(thinking_log_data)
log_entry = ThinkingLog(
chat_id=self.chat_id,
trigger_text=self.trigger_response_text,
response_text=self.response_text,
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
timing_results_json=json.dumps(self.timing_results),
chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()
return True
except Exception as e:

View File

@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database import db
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None
self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID"""
self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod
def _init_database():
"""初始化数据库"""
if "online_time" not in db.list_collection_names():
# 初始化数据库(在线时长)
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": datetime.now() + timedelta(minutes=1),
}
},
)
else:
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
if recent_record := db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
):
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
if recent_record:
# 如果有记录,则更新结束时间
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one(
{
"start_timestamp": current_time,
"end_timestamp": current_time + timedelta(minutes=1),
}
).inserted_id
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
)
self.record_id = new_record.id
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型
user_id = str(record.get("user_id", "unknown")) # 用户ID
model_name = record.get("model_name", "unknown") # 模型名称
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
break
return stats
@staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
# 统计在线时间
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
end_timestamp: datetime = record.get("end_timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if end_timestamp >= period_start:
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
# 统计消息量
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
chat_info = message.get("chat_info", None) # 聊天信息
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
message_time_ts = message.time # This is a float timestamp
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
if group_info is not None:
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}"
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
elif user_info:
# 若没有群聊信息,则尝试获取用户信息
chat_id = f"u{user_info['user_id']}"
chat_name = user_info["user_nickname"]
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else:
continue # 如果没有群组信息也没有用户信息,则跳过
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -13,8 +13,10 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...common.database import db
from ...common.database.database import db
from ...config.config import global_config
from ...common.database.database_model import Messages
from ...common.message_repository import find_messages, count_messages
logger = get_module_logger("chat_utils")
@@ -43,8 +45,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
@@ -64,18 +66,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
)
# 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text):
if re.search(f"@[\s\S]*?id:{global_config.bot.qq_account}", message.processed_plain_text):
is_at = True
is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply:
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\)[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
@@ -88,7 +90,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability
@@ -96,7 +98,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
@@ -107,20 +110,12 @@ async def get_embedding(text, request_type="embedding"):
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
recent_messages = list(
db.messages.find(
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"chat_id": 1,
"chat_info": 1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1, # 返回处理后的文本字段
},
)
.sort("time", -1)
.limit(limit)
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit
)
if not recent_messages:
@@ -142,17 +137,14 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> list:
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
recent_messages = list(
db.messages.find(
{"chat_id": chat_stream_id},
{
"user_info": 1,
},
)
.sort("time", -1)
.limit(limit)
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit
)
if not recent_messages:
@@ -160,10 +152,15 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
who_chat_in_group = []
for msg_db_data in recent_messages:
user_info = UserInfo.from_dict(msg_db_data["user_info"])
user_info = UserInfo.from_dict({
"platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", "")
})
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
@@ -321,7 +318,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
if global_config.enable_kaomoji_protection:
if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
@@ -340,8 +337,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2
max_sentence_num = global_config.response_max_sentence_num
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
@@ -349,20 +346,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
if global_config.enable_response_splitter:
if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
@@ -372,14 +369,14 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:
# sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
if global_config.enable_kaomoji_protection:
if global_config.response_splitter.enable_kaomoji_protection:
sentences = recover_kaomoji(sentences, kaomoji_mapping)
return sentences
@@ -580,26 +577,23 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
logger.error("stream_id 不能为空")
return 0, 0
# 直接查询时间范围内的消息
# time > start_time AND time <= end_time
query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 执行查询
messages_cursor = db.messages.find(query)
# 先获取消息数量
count = count_messages(filter_query)
# 遍历结果计算数量和长度
for msg in messages_cursor:
count += 1
total_length += len(msg.get("processed_plain_text", ""))
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
# logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}")
return count, total_length
except PyMongoError as e:
logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}")
return 0, 0
except Exception as e: # 保留一个通用异常捕获以防万一
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0

View File

@@ -8,7 +8,8 @@ import io
import numpy as np
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config
from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self):
if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{
"$set": {
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}")
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
@@ -116,51 +103,64 @@ class ImageManager:
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64)
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.save_emoji:
if global_config.emoji.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(emoji_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "emoji",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存表情包: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.save_pic:
if global_config.emoji.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "image",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存图片: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]"

View File

@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import db # noqa E402
from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件

View File

@@ -1,5 +1,6 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点
db: Database = DBWrapper()
memory_db: Database = DBWrapper()
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(_DB_FILE)

View File

@@ -0,0 +1,393 @@
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
from ..logger_manager import get_logger
logger = get_logger("database_model")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
# db = SqliteDatabase('MaiBot.db')
#
# 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=5432)
#
# 对于 MySQL:
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
create_time = DoubleField()
# group_info 字段:
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
group_platform = TextField()
group_id = TextField()
group_name = TextField()
# last_active_time: 1746623771.4825106 (时间戳精确到小数点后7位)
last_active_time = DoubleField()
# platform: "qq" (顶层平台字段)
platform = TextField()
# user_info 字段:
# platform: "qq"
# user_id: "1787882683"
# user_nickname: "墨梓柒(IceSakurary)"
# user_cardname: ""
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = TextField(index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
is_banned = BooleanField(default=False) # 是否被禁止注册
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
emotion = TextField(null=True)
record_time = FloatField() # 记录时间(被创建的时间)
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
# database = db # 继承自 BaseModel
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
chat_info_platform = TextField()
chat_info_user_platform = TextField()
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
emoji_hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
timestamp = FloatField() # 时间戳
type = TextField() # 图像类型,例如 "emoji"
class Meta:
# database = db # 继承自 BaseModel
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField(default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称
relationship_value = IntegerField(default=0) # 关系值
know_time = FloatField() # 认识时间 (时间戳)
msg_interval = IntegerField() # 消息间隔
# msg_interval_list: 存储为 JSON 字符串的列表
msg_interval_list = TextField(null=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = "thinking_logs"
class RecalledMessages(BaseModel):
"""
用于存储撤回消息记录的模型。
"""
message_id = TextField(index=True) # 被撤回的消息 ID
time = DoubleField() # 撤回操作发生的时间戳
stream_id = TextField() # 对应的 ChatStreams stream_id
class Meta:
table_name = "recalled_messages"
class GraphNodes(BaseModel):
"""
用于存储记忆图节点的模型
"""
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_nodes"
class GraphEdges(BaseModel):
"""
用于存储记忆图边的模型
"""
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_edges"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
]
)
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
"""
import sys
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到。")
needs_creation = True
break # 一个表丢失,无需进一步检查。
if not needs_creation:
# 检查字段
for model in models:
table_name = model._meta.table_name
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = model._meta.fields
for field_name in model_fields:
if field_name not in existing_columns:
logger.error(f"'{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
sys.exit(1)
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
if needs_creation:
logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
try:
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
logger.info("数据库表创建过程完成。")
except Exception as e:
logger.exception(f"创建表期间出错: {e}")
else:
logger.info("所有数据库表及字段均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -276,6 +276,40 @@ CHAT_STYLE_CONFIG = {
},
}
# Topic日志样式配置
NORMAL_CHAT_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<white>{time:YYYY-MM-DD HH:mm:ss}</white> | "
"<level>{level: <8}</level> | "
"<green>一般水群</green> | "
"<level>{message}</level>"
),
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
},
"simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <green>一般水群</green> | <green>{message}</green>", # noqa: E501
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
},
}
# Topic日志样式配置
FOCUS_CHAT_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<white>{time:YYYY-MM-DD HH:mm:ss}</white> | "
"<level>{level: <8}</level> | "
"<green>专注水群</green> | "
"<level>{message}</level>"
),
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}",
},
"simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <green>专注水群</green> | <green>{message}</green>", # noqa: E501
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}",
},
}
REMOTE_STYLE_CONFIG = {
"advanced": {
"console_format": (
@@ -629,22 +663,22 @@ PROCESSOR_STYLE_CONFIG = {
PLANNER_STYLE_CONFIG = {
"advanced": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #36DEFF>规划器</fg #36DEFF> | <fg #36DEFF>{message}</fg #36DEFF>",
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #4DCDFF>规划器</fg #4DCDFF> | <fg #4DCDFF>{message}</fg #4DCDFF>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
},
"simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #36DEFF>规划器</fg #36DEFF> | <fg #36DEFF>{message}</fg #36DEFF>",
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #4DCDFF>规划器</fg #4DCDFF> | <fg #4DCDFF>{message}</fg #4DCDFF>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
},
}
ACTION_TAKEN_STYLE_CONFIG = {
"advanced": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #22DAFF>动作</fg #22DAFF> | <fg #22DAFF>{message}</fg #22DAFF>",
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #FFA01F>动作</fg #FFA01F> | <fg #FFA01F>{message}</fg #FFA01F>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
},
"simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #22DAFF>动作</fg #22DAFF> | <fg #22DAFF>{message}</fg #22DAFF>",
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #FFA01F>动作</fg #FFA01F> | <fg #FFA01F>{message}</fg #FFA01F>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
},
}
@@ -935,6 +969,8 @@ MAIM_MESSAGE_STYLE_CONFIG = (
INTEREST_CHAT_STYLE_CONFIG = (
INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"]
)
NORMAL_CHAT_STYLE_CONFIG = NORMAL_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else NORMAL_CHAT_STYLE_CONFIG["advanced"]
FOCUS_CHAT_STYLE_CONFIG = FOCUS_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else FOCUS_CHAT_STYLE_CONFIG["advanced"]
def is_registered_module(record: dict) -> bool:

View File

@@ -21,6 +21,8 @@ from src.common.logger import (
WILLING_STYLE_CONFIG,
PFC_ACTION_PLANNER_STYLE_CONFIG,
MAI_STATE_CONFIG,
NORMAL_CHAT_STYLE_CONFIG,
FOCUS_CHAT_STYLE_CONFIG,
LPMM_STYLE_CONFIG,
HFC_STYLE_CONFIG,
OBSERVATION_STYLE_CONFIG,
@@ -95,7 +97,8 @@ MODULE_LOGGER_CONFIGS = {
"init": INIT_STYLE_CONFIG, # 初始化
"interest_chat": INTEREST_CHAT_STYLE_CONFIG, # 兴趣
"api": API_SERVER_STYLE_CONFIG, # API服务器
"maim_message": MAIM_MESSAGE_STYLE_CONFIG, # 消息服务
"normal_chat": NORMAL_CHAT_STYLE_CONFIG, # 一般水群
"focus_chat": FOCUS_CHAT_STYLE_CONFIG, # 专注水群
# ...如有更多模块,继续添加...
}

View File

@@ -1,11 +1,19 @@
from src.common.database import db
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,84 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。
Args:
message_filter: MongoDB 查询过滤器。
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns:
消息文档列表,如果出错则返回空列表。
消息字典列表,如果出错则返回空列表。
"""
try:
query = db.messages.find(message_filter)
query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.sort([("time", 1)]).limit(limit)
results = list(query)
query = query.order_by(Messages.time.asc()).limit(limit)
peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.sort([("time", -1)]).limit(limit)
latest_results = list(query)
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
# 假设消息文档中总是有 'time' 字段且可排序
results = sorted(latest_results, key=lambda msg: msg.get("time"))
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
query = query.sort(sort)
results = list(query)
peewee_sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
return results
return [_model_to_dict(msg) for msg in peewee_results]
except Exception as e:
log_message = (
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。
Args:
message_filter: MongoDB 查询过滤器。
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns:
符合条件的消息数量,如果出错则返回 0。
"""
try:
count = db.messages.count_documents(message_filter)
query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = query.count()
return count
except Exception as e:
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),
"mmc_version": global_config.MAI_VERSION,
"mmc_version": global_config.MMC_VERSION,
}
match platform.system():
@@ -133,9 +133,8 @@ class TelemetryHeartBeatTask(AsyncTask):
async def run(self):
# 发送心跳
if global_config.remote_enable:
if self.client_uuid is None:
if not await self._req_uuid():
if global_config.telemetry.enable:
if self.client_uuid is None and not await self._req_uuid():
logger.error("获取UUID失败跳过此次心跳")
return

View File

@@ -1,64 +1,68 @@
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from dataclasses import field, dataclass
import tomli
import tomlkit
import shutil
from datetime import datetime
from pathlib import Path
from packaging import version
from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger_manager import get_logger
from rich.traceback import install
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
ChatTargetConfig,
PersonalityConfig,
IdentityConfig,
PlatformsConfig,
ChatConfig,
NormalChatConfig,
FocusChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
ModelConfig,
)
install(extra_lines=3)
# 配置主程序日志格式
logger = get_logger("config")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = True
mai_version_main = "0.6.4"
mai_version_fix = "snapshot-1"
CONFIG_DIR = "config"
TEMPLATE_DIR = "template"
if mai_version_fix:
if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
else:
mai_version = f"{mai_version_main}-{mai_version_fix}"
else:
if is_test:
mai_version = f"test-{mai_version_main}"
else:
mai_version = mai_version_main
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.7.0-snapshot.1"
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
old_config_dir = f"{CONFIG_DIR}/old"
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
template_path = f"{TEMPLATE_DIR}/bot_config_template.toml"
old_config_path = f"{CONFIG_DIR}/bot_config.toml"
new_config_path = f"{CONFIG_DIR}/bot_config.toml"
# 检查配置文件是否存在
if not old_config_path.exists():
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
# 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(template_path, old_config_path)
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
return quit()
quit()
# 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f:
@@ -75,13 +79,15 @@ def update_config():
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True)
os.makedirs(old_config_dir, exist_ok=True)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
@@ -91,24 +97,23 @@ def update_config():
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
# 递归更新配置
def update_dict(target, source):
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
if not value:
target[key] = tomlkit.array()
else:
target[key] = tomlkit.array(value)
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
@@ -123,619 +128,57 @@ def update_config():
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成")
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
@dataclass
class BotConfig:
"""机器人配置类"""
class Config(ConfigBase):
"""配置类"""
INNER_VERSION: Version = None
MAI_VERSION: str = mai_version # 硬编码的版本信息
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
# bot
BOT_QQ: Optional[str] = "114514"
BOT_NICKNAME: Optional[str] = None
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
bot: BotConfig
chat_target: ChatTargetConfig
personality: PersonalityConfig
identity: IdentityConfig
platforms: PlatformsConfig
chat: ChatConfig
normal_chat: NormalChatConfig
focus_chat: FocusChatConfig
emoji: EmojiConfig
memory: MemoryConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
model: ModelConfig
# group
talk_allowed_groups = set()
talk_frequency_down_groups = set()
ban_user_id = set()
# personality
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(
default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
]
)
expression_style = "描述麦麦说话的表达风格,表达习惯"
# identity
identity_detail: List[str] = field(
default_factory=lambda: [
"身份特点",
"身份特点",
]
)
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁
gender: str = "" # 性别
appearance: str = "用几句话描述外貌特征" # 外貌特征
# chat
allow_focus_mode: bool = True # 是否允许专注聊天状态
base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
message_buffer: bool = True # 消息缓冲器
ban_words = set()
ban_msgs_regex = set()
# focus_chat
reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发
default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢
consecutive_no_reply_threshold = 3
compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5
compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
# normal_chat
model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率
model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率
emoji_chance: float = 0.2 # 发送表情包的基础概率
thinking_timeout: int = 120 # 思考时间
willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复
# emoji
max_emoji_num: int = 200 # 表情包最大数量
max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
save_pic: bool = False # 是否保存图片
save_emoji: bool = False # 是否保存表情包
steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
# memory
build_memory_interval: int = 600 # 记忆构建间隔(秒)
memory_build_distribution: list = field(
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
) # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num: int = 10 # 记忆构建采样数量
build_memory_sample_length: int = 20 # 记忆构建采样长度
memory_compress_rate: float = 0.1 # 记忆压缩率
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒)
consolidation_similarity_threshold: float = 0.7 # 相似度阈值
consolidate_memory_percentage: float = 0.01 # 检查节点比例
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
# mood
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
# keywords
keywords_reaction_rules = [] # 关键词回复规则
# chinese_typo
chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# response_splitter
enable_kaomoji_protection = False # 是否启用颜文字保护
enable_response_splitter = True # 是否启用回复分割器
response_max_length = 100 # 回复允许的最大长度
response_max_sentence_num = 3 # 回复允许的最大句子数
model_max_output_length: int = 800 # 最大回复长度
# remote
remote_enable: bool = True # 是否启用远程控制
# experimental
enable_friend_chat: bool = False # 是否启用好友聊天
# enable_think_flow: bool = False # 是否启用思考流程
talk_allowed_private = set()
enable_pfc_chatting: bool = False # 是否启用PFC聊天
# 模型配置
llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
# llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
llm_summary: Dict[str, str] = field(default_factory=lambda: {})
embedding: Dict[str, str] = field(default_factory=lambda: {})
vlm: Dict[str, str] = field(default_factory=lambda: {})
moderation: Dict[str, str] = field(default_factory=lambda: {})
llm_observation: Dict[str, str] = field(default_factory=lambda: {})
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_tool_use: Dict[str, str] = field(default_factory=lambda: {})
llm_plan: Dict[str, str] = field(default_factory=lambda: {})
api_urls: Dict[str, str] = field(default_factory=lambda: {})
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet
Args:
value[str]: 版本表达式(字符串)
Returns:
SpecifierSet
def load_config(config_path: str) -> Config:
"""
try:
converted = SpecifierSet(value)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@classmethod
def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
if "inner" in toml:
# 创建Config对象
try:
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.personality_core = personality_config.get("personality_core", config.personality_core)
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
if config.INNER_VERSION in SpecifierSet(">=1.7.0"):
config.expression_style = personality_config.get("expression_style", config.expression_style)
def identity(parent: dict):
identity_config = parent["identity"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
config.height = identity_config.get("height", config.height)
config.weight = identity_config.get("weight", config.weight)
config.age = identity_config.get("age", config.age)
config.gender = identity_config.get("gender", config.gender)
config.appearance = identity_config.get("appearance", config.appearance)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.save_pic = emoji_config.get("save_pic", config.save_pic)
config.save_emoji = emoji_config.get("save_emoji", config.save_emoji)
config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji)
def bot(parent: dict):
# 机器人基础配置
bot_config = parent["bot"]
bot_qq = bot_config.get("qq")
config.BOT_QQ = str(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def chat(parent: dict):
chat_config = parent["chat"]
config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode)
config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num)
config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num)
config.observation_context_size = chat_config.get(
"observation_context_size", config.observation_context_size
)
config.message_buffer = chat_config.get("message_buffer", config.message_buffer)
config.ban_words = chat_config.get("ban_words", config.ban_words)
for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
config.ban_msgs_regex.add(re.compile(r))
def normal_chat(parent: dict):
normal_chat_config = parent["normal_chat"]
config.model_reasoning_probability = normal_chat_config.get(
"model_reasoning_probability", config.model_reasoning_probability
)
config.model_normal_probability = normal_chat_config.get(
"model_normal_probability", config.model_normal_probability
)
config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance)
config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout)
config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode)
config.response_willing_amplifier = normal_chat_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = normal_chat_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate)
config.emoji_response_penalty = normal_chat_config.get(
"emoji_response_penalty", config.emoji_response_penalty
)
config.mentioned_bot_inevitable_reply = normal_chat_config.get(
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
)
config.at_bot_inevitable_reply = normal_chat_config.get(
"at_bot_inevitable_reply", config.at_bot_inevitable_reply
)
def focus_chat(parent: dict):
focus_chat_config = parent["focus_chat"]
config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length)
config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit)
config.reply_trigger_threshold = focus_chat_config.get(
"reply_trigger_threshold", config.reply_trigger_threshold
)
config.default_decay_rate_per_second = focus_chat_config.get(
"default_decay_rate_per_second", config.default_decay_rate_per_second
)
config.consecutive_no_reply_threshold = focus_chat_config.get(
"consecutive_no_reply_threshold", config.consecutive_no_reply_threshold
)
def model(parent: dict):
# 加载模型配置
model_config: dict = parent["model"]
config_list = [
"llm_reasoning",
# "llm_reasoning_minor",
"llm_normal",
"llm_topic_judge",
"llm_summary",
"vlm",
"embedding",
"llm_tool_use",
"llm_observation",
"llm_sub_heartflow",
"llm_plan",
"llm_heartflow",
"llm_PFC_action_planner",
"llm_PFC_chat",
"llm_PFC_reply_checker",
]
for item in config_list:
if item in model_config:
cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name": "",
"base_url": "",
"key": "",
"stream": False,
"pri_in": 0,
"pri_out": 0,
"temp": 0.7,
}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name", "pri_in", "pri_out"]
stream_item = ["stream"]
if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
stable_item.append("stream")
pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段
for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0
if i in pricing_item and i not in cfg_item:
cfg_target[i] = 0
if i in stream_item and i not in cfg_item:
cfg_target[i] = False
else:
# 没有特殊情况则原样复制
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
# 如果配置中有temp参数就使用配置中的值
if "temp" in cfg_item:
cfg_target["temp"] = cfg_item["temp"]
else:
# 如果没有temp参数就删除默认值
cfg_target.pop("temp", None)
provider = cfg_item.get("provider")
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config, item, cfg_target)
else:
logger.error(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
raise KeyError(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get(
"memory_forget_percentage", config.memory_forget_percentage
)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
config.memory_build_distribution = memory_config.get(
"memory_build_distribution", config.memory_build_distribution
)
config.build_memory_sample_num = memory_config.get(
"build_memory_sample_num", config.build_memory_sample_num
)
config.build_memory_sample_length = memory_config.get(
"build_memory_sample_length", config.build_memory_sample_length
)
if config.INNER_VERSION in SpecifierSet(">=1.5.1"):
config.consolidate_memory_interval = memory_config.get(
"consolidate_memory_interval", config.consolidate_memory_interval
)
config.consolidation_similarity_threshold = memory_config.get(
"consolidation_similarity_threshold", config.consolidation_similarity_threshold
)
config.consolidate_memory_percentage = memory_config.get(
"consolidate_memory_percentage", config.consolidate_memory_percentage
)
def remote(parent: dict):
remote_config = parent["remote"]
config.remote_enable = remote_config.get("enable", config.remote_enable)
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
def keywords_reaction(parent: dict):
keywords_reaction_config = parent["keywords_reaction"]
if keywords_reaction_config.get("enable", False):
config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
for rule in config.keywords_reaction_rules:
if rule.get("enable", False) and "regex" in rule:
rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
def chinese_typo(parent: dict):
chinese_typo_config = parent["chinese_typo"]
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def response_splitter(parent: dict):
response_splitter_config = parent["response_splitter"]
config.enable_response_splitter = response_splitter_config.get(
"enable_response_splitter", config.enable_response_splitter
)
config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
config.response_max_sentence_num = response_splitter_config.get(
"response_max_sentence_num", config.response_max_sentence_num
)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.enable_kaomoji_protection = response_splitter_config.get(
"enable_kaomoji_protection", config.enable_kaomoji_protection
)
if config.INNER_VERSION in SpecifierSet(">=1.6.0"):
config.model_max_output_length = response_splitter_config.get(
"model_max_output_length", config.model_max_output_length
)
def groups(parent: dict):
groups_config = parent["groups"]
# config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", []))
# config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
config.talk_frequency_down_groups = set(
str(group) for group in groups_config.get("talk_frequency_down", [])
)
# config.ban_user_id = set(groups_config.get("ban_user_id", []))
config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", []))
def experimental(parent: dict):
experimental_config = parent["experimental"]
config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
# config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", []))
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
# 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号:当你做了不兼容的 API 修改,
# 次版本号:当你做了向下兼容的功能性新增,
# 修订号:当你做了向下兼容的问题修正。
# 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
# 如果你做了break的修改就应该改动主版本号
# 如果做了一个兼容修改,就不应该要求这个选项是必须的!
include_configs = {
"bot": {"func": bot, "support": ">=0.0.0"},
"groups": {"func": groups, "support": ">=0.0.0"},
"personality": {"func": personality, "support": ">=0.0.0"},
"identity": {"func": identity, "support": ">=1.2.4"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
"experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
"chat": {"func": chat, "support": ">=1.6.0", "necessary": False},
"normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False},
"focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
for key in include_configs:
item_support = include_configs[key]["support"]
include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
# 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict)
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifierset: SpecifierSet = include_configs[key]["support"]
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifierset}"
)
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# identity_detail字段非空检查
if not config.identity_detail:
logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
logger.success(f"成功加载配置文件: {config_path}")
return config
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {mai_version}")
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
bot_config_floder_path = BotConfig.get_config_dir()
logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
else:
# 配置文件不存在
logger.error("配置文件不存在,请检查路径: {bot_config_path}")
raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
global_config = BotConfig.load_config(config_path=bot_config_path)
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml")
logger.info("非常的新鲜,非常的美味!")

116
src/config/config_base.py Normal file
View File

@@ -0,0 +1,116 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,399 @@
from dataclasses import dataclass, field
from typing import Any
from src.config.config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class BotConfig(ConfigBase):
"""QQ机器人配置类"""
qq_account: str
"""QQ账号"""
nickname: str
"""昵称"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""
@dataclass
class ChatTargetConfig(ConfigBase):
"""
聊天目标配置类
此类中有聊天的群组和用户配置
"""
talk_allowed_groups: set[str] = field(default_factory=lambda: set())
"""允许聊天的群组列表"""
talk_frequency_down_groups: set[str] = field(default_factory=lambda: set())
"""降低聊天频率的群组列表"""
ban_user_id: set[str] = field(default_factory=lambda: set())
"""禁止聊天的用户列表"""
@dataclass
class PersonalityConfig(ConfigBase):
"""人格配置类"""
personality_core: str
"""核心人格"""
expression_style: str
"""表达风格"""
personality_sides: list[str] = field(default_factory=lambda: [])
"""人格侧写"""
@dataclass
class IdentityConfig(ConfigBase):
"""个体特征配置类"""
height: int = 170
"""身高(单位:厘米)"""
weight: float = 50
"""体重(单位:千克)"""
age: int = 18
"""年龄(单位:岁)"""
gender: str = ""
"""性别(男/女)"""
appearance: str = "可爱"
"""外貌描述"""
identity_detail: list[str] = field(default_factory=lambda: [])
"""身份特征"""
@dataclass
class PlatformsConfig(ConfigBase):
"""平台配置类"""
qq: str
"""QQ适配器连接URL配置"""
@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
allow_focus_mode: bool = True
"""是否允许专注聊天状态"""
base_normal_chat_num: int = 3
"""最多允许多少个群进行普通聊天"""
base_focused_chat_num: int = 2
"""最多允许多少个群进行专注聊天"""
observation_context_size: int = 12
"""可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
message_buffer: bool = True
"""消息缓冲器"""
ban_words: set[str] = field(default_factory=lambda: set())
"""过滤词列表"""
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
reasoning_model_probability: float = 0.3
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
emoji_chance: float = 0.2
"""发送表情包的基础概率"""
thinking_timeout: int = 120
"""最长思考时间"""
willing_mode: str = "classical"
"""意愿模式"""
response_willing_amplifier: float = 1.0
"""回复意愿放大系数"""
response_interested_rate_amplifier: float = 1.0
"""回复兴趣度放大系数"""
down_frequency_rate: float = 3.0
"""降低回复频率的群组回复意愿降低系数"""
emoji_response_penalty: float = 0.0
"""表情包回复惩罚系数"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
at_bot_inevitable_reply: bool = False
"""@bot 必然回复"""
@dataclass
class FocusChatConfig(ConfigBase):
"""专注聊天配置类"""
reply_trigger_threshold: float = 3.0
"""心流聊天触发阈值,越低越容易触发"""
default_decay_rate_per_second: float = 0.98
"""默认衰减率,越大衰减越快"""
consecutive_no_reply_threshold: int = 3
"""连续不回复的次数阈值"""
compressed_length: int = 5
"""心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5"""
compress_length_limit: int = 5
"""最多压缩份数,超过该数值的压缩上下文会被删除"""
@dataclass
class EmojiConfig(ConfigBase):
"""表情包配置类"""
max_reg_num: int = 200
"""表情包最大注册数量"""
do_replace: bool = True
"""达到最大注册数量时替换旧表情包"""
check_interval: int = 120
"""表情包检查间隔(分钟)"""
save_pic: bool = False
"""是否保存图片"""
cache_emoji: bool = True
"""是否缓存表情包"""
steal_emoji: bool = True
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
content_filtration: bool = False
"""是否开启表情包过滤"""
filtration_prompt: str = "符合公序良俗"
"""表情包过滤要求"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
memory_build_interval: int = 600
"""记忆构建间隔(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
memory_compress_rate: float = 0.1
"""记忆压缩率"""
forget_memory_interval: int = 1000
"""记忆遗忘间隔(秒)"""
memory_forget_time: int = 24
"""记忆遗忘时间(小时)"""
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
mood_update_interval: int = 1
"""情绪更新间隔(秒)"""
mood_decay_rate: float = 0.95
"""情绪衰减率"""
mood_intensity_factor: float = 0.7
"""情绪强度因子"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
enable: bool = True
"""是否启用关键词规则"""
keywords: list[str] = field(default_factory=lambda: [])
"""关键词列表"""
regex: list[str] = field(default_factory=lambda: [])
"""正则表达式列表"""
reaction: str = ""
"""关键词触发的反应"""
@dataclass
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
enable: bool = True
"""是否启用关键词反应"""
rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
"""关键词反应规则列表"""
@dataclass
class ChineseTypoConfig(ConfigBase):
"""中文错别字配置类"""
enable: bool = True
"""是否启用中文错别字生成器"""
error_rate: float = 0.01
"""单字替换概率"""
min_freq: int = 9
"""最小字频阈值"""
tone_error_rate: float = 0.1
"""声调错误概率"""
word_replace_rate: float = 0.006
"""整词替换概率"""
@dataclass
class ResponseSplitterConfig(ConfigBase):
"""回复分割器配置类"""
enable: bool = True
"""是否启用回复分割器"""
max_length: int = 256
"""回复允许的最大长度"""
max_sentence_num: int = 3
"""回复允许的最大句子数"""
enable_kaomoji_protection: bool = False
"""是否启用颜文字保护"""
@dataclass
class TelemetryConfig(ConfigBase):
"""遥测配置类"""
enable: bool = True
"""是否启用遥测"""
@dataclass
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
# enable_friend_chat: bool = False
# """是否启用好友聊天"""
# talk_allowed_private: set[str] = field(default_factory=lambda: set())
# """允许聊天的私聊列表"""
pfc_chatting: bool = False
"""是否启用PFC"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
reasoning: dict[str, Any] = field(default_factory=lambda: {})
"""推理模型配置"""
normal: dict[str, Any] = field(default_factory=lambda: {})
"""普通模型配置"""
topic_judge: dict[str, Any] = field(default_factory=lambda: {})
"""主题判断模型配置"""
summary: dict[str, Any] = field(default_factory=lambda: {})
"""摘要模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""心流模型配置"""
observation: dict[str, Any] = field(default_factory=lambda: {})
"""观察模型配置"""
sub_heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""子心流模型配置"""
plan: dict[str, Any] = field(default_factory=lambda: {})
"""计划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {})
"""PFC动作规划模型配置"""
pfc_chat: dict[str, Any] = field(default_factory=lambda: {})
"""PFC聊天模型配置"""
pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {})
"""PFC回复检查模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""

View File

@@ -114,7 +114,7 @@ class ActionPlanner:
request_type="action_planning",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
@@ -140,7 +140,7 @@ class ActionPlanner:
# (这部分逻辑不变)
time_since_last_bot_message_info = ""
try:
bot_id = str(global_config.BOT_QQ)
bot_id = str(global_config.bot.qq_account)
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
for i in range(len(observation_info.chat_history) - 1, -1, -1):
msg = observation_info.chat_history[i]

View File

@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
create_new_message_notification,
create_cold_chat_notification,
)
from src.experimental.PFC.message_storage import MongoDBMessageStorage
from src.experimental.PFC.message_storage import PeeweeMessageStorage
from rich.traceback import install
install(extra_lines=3)
@@ -53,7 +53,7 @@ class ChatObserver:
self.stream_id = stream_id
self.private_name = private_name
self.message_storage = MongoDBMessageStorage()
self.message_storage = PeeweeMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
@@ -323,7 +323,7 @@ class ChatObserver:
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
if user_info.user_id == global_config.bot.qq_account:
self.update_bot_speak_time(msg["time"])
else:
self.update_user_speak_time(msg["time"])

View File

@@ -42,8 +42,8 @@ class DirectMessageSender:
# 获取麦麦的信息
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=chat_stream.platform,
)

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from src.common.database import db
# from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
class MessageStorage(ABC):
@@ -47,28 +50,35 @@ class MessageStorage(ABC):
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
class PeeweeMessageStorage(MessageStorage):
"""Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
# print(f"storage_check_message: {message_time}")
query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
.order_by(Messages.time.asc())
)
return list(db.messages.find(query).sort("time", 1))
# print(f"storage_check_message: {message_time}")
messages_models = list(query)
return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time < time_point))
.order_by(Messages.time.desc())
.limit(limit)
)
messages_models = list(query)
# 将消息按时间正序排列
messages.reverse()
return messages
messages_models.reverse()
return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
return db.messages.find_one(query) is not None
return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
# # 创建一个内存消息存储实现,用于测试

View File

@@ -42,13 +42,14 @@ class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
# TODO: API-Adapter修改标记
self.llm = LLMRequest(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
@@ -315,7 +316,7 @@ class GoalAnalyzer:
# message_segment = Seg(type="text", data=content)
# bot_user_info = UserInfo(
# user_id=global_config.BOT_QQ,
# user_nickname=global_config.BOT_NICKNAME,
# user_nickname=global_config.bot.nickname,
# platform=chat_stream.platform,
# )

View File

@@ -14,9 +14,10 @@ class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
# TODO: API-Adapter修改标记
self.llm = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=1000,
request_type="knowledge_fetch",
)

View File

@@ -16,7 +16,7 @@ class ReplyChecker:
self.llm = LLMRequest(
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
)
self.name = global_config.BOT_NICKNAME
self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.max_retries = 3 # 最大重试次数
@@ -43,7 +43,7 @@ class ReplyChecker:
bot_messages = []
for msg in reversed(chat_history):
user_info = UserInfo.from_dict(msg.get("user_info", {}))
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串
bot_messages.append(msg.get("processed_plain_text", ""))
if len(bot_messages) >= 2: # 只和最近的两条比较
break

View File

@@ -93,7 +93,7 @@ class ReplyGenerator:
request_type="reply_generation",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.name = global_config.bot.nickname
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.reply_checker = ReplyChecker(stream_id, private_name)

View File

@@ -19,7 +19,7 @@ class Waiter:
def __init__(self, stream_id: str, private_name: str):
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.name = global_config.BOT_NICKNAME
self.name = global_config.bot.nickname
self.private_name = private_name
# self.wait_accumulated_time = 0 # 不再需要累加计时

View File

@@ -16,7 +16,7 @@ class MessageProcessor:
@staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
@@ -28,7 +28,7 @@ class MessageProcessor:
@staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"

View File

@@ -40,7 +40,7 @@ class MainSystem:
async def initialize(self):
"""初始化系统组件"""
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
logger.debug(f"正在唤醒{global_config.bot.nickname}......")
# 其他初始化任务
await asyncio.gather(self._init_components())
@@ -84,7 +84,7 @@ class MainSystem:
asyncio.create_task(chat_manager._auto_save_task())
# 使用HippocampusManager初始化海马体
self.hippocampus_manager.initialize(global_config=global_config)
self.hippocampus_manager.initialize()
# await asyncio.sleep(0.5) #防止logger输出飞了
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
@@ -92,15 +92,15 @@ class MainSystem:
# 初始化个体特征
self.individuality.initialize(
bot_nickname=global_config.BOT_NICKNAME,
personality_core=global_config.personality_core,
personality_sides=global_config.personality_sides,
identity_detail=global_config.identity_detail,
height=global_config.height,
weight=global_config.weight,
age=global_config.age,
gender=global_config.gender,
appearance=global_config.appearance,
bot_nickname=global_config.bot.nickname,
personality_core=global_config.personality.personality_core,
personality_sides=global_config.personality.personality_sides,
identity_detail=global_config.identity.identity_detail,
height=global_config.identity.height,
weight=global_config.identity.weight,
age=global_config.identity.age,
gender=global_config.identity.gender,
appearance=global_config.identity.appearance,
)
logger.success("个体特征初始化成功")
@@ -141,7 +141,7 @@ class MainSystem:
async def build_memory_task():
"""记忆构建任务"""
while True:
await asyncio.sleep(global_config.build_memory_interval)
await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory()
@@ -149,16 +149,18 @@ class MainSystem:
async def forget_memory_task():
"""记忆遗忘任务"""
while True:
await asyncio.sleep(global_config.forget_memory_interval)
await asyncio.sleep(global_config.memory.forget_memory_interval)
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
await HippocampusManager.get_instance().forget_memory(
percentage=global_config.memory.memory_forget_percentage
)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@staticmethod
async def consolidate_memory_task():
"""记忆整合任务"""
while True:
await asyncio.sleep(global_config.consolidate_memory_interval)
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...")
await HippocampusManager.get_instance().consolidate_memory()
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")

View File

@@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask):
def __init__(self):
super().__init__(
task_name="Mood Update Task",
wait_before_start=global_config.mood_update_interval,
run_interval=global_config.mood_update_interval,
wait_before_start=global_config.mood.mood_update_interval,
run_interval=global_config.mood.mood_update_interval,
)
# 从配置文件获取衰减率
self.decay_rate_valence: float = 1 - global_config.mood_decay_rate
self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate
"""愉悦度衰减率"""
self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate
self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate
"""唤醒度衰减率"""
self.last_update = time.time()

101
src/plugins.md Normal file
View File

@@ -0,0 +1,101 @@
# 如何编写MaiBot插件
## 基本步骤
1.`src/plugins/你的插件名/actions/`目录下创建插件文件
2. 继承`PluginAction`基类
3. 实现`process`方法
## 插件结构示例
```python
from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from typing import Tuple
logger = get_logger("your_action_name")
@register_action
class YourAction(PluginAction):
"""你的动作描述"""
action_name = "your_action_name" # 动作名称,必须唯一
action_description = "这个动作的详细描述,会展示给用户"
action_parameters = {
"param1": "参数1的说明可选",
"param2": "参数2的说明可选"
}
action_require = [
"使用场景1",
"使用场景2"
]
default = False # 是否默认启用
async def process(self) -> Tuple[bool, str]:
"""插件核心逻辑"""
# 你的代码逻辑...
return True, "执行结果"
```
## 可用的API方法
插件可以使用`PluginAction`基类提供的以下API
### 1. 发送消息
```python
await self.send_message("要发送的文本", target="可选的回复目标")
```
### 2. 获取聊天类型
```python
chat_type = self.get_chat_type() # 返回 "group" 或 "private" 或 "unknown"
```
### 3. 获取最近消息
```python
messages = self.get_recent_messages(count=5) # 获取最近5条消息
# 返回格式: [{"sender": "发送者", "content": "内容", "timestamp": 时间戳}, ...]
```
### 4. 获取动作参数
```python
param_value = self.action_data.get("param_name", "默认值")
```
### 5. 日志记录
```python
logger.info(f"{self.log_prefix} 你的日志信息")
logger.warning("警告信息")
logger.error("错误信息")
```
## 返回值说明
`process`方法必须返回一个元组,包含两个元素:
- 第一个元素(bool): 表示动作是否执行成功
- 第二个元素(str): 执行结果的文本描述
```python
return True, "执行成功的消息"
# 或
return False, "执行失败的原因"
```
## 最佳实践
1. 使用`action_parameters`清晰定义你的动作需要的参数
2. 使用`action_require`描述何时应该使用你的动作
3. 使用`action_description`准确描述你的动作功能
4. 使用`logger`记录重要信息,方便调试
5. 避免操作底层系统,尽量使用`PluginAction`提供的API
## 注册与加载
插件会在系统启动时自动加载,只要放在正确的目录并添加了`@register_action`装饰器。
若设置`default = True`,插件会自动添加到默认动作集;否则需要在系统中手动启用。

1
src/plugins/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""插件系统包"""

Some files were not shown because too many files have changed in this diff Show More