Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
20
CLAUDE.md
Normal file
20
CLAUDE.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
- **Run Bot**: `python bot.py`
|
||||||
|
- **Lint**: `ruff check --fix .` or `ruff format .`
|
||||||
|
- **Run Tests**: `python -m unittest discover -v`
|
||||||
|
- **Run Single Test**: `python -m unittest src/plugins/message/test.py`
|
||||||
|
|
||||||
|
## Code Style
|
||||||
|
- **Formatting**: Line length 120 chars, use double quotes for strings
|
||||||
|
- **Imports**: Group standard library, external packages, then internal imports
|
||||||
|
- **Naming**: snake_case for functions/variables, PascalCase for classes
|
||||||
|
- **Error Handling**: Use try/except blocks with specific exceptions
|
||||||
|
- **Types**: Use type hints where possible
|
||||||
|
- **Docstrings**: Document classes and complex functions
|
||||||
|
- **Linting**: Follow ruff rules (E, F, B) with ignores E711, E501
|
||||||
|
|
||||||
|
When making changes, run `ruff check --fix .` to ensure code follows style guidelines. The codebase uses Ruff for linting and formatting.
|
||||||
159
README.md
159
README.md
@@ -1,24 +1,66 @@
|
|||||||
# 麦麦!MaiCore-MaiMBot (编辑中)
|
# 麦麦!MaiCore-MaiMBot (编辑中)
|
||||||
|
<br />
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://github.com/MaiM-with-u/MaiBot/">
|
||||||
|
<img src="depends-data/maimai.png" alt="Logo" width="200">
|
||||||
|
</a>
|
||||||
|
<br />
|
||||||
|
<a href="https://space.bilibili.com/1344099355">
|
||||||
|
画师:略nd
|
||||||
|
</a>
|
||||||
|
|
||||||
|
<h3 align="center">MaiBot(麦麦)</h3>
|
||||||
|
<p align="center">
|
||||||
|
一款专注于<strong> 群组聊天 </strong>的赛博网友
|
||||||
|
<br />
|
||||||
|
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
<!-- <a href="https://github.com/shaojintian/Best_README_template">查看Demo</a>
|
||||||
|
· -->
|
||||||
|
<a href="https://github.com/MaiM-with-u/MaiBot/issues">报告Bug</a>
|
||||||
|
·
|
||||||
|
<a href="https://github.com/MaiM-with-u/MaiBot/issues">提出新特性</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
## 新版0.6.0部署前先阅读:https://docs.mai-mai.org/manual/usage/mmc_q_a
|
## 新版0.6.0部署前先阅读:https://docs.mai-mai.org/manual/usage/mmc_q_a
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## 📝 项目简介
|
## 📝 项目简介
|
||||||
|
|
||||||
**🍔MaiCore是一个基于大语言模型的可交互智能体**
|
**🍔MaiCore是一个基于大语言模型的可交互智能体**
|
||||||
|
|
||||||
- LLM 提供对话能力
|
|
||||||
- 动态Prompt构建器
|
- 💭 **智能对话系统**:基于LLM的自然语言交互
|
||||||
- 实时的思维系统
|
- 🤔 **实时思维系统**:模拟人类思考过程
|
||||||
- MongoDB 提供数据持久化支持
|
- 💝 **情感表达系统**:丰富的表情包和情绪表达
|
||||||
- 可扩展,可支持多种平台和多种功能
|
- 🧠 **持久记忆系统**:基于MongoDB的长期记忆存储
|
||||||
|
- 🔄 **动态人格系统**:自适应的性格特征
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||||
|
<img src="depends-data/video.png" width="200" alt="麦麦演示视频">
|
||||||
|
<br>
|
||||||
|
👆 点击观看麦麦演示视频 👆
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
### 📢 版本信息
|
||||||
|
|
||||||
**最新版本: v0.6.0** ([查看更新日志](changelogs/changelog.md))
|
**最新版本: v0.6.0** ([查看更新日志](changelogs/changelog.md))
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
@@ -28,19 +70,12 @@
|
|||||||
> 次版本MaiBot将基于MaiCore运行,不再依赖于nonebot相关组件运行。
|
> 次版本MaiBot将基于MaiCore运行,不再依赖于nonebot相关组件运行。
|
||||||
> MaiBot将通过nonebot的插件与nonebot建立联系,然后nonebot与QQ建立联系,实现MaiBot与QQ的交互
|
> MaiBot将通过nonebot的插件与nonebot建立联系,然后nonebot与QQ建立联系,实现MaiBot与QQ的交互
|
||||||
|
|
||||||
**分支介绍:**
|
**分支说明:**
|
||||||
- main 稳定版本
|
- `main`: 稳定发布版本
|
||||||
- dev 开发版(不知道什么意思就别下)
|
- `dev`: 开发测试版本(不知道什么意思就别下)
|
||||||
- classical 0.6.0以前的版本
|
- `classical`: 0.6.0之前的版本
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
|
||||||
<img src="docs/pic/video.png" width="300" alt="麦麦演示视频">
|
|
||||||
<br>
|
|
||||||
👆 点击观看麦麦演示视频 👆
|
|
||||||
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> - 项目处于活跃开发阶段,代码可能随时更改
|
> - 项目处于活跃开发阶段,代码可能随时更改
|
||||||
@@ -49,6 +84,12 @@
|
|||||||
> - 由于持续迭代,可能存在一些已知或未知的bug
|
> - 由于持续迭代,可能存在一些已知或未知的bug
|
||||||
> - 由于开发中,可能消耗较多token
|
> - 由于开发中,可能消耗较多token
|
||||||
|
|
||||||
|
### ⚠️ 重要提示
|
||||||
|
|
||||||
|
- 升级到v0.6.0版本前请务必阅读:[升级指南](https://docs.mai-mai.org/manual/usage/mmc_q_a)
|
||||||
|
- 本版本基于MaiCore重构,通过nonebot插件与QQ平台交互
|
||||||
|
- 项目处于活跃开发阶段,功能和API可能随时调整
|
||||||
|
|
||||||
### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
|
### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
|
||||||
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
|
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
|
||||||
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】
|
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】
|
||||||
@@ -72,55 +113,35 @@
|
|||||||
|
|
||||||
## 🎯 功能介绍
|
## 🎯 功能介绍
|
||||||
|
|
||||||
### 💬 聊天功能
|
| 模块 | 主要功能 | 特点 |
|
||||||
- 提供思维流(心流)聊天和推理聊天两种对话逻辑
|
|------|---------|------|
|
||||||
- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言
|
| 💬 聊天系统 | • 思维流/推理聊天<br>• 关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC) | 拟人化交互 |
|
||||||
- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
|
| 🧠 思维流系统 | • 实时思考生成<br>• 自动启停机制<br>• 日程系统联动 | 智能化决策 |
|
||||||
- 支持多模型,多厂商自定义配置
|
| 🧠 记忆系统 2.0 | • 优化记忆抽取<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 |
|
||||||
- 动态的prompt构建器,更拟人
|
| 😊 表情包系统 | • 情绪匹配发送<br>• GIF支持<br>• 自动收集与审查 | 丰富表达 |
|
||||||
- 支持图片,转发消息,回复消息的识别
|
| 📅 日程系统 | • 动态日程生成<br>• 自定义想象力<br>• 思维流联动 | 智能规划 |
|
||||||
- 支持私聊功能,可使用PFC模式的有目的多轮对话(实验性)
|
| 👥 关系系统 2.0 | • 关系管理优化<br>• 丰富接口支持<br>• 个性化交互 | 深度社交 |
|
||||||
|
| 📊 统计系统 | • 使用数据统计<br>• LLM调用记录<br>• 实时控制台显示 | 数据可视 |
|
||||||
|
| 🔧 系统功能 | • 优雅关闭机制<br>• 自动数据保存<br>• 异常处理完善 | 稳定可靠 |
|
||||||
|
|
||||||
### 🧠 思维流系统
|
## 📐 项目架构
|
||||||
- 思维流能够在回复前后进行思考,生成实时想法
|
|
||||||
- 思维流自动启停机制,提升资源利用效率
|
|
||||||
- 思维流与日程系统联动,实现动态日程生成
|
|
||||||
|
|
||||||
### 🧠 记忆系统 2.0
|
```mermaid
|
||||||
- 优化记忆抽取策略和prompt结构
|
graph TD
|
||||||
- 改进海马体记忆提取机制,提升自然度
|
A[MaiCore] --> B[对话系统]
|
||||||
- 对聊天记录进行概括存储,在需要时调用
|
A --> C[思维流系统]
|
||||||
|
A --> D[记忆系统]
|
||||||
|
A --> E[情感系统]
|
||||||
|
B --> F[多模型支持]
|
||||||
|
B --> G[动态Prompt]
|
||||||
|
C --> H[实时思考]
|
||||||
|
C --> I[日程联动]
|
||||||
|
D --> J[记忆存储]
|
||||||
|
D --> K[记忆检索]
|
||||||
|
E --> L[表情管理]
|
||||||
|
E --> M[情绪识别]
|
||||||
|
```
|
||||||
|
|
||||||
### 😊 表情包系统
|
|
||||||
- 支持根据发言内容发送对应情绪的表情包
|
|
||||||
- 支持识别和处理gif表情包
|
|
||||||
- 会自动偷群友的表情包
|
|
||||||
- 表情包审查功能
|
|
||||||
- 表情包文件完整性自动检查
|
|
||||||
- 自动清理缓存图片
|
|
||||||
|
|
||||||
### 📅 日程系统
|
|
||||||
- 动态更新的日程生成
|
|
||||||
- 可自定义想象力程度
|
|
||||||
- 与聊天情况交互(思维流模式下)
|
|
||||||
|
|
||||||
### 👥 关系系统 2.0
|
|
||||||
- 优化关系管理系统,适用于新版本
|
|
||||||
- 提供更丰富的关系接口
|
|
||||||
- 针对每个用户创建"关系",实现个性化回复
|
|
||||||
|
|
||||||
### 📊 统计系统
|
|
||||||
- 详细的使用数据统计
|
|
||||||
- LLM调用统计
|
|
||||||
- 在控制台显示统计信息
|
|
||||||
|
|
||||||
### 🔧 系统功能
|
|
||||||
- 支持优雅的shutdown机制
|
|
||||||
- 自动保存功能,定期保存聊天记录和关系数据
|
|
||||||
- 完善的异常处理机制
|
|
||||||
- 可自定义时区设置
|
|
||||||
- 优化的日志输出格式
|
|
||||||
- 配置自动更新功能
|
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
## 开发计划TODO:LIST
|
||||||
|
|
||||||
|
|||||||
BIN
depends-data/maimai.png
Normal file
BIN
depends-data/maimai.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 455 KiB |
BIN
depends-data/video.png
Normal file
BIN
depends-data/video.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
@@ -24,10 +24,10 @@
|
|||||||
|
|
||||||
# # 标记GUI是否运行中
|
# # 标记GUI是否运行中
|
||||||
# self.is_running = True
|
# self.is_running = True
|
||||||
|
|
||||||
# # 程序关闭时的清理操作
|
# # 程序关闭时的清理操作
|
||||||
# self.protocol("WM_DELETE_WINDOW", self._on_closing)
|
# self.protocol("WM_DELETE_WINDOW", self._on_closing)
|
||||||
|
|
||||||
# # 初始化进程、日志队列、日志数据等变量
|
# # 初始化进程、日志队列、日志数据等变量
|
||||||
# self.process = None
|
# self.process = None
|
||||||
# self.log_queue = queue.Queue()
|
# self.log_queue = queue.Queue()
|
||||||
@@ -236,7 +236,7 @@
|
|||||||
# while not self.log_queue.empty():
|
# while not self.log_queue.empty():
|
||||||
# line = self.log_queue.get()
|
# line = self.log_queue.get()
|
||||||
# self.process_log_line(line)
|
# self.process_log_line(line)
|
||||||
|
|
||||||
# # 仅在GUI仍在运行时继续处理队列
|
# # 仅在GUI仍在运行时继续处理队列
|
||||||
# if self.is_running:
|
# if self.is_running:
|
||||||
# self.after(100, self.process_log_queue)
|
# self.after(100, self.process_log_queue)
|
||||||
@@ -245,11 +245,11 @@
|
|||||||
# """解析单行日志并更新日志数据和筛选器"""
|
# """解析单行日志并更新日志数据和筛选器"""
|
||||||
# match = re.match(
|
# match = re.match(
|
||||||
# r"""^
|
# r"""^
|
||||||
# (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
|
# (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
|
||||||
# (?P<level>\w+)\s*\|\s*
|
# (?P<level>\w+)\s*\|\s*
|
||||||
# (?P<module>.*?)
|
# (?P<module>.*?)
|
||||||
# \s*[-|]\s*
|
# \s*[-|]\s*
|
||||||
# (?P<message>.*)
|
# (?P<message>.*)
|
||||||
# $""",
|
# $""",
|
||||||
# line.strip(),
|
# line.strip(),
|
||||||
# re.VERBOSE,
|
# re.VERBOSE,
|
||||||
@@ -354,10 +354,10 @@
|
|||||||
# """处理窗口关闭事件,安全清理资源"""
|
# """处理窗口关闭事件,安全清理资源"""
|
||||||
# # 标记GUI已关闭
|
# # 标记GUI已关闭
|
||||||
# self.is_running = False
|
# self.is_running = False
|
||||||
|
|
||||||
# # 停止日志进程
|
# # 停止日志进程
|
||||||
# self.stop_process()
|
# self.stop_process()
|
||||||
|
|
||||||
# # 安全清理tkinter变量
|
# # 安全清理tkinter变量
|
||||||
# for attr_name in list(self.__dict__.keys()):
|
# for attr_name in list(self.__dict__.keys()):
|
||||||
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
|
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
|
||||||
@@ -367,7 +367,7 @@
|
|||||||
# except Exception:
|
# except Exception:
|
||||||
# pass
|
# pass
|
||||||
# setattr(self, attr_name, None)
|
# setattr(self, attr_name, None)
|
||||||
|
|
||||||
# self.quit()
|
# self.quit()
|
||||||
# sys.exit(0)
|
# sys.exit(0)
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,7 @@
|
|||||||
# """处理窗口关闭事件"""
|
# """处理窗口关闭事件"""
|
||||||
# # 标记GUI已关闭,防止后台线程继续访问tkinter对象
|
# # 标记GUI已关闭,防止后台线程继续访问tkinter对象
|
||||||
# self.is_running = False
|
# self.is_running = False
|
||||||
|
|
||||||
# # 安全清理所有可能的tkinter变量
|
# # 安全清理所有可能的tkinter变量
|
||||||
# for attr_name in list(self.__dict__.keys()):
|
# for attr_name in list(self.__dict__.keys()):
|
||||||
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
|
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
|
||||||
@@ -138,7 +138,7 @@
|
|||||||
# except Exception:
|
# except Exception:
|
||||||
# pass
|
# pass
|
||||||
# setattr(self, attr_name, None)
|
# setattr(self, attr_name, None)
|
||||||
|
|
||||||
# # 退出
|
# # 退出
|
||||||
# self.root.quit()
|
# self.root.quit()
|
||||||
# sys.exit(0)
|
# sys.exit(0)
|
||||||
@@ -259,7 +259,7 @@
|
|||||||
# while True:
|
# while True:
|
||||||
# if not self.is_running:
|
# if not self.is_running:
|
||||||
# break # 如果GUI已关闭,停止线程
|
# break # 如果GUI已关闭,停止线程
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# # 从数据库获取最新数据,只获取启动时间之后的记录
|
# # 从数据库获取最新数据,只获取启动时间之后的记录
|
||||||
# query = {"time": {"$gt": self.start_timestamp}}
|
# query = {"time": {"$gt": self.start_timestamp}}
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ class Heartflow:
|
|||||||
self._subheartflows = {}
|
self._subheartflows = {}
|
||||||
self.active_subheartflows_nums = 0
|
self.active_subheartflows_nums = 0
|
||||||
|
|
||||||
|
|
||||||
async def _cleanup_inactive_subheartflows(self):
|
async def _cleanup_inactive_subheartflows(self):
|
||||||
"""定期清理不活跃的子心流"""
|
"""定期清理不活跃的子心流"""
|
||||||
while True:
|
while True:
|
||||||
@@ -84,25 +83,22 @@ class Heartflow:
|
|||||||
|
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
personality_info = prompt_personality
|
personality_info = prompt_personality
|
||||||
|
|
||||||
|
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
related_memory_info = "memory"
|
related_memory_info = "memory"
|
||||||
@@ -146,22 +142,20 @@ class Heartflow:
|
|||||||
async def minds_summary(self, minds_str):
|
async def minds_summary(self, minds_str):
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
personality_info = prompt_personality
|
personality_info = prompt_personality
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
|
|
||||||
@@ -183,7 +177,7 @@ class Heartflow:
|
|||||||
添加一个SubHeartflow实例到self._subheartflows字典中
|
添加一个SubHeartflow实例到self._subheartflows字典中
|
||||||
并根据subheartflow_id为子心流创建一个观察对象
|
并根据subheartflow_id为子心流创建一个观察对象
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if subheartflow_id not in self._subheartflows:
|
if subheartflow_id not in self._subheartflows:
|
||||||
logger.debug(f"创建 subheartflow: {subheartflow_id}")
|
logger.debug(f"创建 subheartflow: {subheartflow_id}")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from src.common.database import db
|
|||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
# 所有观察的基类
|
# 所有观察的基类
|
||||||
class Observation:
|
class Observation:
|
||||||
def __init__(self, observe_type, observe_id):
|
def __init__(self, observe_type, observe_id):
|
||||||
@@ -24,7 +25,7 @@ class ChattingObservation(Observation):
|
|||||||
|
|
||||||
self.talking_message = []
|
self.talking_message = []
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
|
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ class ChattingObservation(Observation):
|
|||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
if "detailed_plain_text" in msg:
|
if "detailed_plain_text" in msg:
|
||||||
new_messages_str += f"{msg['detailed_plain_text']}"
|
new_messages_str += f"{msg['detailed_plain_text']}"
|
||||||
|
|
||||||
# print(f"new_messages_str:{new_messages_str}")
|
# print(f"new_messages_str:{new_messages_str}")
|
||||||
|
|
||||||
# 将新消息添加到talking_message,同时保持列表长度不超过20条
|
# 将新消息添加到talking_message,同时保持列表长度不超过20条
|
||||||
@@ -117,26 +118,22 @@ class ChattingObservation(Observation):
|
|||||||
# print(f"更新聊天总结:{self.talking_summary}")
|
# print(f"更新聊天总结:{self.talking_summary}")
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
personality_info = prompt_personality
|
personality_info = prompt_personality
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += f"{personality_info},请注意识别你自己的聊天发言"
|
prompt += f"{personality_info},请注意识别你自己的聊天发言"
|
||||||
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
|
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
|
||||||
@@ -148,7 +145,6 @@ class ChattingObservation(Observation):
|
|||||||
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
|
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
|
||||||
print(f"prompt:{prompt}")
|
print(f"prompt:{prompt}")
|
||||||
print(f"self.observe_info:{self.observe_info}")
|
print(f"self.observe_info:{self.observe_info}")
|
||||||
|
|
||||||
|
|
||||||
def translate_message_list_to_str(self):
|
def translate_message_list_to_str(self):
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
|
|||||||
@@ -53,11 +53,10 @@ class SubHeartflow:
|
|||||||
if not self.current_mind:
|
if not self.current_mind:
|
||||||
self.current_mind = "你什么也没想"
|
self.current_mind = "你什么也没想"
|
||||||
|
|
||||||
|
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
|
|
||||||
self.observations: list[Observation] = []
|
self.observations: list[Observation] = []
|
||||||
|
|
||||||
self.running_knowledges = []
|
self.running_knowledges = []
|
||||||
|
|
||||||
def add_observation(self, observation: Observation):
|
def add_observation(self, observation: Observation):
|
||||||
@@ -86,7 +85,9 @@ class SubHeartflow:
|
|||||||
async def subheartflow_start_working(self):
|
async def subheartflow_start_working(self):
|
||||||
while True:
|
while True:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time: # 120秒无回复/不在场,冻结
|
if (
|
||||||
|
current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
|
||||||
|
): # 120秒无回复/不在场,冻结
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
|
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
|
||||||
else:
|
else:
|
||||||
@@ -100,7 +101,9 @@ class SubHeartflow:
|
|||||||
await asyncio.sleep(global_config.sub_heart_flow_update_interval)
|
await asyncio.sleep(global_config.sub_heart_flow_update_interval)
|
||||||
|
|
||||||
# 检查是否超过10分钟没有激活
|
# 检查是否超过10分钟没有激活
|
||||||
if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 5分钟无回复/不在场,销毁
|
if (
|
||||||
|
current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
|
||||||
|
): # 5分钟无回复/不在场,销毁
|
||||||
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
||||||
break # 退出循环,销毁自己
|
break # 退出循环,销毁自己
|
||||||
|
|
||||||
@@ -147,11 +150,11 @@ class SubHeartflow:
|
|||||||
# self.current_mind = reponse
|
# self.current_mind = reponse
|
||||||
# logger.debug(f"prompt:\n{prompt}\n")
|
# logger.debug(f"prompt:\n{prompt}\n")
|
||||||
# logger.info(f"麦麦的脑内状态:{self.current_mind}")
|
# logger.info(f"麦麦的脑内状态:{self.current_mind}")
|
||||||
|
|
||||||
async def do_observe(self):
|
async def do_observe(self):
|
||||||
observation = self.observations[0]
|
observation = self.observations[0]
|
||||||
await observation.observe()
|
await observation.observe()
|
||||||
|
|
||||||
async def do_thinking_before_reply(self, message_txt):
|
async def do_thinking_before_reply(self, message_txt):
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
@@ -162,23 +165,20 @@ class SubHeartflow:
|
|||||||
|
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 调取记忆
|
# 调取记忆
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
@@ -191,7 +191,7 @@ class SubHeartflow:
|
|||||||
else:
|
else:
|
||||||
related_memory_info = ""
|
related_memory_info = ""
|
||||||
|
|
||||||
related_info,grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
|
related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
|
||||||
# print(related_info)
|
# print(related_info)
|
||||||
for _topic, results in grouped_results.items():
|
for _topic, results in grouped_results.items():
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -227,25 +227,23 @@ class SubHeartflow:
|
|||||||
|
|
||||||
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
|
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
|
||||||
# print("麦麦回复之后脑袋转起来了")
|
# print("麦麦回复之后脑袋转起来了")
|
||||||
|
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
|
|
||||||
@@ -279,22 +277,20 @@ class SubHeartflow:
|
|||||||
async def judge_willing(self):
|
async def judge_willing(self):
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# print("麦麦闹情绪了1")
|
# print("麦麦闹情绪了1")
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
@@ -320,13 +316,12 @@ class SubHeartflow:
|
|||||||
def update_current_mind(self, reponse):
|
def update_current_mind(self, reponse):
|
||||||
self.past_mind.append(self.current_mind)
|
self.past_mind.append(self.current_mind)
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
|
|
||||||
|
|
||||||
async def get_prompt_info(self, message: str, threshold: float):
|
async def get_prompt_info(self, message: str, threshold: float):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
related_info = ""
|
related_info = ""
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
|
|
||||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||||
topics = []
|
topics = []
|
||||||
# try:
|
# try:
|
||||||
@@ -334,7 +329,7 @@ class SubHeartflow:
|
|||||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||||
|
|
||||||
# # 提取关键词
|
# # 提取关键词
|
||||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
# if not topics:
|
# if not topics:
|
||||||
@@ -345,7 +340,7 @@ class SubHeartflow:
|
|||||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
# if topic.strip()
|
# if topic.strip()
|
||||||
# ]
|
# ]
|
||||||
|
|
||||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||||
@@ -353,7 +348,7 @@ class SubHeartflow:
|
|||||||
# words = jieba.cut(message)
|
# words = jieba.cut(message)
|
||||||
# topics = [word for word in words if len(word) > 1][:5]
|
# topics = [word for word in words if len(word) > 1][:5]
|
||||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||||
|
|
||||||
# 如果无法提取到主题,直接使用整个消息
|
# 如果无法提取到主题,直接使用整个消息
|
||||||
if not topics:
|
if not topics:
|
||||||
logger.debug("未能提取到任何主题,使用整个消息进行查询")
|
logger.debug("未能提取到任何主题,使用整个消息进行查询")
|
||||||
@@ -361,26 +356,26 @@ class SubHeartflow:
|
|||||||
if not embedding:
|
if not embedding:
|
||||||
logger.error("获取消息嵌入向量失败")
|
logger.error("获取消息嵌入向量失败")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||||
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||||
return related_info, {}
|
return related_info, {}
|
||||||
|
|
||||||
# 2. 对每个主题进行知识库查询
|
# 2. 对每个主题进行知识库查询
|
||||||
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||||
|
|
||||||
# 优化:批量获取嵌入向量,减少API调用
|
# 优化:批量获取嵌入向量,减少API调用
|
||||||
embeddings = {}
|
embeddings = {}
|
||||||
topics_batch = [topic for topic in topics if len(topic) > 0]
|
topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||||
if message: # 确保消息非空
|
if message: # 确保消息非空
|
||||||
topics_batch.append(message)
|
topics_batch.append(message)
|
||||||
|
|
||||||
# 批量获取嵌入向量
|
# 批量获取嵌入向量
|
||||||
embed_start_time = time.time()
|
embed_start_time = time.time()
|
||||||
for text in topics_batch:
|
for text in topics_batch:
|
||||||
if not text or len(text.strip()) == 0:
|
if not text or len(text.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = await get_embedding(text, request_type="info_retrieval")
|
embedding = await get_embedding(text, request_type="info_retrieval")
|
||||||
if embedding:
|
if embedding:
|
||||||
@@ -389,17 +384,17 @@ class SubHeartflow:
|
|||||||
logger.warning(f"获取'{text}'的嵌入向量失败")
|
logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||||
|
|
||||||
if not embeddings:
|
if not embeddings:
|
||||||
logger.error("所有嵌入向量获取失败")
|
logger.error("所有嵌入向量获取失败")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 3. 对每个主题进行知识库查询
|
# 3. 对每个主题进行知识库查询
|
||||||
all_results = []
|
all_results = []
|
||||||
query_start_time = time.time()
|
query_start_time = time.time()
|
||||||
|
|
||||||
# 首先添加原始消息的查询结果
|
# 首先添加原始消息的查询结果
|
||||||
if message in embeddings:
|
if message in embeddings:
|
||||||
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||||
@@ -408,12 +403,12 @@ class SubHeartflow:
|
|||||||
result["topic"] = "原始消息"
|
result["topic"] = "原始消息"
|
||||||
all_results.extend(original_results)
|
all_results.extend(original_results)
|
||||||
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||||
|
|
||||||
# 然后添加每个主题的查询结果
|
# 然后添加每个主题的查询结果
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
if not topic or topic not in embeddings:
|
if not topic or topic not in embeddings:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||||
if topic_results:
|
if topic_results:
|
||||||
@@ -424,9 +419,9 @@ class SubHeartflow:
|
|||||||
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||||
|
|
||||||
# 4. 去重和过滤
|
# 4. 去重和过滤
|
||||||
process_start_time = time.time()
|
process_start_time = time.time()
|
||||||
unique_contents = set()
|
unique_contents = set()
|
||||||
@@ -436,14 +431,16 @@ class SubHeartflow:
|
|||||||
if content not in unique_contents:
|
if content not in unique_contents:
|
||||||
unique_contents.add(content)
|
unique_contents.add(content)
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
|
|
||||||
# 5. 按相似度排序
|
# 5. 按相似度排序
|
||||||
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
# 6. 限制总数量(最多10条)
|
# 6. 限制总数量(最多10条)
|
||||||
filtered_results = filtered_results[:10]
|
filtered_results = filtered_results[:10]
|
||||||
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
|
logger.info(
|
||||||
|
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
||||||
|
)
|
||||||
|
|
||||||
# 7. 格式化输出
|
# 7. 格式化输出
|
||||||
if filtered_results:
|
if filtered_results:
|
||||||
format_start_time = time.time()
|
format_start_time = time.time()
|
||||||
@@ -453,7 +450,7 @@ class SubHeartflow:
|
|||||||
if topic not in grouped_results:
|
if topic not in grouped_results:
|
||||||
grouped_results[topic] = []
|
grouped_results[topic] = []
|
||||||
grouped_results[topic].append(result)
|
grouped_results[topic].append(result)
|
||||||
|
|
||||||
# 按主题组织输出
|
# 按主题组织输出
|
||||||
for topic, results in grouped_results.items():
|
for topic, results in grouped_results.items():
|
||||||
related_info += f"【主题: {topic}】\n"
|
related_info += f"【主题: {topic}】\n"
|
||||||
@@ -464,13 +461,15 @@ class SubHeartflow:
|
|||||||
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||||
related_info += f"{content}\n"
|
related_info += f"{content}\n"
|
||||||
related_info += "\n"
|
related_info += "\n"
|
||||||
|
|
||||||
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
|
||||||
|
|
||||||
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
|
||||||
return related_info,grouped_results
|
|
||||||
|
|
||||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
|
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||||
|
|
||||||
|
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||||
|
return related_info, grouped_results
|
||||||
|
|
||||||
|
def get_info_from_db(
|
||||||
|
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||||
|
) -> Union[str, list]:
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
|
|||||||
@@ -2,27 +2,36 @@ from dataclasses import dataclass
|
|||||||
from typing import List
|
from typing import List
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Identity:
|
class Identity:
|
||||||
"""身份特征类"""
|
"""身份特征类"""
|
||||||
|
|
||||||
identity_detail: List[str] # 身份细节描述
|
identity_detail: List[str] # 身份细节描述
|
||||||
height: int # 身高(厘米)
|
height: int # 身高(厘米)
|
||||||
weight: int # 体重(千克)
|
weight: int # 体重(千克)
|
||||||
age: int # 年龄
|
age: int # 年龄
|
||||||
gender: str # 性别
|
gender: str # 性别
|
||||||
appearance: str # 外貌特征
|
appearance: str # 外貌特征
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, identity_detail: List[str] = None, height: int = 0, weight: int = 0,
|
def __init__(
|
||||||
age: int = 0, gender: str = "", appearance: str = ""):
|
self,
|
||||||
|
identity_detail: List[str] = None,
|
||||||
|
height: int = 0,
|
||||||
|
weight: int = 0,
|
||||||
|
age: int = 0,
|
||||||
|
gender: str = "",
|
||||||
|
appearance: str = "",
|
||||||
|
):
|
||||||
"""初始化身份特征
|
"""初始化身份特征
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identity_detail: 身份细节描述列表
|
identity_detail: 身份细节描述列表
|
||||||
height: 身高(厘米)
|
height: 身高(厘米)
|
||||||
@@ -39,23 +48,24 @@ class Identity:
|
|||||||
self.age = age
|
self.age = age
|
||||||
self.gender = gender
|
self.gender = gender
|
||||||
self.appearance = appearance
|
self.appearance = appearance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> 'Identity':
|
def get_instance(cls) -> "Identity":
|
||||||
"""获取Identity单例实例
|
"""获取Identity单例实例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Identity: 单例实例
|
Identity: 单例实例
|
||||||
"""
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, identity_detail: List[str], height: int, weight: int,
|
def initialize(
|
||||||
age: int, gender: str, appearance: str) -> 'Identity':
|
cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str
|
||||||
|
) -> "Identity":
|
||||||
"""初始化身份特征
|
"""初始化身份特征
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identity_detail: 身份细节描述列表
|
identity_detail: 身份细节描述列表
|
||||||
height: 身高(厘米)
|
height: 身高(厘米)
|
||||||
@@ -63,7 +73,7 @@ class Identity:
|
|||||||
age: 年龄
|
age: 年龄
|
||||||
gender: 性别
|
gender: 性别
|
||||||
appearance: 外貌特征
|
appearance: 外貌特征
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Identity: 初始化后的身份特征实例
|
Identity: 初始化后的身份特征实例
|
||||||
"""
|
"""
|
||||||
@@ -75,8 +85,8 @@ class Identity:
|
|||||||
instance.gender = gender
|
instance.gender = gender
|
||||||
instance.appearance = appearance
|
instance.appearance = appearance
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def get_prompt(self,x_person,level):
|
def get_prompt(self, x_person, level):
|
||||||
"""
|
"""
|
||||||
获取身份特征的prompt
|
获取身份特征的prompt
|
||||||
"""
|
"""
|
||||||
@@ -86,7 +96,7 @@ class Identity:
|
|||||||
prompt_identity = "我"
|
prompt_identity = "我"
|
||||||
else:
|
else:
|
||||||
prompt_identity = "他"
|
prompt_identity = "他"
|
||||||
|
|
||||||
if level == 1:
|
if level == 1:
|
||||||
identity_detail = self.identity_detail
|
identity_detail = self.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
@@ -96,7 +106,7 @@ class Identity:
|
|||||||
prompt_identity += f",{detail}"
|
prompt_identity += f",{detail}"
|
||||||
prompt_identity += "。"
|
prompt_identity += "。"
|
||||||
return prompt_identity
|
return prompt_identity
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""将身份特征转换为字典格式"""
|
"""将身份特征转换为字典格式"""
|
||||||
return {
|
return {
|
||||||
@@ -105,13 +115,13 @@ class Identity:
|
|||||||
"weight": self.weight,
|
"weight": self.weight,
|
||||||
"age": self.age,
|
"age": self.age,
|
||||||
"gender": self.gender,
|
"gender": self.gender,
|
||||||
"appearance": self.appearance
|
"appearance": self.appearance,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> 'Identity':
|
def from_dict(cls, data: dict) -> "Identity":
|
||||||
"""从字典创建身份特征实例"""
|
"""从字典创建身份特征实例"""
|
||||||
instance = cls.get_instance()
|
instance = cls.get_instance()
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
setattr(instance, key, value)
|
setattr(instance, key, value)
|
||||||
return instance
|
return instance
|
||||||
|
|||||||
@@ -2,35 +2,46 @@ from typing import Optional
|
|||||||
from .personality import Personality
|
from .personality import Personality
|
||||||
from .identity import Identity
|
from .identity import Identity
|
||||||
|
|
||||||
|
|
||||||
class Individuality:
|
class Individuality:
|
||||||
"""个体特征管理类"""
|
"""个体特征管理类"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.personality: Optional[Personality] = None
|
self.personality: Optional[Personality] = None
|
||||||
self.identity: Optional[Identity] = None
|
self.identity: Optional[Identity] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> 'Individuality':
|
def get_instance(cls) -> "Individuality":
|
||||||
"""获取Individuality单例实例
|
"""获取Individuality单例实例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Individuality: 单例实例
|
Individuality: 单例实例
|
||||||
"""
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def initialize(self, bot_nickname: str, personality_core: str, personality_sides: list,
|
def initialize(
|
||||||
identity_detail: list, height: int, weight: int, age: int,
|
self,
|
||||||
gender: str, appearance: str) -> None:
|
bot_nickname: str,
|
||||||
|
personality_core: str,
|
||||||
|
personality_sides: list,
|
||||||
|
identity_detail: list,
|
||||||
|
height: int,
|
||||||
|
weight: int,
|
||||||
|
age: int,
|
||||||
|
gender: str,
|
||||||
|
appearance: str,
|
||||||
|
) -> None:
|
||||||
"""初始化个体特征
|
"""初始化个体特征
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bot_nickname: 机器人昵称
|
bot_nickname: 机器人昵称
|
||||||
personality_core: 人格核心特点
|
personality_core: 人格核心特点
|
||||||
@@ -44,50 +55,43 @@ class Individuality:
|
|||||||
"""
|
"""
|
||||||
# 初始化人格
|
# 初始化人格
|
||||||
self.personality = Personality.initialize(
|
self.personality = Personality.initialize(
|
||||||
bot_nickname=bot_nickname,
|
bot_nickname=bot_nickname, personality_core=personality_core, personality_sides=personality_sides
|
||||||
personality_core=personality_core,
|
|
||||||
personality_sides=personality_sides
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化身份
|
# 初始化身份
|
||||||
self.identity = Identity.initialize(
|
self.identity = Identity.initialize(
|
||||||
identity_detail=identity_detail,
|
identity_detail=identity_detail, height=height, weight=weight, age=age, gender=gender, appearance=appearance
|
||||||
height=height,
|
|
||||||
weight=weight,
|
|
||||||
age=age,
|
|
||||||
gender=gender,
|
|
||||||
appearance=appearance
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""将个体特征转换为字典格式"""
|
"""将个体特征转换为字典格式"""
|
||||||
return {
|
return {
|
||||||
"personality": self.personality.to_dict() if self.personality else None,
|
"personality": self.personality.to_dict() if self.personality else None,
|
||||||
"identity": self.identity.to_dict() if self.identity else None
|
"identity": self.identity.to_dict() if self.identity else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> 'Individuality':
|
def from_dict(cls, data: dict) -> "Individuality":
|
||||||
"""从字典创建个体特征实例"""
|
"""从字典创建个体特征实例"""
|
||||||
instance = cls.get_instance()
|
instance = cls.get_instance()
|
||||||
if data.get("personality"):
|
if data.get("personality"):
|
||||||
instance.personality = Personality.from_dict(data["personality"])
|
instance.personality = Personality.from_dict(data["personality"])
|
||||||
if data.get("identity"):
|
if data.get("identity"):
|
||||||
instance.identity = Identity.from_dict(data["identity"])
|
instance.identity = Identity.from_dict(data["identity"])
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def get_prompt(self,type,x_person,level):
|
def get_prompt(self, type, x_person, level):
|
||||||
"""
|
"""
|
||||||
获取个体特征的prompt
|
获取个体特征的prompt
|
||||||
"""
|
"""
|
||||||
if type == "personality":
|
if type == "personality":
|
||||||
return self.personality.get_prompt(x_person,level)
|
return self.personality.get_prompt(x_person, level)
|
||||||
elif type == "identity":
|
elif type == "identity":
|
||||||
return self.identity.get_prompt(x_person,level)
|
return self.identity.get_prompt(x_person, level)
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_traits(self,factor):
|
def get_traits(self, factor):
|
||||||
"""
|
"""
|
||||||
获取个体特征的特质
|
获取个体特征的特质
|
||||||
"""
|
"""
|
||||||
@@ -101,5 +105,3 @@ class Individuality:
|
|||||||
return self.personality.agreeableness
|
return self.personality.agreeableness
|
||||||
elif factor == "neuroticism":
|
elif factor == "neuroticism":
|
||||||
return self.personality.neuroticism
|
return self.personality.neuroticism
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ with open(config_path, "r", encoding="utf-8") as f:
|
|||||||
config = toml.load(f)
|
config = toml.load(f)
|
||||||
|
|
||||||
# 现在可以导入src模块
|
# 现在可以导入src模块
|
||||||
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES #noqa E402
|
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402
|
||||||
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS #noqa E402
|
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
|
||||||
from src.individuality.offline_llm import LLM_request_off #noqa E402
|
from src.individuality.offline_llm import LLM_request_off # noqa E402
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
env_path = os.path.join(root_path, ".env")
|
env_path = os.path.join(root_path, ".env")
|
||||||
@@ -32,13 +32,12 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def adapt_scene(scene: str) -> str:
|
def adapt_scene(scene: str) -> str:
|
||||||
|
personality_core = config["personality"]["personality_core"]
|
||||||
personality_core = config['personality']['personality_core']
|
personality_sides = config["personality"]["personality_sides"]
|
||||||
personality_sides = config['personality']['personality_sides']
|
|
||||||
personality_side = random.choice(personality_sides)
|
personality_side = random.choice(personality_sides)
|
||||||
identity_details = config['identity']['identity_detail']
|
identity_details = config["identity"]["identity_detail"]
|
||||||
identity_detail = random.choice(identity_details)
|
identity_detail = random.choice(identity_details)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
根据config中的属性,改编场景使其更适合当前角色
|
根据config中的属性,改编场景使其更适合当前角色
|
||||||
|
|
||||||
@@ -51,10 +50,10 @@ def adapt_scene(scene: str) -> str:
|
|||||||
try:
|
try:
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
这是一个参与人格测评的角色形象:
|
这是一个参与人格测评的角色形象:
|
||||||
- 昵称: {config['bot']['nickname']}
|
- 昵称: {config["bot"]["nickname"]}
|
||||||
- 性别: {config['identity']['gender']}
|
- 性别: {config["identity"]["gender"]}
|
||||||
- 年龄: {config['identity']['age']}岁
|
- 年龄: {config["identity"]["age"]}岁
|
||||||
- 外貌: {config['identity']['appearance']}
|
- 外貌: {config["identity"]["appearance"]}
|
||||||
- 性格核心: {personality_core}
|
- 性格核心: {personality_core}
|
||||||
- 性格侧面: {personality_side}
|
- 性格侧面: {personality_side}
|
||||||
- 身份细节: {identity_detail}
|
- 身份细节: {identity_detail}
|
||||||
@@ -62,18 +61,18 @@ def adapt_scene(scene: str) -> str:
|
|||||||
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
|
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
|
||||||
{scene}
|
{scene}
|
||||||
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
|
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
|
||||||
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config['bot']['nickname']}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
|
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config["bot"]["nickname"]}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
|
||||||
现在,请你给出改编后的场景描述
|
现在,请你给出改编后的场景描述
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm = LLM_request_off(model_name=config['model']['llm_normal']['name'])
|
llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"])
|
||||||
adapted_scene, _ = llm.generate_response(prompt)
|
adapted_scene, _ = llm.generate_response(prompt)
|
||||||
|
|
||||||
# 检查返回的场景是否为空或错误信息
|
# 检查返回的场景是否为空或错误信息
|
||||||
if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene:
|
if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene:
|
||||||
print("场景改编失败,将使用原始场景")
|
print("场景改编失败,将使用原始场景")
|
||||||
return scene
|
return scene
|
||||||
|
|
||||||
return adapted_scene
|
return adapted_scene
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"场景改编过程出错:{str(e)},将使用原始场景")
|
print(f"场景改编过程出错:{str(e)},将使用原始场景")
|
||||||
@@ -169,7 +168,7 @@ class PersonalityEvaluator_direct:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"评估过程出错:{str(e)}")
|
print(f"评估过程出错:{str(e)}")
|
||||||
return {dim: 3.5 for dim in dimensions}
|
return {dim: 3.5 for dim in dimensions}
|
||||||
|
|
||||||
def run_evaluation(self):
|
def run_evaluation(self):
|
||||||
"""
|
"""
|
||||||
运行整个评估过程
|
运行整个评估过程
|
||||||
@@ -185,18 +184,23 @@ class PersonalityEvaluator_direct:
|
|||||||
print(f"- 身份细节:{config['identity']['identity_detail']}")
|
print(f"- 身份细节:{config['identity']['identity_detail']}")
|
||||||
print("\n准备好了吗?按回车键开始...")
|
print("\n准备好了吗?按回车键开始...")
|
||||||
input()
|
input()
|
||||||
|
|
||||||
total_scenarios = len(self.scenarios)
|
total_scenarios = len(self.scenarios)
|
||||||
progress_bar = tqdm(total=total_scenarios, desc="场景进度", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
progress_bar = tqdm(
|
||||||
|
total=total_scenarios,
|
||||||
|
desc="场景进度",
|
||||||
|
ncols=100,
|
||||||
|
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
|
||||||
|
)
|
||||||
|
|
||||||
for _i, scenario_data in enumerate(self.scenarios, 1):
|
for _i, scenario_data in enumerate(self.scenarios, 1):
|
||||||
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
|
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
|
||||||
|
|
||||||
# 改编场景,使其更适合当前角色
|
# 改编场景,使其更适合当前角色
|
||||||
print(f"{config['bot']['nickname']}祈祷中...")
|
print(f"{config['bot']['nickname']}祈祷中...")
|
||||||
adapted_scene = adapt_scene(scenario_data["场景"])
|
adapted_scene = adapt_scene(scenario_data["场景"])
|
||||||
scenario_data["改编场景"] = adapted_scene
|
scenario_data["改编场景"] = adapted_scene
|
||||||
|
|
||||||
print(adapted_scene)
|
print(adapted_scene)
|
||||||
print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:")
|
print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:")
|
||||||
response = input().strip()
|
response = input().strip()
|
||||||
@@ -220,13 +224,13 @@ class PersonalityEvaluator_direct:
|
|||||||
|
|
||||||
# 更新进度条
|
# 更新进度条
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
# if i < total_scenarios:
|
# if i < total_scenarios:
|
||||||
# print("\n按回车键继续下一个场景...")
|
# print("\n按回车键继续下一个场景...")
|
||||||
# input()
|
# input()
|
||||||
|
|
||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
|
|
||||||
# 计算平均分
|
# 计算平均分
|
||||||
for dimension in self.final_scores:
|
for dimension in self.final_scores:
|
||||||
if self.dimension_counts[dimension] > 0:
|
if self.dimension_counts[dimension] > 0:
|
||||||
@@ -241,26 +245,26 @@ class PersonalityEvaluator_direct:
|
|||||||
|
|
||||||
# 返回评估结果
|
# 返回评估结果
|
||||||
return self.get_result()
|
return self.get_result()
|
||||||
|
|
||||||
def get_result(self):
|
def get_result(self):
|
||||||
"""
|
"""
|
||||||
获取评估结果
|
获取评估结果
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"final_scores": self.final_scores,
|
"final_scores": self.final_scores,
|
||||||
"dimension_counts": self.dimension_counts,
|
"dimension_counts": self.dimension_counts,
|
||||||
"scenarios": self.scenarios,
|
"scenarios": self.scenarios,
|
||||||
"bot_info": {
|
"bot_info": {
|
||||||
"nickname": config['bot']['nickname'],
|
"nickname": config["bot"]["nickname"],
|
||||||
"gender": config['identity']['gender'],
|
"gender": config["identity"]["gender"],
|
||||||
"age": config['identity']['age'],
|
"age": config["identity"]["age"],
|
||||||
"height": config['identity']['height'],
|
"height": config["identity"]["height"],
|
||||||
"weight": config['identity']['weight'],
|
"weight": config["identity"]["weight"],
|
||||||
"appearance": config['identity']['appearance'],
|
"appearance": config["identity"]["appearance"],
|
||||||
"personality_core": config['personality']['personality_core'],
|
"personality_core": config["personality"]["personality_core"],
|
||||||
"personality_sides": config['personality']['personality_sides'],
|
"personality_sides": config["personality"]["personality_sides"],
|
||||||
"identity_detail": config['identity']['identity_detail']
|
"identity_detail": config["identity"]["identity_detail"],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -275,28 +279,28 @@ def main():
|
|||||||
"extraversion": round(result["final_scores"]["外向性"] / 6, 1),
|
"extraversion": round(result["final_scores"]["外向性"] / 6, 1),
|
||||||
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
|
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
|
||||||
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
|
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
|
||||||
"bot_nickname": config['bot']['nickname']
|
"bot_nickname": config["bot"]["nickname"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
save_dir = os.path.join(root_path, "data", "personality")
|
save_dir = os.path.join(root_path, "data", "personality")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
# 创建文件名,替换可能的非法字符
|
# 创建文件名,替换可能的非法字符
|
||||||
bot_name = config['bot']['nickname']
|
bot_name = config["bot"]["nickname"]
|
||||||
# 替换Windows文件名中不允许的字符
|
# 替换Windows文件名中不允许的字符
|
||||||
for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']:
|
for char in ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]:
|
||||||
bot_name = bot_name.replace(char, '_')
|
bot_name = bot_name.replace(char, "_")
|
||||||
|
|
||||||
file_name = f"{bot_name}_personality.per"
|
file_name = f"{bot_name}_personality.per"
|
||||||
save_path = os.path.join(save_dir, file_name)
|
save_path = os.path.join(save_dir, file_name)
|
||||||
|
|
||||||
# 保存简化的结果
|
# 保存简化的结果
|
||||||
with open(save_path, "w", encoding="utf-8") as f:
|
with open(save_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(simplified_result, f, ensure_ascii=False, indent=4)
|
json.dump(simplified_result, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
print(f"\n结果已保存到 {save_path}")
|
print(f"\n结果已保存到 {save_path}")
|
||||||
|
|
||||||
# 同时保存完整结果到results目录
|
# 同时保存完整结果到results目录
|
||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
with open("results/personality_result.json", "w", encoding="utf-8") as f:
|
with open("results/personality_result.json", "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Personality:
|
class Personality:
|
||||||
"""人格特质类"""
|
"""人格特质类"""
|
||||||
|
|
||||||
openness: float # 开放性
|
openness: float # 开放性
|
||||||
conscientiousness: float # 尽责性
|
conscientiousness: float # 尽责性
|
||||||
extraversion: float # 外向性
|
extraversion: float # 外向性
|
||||||
@@ -15,45 +17,45 @@ class Personality:
|
|||||||
bot_nickname: str # 机器人昵称
|
bot_nickname: str # 机器人昵称
|
||||||
personality_core: str # 人格核心特点
|
personality_core: str # 人格核心特点
|
||||||
personality_sides: List[str] # 人格侧面描述
|
personality_sides: List[str] # 人格侧面描述
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
|
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
|
||||||
if personality_sides is None:
|
if personality_sides is None:
|
||||||
personality_sides = []
|
personality_sides = []
|
||||||
self.personality_core = personality_core
|
self.personality_core = personality_core
|
||||||
self.personality_sides = personality_sides
|
self.personality_sides = personality_sides
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> 'Personality':
|
def get_instance(cls) -> "Personality":
|
||||||
"""获取Personality单例实例
|
"""获取Personality单例实例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Personality: 单例实例
|
Personality: 单例实例
|
||||||
"""
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _init_big_five_personality(self):
|
def _init_big_five_personality(self):
|
||||||
"""初始化大五人格特质"""
|
"""初始化大五人格特质"""
|
||||||
# 构建文件路径
|
# 构建文件路径
|
||||||
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
|
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
|
||||||
|
|
||||||
# 如果文件存在,读取文件
|
# 如果文件存在,读取文件
|
||||||
if personality_file.exists():
|
if personality_file.exists():
|
||||||
with open(personality_file, 'r', encoding='utf-8') as f:
|
with open(personality_file, "r", encoding="utf-8") as f:
|
||||||
personality_data = json.load(f)
|
personality_data = json.load(f)
|
||||||
self.openness = personality_data.get('openness', 0.5)
|
self.openness = personality_data.get("openness", 0.5)
|
||||||
self.conscientiousness = personality_data.get('conscientiousness', 0.5)
|
self.conscientiousness = personality_data.get("conscientiousness", 0.5)
|
||||||
self.extraversion = personality_data.get('extraversion', 0.5)
|
self.extraversion = personality_data.get("extraversion", 0.5)
|
||||||
self.agreeableness = personality_data.get('agreeableness', 0.5)
|
self.agreeableness = personality_data.get("agreeableness", 0.5)
|
||||||
self.neuroticism = personality_data.get('neuroticism', 0.5)
|
self.neuroticism = personality_data.get("neuroticism", 0.5)
|
||||||
else:
|
else:
|
||||||
# 如果文件不存在,根据personality_core和personality_core来设置大五人格特质
|
# 如果文件不存在,根据personality_core和personality_core来设置大五人格特质
|
||||||
if "活泼" in self.personality_core or "开朗" in self.personality_sides:
|
if "活泼" in self.personality_core or "开朗" in self.personality_sides:
|
||||||
@@ -62,31 +64,31 @@ class Personality:
|
|||||||
else:
|
else:
|
||||||
self.extraversion = 0.3
|
self.extraversion = 0.3
|
||||||
self.neuroticism = 0.5
|
self.neuroticism = 0.5
|
||||||
|
|
||||||
if "认真" in self.personality_core or "负责" in self.personality_sides:
|
if "认真" in self.personality_core or "负责" in self.personality_sides:
|
||||||
self.conscientiousness = 0.9
|
self.conscientiousness = 0.9
|
||||||
else:
|
else:
|
||||||
self.conscientiousness = 0.5
|
self.conscientiousness = 0.5
|
||||||
|
|
||||||
if "友善" in self.personality_core or "温柔" in self.personality_sides:
|
if "友善" in self.personality_core or "温柔" in self.personality_sides:
|
||||||
self.agreeableness = 0.9
|
self.agreeableness = 0.9
|
||||||
else:
|
else:
|
||||||
self.agreeableness = 0.5
|
self.agreeableness = 0.5
|
||||||
|
|
||||||
if "创新" in self.personality_core or "开放" in self.personality_sides:
|
if "创新" in self.personality_core or "开放" in self.personality_sides:
|
||||||
self.openness = 0.8
|
self.openness = 0.8
|
||||||
else:
|
else:
|
||||||
self.openness = 0.5
|
self.openness = 0.5
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> 'Personality':
|
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> "Personality":
|
||||||
"""初始化人格特质
|
"""初始化人格特质
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bot_nickname: 机器人昵称
|
bot_nickname: 机器人昵称
|
||||||
personality_core: 人格核心特点
|
personality_core: 人格核心特点
|
||||||
personality_sides: 人格侧面描述
|
personality_sides: 人格侧面描述
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Personality: 初始化后的人格特质实例
|
Personality: 初始化后的人格特质实例
|
||||||
"""
|
"""
|
||||||
@@ -96,7 +98,7 @@ class Personality:
|
|||||||
instance.personality_sides = personality_sides
|
instance.personality_sides = personality_sides
|
||||||
instance._init_big_five_personality()
|
instance._init_big_five_personality()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
"""将人格特质转换为字典格式"""
|
"""将人格特质转换为字典格式"""
|
||||||
return {
|
return {
|
||||||
@@ -107,18 +109,18 @@ class Personality:
|
|||||||
"neuroticism": self.neuroticism,
|
"neuroticism": self.neuroticism,
|
||||||
"bot_nickname": self.bot_nickname,
|
"bot_nickname": self.bot_nickname,
|
||||||
"personality_core": self.personality_core,
|
"personality_core": self.personality_core,
|
||||||
"personality_sides": self.personality_sides
|
"personality_sides": self.personality_sides,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict) -> 'Personality':
|
def from_dict(cls, data: Dict) -> "Personality":
|
||||||
"""从字典创建人格特质实例"""
|
"""从字典创建人格特质实例"""
|
||||||
instance = cls.get_instance()
|
instance = cls.get_instance()
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
setattr(instance, key, value)
|
setattr(instance, key, value)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def get_prompt(self,x_person,level):
|
def get_prompt(self, x_person, level):
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
if x_person == 2:
|
if x_person == 2:
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
@@ -126,10 +128,10 @@ class Personality:
|
|||||||
prompt_personality = "我"
|
prompt_personality = "我"
|
||||||
else:
|
else:
|
||||||
prompt_personality = "他"
|
prompt_personality = "他"
|
||||||
#person
|
# person
|
||||||
|
|
||||||
prompt_personality += self.personality_core
|
prompt_personality += self.personality_core
|
||||||
|
|
||||||
if level == 2:
|
if level == 2:
|
||||||
personality_sides = self.personality_sides
|
personality_sides = self.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
@@ -140,5 +142,5 @@ class Personality:
|
|||||||
prompt_personality += f",{side}"
|
prompt_personality += f",{side}"
|
||||||
|
|
||||||
prompt_personality += "。"
|
prompt_personality += "。"
|
||||||
|
|
||||||
return prompt_personality
|
return prompt_personality
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def load_scenes() -> Dict:
|
def load_scenes() -> Dict:
|
||||||
"""
|
"""
|
||||||
从JSON文件加载场景数据
|
从JSON文件加载场景数据
|
||||||
@@ -10,13 +11,15 @@ def load_scenes() -> Dict:
|
|||||||
Dict: 包含所有场景的字典
|
Dict: 包含所有场景的字典
|
||||||
"""
|
"""
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
json_path = os.path.join(current_dir, 'template_scene.json')
|
json_path = os.path.join(current_dir, "template_scene.json")
|
||||||
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
PERSONALITY_SCENES = load_scenes()
|
PERSONALITY_SCENES = load_scenes()
|
||||||
|
|
||||||
|
|
||||||
def get_scene_by_factor(factor: str) -> Dict:
|
def get_scene_by_factor(factor: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
根据人格因子获取对应的情景测试
|
根据人格因子获取对应的情景测试
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class MainSystem:
|
|||||||
weight=global_config.weight,
|
weight=global_config.weight,
|
||||||
age=global_config.age,
|
age=global_config.age,
|
||||||
gender=global_config.gender,
|
gender=global_config.gender,
|
||||||
appearance=global_config.appearance
|
appearance=global_config.appearance,
|
||||||
)
|
)
|
||||||
logger.success("个体特征初始化成功")
|
logger.success("个体特征初始化成功")
|
||||||
|
|
||||||
@@ -135,7 +135,6 @@ class MainSystem:
|
|||||||
await asyncio.sleep(global_config.build_memory_interval)
|
await asyncio.sleep(global_config.build_memory_interval)
|
||||||
logger.info("正在进行记忆构建")
|
logger.info("正在进行记忆构建")
|
||||||
await HippocampusManager.get_instance().build_memory()
|
await HippocampusManager.get_instance().build_memory()
|
||||||
|
|
||||||
|
|
||||||
async def forget_memory_task(self):
|
async def forget_memory_task(self):
|
||||||
"""记忆遗忘任务"""
|
"""记忆遗忘任务"""
|
||||||
@@ -144,7 +143,6 @@ class MainSystem:
|
|||||||
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
||||||
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
|
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
|
||||||
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
||||||
|
|
||||||
|
|
||||||
async def print_mood_task(self):
|
async def print_mood_task(self):
|
||||||
"""打印情绪状态"""
|
"""打印情绪状态"""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Dict, Any, List, Tuple
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..message.message_base import UserInfo
|
from ..message.message_base import UserInfo
|
||||||
from ..config.config import global_config
|
from ..config.config import global_config
|
||||||
@@ -9,16 +9,17 @@ from .message_storage import MessageStorage, MongoDBMessageStorage
|
|||||||
|
|
||||||
logger = get_module_logger("chat_observer")
|
logger = get_module_logger("chat_observer")
|
||||||
|
|
||||||
|
|
||||||
class ChatObserver:
|
class ChatObserver:
|
||||||
"""聊天状态观察器"""
|
"""聊天状态观察器"""
|
||||||
|
|
||||||
# 类级别的实例管理
|
# 类级别的实例管理
|
||||||
_instances: Dict[str, 'ChatObserver'] = {}
|
_instances: Dict[str, "ChatObserver"] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
||||||
"""获取或创建观察器实例
|
"""获取或创建观察器实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||||
@@ -32,14 +33,14 @@ class ChatObserver:
|
|||||||
|
|
||||||
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
|
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
|
||||||
"""初始化观察器
|
"""初始化观察器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||||
"""
|
"""
|
||||||
if stream_id in self._instances:
|
if stream_id in self._instances:
|
||||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||||
|
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.message_storage = message_storage or MongoDBMessageStorage()
|
self.message_storage = message_storage or MongoDBMessageStorage()
|
||||||
|
|
||||||
@@ -53,9 +54,9 @@ class ChatObserver:
|
|||||||
|
|
||||||
# 消息历史记录
|
# 消息历史记录
|
||||||
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
||||||
self.last_message_id: Optional[str] = None # 最后一条消息的ID
|
self.last_message_id: Optional[str] = None # 最后一条消息的ID
|
||||||
self.message_count: int = 0 # 消息计数
|
self.message_count: int = 0 # 消息计数
|
||||||
|
|
||||||
# 运行状态
|
# 运行状态
|
||||||
self._running: bool = False
|
self._running: bool = False
|
||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
@@ -77,7 +78,7 @@ class ChatObserver:
|
|||||||
|
|
||||||
async def check(self) -> bool:
|
async def check(self) -> bool:
|
||||||
"""检查距离上一次观察之后是否有了新消息
|
"""检查距离上一次观察之后是否有了新消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否有新消息
|
bool: 是否有新消息
|
||||||
"""
|
"""
|
||||||
@@ -91,7 +92,7 @@ class ChatObserver:
|
|||||||
if new_message_exists:
|
if new_message_exists:
|
||||||
logger.debug("发现新消息")
|
logger.debug("发现新消息")
|
||||||
self.last_check_time = time.time()
|
self.last_check_time = time.time()
|
||||||
|
|
||||||
return new_message_exists
|
return new_message_exists
|
||||||
|
|
||||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||||
@@ -104,7 +105,7 @@ class ChatObserver:
|
|||||||
self.last_message_id = message["message_id"]
|
self.last_message_id = message["message_id"]
|
||||||
self.last_message_time = message["time"] # 更新最后消息时间
|
self.last_message_time = message["time"] # 更新最后消息时间
|
||||||
self.message_count += 1
|
self.message_count += 1
|
||||||
|
|
||||||
# 更新说话时间
|
# 更新说话时间
|
||||||
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||||
if user_info.user_id == global_config.BOT_QQ:
|
if user_info.user_id == global_config.BOT_QQ:
|
||||||
@@ -186,41 +187,40 @@ class ChatObserver:
|
|||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
end_time: Optional[float] = None,
|
end_time: Optional[float] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取消息历史
|
"""获取消息历史
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_time: 开始时间戳
|
start_time: 开始时间戳
|
||||||
end_time: 结束时间戳
|
end_time: 结束时间戳
|
||||||
limit: 限制返回消息数量
|
limit: 限制返回消息数量
|
||||||
user_id: 指定用户ID
|
user_id: 指定用户ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 消息列表
|
List[Dict[str, Any]]: 消息列表
|
||||||
"""
|
"""
|
||||||
filtered_messages = self.message_history
|
filtered_messages = self.message_history
|
||||||
|
|
||||||
if start_time is not None:
|
if start_time is not None:
|
||||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||||
|
|
||||||
if end_time is not None:
|
if end_time is not None:
|
||||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
filtered_messages = [
|
filtered_messages = [
|
||||||
m for m in filtered_messages
|
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||||
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
filtered_messages = filtered_messages[-limit:]
|
filtered_messages = filtered_messages[-limit:]
|
||||||
|
|
||||||
return filtered_messages
|
return filtered_messages
|
||||||
|
|
||||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||||
"""获取新消息
|
"""获取新消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 新消息列表
|
List[Dict[str, Any]]: 新消息列表
|
||||||
"""
|
"""
|
||||||
@@ -231,15 +231,15 @@ class ChatObserver:
|
|||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||||
"""获取指定时间点之前的消息
|
"""获取指定时间点之前的消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time_point: 时间戳
|
time_point: 时间戳
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 最多5条消息
|
List[Dict[str, Any]]: 最多5条消息
|
||||||
"""
|
"""
|
||||||
@@ -250,7 +250,7 @@ class ChatObserver:
|
|||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
'''主要观察循环'''
|
'''主要观察循环'''
|
||||||
@@ -263,7 +263,7 @@ class ChatObserver:
|
|||||||
await self._add_message_to_history(message)
|
await self._add_message_to_history(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"缓冲消息出错: {e}")
|
logger.error(f"缓冲消息出错: {e}")
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
# 等待事件或超时(1秒)
|
# 等待事件或超时(1秒)
|
||||||
@@ -271,13 +271,13 @@ class ChatObserver:
|
|||||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass # 超时后也执行一次检查
|
pass # 超时后也执行一次检查
|
||||||
|
|
||||||
self._update_event.clear() # 重置触发事件
|
self._update_event.clear() # 重置触发事件
|
||||||
self._update_complete.clear() # 重置完成事件
|
self._update_complete.clear() # 重置完成事件
|
||||||
|
|
||||||
# 获取新消息
|
# 获取新消息
|
||||||
new_messages = await self._fetch_new_messages()
|
new_messages = await self._fetch_new_messages()
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
# 处理新消息
|
# 处理新消息
|
||||||
for message in new_messages:
|
for message in new_messages:
|
||||||
@@ -285,21 +285,21 @@ class ChatObserver:
|
|||||||
|
|
||||||
# 设置完成事件
|
# 设置完成事件
|
||||||
self._update_complete.set()
|
self._update_complete.set()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新循环出错: {e}")
|
logger.error(f"更新循环出错: {e}")
|
||||||
self._update_complete.set() # 即使出错也要设置完成事件
|
self._update_complete.set() # 即使出错也要设置完成事件
|
||||||
|
|
||||||
def trigger_update(self):
|
def trigger_update(self):
|
||||||
"""触发一次立即更新"""
|
"""触发一次立即更新"""
|
||||||
self._update_event.set()
|
self._update_event.set()
|
||||||
|
|
||||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||||
"""等待更新完成
|
"""等待更新完成
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: 超时时间(秒)
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功完成更新(False表示超时)
|
bool: 是否成功完成更新(False表示超时)
|
||||||
"""
|
"""
|
||||||
@@ -309,16 +309,16 @@ class ChatObserver:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(f"等待更新完成超时({timeout}秒)")
|
logger.warning(f"等待更新完成超时({timeout}秒)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""启动观察器"""
|
"""启动观察器"""
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._task = asyncio.create_task(self._update_loop())
|
self._task = asyncio.create_task(self._update_loop())
|
||||||
logger.info(f"ChatObserver for {self.stream_id} started")
|
logger.info(f"ChatObserver for {self.stream_id} started")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""停止观察器"""
|
"""停止观察器"""
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -327,15 +327,15 @@ class ChatObserver:
|
|||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
logger.info(f"ChatObserver for {self.stream_id} stopped")
|
logger.info(f"ChatObserver for {self.stream_id} stopped")
|
||||||
|
|
||||||
async def process_chat_history(self, messages: list):
|
async def process_chat_history(self, messages: list):
|
||||||
"""处理聊天历史
|
"""处理聊天历史
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
"""
|
"""
|
||||||
self.update_check_time()
|
self.update_check_time()
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
try:
|
try:
|
||||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||||
@@ -345,33 +345,33 @@ class ChatObserver:
|
|||||||
self.update_user_speak_time(msg["time"])
|
self.update_user_speak_time(msg["time"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"处理消息时间时出错: {e}")
|
logger.warning(f"处理消息时间时出错: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def update_check_time(self):
|
def update_check_time(self):
|
||||||
"""更新查看时间"""
|
"""更新查看时间"""
|
||||||
self.last_check_time = time.time()
|
self.last_check_time = time.time()
|
||||||
|
|
||||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||||
"""更新机器人说话时间"""
|
"""更新机器人说话时间"""
|
||||||
self.last_bot_speak_time = speak_time or time.time()
|
self.last_bot_speak_time = speak_time or time.time()
|
||||||
|
|
||||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||||
"""更新用户说话时间"""
|
"""更新用户说话时间"""
|
||||||
self.last_user_speak_time = speak_time or time.time()
|
self.last_user_speak_time = speak_time or time.time()
|
||||||
|
|
||||||
def get_time_info(self) -> str:
|
def get_time_info(self) -> str:
|
||||||
"""获取时间信息文本"""
|
"""获取时间信息文本"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
time_info = ""
|
time_info = ""
|
||||||
|
|
||||||
if self.last_bot_speak_time:
|
if self.last_bot_speak_time:
|
||||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||||
|
|
||||||
if self.last_user_speak_time:
|
if self.last_user_speak_time:
|
||||||
user_speak_ago = current_time - self.last_user_speak_time
|
user_speak_ago = current_time - self.last_user_speak_time
|
||||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||||
|
|
||||||
return time_info
|
return time_info
|
||||||
|
|
||||||
def start_periodic_update(self):
|
def start_periodic_update(self):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#Programmable Friendly Conversationalist
|
# Programmable Friendly Conversationalist
|
||||||
#Prefrontal cortex
|
# Prefrontal cortex
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||||
@@ -29,20 +29,17 @@ logger = get_module_logger("pfc")
|
|||||||
|
|
||||||
class GoalAnalyzer:
|
class GoalAnalyzer:
|
||||||
"""对话目标分析器"""
|
"""对话目标分析器"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str):
|
||||||
self.llm = LLM_request(
|
self.llm = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1000,
|
|
||||||
request_type="conversation_goal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
|
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
|
|
||||||
# 多目标存储结构
|
# 多目标存储结构
|
||||||
self.goals = [] # 存储多个目标
|
self.goals = [] # 存储多个目标
|
||||||
self.max_goals = 3 # 同时保持的最大目标数量
|
self.max_goals = 3 # 同时保持的最大目标数量
|
||||||
@@ -50,10 +47,10 @@ class GoalAnalyzer:
|
|||||||
|
|
||||||
async def analyze_goal(self) -> Tuple[str, str, str]:
|
async def analyze_goal(self) -> Tuple[str, str, str]:
|
||||||
"""分析对话历史并设定目标
|
"""分析对话历史并设定目标
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_history: 聊天历史记录列表
|
chat_history: 聊天历史记录列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||||
"""
|
"""
|
||||||
@@ -70,16 +67,16 @@ class GoalAnalyzer:
|
|||||||
if sender == self.name:
|
if sender == self.name:
|
||||||
sender = "你说"
|
sender = "你说"
|
||||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||||
|
|
||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
# 构建当前已有目标的文本
|
# 构建当前已有目标的文本
|
||||||
existing_goals_text = ""
|
existing_goals_text = ""
|
||||||
if self.goals:
|
if self.goals:
|
||||||
existing_goals_text = "当前已有的对话目标:\n"
|
existing_goals_text = "当前已有的对话目标:\n"
|
||||||
for i, (goal, _, reason) in enumerate(self.goals):
|
for i, (goal, _, reason) in enumerate(self.goals):
|
||||||
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
|
existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
|
||||||
|
|
||||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||||
这些目标应该反映出对话的不同方面和意图。
|
这些目标应该反映出对话的不同方面和意图。
|
||||||
|
|
||||||
@@ -107,46 +104,44 @@ class GoalAnalyzer:
|
|||||||
logger.debug(f"发送到LLM的提示词: {prompt}")
|
logger.debug(f"发送到LLM的提示词: {prompt}")
|
||||||
content, _ = await self.llm.generate_response_async(prompt)
|
content, _ = await self.llm.generate_response_async(prompt)
|
||||||
logger.debug(f"LLM原始返回内容: {content}")
|
logger.debug(f"LLM原始返回内容: {content}")
|
||||||
|
|
||||||
# 使用简化函数提取JSON内容
|
# 使用简化函数提取JSON内容
|
||||||
success, result = get_items_from_json(
|
success, result = get_items_from_json(
|
||||||
content,
|
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
|
||||||
"goal", "reasoning",
|
|
||||||
required_types={"goal": str, "reasoning": str}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error(f"无法解析JSON,重试第{retry + 1}次")
|
logger.error(f"无法解析JSON,重试第{retry + 1}次")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
goal = result["goal"]
|
goal = result["goal"]
|
||||||
reasoning = result["reasoning"]
|
reasoning = result["reasoning"]
|
||||||
|
|
||||||
# 使用默认的方法
|
# 使用默认的方法
|
||||||
method = "以友好的态度回应"
|
method = "以友好的态度回应"
|
||||||
|
|
||||||
# 更新目标列表
|
# 更新目标列表
|
||||||
await self._update_goals(goal, method, reasoning)
|
await self._update_goals(goal, method, reasoning)
|
||||||
|
|
||||||
# 返回当前最主要的目标
|
# 返回当前最主要的目标
|
||||||
if self.goals:
|
if self.goals:
|
||||||
current_goal, current_method, current_reasoning = self.goals[0]
|
current_goal, current_method, current_reasoning = self.goals[0]
|
||||||
return current_goal, current_method, current_reasoning
|
return current_goal, current_method, current_reasoning
|
||||||
else:
|
else:
|
||||||
return goal, method, reasoning
|
return goal, method, reasoning
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
|
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
|
||||||
if retry == max_retries - 1:
|
if retry == max_retries - 1:
|
||||||
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 所有重试都失败后的默认返回
|
# 所有重试都失败后的默认返回
|
||||||
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
||||||
|
|
||||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||||
"""更新目标列表
|
"""更新目标列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_goal: 新的目标
|
new_goal: 新的目标
|
||||||
method: 实现目标的方法
|
method: 实现目标的方法
|
||||||
@@ -160,23 +155,23 @@ class GoalAnalyzer:
|
|||||||
# 将此目标移到列表前面(最主要的位置)
|
# 将此目标移到列表前面(最主要的位置)
|
||||||
self.goals.insert(0, self.goals.pop(i))
|
self.goals.insert(0, self.goals.pop(i))
|
||||||
return
|
return
|
||||||
|
|
||||||
# 添加新目标到列表前面
|
# 添加新目标到列表前面
|
||||||
self.goals.insert(0, (new_goal, method, reasoning))
|
self.goals.insert(0, (new_goal, method, reasoning))
|
||||||
|
|
||||||
# 限制目标数量
|
# 限制目标数量
|
||||||
if len(self.goals) > self.max_goals:
|
if len(self.goals) > self.max_goals:
|
||||||
self.goals.pop() # 移除最老的目标
|
self.goals.pop() # 移除最老的目标
|
||||||
|
|
||||||
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
|
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
|
||||||
"""简单计算两个目标之间的相似度
|
"""简单计算两个目标之间的相似度
|
||||||
|
|
||||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
goal1: 第一个目标
|
goal1: 第一个目标
|
||||||
goal2: 第二个目标
|
goal2: 第二个目标
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: 相似度得分 (0-1)
|
float: 相似度得分 (0-1)
|
||||||
"""
|
"""
|
||||||
@@ -186,18 +181,18 @@ class GoalAnalyzer:
|
|||||||
overlap = len(words1.intersection(words2))
|
overlap = len(words1.intersection(words2))
|
||||||
total = len(words1.union(words2))
|
total = len(words1.union(words2))
|
||||||
return overlap / total if total > 0 else 0
|
return overlap / total if total > 0 else 0
|
||||||
|
|
||||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||||
"""获取所有当前目标
|
"""获取所有当前目标
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||||
"""
|
"""
|
||||||
return self.goals.copy()
|
return self.goals.copy()
|
||||||
|
|
||||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||||
"""获取除了当前主要目标外的其他备选目标
|
"""获取除了当前主要目标外的其他备选目标
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, str, str]]: 备选目标列表
|
List[Tuple[str, str, str]]: 备选目标列表
|
||||||
"""
|
"""
|
||||||
@@ -215,9 +210,9 @@ class GoalAnalyzer:
|
|||||||
if sender == self.name:
|
if sender == self.name:
|
||||||
sender = "你说"
|
sender = "你说"
|
||||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||||
|
|
||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,
|
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,
|
||||||
当前对话目标:{goal}
|
当前对话目标:{goal}
|
||||||
产生该对话目标的原因:{reasoning}
|
产生该对话目标的原因:{reasoning}
|
||||||
@@ -247,7 +242,7 @@ class GoalAnalyzer:
|
|||||||
"goal_achieved", "stop_conversation", "reason",
|
"goal_achieved", "stop_conversation", "reason",
|
||||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
|
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error("无法解析对话分析结果JSON")
|
logger.error("无法解析对话分析结果JSON")
|
||||||
return False, False, "解析结果失败"
|
return False, False, "解析结果失败"
|
||||||
@@ -265,14 +260,15 @@ class GoalAnalyzer:
|
|||||||
|
|
||||||
class Waiter:
|
class Waiter:
|
||||||
"""快 速 等 待"""
|
"""快 速 等 待"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str):
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
|
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
|
|
||||||
async def wait(self) -> bool:
|
async def wait(self) -> bool:
|
||||||
"""等待
|
"""等待
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否超时(True表示超时)
|
bool: 是否超时(True表示超时)
|
||||||
"""
|
"""
|
||||||
@@ -298,7 +294,7 @@ class Waiter:
|
|||||||
|
|
||||||
class DirectMessageSender:
|
class DirectMessageSender:
|
||||||
"""直接发送消息到平台的发送器"""
|
"""直接发送消息到平台的发送器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.logger = get_module_logger("direct_sender")
|
self.logger = get_module_logger("direct_sender")
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
@@ -310,7 +306,7 @@ class DirectMessageSender:
|
|||||||
reply_to_message: Optional[Message] = None,
|
reply_to_message: Optional[Message] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""直接发送消息到平台
|
"""直接发送消息到平台
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_stream: 聊天流
|
chat_stream: 聊天流
|
||||||
content: 消息内容
|
content: 消息内容
|
||||||
@@ -323,7 +319,7 @@ class DirectMessageSender:
|
|||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
platform=chat_stream.platform,
|
platform=chat_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = MessageSending(
|
message = MessageSending(
|
||||||
message_id=f"dm{round(time.time(), 2)}",
|
message_id=f"dm{round(time.time(), 2)}",
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
@@ -343,18 +339,17 @@ class DirectMessageSender:
|
|||||||
try:
|
try:
|
||||||
message_json = message.to_dict()
|
message_json = message.to_dict()
|
||||||
end_point = global_config.api_urls.get(chat_stream.platform, None)
|
end_point = global_config.api_urls.get(chat_stream.platform, None)
|
||||||
|
|
||||||
if not end_point:
|
if not end_point:
|
||||||
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
|
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
|
||||||
|
|
||||||
await global_api.send_message_REST(end_point, message_json)
|
await global_api.send_message_REST(end_point, message_json)
|
||||||
|
|
||||||
# 存储消息
|
# 存储消息
|
||||||
await self.storage.store_message(message, message.chat_stream)
|
await self.storage.store_message(message, message.chat_stream)
|
||||||
|
|
||||||
self.logger.info(f"直接发送消息成功: {content[:30]}...")
|
self.logger.info(f"直接发送消息成功: {content[:30]}...")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"直接发送消息失败: {str(e)}")
|
self.logger.error(f"直接发送消息失败: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -7,24 +7,22 @@ from ..chat.message import Message
|
|||||||
|
|
||||||
logger = get_module_logger("knowledge_fetcher")
|
logger = get_module_logger("knowledge_fetcher")
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFetcher:
|
class KnowledgeFetcher:
|
||||||
"""知识调取器"""
|
"""知识调取器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = LLM_request(
|
self.llm = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1000,
|
|
||||||
request_type="knowledge_fetch"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||||
"""获取相关知识
|
"""获取相关知识
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询内容
|
query: 查询内容
|
||||||
chat_history: 聊天历史
|
chat_history: 聊天历史
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: (获取的知识, 知识来源)
|
Tuple[str, str]: (获取的知识, 知识来源)
|
||||||
"""
|
"""
|
||||||
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
|
|||||||
for msg in chat_history:
|
for msg in chat_history:
|
||||||
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
|
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
|
||||||
chat_history_text += f"{msg.detailed_plain_text}\n"
|
chat_history_text += f"{msg.detailed_plain_text}\n"
|
||||||
|
|
||||||
# 从记忆中获取相关知识
|
# 从记忆中获取相关知识
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=f"{query}\n{chat_history_text}",
|
text=f"{query}\n{chat_history_text}",
|
||||||
max_memory_num=3,
|
max_memory_num=3,
|
||||||
max_memory_length=2,
|
max_memory_length=2,
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
fast_retrieval=False
|
fast_retrieval=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if related_memory:
|
if related_memory:
|
||||||
knowledge = ""
|
knowledge = ""
|
||||||
sources = []
|
sources = []
|
||||||
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
|
|||||||
knowledge += memory[1] + "\n"
|
knowledge += memory[1] + "\n"
|
||||||
sources.append(f"记忆片段{memory[0]}")
|
sources.append(f"记忆片段{memory[0]}")
|
||||||
return knowledge.strip(), ",".join(sources)
|
return knowledge.strip(), ",".join(sources)
|
||||||
|
|
||||||
return "未找到相关知识", "无记忆匹配"
|
return "未找到相关知识", "无记忆匹配"
|
||||||
|
|||||||
@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger("pfc_utils")
|
logger = get_module_logger("pfc_utils")
|
||||||
|
|
||||||
|
|
||||||
def get_items_from_json(
|
def get_items_from_json(
|
||||||
content: str,
|
content: str,
|
||||||
*items: str,
|
*items: str,
|
||||||
default_values: Optional[Dict[str, Any]] = None,
|
default_values: Optional[Dict[str, Any]] = None,
|
||||||
required_types: Optional[Dict[str, type]] = None
|
required_types: Optional[Dict[str, type]] = None,
|
||||||
) -> Tuple[bool, Dict[str, Any]]:
|
) -> Tuple[bool, Dict[str, Any]]:
|
||||||
"""从文本中提取JSON内容并获取指定字段
|
"""从文本中提取JSON内容并获取指定字段
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 包含JSON的文本
|
content: 包含JSON的文本
|
||||||
*items: 要提取的字段名
|
*items: 要提取的字段名
|
||||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
|
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
|
||||||
"""
|
"""
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# 设置默认值
|
# 设置默认值
|
||||||
if default_values:
|
if default_values:
|
||||||
result.update(default_values)
|
result.update(default_values)
|
||||||
|
|
||||||
# 尝试解析JSON
|
# 尝试解析JSON
|
||||||
try:
|
try:
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||||
json_pattern = r'\{[^{}]*\}'
|
json_pattern = r"\{[^{}]*\}"
|
||||||
json_match = re.search(json_pattern, content)
|
json_match = re.search(json_pattern, content)
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
@@ -45,28 +46,28 @@ def get_items_from_json(
|
|||||||
else:
|
else:
|
||||||
logger.error("无法在返回内容中找到有效的JSON")
|
logger.error("无法在返回内容中找到有效的JSON")
|
||||||
return False, result
|
return False, result
|
||||||
|
|
||||||
# 提取字段
|
# 提取字段
|
||||||
for item in items:
|
for item in items:
|
||||||
if item in json_data:
|
if item in json_data:
|
||||||
result[item] = json_data[item]
|
result[item] = json_data[item]
|
||||||
|
|
||||||
# 验证必需字段
|
# 验证必需字段
|
||||||
if not all(item in result for item in items):
|
if not all(item in result for item in items):
|
||||||
logger.error(f"JSON缺少必要字段,实际内容: {json_data}")
|
logger.error(f"JSON缺少必要字段,实际内容: {json_data}")
|
||||||
return False, result
|
return False, result
|
||||||
|
|
||||||
# 验证字段类型
|
# 验证字段类型
|
||||||
if required_types:
|
if required_types:
|
||||||
for field, expected_type in required_types.items():
|
for field, expected_type in required_types.items():
|
||||||
if field in result and not isinstance(result[field], expected_type):
|
if field in result and not isinstance(result[field], expected_type):
|
||||||
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
|
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
|
||||||
return False, result
|
return False, result
|
||||||
|
|
||||||
# 验证字符串字段不为空
|
# 验证字符串字段不为空
|
||||||
for field in items:
|
for field in items:
|
||||||
if isinstance(result[field], str) and not result[field].strip():
|
if isinstance(result[field], str) and not result[field].strip():
|
||||||
logger.error(f"{field} 不能为空")
|
logger.error(f"{field} 不能为空")
|
||||||
return False, result
|
return False, result
|
||||||
|
|
||||||
return True, result
|
return True, result
|
||||||
|
|||||||
@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
|
|||||||
|
|
||||||
logger = get_module_logger("reply_checker")
|
logger = get_module_logger("reply_checker")
|
||||||
|
|
||||||
|
|
||||||
class ReplyChecker:
|
class ReplyChecker:
|
||||||
"""回复检查器"""
|
"""回复检查器"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str):
|
||||||
self.llm = LLM_request(
|
self.llm = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1000,
|
|
||||||
request_type="reply_check"
|
|
||||||
)
|
)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
self.max_retries = 2 # 最大重试次数
|
self.max_retries = 2 # 最大重试次数
|
||||||
|
|
||||||
async def check(
|
async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
||||||
self,
|
|
||||||
reply: str,
|
|
||||||
goal: str,
|
|
||||||
retry_count: int = 0
|
|
||||||
) -> Tuple[bool, str, bool]:
|
|
||||||
"""检查生成的回复是否合适
|
"""检查生成的回复是否合适
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reply: 生成的回复
|
reply: 生成的回复
|
||||||
goal: 对话目标
|
goal: 对话目标
|
||||||
retry_count: 当前重试次数
|
retry_count: 当前重试次数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||||
"""
|
"""
|
||||||
@@ -49,7 +42,7 @@ class ReplyChecker:
|
|||||||
if sender == self.name:
|
if sender == self.name:
|
||||||
sender = "你说"
|
sender = "你说"
|
||||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||||
|
|
||||||
prompt = f"""请检查以下回复是否合适:
|
prompt = f"""请检查以下回复是否合适:
|
||||||
|
|
||||||
当前对话目标:{goal}
|
当前对话目标:{goal}
|
||||||
@@ -83,7 +76,7 @@ class ReplyChecker:
|
|||||||
try:
|
try:
|
||||||
content, _ = await self.llm.generate_response_async(prompt)
|
content, _ = await self.llm.generate_response_async(prompt)
|
||||||
logger.debug(f"检查回复的原始返回: {content}")
|
logger.debug(f"检查回复的原始返回: {content}")
|
||||||
|
|
||||||
# 清理内容,尝试提取JSON部分
|
# 清理内容,尝试提取JSON部分
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
try:
|
try:
|
||||||
@@ -92,7 +85,8 @@ class ReplyChecker:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||||
import re
|
import re
|
||||||
json_pattern = r'\{[^{}]*\}'
|
|
||||||
|
json_pattern = r"\{[^{}]*\}"
|
||||||
json_match = re.search(json_pattern, content)
|
json_match = re.search(json_pattern, content)
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
@@ -109,33 +103,33 @@ class ReplyChecker:
|
|||||||
reason = content[:100] if content else "无法解析响应"
|
reason = content[:100] if content else "无法解析响应"
|
||||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||||
return is_suitable, reason, need_replan
|
return is_suitable, reason, need_replan
|
||||||
|
|
||||||
# 验证JSON字段
|
# 验证JSON字段
|
||||||
suitable = result.get("suitable", None)
|
suitable = result.get("suitable", None)
|
||||||
reason = result.get("reason", "未提供原因")
|
reason = result.get("reason", "未提供原因")
|
||||||
need_replan = result.get("need_replan", False)
|
need_replan = result.get("need_replan", False)
|
||||||
|
|
||||||
# 如果suitable字段是字符串,转换为布尔值
|
# 如果suitable字段是字符串,转换为布尔值
|
||||||
if isinstance(suitable, str):
|
if isinstance(suitable, str):
|
||||||
suitable = suitable.lower() == "true"
|
suitable = suitable.lower() == "true"
|
||||||
|
|
||||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||||
if suitable is None:
|
if suitable is None:
|
||||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||||
|
|
||||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||||
if not suitable and retry_count < self.max_retries:
|
if not suitable and retry_count < self.max_retries:
|
||||||
return False, reason, False
|
return False, reason, False
|
||||||
|
|
||||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||||
if not suitable and retry_count >= self.max_retries:
|
if not suitable and retry_count >= self.max_retries:
|
||||||
return False, f"多次重试后仍不合适: {reason}", True
|
return False, f"多次重试后仍不合适: {reason}", True
|
||||||
|
|
||||||
return suitable, reason, need_replan
|
return suitable, reason, need_replan
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查回复时出错: {e}")
|
logger.error(f"检查回复时出错: {e}")
|
||||||
# 如果出错且已达到最大重试次数,建议重新规划
|
# 如果出错且已达到最大重试次数,建议重新规划
|
||||||
if retry_count >= self.max_retries:
|
if retry_count >= self.max_retries:
|
||||||
return False, "多次检查失败,建议重新规划", True
|
return False, "多次检查失败,建议重新规划", True
|
||||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||||
|
|||||||
@@ -12,5 +12,5 @@ __all__ = [
|
|||||||
"chat_manager",
|
"chat_manager",
|
||||||
"message_manager",
|
"message_manager",
|
||||||
"MessageStorage",
|
"MessageStorage",
|
||||||
"auto_speak_manager"
|
"auto_speak_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -44,11 +44,11 @@ class ChatBot:
|
|||||||
async def _create_PFC_chat(self, message: MessageRecv):
|
async def _create_PFC_chat(self, message: MessageRecv):
|
||||||
try:
|
try:
|
||||||
chat_id = str(message.chat_stream.stream_id)
|
chat_id = str(message.chat_stream.stream_id)
|
||||||
|
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.enable_pfc_chatting:
|
||||||
|
|
||||||
await self.pfc_manager.get_or_create_conversation(chat_id)
|
await self.pfc_manager.get_or_create_conversation(chat_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建PFC聊天失败: {e}")
|
logger.error(f"创建PFC聊天失败: {e}")
|
||||||
|
|
||||||
@@ -59,16 +59,16 @@ class ChatBot:
|
|||||||
- 包含思维流状态管理
|
- 包含思维流状态管理
|
||||||
- 在回复前进行观察和状态更新
|
- 在回复前进行观察和状态更新
|
||||||
- 回复后更新思维流状态
|
- 回复后更新思维流状态
|
||||||
|
|
||||||
2. reasoning模式:使用推理系统进行回复
|
2. reasoning模式:使用推理系统进行回复
|
||||||
- 直接使用意愿管理器计算回复概率
|
- 直接使用意愿管理器计算回复概率
|
||||||
- 没有思维流相关的状态管理
|
- 没有思维流相关的状态管理
|
||||||
- 更简单直接的回复逻辑
|
- 更简单直接的回复逻辑
|
||||||
|
|
||||||
3. pfc_chatting模式:仅进行消息处理
|
3. pfc_chatting模式:仅进行消息处理
|
||||||
- 不进行任何回复
|
- 不进行任何回复
|
||||||
- 只处理和存储消息
|
- 只处理和存储消息
|
||||||
|
|
||||||
所有模式都包含:
|
所有模式都包含:
|
||||||
- 消息过滤
|
- 消息过滤
|
||||||
- 记忆激活
|
- 记忆激活
|
||||||
@@ -89,7 +89,7 @@ class ChatBot:
|
|||||||
if userinfo.user_id in global_config.ban_user_id:
|
if userinfo.user_id in global_config.ban_user_id:
|
||||||
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
||||||
return
|
return
|
||||||
|
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.enable_pfc_chatting:
|
||||||
try:
|
try:
|
||||||
if groupinfo is None and global_config.enable_friend_chat:
|
if groupinfo is None and global_config.enable_friend_chat:
|
||||||
@@ -118,7 +118,7 @@ class ChatBot:
|
|||||||
logger.error(f"处理PFC消息失败: {e}")
|
logger.error(f"处理PFC消息失败: {e}")
|
||||||
else:
|
else:
|
||||||
if groupinfo is None and global_config.enable_friend_chat:
|
if groupinfo is None and global_config.enable_friend_chat:
|
||||||
# 私聊处理流程
|
# 私聊处理流程
|
||||||
# await self._handle_private_chat(message)
|
# await self._handle_private_chat(message)
|
||||||
if global_config.response_mode == "heart_flow":
|
if global_config.response_mode == "heart_flow":
|
||||||
await self.think_flow_chat.process_message(message_data)
|
await self.think_flow_chat.process_message(message_data)
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ class EmojiManager:
|
|||||||
self.llm_emotion_judge = LLM_request(
|
self.llm_emotion_judge = LLM_request(
|
||||||
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
|
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
|
||||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
self.emoji_num = 0
|
self.emoji_num = 0
|
||||||
self.emoji_num_max = global_config.max_emoji_num
|
self.emoji_num_max = global_config.max_emoji_num
|
||||||
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
|
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
|
||||||
|
|
||||||
logger.info("启动表情包管理器")
|
logger.info("启动表情包管理器")
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
@@ -51,7 +51,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
def _update_emoji_count(self):
|
def _update_emoji_count(self):
|
||||||
"""更新表情包数量统计
|
"""更新表情包数量统计
|
||||||
|
|
||||||
检查数据库中的表情包数量并更新到 self.emoji_num
|
检查数据库中的表情包数量并更新到 self.emoji_num
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -376,7 +376,6 @@ class EmojiManager:
|
|||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[错误] 扫描表情包失败")
|
logger.exception("[错误] 扫描表情包失败")
|
||||||
|
|
||||||
|
|
||||||
def check_emoji_file_integrity(self):
|
def check_emoji_file_integrity(self):
|
||||||
"""检查表情包文件完整性
|
"""检查表情包文件完整性
|
||||||
@@ -451,7 +450,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
def check_emoji_file_full(self):
|
def check_emoji_file_full(self):
|
||||||
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
|
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
|
||||||
|
|
||||||
删除规则:
|
删除规则:
|
||||||
1. 优先删除创建时间更早的表情包
|
1. 优先删除创建时间更早的表情包
|
||||||
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
|
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
|
||||||
@@ -460,23 +459,23 @@ class EmojiManager:
|
|||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
# 更新表情包数量
|
# 更新表情包数量
|
||||||
self._update_emoji_count()
|
self._update_emoji_count()
|
||||||
|
|
||||||
# 检查是否超出限制
|
# 检查是否超出限制
|
||||||
if self.emoji_num <= self.emoji_num_max:
|
if self.emoji_num <= self.emoji_num_max:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果超出限制但不允许删除,则只记录警告
|
# 如果超出限制但不允许删除,则只记录警告
|
||||||
if not global_config.max_reach_deletion:
|
if not global_config.max_reach_deletion:
|
||||||
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
|
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 计算需要删除的数量
|
# 计算需要删除的数量
|
||||||
delete_count = self.emoji_num - self.emoji_num_max
|
delete_count = self.emoji_num - self.emoji_num_max
|
||||||
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
|
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
|
||||||
|
|
||||||
# 获取所有表情包,按时间戳升序(旧的在前)排序
|
# 获取所有表情包,按时间戳升序(旧的在前)排序
|
||||||
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
|
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
|
||||||
|
|
||||||
# 计算权重:使用次数越多,被删除的概率越小
|
# 计算权重:使用次数越多,被删除的概率越小
|
||||||
weights = []
|
weights = []
|
||||||
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
|
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
|
||||||
@@ -485,11 +484,11 @@ class EmojiManager:
|
|||||||
# 使用指数衰减函数计算权重,使用次数越多权重越小
|
# 使用指数衰减函数计算权重,使用次数越多权重越小
|
||||||
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
|
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
|
||||||
weights.append(weight)
|
weights.append(weight)
|
||||||
|
|
||||||
# 根据权重随机选择要删除的表情包
|
# 根据权重随机选择要删除的表情包
|
||||||
to_delete = []
|
to_delete = []
|
||||||
remaining_indices = list(range(len(all_emojis)))
|
remaining_indices = list(range(len(all_emojis)))
|
||||||
|
|
||||||
while len(to_delete) < delete_count and remaining_indices:
|
while len(to_delete) < delete_count and remaining_indices:
|
||||||
# 计算当前剩余表情包的权重
|
# 计算当前剩余表情包的权重
|
||||||
current_weights = [weights[i] for i in remaining_indices]
|
current_weights = [weights[i] for i in remaining_indices]
|
||||||
@@ -497,13 +496,13 @@ class EmojiManager:
|
|||||||
total_weight = sum(current_weights)
|
total_weight = sum(current_weights)
|
||||||
if total_weight == 0:
|
if total_weight == 0:
|
||||||
break
|
break
|
||||||
normalized_weights = [w/total_weight for w in current_weights]
|
normalized_weights = [w / total_weight for w in current_weights]
|
||||||
|
|
||||||
# 随机选择一个表情包
|
# 随机选择一个表情包
|
||||||
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
|
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
|
||||||
to_delete.append(all_emojis[selected_idx])
|
to_delete.append(all_emojis[selected_idx])
|
||||||
remaining_indices.remove(selected_idx)
|
remaining_indices.remove(selected_idx)
|
||||||
|
|
||||||
# 删除选中的表情包
|
# 删除选中的表情包
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for emoji in to_delete:
|
for emoji in to_delete:
|
||||||
@@ -512,26 +511,26 @@ class EmojiManager:
|
|||||||
if "path" in emoji and os.path.exists(emoji["path"]):
|
if "path" in emoji and os.path.exists(emoji["path"]):
|
||||||
os.remove(emoji["path"])
|
os.remove(emoji["path"])
|
||||||
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
|
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
|
||||||
|
|
||||||
# 删除数据库记录
|
# 删除数据库记录
|
||||||
db.emoji.delete_one({"_id": emoji["_id"]})
|
db.emoji.delete_one({"_id": emoji["_id"]})
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
|
|
||||||
# 同时从images集合中删除
|
# 同时从images集合中删除
|
||||||
if "hash" in emoji:
|
if "hash" in emoji:
|
||||||
db.images.delete_one({"hash": emoji["hash"]})
|
db.images.delete_one({"hash": emoji["hash"]})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 更新表情包数量
|
# 更新表情包数量
|
||||||
self._update_emoji_count()
|
self._update_emoji_count()
|
||||||
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
|
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
|
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
|
||||||
|
|
||||||
async def start_periodic_check_register(self):
|
async def start_periodic_check_register(self):
|
||||||
"""定期检查表情包完整性和数量"""
|
"""定期检查表情包完整性和数量"""
|
||||||
while True:
|
while True:
|
||||||
@@ -542,7 +541,7 @@ class EmojiManager:
|
|||||||
logger.info("[扫描] 开始扫描新表情包...")
|
logger.info("[扫描] 开始扫描新表情包...")
|
||||||
if self.emoji_num < self.emoji_num_max:
|
if self.emoji_num < self.emoji_num_max:
|
||||||
await self.scan_new_emojis()
|
await self.scan_new_emojis()
|
||||||
if (self.emoji_num > self.emoji_num_max):
|
if self.emoji_num > self.emoji_num_max:
|
||||||
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
|
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
|
||||||
if not global_config.max_reach_deletion:
|
if not global_config.max_reach_deletion:
|
||||||
logger.warning("表情包数量超过最大限制,终止注册")
|
logger.warning("表情包数量超过最大限制,终止注册")
|
||||||
@@ -551,7 +550,7 @@ class EmojiManager:
|
|||||||
logger.warning("表情包数量超过最大限制,开始删除表情包")
|
logger.warning("表情包数量超过最大限制,开始删除表情包")
|
||||||
self.check_emoji_file_full()
|
self.check_emoji_file_full()
|
||||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||||
|
|
||||||
async def delete_all_images(self):
|
async def delete_all_images(self):
|
||||||
"""删除 data/image 目录下的所有文件"""
|
"""删除 data/image 目录下的所有文件"""
|
||||||
try:
|
try:
|
||||||
@@ -559,10 +558,10 @@ class EmojiManager:
|
|||||||
if not os.path.exists(image_dir):
|
if not os.path.exists(image_dir):
|
||||||
logger.warning(f"[警告] 目录不存在: {image_dir}")
|
logger.warning(f"[警告] 目录不存在: {image_dir}")
|
||||||
return
|
return
|
||||||
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
failed_count = 0
|
failed_count = 0
|
||||||
|
|
||||||
# 遍历目录下的所有文件
|
# 遍历目录下的所有文件
|
||||||
for filename in os.listdir(image_dir):
|
for filename in os.listdir(image_dir):
|
||||||
file_path = os.path.join(image_dir, filename)
|
file_path = os.path.join(image_dir, filename)
|
||||||
@@ -574,11 +573,12 @@ class EmojiManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
|
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
|
||||||
|
|
||||||
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count} 个")
|
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count} 个")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
|
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
emoji_manager = EmojiManager()
|
emoji_manager = EmojiManager()
|
||||||
|
|||||||
@@ -13,9 +13,10 @@ from ..config.config import global_config
|
|||||||
|
|
||||||
logger = get_module_logger("message_buffer")
|
logger = get_module_logger("message_buffer")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheMessages:
|
class CacheMessages:
|
||||||
message: MessageRecv
|
message: MessageRecv
|
||||||
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
|
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
|
||||||
result: str = "U"
|
result: str = "U"
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ class MessageBuffer:
|
|||||||
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
|
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def get_person_id_(self, platform:str, user_id:str, group_info:GroupInfo):
|
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
if group_info:
|
if group_info:
|
||||||
group_id = group_info.group_id
|
group_id = group_info.group_id
|
||||||
@@ -34,16 +35,17 @@ class MessageBuffer:
|
|||||||
key = f"{platform}_{user_id}_{group_id}"
|
key = f"{platform}_{user_id}_{group_id}"
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
async def start_caching_messages(self, message:MessageRecv):
|
async def start_caching_messages(self, message: MessageRecv):
|
||||||
"""添加消息,启动缓冲"""
|
"""添加消息,启动缓冲"""
|
||||||
if not global_config.message_buffer:
|
if not global_config.message_buffer:
|
||||||
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
|
person_id = person_info_manager.get_person_id(
|
||||||
message.message_info.user_info.user_id)
|
message.message_info.user_info.platform, message.message_info.user_info.user_id
|
||||||
|
)
|
||||||
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
|
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
|
||||||
return
|
return
|
||||||
person_id_ = self.get_person_id_(message.message_info.platform,
|
person_id_ = self.get_person_id_(
|
||||||
message.message_info.user_info.user_id,
|
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
|
||||||
message.message_info.group_info)
|
)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
if person_id_ not in self.buffer_pool:
|
if person_id_ not in self.buffer_pool:
|
||||||
@@ -64,25 +66,24 @@ class MessageBuffer:
|
|||||||
break
|
break
|
||||||
elif msg.result == "F":
|
elif msg.result == "F":
|
||||||
recent_F_count += 1
|
recent_F_count += 1
|
||||||
|
|
||||||
# 判断条件:最近T之后有超过3-5条F
|
# 判断条件:最近T之后有超过3-5条F
|
||||||
if (recent_F_count >= random.randint(3, 5)):
|
if recent_F_count >= random.randint(3, 5):
|
||||||
new_msg = CacheMessages(message=message, result="T")
|
new_msg = CacheMessages(message=message, result="T")
|
||||||
new_msg.cache_determination.set()
|
new_msg.cache_determination.set()
|
||||||
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
|
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
|
||||||
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
|
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 添加新消息
|
# 添加新消息
|
||||||
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
|
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
|
||||||
|
|
||||||
# 启动3秒缓冲计时器
|
# 启动3秒缓冲计时器
|
||||||
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
|
person_id = person_info_manager.get_person_id(
|
||||||
message.message_info.user_info.user_id)
|
message.message_info.user_info.platform, message.message_info.user_info.user_id
|
||||||
|
)
|
||||||
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
|
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
|
||||||
asyncio.create_task(self._debounce_processor(person_id_,
|
asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
|
||||||
message.message_info.message_id,
|
|
||||||
person_id))
|
|
||||||
|
|
||||||
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
|
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
|
||||||
"""等待3秒无新消息"""
|
"""等待3秒无新消息"""
|
||||||
@@ -92,36 +93,33 @@ class MessageBuffer:
|
|||||||
return
|
return
|
||||||
interval_time = max(0.5, int(interval_time) / 1000)
|
interval_time = max(0.5, int(interval_time) / 1000)
|
||||||
await asyncio.sleep(interval_time)
|
await asyncio.sleep(interval_time)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
if (person_id_ not in self.buffer_pool or
|
if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
|
||||||
message_id not in self.buffer_pool[person_id_]):
|
|
||||||
logger.debug(f"消息已被清理,msgid: {message_id}")
|
logger.debug(f"消息已被清理,msgid: {message_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
cache_msg = self.buffer_pool[person_id_][message_id]
|
cache_msg = self.buffer_pool[person_id_][message_id]
|
||||||
if cache_msg.result == "U":
|
if cache_msg.result == "U":
|
||||||
cache_msg.result = "T"
|
cache_msg.result = "T"
|
||||||
cache_msg.cache_determination.set()
|
cache_msg.cache_determination.set()
|
||||||
|
|
||||||
|
async def query_buffer_result(self, message: MessageRecv) -> bool:
|
||||||
async def query_buffer_result(self, message:MessageRecv) -> bool:
|
|
||||||
"""查询缓冲结果,并清理"""
|
"""查询缓冲结果,并清理"""
|
||||||
if not global_config.message_buffer:
|
if not global_config.message_buffer:
|
||||||
return True
|
return True
|
||||||
person_id_ = self.get_person_id_(message.message_info.platform,
|
person_id_ = self.get_person_id_(
|
||||||
message.message_info.user_info.user_id,
|
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
|
||||||
message.message_info.group_info)
|
)
|
||||||
|
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
user_msgs = self.buffer_pool.get(person_id_, {})
|
user_msgs = self.buffer_pool.get(person_id_, {})
|
||||||
cache_msg = user_msgs.get(message.message_info.message_id)
|
cache_msg = user_msgs.get(message.message_info.message_id)
|
||||||
|
|
||||||
if not cache_msg:
|
if not cache_msg:
|
||||||
logger.debug(f"查询异常,消息不存在,msgid: {message.message_info.message_id}")
|
logger.debug(f"查询异常,消息不存在,msgid: {message.message_info.message_id}")
|
||||||
return False # 消息不存在或已清理
|
return False # 消息不存在或已清理
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
|
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
|
||||||
result = cache_msg.result == "T"
|
result = cache_msg.result == "T"
|
||||||
@@ -144,9 +142,8 @@ class MessageBuffer:
|
|||||||
keep_msgs[msg_id] = msg
|
keep_msgs[msg_id] = msg
|
||||||
elif msg.result == "F":
|
elif msg.result == "F":
|
||||||
# 收集F消息的文本内容
|
# 收集F消息的文本内容
|
||||||
if (hasattr(msg.message, 'processed_plain_text')
|
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
|
||||||
and msg.message.processed_plain_text):
|
if msg.message.message_segment.type == "text":
|
||||||
if msg.message.message_segment.type == "text":
|
|
||||||
combined_text.append(msg.message.processed_plain_text)
|
combined_text.append(msg.message.processed_plain_text)
|
||||||
elif msg.message.message_segment.type != "text":
|
elif msg.message.message_segment.type != "text":
|
||||||
is_update = False
|
is_update = False
|
||||||
@@ -157,20 +154,20 @@ class MessageBuffer:
|
|||||||
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
|
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
|
||||||
if type == "text":
|
if type == "text":
|
||||||
message.processed_plain_text = "".join(combined_text)
|
message.processed_plain_text = "".join(combined_text)
|
||||||
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息")
|
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
|
||||||
elif type == "emoji":
|
elif type == "emoji":
|
||||||
combined_text.pop()
|
combined_text.pop()
|
||||||
message.processed_plain_text = "".join(combined_text)
|
message.processed_plain_text = "".join(combined_text)
|
||||||
message.is_emoji = False
|
message.is_emoji = False
|
||||||
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容,覆盖当前emoji消息")
|
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容,覆盖当前emoji消息")
|
||||||
|
|
||||||
self.buffer_pool[person_id_] = keep_msgs
|
self.buffer_pool[person_id_] = keep_msgs
|
||||||
return result
|
return result
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.debug(f"查询超时消息id: {message.message_info.message_id}")
|
logger.debug(f"查询超时消息id: {message.message_info.message_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def save_message_interval(self, person_id:str, message:BaseMessageInfo):
|
async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
|
||||||
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
|
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
|
||||||
now_time_ms = int(round(time.time() * 1000))
|
now_time_ms = int(round(time.time() * 1000))
|
||||||
if len(message_interval_list) < 1000:
|
if len(message_interval_list) < 1000:
|
||||||
@@ -179,12 +176,12 @@ class MessageBuffer:
|
|||||||
message_interval_list.pop(0)
|
message_interval_list.pop(0)
|
||||||
message_interval_list.append(now_time_ms)
|
message_interval_list.append(now_time_ms)
|
||||||
data = {
|
data = {
|
||||||
"platform" : message.platform,
|
"platform": message.platform,
|
||||||
"user_id" : message.user_info.user_id,
|
"user_id": message.user_info.user_id,
|
||||||
"nickname" : message.user_info.user_nickname,
|
"nickname": message.user_info.user_nickname,
|
||||||
"konw_time" : int(time.time())
|
"konw_time": int(time.time()),
|
||||||
}
|
}
|
||||||
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
|
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
|
||||||
|
|
||||||
|
|
||||||
message_buffer = MessageBuffer()
|
message_buffer = MessageBuffer()
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ class Message_Sender:
|
|||||||
typing_time = calculate_typing_time(
|
typing_time = calculate_typing_time(
|
||||||
input_string=message.processed_plain_text,
|
input_string=message.processed_plain_text,
|
||||||
thinking_start_time=message.thinking_start_time,
|
thinking_start_time=message.thinking_start_time,
|
||||||
is_emoji=message.is_emoji)
|
is_emoji=message.is_emoji,
|
||||||
|
)
|
||||||
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
|
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
|
||||||
await asyncio.sleep(typing_time)
|
await asyncio.sleep(typing_time)
|
||||||
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
|
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
|
||||||
@@ -227,7 +228,7 @@ class MessageManager:
|
|||||||
await message_earliest.process()
|
await message_earliest.process()
|
||||||
|
|
||||||
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
|
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
|
||||||
|
|
||||||
await message_sender.send_message(message_earliest)
|
await message_sender.send_message(message_earliest)
|
||||||
|
|
||||||
await self.storage.store_message(message_earliest, message_earliest.chat_stream)
|
await self.storage.store_message(message_earliest, message_earliest.chat_stream)
|
||||||
|
|||||||
@@ -56,14 +56,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
|
|||||||
logger.info("被@,回复概率设置为100%")
|
logger.info("被@,回复概率设置为100%")
|
||||||
else:
|
else:
|
||||||
if not is_mentioned:
|
if not is_mentioned:
|
||||||
|
|
||||||
# 判断是否被回复
|
# 判断是否被回复
|
||||||
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
|
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
|
|
||||||
# 判断内容中是否被提及
|
# 判断内容中是否被提及
|
||||||
message_content = re.sub(r'\@[\s\S]*?((\d+))','', message.processed_plain_text)
|
message_content = re.sub(r"\@[\s\S]*?((\d+))", "", message.processed_plain_text)
|
||||||
message_content = re.sub(r'回复[\s\S]*?\((\d+)\)的消息,说: ','', message_content)
|
message_content = re.sub(r"回复[\s\S]*?\((\d+)\)的消息,说: ", "", message_content)
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
if keyword in message_content:
|
if keyword in message_content:
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
@@ -359,7 +358,13 @@ def process_llm_response(text: str) -> List[str]:
|
|||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float:
|
def calculate_typing_time(
|
||||||
|
input_string: str,
|
||||||
|
thinking_start_time: float,
|
||||||
|
chinese_time: float = 0.2,
|
||||||
|
english_time: float = 0.1,
|
||||||
|
is_emoji: bool = False,
|
||||||
|
) -> float:
|
||||||
"""
|
"""
|
||||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||||
input_string (str): 输入的字符串
|
input_string (str): 输入的字符串
|
||||||
@@ -393,19 +398,18 @@ def calculate_typing_time(input_string: str, thinking_start_time: float, chinese
|
|||||||
total_time += chinese_time
|
total_time += chinese_time
|
||||||
else: # 其他字符(如英文)
|
else: # 其他字符(如英文)
|
||||||
total_time += english_time
|
total_time += english_time
|
||||||
|
|
||||||
|
|
||||||
if is_emoji:
|
if is_emoji:
|
||||||
total_time = 1
|
total_time = 1
|
||||||
|
|
||||||
if time.time() - thinking_start_time > 10:
|
if time.time() - thinking_start_time > 10:
|
||||||
total_time = 1
|
total_time = 1
|
||||||
|
|
||||||
# print(f"thinking_start_time:{thinking_start_time}")
|
# print(f"thinking_start_time:{thinking_start_time}")
|
||||||
# print(f"nowtime:{time.time()}")
|
# print(f"nowtime:{time.time()}")
|
||||||
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
|
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
|
||||||
# print(f"{total_time}")
|
# print(f"{total_time}")
|
||||||
|
|
||||||
return total_time # 加上回车时间
|
return total_time # 加上回车时间
|
||||||
|
|
||||||
|
|
||||||
@@ -535,39 +539,32 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
|||||||
try:
|
try:
|
||||||
# 获取开始时间之前最新的一条消息
|
# 获取开始时间之前最新的一条消息
|
||||||
start_message = db.messages.find_one(
|
start_message = db.messages.find_one(
|
||||||
{
|
{"chat_id": stream_id, "time": {"$lte": start_time}},
|
||||||
"chat_id": stream_id,
|
sort=[("time", -1), ("_id", -1)], # 按时间倒序,_id倒序(最后插入的在前)
|
||||||
"time": {"$lte": start_time}
|
|
||||||
},
|
|
||||||
sort=[("time", -1), ("_id", -1)] # 按时间倒序,_id倒序(最后插入的在前)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取结束时间最近的一条消息
|
# 获取结束时间最近的一条消息
|
||||||
# 先找到结束时间点的所有消息
|
# 先找到结束时间点的所有消息
|
||||||
end_time_messages = list(db.messages.find(
|
end_time_messages = list(
|
||||||
{
|
db.messages.find(
|
||||||
"chat_id": stream_id,
|
{"chat_id": stream_id, "time": {"$lte": end_time}},
|
||||||
"time": {"$lte": end_time}
|
sort=[("time", -1)], # 先按时间倒序
|
||||||
},
|
).limit(10)
|
||||||
sort=[("time", -1)] # 先按时间倒序
|
) # 限制查询数量,避免性能问题
|
||||||
).limit(10)) # 限制查询数量,避免性能问题
|
|
||||||
|
|
||||||
if not end_time_messages:
|
if not end_time_messages:
|
||||||
logger.warning(f"未找到结束时间 {end_time} 之前的消息")
|
logger.warning(f"未找到结束时间 {end_time} 之前的消息")
|
||||||
return 0, 0
|
return 0, 0
|
||||||
|
|
||||||
# 找到最大时间
|
# 找到最大时间
|
||||||
max_time = end_time_messages[0]["time"]
|
max_time = end_time_messages[0]["time"]
|
||||||
# 在最大时间的消息中找最后插入的(_id最大的)
|
# 在最大时间的消息中找最后插入的(_id最大的)
|
||||||
end_message = max(
|
end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
|
||||||
[msg for msg in end_time_messages if msg["time"] == max_time],
|
|
||||||
key=lambda x: x["_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not start_message:
|
if not start_message:
|
||||||
logger.warning(f"未找到开始时间 {start_time} 之前的消息")
|
logger.warning(f"未找到开始时间 {start_time} 之前的消息")
|
||||||
return 0, 0
|
return 0, 0
|
||||||
|
|
||||||
# 调试输出
|
# 调试输出
|
||||||
# print("\n=== 消息范围信息 ===")
|
# print("\n=== 消息范围信息 ===")
|
||||||
# print("Start message:", {
|
# print("Start message:", {
|
||||||
@@ -587,20 +584,16 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
|||||||
# 如果结束消息的时间等于开始时间,返回0
|
# 如果结束消息的时间等于开始时间,返回0
|
||||||
if end_message["time"] == start_message["time"]:
|
if end_message["time"] == start_message["time"]:
|
||||||
return 0, 0
|
return 0, 0
|
||||||
|
|
||||||
# 获取并打印这个时间范围内的所有消息
|
# 获取并打印这个时间范围内的所有消息
|
||||||
# print("\n=== 时间范围内的所有消息 ===")
|
# print("\n=== 时间范围内的所有消息 ===")
|
||||||
all_messages = list(db.messages.find(
|
all_messages = list(
|
||||||
{
|
db.messages.find(
|
||||||
"chat_id": stream_id,
|
{"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
|
||||||
"time": {
|
sort=[("time", 1), ("_id", 1)], # 按时间正序,_id正序
|
||||||
"$gte": start_message["time"],
|
)
|
||||||
"$lte": end_message["time"]
|
)
|
||||||
}
|
|
||||||
},
|
|
||||||
sort=[("time", 1), ("_id", 1)] # 按时间正序,_id正序
|
|
||||||
))
|
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
total_length = 0
|
total_length = 0
|
||||||
for msg in all_messages:
|
for msg in all_messages:
|
||||||
@@ -615,10 +608,10 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
|||||||
# "text_length": text_length,
|
# "text_length": text_length,
|
||||||
# "_id": str(msg.get("_id"))
|
# "_id": str(msg.get("_id"))
|
||||||
# })
|
# })
|
||||||
|
|
||||||
# 如果时间不同,需要把end_message本身也计入
|
# 如果时间不同,需要把end_message本身也计入
|
||||||
return count - 1, total_length
|
return count - 1, total_length
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算消息数量时出错: {str(e)}")
|
logger.error(f"计算消息数量时出错: {str(e)}")
|
||||||
return 0, 0
|
return 0, 0
|
||||||
|
|||||||
@@ -239,13 +239,13 @@ class ImageManager:
|
|||||||
# 解码base64
|
# 解码base64
|
||||||
gif_data = base64.b64decode(gif_base64)
|
gif_data = base64.b64decode(gif_base64)
|
||||||
gif = Image.open(io.BytesIO(gif_data))
|
gif = Image.open(io.BytesIO(gif_data))
|
||||||
|
|
||||||
# 收集所有帧
|
# 收集所有帧
|
||||||
frames = []
|
frames = []
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
gif.seek(len(frames))
|
gif.seek(len(frames))
|
||||||
frame = gif.convert('RGB')
|
frame = gif.convert("RGB")
|
||||||
frames.append(frame.copy())
|
frames.append(frame.copy())
|
||||||
except EOFError:
|
except EOFError:
|
||||||
pass
|
pass
|
||||||
@@ -264,18 +264,19 @@ class ImageManager:
|
|||||||
|
|
||||||
# 获取单帧的尺寸
|
# 获取单帧的尺寸
|
||||||
frame_width, frame_height = selected_frames[0].size
|
frame_width, frame_height = selected_frames[0].size
|
||||||
|
|
||||||
# 计算目标尺寸,保持宽高比
|
# 计算目标尺寸,保持宽高比
|
||||||
target_height = 200 # 固定高度
|
target_height = 200 # 固定高度
|
||||||
target_width = int((target_height / frame_height) * frame_width)
|
target_width = int((target_height / frame_height) * frame_width)
|
||||||
|
|
||||||
# 调整所有帧的大小
|
# 调整所有帧的大小
|
||||||
resized_frames = [frame.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
resized_frames = [
|
||||||
for frame in selected_frames]
|
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
|
||||||
|
]
|
||||||
|
|
||||||
# 创建拼接图像
|
# 创建拼接图像
|
||||||
total_width = target_width * len(resized_frames)
|
total_width = target_width * len(resized_frames)
|
||||||
combined_image = Image.new('RGB', (total_width, target_height))
|
combined_image = Image.new("RGB", (total_width, target_height))
|
||||||
|
|
||||||
# 水平拼接图像
|
# 水平拼接图像
|
||||||
for idx, frame in enumerate(resized_frames):
|
for idx, frame in enumerate(resized_frames):
|
||||||
@@ -283,11 +284,11 @@ class ImageManager:
|
|||||||
|
|
||||||
# 转换为base64
|
# 转换为base64
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
combined_image.save(buffer, format='JPEG', quality=85)
|
combined_image.save(buffer, format="JPEG", quality=85)
|
||||||
result_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
return result_base64
|
return result_base64
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"GIF转换失败: {str(e)}")
|
logger.error(f"GIF转换失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ from datetime import datetime
|
|||||||
|
|
||||||
logger = get_module_logger("pfc_message_processor")
|
logger = get_module_logger("pfc_message_processor")
|
||||||
|
|
||||||
|
|
||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
"""消息处理器,负责处理接收到的消息并存储"""
|
"""消息处理器,负责处理接收到的消息并存储"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
|
|
||||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||||
"""检查消息中是否包含过滤词"""
|
"""检查消息中是否包含过滤词"""
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
@@ -34,10 +35,10 @@ class MessageProcessor:
|
|||||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def process_message(self, message: MessageRecv) -> None:
|
async def process_message(self, message: MessageRecv) -> None:
|
||||||
"""处理消息并存储
|
"""处理消息并存储
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 消息对象
|
message: 消息对象
|
||||||
"""
|
"""
|
||||||
@@ -55,12 +56,9 @@ class MessageProcessor:
|
|||||||
|
|
||||||
# 存储消息
|
# 存储消息
|
||||||
await self.storage.store_message(message, chat)
|
await self.storage.store_message(message, chat)
|
||||||
|
|
||||||
# 打印消息信息
|
# 打印消息信息
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
# 将时间戳转换为datetime对象
|
# 将时间戳转换为datetime对象
|
||||||
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
|
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
|
||||||
logger.info(
|
logger.info(f"[{current_time}][{mes_name}]{chat.user_info.user_nickname}: {message.processed_plain_text}")
|
||||||
f"[{current_time}][{mes_name}]"
|
|
||||||
f"{chat.user_info.user_nickname}: {message.processed_plain_text}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ chat_config = LogConfig(
|
|||||||
|
|
||||||
logger = get_module_logger("reasoning_chat", config=chat_config)
|
logger = get_module_logger("reasoning_chat", config=chat_config)
|
||||||
|
|
||||||
|
|
||||||
class ReasoningChat:
|
class ReasoningChat:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
@@ -224,13 +225,13 @@ class ReasoningChat:
|
|||||||
do_reply = False
|
do_reply = False
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
do_reply = True
|
do_reply = True
|
||||||
|
|
||||||
# 创建思考消息
|
# 创建思考消息
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
|
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
timing_results["创建思考消息"] = timer2 - timer1
|
timing_results["创建思考消息"] = timer2 - timer1
|
||||||
|
|
||||||
# 生成回复
|
# 生成回复
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
response_set = await self.gpt.generate_response(message)
|
response_set = await self.gpt.generate_response(message)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
#从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
if random.random() < global_config.MODEL_R1_PROBABILITY:
|
if random.random() < global_config.MODEL_R1_PROBABILITY:
|
||||||
self.current_model_type = "深深地"
|
self.current_model_type = "深深地"
|
||||||
current_model = self.model_reasoning
|
current_model = self.model_reasoning
|
||||||
@@ -51,7 +51,6 @@ class ResponseGenerator:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||||
) # noqa: E501
|
) # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(message, current_model)
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
|
|
||||||
@@ -189,4 +188,4 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
# print(f"得到了处理后的llm返回{processed_response}")
|
# print(f"得到了处理后的llm返回{processed_response}")
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|||||||
@@ -24,35 +24,32 @@ class PromptBuilder:
|
|||||||
async def _build_prompt(
|
async def _build_prompt(
|
||||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
|
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
#person
|
# person
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
|
|
||||||
personality_core = individuality.personality.personality_core
|
personality_core = individuality.personality.personality_core
|
||||||
prompt_personality += personality_core
|
prompt_personality += personality_core
|
||||||
|
|
||||||
personality_sides = individuality.personality.personality_sides
|
personality_sides = individuality.personality.personality_sides
|
||||||
random.shuffle(personality_sides)
|
random.shuffle(personality_sides)
|
||||||
prompt_personality += f",{personality_sides[0]}"
|
prompt_personality += f",{personality_sides[0]}"
|
||||||
|
|
||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
who_chat_in_group = [(chat_stream.user_info.platform,
|
who_chat_in_group = [
|
||||||
chat_stream.user_info.user_id,
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||||
chat_stream.user_info.user_nickname)]
|
]
|
||||||
who_chat_in_group += get_recent_group_speaker(
|
who_chat_in_group += get_recent_group_speaker(
|
||||||
stream_id,
|
stream_id,
|
||||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
|
||||||
limit=global_config.MAX_CONTEXT_SIZE,
|
limit=global_config.MAX_CONTEXT_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
relation_prompt = ""
|
relation_prompt = ""
|
||||||
for person in who_chat_in_group:
|
for person in who_chat_in_group:
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||||
@@ -67,7 +64,7 @@ class PromptBuilder:
|
|||||||
mood_prompt = mood_manager.get_prompt()
|
mood_prompt = mood_manager.get_prompt()
|
||||||
|
|
||||||
# logger.info(f"心情prompt: {mood_prompt}")
|
# logger.info(f"心情prompt: {mood_prompt}")
|
||||||
|
|
||||||
# 调取记忆
|
# 调取记忆
|
||||||
memory_prompt = ""
|
memory_prompt = ""
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
@@ -84,7 +81,7 @@ class PromptBuilder:
|
|||||||
# print(f"相关记忆:{related_memory_info}")
|
# print(f"相关记忆:{related_memory_info}")
|
||||||
|
|
||||||
# 日程构建
|
# 日程构建
|
||||||
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
|
schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
chat_in_group = True
|
chat_in_group = True
|
||||||
@@ -143,7 +140,7 @@ class PromptBuilder:
|
|||||||
涉及政治敏感以及违法违规的内容请规避。"""
|
涉及政治敏感以及违法违规的内容请规避。"""
|
||||||
|
|
||||||
logger.info("开始构建prompt")
|
logger.info("开始构建prompt")
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
{relation_prompt_all}
|
{relation_prompt_all}
|
||||||
{memory_prompt}
|
{memory_prompt}
|
||||||
@@ -165,7 +162,7 @@ class PromptBuilder:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
related_info = ""
|
related_info = ""
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
|
|
||||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||||
topics = []
|
topics = []
|
||||||
# try:
|
# try:
|
||||||
@@ -173,7 +170,7 @@ class PromptBuilder:
|
|||||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||||
|
|
||||||
# # 提取关键词
|
# # 提取关键词
|
||||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
# if not topics:
|
# if not topics:
|
||||||
@@ -184,7 +181,7 @@ class PromptBuilder:
|
|||||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
# if topic.strip()
|
# if topic.strip()
|
||||||
# ]
|
# ]
|
||||||
|
|
||||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||||
@@ -192,7 +189,7 @@ class PromptBuilder:
|
|||||||
# words = jieba.cut(message)
|
# words = jieba.cut(message)
|
||||||
# topics = [word for word in words if len(word) > 1][:5]
|
# topics = [word for word in words if len(word) > 1][:5]
|
||||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||||
|
|
||||||
# 如果无法提取到主题,直接使用整个消息
|
# 如果无法提取到主题,直接使用整个消息
|
||||||
if not topics:
|
if not topics:
|
||||||
logger.info("未能提取到任何主题,使用整个消息进行查询")
|
logger.info("未能提取到任何主题,使用整个消息进行查询")
|
||||||
@@ -200,26 +197,26 @@ class PromptBuilder:
|
|||||||
if not embedding:
|
if not embedding:
|
||||||
logger.error("获取消息嵌入向量失败")
|
logger.error("获取消息嵌入向量失败")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||||
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||||
return related_info
|
return related_info
|
||||||
|
|
||||||
# 2. 对每个主题进行知识库查询
|
# 2. 对每个主题进行知识库查询
|
||||||
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||||
|
|
||||||
# 优化:批量获取嵌入向量,减少API调用
|
# 优化:批量获取嵌入向量,减少API调用
|
||||||
embeddings = {}
|
embeddings = {}
|
||||||
topics_batch = [topic for topic in topics if len(topic) > 0]
|
topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||||
if message: # 确保消息非空
|
if message: # 确保消息非空
|
||||||
topics_batch.append(message)
|
topics_batch.append(message)
|
||||||
|
|
||||||
# 批量获取嵌入向量
|
# 批量获取嵌入向量
|
||||||
embed_start_time = time.time()
|
embed_start_time = time.time()
|
||||||
for text in topics_batch:
|
for text in topics_batch:
|
||||||
if not text or len(text.strip()) == 0:
|
if not text or len(text.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = await get_embedding(text, request_type="prompt_build")
|
embedding = await get_embedding(text, request_type="prompt_build")
|
||||||
if embedding:
|
if embedding:
|
||||||
@@ -228,17 +225,17 @@ class PromptBuilder:
|
|||||||
logger.warning(f"获取'{text}'的嵌入向量失败")
|
logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||||
|
|
||||||
if not embeddings:
|
if not embeddings:
|
||||||
logger.error("所有嵌入向量获取失败")
|
logger.error("所有嵌入向量获取失败")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 3. 对每个主题进行知识库查询
|
# 3. 对每个主题进行知识库查询
|
||||||
all_results = []
|
all_results = []
|
||||||
query_start_time = time.time()
|
query_start_time = time.time()
|
||||||
|
|
||||||
# 首先添加原始消息的查询结果
|
# 首先添加原始消息的查询结果
|
||||||
if message in embeddings:
|
if message in embeddings:
|
||||||
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||||
@@ -247,12 +244,12 @@ class PromptBuilder:
|
|||||||
result["topic"] = "原始消息"
|
result["topic"] = "原始消息"
|
||||||
all_results.extend(original_results)
|
all_results.extend(original_results)
|
||||||
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||||
|
|
||||||
# 然后添加每个主题的查询结果
|
# 然后添加每个主题的查询结果
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
if not topic or topic not in embeddings:
|
if not topic or topic not in embeddings:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||||
if topic_results:
|
if topic_results:
|
||||||
@@ -263,9 +260,9 @@ class PromptBuilder:
|
|||||||
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||||
|
|
||||||
# 4. 去重和过滤
|
# 4. 去重和过滤
|
||||||
process_start_time = time.time()
|
process_start_time = time.time()
|
||||||
unique_contents = set()
|
unique_contents = set()
|
||||||
@@ -275,14 +272,16 @@ class PromptBuilder:
|
|||||||
if content not in unique_contents:
|
if content not in unique_contents:
|
||||||
unique_contents.add(content)
|
unique_contents.add(content)
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
|
|
||||||
# 5. 按相似度排序
|
# 5. 按相似度排序
|
||||||
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
# 6. 限制总数量(最多10条)
|
# 6. 限制总数量(最多10条)
|
||||||
filtered_results = filtered_results[:10]
|
filtered_results = filtered_results[:10]
|
||||||
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
|
logger.info(
|
||||||
|
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
||||||
|
)
|
||||||
|
|
||||||
# 7. 格式化输出
|
# 7. 格式化输出
|
||||||
if filtered_results:
|
if filtered_results:
|
||||||
format_start_time = time.time()
|
format_start_time = time.time()
|
||||||
@@ -292,7 +291,7 @@ class PromptBuilder:
|
|||||||
if topic not in grouped_results:
|
if topic not in grouped_results:
|
||||||
grouped_results[topic] = []
|
grouped_results[topic] = []
|
||||||
grouped_results[topic].append(result)
|
grouped_results[topic].append(result)
|
||||||
|
|
||||||
# 按主题组织输出
|
# 按主题组织输出
|
||||||
for topic, results in grouped_results.items():
|
for topic, results in grouped_results.items():
|
||||||
related_info += f"【主题: {topic}】\n"
|
related_info += f"【主题: {topic}】\n"
|
||||||
@@ -303,13 +302,15 @@ class PromptBuilder:
|
|||||||
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||||
related_info += f"{content}\n"
|
related_info += f"{content}\n"
|
||||||
related_info += "\n"
|
related_info += "\n"
|
||||||
|
|
||||||
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||||
|
|
||||||
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||||
return related_info
|
return related_info
|
||||||
|
|
||||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
|
def get_info_from_db(
|
||||||
|
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||||
|
) -> Union[str, list]:
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ chat_config = LogConfig(
|
|||||||
|
|
||||||
logger = get_module_logger("think_flow_chat", config=chat_config)
|
logger = get_module_logger("think_flow_chat", config=chat_config)
|
||||||
|
|
||||||
|
|
||||||
class ThinkFlowChat:
|
class ThinkFlowChat:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
@@ -96,7 +97,7 @@ class ThinkFlowChat:
|
|||||||
)
|
)
|
||||||
if not mark_head:
|
if not mark_head:
|
||||||
mark_head = True
|
mark_head = True
|
||||||
|
|
||||||
# print(f"thinking_start_time:{bot_message.thinking_start_time}")
|
# print(f"thinking_start_time:{bot_message.thinking_start_time}")
|
||||||
message_set.add_message(bot_message)
|
message_set.add_message(bot_message)
|
||||||
message_manager.add_message(message_set)
|
message_manager.add_message(message_set)
|
||||||
@@ -110,7 +111,7 @@ class ThinkFlowChat:
|
|||||||
if emoji_raw:
|
if emoji_raw:
|
||||||
emoji_path, description = emoji_raw
|
emoji_path, description = emoji_raw
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
# logger.info(emoji_cq)
|
# logger.info(emoji_cq)
|
||||||
|
|
||||||
thinking_time_point = round(message.message_info.time, 2)
|
thinking_time_point = round(message.message_info.time, 2)
|
||||||
@@ -130,7 +131,7 @@ class ThinkFlowChat:
|
|||||||
is_head=False,
|
is_head=False,
|
||||||
is_emoji=True,
|
is_emoji=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logger.info("22222222222222")
|
# logger.info("22222222222222")
|
||||||
message_manager.add_message(bot_message)
|
message_manager.add_message(bot_message)
|
||||||
|
|
||||||
@@ -180,7 +181,7 @@ class ThinkFlowChat:
|
|||||||
|
|
||||||
await message.process()
|
await message.process()
|
||||||
logger.debug(f"消息处理成功{message.processed_plain_text}")
|
logger.debug(f"消息处理成功{message.processed_plain_text}")
|
||||||
|
|
||||||
# 过滤词/正则表达式过滤
|
# 过滤词/正则表达式过滤
|
||||||
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
||||||
message.raw_message, chat, userinfo
|
message.raw_message, chat, userinfo
|
||||||
@@ -190,7 +191,7 @@ class ThinkFlowChat:
|
|||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
await self.storage.store_message(message, chat)
|
||||||
logger.debug(f"存储成功{message.processed_plain_text}")
|
logger.debug(f"存储成功{message.processed_plain_text}")
|
||||||
|
|
||||||
# 记忆激活
|
# 记忆激活
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
|
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
|
||||||
@@ -214,15 +215,13 @@ class ThinkFlowChat:
|
|||||||
# 处理提及
|
# 处理提及
|
||||||
is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
|
is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
|
||||||
|
|
||||||
|
|
||||||
# 计算回复意愿
|
# 计算回复意愿
|
||||||
current_willing_old = willing_manager.get_willing(chat_stream=chat)
|
current_willing_old = willing_manager.get_willing(chat_stream=chat)
|
||||||
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
|
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
|
||||||
# current_willing = (current_willing_old + current_willing_new) / 2
|
# current_willing = (current_willing_old + current_willing_new) / 2
|
||||||
# 有点bug
|
# 有点bug
|
||||||
current_willing = current_willing_old
|
current_willing = current_willing_old
|
||||||
|
|
||||||
|
|
||||||
willing_manager.set_willing(chat.stream_id, current_willing)
|
willing_manager.set_willing(chat.stream_id, current_willing)
|
||||||
|
|
||||||
# 意愿激活
|
# 意愿激活
|
||||||
@@ -258,7 +257,7 @@ class ThinkFlowChat:
|
|||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
try:
|
try:
|
||||||
do_reply = True
|
do_reply = True
|
||||||
|
|
||||||
# 创建思考消息
|
# 创建思考消息
|
||||||
try:
|
try:
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
@@ -267,9 +266,9 @@ class ThinkFlowChat:
|
|||||||
timing_results["创建思考消息"] = timer2 - timer1
|
timing_results["创建思考消息"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流创建思考消息失败: {e}")
|
logger.error(f"心流创建思考消息失败: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 观察
|
# 观察
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
await heartflow.get_subheartflow(chat.stream_id).do_observe()
|
await heartflow.get_subheartflow(chat.stream_id).do_observe()
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
@@ -280,12 +279,14 @@ class ThinkFlowChat:
|
|||||||
# 思考前脑内状态
|
# 思考前脑内状态
|
||||||
try:
|
try:
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
|
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
|
||||||
|
message.processed_plain_text
|
||||||
|
)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
timing_results["思考前脑内状态"] = timer2 - timer1
|
timing_results["思考前脑内状态"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流思考前脑内状态失败: {e}")
|
logger.error(f"心流思考前脑内状态失败: {e}")
|
||||||
|
|
||||||
# 生成回复
|
# 生成回复
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
response_set = await self.gpt.generate_response(message)
|
response_set = await self.gpt.generate_response(message)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ class ResponseGenerator:
|
|||||||
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||||
)
|
)
|
||||||
@@ -178,4 +177,3 @@ class ResponseGenerator:
|
|||||||
# print(f"得到了处理后的llm返回{processed_response}")
|
# print(f"得到了处理后的llm返回{processed_response}")
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
|
|||||||
@@ -21,22 +21,21 @@ class PromptBuilder:
|
|||||||
async def _build_prompt(
|
async def _build_prompt(
|
||||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
|
|
||||||
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
||||||
|
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
prompt_personality = individuality.get_prompt(type = "personality",x_person = 2,level = 1)
|
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
|
||||||
prompt_identity = individuality.get_prompt(type = "identity",x_person = 2,level = 1)
|
prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
|
||||||
# 关系
|
# 关系
|
||||||
who_chat_in_group = [(chat_stream.user_info.platform,
|
who_chat_in_group = [
|
||||||
chat_stream.user_info.user_id,
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||||
chat_stream.user_info.user_nickname)]
|
]
|
||||||
who_chat_in_group += get_recent_group_speaker(
|
who_chat_in_group += get_recent_group_speaker(
|
||||||
stream_id,
|
stream_id,
|
||||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
|
||||||
limit=global_config.MAX_CONTEXT_SIZE,
|
limit=global_config.MAX_CONTEXT_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
relation_prompt = ""
|
relation_prompt = ""
|
||||||
for person in who_chat_in_group:
|
for person in who_chat_in_group:
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||||
@@ -100,7 +99,7 @@ class PromptBuilder:
|
|||||||
涉及政治敏感以及违法违规的内容请规避。"""
|
涉及政治敏感以及违法违规的内容请规避。"""
|
||||||
|
|
||||||
logger.info("开始构建prompt")
|
logger.info("开始构建prompt")
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
{relation_prompt_all}\n
|
{relation_prompt_all}\n
|
||||||
{chat_target}
|
{chat_target}
|
||||||
@@ -114,7 +113,7 @@ class PromptBuilder:
|
|||||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
|
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。"""
|
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。"""
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import tomlkit
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def update_config():
|
def update_config():
|
||||||
print("开始更新配置文件...")
|
print("开始更新配置文件...")
|
||||||
# 获取根目录路径
|
# 获取根目录路径
|
||||||
@@ -25,11 +26,11 @@ def update_config():
|
|||||||
print(f"发现旧配置文件: {old_config_path}")
|
print(f"发现旧配置文件: {old_config_path}")
|
||||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||||
old_config = tomlkit.load(f)
|
old_config = tomlkit.load(f)
|
||||||
|
|
||||||
# 生成带时间戳的新文件名
|
# 生成带时间戳的新文件名
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
||||||
|
|
||||||
# 移动旧配置文件到old目录
|
# 移动旧配置文件到old目录
|
||||||
shutil.move(old_config_path, old_backup_path)
|
shutil.move(old_config_path, old_backup_path)
|
||||||
print(f"已备份旧配置文件到: {old_backup_path}")
|
print(f"已备份旧配置文件到: {old_backup_path}")
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ logger = get_module_logger("config", config=config_config)
|
|||||||
is_test = True
|
is_test = True
|
||||||
mai_version_main = "0.6.2"
|
mai_version_main = "0.6.2"
|
||||||
mai_version_fix = "snapshot-1"
|
mai_version_fix = "snapshot-1"
|
||||||
|
|
||||||
if mai_version_fix:
|
if mai_version_fix:
|
||||||
if is_test:
|
if is_test:
|
||||||
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
||||||
@@ -39,6 +40,7 @@ else:
|
|||||||
else:
|
else:
|
||||||
mai_version = mai_version_main
|
mai_version = mai_version_main
|
||||||
|
|
||||||
|
|
||||||
def update_config():
|
def update_config():
|
||||||
# 获取根目录路径
|
# 获取根目录路径
|
||||||
root_dir = Path(__file__).parent.parent.parent.parent
|
root_dir = Path(__file__).parent.parent.parent.parent
|
||||||
@@ -54,7 +56,7 @@ def update_config():
|
|||||||
# 检查配置文件是否存在
|
# 检查配置文件是否存在
|
||||||
if not old_config_path.exists():
|
if not old_config_path.exists():
|
||||||
logger.info("配置文件不存在,从模板创建新配置")
|
logger.info("配置文件不存在,从模板创建新配置")
|
||||||
#创建文件夹
|
# 创建文件夹
|
||||||
old_config_dir.mkdir(parents=True, exist_ok=True)
|
old_config_dir.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy2(template_path, old_config_path)
|
shutil.copy2(template_path, old_config_path)
|
||||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
||||||
@@ -84,7 +86,7 @@ def update_config():
|
|||||||
# 生成带时间戳的新文件名
|
# 生成带时间戳的新文件名
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
||||||
|
|
||||||
# 移动旧配置文件到old目录
|
# 移动旧配置文件到old目录
|
||||||
shutil.move(old_config_path, old_backup_path)
|
shutil.move(old_config_path, old_backup_path)
|
||||||
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
||||||
@@ -127,6 +129,7 @@ def update_config():
|
|||||||
f.write(tomlkit.dumps(new_config))
|
f.write(tomlkit.dumps(new_config))
|
||||||
logger.info("配置文件更新完成")
|
logger.info("配置文件更新完成")
|
||||||
|
|
||||||
|
|
||||||
logger = get_module_logger("config")
|
logger = get_module_logger("config")
|
||||||
|
|
||||||
|
|
||||||
@@ -148,17 +151,21 @@ class BotConfig:
|
|||||||
ban_user_id = set()
|
ban_user_id = set()
|
||||||
|
|
||||||
# personality
|
# personality
|
||||||
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
||||||
personality_sides: List[str] = field(default_factory=lambda: [
|
personality_sides: List[str] = field(
|
||||||
"用一句话或几句话描述人格的一些侧面",
|
default_factory=lambda: [
|
||||||
"用一句话或几句话描述人格的一些侧面",
|
"用一句话或几句话描述人格的一些侧面",
|
||||||
"用一句话或几句话描述人格的一些侧面"
|
"用一句话或几句话描述人格的一些侧面",
|
||||||
])
|
"用一句话或几句话描述人格的一些侧面",
|
||||||
|
]
|
||||||
|
)
|
||||||
# identity
|
# identity
|
||||||
identity_detail: List[str] = field(default_factory=lambda: [
|
identity_detail: List[str] = field(
|
||||||
"身份特点",
|
default_factory=lambda: [
|
||||||
"身份特点",
|
"身份特点",
|
||||||
])
|
"身份特点",
|
||||||
|
]
|
||||||
|
)
|
||||||
height: int = 170 # 身高 单位厘米
|
height: int = 170 # 身高 单位厘米
|
||||||
weight: int = 50 # 体重 单位千克
|
weight: int = 50 # 体重 单位千克
|
||||||
age: int = 20 # 年龄 单位岁
|
age: int = 20 # 年龄 单位岁
|
||||||
@@ -181,22 +188,22 @@ class BotConfig:
|
|||||||
|
|
||||||
ban_words = set()
|
ban_words = set()
|
||||||
ban_msgs_regex = set()
|
ban_msgs_regex = set()
|
||||||
|
|
||||||
#heartflow
|
# heartflow
|
||||||
# enable_heartflow: bool = False # 是否启用心流
|
# enable_heartflow: bool = False # 是否启用心流
|
||||||
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
|
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
|
||||||
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
|
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
|
||||||
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
|
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
|
||||||
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
|
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
|
||||||
|
|
||||||
# willing
|
# willing
|
||||||
willing_mode: str = "classical" # 意愿模式
|
willing_mode: str = "classical" # 意愿模式
|
||||||
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
|
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
|
||||||
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
|
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
|
||||||
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
|
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
|
||||||
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
|
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
|
||||||
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
|
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
|
||||||
at_bot_inevitable_reply: bool = False # @bot 必然回复
|
at_bot_inevitable_reply: bool = False # @bot 必然回复
|
||||||
|
|
||||||
# response
|
# response
|
||||||
response_mode: str = "heart_flow" # 回复策略
|
response_mode: str = "heart_flow" # 回复策略
|
||||||
@@ -354,7 +361,6 @@ class BotConfig:
|
|||||||
"""从TOML配置文件加载配置"""
|
"""从TOML配置文件加载配置"""
|
||||||
config = cls()
|
config = cls()
|
||||||
|
|
||||||
|
|
||||||
def personality(parent: dict):
|
def personality(parent: dict):
|
||||||
personality_config = parent["personality"]
|
personality_config = parent["personality"]
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
||||||
@@ -418,13 +424,21 @@ class BotConfig:
|
|||||||
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
|
if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
|
||||||
config.response_mode = response_config.get("response_mode", config.response_mode)
|
config.response_mode = response_config.get("response_mode", config.response_mode)
|
||||||
|
|
||||||
def heartflow(parent: dict):
|
def heartflow(parent: dict):
|
||||||
heartflow_config = parent["heartflow"]
|
heartflow_config = parent["heartflow"]
|
||||||
config.sub_heart_flow_update_interval = heartflow_config.get("sub_heart_flow_update_interval", config.sub_heart_flow_update_interval)
|
config.sub_heart_flow_update_interval = heartflow_config.get(
|
||||||
config.sub_heart_flow_freeze_time = heartflow_config.get("sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time)
|
"sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
|
||||||
config.sub_heart_flow_stop_time = heartflow_config.get("sub_heart_flow_stop_time", config.sub_heart_flow_stop_time)
|
)
|
||||||
config.heart_flow_update_interval = heartflow_config.get("heart_flow_update_interval", config.heart_flow_update_interval)
|
config.sub_heart_flow_freeze_time = heartflow_config.get(
|
||||||
|
"sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
|
||||||
|
)
|
||||||
|
config.sub_heart_flow_stop_time = heartflow_config.get(
|
||||||
|
"sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
|
||||||
|
)
|
||||||
|
config.heart_flow_update_interval = heartflow_config.get(
|
||||||
|
"heart_flow_update_interval", config.heart_flow_update_interval
|
||||||
|
)
|
||||||
|
|
||||||
def willing(parent: dict):
|
def willing(parent: dict):
|
||||||
willing_config = parent["willing"]
|
willing_config = parent["willing"]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
|
|||||||
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||||
from .memory_config import MemoryConfig
|
from .memory_config import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
def get_closest_chat_from_db(length: int, timestamp: str):
|
def get_closest_chat_from_db(length: int, timestamp: str):
|
||||||
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
|
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
|
||||||
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
|
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
|
||||||
|
|||||||
@@ -179,7 +179,6 @@ class LLM_request:
|
|||||||
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||||
# logger.info(f"使用模型: {self.model_name}")
|
# logger.info(f"使用模型: {self.model_name}")
|
||||||
|
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
if image_base64:
|
if image_base64:
|
||||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||||
@@ -205,13 +204,17 @@ class LLM_request:
|
|||||||
# 处理需要重试的状态码
|
# 处理需要重试的状态码
|
||||||
if response.status in policy["retry_codes"]:
|
if response.status in policy["retry_codes"]:
|
||||||
wait_time = policy["base_wait"] * (2**retry)
|
wait_time = policy["base_wait"] * (2**retry)
|
||||||
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
|
logger.warning(
|
||||||
|
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
|
||||||
|
)
|
||||||
if response.status == 413:
|
if response.status == 413:
|
||||||
logger.warning("请求体过大,尝试压缩...")
|
logger.warning("请求体过大,尝试压缩...")
|
||||||
image_base64 = compress_base64_image_by_scale(image_base64)
|
image_base64 = compress_base64_image_by_scale(image_base64)
|
||||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||||
elif response.status in [500, 503]:
|
elif response.status in [500, 503]:
|
||||||
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||||
|
)
|
||||||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||||
@@ -219,7 +222,9 @@ class LLM_request:
|
|||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
elif response.status in policy["abort_codes"]:
|
elif response.status in policy["abort_codes"]:
|
||||||
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||||
|
)
|
||||||
# 尝试获取并记录服务器返回的详细错误信息
|
# 尝试获取并记录服务器返回的详细错误信息
|
||||||
try:
|
try:
|
||||||
error_json = await response.json()
|
error_json = await response.json()
|
||||||
@@ -257,7 +262,9 @@ class LLM_request:
|
|||||||
):
|
):
|
||||||
old_model_name = self.model_name
|
old_model_name = self.model_name
|
||||||
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||||||
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
logger.warning(
|
||||||
|
f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# 对全局配置进行更新
|
# 对全局配置进行更新
|
||||||
if global_config.llm_normal.get("name") == old_model_name:
|
if global_config.llm_normal.get("name") == old_model_name:
|
||||||
@@ -266,7 +273,9 @@ class LLM_request:
|
|||||||
|
|
||||||
if global_config.llm_reasoning.get("name") == old_model_name:
|
if global_config.llm_reasoning.get("name") == old_model_name:
|
||||||
global_config.llm_reasoning["name"] = self.model_name
|
global_config.llm_reasoning["name"] = self.model_name
|
||||||
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
logger.warning(
|
||||||
|
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# 更新payload中的模型名
|
# 更新payload中的模型名
|
||||||
if payload and "model" in payload:
|
if payload and "model" in payload:
|
||||||
@@ -328,7 +337,14 @@ class LLM_request:
|
|||||||
await response.release()
|
await response.release()
|
||||||
# 返回已经累积的内容
|
# 返回已经累积的内容
|
||||||
result = {
|
result = {
|
||||||
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": accumulated_content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
@@ -345,7 +361,14 @@ class LLM_request:
|
|||||||
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
||||||
# 返回已经累积的内容
|
# 返回已经累积的内容
|
||||||
result = {
|
result = {
|
||||||
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": accumulated_content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
@@ -360,7 +383,9 @@ class LLM_request:
|
|||||||
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||||
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
||||||
result = {
|
result = {
|
||||||
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}],
|
"choices": [
|
||||||
|
{"message": {"content": content, "reasoning_content": reasoning_content}}
|
||||||
|
],
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
@@ -394,7 +419,9 @@ class LLM_request:
|
|||||||
# 处理aiohttp抛出的响应错误
|
# 处理aiohttp抛出的响应错误
|
||||||
if retry < policy["max_retries"] - 1:
|
if retry < policy["max_retries"] - 1:
|
||||||
wait_time = policy["base_wait"] * (2**retry)
|
wait_time = policy["base_wait"] * (2**retry)
|
||||||
logger.error(f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
|
logger.error(
|
||||||
|
f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
|
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
|
||||||
error_text = await e.response.text()
|
error_text = await e.response.text()
|
||||||
@@ -419,13 +446,17 @@ class LLM_request:
|
|||||||
else:
|
else:
|
||||||
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
|
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
|
||||||
except (json.JSONDecodeError, TypeError) as json_err:
|
except (json.JSONDecodeError, TypeError) as json_err:
|
||||||
logger.warning(f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
|
logger.warning(
|
||||||
|
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
|
||||||
|
)
|
||||||
except (AttributeError, TypeError, ValueError) as parse_err:
|
except (AttributeError, TypeError, ValueError) as parse_err:
|
||||||
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
|
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
|
||||||
|
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
logger.critical(f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
|
logger.critical(
|
||||||
|
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
|
||||||
|
)
|
||||||
# 安全地检查和记录请求详情
|
# 安全地检查和记录请求详情
|
||||||
if (
|
if (
|
||||||
image_base64
|
image_base64
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class MoodManager:
|
|||||||
# 神经质:影响情绪变化速度
|
# 神经质:影响情绪变化速度
|
||||||
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
|
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
|
||||||
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
|
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
|
||||||
|
|
||||||
# 宜人性:影响情绪基准线
|
# 宜人性:影响情绪基准线
|
||||||
if personality.agreeableness < 0.2:
|
if personality.agreeableness < 0.2:
|
||||||
agreeableness_bias = (personality.agreeableness - 0.2) * 2
|
agreeableness_bias = (personality.agreeableness - 0.2) * 2
|
||||||
@@ -151,7 +151,7 @@ class MoodManager:
|
|||||||
# 分别计算正向和负向的衰减率
|
# 分别计算正向和负向的衰减率
|
||||||
if self.current_mood.valence >= 0:
|
if self.current_mood.valence >= 0:
|
||||||
# 正向情绪衰减
|
# 正向情绪衰减
|
||||||
decay_rate_positive = self.decay_rate_valence * (1/agreeableness_factor)
|
decay_rate_positive = self.decay_rate_valence * (1 / agreeableness_factor)
|
||||||
valence_target = 0 + agreeableness_bias
|
valence_target = 0 + agreeableness_bias
|
||||||
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
|
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
|
||||||
-decay_rate_positive * time_diff * neuroticism_factor
|
-decay_rate_positive * time_diff * neuroticism_factor
|
||||||
@@ -279,8 +279,9 @@ class MoodManager:
|
|||||||
# 限制范围
|
# 限制范围
|
||||||
self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence))
|
self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence))
|
||||||
self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal))
|
self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal))
|
||||||
|
|
||||||
self._update_mood_text()
|
self._update_mood_text()
|
||||||
|
|
||||||
logger.info(f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}")
|
logger.info(
|
||||||
|
f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import asyncio
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -30,38 +31,39 @@ PersonInfoManager 类方法功能摘要:
|
|||||||
logger = get_module_logger("person_info")
|
logger = get_module_logger("person_info")
|
||||||
|
|
||||||
person_info_default = {
|
person_info_default = {
|
||||||
"person_id" : None,
|
"person_id": None,
|
||||||
"platform" : None,
|
"platform": None,
|
||||||
"user_id" : None,
|
"user_id": None,
|
||||||
"nickname" : None,
|
"nickname": None,
|
||||||
# "age" : 0,
|
# "age" : 0,
|
||||||
"relationship_value" : 0,
|
"relationship_value": 0,
|
||||||
# "saved" : True,
|
# "saved" : True,
|
||||||
# "impression" : None,
|
# "impression" : None,
|
||||||
# "gender" : Unkown,
|
# "gender" : Unkown,
|
||||||
"konw_time" : 0,
|
"konw_time": 0,
|
||||||
"msg_interval": 3000,
|
"msg_interval": 3000,
|
||||||
"msg_interval_list": []
|
"msg_interval_list": [],
|
||||||
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
||||||
|
|
||||||
|
|
||||||
class PersonInfoManager:
|
class PersonInfoManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if "person_info" not in db.list_collection_names():
|
if "person_info" not in db.list_collection_names():
|
||||||
db.create_collection("person_info")
|
db.create_collection("person_info")
|
||||||
db.person_info.create_index("person_id", unique=True)
|
db.person_info.create_index("person_id", unique=True)
|
||||||
|
|
||||||
def get_person_id(self, platform:str, user_id:int):
|
def get_person_id(self, platform: str, user_id: int):
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
components = [platform, str(user_id)]
|
components = [platform, str(user_id)]
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
async def create_person_info(self, person_id:str, data:dict = None):
|
async def create_person_info(self, person_id: str, data: dict = None):
|
||||||
"""创建一个项"""
|
"""创建一个项"""
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.debug("创建失败,personid不存在")
|
logger.debug("创建失败,personid不存在")
|
||||||
return
|
return
|
||||||
|
|
||||||
_person_info_default = copy.deepcopy(person_info_default)
|
_person_info_default = copy.deepcopy(person_info_default)
|
||||||
_person_info_default["person_id"] = person_id
|
_person_info_default["person_id"] = person_id
|
||||||
|
|
||||||
@@ -72,19 +74,16 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
db.person_info.insert_one(_person_info_default)
|
db.person_info.insert_one(_person_info_default)
|
||||||
|
|
||||||
async def update_one_field(self, person_id:str, field_name:str, value, Data:dict = None):
|
async def update_one_field(self, person_id: str, field_name: str, value, Data: dict = None):
|
||||||
"""更新某一个字段,会补全"""
|
"""更新某一个字段,会补全"""
|
||||||
if field_name not in person_info_default.keys():
|
if field_name not in person_info_default.keys():
|
||||||
logger.debug(f"更新'{field_name}'失败,未定义的字段")
|
logger.debug(f"更新'{field_name}'失败,未定义的字段")
|
||||||
return
|
return
|
||||||
|
|
||||||
document = db.person_info.find_one({"person_id": person_id})
|
document = db.person_info.find_one({"person_id": person_id})
|
||||||
|
|
||||||
if document:
|
if document:
|
||||||
db.person_info.update_one(
|
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
|
||||||
{"person_id": person_id},
|
|
||||||
{"$set": {field_name: value}}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
Data[field_name] = value
|
Data[field_name] = value
|
||||||
logger.debug(f"更新时{person_id}不存在,已新建")
|
logger.debug(f"更新时{person_id}不存在,已新建")
|
||||||
@@ -107,23 +106,20 @@ class PersonInfoManager:
|
|||||||
if not person_id:
|
if not person_id:
|
||||||
logger.debug("get_value获取失败:person_id不能为空")
|
logger.debug("get_value获取失败:person_id不能为空")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if field_name not in person_info_default:
|
if field_name not in person_info_default:
|
||||||
logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
|
logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
document = db.person_info.find_one(
|
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
||||||
{"person_id": person_id},
|
|
||||||
{field_name: 1}
|
|
||||||
)
|
|
||||||
|
|
||||||
if document and field_name in document:
|
if document and field_name in document:
|
||||||
return document[field_name]
|
return document[field_name]
|
||||||
else:
|
else:
|
||||||
default_value = copy.deepcopy(person_info_default[field_name])
|
default_value = copy.deepcopy(person_info_default[field_name])
|
||||||
logger.debug(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
|
logger.debug(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
async def get_values(self, person_id: str, field_names: list) -> dict:
|
async def get_values(self, person_id: str, field_names: list) -> dict:
|
||||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||||
if not person_id:
|
if not person_id:
|
||||||
@@ -139,62 +135,57 @@ class PersonInfoManager:
|
|||||||
# 构建查询投影(所有字段都有效才会执行到这里)
|
# 构建查询投影(所有字段都有效才会执行到这里)
|
||||||
projection = {field: 1 for field in field_names}
|
projection = {field: 1 for field in field_names}
|
||||||
|
|
||||||
document = db.person_info.find_one(
|
document = db.person_info.find_one({"person_id": person_id}, projection)
|
||||||
{"person_id": person_id},
|
|
||||||
projection
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for field in field_names:
|
for field in field_names:
|
||||||
result[field] = copy.deepcopy(
|
result[field] = copy.deepcopy(
|
||||||
document.get(field, person_info_default[field])
|
document.get(field, person_info_default[field]) if document else person_info_default[field]
|
||||||
if document else person_info_default[field]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def del_all_undefined_field(self):
|
async def del_all_undefined_field(self):
|
||||||
"""删除所有项里的未定义字段"""
|
"""删除所有项里的未定义字段"""
|
||||||
# 获取所有已定义的字段名
|
# 获取所有已定义的字段名
|
||||||
defined_fields = set(person_info_default.keys())
|
defined_fields = set(person_info_default.keys())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 遍历集合中的所有文档
|
# 遍历集合中的所有文档
|
||||||
for document in db.person_info.find({}):
|
for document in db.person_info.find({}):
|
||||||
# 找出文档中未定义的字段
|
# 找出文档中未定义的字段
|
||||||
undefined_fields = set(document.keys()) - defined_fields - {'_id'}
|
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
|
||||||
|
|
||||||
if undefined_fields:
|
if undefined_fields:
|
||||||
# 构建更新操作,使用$unset删除未定义字段
|
# 构建更新操作,使用$unset删除未定义字段
|
||||||
update_result = db.person_info.update_one(
|
update_result = db.person_info.update_one(
|
||||||
{'_id': document['_id']},
|
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
|
||||||
{'$unset': {field: 1 for field in undefined_fields}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if update_result.modified_count > 0:
|
if update_result.modified_count > 0:
|
||||||
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
|
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理未定义字段时出错: {e}")
|
logger.error(f"清理未定义字段时出错: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
async def get_specific_value_list(
|
async def get_specific_value_list(
|
||||||
self,
|
self,
|
||||||
field_name: str,
|
field_name: str,
|
||||||
way: Callable[[Any], bool], # 接受任意类型值
|
way: Callable[[Any], bool], # 接受任意类型值
|
||||||
) ->Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取满足条件的字段值字典
|
获取满足条件的字段值字典
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field_name: 目标字段名
|
field_name: 目标字段名
|
||||||
way: 判断函数 (value: Any) -> bool
|
way: 判断函数 (value: Any) -> bool
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{person_id: value} | {}
|
{person_id: value} | {}
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
# 查找所有nickname包含"admin"的用户
|
# 查找所有nickname包含"admin"的用户
|
||||||
result = manager.specific_value_list(
|
result = manager.specific_value_list(
|
||||||
@@ -208,10 +199,7 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = {}
|
result = {}
|
||||||
for doc in db.person_info.find(
|
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
|
||||||
{field_name: {"$exists": True}},
|
|
||||||
{"person_id": 1, field_name: 1, "_id": 0}
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
value = doc[field_name]
|
value = doc[field_name]
|
||||||
if way(value):
|
if way(value):
|
||||||
@@ -225,11 +213,11 @@ class PersonInfoManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
|
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def personal_habit_deduction(self):
|
async def personal_habit_deduction(self):
|
||||||
"""启动个人信息推断,每天根据一定条件推断一次"""
|
"""启动个人信息推断,每天根据一定条件推断一次"""
|
||||||
try:
|
try:
|
||||||
while(1):
|
while 1:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
current_time = datetime.datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
@@ -237,8 +225,7 @@ class PersonInfoManager:
|
|||||||
# "msg_interval"推断
|
# "msg_interval"推断
|
||||||
msg_interval_map = False
|
msg_interval_map = False
|
||||||
msg_interval_lists = await self.get_specific_value_list(
|
msg_interval_lists = await self.get_specific_value_list(
|
||||||
"msg_interval_list",
|
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
|
||||||
lambda x: isinstance(x, list) and len(x) >= 100
|
|
||||||
)
|
)
|
||||||
for person_id, msg_interval_list_ in msg_interval_lists.items():
|
for person_id, msg_interval_list_ in msg_interval_lists.items():
|
||||||
try:
|
try:
|
||||||
@@ -258,23 +245,23 @@ class PersonInfoManager:
|
|||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
time_series = pd.Series(time_interval)
|
time_series = pd.Series(time_interval)
|
||||||
plt.hist(time_series, bins=50, density=True, alpha=0.4, color='pink', label='Histogram')
|
plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
|
||||||
time_series.plot(kind='kde', color='mediumpurple', linewidth=1, label='Density')
|
time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
|
||||||
plt.grid(True, alpha=0.2)
|
plt.grid(True, alpha=0.2)
|
||||||
plt.xlim(0, 8000)
|
plt.xlim(0, 8000)
|
||||||
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
|
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
|
||||||
plt.xlabel("Interval (ms)")
|
plt.xlabel("Interval (ms)")
|
||||||
plt.ylabel("Density")
|
plt.ylabel("Density")
|
||||||
plt.legend(framealpha=0.9, facecolor='white')
|
plt.legend(framealpha=0.9, facecolor="white")
|
||||||
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
||||||
plt.savefig(img_path)
|
plt.savefig(img_path)
|
||||||
plt.close()
|
plt.close()
|
||||||
# 画图
|
# 画图
|
||||||
|
|
||||||
q25, q75 = np.percentile(time_interval, [25, 75])
|
q25, q75 = np.percentile(time_interval, [25, 75])
|
||||||
iqr = q75 - q25
|
iqr = q75 - q25
|
||||||
filtered = [x for x in time_interval if (q25 - 1.5*iqr) <= x <= (q75 + 1.5*iqr)]
|
filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
|
||||||
|
|
||||||
msg_interval = int(round(np.percentile(filtered, 80)))
|
msg_interval = int(round(np.percentile(filtered, 80)))
|
||||||
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
||||||
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
|
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ relationship_config = LogConfig(
|
|||||||
)
|
)
|
||||||
logger = get_module_logger("rel_manager", config=relationship_config)
|
logger = get_module_logger("rel_manager", config=relationship_config)
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.positive_feedback_value = 0 # 正反馈系统
|
self.positive_feedback_value = 0 # 正反馈系统
|
||||||
@@ -22,6 +23,7 @@ class RelationshipManager:
|
|||||||
def mood_manager(self):
|
def mood_manager(self):
|
||||||
if self._mood_manager is None:
|
if self._mood_manager is None:
|
||||||
from ..moods.moods import MoodManager # 延迟导入
|
from ..moods.moods import MoodManager # 延迟导入
|
||||||
|
|
||||||
self._mood_manager = MoodManager.get_instance()
|
self._mood_manager = MoodManager.get_instance()
|
||||||
return self._mood_manager
|
return self._mood_manager
|
||||||
|
|
||||||
@@ -51,27 +53,27 @@ class RelationshipManager:
|
|||||||
self.positive_feedback_value -= 1
|
self.positive_feedback_value -= 1
|
||||||
elif self.positive_feedback_value > 0:
|
elif self.positive_feedback_value > 0:
|
||||||
self.positive_feedback_value = 0
|
self.positive_feedback_value = 0
|
||||||
|
|
||||||
if abs(self.positive_feedback_value) > 1:
|
if abs(self.positive_feedback_value) > 1:
|
||||||
logger.info(f"触发mood变更增益,当前增益系数:{self.gain_coefficient[abs(self.positive_feedback_value)]}")
|
logger.info(f"触发mood变更增益,当前增益系数:{self.gain_coefficient[abs(self.positive_feedback_value)]}")
|
||||||
|
|
||||||
def mood_feedback(self, value):
|
def mood_feedback(self, value):
|
||||||
"""情绪反馈"""
|
"""情绪反馈"""
|
||||||
mood_manager = self.mood_manager
|
mood_manager = self.mood_manager
|
||||||
mood_gain = (mood_manager.get_current_mood().valence) ** 2 \
|
mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
|
||||||
* math.copysign(1, value * mood_manager.get_current_mood().valence)
|
1, value * mood_manager.get_current_mood().valence
|
||||||
|
)
|
||||||
value += value * mood_gain
|
value += value * mood_gain
|
||||||
logger.info(f"当前relationship增益系数:{mood_gain:.3f}")
|
logger.info(f"当前relationship增益系数:{mood_gain:.3f}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def feedback_to_mood(self, mood_value):
|
def feedback_to_mood(self, mood_value):
|
||||||
"""对情绪的反馈"""
|
"""对情绪的反馈"""
|
||||||
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
|
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
|
||||||
if (mood_value > 0 and self.positive_feedback_value > 0
|
if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
|
||||||
or mood_value < 0 and self.positive_feedback_value < 0):
|
return mood_value * coefficient
|
||||||
return mood_value*coefficient
|
|
||||||
else:
|
else:
|
||||||
return mood_value/coefficient
|
return mood_value / coefficient
|
||||||
|
|
||||||
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
|
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
|
||||||
"""计算并变更关系值
|
"""计算并变更关系值
|
||||||
@@ -88,7 +90,7 @@ class RelationshipManager:
|
|||||||
"中立": 1,
|
"中立": 1,
|
||||||
"反对": 2,
|
"反对": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
valuedict = {
|
valuedict = {
|
||||||
"开心": 1.5,
|
"开心": 1.5,
|
||||||
"愤怒": -2.0,
|
"愤怒": -2.0,
|
||||||
@@ -103,10 +105,10 @@ class RelationshipManager:
|
|||||||
|
|
||||||
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
|
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
|
||||||
data = {
|
data = {
|
||||||
"platform" : chat_stream.user_info.platform,
|
"platform": chat_stream.user_info.platform,
|
||||||
"user_id" : chat_stream.user_info.user_id,
|
"user_id": chat_stream.user_info.user_id,
|
||||||
"nickname" : chat_stream.user_info.user_nickname,
|
"nickname": chat_stream.user_info.user_nickname,
|
||||||
"konw_time" : int(time.time())
|
"konw_time": int(time.time()),
|
||||||
}
|
}
|
||||||
old_value = await person_info_manager.get_value(person_id, "relationship_value")
|
old_value = await person_info_manager.get_value(person_id, "relationship_value")
|
||||||
old_value = self.ensure_float(old_value, person_id)
|
old_value = self.ensure_float(old_value, person_id)
|
||||||
@@ -200,4 +202,5 @@ class RelationshipManager:
|
|||||||
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}),已重置为0")
|
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}),已重置为0")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
relationship_manager = RelationshipManager()
|
relationship_manager = RelationshipManager()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfi
|
|||||||
from src.plugins.models.utils_model import LLM_request # noqa: E402
|
from src.plugins.models.utils_model import LLM_request # noqa: E402
|
||||||
from src.plugins.config.config import global_config # noqa: E402
|
from src.plugins.config.config import global_config # noqa: E402
|
||||||
|
|
||||||
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
|
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
|
||||||
|
|
||||||
|
|
||||||
schedule_config = LogConfig(
|
schedule_config = LogConfig(
|
||||||
@@ -31,10 +31,16 @@ class ScheduleGenerator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 使用离线LLM模型
|
# 使用离线LLM模型
|
||||||
self.llm_scheduler_all = LLM_request(
|
self.llm_scheduler_all = LLM_request(
|
||||||
model=global_config.llm_reasoning, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=7000, request_type="schedule"
|
model=global_config.llm_reasoning,
|
||||||
|
temperature=global_config.SCHEDULE_TEMPERATURE,
|
||||||
|
max_tokens=7000,
|
||||||
|
request_type="schedule",
|
||||||
)
|
)
|
||||||
self.llm_scheduler_doing = LLM_request(
|
self.llm_scheduler_doing = LLM_request(
|
||||||
model=global_config.llm_normal, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=2048, request_type="schedule"
|
model=global_config.llm_normal,
|
||||||
|
temperature=global_config.SCHEDULE_TEMPERATURE,
|
||||||
|
max_tokens=2048,
|
||||||
|
request_type="schedule",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.today_schedule_text = ""
|
self.today_schedule_text = ""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
@@ -22,6 +22,7 @@ class LLMStatistics:
|
|||||||
self.stats_thread = None
|
self.stats_thread = None
|
||||||
self.console_thread = None
|
self.console_thread = None
|
||||||
self._init_database()
|
self._init_database()
|
||||||
|
self.name_dict: Dict[List] = {}
|
||||||
|
|
||||||
def _init_database(self):
|
def _init_database(self):
|
||||||
"""初始化数据库集合"""
|
"""初始化数据库集合"""
|
||||||
@@ -137,16 +138,24 @@ class LLMStatistics:
|
|||||||
# user_id = str(doc.get("user_info", {}).get("user_id", "unknown"))
|
# user_id = str(doc.get("user_info", {}).get("user_id", "unknown"))
|
||||||
chat_info = doc.get("chat_info", {})
|
chat_info = doc.get("chat_info", {})
|
||||||
user_info = doc.get("user_info", {})
|
user_info = doc.get("user_info", {})
|
||||||
|
message_time = doc.get("time", 0)
|
||||||
group_info = chat_info.get("group_info") if chat_info else {}
|
group_info = chat_info.get("group_info") if chat_info else {}
|
||||||
# print(f"group_info: {group_info}")
|
# print(f"group_info: {group_info}")
|
||||||
group_name = None
|
group_name = None
|
||||||
if group_info:
|
if group_info:
|
||||||
|
group_id = f"g{group_info.get('group_id')}"
|
||||||
group_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
|
group_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
|
||||||
if user_info and not group_name:
|
if user_info and not group_name:
|
||||||
|
group_id = f"u{user_info['user_id']}"
|
||||||
group_name = user_info["user_nickname"]
|
group_name = user_info["user_nickname"]
|
||||||
|
if self.name_dict.get(group_id):
|
||||||
|
if message_time > self.name_dict.get(group_id)[1]:
|
||||||
|
self.name_dict[group_id] = [group_name, message_time]
|
||||||
|
else:
|
||||||
|
self.name_dict[group_id] = [group_name, message_time]
|
||||||
# print(f"group_name: {group_name}")
|
# print(f"group_name: {group_name}")
|
||||||
stats["messages_by_user"][user_id] += 1
|
stats["messages_by_user"][user_id] += 1
|
||||||
stats["messages_by_chat"][group_name] += 1
|
stats["messages_by_chat"][group_id] += 1
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@@ -187,7 +196,7 @@ class LLMStatistics:
|
|||||||
tokens = stats["tokens_by_model"][model_name]
|
tokens = stats["tokens_by_model"][model_name]
|
||||||
cost = stats["costs_by_model"][model_name]
|
cost = stats["costs_by_model"][model_name]
|
||||||
output.append(
|
output.append(
|
||||||
data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
|
data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
|
||||||
)
|
)
|
||||||
output.append("")
|
output.append("")
|
||||||
|
|
||||||
@@ -221,8 +230,8 @@ class LLMStatistics:
|
|||||||
# 添加聊天统计
|
# 添加聊天统计
|
||||||
output.append("群组统计:")
|
output.append("群组统计:")
|
||||||
output.append(("群组名称 消息数量"))
|
output.append(("群组名称 消息数量"))
|
||||||
for group_name, count in sorted(stats["messages_by_chat"].items()):
|
for group_id, count in sorted(stats["messages_by_chat"].items()):
|
||||||
output.append(f"{group_name[:32]:<32} {count:>10}")
|
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
|
||||||
|
|
||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
|
|
||||||
@@ -250,7 +259,7 @@ class LLMStatistics:
|
|||||||
tokens = stats["tokens_by_model"][model_name]
|
tokens = stats["tokens_by_model"][model_name]
|
||||||
cost = stats["costs_by_model"][model_name]
|
cost = stats["costs_by_model"][model_name]
|
||||||
output.append(
|
output.append(
|
||||||
data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
|
data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
|
||||||
)
|
)
|
||||||
output.append("")
|
output.append("")
|
||||||
|
|
||||||
@@ -284,8 +293,8 @@ class LLMStatistics:
|
|||||||
# 添加聊天统计
|
# 添加聊天统计
|
||||||
output.append("群组统计:")
|
output.append("群组统计:")
|
||||||
output.append(("群组名称 消息数量"))
|
output.append(("群组名称 消息数量"))
|
||||||
for group_name, count in sorted(stats["messages_by_chat"].items()):
|
for group_id, count in sorted(stats["messages_by_chat"].items()):
|
||||||
output.append(f"{group_name[:32]:<32} {count:>10}")
|
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
|
||||||
|
|
||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
|
|
||||||
|
|||||||
@@ -53,18 +53,18 @@ class KnowledgeLibrary:
|
|||||||
# 按空行分割内容
|
# 按空行分割内容
|
||||||
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
|
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
for para in paragraphs:
|
for para in paragraphs:
|
||||||
para_length = len(para)
|
para_length = len(para)
|
||||||
|
|
||||||
# 如果段落长度小于等于最大长度,直接添加
|
# 如果段落长度小于等于最大长度,直接添加
|
||||||
if para_length <= max_length:
|
if para_length <= max_length:
|
||||||
chunks.append(para)
|
chunks.append(para)
|
||||||
else:
|
else:
|
||||||
# 如果段落超过最大长度,则按最大长度切分
|
# 如果段落超过最大长度,则按最大长度切分
|
||||||
for i in range(0, para_length, max_length):
|
for i in range(0, para_length, max_length):
|
||||||
chunks.append(para[i:i + max_length])
|
chunks.append(para[i : i + max_length])
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> list:
|
def get_embedding(self, text: str) -> list:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user