This commit is contained in:
SengokuCola
2025-04-08 23:12:00 +08:00
45 changed files with 1177 additions and 1224 deletions

20
CLAUDE.md Normal file
View File

@@ -0,0 +1,20 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
- **Run Bot**: `python bot.py`
- **Lint**: `ruff check --fix .` or `ruff format .`
- **Run Tests**: `python -m unittest discover -v`
- **Run Single Test**: `python -m unittest src/plugins/message/test.py`
## Code Style
- **Formatting**: Line length 120 chars, use double quotes for strings
- **Imports**: Group standard library, external packages, then internal imports
- **Naming**: snake_case for functions/variables, PascalCase for classes
- **Error Handling**: Use try/except blocks with specific exceptions
- **Types**: Use type hints where possible
- **Docstrings**: Document classes and complex functions
- **Linting**: Follow ruff rules (E, F, B) with ignores E711, E501
When making changes, run `ruff check --fix .` to ensure code follows style guidelines. The codebase uses Ruff for linting and formatting.

159
README.md
View File

@@ -1,24 +1,66 @@
# 麦麦MaiCore-MaiMBot (编辑中) # 麦麦MaiCore-MaiMBot (编辑中)
<br />
<div align="center">
![Python Version](https://img.shields.io/badge/Python-3.9+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者)
![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数)
![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数)
![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot)
</div>
<p align="center">
<a href="https://github.com/MaiM-with-u/MaiBot/">
<img src="depends-data/maimai.png" alt="Logo" width="200">
</a>
<br />
<a href="https://space.bilibili.com/1344099355">
画师略nd
</a>
<h3 align="center">MaiBot(麦麦)</h3>
<p align="center">
一款专注于<strong> 群组聊天 </strong>的赛博网友
<br />
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
<br />
<br />
<!-- <a href="https://github.com/shaojintian/Best_README_template">查看Demo</a>
· -->
<a href="https://github.com/MaiM-with-u/MaiBot/issues">报告Bug</a>
·
<a href="https://github.com/MaiM-with-u/MaiBot/issues">提出新特性</a>
</p>
</p>
## 新版0.6.0部署前先阅读https://docs.mai-mai.org/manual/usage/mmc_q_a ## 新版0.6.0部署前先阅读https://docs.mai-mai.org/manual/usage/mmc_q_a
<div align="center">
![Python Version](https://img.shields.io/badge/Python-3.9+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
</div>
## 📝 项目简介 ## 📝 项目简介
**🍔MaiCore是一个基于大语言模型的可交互智能体** **🍔MaiCore是一个基于大语言模型的可交互智能体**
- LLM 提供对话能力
- 动态Prompt构建器 - 💭 **智能对话系统**基于LLM的自然语言交互
- 实时思维系统 - 🤔 **实时思维系统**:模拟人类思考过程
- MongoDB 提供数据持久化支持 - 💝 **情感表达系统**:丰富的表情包和情绪表达
- 可扩展,可支持多种平台和多种功能 - 🧠 **持久记忆系统**基于MongoDB的长期记忆存储
- 🔄 **动态人格系统**:自适应的性格特征
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="depends-data/video.png" width="200" alt="麦麦演示视频">
<br>
👆 点击观看麦麦演示视频 👆
</a>
</div>
### 📢 版本信息
**最新版本: v0.6.0** ([查看更新日志](changelogs/changelog.md)) **最新版本: v0.6.0** ([查看更新日志](changelogs/changelog.md))
> [!WARNING] > [!WARNING]
@@ -28,19 +70,12 @@
> 次版本MaiBot将基于MaiCore运行不再依赖于nonebot相关组件运行。 > 次版本MaiBot将基于MaiCore运行不再依赖于nonebot相关组件运行。
> MaiBot将通过nonebot的插件与nonebot建立联系然后nonebot与QQ建立联系实现MaiBot与QQ的交互 > MaiBot将通过nonebot的插件与nonebot建立联系然后nonebot与QQ建立联系实现MaiBot与QQ的交互
**分支介绍:** **分支说明:**
- main 稳定版本 - `main`: 稳定发布版本
- dev 开发版(不知道什么意思就别下) - `dev`: 开发测试版本(不知道什么意思就别下)
- classical 0.6.0前的版本 - `classical`: 0.6.0前的版本
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="docs/pic/video.png" width="300" alt="麦麦演示视频">
<br>
👆 点击观看麦麦演示视频 👆
</a>
</div>
> [!WARNING] > [!WARNING]
> - 项目处于活跃开发阶段,代码可能随时更改 > - 项目处于活跃开发阶段,代码可能随时更改
@@ -49,6 +84,12 @@
> - 由于持续迭代可能存在一些已知或未知的bug > - 由于持续迭代可能存在一些已知或未知的bug
> - 由于开发中可能消耗较多token > - 由于开发中可能消耗较多token
### ⚠️ 重要提示
- 升级到v0.6.0版本前请务必阅读:[升级指南](https://docs.mai-mai.org/manual/usage/mmc_q_a)
- 本版本基于MaiCore重构通过nonebot插件与QQ平台交互
- 项目处于活跃开发阶段功能和API可能随时调整
### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 ### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779 - [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】 - [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】
@@ -72,55 +113,35 @@
## 🎯 功能介绍 ## 🎯 功能介绍
### 💬 聊天功能 | 模块 | 主要功能 | 特点 |
- 提供思维流(心流)聊天和推理聊天两种对话逻辑 |------|---------|------|
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言 | 💬 聊天系统 | • 思维流/推理聊天<br>关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC) | 拟人化交互 |
- 支持bot名字呼唤发言检测到"麦麦"会主动发言,可配置 | 🧠 思维流系统 | • 实时思考生成<br>• 自动启停机制<br>• 日程系统联动 | 智能化决策 |
- 支持多模型,多厂商自定义配置 | 🧠 记忆系统 2.0 | • 优化记忆抽取<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 |
- 动态的prompt构建器更拟人 | 😊 表情包系统 | • 情绪匹配发送<br>• GIF支持<br>• 自动收集与审查 | 丰富表达 |
- 支持图片,转发消息,回复消息的识别 | 📅 日程系统 | • 动态日程生成<br>• 自定义想象力<br>• 思维流联动 | 智能规划 |
- 支持私聊功能可使用PFC模式的有目的多轮对话实验性 | 👥 关系系统 2.0 | • 关系管理优化<br>• 丰富接口支持<br>• 个性化交互 | 深度社交 |
| 📊 统计系统 | • 使用数据统计<br>• LLM调用记录<br>• 实时控制台显示 | 数据可视 |
| 🔧 系统功能 | • 优雅关闭机制<br>• 自动数据保存<br>• 异常处理完善 | 稳定可靠 |
### 🧠 思维流系统 ## 📐 项目架构
- 思维流能够在回复前后进行思考,生成实时想法
- 思维流自动启停机制,提升资源利用效率
- 思维流与日程系统联动,实现动态日程生成
### 🧠 记忆系统 2.0 ```mermaid
- 优化记忆抽取策略和prompt结构 graph TD
- 改进海马体记忆提取机制,提升自然度 A[MaiCore] --> B[对话系统]
- 对聊天记录进行概括存储,在需要时调用 A --> C[思维流系统]
A --> D[记忆系统]
A --> E[情感系统]
B --> F[多模型支持]
B --> G[动态Prompt]
C --> H[实时思考]
C --> I[日程联动]
D --> J[记忆存储]
D --> K[记忆检索]
E --> L[表情管理]
E --> M[情绪识别]
```
### 😊 表情包系统
- 支持根据发言内容发送对应情绪的表情包
- 支持识别和处理gif表情包
- 会自动偷群友的表情包
- 表情包审查功能
- 表情包文件完整性自动检查
- 自动清理缓存图片
### 📅 日程系统
- 动态更新的日程生成
- 可自定义想象力程度
- 与聊天情况交互(思维流模式下)
### 👥 关系系统 2.0
- 优化关系管理系统,适用于新版本
- 提供更丰富的关系接口
- 针对每个用户创建"关系",实现个性化回复
### 📊 统计系统
- 详细的使用数据统计
- LLM调用统计
- 在控制台显示统计信息
### 🔧 系统功能
- 支持优雅的shutdown机制
- 自动保存功能,定期保存聊天记录和关系数据
- 完善的异常处理机制
- 可自定义时区设置
- 优化的日志输出格式
- 配置自动更新功能
## 开发计划TODOLIST ## 开发计划TODOLIST

BIN
depends-data/maimai.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 455 KiB

BIN
depends-data/video.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -24,10 +24,10 @@
# # 标记GUI是否运行中 # # 标记GUI是否运行中
# self.is_running = True # self.is_running = True
# # 程序关闭时的清理操作 # # 程序关闭时的清理操作
# self.protocol("WM_DELETE_WINDOW", self._on_closing) # self.protocol("WM_DELETE_WINDOW", self._on_closing)
# # 初始化进程、日志队列、日志数据等变量 # # 初始化进程、日志队列、日志数据等变量
# self.process = None # self.process = None
# self.log_queue = queue.Queue() # self.log_queue = queue.Queue()
@@ -236,7 +236,7 @@
# while not self.log_queue.empty(): # while not self.log_queue.empty():
# line = self.log_queue.get() # line = self.log_queue.get()
# self.process_log_line(line) # self.process_log_line(line)
# # 仅在GUI仍在运行时继续处理队列 # # 仅在GUI仍在运行时继续处理队列
# if self.is_running: # if self.is_running:
# self.after(100, self.process_log_queue) # self.after(100, self.process_log_queue)
@@ -245,11 +245,11 @@
# """解析单行日志并更新日志数据和筛选器""" # """解析单行日志并更新日志数据和筛选器"""
# match = re.match( # match = re.match(
# r"""^ # r"""^
# (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)? # (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
# (?P<level>\w+)\s*\|\s* # (?P<level>\w+)\s*\|\s*
# (?P<module>.*?) # (?P<module>.*?)
# \s*[-|]\s* # \s*[-|]\s*
# (?P<message>.*) # (?P<message>.*)
# $""", # $""",
# line.strip(), # line.strip(),
# re.VERBOSE, # re.VERBOSE,
@@ -354,10 +354,10 @@
# """处理窗口关闭事件,安全清理资源""" # """处理窗口关闭事件,安全清理资源"""
# # 标记GUI已关闭 # # 标记GUI已关闭
# self.is_running = False # self.is_running = False
# # 停止日志进程 # # 停止日志进程
# self.stop_process() # self.stop_process()
# # 安全清理tkinter变量 # # 安全清理tkinter变量
# for attr_name in list(self.__dict__.keys()): # for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)): # if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -367,7 +367,7 @@
# except Exception: # except Exception:
# pass # pass
# setattr(self, attr_name, None) # setattr(self, attr_name, None)
# self.quit() # self.quit()
# sys.exit(0) # sys.exit(0)

View File

@@ -127,7 +127,7 @@
# """处理窗口关闭事件""" # """处理窗口关闭事件"""
# # 标记GUI已关闭防止后台线程继续访问tkinter对象 # # 标记GUI已关闭防止后台线程继续访问tkinter对象
# self.is_running = False # self.is_running = False
# # 安全清理所有可能的tkinter变量 # # 安全清理所有可能的tkinter变量
# for attr_name in list(self.__dict__.keys()): # for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)): # if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -138,7 +138,7 @@
# except Exception: # except Exception:
# pass # pass
# setattr(self, attr_name, None) # setattr(self, attr_name, None)
# # 退出 # # 退出
# self.root.quit() # self.root.quit()
# sys.exit(0) # sys.exit(0)
@@ -259,7 +259,7 @@
# while True: # while True:
# if not self.is_running: # if not self.is_running:
# break # 如果GUI已关闭停止线程 # break # 如果GUI已关闭停止线程
# try: # try:
# # 从数据库获取最新数据,只获取启动时间之后的记录 # # 从数据库获取最新数据,只获取启动时间之后的记录
# query = {"time": {"$gt": self.start_timestamp}} # query = {"time": {"$gt": self.start_timestamp}}

View File

@@ -42,7 +42,6 @@ class Heartflow:
self._subheartflows = {} self._subheartflows = {}
self.active_subheartflows_nums = 0 self.active_subheartflows_nums = 0
async def _cleanup_inactive_subheartflows(self): async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流""" """定期清理不活跃的子心流"""
while True: while True:
@@ -84,25 +83,22 @@ class Heartflow:
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = "memory" related_memory_info = "memory"
@@ -146,22 +142,20 @@ class Heartflow:
async def minds_summary(self, minds_str): async def minds_summary(self, minds_str):
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -183,7 +177,7 @@ class Heartflow:
添加一个SubHeartflow实例到self._subheartflows字典中 添加一个SubHeartflow实例到self._subheartflows字典中
并根据subheartflow_id为子心流创建一个观察对象 并根据subheartflow_id为子心流创建一个观察对象
""" """
try: try:
if subheartflow_id not in self._subheartflows: if subheartflow_id not in self._subheartflows:
logger.debug(f"创建 subheartflow: {subheartflow_id}") logger.debug(f"创建 subheartflow: {subheartflow_id}")

View File

@@ -7,6 +7,7 @@ from src.common.database import db
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
import random import random
# 所有观察的基类 # 所有观察的基类
class Observation: class Observation:
def __init__(self, observe_type, observe_id): def __init__(self, observe_type, observe_id):
@@ -24,7 +25,7 @@ class ChattingObservation(Observation):
self.talking_message = [] self.talking_message = []
self.talking_message_str = "" self.talking_message_str = ""
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES self.nick_name = global_config.BOT_ALIAS_NAMES
@@ -57,7 +58,7 @@ class ChattingObservation(Observation):
for msg in new_messages: for msg in new_messages:
if "detailed_plain_text" in msg: if "detailed_plain_text" in msg:
new_messages_str += f"{msg['detailed_plain_text']}" new_messages_str += f"{msg['detailed_plain_text']}"
# print(f"new_messages_str{new_messages_str}") # print(f"new_messages_str{new_messages_str}")
# 将新消息添加到talking_message同时保持列表长度不超过20条 # 将新消息添加到talking_message同时保持列表长度不超过20条
@@ -117,26 +118,22 @@ class ChattingObservation(Observation):
# print(f"更新聊天总结:{self.talking_summary}") # print(f"更新聊天总结:{self.talking_summary}")
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
prompt = "" prompt = ""
prompt += f"{personality_info},请注意识别你自己的聊天发言" prompt += f"{personality_info},请注意识别你自己的聊天发言"
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n" prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
@@ -148,7 +145,6 @@ class ChattingObservation(Observation):
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
print(f"prompt{prompt}") print(f"prompt{prompt}")
print(f"self.observe_info{self.observe_info}") print(f"self.observe_info{self.observe_info}")
def translate_message_list_to_str(self): def translate_message_list_to_str(self):
self.talking_message_str = "" self.talking_message_str = ""

View File

@@ -53,11 +53,10 @@ class SubHeartflow:
if not self.current_mind: if not self.current_mind:
self.current_mind = "你什么也没想" self.current_mind = "你什么也没想"
self.is_active = False self.is_active = False
self.observations: list[Observation] = [] self.observations: list[Observation] = []
self.running_knowledges = [] self.running_knowledges = []
def add_observation(self, observation: Observation): def add_observation(self, observation: Observation):
@@ -86,7 +85,9 @@ class SubHeartflow:
async def subheartflow_start_working(self): async def subheartflow_start_working(self):
while True: while True:
current_time = time.time() current_time = time.time()
if current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time: # 120秒无回复/不在场,冻结 if (
current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
): # 120秒无回复/不在场,冻结
self.is_active = False self.is_active = False
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次 await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
else: else:
@@ -100,7 +101,9 @@ class SubHeartflow:
await asyncio.sleep(global_config.sub_heart_flow_update_interval) await asyncio.sleep(global_config.sub_heart_flow_update_interval)
# 检查是否超过10分钟没有激活 # 检查是否超过10分钟没有激活
if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 5分钟无回复/不在场,销毁 if (
current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...") logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己 break # 退出循环,销毁自己
@@ -147,11 +150,11 @@ class SubHeartflow:
# self.current_mind = reponse # self.current_mind = reponse
# logger.debug(f"prompt:\n{prompt}\n") # logger.debug(f"prompt:\n{prompt}\n")
# logger.info(f"麦麦的脑内状态:{self.current_mind}") # logger.info(f"麦麦的脑内状态:{self.current_mind}")
async def do_observe(self): async def do_observe(self):
observation = self.observations[0] observation = self.observations[0]
await observation.observe() await observation.observe()
async def do_thinking_before_reply(self, message_txt): async def do_thinking_before_reply(self, message_txt):
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -162,23 +165,20 @@ class SubHeartflow:
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 调取记忆 # 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -191,7 +191,7 @@ class SubHeartflow:
else: else:
related_memory_info = "" related_memory_info = ""
related_info,grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4) related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
# print(related_info) # print(related_info)
for _topic, results in grouped_results.items(): for _topic, results in grouped_results.items():
for result in results: for result in results:
@@ -227,25 +227,23 @@ class SubHeartflow:
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt): async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
# print("麦麦回复之后脑袋转起来了") # print("麦麦回复之后脑袋转起来了")
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -279,22 +277,20 @@ class SubHeartflow:
async def judge_willing(self): async def judge_willing(self):
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# print("麦麦闹情绪了1") # print("麦麦闹情绪了1")
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -320,13 +316,12 @@ class SubHeartflow:
def update_current_mind(self, reponse): def update_current_mind(self, reponse):
self.past_mind.append(self.current_mind) self.past_mind.append(self.current_mind)
self.current_mind = reponse self.current_mind = reponse
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time() start_time = time.time()
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法 # 1. 先从LLM获取主题类似于记忆系统的做法
topics = [] topics = []
# try: # try:
@@ -334,7 +329,7 @@ class SubHeartflow:
# hippocampus = HippocampusManager.get_instance()._hippocampus # hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1))) # topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词 # # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0]) # topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics: # if not topics:
@@ -345,7 +340,7 @@ class SubHeartflow:
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",") # for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip() # if topic.strip()
# ] # ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}") # logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e: # except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}") # logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -353,7 +348,7 @@ class SubHeartflow:
# words = jieba.cut(message) # words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5] # topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息 # 如果无法提取到主题,直接使用整个消息
if not topics: if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询") logger.debug("未能提取到任何主题,使用整个消息进行查询")
@@ -361,26 +356,26 @@ class SubHeartflow:
if not embedding: if not embedding:
logger.error("获取消息嵌入向量失败") logger.error("获取消息嵌入向量失败")
return "" return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {} return related_info, {}
# 2. 对每个主题进行知识库查询 # 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询") logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用 # 优化批量获取嵌入向量减少API调用
embeddings = {} embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0] topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空 if message: # 确保消息非空
topics_batch.append(message) topics_batch.append(message)
# 批量获取嵌入向量 # 批量获取嵌入向量
embed_start_time = time.time() embed_start_time = time.time()
for text in topics_batch: for text in topics_batch:
if not text or len(text.strip()) == 0: if not text or len(text.strip()) == 0:
continue continue
try: try:
embedding = await get_embedding(text, request_type="info_retrieval") embedding = await get_embedding(text, request_type="info_retrieval")
if embedding: if embedding:
@@ -389,17 +384,17 @@ class SubHeartflow:
logger.warning(f"获取'{text}'的嵌入向量失败") logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e: except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}") logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings: if not embeddings:
logger.error("所有嵌入向量获取失败") logger.error("所有嵌入向量获取失败")
return "" return ""
# 3. 对每个主题进行知识库查询 # 3. 对每个主题进行知识库查询
all_results = [] all_results = []
query_start_time = time.time() query_start_time = time.time()
# 首先添加原始消息的查询结果 # 首先添加原始消息的查询结果
if message in embeddings: if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -408,12 +403,12 @@ class SubHeartflow:
result["topic"] = "原始消息" result["topic"] = "原始消息"
all_results.extend(original_results) all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果") logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果 # 然后添加每个主题的查询结果
for topic in topics: for topic in topics:
if not topic or topic not in embeddings: if not topic or topic not in embeddings:
continue continue
try: try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results: if topic_results:
@@ -424,9 +419,9 @@ class SubHeartflow:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e: except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤 # 4. 去重和过滤
process_start_time = time.time() process_start_time = time.time()
unique_contents = set() unique_contents = set()
@@ -436,14 +431,16 @@ class SubHeartflow:
if content not in unique_contents: if content not in unique_contents:
unique_contents.add(content) unique_contents.add(content)
filtered_results.append(result) filtered_results.append(result)
# 5. 按相似度排序 # 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True) filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条 # 6. 限制总数量最多10条
filtered_results = filtered_results[:10] filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果") logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出 # 7. 格式化输出
if filtered_results: if filtered_results:
format_start_time = time.time() format_start_time = time.time()
@@ -453,7 +450,7 @@ class SubHeartflow:
if topic not in grouped_results: if topic not in grouped_results:
grouped_results[topic] = [] grouped_results[topic] = []
grouped_results[topic].append(result) grouped_results[topic].append(result)
# 按主题组织输出 # 按主题组织输出
for topic, results in grouped_results.items(): for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n" related_info += f"【主题: {topic}\n"
@@ -464,13 +461,15 @@ class SubHeartflow:
# related_info += f"{i}. [{similarity:.2f}] {content}\n" # related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n" related_info += f"{content}\n"
related_info += "\n" related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info,grouped_results
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]: logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算

View File

@@ -2,27 +2,36 @@ from dataclasses import dataclass
from typing import List from typing import List
import random import random
@dataclass @dataclass
class Identity: class Identity:
"""身份特征类""" """身份特征类"""
identity_detail: List[str] # 身份细节描述 identity_detail: List[str] # 身份细节描述
height: int # 身高(厘米) height: int # 身高(厘米)
weight: int # 体重(千克) weight: int # 体重(千克)
age: int # 年龄 age: int # 年龄
gender: str # 性别 gender: str # 性别
appearance: str # 外貌特征 appearance: str # 外貌特征
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, identity_detail: List[str] = None, height: int = 0, weight: int = 0, def __init__(
age: int = 0, gender: str = "", appearance: str = ""): self,
identity_detail: List[str] = None,
height: int = 0,
weight: int = 0,
age: int = 0,
gender: str = "",
appearance: str = "",
):
"""初始化身份特征 """初始化身份特征
Args: Args:
identity_detail: 身份细节描述列表 identity_detail: 身份细节描述列表
height: 身高(厘米) height: 身高(厘米)
@@ -39,23 +48,24 @@ class Identity:
self.age = age self.age = age
self.gender = gender self.gender = gender
self.appearance = appearance self.appearance = appearance
@classmethod @classmethod
def get_instance(cls) -> 'Identity': def get_instance(cls) -> "Identity":
"""获取Identity单例实例 """获取Identity单例实例
Returns: Returns:
Identity: 单例实例 Identity: 单例实例
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
@classmethod @classmethod
def initialize(cls, identity_detail: List[str], height: int, weight: int, def initialize(
age: int, gender: str, appearance: str) -> 'Identity': cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str
) -> "Identity":
"""初始化身份特征 """初始化身份特征
Args: Args:
identity_detail: 身份细节描述列表 identity_detail: 身份细节描述列表
height: 身高(厘米) height: 身高(厘米)
@@ -63,7 +73,7 @@ class Identity:
age: 年龄 age: 年龄
gender: 性别 gender: 性别
appearance: 外貌特征 appearance: 外貌特征
Returns: Returns:
Identity: 初始化后的身份特征实例 Identity: 初始化后的身份特征实例
""" """
@@ -75,8 +85,8 @@ class Identity:
instance.gender = gender instance.gender = gender
instance.appearance = appearance instance.appearance = appearance
return instance return instance
def get_prompt(self,x_person,level): def get_prompt(self, x_person, level):
""" """
获取身份特征的prompt 获取身份特征的prompt
""" """
@@ -86,7 +96,7 @@ class Identity:
prompt_identity = "" prompt_identity = ""
else: else:
prompt_identity = "" prompt_identity = ""
if level == 1: if level == 1:
identity_detail = self.identity_detail identity_detail = self.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
@@ -96,7 +106,7 @@ class Identity:
prompt_identity += f",{detail}" prompt_identity += f",{detail}"
prompt_identity += "" prompt_identity += ""
return prompt_identity return prompt_identity
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""将身份特征转换为字典格式""" """将身份特征转换为字典格式"""
return { return {
@@ -105,13 +115,13 @@ class Identity:
"weight": self.weight, "weight": self.weight,
"age": self.age, "age": self.age,
"gender": self.gender, "gender": self.gender,
"appearance": self.appearance "appearance": self.appearance,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'Identity': def from_dict(cls, data: dict) -> "Identity":
"""从字典创建身份特征实例""" """从字典创建身份特征实例"""
instance = cls.get_instance() instance = cls.get_instance()
for key, value in data.items(): for key, value in data.items():
setattr(instance, key, value) setattr(instance, key, value)
return instance return instance

View File

@@ -2,35 +2,46 @@ from typing import Optional
from .personality import Personality from .personality import Personality
from .identity import Identity from .identity import Identity
class Individuality: class Individuality:
"""个体特征管理类""" """个体特征管理类"""
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.personality: Optional[Personality] = None self.personality: Optional[Personality] = None
self.identity: Optional[Identity] = None self.identity: Optional[Identity] = None
@classmethod @classmethod
def get_instance(cls) -> 'Individuality': def get_instance(cls) -> "Individuality":
"""获取Individuality单例实例 """获取Individuality单例实例
Returns: Returns:
Individuality: 单例实例 Individuality: 单例实例
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def initialize(self, bot_nickname: str, personality_core: str, personality_sides: list, def initialize(
identity_detail: list, height: int, weight: int, age: int, self,
gender: str, appearance: str) -> None: bot_nickname: str,
personality_core: str,
personality_sides: list,
identity_detail: list,
height: int,
weight: int,
age: int,
gender: str,
appearance: str,
) -> None:
"""初始化个体特征 """初始化个体特征
Args: Args:
bot_nickname: 机器人昵称 bot_nickname: 机器人昵称
personality_core: 人格核心特点 personality_core: 人格核心特点
@@ -44,50 +55,43 @@ class Individuality:
""" """
# 初始化人格 # 初始化人格
self.personality = Personality.initialize( self.personality = Personality.initialize(
bot_nickname=bot_nickname, bot_nickname=bot_nickname, personality_core=personality_core, personality_sides=personality_sides
personality_core=personality_core,
personality_sides=personality_sides
) )
# 初始化身份 # 初始化身份
self.identity = Identity.initialize( self.identity = Identity.initialize(
identity_detail=identity_detail, identity_detail=identity_detail, height=height, weight=weight, age=age, gender=gender, appearance=appearance
height=height,
weight=weight,
age=age,
gender=gender,
appearance=appearance
) )
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""将个体特征转换为字典格式""" """将个体特征转换为字典格式"""
return { return {
"personality": self.personality.to_dict() if self.personality else None, "personality": self.personality.to_dict() if self.personality else None,
"identity": self.identity.to_dict() if self.identity else None "identity": self.identity.to_dict() if self.identity else None,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'Individuality': def from_dict(cls, data: dict) -> "Individuality":
"""从字典创建个体特征实例""" """从字典创建个体特征实例"""
instance = cls.get_instance() instance = cls.get_instance()
if data.get("personality"): if data.get("personality"):
instance.personality = Personality.from_dict(data["personality"]) instance.personality = Personality.from_dict(data["personality"])
if data.get("identity"): if data.get("identity"):
instance.identity = Identity.from_dict(data["identity"]) instance.identity = Identity.from_dict(data["identity"])
return instance return instance
def get_prompt(self,type,x_person,level): def get_prompt(self, type, x_person, level):
""" """
获取个体特征的prompt 获取个体特征的prompt
""" """
if type == "personality": if type == "personality":
return self.personality.get_prompt(x_person,level) return self.personality.get_prompt(x_person, level)
elif type == "identity": elif type == "identity":
return self.identity.get_prompt(x_person,level) return self.identity.get_prompt(x_person, level)
else: else:
return "" return ""
def get_traits(self,factor): def get_traits(self, factor):
""" """
获取个体特征的特质 获取个体特征的特质
""" """
@@ -101,5 +105,3 @@ class Individuality:
return self.personality.agreeableness return self.personality.agreeableness
elif factor == "neuroticism": elif factor == "neuroticism":
return self.personality.neuroticism return self.personality.neuroticism

View File

@@ -17,9 +17,9 @@ with open(config_path, "r", encoding="utf-8") as f:
config = toml.load(f) config = toml.load(f)
# 现在可以导入src模块 # 现在可以导入src模块
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES #noqa E402 from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS #noqa E402 from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLM_request_off #noqa E402 from src.individuality.offline_llm import LLM_request_off # noqa E402
# 加载环境变量 # 加载环境变量
env_path = os.path.join(root_path, ".env") env_path = os.path.join(root_path, ".env")
@@ -32,13 +32,12 @@ else:
def adapt_scene(scene: str) -> str: def adapt_scene(scene: str) -> str:
personality_core = config["personality"]["personality_core"]
personality_core = config['personality']['personality_core'] personality_sides = config["personality"]["personality_sides"]
personality_sides = config['personality']['personality_sides']
personality_side = random.choice(personality_sides) personality_side = random.choice(personality_sides)
identity_details = config['identity']['identity_detail'] identity_details = config["identity"]["identity_detail"]
identity_detail = random.choice(identity_details) identity_detail = random.choice(identity_details)
""" """
根据config中的属性改编场景使其更适合当前角色 根据config中的属性改编场景使其更适合当前角色
@@ -51,10 +50,10 @@ def adapt_scene(scene: str) -> str:
try: try:
prompt = f""" prompt = f"""
这是一个参与人格测评的角色形象: 这是一个参与人格测评的角色形象:
- 昵称: {config['bot']['nickname']} - 昵称: {config["bot"]["nickname"]}
- 性别: {config['identity']['gender']} - 性别: {config["identity"]["gender"]}
- 年龄: {config['identity']['age']} - 年龄: {config["identity"]["age"]}
- 外貌: {config['identity']['appearance']} - 外貌: {config["identity"]["appearance"]}
- 性格核心: {personality_core} - 性格核心: {personality_core}
- 性格侧面: {personality_side} - 性格侧面: {personality_side}
- 身份细节: {identity_detail} - 身份细节: {identity_detail}
@@ -62,18 +61,18 @@ def adapt_scene(scene: str) -> str:
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应: 请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
{scene} {scene}
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。 保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config['bot']['nickname']}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述, 改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config["bot"]["nickname"]}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
现在,请你给出改编后的场景描述 现在,请你给出改编后的场景描述
""" """
llm = LLM_request_off(model_name=config['model']['llm_normal']['name']) llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"])
adapted_scene, _ = llm.generate_response(prompt) adapted_scene, _ = llm.generate_response(prompt)
# 检查返回的场景是否为空或错误信息 # 检查返回的场景是否为空或错误信息
if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene: if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene:
print("场景改编失败,将使用原始场景") print("场景改编失败,将使用原始场景")
return scene return scene
return adapted_scene return adapted_scene
except Exception as e: except Exception as e:
print(f"场景改编过程出错:{str(e)},将使用原始场景") print(f"场景改编过程出错:{str(e)},将使用原始场景")
@@ -169,7 +168,7 @@ class PersonalityEvaluator_direct:
except Exception as e: except Exception as e:
print(f"评估过程出错:{str(e)}") print(f"评估过程出错:{str(e)}")
return {dim: 3.5 for dim in dimensions} return {dim: 3.5 for dim in dimensions}
def run_evaluation(self): def run_evaluation(self):
""" """
运行整个评估过程 运行整个评估过程
@@ -185,18 +184,23 @@ class PersonalityEvaluator_direct:
print(f"- 身份细节:{config['identity']['identity_detail']}") print(f"- 身份细节:{config['identity']['identity_detail']}")
print("\n准备好了吗?按回车键开始...") print("\n准备好了吗?按回车键开始...")
input() input()
total_scenarios = len(self.scenarios) total_scenarios = len(self.scenarios)
progress_bar = tqdm(total=total_scenarios, desc="场景进度", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') progress_bar = tqdm(
total=total_scenarios,
desc="场景进度",
ncols=100,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
for _i, scenario_data in enumerate(self.scenarios, 1): for _i, scenario_data in enumerate(self.scenarios, 1):
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}") # print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
# 改编场景,使其更适合当前角色 # 改编场景,使其更适合当前角色
print(f"{config['bot']['nickname']}祈祷中...") print(f"{config['bot']['nickname']}祈祷中...")
adapted_scene = adapt_scene(scenario_data["场景"]) adapted_scene = adapt_scene(scenario_data["场景"])
scenario_data["改编场景"] = adapted_scene scenario_data["改编场景"] = adapted_scene
print(adapted_scene) print(adapted_scene)
print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:") print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:")
response = input().strip() response = input().strip()
@@ -220,13 +224,13 @@ class PersonalityEvaluator_direct:
# 更新进度条 # 更新进度条
progress_bar.update(1) progress_bar.update(1)
# if i < total_scenarios: # if i < total_scenarios:
# print("\n按回车键继续下一个场景...") # print("\n按回车键继续下一个场景...")
# input() # input()
progress_bar.close() progress_bar.close()
# 计算平均分 # 计算平均分
for dimension in self.final_scores: for dimension in self.final_scores:
if self.dimension_counts[dimension] > 0: if self.dimension_counts[dimension] > 0:
@@ -241,26 +245,26 @@ class PersonalityEvaluator_direct:
# 返回评估结果 # 返回评估结果
return self.get_result() return self.get_result()
def get_result(self): def get_result(self):
""" """
获取评估结果 获取评估结果
""" """
return { return {
"final_scores": self.final_scores, "final_scores": self.final_scores,
"dimension_counts": self.dimension_counts, "dimension_counts": self.dimension_counts,
"scenarios": self.scenarios, "scenarios": self.scenarios,
"bot_info": { "bot_info": {
"nickname": config['bot']['nickname'], "nickname": config["bot"]["nickname"],
"gender": config['identity']['gender'], "gender": config["identity"]["gender"],
"age": config['identity']['age'], "age": config["identity"]["age"],
"height": config['identity']['height'], "height": config["identity"]["height"],
"weight": config['identity']['weight'], "weight": config["identity"]["weight"],
"appearance": config['identity']['appearance'], "appearance": config["identity"]["appearance"],
"personality_core": config['personality']['personality_core'], "personality_core": config["personality"]["personality_core"],
"personality_sides": config['personality']['personality_sides'], "personality_sides": config["personality"]["personality_sides"],
"identity_detail": config['identity']['identity_detail'] "identity_detail": config["identity"]["identity_detail"],
} },
} }
@@ -275,28 +279,28 @@ def main():
"extraversion": round(result["final_scores"]["外向性"] / 6, 1), "extraversion": round(result["final_scores"]["外向性"] / 6, 1),
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1), "agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1), "neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
"bot_nickname": config['bot']['nickname'] "bot_nickname": config["bot"]["nickname"],
} }
# 确保目录存在 # 确保目录存在
save_dir = os.path.join(root_path, "data", "personality") save_dir = os.path.join(root_path, "data", "personality")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# 创建文件名,替换可能的非法字符 # 创建文件名,替换可能的非法字符
bot_name = config['bot']['nickname'] bot_name = config["bot"]["nickname"]
# 替换Windows文件名中不允许的字符 # 替换Windows文件名中不允许的字符
for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']: for char in ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]:
bot_name = bot_name.replace(char, '_') bot_name = bot_name.replace(char, "_")
file_name = f"{bot_name}_personality.per" file_name = f"{bot_name}_personality.per"
save_path = os.path.join(save_dir, file_name) save_path = os.path.join(save_dir, file_name)
# 保存简化的结果 # 保存简化的结果
with open(save_path, "w", encoding="utf-8") as f: with open(save_path, "w", encoding="utf-8") as f:
json.dump(simplified_result, f, ensure_ascii=False, indent=4) json.dump(simplified_result, f, ensure_ascii=False, indent=4)
print(f"\n结果已保存到 {save_path}") print(f"\n结果已保存到 {save_path}")
# 同时保存完整结果到results目录 # 同时保存完整结果到results目录
os.makedirs("results", exist_ok=True) os.makedirs("results", exist_ok=True)
with open("results/personality_result.json", "w", encoding="utf-8") as f: with open("results/personality_result.json", "w", encoding="utf-8") as f:

View File

@@ -4,9 +4,11 @@ import json
from pathlib import Path from pathlib import Path
import random import random
@dataclass @dataclass
class Personality: class Personality:
"""人格特质类""" """人格特质类"""
openness: float # 开放性 openness: float # 开放性
conscientiousness: float # 尽责性 conscientiousness: float # 尽责性
extraversion: float # 外向性 extraversion: float # 外向性
@@ -15,45 +17,45 @@ class Personality:
bot_nickname: str # 机器人昵称 bot_nickname: str # 机器人昵称
personality_core: str # 人格核心特点 personality_core: str # 人格核心特点
personality_sides: List[str] # 人格侧面描述 personality_sides: List[str] # 人格侧面描述
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, personality_core: str = "", personality_sides: List[str] = None): def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
if personality_sides is None: if personality_sides is None:
personality_sides = [] personality_sides = []
self.personality_core = personality_core self.personality_core = personality_core
self.personality_sides = personality_sides self.personality_sides = personality_sides
@classmethod @classmethod
def get_instance(cls) -> 'Personality': def get_instance(cls) -> "Personality":
"""获取Personality单例实例 """获取Personality单例实例
Returns: Returns:
Personality: 单例实例 Personality: 单例实例
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def _init_big_five_personality(self): def _init_big_five_personality(self):
"""初始化大五人格特质""" """初始化大五人格特质"""
# 构建文件路径 # 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per" personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
# 如果文件存在,读取文件 # 如果文件存在,读取文件
if personality_file.exists(): if personality_file.exists():
with open(personality_file, 'r', encoding='utf-8') as f: with open(personality_file, "r", encoding="utf-8") as f:
personality_data = json.load(f) personality_data = json.load(f)
self.openness = personality_data.get('openness', 0.5) self.openness = personality_data.get("openness", 0.5)
self.conscientiousness = personality_data.get('conscientiousness', 0.5) self.conscientiousness = personality_data.get("conscientiousness", 0.5)
self.extraversion = personality_data.get('extraversion', 0.5) self.extraversion = personality_data.get("extraversion", 0.5)
self.agreeableness = personality_data.get('agreeableness', 0.5) self.agreeableness = personality_data.get("agreeableness", 0.5)
self.neuroticism = personality_data.get('neuroticism', 0.5) self.neuroticism = personality_data.get("neuroticism", 0.5)
else: else:
# 如果文件不存在根据personality_core和personality_core来设置大五人格特质 # 如果文件不存在根据personality_core和personality_core来设置大五人格特质
if "活泼" in self.personality_core or "开朗" in self.personality_sides: if "活泼" in self.personality_core or "开朗" in self.personality_sides:
@@ -62,31 +64,31 @@ class Personality:
else: else:
self.extraversion = 0.3 self.extraversion = 0.3
self.neuroticism = 0.5 self.neuroticism = 0.5
if "认真" in self.personality_core or "负责" in self.personality_sides: if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9 self.conscientiousness = 0.9
else: else:
self.conscientiousness = 0.5 self.conscientiousness = 0.5
if "友善" in self.personality_core or "温柔" in self.personality_sides: if "友善" in self.personality_core or "温柔" in self.personality_sides:
self.agreeableness = 0.9 self.agreeableness = 0.9
else: else:
self.agreeableness = 0.5 self.agreeableness = 0.5
if "创新" in self.personality_core or "开放" in self.personality_sides: if "创新" in self.personality_core or "开放" in self.personality_sides:
self.openness = 0.8 self.openness = 0.8
else: else:
self.openness = 0.5 self.openness = 0.5
@classmethod @classmethod
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> 'Personality': def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> "Personality":
"""初始化人格特质 """初始化人格特质
Args: Args:
bot_nickname: 机器人昵称 bot_nickname: 机器人昵称
personality_core: 人格核心特点 personality_core: 人格核心特点
personality_sides: 人格侧面描述 personality_sides: 人格侧面描述
Returns: Returns:
Personality: 初始化后的人格特质实例 Personality: 初始化后的人格特质实例
""" """
@@ -96,7 +98,7 @@ class Personality:
instance.personality_sides = personality_sides instance.personality_sides = personality_sides
instance._init_big_five_personality() instance._init_big_five_personality()
return instance return instance
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""将人格特质转换为字典格式""" """将人格特质转换为字典格式"""
return { return {
@@ -107,18 +109,18 @@ class Personality:
"neuroticism": self.neuroticism, "neuroticism": self.neuroticism,
"bot_nickname": self.bot_nickname, "bot_nickname": self.bot_nickname,
"personality_core": self.personality_core, "personality_core": self.personality_core,
"personality_sides": self.personality_sides "personality_sides": self.personality_sides,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'Personality': def from_dict(cls, data: Dict) -> "Personality":
"""从字典创建人格特质实例""" """从字典创建人格特质实例"""
instance = cls.get_instance() instance = cls.get_instance()
for key, value in data.items(): for key, value in data.items():
setattr(instance, key, value) setattr(instance, key, value)
return instance return instance
def get_prompt(self,x_person,level): def get_prompt(self, x_person, level):
# 开始构建prompt # 开始构建prompt
if x_person == 2: if x_person == 2:
prompt_personality = "" prompt_personality = ""
@@ -126,10 +128,10 @@ class Personality:
prompt_personality = "" prompt_personality = ""
else: else:
prompt_personality = "" prompt_personality = ""
#person # person
prompt_personality += self.personality_core prompt_personality += self.personality_core
if level == 2: if level == 2:
personality_sides = self.personality_sides personality_sides = self.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
@@ -140,5 +142,5 @@ class Personality:
prompt_personality += f",{side}" prompt_personality += f",{side}"
prompt_personality += "" prompt_personality += ""
return prompt_personality return prompt_personality

View File

@@ -2,6 +2,7 @@ import json
from typing import Dict from typing import Dict
import os import os
def load_scenes() -> Dict: def load_scenes() -> Dict:
""" """
从JSON文件加载场景数据 从JSON文件加载场景数据
@@ -10,13 +11,15 @@ def load_scenes() -> Dict:
Dict: 包含所有场景的字典 Dict: 包含所有场景的字典
""" """
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
json_path = os.path.join(current_dir, 'template_scene.json') json_path = os.path.join(current_dir, "template_scene.json")
with open(json_path, 'r', encoding='utf-8') as f: with open(json_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
PERSONALITY_SCENES = load_scenes() PERSONALITY_SCENES = load_scenes()
def get_scene_by_factor(factor: str) -> Dict: def get_scene_by_factor(factor: str) -> Dict:
""" """
根据人格因子获取对应的情景测试 根据人格因子获取对应的情景测试

View File

@@ -100,7 +100,7 @@ class MainSystem:
weight=global_config.weight, weight=global_config.weight,
age=global_config.age, age=global_config.age,
gender=global_config.gender, gender=global_config.gender,
appearance=global_config.appearance appearance=global_config.appearance,
) )
logger.success("个体特征初始化成功") logger.success("个体特征初始化成功")
@@ -135,7 +135,6 @@ class MainSystem:
await asyncio.sleep(global_config.build_memory_interval) await asyncio.sleep(global_config.build_memory_interval)
logger.info("正在进行记忆构建") logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory() await HippocampusManager.get_instance().build_memory()
async def forget_memory_task(self): async def forget_memory_task(self):
"""记忆遗忘任务""" """记忆遗忘任务"""
@@ -144,7 +143,6 @@ class MainSystem:
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage) await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
async def print_mood_task(self): async def print_mood_task(self):
"""打印情绪状态""" """打印情绪状态"""

View File

@@ -1,6 +1,6 @@
import time import time
import asyncio import asyncio
from typing import Optional, Dict, Any, List, Tuple from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
from ..config.config import global_config from ..config.config import global_config
@@ -9,16 +9,17 @@ from .message_storage import MessageStorage, MongoDBMessageStorage
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")
class ChatObserver: class ChatObserver:
"""聊天状态观察器""" """聊天状态观察器"""
# 类级别的实例管理 # 类级别的实例管理
_instances: Dict[str, 'ChatObserver'] = {} _instances: Dict[str, "ChatObserver"] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver': def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现 message_storage: 消息存储实现如果为None则使用MongoDB实现
@@ -32,14 +33,14 @@ class ChatObserver:
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
"""初始化观察器 """初始化观察器
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现 message_storage: 消息存储实现如果为None则使用MongoDB实现
""" """
if stream_id in self._instances: if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.") raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage() self.message_storage = message_storage or MongoDBMessageStorage()
@@ -53,9 +54,9 @@ class ChatObserver:
# 消息历史记录 # 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数 self.message_count: int = 0 # 消息计数
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
@@ -77,7 +78,7 @@ class ChatObserver:
async def check(self) -> bool: async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息 """检查距离上一次观察之后是否有了新消息
Returns: Returns:
bool: 是否有新消息 bool: 是否有新消息
""" """
@@ -91,7 +92,7 @@ class ChatObserver:
if new_message_exists: if new_message_exists:
logger.debug("发现新消息") logger.debug("发现新消息")
self.last_check_time = time.time() self.last_check_time = time.time()
return new_message_exists return new_message_exists
async def _add_message_to_history(self, message: Dict[str, Any]): async def _add_message_to_history(self, message: Dict[str, Any]):
@@ -104,7 +105,7 @@ class ChatObserver:
self.last_message_id = message["message_id"] self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间 self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1 self.message_count += 1
# 更新说话时间 # 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {})) user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ: if user_info.user_id == global_config.BOT_QQ:
@@ -186,41 +187,40 @@ class ChatObserver:
start_time: Optional[float] = None, start_time: Optional[float] = None,
end_time: Optional[float] = None, end_time: Optional[float] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
user_id: Optional[str] = None user_id: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取消息历史 """获取消息历史
Args: Args:
start_time: 开始时间戳 start_time: 开始时间戳
end_time: 结束时间戳 end_time: 结束时间戳
limit: 限制返回消息数量 limit: 限制返回消息数量
user_id: 指定用户ID user_id: 指定用户ID
Returns: Returns:
List[Dict[str, Any]]: 消息列表 List[Dict[str, Any]]: 消息列表
""" """
filtered_messages = self.message_history filtered_messages = self.message_history
if start_time is not None: if start_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time] filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
if end_time is not None: if end_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time] filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
if user_id is not None: if user_id is not None:
filtered_messages = [ filtered_messages = [
m for m in filtered_messages m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
] ]
if limit is not None: if limit is not None:
filtered_messages = filtered_messages[-limit:] filtered_messages = filtered_messages[-limit:]
return filtered_messages return filtered_messages
async def _fetch_new_messages(self) -> List[Dict[str, Any]]: async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息 """获取新消息
Returns: Returns:
List[Dict[str, Any]]: 新消息列表 List[Dict[str, Any]]: 新消息列表
""" """
@@ -231,15 +231,15 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
return new_messages return new_messages
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]: async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息 """获取指定时间点之前的消息
Args: Args:
time_point: 时间戳 time_point: 时间戳
Returns: Returns:
List[Dict[str, Any]]: 最多5条消息 List[Dict[str, Any]]: 最多5条消息
""" """
@@ -250,7 +250,7 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
return new_messages return new_messages
'''主要观察循环''' '''主要观察循环'''
@@ -263,7 +263,7 @@ class ChatObserver:
await self._add_message_to_history(message) await self._add_message_to_history(message)
except Exception as e: except Exception as e:
logger.error(f"缓冲消息出错: {e}") logger.error(f"缓冲消息出错: {e}")
while self._running: while self._running:
try: try:
# 等待事件或超时1秒 # 等待事件或超时1秒
@@ -271,13 +271,13 @@ class ChatObserver:
await asyncio.wait_for(self._update_event.wait(), timeout=1) await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass # 超时后也执行一次检查 pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件 self._update_event.clear() # 重置触发事件
self._update_complete.clear() # 重置完成事件 self._update_complete.clear() # 重置完成事件
# 获取新消息 # 获取新消息
new_messages = await self._fetch_new_messages() new_messages = await self._fetch_new_messages()
if new_messages: if new_messages:
# 处理新消息 # 处理新消息
for message in new_messages: for message in new_messages:
@@ -285,21 +285,21 @@ class ChatObserver:
# 设置完成事件 # 设置完成事件
self._update_complete.set() self._update_complete.set()
except Exception as e: except Exception as e:
logger.error(f"更新循环出错: {e}") logger.error(f"更新循环出错: {e}")
self._update_complete.set() # 即使出错也要设置完成事件 self._update_complete.set() # 即使出错也要设置完成事件
def trigger_update(self): def trigger_update(self):
"""触发一次立即更新""" """触发一次立即更新"""
self._update_event.set() self._update_event.set()
async def wait_for_update(self, timeout: float = 5.0) -> bool: async def wait_for_update(self, timeout: float = 5.0) -> bool:
"""等待更新完成 """等待更新完成
Args: Args:
timeout: 超时时间(秒) timeout: 超时时间(秒)
Returns: Returns:
bool: 是否成功完成更新False表示超时 bool: 是否成功完成更新False表示超时
""" """
@@ -309,16 +309,16 @@ class ChatObserver:
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"等待更新完成超时({timeout}秒)") logger.warning(f"等待更新完成超时({timeout}秒)")
return False return False
def start(self): def start(self):
"""启动观察器""" """启动观察器"""
if self._running: if self._running:
return return
self._running = True self._running = True
self._task = asyncio.create_task(self._update_loop()) self._task = asyncio.create_task(self._update_loop())
logger.info(f"ChatObserver for {self.stream_id} started") logger.info(f"ChatObserver for {self.stream_id} started")
def stop(self): def stop(self):
"""停止观察器""" """停止观察器"""
self._running = False self._running = False
@@ -327,15 +327,15 @@ class ChatObserver:
if self._task: if self._task:
self._task.cancel() self._task.cancel()
logger.info(f"ChatObserver for {self.stream_id} stopped") logger.info(f"ChatObserver for {self.stream_id} stopped")
async def process_chat_history(self, messages: list): async def process_chat_history(self, messages: list):
"""处理聊天历史 """处理聊天历史
Args: Args:
messages: 消息列表 messages: 消息列表
""" """
self.update_check_time() self.update_check_time()
for msg in messages: for msg in messages:
try: try:
user_info = UserInfo.from_dict(msg.get("user_info", {})) user_info = UserInfo.from_dict(msg.get("user_info", {}))
@@ -345,33 +345,33 @@ class ChatObserver:
self.update_user_speak_time(msg["time"]) self.update_user_speak_time(msg["time"])
except Exception as e: except Exception as e:
logger.warning(f"处理消息时间时出错: {e}") logger.warning(f"处理消息时间时出错: {e}")
continue continue
def update_check_time(self): def update_check_time(self):
"""更新查看时间""" """更新查看时间"""
self.last_check_time = time.time() self.last_check_time = time.time()
def update_bot_speak_time(self, speak_time: Optional[float] = None): def update_bot_speak_time(self, speak_time: Optional[float] = None):
"""更新机器人说话时间""" """更新机器人说话时间"""
self.last_bot_speak_time = speak_time or time.time() self.last_bot_speak_time = speak_time or time.time()
def update_user_speak_time(self, speak_time: Optional[float] = None): def update_user_speak_time(self, speak_time: Optional[float] = None):
"""更新用户说话时间""" """更新用户说话时间"""
self.last_user_speak_time = speak_time or time.time() self.last_user_speak_time = speak_time or time.time()
def get_time_info(self) -> str: def get_time_info(self) -> str:
"""获取时间信息文本""" """获取时间信息文本"""
current_time = time.time() current_time = time.time()
time_info = "" time_info = ""
if self.last_bot_speak_time: if self.last_bot_speak_time:
bot_speak_ago = current_time - self.last_bot_speak_time bot_speak_ago = current_time - self.last_bot_speak_time
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}" time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}"
if self.last_user_speak_time: if self.last_user_speak_time:
user_speak_ago = current_time - self.last_user_speak_time user_speak_ago = current_time - self.last_user_speak_time
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}" time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}"
return time_info return time_info
def start_periodic_update(self): def start_periodic_update(self):

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
#Prefrontal cortex # Prefrontal cortex
import datetime import datetime
import asyncio import asyncio
from typing import List, Optional, Tuple, TYPE_CHECKING from typing import List, Optional, Tuple, TYPE_CHECKING
@@ -29,20 +29,17 @@ logger = get_module_logger("pfc")
class GoalAnalyzer: class GoalAnalyzer:
"""对话目标分析器""" """对话目标分析器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
temperature=0.7,
max_tokens=1000,
request_type="conversation_goal"
) )
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES self.nick_name = global_config.BOT_ALIAS_NAMES
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
# 多目标存储结构 # 多目标存储结构
self.goals = [] # 存储多个目标 self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量 self.max_goals = 3 # 同时保持的最大目标数量
@@ -50,10 +47,10 @@ class GoalAnalyzer:
async def analyze_goal(self) -> Tuple[str, str, str]: async def analyze_goal(self) -> Tuple[str, str, str]:
"""分析对话历史并设定目标 """分析对话历史并设定目标
Args: Args:
chat_history: 聊天历史记录列表 chat_history: 聊天历史记录列表
Returns: Returns:
Tuple[str, str, str]: (目标, 方法, 原因) Tuple[str, str, str]: (目标, 方法, 原因)
""" """
@@ -70,16 +67,16 @@ class GoalAnalyzer:
if sender == self.name: if sender == self.name:
sender = "你说" sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n" chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建当前已有目标的文本 # 构建当前已有目标的文本
existing_goals_text = "" existing_goals_text = ""
if self.goals: if self.goals:
existing_goals_text = "当前已有的对话目标:\n" existing_goals_text = "当前已有的对话目标:\n"
for i, (goal, _, reason) in enumerate(self.goals): for i, (goal, _, reason) in enumerate(self.goals):
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n" existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。 这些目标应该反映出对话的不同方面和意图。
@@ -107,46 +104,44 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}") logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}") logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
) )
if not success: if not success:
logger.error(f"无法解析JSON重试第{retry + 1}") logger.error(f"无法解析JSON重试第{retry + 1}")
continue continue
goal = result["goal"] goal = result["goal"]
reasoning = result["reasoning"] reasoning = result["reasoning"]
# 使用默认的方法 # 使用默认的方法
method = "以友好的态度回应" method = "以友好的态度回应"
# 更新目标列表 # 更新目标列表
await self._update_goals(goal, method, reasoning) await self._update_goals(goal, method, reasoning)
# 返回当前最主要的目标 # 返回当前最主要的目标
if self.goals: if self.goals:
current_goal, current_method, current_reasoning = self.goals[0] current_goal, current_method, current_reasoning = self.goals[0]
return current_goal, current_method, current_reasoning return current_goal, current_method, current_reasoning
else: else:
return goal, method, reasoning return goal, method, reasoning
except Exception as e: except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}") logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}")
if retry == max_retries - 1: if retry == max_retries - 1:
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行" return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
continue continue
# 所有重试都失败后的默认返回 # 所有重试都失败后的默认返回
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行" return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
async def _update_goals(self, new_goal: str, method: str, reasoning: str): async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表 """更新目标列表
Args: Args:
new_goal: 新的目标 new_goal: 新的目标
method: 实现目标的方法 method: 实现目标的方法
@@ -160,23 +155,23 @@ class GoalAnalyzer:
# 将此目标移到列表前面(最主要的位置) # 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i)) self.goals.insert(0, self.goals.pop(i))
return return
# 添加新目标到列表前面 # 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning)) self.goals.insert(0, (new_goal, method, reasoning))
# 限制目标数量 # 限制目标数量
if len(self.goals) > self.max_goals: if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标 self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float: def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度 """简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法 这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args: Args:
goal1: 第一个目标 goal1: 第一个目标
goal2: 第二个目标 goal2: 第二个目标
Returns: Returns:
float: 相似度得分 (0-1) float: 相似度得分 (0-1)
""" """
@@ -186,18 +181,18 @@ class GoalAnalyzer:
overlap = len(words1.intersection(words2)) overlap = len(words1.intersection(words2))
total = len(words1.union(words2)) total = len(words1.union(words2))
return overlap / total if total > 0 else 0 return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]: async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标 """获取所有当前目标
Returns: Returns:
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因) List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
""" """
return self.goals.copy() return self.goals.copy()
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]: async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标 """获取除了当前主要目标外的其他备选目标
Returns: Returns:
List[Tuple[str, str, str]]: 备选目标列表 List[Tuple[str, str, str]]: 备选目标列表
""" """
@@ -215,9 +210,9 @@ class GoalAnalyzer:
if sender == self.name: if sender == self.name:
sender = "你说" sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n" chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天 prompt = f"""{personality_text}。现在你在参与一场QQ聊天
当前对话目标:{goal} 当前对话目标:{goal}
产生该对话目标的原因:{reasoning} 产生该对话目标的原因:{reasoning}
@@ -247,7 +242,7 @@ class GoalAnalyzer:
"goal_achieved", "stop_conversation", "reason", "goal_achieved", "stop_conversation", "reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str} required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
) )
if not success: if not success:
logger.error("无法解析对话分析结果JSON") logger.error("无法解析对话分析结果JSON")
return False, False, "解析结果失败" return False, False, "解析结果失败"
@@ -265,14 +260,15 @@ class GoalAnalyzer:
class Waiter: class Waiter:
"""快 速 等 待""" """快 速 等 待"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
async def wait(self) -> bool: async def wait(self) -> bool:
"""等待 """等待
Returns: Returns:
bool: 是否超时True表示超时 bool: 是否超时True表示超时
""" """
@@ -298,7 +294,7 @@ class Waiter:
class DirectMessageSender: class DirectMessageSender:
"""直接发送消息到平台的发送器""" """直接发送消息到平台的发送器"""
def __init__(self): def __init__(self):
self.logger = get_module_logger("direct_sender") self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage() self.storage = MessageStorage()
@@ -310,7 +306,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None, reply_to_message: Optional[Message] = None,
) -> None: ) -> None:
"""直接发送消息到平台 """直接发送消息到平台
Args: Args:
chat_stream: 聊天流 chat_stream: 聊天流
content: 消息内容 content: 消息内容
@@ -323,7 +319,7 @@ class DirectMessageSender:
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
platform=chat_stream.platform, platform=chat_stream.platform,
) )
message = MessageSending( message = MessageSending(
message_id=f"dm{round(time.time(), 2)}", message_id=f"dm{round(time.time(), 2)}",
chat_stream=chat_stream, chat_stream=chat_stream,
@@ -343,18 +339,17 @@ class DirectMessageSender:
try: try:
message_json = message.to_dict() message_json = message.to_dict()
end_point = global_config.api_urls.get(chat_stream.platform, None) end_point = global_config.api_urls.get(chat_stream.platform, None)
if not end_point: if not end_point:
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置") raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
await global_api.send_message_REST(end_point, message_json) await global_api.send_message_REST(end_point, message_json)
# 存储消息 # 存储消息
await self.storage.store_message(message, message.chat_stream) await self.storage.store_message(message, message.chat_stream)
self.logger.info(f"直接发送消息成功: {content[:30]}...") self.logger.info(f"直接发送消息成功: {content[:30]}...")
except Exception as e: except Exception as e:
self.logger.error(f"直接发送消息失败: {str(e)}") self.logger.error(f"直接发送消息失败: {str(e)}")
raise raise

View File

@@ -7,24 +7,22 @@ from ..chat.message import Message
logger = get_module_logger("knowledge_fetcher") logger = get_module_logger("knowledge_fetcher")
class KnowledgeFetcher: class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self): def __init__(self):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
temperature=0.7,
max_tokens=1000,
request_type="knowledge_fetch"
) )
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]: async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识 """获取相关知识
Args: Args:
query: 查询内容 query: 查询内容
chat_history: 聊天历史 chat_history: 聊天历史
Returns: Returns:
Tuple[str, str]: (获取的知识, 知识来源) Tuple[str, str]: (获取的知识, 知识来源)
""" """
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
for msg in chat_history: for msg in chat_history:
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}" # sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
chat_history_text += f"{msg.detailed_plain_text}\n" chat_history_text += f"{msg.detailed_plain_text}\n"
# 从记忆中获取相关知识 # 从记忆中获取相关知识
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=f"{query}\n{chat_history_text}", text=f"{query}\n{chat_history_text}",
max_memory_num=3, max_memory_num=3,
max_memory_length=2, max_memory_length=2,
max_depth=3, max_depth=3,
fast_retrieval=False fast_retrieval=False,
) )
if related_memory: if related_memory:
knowledge = "" knowledge = ""
sources = [] sources = []
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
knowledge += memory[1] + "\n" knowledge += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}") sources.append(f"记忆片段{memory[0]}")
return knowledge.strip(), "".join(sources) return knowledge.strip(), "".join(sources)
return "未找到相关知识", "无记忆匹配" return "未找到相关知识", "无记忆匹配"

View File

@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils") logger = get_module_logger("pfc_utils")
def get_items_from_json( def get_items_from_json(
content: str, content: str,
*items: str, *items: str,
default_values: Optional[Dict[str, Any]] = None, default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]: ) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段 """从文本中提取JSON内容并获取指定字段
Args: Args:
content: 包含JSON的文本 content: 包含JSON的文本
*items: 要提取的字段名 *items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值} default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型} required_types: 字段的必需类型,格式为 {字段名: 类型}
Returns: Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典) Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
""" """
content = content.strip() content = content.strip()
result = {} result = {}
# 设置默认值 # 设置默认值
if default_values: if default_values:
result.update(default_values) result.update(default_values)
# 尝试解析JSON # 尝试解析JSON
try: try:
json_data = json.loads(content) json_data = json.loads(content)
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分 # 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r'\{[^{}]*\}' json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content) json_match = re.search(json_pattern, content)
if json_match: if json_match:
try: try:
@@ -45,28 +46,28 @@ def get_items_from_json(
else: else:
logger.error("无法在返回内容中找到有效的JSON") logger.error("无法在返回内容中找到有效的JSON")
return False, result return False, result
# 提取字段 # 提取字段
for item in items: for item in items:
if item in json_data: if item in json_data:
result[item] = json_data[item] result[item] = json_data[item]
# 验证必需字段 # 验证必需字段
if not all(item in result for item in items): if not all(item in result for item in items):
logger.error(f"JSON缺少必要字段实际内容: {json_data}") logger.error(f"JSON缺少必要字段实际内容: {json_data}")
return False, result return False, result
# 验证字段类型 # 验证字段类型
if required_types: if required_types:
for field, expected_type in required_types.items(): for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type): if field in result and not isinstance(result[field], expected_type):
logger.error(f"{field} 必须是 {expected_type.__name__} 类型") logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
return False, result return False, result
# 验证字符串字段不为空 # 验证字符串字段不为空
for field in items: for field in items:
if isinstance(result[field], str) and not result[field].strip(): if isinstance(result[field], str) and not result[field].strip():
logger.error(f"{field} 不能为空") logger.error(f"{field} 不能为空")
return False, result return False, result
return True, result return True, result

View File

@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
logger = get_module_logger("reply_checker") logger = get_module_logger("reply_checker")
class ReplyChecker: class ReplyChecker:
"""回复检查器""" """回复检查器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
temperature=0.7,
max_tokens=1000,
request_type="reply_check"
) )
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.max_retries = 2 # 最大重试次数 self.max_retries = 2 # 最大重试次数
async def check( async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适 """检查生成的回复是否合适
Args: Args:
reply: 生成的回复 reply: 生成的回复
goal: 对话目标 goal: 对话目标
retry_count: 当前重试次数 retry_count: 当前重试次数
Returns: Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划) Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
""" """
@@ -49,7 +42,7 @@ class ReplyChecker:
if sender == self.name: if sender == self.name:
sender = "你说" sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n" chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
prompt = f"""请检查以下回复是否合适: prompt = f"""请检查以下回复是否合适:
当前对话目标:{goal} 当前对话目标:{goal}
@@ -83,7 +76,7 @@ class ReplyChecker:
try: try:
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"检查回复的原始返回: {content}") logger.debug(f"检查回复的原始返回: {content}")
# 清理内容尝试提取JSON部分 # 清理内容尝试提取JSON部分
content = content.strip() content = content.strip()
try: try:
@@ -92,7 +85,8 @@ class ReplyChecker:
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分 # 如果直接解析失败尝试查找和提取JSON部分
import re import re
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content) json_match = re.search(json_pattern, content)
if json_match: if json_match:
try: try:
@@ -109,33 +103,33 @@ class ReplyChecker:
reason = content[:100] if content else "无法解析响应" reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower() need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan return is_suitable, reason, need_replan
# 验证JSON字段 # 验证JSON字段
suitable = result.get("suitable", None) suitable = result.get("suitable", None)
reason = result.get("reason", "未提供原因") reason = result.get("reason", "未提供原因")
need_replan = result.get("need_replan", False) need_replan = result.get("need_replan", False)
# 如果suitable字段是字符串转换为布尔值 # 如果suitable字段是字符串转换为布尔值
if isinstance(suitable, str): if isinstance(suitable, str):
suitable = suitable.lower() == "true" suitable = suitable.lower() == "true"
# 如果suitable字段不存在或不是布尔值从reason中判断 # 如果suitable字段不存在或不是布尔值从reason中判断
if suitable is None: if suitable is None:
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower() suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
# 如果不合适且未达到最大重试次数,返回需要重试 # 如果不合适且未达到最大重试次数,返回需要重试
if not suitable and retry_count < self.max_retries: if not suitable and retry_count < self.max_retries:
return False, reason, False return False, reason, False
# 如果不合适且已达到最大重试次数,返回需要重新规划 # 如果不合适且已达到最大重试次数,返回需要重新规划
if not suitable and retry_count >= self.max_retries: if not suitable and retry_count >= self.max_retries:
return False, f"多次重试后仍不合适: {reason}", True return False, f"多次重试后仍不合适: {reason}", True
return suitable, reason, need_replan return suitable, reason, need_replan
except Exception as e: except Exception as e:
logger.error(f"检查回复时出错: {e}") logger.error(f"检查回复时出错: {e}")
# 如果出错且已达到最大重试次数,建议重新规划 # 如果出错且已达到最大重试次数,建议重新规划
if retry_count >= self.max_retries: if retry_count >= self.max_retries:
return False, "多次检查失败,建议重新规划", True return False, "多次检查失败,建议重新规划", True
return False, f"检查过程出错,建议重试: {str(e)}", False return False, f"检查过程出错,建议重试: {str(e)}", False

View File

@@ -12,5 +12,5 @@ __all__ = [
"chat_manager", "chat_manager",
"message_manager", "message_manager",
"MessageStorage", "MessageStorage",
"auto_speak_manager" "auto_speak_manager",
] ]

View File

@@ -44,11 +44,11 @@ class ChatBot:
async def _create_PFC_chat(self, message: MessageRecv): async def _create_PFC_chat(self, message: MessageRecv):
try: try:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id) await self.pfc_manager.get_or_create_conversation(chat_id)
except Exception as e: except Exception as e:
logger.error(f"创建PFC聊天失败: {e}") logger.error(f"创建PFC聊天失败: {e}")
@@ -59,16 +59,16 @@ class ChatBot:
- 包含思维流状态管理 - 包含思维流状态管理
- 在回复前进行观察和状态更新 - 在回复前进行观察和状态更新
- 回复后更新思维流状态 - 回复后更新思维流状态
2. reasoning模式使用推理系统进行回复 2. reasoning模式使用推理系统进行回复
- 直接使用意愿管理器计算回复概率 - 直接使用意愿管理器计算回复概率
- 没有思维流相关的状态管理 - 没有思维流相关的状态管理
- 更简单直接的回复逻辑 - 更简单直接的回复逻辑
3. pfc_chatting模式仅进行消息处理 3. pfc_chatting模式仅进行消息处理
- 不进行任何回复 - 不进行任何回复
- 只处理和存储消息 - 只处理和存储消息
所有模式都包含: 所有模式都包含:
- 消息过滤 - 消息过滤
- 记忆激活 - 记忆激活
@@ -89,7 +89,7 @@ class ChatBot:
if userinfo.user_id in global_config.ban_user_id: if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复") logger.debug(f"用户{userinfo.user_id}被禁止回复")
return return
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:
try: try:
if groupinfo is None and global_config.enable_friend_chat: if groupinfo is None and global_config.enable_friend_chat:
@@ -118,7 +118,7 @@ class ChatBot:
logger.error(f"处理PFC消息失败: {e}") logger.error(f"处理PFC消息失败: {e}")
else: else:
if groupinfo is None and global_config.enable_friend_chat: if groupinfo is None and global_config.enable_friend_chat:
# 私聊处理流程 # 私聊处理流程
# await self._handle_private_chat(message) # await self._handle_private_chat(message)
if global_config.response_mode == "heart_flow": if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data) await self.think_flow_chat.process_message(message_data)

View File

@@ -38,11 +38,11 @@ class EmojiManager:
self.llm_emotion_judge = LLM_request( self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji" model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0 self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
logger.info("启动表情包管理器") logger.info("启动表情包管理器")
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
@@ -51,7 +51,7 @@ class EmojiManager:
def _update_emoji_count(self): def _update_emoji_count(self):
"""更新表情包数量统计 """更新表情包数量统计
检查数据库中的表情包数量并更新到 self.emoji_num 检查数据库中的表情包数量并更新到 self.emoji_num
""" """
try: try:
@@ -376,7 +376,6 @@ class EmojiManager:
except Exception: except Exception:
logger.exception("[错误] 扫描表情包失败") logger.exception("[错误] 扫描表情包失败")
def check_emoji_file_integrity(self): def check_emoji_file_integrity(self):
"""检查表情包文件完整性 """检查表情包文件完整性
@@ -451,7 +450,7 @@ class EmojiManager:
def check_emoji_file_full(self): def check_emoji_file_full(self):
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包 """检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
删除规则: 删除规则:
1. 优先删除创建时间更早的表情包 1. 优先删除创建时间更早的表情包
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除 2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
@@ -460,23 +459,23 @@ class EmojiManager:
self._ensure_db() self._ensure_db()
# 更新表情包数量 # 更新表情包数量
self._update_emoji_count() self._update_emoji_count()
# 检查是否超出限制 # 检查是否超出限制
if self.emoji_num <= self.emoji_num_max: if self.emoji_num <= self.emoji_num_max:
return return
# 如果超出限制但不允许删除,则只记录警告 # 如果超出限制但不允许删除,则只记录警告
if not global_config.max_reach_deletion: if not global_config.max_reach_deletion:
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除") logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
return return
# 计算需要删除的数量 # 计算需要删除的数量
delete_count = self.emoji_num - self.emoji_num_max delete_count = self.emoji_num - self.emoji_num_max
logger.info(f"[清理] 需要删除 {delete_count} 个表情包") logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
# 获取所有表情包,按时间戳升序(旧的在前)排序 # 获取所有表情包,按时间戳升序(旧的在前)排序
all_emojis = list(db.emoji.find().sort([("timestamp", 1)])) all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
# 计算权重:使用次数越多,被删除的概率越小 # 计算权重:使用次数越多,被删除的概率越小
weights = [] weights = []
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1) max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
@@ -485,11 +484,11 @@ class EmojiManager:
# 使用指数衰减函数计算权重,使用次数越多权重越小 # 使用指数衰减函数计算权重,使用次数越多权重越小
weight = 1.0 / (1.0 + usage_count / max(1, max_usage)) weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
weights.append(weight) weights.append(weight)
# 根据权重随机选择要删除的表情包 # 根据权重随机选择要删除的表情包
to_delete = [] to_delete = []
remaining_indices = list(range(len(all_emojis))) remaining_indices = list(range(len(all_emojis)))
while len(to_delete) < delete_count and remaining_indices: while len(to_delete) < delete_count and remaining_indices:
# 计算当前剩余表情包的权重 # 计算当前剩余表情包的权重
current_weights = [weights[i] for i in remaining_indices] current_weights = [weights[i] for i in remaining_indices]
@@ -497,13 +496,13 @@ class EmojiManager:
total_weight = sum(current_weights) total_weight = sum(current_weights)
if total_weight == 0: if total_weight == 0:
break break
normalized_weights = [w/total_weight for w in current_weights] normalized_weights = [w / total_weight for w in current_weights]
# 随机选择一个表情包 # 随机选择一个表情包
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0] selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
to_delete.append(all_emojis[selected_idx]) to_delete.append(all_emojis[selected_idx])
remaining_indices.remove(selected_idx) remaining_indices.remove(selected_idx)
# 删除选中的表情包 # 删除选中的表情包
deleted_count = 0 deleted_count = 0
for emoji in to_delete: for emoji in to_delete:
@@ -512,26 +511,26 @@ class EmojiManager:
if "path" in emoji and os.path.exists(emoji["path"]): if "path" in emoji and os.path.exists(emoji["path"]):
os.remove(emoji["path"]) os.remove(emoji["path"])
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})") logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
# 删除数据库记录 # 删除数据库记录
db.emoji.delete_one({"_id": emoji["_id"]}) db.emoji.delete_one({"_id": emoji["_id"]})
deleted_count += 1 deleted_count += 1
# 同时从images集合中删除 # 同时从images集合中删除
if "hash" in emoji: if "hash" in emoji:
db.images.delete_one({"hash": emoji["hash"]}) db.images.delete_one({"hash": emoji["hash"]})
except Exception as e: except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}") logger.error(f"[错误] 删除表情包失败: {str(e)}")
continue continue
# 更新表情包数量 # 更新表情包数量
self._update_emoji_count() self._update_emoji_count()
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}") logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
except Exception as e: except Exception as e:
logger.error(f"[错误] 检查表情包数量失败: {str(e)}") logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
async def start_periodic_check_register(self): async def start_periodic_check_register(self):
"""定期检查表情包完整性和数量""" """定期检查表情包完整性和数量"""
while True: while True:
@@ -542,7 +541,7 @@ class EmojiManager:
logger.info("[扫描] 开始扫描新表情包...") logger.info("[扫描] 开始扫描新表情包...")
if self.emoji_num < self.emoji_num_max: if self.emoji_num < self.emoji_num_max:
await self.scan_new_emojis() await self.scan_new_emojis()
if (self.emoji_num > self.emoji_num_max): if self.emoji_num > self.emoji_num_max:
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册") logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
if not global_config.max_reach_deletion: if not global_config.max_reach_deletion:
logger.warning("表情包数量超过最大限制,终止注册") logger.warning("表情包数量超过最大限制,终止注册")
@@ -551,7 +550,7 @@ class EmojiManager:
logger.warning("表情包数量超过最大限制,开始删除表情包") logger.warning("表情包数量超过最大限制,开始删除表情包")
self.check_emoji_file_full() self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def delete_all_images(self): async def delete_all_images(self):
"""删除 data/image 目录下的所有文件""" """删除 data/image 目录下的所有文件"""
try: try:
@@ -559,10 +558,10 @@ class EmojiManager:
if not os.path.exists(image_dir): if not os.path.exists(image_dir):
logger.warning(f"[警告] 目录不存在: {image_dir}") logger.warning(f"[警告] 目录不存在: {image_dir}")
return return
deleted_count = 0 deleted_count = 0
failed_count = 0 failed_count = 0
# 遍历目录下的所有文件 # 遍历目录下的所有文件
for filename in os.listdir(image_dir): for filename in os.listdir(image_dir):
file_path = os.path.join(image_dir, filename) file_path = os.path.join(image_dir, filename)
@@ -574,11 +573,12 @@ class EmojiManager:
except Exception as e: except Exception as e:
failed_count += 1 failed_count += 1
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}") logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count}") logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count}")
except Exception as e: except Exception as e:
logger.error(f"[错误] 删除图片目录失败: {str(e)}") logger.error(f"[错误] 删除图片目录失败: {str(e)}")
# 创建全局单例 # 创建全局单例
emoji_manager = EmojiManager() emoji_manager = EmojiManager()

View File

@@ -13,9 +13,10 @@ from ..config.config import global_config
logger = get_module_logger("message_buffer") logger = get_module_logger("message_buffer")
@dataclass @dataclass
class CacheMessages: class CacheMessages:
message: MessageRecv message: MessageRecv
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果 cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
result: str = "U" result: str = "U"
@@ -25,7 +26,7 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {} self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def get_person_id_(self, platform:str, user_id:str, group_info:GroupInfo): def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id""" """获取唯一id"""
if group_info: if group_info:
group_id = group_info.group_id group_id = group_info.group_id
@@ -34,16 +35,17 @@ class MessageBuffer:
key = f"{platform}_{user_id}_{group_id}" key = f"{platform}_{user_id}_{group_id}"
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
async def start_caching_messages(self, message:MessageRecv): async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲""" """添加消息,启动缓冲"""
if not global_config.message_buffer: if not global_config.message_buffer:
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform, person_id = person_info_manager.get_person_id(
message.message_info.user_info.user_id) message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info)) asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return return
person_id_ = self.get_person_id_(message.message_info.platform, person_id_ = self.get_person_id_(
message.message_info.user_info.user_id, message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
message.message_info.group_info) )
async with self.lock: async with self.lock:
if person_id_ not in self.buffer_pool: if person_id_ not in self.buffer_pool:
@@ -64,25 +66,24 @@ class MessageBuffer:
break break
elif msg.result == "F": elif msg.result == "F":
recent_F_count += 1 recent_F_count += 1
# 判断条件最近T之后有超过3-5条F # 判断条件最近T之后有超过3-5条F
if (recent_F_count >= random.randint(3, 5)): if recent_F_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T") new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set() new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}") logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
return return
# 添加新消息 # 添加新消息
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message) self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
# 启动3秒缓冲计时器 # 启动3秒缓冲计时器
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform, person_id = person_info_manager.get_person_id(
message.message_info.user_info.user_id) message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info)) asyncio.create_task(self.save_message_interval(person_id, message.message_info))
asyncio.create_task(self._debounce_processor(person_id_, asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
message.message_info.message_id,
person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str): async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息""" """等待3秒无新消息"""
@@ -92,36 +93,33 @@ class MessageBuffer:
return return
interval_time = max(0.5, int(interval_time) / 1000) interval_time = max(0.5, int(interval_time) / 1000)
await asyncio.sleep(interval_time) await asyncio.sleep(interval_time)
async with self.lock: async with self.lock:
if (person_id_ not in self.buffer_pool or if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
message_id not in self.buffer_pool[person_id_]):
logger.debug(f"消息已被清理msgid: {message_id}") logger.debug(f"消息已被清理msgid: {message_id}")
return return
cache_msg = self.buffer_pool[person_id_][message_id] cache_msg = self.buffer_pool[person_id_][message_id]
if cache_msg.result == "U": if cache_msg.result == "U":
cache_msg.result = "T" cache_msg.result = "T"
cache_msg.cache_determination.set() cache_msg.cache_determination.set()
async def query_buffer_result(self, message: MessageRecv) -> bool:
async def query_buffer_result(self, message:MessageRecv) -> bool:
"""查询缓冲结果,并清理""" """查询缓冲结果,并清理"""
if not global_config.message_buffer: if not global_config.message_buffer:
return True return True
person_id_ = self.get_person_id_(message.message_info.platform, person_id_ = self.get_person_id_(
message.message_info.user_info.user_id, message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
message.message_info.group_info) )
async with self.lock: async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {}) user_msgs = self.buffer_pool.get(person_id_, {})
cache_msg = user_msgs.get(message.message_info.message_id) cache_msg = user_msgs.get(message.message_info.message_id)
if not cache_msg: if not cache_msg:
logger.debug(f"查询异常消息不存在msgid: {message.message_info.message_id}") logger.debug(f"查询异常消息不存在msgid: {message.message_info.message_id}")
return False # 消息不存在或已清理 return False # 消息不存在或已清理
try: try:
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10) await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
result = cache_msg.result == "T" result = cache_msg.result == "T"
@@ -144,9 +142,8 @@ class MessageBuffer:
keep_msgs[msg_id] = msg keep_msgs[msg_id] = msg
elif msg.result == "F": elif msg.result == "F":
# 收集F消息的文本内容 # 收集F消息的文本内容
if (hasattr(msg.message, 'processed_plain_text') if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
and msg.message.processed_plain_text): if msg.message.message_segment.type == "text":
if msg.message.message_segment.type == "text":
combined_text.append(msg.message.processed_plain_text) combined_text.append(msg.message.processed_plain_text)
elif msg.message.message_segment.type != "text": elif msg.message.message_segment.type != "text":
is_update = False is_update = False
@@ -157,20 +154,20 @@ class MessageBuffer:
if combined_text and combined_text[0] != message.processed_plain_text and is_update: if combined_text and combined_text[0] != message.processed_plain_text and is_update:
if type == "text": if type == "text":
message.processed_plain_text = "".join(combined_text) message.processed_plain_text = "".join(combined_text)
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息") logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
elif type == "emoji": elif type == "emoji":
combined_text.pop() combined_text.pop()
message.processed_plain_text = "".join(combined_text) message.processed_plain_text = "".join(combined_text)
message.is_emoji = False message.is_emoji = False
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容覆盖当前emoji消息") logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容覆盖当前emoji消息")
self.buffer_pool[person_id_] = keep_msgs self.buffer_pool[person_id_] = keep_msgs
return result return result
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug(f"查询超时消息id {message.message_info.message_id}") logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False return False
async def save_message_interval(self, person_id:str, message:BaseMessageInfo): async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list") message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000)) now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000: if len(message_interval_list) < 1000:
@@ -179,12 +176,12 @@ class MessageBuffer:
message_interval_list.pop(0) message_interval_list.pop(0)
message_interval_list.append(now_time_ms) message_interval_list.append(now_time_ms)
data = { data = {
"platform" : message.platform, "platform": message.platform,
"user_id" : message.user_info.user_id, "user_id": message.user_info.user_id,
"nickname" : message.user_info.user_nickname, "nickname": message.user_info.user_nickname,
"konw_time" : int(time.time()) "konw_time": int(time.time()),
} }
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data) await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
message_buffer = MessageBuffer() message_buffer = MessageBuffer()

View File

@@ -68,7 +68,8 @@ class Message_Sender:
typing_time = calculate_typing_time( typing_time = calculate_typing_time(
input_string=message.processed_plain_text, input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time, thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji) is_emoji=message.is_emoji,
)
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束") logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束") logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
@@ -227,7 +228,7 @@ class MessageManager:
await message_earliest.process() await message_earliest.process()
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}") # print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
await message_sender.send_message(message_earliest) await message_sender.send_message(message_earliest)
await self.storage.store_message(message_earliest, message_earliest.chat_stream) await self.storage.store_message(message_earliest, message_earliest.chat_stream)

View File

@@ -56,14 +56,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
logger.info("被@回复概率设置为100%") logger.info("被@回复概率设置为100%")
else: else:
if not is_mentioned: if not is_mentioned:
# 判断是否被回复 # 判断是否被回复
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text): if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
is_mentioned = True is_mentioned = True
# 判断内容中是否被提及 # 判断内容中是否被提及
message_content = re.sub(r'\@[\s\S]*?(\d+)','', message.processed_plain_text) message_content = re.sub(r"\@[\s\S]*?(\d+)", "", message.processed_plain_text)
message_content = re.sub(r'回复[\s\S]*?\((\d+)\)的消息,说: ','', message_content) message_content = re.sub(r"回复[\s\S]*?\((\d+)\)的消息,说: ", "", message_content)
for keyword in keywords: for keyword in keywords:
if keyword in message_content: if keyword in message_content:
is_mentioned = True is_mentioned = True
@@ -359,7 +358,13 @@ def process_llm_response(text: str) -> List[str]:
return sentences return sentences
def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float: def calculate_typing_time(
input_string: str,
thinking_start_time: float,
chinese_time: float = 0.2,
english_time: float = 0.1,
is_emoji: bool = False,
) -> float:
""" """
计算输入字符串所需的时间,中文和英文字符有不同的输入时间 计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串 input_string (str): 输入的字符串
@@ -393,19 +398,18 @@ def calculate_typing_time(input_string: str, thinking_start_time: float, chinese
total_time += chinese_time total_time += chinese_time
else: # 其他字符(如英文) else: # 其他字符(如英文)
total_time += english_time total_time += english_time
if is_emoji: if is_emoji:
total_time = 1 total_time = 1
if time.time() - thinking_start_time > 10: if time.time() - thinking_start_time > 10:
total_time = 1 total_time = 1
# print(f"thinking_start_time:{thinking_start_time}") # print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}") # print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}") # print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}") # print(f"{total_time}")
return total_time # 加上回车时间 return total_time # 加上回车时间
@@ -535,39 +539,32 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
try: try:
# 获取开始时间之前最新的一条消息 # 获取开始时间之前最新的一条消息
start_message = db.messages.find_one( start_message = db.messages.find_one(
{ {"chat_id": stream_id, "time": {"$lte": start_time}},
"chat_id": stream_id, sort=[("time", -1), ("_id", -1)], # 按时间倒序_id倒序最后插入的在前
"time": {"$lte": start_time}
},
sort=[("time", -1), ("_id", -1)] # 按时间倒序_id倒序最后插入的在前
) )
# 获取结束时间最近的一条消息 # 获取结束时间最近的一条消息
# 先找到结束时间点的所有消息 # 先找到结束时间点的所有消息
end_time_messages = list(db.messages.find( end_time_messages = list(
{ db.messages.find(
"chat_id": stream_id, {"chat_id": stream_id, "time": {"$lte": end_time}},
"time": {"$lte": end_time} sort=[("time", -1)], # 先按时间倒序
}, ).limit(10)
sort=[("time", -1)] # 先按时间倒序 ) # 限制查询数量,避免性能问题
).limit(10)) # 限制查询数量,避免性能问题
if not end_time_messages: if not end_time_messages:
logger.warning(f"未找到结束时间 {end_time} 之前的消息") logger.warning(f"未找到结束时间 {end_time} 之前的消息")
return 0, 0 return 0, 0
# 找到最大时间 # 找到最大时间
max_time = end_time_messages[0]["time"] max_time = end_time_messages[0]["time"]
# 在最大时间的消息中找最后插入的_id最大的 # 在最大时间的消息中找最后插入的_id最大的
end_message = max( end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
[msg for msg in end_time_messages if msg["time"] == max_time],
key=lambda x: x["_id"]
)
if not start_message: if not start_message:
logger.warning(f"未找到开始时间 {start_time} 之前的消息") logger.warning(f"未找到开始时间 {start_time} 之前的消息")
return 0, 0 return 0, 0
# 调试输出 # 调试输出
# print("\n=== 消息范围信息 ===") # print("\n=== 消息范围信息 ===")
# print("Start message:", { # print("Start message:", {
@@ -587,20 +584,16 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 如果结束消息的时间等于开始时间返回0 # 如果结束消息的时间等于开始时间返回0
if end_message["time"] == start_message["time"]: if end_message["time"] == start_message["time"]:
return 0, 0 return 0, 0
# 获取并打印这个时间范围内的所有消息 # 获取并打印这个时间范围内的所有消息
# print("\n=== 时间范围内的所有消息 ===") # print("\n=== 时间范围内的所有消息 ===")
all_messages = list(db.messages.find( all_messages = list(
{ db.messages.find(
"chat_id": stream_id, {"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
"time": { sort=[("time", 1), ("_id", 1)], # 按时间正序_id正序
"$gte": start_message["time"], )
"$lte": end_message["time"] )
}
},
sort=[("time", 1), ("_id", 1)] # 按时间正序_id正序
))
count = 0 count = 0
total_length = 0 total_length = 0
for msg in all_messages: for msg in all_messages:
@@ -615,10 +608,10 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# "text_length": text_length, # "text_length": text_length,
# "_id": str(msg.get("_id")) # "_id": str(msg.get("_id"))
# }) # })
# 如果时间不同需要把end_message本身也计入 # 如果时间不同需要把end_message本身也计入
return count - 1, total_length return count - 1, total_length
except Exception as e: except Exception as e:
logger.error(f"计算消息数量时出错: {str(e)}") logger.error(f"计算消息数量时出错: {str(e)}")
return 0, 0 return 0, 0

View File

@@ -239,13 +239,13 @@ class ImageManager:
# 解码base64 # 解码base64
gif_data = base64.b64decode(gif_base64) gif_data = base64.b64decode(gif_base64)
gif = Image.open(io.BytesIO(gif_data)) gif = Image.open(io.BytesIO(gif_data))
# 收集所有帧 # 收集所有帧
frames = [] frames = []
try: try:
while True: while True:
gif.seek(len(frames)) gif.seek(len(frames))
frame = gif.convert('RGB') frame = gif.convert("RGB")
frames.append(frame.copy()) frames.append(frame.copy())
except EOFError: except EOFError:
pass pass
@@ -264,18 +264,19 @@ class ImageManager:
# 获取单帧的尺寸 # 获取单帧的尺寸
frame_width, frame_height = selected_frames[0].size frame_width, frame_height = selected_frames[0].size
# 计算目标尺寸,保持宽高比 # 计算目标尺寸,保持宽高比
target_height = 200 # 固定高度 target_height = 200 # 固定高度
target_width = int((target_height / frame_height) * frame_width) target_width = int((target_height / frame_height) * frame_width)
# 调整所有帧的大小 # 调整所有帧的大小
resized_frames = [frame.resize((target_width, target_height), Image.Resampling.LANCZOS) resized_frames = [
for frame in selected_frames] frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
]
# 创建拼接图像 # 创建拼接图像
total_width = target_width * len(resized_frames) total_width = target_width * len(resized_frames)
combined_image = Image.new('RGB', (total_width, target_height)) combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像 # 水平拼接图像
for idx, frame in enumerate(resized_frames): for idx, frame in enumerate(resized_frames):
@@ -283,11 +284,11 @@ class ImageManager:
# 转换为base64 # 转换为base64
buffer = io.BytesIO() buffer = io.BytesIO()
combined_image.save(buffer, format='JPEG', quality=85) combined_image.save(buffer, format="JPEG", quality=85)
result_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64 return result_base64
except Exception as e: except Exception as e:
logger.error(f"GIF转换失败: {str(e)}") logger.error(f"GIF转换失败: {str(e)}")
return None return None

View File

@@ -7,12 +7,13 @@ from datetime import datetime
logger = get_module_logger("pfc_message_processor") logger = get_module_logger("pfc_message_processor")
class MessageProcessor: class MessageProcessor:
"""消息处理器,负责处理接收到的消息并存储""" """消息处理器,负责处理接收到的消息并存储"""
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
def _check_ban_words(self, text: str, chat, userinfo) -> bool: def _check_ban_words(self, text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
for word in global_config.ban_words: for word in global_config.ban_words:
@@ -34,10 +35,10 @@ class MessageProcessor:
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered") logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True return True
return False return False
async def process_message(self, message: MessageRecv) -> None: async def process_message(self, message: MessageRecv) -> None:
"""处理消息并存储 """处理消息并存储
Args: Args:
message: 消息对象 message: 消息对象
""" """
@@ -55,12 +56,9 @@ class MessageProcessor:
# 存储消息 # 存储消息
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
# 打印消息信息 # 打印消息信息
mes_name = chat.group_info.group_name if chat.group_info else "私聊" mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# 将时间戳转换为datetime对象 # 将时间戳转换为datetime对象
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S") current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
logger.info( logger.info(f"[{current_time}][{mes_name}]{chat.user_info.user_nickname}: {message.processed_plain_text}")
f"[{current_time}][{mes_name}]"
f"{chat.user_info.user_nickname}: {message.processed_plain_text}"
)

View File

@@ -27,6 +27,7 @@ chat_config = LogConfig(
logger = get_module_logger("reasoning_chat", config=chat_config) logger = get_module_logger("reasoning_chat", config=chat_config)
class ReasoningChat: class ReasoningChat:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
@@ -224,13 +225,13 @@ class ReasoningChat:
do_reply = False do_reply = False
if random() < reply_probability: if random() < reply_probability:
do_reply = True do_reply = True
# 创建思考消息 # 创建思考消息
timer1 = time.time() timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo) thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time() timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1 timing_results["创建思考消息"] = timer2 - timer1
# 生成回复 # 生成回复
timer1 = time.time() timer1 = time.time()
response_set = await self.gpt.generate_response(message) response_set = await self.gpt.generate_response(message)

View File

@@ -40,7 +40,7 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
#从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
if random.random() < global_config.MODEL_R1_PROBABILITY: if random.random() < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "深深地" self.current_model_type = "深深地"
current_model = self.model_reasoning current_model = self.model_reasoning
@@ -51,7 +51,6 @@ class ResponseGenerator:
logger.info( logger.info(
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501 ) # noqa: E501
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(message, current_model)
@@ -189,4 +188,4 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}") # print(f"得到了处理后的llm返回{processed_response}")
return processed_response return processed_response

View File

@@ -24,35 +24,32 @@ class PromptBuilder:
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
#person # person
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core personality_core = individuality.personality.personality_core
prompt_personality += personality_core prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides) random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}" prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 关系 # 关系
who_chat_in_group = [(chat_stream.user_info.platform, who_chat_in_group = [
chat_stream.user_info.user_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
chat_stream.user_info.user_nickname)] ]
who_chat_in_group += get_recent_group_speaker( who_chat_in_group += get_recent_group_speaker(
stream_id, stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id), (chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE, limit=global_config.MAX_CONTEXT_SIZE,
) )
relation_prompt = "" relation_prompt = ""
for person in who_chat_in_group: for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -67,7 +64,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt() mood_prompt = mood_manager.get_prompt()
# logger.info(f"心情prompt: {mood_prompt}") # logger.info(f"心情prompt: {mood_prompt}")
# 调取记忆 # 调取记忆
memory_prompt = "" memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
@@ -84,7 +81,7 @@ class PromptBuilder:
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
# 日程构建 # 日程构建
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}''' schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
# 获取聊天上下文 # 获取聊天上下文
chat_in_group = True chat_in_group = True
@@ -143,7 +140,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。""" 涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt") logger.info("开始构建prompt")
prompt = f""" prompt = f"""
{relation_prompt_all} {relation_prompt_all}
{memory_prompt} {memory_prompt}
@@ -165,7 +162,7 @@ class PromptBuilder:
start_time = time.time() start_time = time.time()
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法 # 1. 先从LLM获取主题类似于记忆系统的做法
topics = [] topics = []
# try: # try:
@@ -173,7 +170,7 @@ class PromptBuilder:
# hippocampus = HippocampusManager.get_instance()._hippocampus # hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1))) # topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词 # # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0]) # topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics: # if not topics:
@@ -184,7 +181,7 @@ class PromptBuilder:
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",") # for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip() # if topic.strip()
# ] # ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}") # logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e: # except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}") # logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -192,7 +189,7 @@ class PromptBuilder:
# words = jieba.cut(message) # words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5] # topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息 # 如果无法提取到主题,直接使用整个消息
if not topics: if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询") logger.info("未能提取到任何主题,使用整个消息进行查询")
@@ -200,26 +197,26 @@ class PromptBuilder:
if not embedding: if not embedding:
logger.error("获取消息嵌入向量失败") logger.error("获取消息嵌入向量失败")
return "" return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
# 2. 对每个主题进行知识库查询 # 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询") logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用 # 优化批量获取嵌入向量减少API调用
embeddings = {} embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0] topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空 if message: # 确保消息非空
topics_batch.append(message) topics_batch.append(message)
# 批量获取嵌入向量 # 批量获取嵌入向量
embed_start_time = time.time() embed_start_time = time.time()
for text in topics_batch: for text in topics_batch:
if not text or len(text.strip()) == 0: if not text or len(text.strip()) == 0:
continue continue
try: try:
embedding = await get_embedding(text, request_type="prompt_build") embedding = await get_embedding(text, request_type="prompt_build")
if embedding: if embedding:
@@ -228,17 +225,17 @@ class PromptBuilder:
logger.warning(f"获取'{text}'的嵌入向量失败") logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e: except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}") logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings: if not embeddings:
logger.error("所有嵌入向量获取失败") logger.error("所有嵌入向量获取失败")
return "" return ""
# 3. 对每个主题进行知识库查询 # 3. 对每个主题进行知识库查询
all_results = [] all_results = []
query_start_time = time.time() query_start_time = time.time()
# 首先添加原始消息的查询结果 # 首先添加原始消息的查询结果
if message in embeddings: if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -247,12 +244,12 @@ class PromptBuilder:
result["topic"] = "原始消息" result["topic"] = "原始消息"
all_results.extend(original_results) all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果") logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果 # 然后添加每个主题的查询结果
for topic in topics: for topic in topics:
if not topic or topic not in embeddings: if not topic or topic not in embeddings:
continue continue
try: try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results: if topic_results:
@@ -263,9 +260,9 @@ class PromptBuilder:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e: except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤 # 4. 去重和过滤
process_start_time = time.time() process_start_time = time.time()
unique_contents = set() unique_contents = set()
@@ -275,14 +272,16 @@ class PromptBuilder:
if content not in unique_contents: if content not in unique_contents:
unique_contents.add(content) unique_contents.add(content)
filtered_results.append(result) filtered_results.append(result)
# 5. 按相似度排序 # 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True) filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条 # 6. 限制总数量最多10条
filtered_results = filtered_results[:10] filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果") logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出 # 7. 格式化输出
if filtered_results: if filtered_results:
format_start_time = time.time() format_start_time = time.time()
@@ -292,7 +291,7 @@ class PromptBuilder:
if topic not in grouped_results: if topic not in grouped_results:
grouped_results[topic] = [] grouped_results[topic] = []
grouped_results[topic].append(result) grouped_results[topic].append(result)
# 按主题组织输出 # 按主题组织输出
for topic, results in grouped_results.items(): for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n" related_info += f"【主题: {topic}\n"
@@ -303,13 +302,15 @@ class PromptBuilder:
# related_info += f"{i}. [{similarity:.2f}] {content}\n" # related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n" related_info += f"{content}\n"
related_info += "\n" related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}") logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]: def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算

View File

@@ -28,6 +28,7 @@ chat_config = LogConfig(
logger = get_module_logger("think_flow_chat", config=chat_config) logger = get_module_logger("think_flow_chat", config=chat_config)
class ThinkFlowChat: class ThinkFlowChat:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
@@ -96,7 +97,7 @@ class ThinkFlowChat:
) )
if not mark_head: if not mark_head:
mark_head = True mark_head = True
# print(f"thinking_start_time:{bot_message.thinking_start_time}") # print(f"thinking_start_time:{bot_message.thinking_start_time}")
message_set.add_message(bot_message) message_set.add_message(bot_message)
message_manager.add_message(message_set) message_manager.add_message(message_set)
@@ -110,7 +111,7 @@ class ThinkFlowChat:
if emoji_raw: if emoji_raw:
emoji_path, description = emoji_raw emoji_path, description = emoji_raw
emoji_cq = image_path_to_base64(emoji_path) emoji_cq = image_path_to_base64(emoji_path)
# logger.info(emoji_cq) # logger.info(emoji_cq)
thinking_time_point = round(message.message_info.time, 2) thinking_time_point = round(message.message_info.time, 2)
@@ -130,7 +131,7 @@ class ThinkFlowChat:
is_head=False, is_head=False,
is_emoji=True, is_emoji=True,
) )
# logger.info("22222222222222") # logger.info("22222222222222")
message_manager.add_message(bot_message) message_manager.add_message(bot_message)
@@ -180,7 +181,7 @@ class ThinkFlowChat:
await message.process() await message.process()
logger.debug(f"消息处理成功{message.processed_plain_text}") logger.debug(f"消息处理成功{message.processed_plain_text}")
# 过滤词/正则表达式过滤 # 过滤词/正则表达式过滤
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex( if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
message.raw_message, chat, userinfo message.raw_message, chat, userinfo
@@ -190,7 +191,7 @@ class ThinkFlowChat:
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
logger.debug(f"存储成功{message.processed_plain_text}") logger.debug(f"存储成功{message.processed_plain_text}")
# 记忆激活 # 记忆激活
timer1 = time.time() timer1 = time.time()
interested_rate = await HippocampusManager.get_instance().get_activate_from_text( interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
@@ -214,15 +215,13 @@ class ThinkFlowChat:
# 处理提及 # 处理提及
is_mentioned, reply_probability = is_mentioned_bot_in_message(message) is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
# 计算回复意愿 # 计算回复意愿
current_willing_old = willing_manager.get_willing(chat_stream=chat) current_willing_old = willing_manager.get_willing(chat_stream=chat)
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4 # current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
# current_willing = (current_willing_old + current_willing_new) / 2 # current_willing = (current_willing_old + current_willing_new) / 2
# 有点bug # 有点bug
current_willing = current_willing_old current_willing = current_willing_old
willing_manager.set_willing(chat.stream_id, current_willing) willing_manager.set_willing(chat.stream_id, current_willing)
# 意愿激活 # 意愿激活
@@ -258,7 +257,7 @@ class ThinkFlowChat:
if random() < reply_probability: if random() < reply_probability:
try: try:
do_reply = True do_reply = True
# 创建思考消息 # 创建思考消息
try: try:
timer1 = time.time() timer1 = time.time()
@@ -267,9 +266,9 @@ class ThinkFlowChat:
timing_results["创建思考消息"] = timer2 - timer1 timing_results["创建思考消息"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流创建思考消息失败: {e}") logger.error(f"心流创建思考消息失败: {e}")
try: try:
# 观察 # 观察
timer1 = time.time() timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe() await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time() timer2 = time.time()
@@ -280,12 +279,14 @@ class ThinkFlowChat:
# 思考前脑内状态 # 思考前脑内状态
try: try:
timer1 = time.time() timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text) await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
message.processed_plain_text
)
timer2 = time.time() timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1 timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}") logger.error(f"心流思考前脑内状态失败: {e}")
# 生成回复 # 生成回复
timer1 = time.time() timer1 = time.time()
response_set = await self.gpt.generate_response(message) response_set = await self.gpt.generate_response(message)

View File

@@ -35,7 +35,6 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
logger.info( logger.info(
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) )
@@ -178,4 +177,3 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}") # print(f"得到了处理后的llm返回{processed_response}")
return processed_response return processed_response

View File

@@ -21,22 +21,21 @@ class PromptBuilder:
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(type = "personality",x_person = 2,level = 1) prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
prompt_identity = individuality.get_prompt(type = "identity",x_person = 2,level = 1) prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
# 关系 # 关系
who_chat_in_group = [(chat_stream.user_info.platform, who_chat_in_group = [
chat_stream.user_info.user_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
chat_stream.user_info.user_nickname)] ]
who_chat_in_group += get_recent_group_speaker( who_chat_in_group += get_recent_group_speaker(
stream_id, stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id), (chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE, limit=global_config.MAX_CONTEXT_SIZE,
) )
relation_prompt = "" relation_prompt = ""
for person in who_chat_in_group: for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -100,7 +99,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。""" 涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt") logger.info("开始构建prompt")
prompt = f""" prompt = f"""
{relation_prompt_all}\n {relation_prompt_all}\n
{chat_target} {chat_target}
@@ -114,7 +113,7 @@ class PromptBuilder:
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""" {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
return prompt return prompt

View File

@@ -3,6 +3,7 @@ import tomlkit
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
def update_config(): def update_config():
print("开始更新配置文件...") print("开始更新配置文件...")
# 获取根目录路径 # 获取根目录路径
@@ -25,11 +26,11 @@ def update_config():
print(f"发现旧配置文件: {old_config_path}") print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f: with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f) old_config = tomlkit.load(f)
# 生成带时间戳的新文件名 # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录 # 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path) shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}") print(f"已备份旧配置文件到: {old_backup_path}")

View File

@@ -28,6 +28,7 @@ logger = get_module_logger("config", config=config_config)
is_test = True is_test = True
mai_version_main = "0.6.2" mai_version_main = "0.6.2"
mai_version_fix = "snapshot-1" mai_version_fix = "snapshot-1"
if mai_version_fix: if mai_version_fix:
if is_test: if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}" mai_version = f"test-{mai_version_main}-{mai_version_fix}"
@@ -39,6 +40,7 @@ else:
else: else:
mai_version = mai_version_main mai_version = mai_version_main
def update_config(): def update_config():
# 获取根目录路径 # 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent root_dir = Path(__file__).parent.parent.parent.parent
@@ -54,7 +56,7 @@ def update_config():
# 检查配置文件是否存在 # 检查配置文件是否存在
if not old_config_path.exists(): if not old_config_path.exists():
logger.info("配置文件不存在,从模板创建新配置") logger.info("配置文件不存在,从模板创建新配置")
#创建文件夹 # 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True) old_config_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(template_path, old_config_path) shutil.copy2(template_path, old_config_path)
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
@@ -84,7 +86,7 @@ def update_config():
# 生成带时间戳的新文件名 # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录 # 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path) shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}") logger.info(f"已备份旧配置文件到: {old_backup_path}")
@@ -127,6 +129,7 @@ def update_config():
f.write(tomlkit.dumps(new_config)) f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成") logger.info("配置文件更新完成")
logger = get_module_logger("config") logger = get_module_logger("config")
@@ -148,17 +151,21 @@ class BotConfig:
ban_user_id = set() ban_user_id = set()
# personality # personality
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋 personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(default_factory=lambda: [ personality_sides: List[str] = field(
"用一句话或几句话描述人格的一些侧面", default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面", "用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面" "用一句话或几句话描述人格的一些侧面",
]) "用一句话或几句话描述人格的一些侧面",
]
)
# identity # identity
identity_detail: List[str] = field(default_factory=lambda: [ identity_detail: List[str] = field(
"身份特点", default_factory=lambda: [
"身份特点", "身份特点",
]) "身份特点",
]
)
height: int = 170 # 身高 单位厘米 height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克 weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁 age: int = 20 # 年龄 单位岁
@@ -181,22 +188,22 @@ class BotConfig:
ban_words = set() ban_words = set()
ban_msgs_regex = set() ban_msgs_regex = set()
#heartflow # heartflow
# enable_heartflow: bool = False # 是否启用心流 # enable_heartflow: bool = False # 是否启用心流
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒 sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒 heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
# willing # willing
willing_mode: str = "classical" # 意愿模式 willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数 response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数 response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数 down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚 emoji_response_penalty: float = 0.0 # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复 mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复 at_bot_inevitable_reply: bool = False # @bot 必然回复
# response # response
response_mode: str = "heart_flow" # 回复策略 response_mode: str = "heart_flow" # 回复策略
@@ -354,7 +361,6 @@ class BotConfig:
"""从TOML配置文件加载配置""" """从TOML配置文件加载配置"""
config = cls() config = cls()
def personality(parent: dict): def personality(parent: dict):
personality_config = parent["personality"] personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"): if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
@@ -418,13 +424,21 @@ class BotConfig:
config.max_response_length = response_config.get("max_response_length", config.max_response_length) config.max_response_length = response_config.get("max_response_length", config.max_response_length)
if config.INNER_VERSION in SpecifierSet(">=1.0.4"): if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
config.response_mode = response_config.get("response_mode", config.response_mode) config.response_mode = response_config.get("response_mode", config.response_mode)
def heartflow(parent: dict): def heartflow(parent: dict):
heartflow_config = parent["heartflow"] heartflow_config = parent["heartflow"]
config.sub_heart_flow_update_interval = heartflow_config.get("sub_heart_flow_update_interval", config.sub_heart_flow_update_interval) config.sub_heart_flow_update_interval = heartflow_config.get(
config.sub_heart_flow_freeze_time = heartflow_config.get("sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time) "sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
config.sub_heart_flow_stop_time = heartflow_config.get("sub_heart_flow_stop_time", config.sub_heart_flow_stop_time) )
config.heart_flow_update_interval = heartflow_config.get("heart_flow_update_interval", config.heart_flow_update_interval) config.sub_heart_flow_freeze_time = heartflow_config.get(
"sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
)
config.sub_heart_flow_stop_time = heartflow_config.get(
"sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
)
config.heart_flow_update_interval = heartflow_config.get(
"heart_flow_update_interval", config.heart_flow_update_interval
)
def willing(parent: dict): def willing(parent: dict):
willing_config = parent["willing"] willing_config = parent["willing"]

View File

@@ -14,6 +14,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from .memory_config import MemoryConfig from .memory_config import MemoryConfig
def get_closest_chat_from_db(length: int, timestamp: str): def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}") # print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}") # print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")

View File

@@ -179,7 +179,6 @@ class LLM_request:
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}") # logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}") # logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format) payload = await self._build_payload(prompt, image_base64, image_format)
@@ -205,13 +204,17 @@ class LLM_request:
# 处理需要重试的状态码 # 处理需要重试的状态码
if response.status in policy["retry_codes"]: if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry) wait_time = policy["base_wait"] * (2**retry)
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") logger.warning(
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
)
if response.status == 413: if response.status == 413:
logger.warning("请求体过大,尝试压缩...") logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64) image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format) payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]: elif response.status in [500, 503]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ") raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else: else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
@@ -219,7 +222,9 @@ class LLM_request:
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
continue continue
elif response.status in policy["abort_codes"]: elif response.status in policy["abort_codes"]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息 # 尝试获取并记录服务器返回的详细错误信息
try: try:
error_json = await response.json() error_json = await response.json()
@@ -257,7 +262,9 @@ class LLM_request:
): ):
old_model_name = self.model_name old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀 self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}") logger.warning(
f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}"
)
# 对全局配置进行更新 # 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name: if global_config.llm_normal.get("name") == old_model_name:
@@ -266,7 +273,9 @@ class LLM_request:
if global_config.llm_reasoning.get("name") == old_model_name: if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") logger.warning(
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
)
# 更新payload中的模型名 # 更新payload中的模型名
if payload and "model" in payload: if payload and "model" in payload:
@@ -328,7 +337,14 @@ class LLM_request:
await response.release() await response.release()
# 返回已经累积的内容 # 返回已经累积的内容
result = { result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}], "choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -345,7 +361,14 @@ class LLM_request:
logger.error(f"清理资源时发生错误: {cleanup_error}") logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容 # 返回已经累积的内容
result = { result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}], "choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -360,7 +383,9 @@ class LLM_request:
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器 # 构造一个伪result以便调用自定义响应处理器或默认处理器
result = { result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "choices": [
{"message": {"content": content, "reasoning_content": reasoning_content}}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -394,7 +419,9 @@ class LLM_request:
# 处理aiohttp抛出的响应错误 # 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1: if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry) wait_time = policy["base_wait"] * (2**retry)
logger.error(f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}") logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
)
try: try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"): if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text() error_text = await e.response.text()
@@ -419,13 +446,17 @@ class LLM_request:
else: else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}") logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err: except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}") logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except (AttributeError, TypeError, ValueError) as parse_err: except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}") logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}") logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
)
# 安全地检查和记录请求详情 # 安全地检查和记录请求详情
if ( if (
image_base64 image_base64

View File

@@ -139,7 +139,7 @@ class MoodManager:
# 神经质:影响情绪变化速度 # 神经质:影响情绪变化速度
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5 neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5 agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
# 宜人性:影响情绪基准线 # 宜人性:影响情绪基准线
if personality.agreeableness < 0.2: if personality.agreeableness < 0.2:
agreeableness_bias = (personality.agreeableness - 0.2) * 2 agreeableness_bias = (personality.agreeableness - 0.2) * 2
@@ -151,7 +151,7 @@ class MoodManager:
# 分别计算正向和负向的衰减率 # 分别计算正向和负向的衰减率
if self.current_mood.valence >= 0: if self.current_mood.valence >= 0:
# 正向情绪衰减 # 正向情绪衰减
decay_rate_positive = self.decay_rate_valence * (1/agreeableness_factor) decay_rate_positive = self.decay_rate_valence * (1 / agreeableness_factor)
valence_target = 0 + agreeableness_bias valence_target = 0 + agreeableness_bias
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp( self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-decay_rate_positive * time_diff * neuroticism_factor -decay_rate_positive * time_diff * neuroticism_factor
@@ -279,8 +279,9 @@ class MoodManager:
# 限制范围 # 限制范围
self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence))
self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal))
self._update_mood_text() self._update_mood_text()
logger.info(f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}") logger.info(
f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}"
)

View File

@@ -8,7 +8,8 @@ import asyncio
import numpy as np import numpy as np
import matplotlib import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
@@ -30,38 +31,39 @@ PersonInfoManager 类方法功能摘要:
logger = get_module_logger("person_info") logger = get_module_logger("person_info")
person_info_default = { person_info_default = {
"person_id" : None, "person_id": None,
"platform" : None, "platform": None,
"user_id" : None, "user_id": None,
"nickname" : None, "nickname": None,
# "age" : 0, # "age" : 0,
"relationship_value" : 0, "relationship_value": 0,
# "saved" : True, # "saved" : True,
# "impression" : None, # "impression" : None,
# "gender" : Unkown, # "gender" : Unkown,
"konw_time" : 0, "konw_time": 0,
"msg_interval": 3000, "msg_interval": 3000,
"msg_interval_list": [] "msg_interval_list": [],
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 } # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
if "person_info" not in db.list_collection_names(): if "person_info" not in db.list_collection_names():
db.create_collection("person_info") db.create_collection("person_info")
db.person_info.create_index("person_id", unique=True) db.person_info.create_index("person_id", unique=True)
def get_person_id(self, platform:str, user_id:int): def get_person_id(self, platform: str, user_id: int):
"""获取唯一id""" """获取唯一id"""
components = [platform, str(user_id)] components = [platform, str(user_id)]
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
async def create_person_info(self, person_id:str, data:dict = None): async def create_person_info(self, person_id: str, data: dict = None):
"""创建一个项""" """创建一个项"""
if not person_id: if not person_id:
logger.debug("创建失败personid不存在") logger.debug("创建失败personid不存在")
return return
_person_info_default = copy.deepcopy(person_info_default) _person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id _person_info_default["person_id"] = person_id
@@ -72,19 +74,16 @@ class PersonInfoManager:
db.person_info.insert_one(_person_info_default) db.person_info.insert_one(_person_info_default)
async def update_one_field(self, person_id:str, field_name:str, value, Data:dict = None): async def update_one_field(self, person_id: str, field_name: str, value, Data: dict = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in person_info_default.keys(): if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段") logger.debug(f"更新'{field_name}'失败,未定义的字段")
return return
document = db.person_info.find_one({"person_id": person_id}) document = db.person_info.find_one({"person_id": person_id})
if document: if document:
db.person_info.update_one( db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
{"person_id": person_id},
{"$set": {field_name: value}}
)
else: else:
Data[field_name] = value Data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建") logger.debug(f"更新时{person_id}不存在,已新建")
@@ -107,23 +106,20 @@ class PersonInfoManager:
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
return None return None
if field_name not in person_info_default: if field_name not in person_info_default:
logger.debug(f"get_value获取失败字段'{field_name}'未定义") logger.debug(f"get_value获取失败字段'{field_name}'未定义")
return None return None
document = db.person_info.find_one( document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
{"person_id": person_id},
{field_name: 1}
)
if document and field_name in document: if document and field_name in document:
return document[field_name] return document[field_name]
else: else:
default_value = copy.deepcopy(person_info_default[field_name]) default_value = copy.deepcopy(person_info_default[field_name])
logger.debug(f"获取{person_id}{field_name}失败,已返回默认值{default_value}") logger.debug(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
return default_value return default_value
async def get_values(self, person_id: str, field_names: list) -> dict: async def get_values(self, person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
@@ -139,62 +135,57 @@ class PersonInfoManager:
# 构建查询投影(所有字段都有效才会执行到这里) # 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names} projection = {field: 1 for field in field_names}
document = db.person_info.find_one( document = db.person_info.find_one({"person_id": person_id}, projection)
{"person_id": person_id},
projection
)
result = {} result = {}
for field in field_names: for field in field_names:
result[field] = copy.deepcopy( result[field] = copy.deepcopy(
document.get(field, person_info_default[field]) document.get(field, person_info_default[field]) if document else person_info_default[field]
if document else person_info_default[field]
) )
return result return result
async def del_all_undefined_field(self): async def del_all_undefined_field(self):
"""删除所有项里的未定义字段""" """删除所有项里的未定义字段"""
# 获取所有已定义的字段名 # 获取所有已定义的字段名
defined_fields = set(person_info_default.keys()) defined_fields = set(person_info_default.keys())
try: try:
# 遍历集合中的所有文档 # 遍历集合中的所有文档
for document in db.person_info.find({}): for document in db.person_info.find({}):
# 找出文档中未定义的字段 # 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {'_id'} undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields: if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段 # 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one( update_result = db.person_info.update_one(
{'_id': document['_id']}, {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
{'$unset': {field: 1 for field in undefined_fields}}
) )
if update_result.modified_count > 0: if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}") logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return return
except Exception as e: except Exception as e:
logger.error(f"清理未定义字段时出错: {e}") logger.error(f"清理未定义字段时出错: {e}")
return return
async def get_specific_value_list( async def get_specific_value_list(
self, self,
field_name: str, field_name: str,
way: Callable[[Any], bool], # 接受任意类型值 way: Callable[[Any], bool], # 接受任意类型值
) ->Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取满足条件的字段值字典 获取满足条件的字段值字典
Args: Args:
field_name: 目标字段名 field_name: 目标字段名
way: 判断函数 (value: Any) -> bool way: 判断函数 (value: Any) -> bool
Returns: Returns:
{person_id: value} | {} {person_id: value} | {}
Example: Example:
# 查找所有nickname包含"admin"的用户 # 查找所有nickname包含"admin"的用户
result = manager.specific_value_list( result = manager.specific_value_list(
@@ -208,10 +199,7 @@ class PersonInfoManager:
try: try:
result = {} result = {}
for doc in db.person_info.find( for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
{field_name: {"$exists": True}},
{"person_id": 1, field_name: 1, "_id": 0}
):
try: try:
value = doc[field_name] value = doc[field_name]
if way(value): if way(value):
@@ -225,11 +213,11 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True) logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
return {} return {}
async def personal_habit_deduction(self): async def personal_habit_deduction(self):
"""启动个人信息推断,每天根据一定条件推断一次""" """启动个人信息推断,每天根据一定条件推断一次"""
try: try:
while(1): while 1:
await asyncio.sleep(60) await asyncio.sleep(60)
current_time = datetime.datetime.now() current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
@@ -237,8 +225,7 @@ class PersonInfoManager:
# "msg_interval"推断 # "msg_interval"推断
msg_interval_map = False msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list( msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list", "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
lambda x: isinstance(x, list) and len(x) >= 100
) )
for person_id, msg_interval_list_ in msg_interval_lists.items(): for person_id, msg_interval_list_ in msg_interval_lists.items():
try: try:
@@ -258,23 +245,23 @@ class PersonInfoManager:
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
time_series = pd.Series(time_interval) time_series = pd.Series(time_interval)
plt.hist(time_series, bins=50, density=True, alpha=0.4, color='pink', label='Histogram') plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
time_series.plot(kind='kde', color='mediumpurple', linewidth=1, label='Density') time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
plt.grid(True, alpha=0.2) plt.grid(True, alpha=0.2)
plt.xlim(0, 8000) plt.xlim(0, 8000)
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)") plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
plt.xlabel("Interval (ms)") plt.xlabel("Interval (ms)")
plt.ylabel("Density") plt.ylabel("Density")
plt.legend(framealpha=0.9, facecolor='white') plt.legend(framealpha=0.9, facecolor="white")
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path) plt.savefig(img_path)
plt.close() plt.close()
# 画图 # 画图
q25, q75 = np.percentile(time_interval, [25, 75]) q25, q75 = np.percentile(time_interval, [25, 75])
iqr = q75 - q25 iqr = q75 - q25
filtered = [x for x in time_interval if (q25 - 1.5*iqr) <= x <= (q75 + 1.5*iqr)] filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
msg_interval = int(round(np.percentile(filtered, 80))) msg_interval = int(round(np.percentile(filtered, 80)))
await self.update_one_field(person_id, "msg_interval", msg_interval) await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}") logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")

View File

@@ -12,6 +12,7 @@ relationship_config = LogConfig(
) )
logger = get_module_logger("rel_manager", config=relationship_config) logger = get_module_logger("rel_manager", config=relationship_config)
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.positive_feedback_value = 0 # 正反馈系统 self.positive_feedback_value = 0 # 正反馈系统
@@ -22,6 +23,7 @@ class RelationshipManager:
def mood_manager(self): def mood_manager(self):
if self._mood_manager is None: if self._mood_manager is None:
from ..moods.moods import MoodManager # 延迟导入 from ..moods.moods import MoodManager # 延迟导入
self._mood_manager = MoodManager.get_instance() self._mood_manager = MoodManager.get_instance()
return self._mood_manager return self._mood_manager
@@ -51,27 +53,27 @@ class RelationshipManager:
self.positive_feedback_value -= 1 self.positive_feedback_value -= 1
elif self.positive_feedback_value > 0: elif self.positive_feedback_value > 0:
self.positive_feedback_value = 0 self.positive_feedback_value = 0
if abs(self.positive_feedback_value) > 1: if abs(self.positive_feedback_value) > 1:
logger.info(f"触发mood变更增益当前增益系数{self.gain_coefficient[abs(self.positive_feedback_value)]}") logger.info(f"触发mood变更增益当前增益系数{self.gain_coefficient[abs(self.positive_feedback_value)]}")
def mood_feedback(self, value): def mood_feedback(self, value):
"""情绪反馈""" """情绪反馈"""
mood_manager = self.mood_manager mood_manager = self.mood_manager
mood_gain = (mood_manager.get_current_mood().valence) ** 2 \ mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
* math.copysign(1, value * mood_manager.get_current_mood().valence) 1, value * mood_manager.get_current_mood().valence
)
value += value * mood_gain value += value * mood_gain
logger.info(f"当前relationship增益系数{mood_gain:.3f}") logger.info(f"当前relationship增益系数{mood_gain:.3f}")
return value return value
def feedback_to_mood(self, mood_value): def feedback_to_mood(self, mood_value):
"""对情绪的反馈""" """对情绪的反馈"""
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)] coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
if (mood_value > 0 and self.positive_feedback_value > 0 if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
or mood_value < 0 and self.positive_feedback_value < 0): return mood_value * coefficient
return mood_value*coefficient
else: else:
return mood_value/coefficient return mood_value / coefficient
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None: async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
"""计算并变更关系值 """计算并变更关系值
@@ -88,7 +90,7 @@ class RelationshipManager:
"中立": 1, "中立": 1,
"反对": 2, "反对": 2,
} }
valuedict = { valuedict = {
"开心": 1.5, "开心": 1.5,
"愤怒": -2.0, "愤怒": -2.0,
@@ -103,10 +105,10 @@ class RelationshipManager:
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id) person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
data = { data = {
"platform" : chat_stream.user_info.platform, "platform": chat_stream.user_info.platform,
"user_id" : chat_stream.user_info.user_id, "user_id": chat_stream.user_info.user_id,
"nickname" : chat_stream.user_info.user_nickname, "nickname": chat_stream.user_info.user_nickname,
"konw_time" : int(time.time()) "konw_time": int(time.time()),
} }
old_value = await person_info_manager.get_value(person_id, "relationship_value") old_value = await person_info_manager.get_value(person_id, "relationship_value")
old_value = self.ensure_float(old_value, person_id) old_value = self.ensure_float(old_value, person_id)
@@ -200,4 +202,5 @@ class RelationshipManager:
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}已重置为0") logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}已重置为0")
return 0.0 return 0.0
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()

View File

@@ -14,7 +14,7 @@ from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfi
from src.plugins.models.utils_model import LLM_request # noqa: E402 from src.plugins.models.utils_model import LLM_request # noqa: E402
from src.plugins.config.config import global_config # noqa: E402 from src.plugins.config.config import global_config # noqa: E402
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区 TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
schedule_config = LogConfig( schedule_config = LogConfig(
@@ -31,10 +31,16 @@ class ScheduleGenerator:
def __init__(self): def __init__(self):
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLM_request(
model=global_config.llm_reasoning, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=7000, request_type="schedule" model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=7000,
request_type="schedule",
) )
self.llm_scheduler_doing = LLM_request( self.llm_scheduler_doing = LLM_request(
model=global_config.llm_normal, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=2048, request_type="schedule" model=global_config.llm_normal,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=2048,
request_type="schedule",
) )
self.today_schedule_text = "" self.today_schedule_text = ""

View File

@@ -2,7 +2,7 @@ import threading
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict from typing import Any, Dict, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ...common.database import db from ...common.database import db
@@ -22,6 +22,7 @@ class LLMStatistics:
self.stats_thread = None self.stats_thread = None
self.console_thread = None self.console_thread = None
self._init_database() self._init_database()
self.name_dict: Dict[List] = {}
def _init_database(self): def _init_database(self):
"""初始化数据库集合""" """初始化数据库集合"""
@@ -137,16 +138,24 @@ class LLMStatistics:
# user_id = str(doc.get("user_info", {}).get("user_id", "unknown")) # user_id = str(doc.get("user_info", {}).get("user_id", "unknown"))
chat_info = doc.get("chat_info", {}) chat_info = doc.get("chat_info", {})
user_info = doc.get("user_info", {}) user_info = doc.get("user_info", {})
message_time = doc.get("time", 0)
group_info = chat_info.get("group_info") if chat_info else {} group_info = chat_info.get("group_info") if chat_info else {}
# print(f"group_info: {group_info}") # print(f"group_info: {group_info}")
group_name = None group_name = None
if group_info: if group_info:
group_id = f"g{group_info.get('group_id')}"
group_name = group_info.get("group_name", f"{group_info.get('group_id')}") group_name = group_info.get("group_name", f"{group_info.get('group_id')}")
if user_info and not group_name: if user_info and not group_name:
group_id = f"u{user_info['user_id']}"
group_name = user_info["user_nickname"] group_name = user_info["user_nickname"]
if self.name_dict.get(group_id):
if message_time > self.name_dict.get(group_id)[1]:
self.name_dict[group_id] = [group_name, message_time]
else:
self.name_dict[group_id] = [group_name, message_time]
# print(f"group_name: {group_name}") # print(f"group_name: {group_name}")
stats["messages_by_user"][user_id] += 1 stats["messages_by_user"][user_id] += 1
stats["messages_by_chat"][group_name] += 1 stats["messages_by_chat"][group_id] += 1
return stats return stats
@@ -187,7 +196,7 @@ class LLMStatistics:
tokens = stats["tokens_by_model"][model_name] tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name] cost = stats["costs_by_model"][model_name]
output.append( output.append(
data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost) data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
) )
output.append("") output.append("")
@@ -221,8 +230,8 @@ class LLMStatistics:
# 添加聊天统计 # 添加聊天统计
output.append("群组统计:") output.append("群组统计:")
output.append(("群组名称 消息数量")) output.append(("群组名称 消息数量"))
for group_name, count in sorted(stats["messages_by_chat"].items()): for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{group_name[:32]:<32} {count:>10}") output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
return "\n".join(output) return "\n".join(output)
@@ -250,7 +259,7 @@ class LLMStatistics:
tokens = stats["tokens_by_model"][model_name] tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name] cost = stats["costs_by_model"][model_name]
output.append( output.append(
data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost) data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
) )
output.append("") output.append("")
@@ -284,8 +293,8 @@ class LLMStatistics:
# 添加聊天统计 # 添加聊天统计
output.append("群组统计:") output.append("群组统计:")
output.append(("群组名称 消息数量")) output.append(("群组名称 消息数量"))
for group_name, count in sorted(stats["messages_by_chat"].items()): for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{group_name[:32]:<32} {count:>10}") output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
return "\n".join(output) return "\n".join(output)

View File

@@ -53,18 +53,18 @@ class KnowledgeLibrary:
# 按空行分割内容 # 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = [] chunks = []
for para in paragraphs: for para in paragraphs:
para_length = len(para) para_length = len(para)
# 如果段落长度小于等于最大长度,直接添加 # 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length: if para_length <= max_length:
chunks.append(para) chunks.append(para)
else: else:
# 如果段落超过最大长度,则按最大长度切分 # 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length): for i in range(0, para_length, max_length):
chunks.append(para[i:i + max_length]) chunks.append(para[i : i + max_length])
return chunks return chunks
def get_embedding(self, text: str) -> list: def get_embedding(self, text: str) -> list:

File diff suppressed because it is too large Load Diff