diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 000000000..02fe9f821
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,20 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## Commands
+- **Run Bot**: `python bot.py`
+- **Lint**: `ruff check --fix .` or `ruff format .`
+- **Run Tests**: `python -m unittest discover -v`
+- **Run Single Test**: `python -m unittest src/plugins/message/test.py`
+
+## Code Style
+- **Formatting**: Line length 120 chars, use double quotes for strings
+- **Imports**: Group standard library, external packages, then internal imports
+- **Naming**: snake_case for functions/variables, PascalCase for classes
+- **Error Handling**: Use try/except blocks with specific exceptions
+- **Types**: Use type hints where possible
+- **Docstrings**: Document classes and complex functions
+- **Linting**: Follow ruff rules (E, F, B) with ignores E711, E501
+
+When making changes, run `ruff check --fix .` to ensure code follows style guidelines. The codebase uses Ruff for linting and formatting.
\ No newline at end of file
diff --git a/README.md b/README.md
index fa97fec14..67aacb8e0 100644
--- a/README.md
+++ b/README.md
@@ -1,24 +1,66 @@
# 麦麦!MaiCore-MaiMBot (编辑中)
+
+
+
+ 
+ 
+ 
+ 
+ 
+ 
+ 
+
+
+
+
+
+
+
+
+
+ 画师:略nd
+
+
+
MaiBot(麦麦)
+
+ 一款专注于 群组聊天 的赛博网友
+
+ 探索本项目的文档 »
+
+
+
+ 报告Bug
+ ·
+ 提出新特性
+
+
+
## 新版0.6.0部署前先阅读:https://docs.mai-mai.org/manual/usage/mmc_q_a
-
-
-
-
-
-
-
## 📝 项目简介
**🍔MaiCore是一个基于大语言模型的可交互智能体**
-- LLM 提供对话能力
-- 动态Prompt构建器
-- 实时的思维系统
-- MongoDB 提供数据持久化支持
-- 可扩展,可支持多种平台和多种功能
+
+- 💭 **智能对话系统**:基于LLM的自然语言交互
+- 🤔 **实时思维系统**:模拟人类思考过程
+- 💝 **情感表达系统**:丰富的表情包和情绪表达
+- 🧠 **持久记忆系统**:基于MongoDB的长期记忆存储
+- 🔄 **动态人格系统**:自适应的性格特征
+
+
+
+
+### 📢 版本信息
**最新版本: v0.6.0** ([查看更新日志](changelogs/changelog.md))
> [!WARNING]
@@ -28,19 +70,12 @@
> 次版本MaiBot将基于MaiCore运行,不再依赖于nonebot相关组件运行。
> MaiBot将通过nonebot的插件与nonebot建立联系,然后nonebot与QQ建立联系,实现MaiBot与QQ的交互
-**分支介绍:**
-- main 稳定版本
-- dev 开发版(不知道什么意思就别下)
-- classical 0.6.0以前的版本
+**分支说明:**
+- `main`: 稳定发布版本
+- `dev`: 开发测试版本(不知道什么意思就别下)
+- `classical`: 0.6.0之前的版本
-
> [!WARNING]
> - 项目处于活跃开发阶段,代码可能随时更改
@@ -49,6 +84,12 @@
> - 由于持续迭代,可能存在一些已知或未知的bug
> - 由于开发中,可能消耗较多token
+### ⚠️ 重要提示
+
+- 升级到v0.6.0版本前请务必阅读:[升级指南](https://docs.mai-mai.org/manual/usage/mmc_q_a)
+- 本版本基于MaiCore重构,通过nonebot插件与QQ平台交互
+- 项目处于活跃开发阶段,功能和API可能随时调整
+
### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】
@@ -72,55 +113,35 @@
## 🎯 功能介绍
-### 💬 聊天功能
-- 提供思维流(心流)聊天和推理聊天两种对话逻辑
-- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言
-- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
-- 支持多模型,多厂商自定义配置
-- 动态的prompt构建器,更拟人
-- 支持图片,转发消息,回复消息的识别
-- 支持私聊功能,可使用PFC模式的有目的多轮对话(实验性)
+| 模块 | 主要功能 | 特点 |
+|------|---------|------|
+| 💬 聊天系统 | • 思维流/推理聊天 • 关键词主动发言 • 多模型支持 • 动态prompt构建 • 私聊功能(PFC) | 拟人化交互 |
+| 🧠 思维流系统 | • 实时思考生成 • 自动启停机制 • 日程系统联动 | 智能化决策 |
+| 🧠 记忆系统 2.0 | • 优化记忆抽取 • 海马体记忆机制 • 聊天记录概括 | 持久化记忆 |
+| 😊 表情包系统 | • 情绪匹配发送 • GIF支持 • 自动收集与审查 | 丰富表达 |
+| 📅 日程系统 | • 动态日程生成 • 自定义想象力 • 思维流联动 | 智能规划 |
+| 👥 关系系统 2.0 | • 关系管理优化 • 丰富接口支持 • 个性化交互 | 深度社交 |
+| 📊 统计系统 | • 使用数据统计 • LLM调用记录 • 实时控制台显示 | 数据可视 |
+| 🔧 系统功能 | • 优雅关闭机制 • 自动数据保存 • 异常处理完善 | 稳定可靠 |
-### 🧠 思维流系统
-- 思维流能够在回复前后进行思考,生成实时想法
-- 思维流自动启停机制,提升资源利用效率
-- 思维流与日程系统联动,实现动态日程生成
+## 📐 项目架构
-### 🧠 记忆系统 2.0
-- 优化记忆抽取策略和prompt结构
-- 改进海马体记忆提取机制,提升自然度
-- 对聊天记录进行概括存储,在需要时调用
+```mermaid
+graph TD
+ A[MaiCore] --> B[对话系统]
+ A --> C[思维流系统]
+ A --> D[记忆系统]
+ A --> E[情感系统]
+ B --> F[多模型支持]
+ B --> G[动态Prompt]
+ C --> H[实时思考]
+ C --> I[日程联动]
+ D --> J[记忆存储]
+ D --> K[记忆检索]
+ E --> L[表情管理]
+ E --> M[情绪识别]
+```
-### 😊 表情包系统
-- 支持根据发言内容发送对应情绪的表情包
-- 支持识别和处理gif表情包
-- 会自动偷群友的表情包
-- 表情包审查功能
-- 表情包文件完整性自动检查
-- 自动清理缓存图片
-
-### 📅 日程系统
-- 动态更新的日程生成
-- 可自定义想象力程度
-- 与聊天情况交互(思维流模式下)
-
-### 👥 关系系统 2.0
-- 优化关系管理系统,适用于新版本
-- 提供更丰富的关系接口
-- 针对每个用户创建"关系",实现个性化回复
-
-### 📊 统计系统
-- 详细的使用数据统计
-- LLM调用统计
-- 在控制台显示统计信息
-
-### 🔧 系统功能
-- 支持优雅的shutdown机制
-- 自动保存功能,定期保存聊天记录和关系数据
-- 完善的异常处理机制
-- 可自定义时区设置
-- 优化的日志输出格式
-- 配置自动更新功能
## 开发计划TODO:LIST
diff --git a/depends-data/maimai.png b/depends-data/maimai.png
new file mode 100644
index 000000000..faccb856b
Binary files /dev/null and b/depends-data/maimai.png differ
diff --git a/depends-data/video.png b/depends-data/video.png
new file mode 100644
index 000000000..84176b2d9
Binary files /dev/null and b/depends-data/video.png differ
diff --git a/src/gui/logger_gui.py b/src/gui/logger_gui.py
index 9488446c4..ad6edafb8 100644
--- a/src/gui/logger_gui.py
+++ b/src/gui/logger_gui.py
@@ -24,10 +24,10 @@
# # 标记GUI是否运行中
# self.is_running = True
-
+
# # 程序关闭时的清理操作
# self.protocol("WM_DELETE_WINDOW", self._on_closing)
-
+
# # 初始化进程、日志队列、日志数据等变量
# self.process = None
# self.log_queue = queue.Queue()
@@ -236,7 +236,7 @@
# while not self.log_queue.empty():
# line = self.log_queue.get()
# self.process_log_line(line)
-
+
# # 仅在GUI仍在运行时继续处理队列
# if self.is_running:
# self.after(100, self.process_log_queue)
@@ -245,11 +245,11 @@
# """解析单行日志并更新日志数据和筛选器"""
# match = re.match(
# r"""^
-# (?:(?P\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
-# (?P\w+)\s*\|\s*
-# (?P.*?)
-# \s*[-|]\s*
-# (?P.*)
+# (?:(?P\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
+# (?P\w+)\s*\|\s*
+# (?P.*?)
+# \s*[-|]\s*
+# (?P.*)
# $""",
# line.strip(),
# re.VERBOSE,
@@ -354,10 +354,10 @@
# """处理窗口关闭事件,安全清理资源"""
# # 标记GUI已关闭
# self.is_running = False
-
+
# # 停止日志进程
# self.stop_process()
-
+
# # 安全清理tkinter变量
# for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -367,7 +367,7 @@
# except Exception:
# pass
# setattr(self, attr_name, None)
-
+
# self.quit()
# sys.exit(0)
diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py
index ffb270fd8..7fa99df2e 100644
--- a/src/gui/reasoning_gui.py
+++ b/src/gui/reasoning_gui.py
@@ -127,7 +127,7 @@
# """处理窗口关闭事件"""
# # 标记GUI已关闭,防止后台线程继续访问tkinter对象
# self.is_running = False
-
+
# # 安全清理所有可能的tkinter变量
# for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -138,7 +138,7 @@
# except Exception:
# pass
# setattr(self, attr_name, None)
-
+
# # 退出
# self.root.quit()
# sys.exit(0)
@@ -259,7 +259,7 @@
# while True:
# if not self.is_running:
# break # 如果GUI已关闭,停止线程
-
+
# try:
# # 从数据库获取最新数据,只获取启动时间之后的记录
# query = {"time": {"$gt": self.start_timestamp}}
diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py
index d5542587c..9cf8d4674 100644
--- a/src/heart_flow/heartflow.py
+++ b/src/heart_flow/heartflow.py
@@ -42,7 +42,6 @@ class Heartflow:
self._subheartflows = {}
self.active_subheartflows_nums = 0
-
async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流"""
while True:
@@ -84,25 +83,22 @@ class Heartflow:
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
personality_info = prompt_personality
-
-
+
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
related_memory_info = "memory"
@@ -146,22 +142,20 @@ class Heartflow:
async def minds_summary(self, minds_str):
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
personality_info = prompt_personality
mood_info = self.current_state.mood
@@ -183,7 +177,7 @@ class Heartflow:
添加一个SubHeartflow实例到self._subheartflows字典中
并根据subheartflow_id为子心流创建一个观察对象
"""
-
+
try:
if subheartflow_id not in self._subheartflows:
logger.debug(f"创建 subheartflow: {subheartflow_id}")
diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py
index f4a082a4e..5befd7322 100644
--- a/src/heart_flow/observation.py
+++ b/src/heart_flow/observation.py
@@ -7,6 +7,7 @@ from src.common.database import db
from src.individuality.individuality import Individuality
import random
+
# 所有观察的基类
class Observation:
def __init__(self, observe_type, observe_id):
@@ -24,7 +25,7 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
-
+
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
@@ -57,7 +58,7 @@ class ChattingObservation(Observation):
for msg in new_messages:
if "detailed_plain_text" in msg:
new_messages_str += f"{msg['detailed_plain_text']}"
-
+
# print(f"new_messages_str:{new_messages_str}")
# 将新消息添加到talking_message,同时保持列表长度不超过20条
@@ -117,26 +118,22 @@ class ChattingObservation(Observation):
# print(f"更新聊天总结:{self.talking_summary}")
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
personality_info = prompt_personality
-
-
-
+
prompt = ""
prompt += f"{personality_info},请注意识别你自己的聊天发言"
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
@@ -148,7 +145,6 @@ class ChattingObservation(Observation):
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
print(f"prompt:{prompt}")
print(f"self.observe_info:{self.observe_info}")
-
def translate_message_list_to_str(self):
self.talking_message_str = ""
diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py
index 583a7d561..a2ba023e2 100644
--- a/src/heart_flow/sub_heartflow.py
+++ b/src/heart_flow/sub_heartflow.py
@@ -53,11 +53,10 @@ class SubHeartflow:
if not self.current_mind:
self.current_mind = "你什么也没想"
-
self.is_active = False
self.observations: list[Observation] = []
-
+
self.running_knowledges = []
def add_observation(self, observation: Observation):
@@ -86,7 +85,9 @@ class SubHeartflow:
async def subheartflow_start_working(self):
while True:
current_time = time.time()
- if current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time: # 120秒无回复/不在场,冻结
+ if (
+ current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
+ ): # 120秒无回复/不在场,冻结
self.is_active = False
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
else:
@@ -100,7 +101,9 @@ class SubHeartflow:
await asyncio.sleep(global_config.sub_heart_flow_update_interval)
# 检查是否超过10分钟没有激活
- if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 5分钟无回复/不在场,销毁
+ if (
+ current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
+ ): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
break # 退出循环,销毁自己
@@ -147,11 +150,11 @@ class SubHeartflow:
# self.current_mind = reponse
# logger.debug(f"prompt:\n{prompt}\n")
# logger.info(f"麦麦的脑内状态:{self.current_mind}")
-
+
async def do_observe(self):
observation = self.observations[0]
await observation.observe()
-
+
async def do_thinking_before_reply(self, message_txt):
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -162,23 +165,20 @@ class SubHeartflow:
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
# 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -191,7 +191,7 @@ class SubHeartflow:
else:
related_memory_info = ""
- related_info,grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
+ related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
# print(related_info)
for _topic, results in grouped_results.items():
for result in results:
@@ -227,25 +227,23 @@ class SubHeartflow:
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
# print("麦麦回复之后脑袋转起来了")
-
+
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -279,22 +277,20 @@ class SubHeartflow:
async def judge_willing(self):
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
# print("麦麦闹情绪了1")
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -320,13 +316,12 @@ class SubHeartflow:
def update_current_mind(self, reponse):
self.past_mind.append(self.current_mind)
self.current_mind = reponse
-
-
+
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
-
+
# 1. 先从LLM获取主题,类似于记忆系统的做法
topics = []
# try:
@@ -334,7 +329,7 @@ class SubHeartflow:
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
-
+
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
@@ -345,7 +340,7 @@ class SubHeartflow:
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
-
+
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -353,7 +348,7 @@ class SubHeartflow:
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
-
+
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询")
@@ -361,26 +356,26 @@ class SubHeartflow:
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
-
+
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
return related_info, {}
-
+
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
-
+
# 优化:批量获取嵌入向量,减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
-
+
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
-
+
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
@@ -389,17 +384,17 @@ class SubHeartflow:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
-
+
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
-
+
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
-
+
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
-
+
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -408,12 +403,12 @@ class SubHeartflow:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
-
+
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
-
+
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
@@ -424,9 +419,9 @@ class SubHeartflow:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
-
+
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
-
+
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
@@ -436,14 +431,16 @@ class SubHeartflow:
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
-
+
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
-
+
# 6. 限制总数量(最多10条)
filtered_results = filtered_results[:10]
- logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
-
+ logger.info(
+ f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
+ )
+
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
@@ -453,7 +450,7 @@ class SubHeartflow:
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
-
+
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}】\n"
@@ -464,13 +461,15 @@ class SubHeartflow:
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
-
- logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
-
- logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
- return related_info,grouped_results
- def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
+ logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
+
+ logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
+ return related_info, grouped_results
+
+ def get_info_from_db(
+ self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
+ ) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
diff --git a/src/individuality/identity.py b/src/individuality/identity.py
index 6704562ec..89cba98d1 100644
--- a/src/individuality/identity.py
+++ b/src/individuality/identity.py
@@ -2,27 +2,36 @@ from dataclasses import dataclass
from typing import List
import random
+
@dataclass
class Identity:
"""身份特征类"""
+
identity_detail: List[str] # 身份细节描述
height: int # 身高(厘米)
weight: int # 体重(千克)
age: int # 年龄
gender: str # 性别
appearance: str # 外貌特征
-
+
_instance = None
-
+
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
-
- def __init__(self, identity_detail: List[str] = None, height: int = 0, weight: int = 0,
- age: int = 0, gender: str = "", appearance: str = ""):
+
+ def __init__(
+ self,
+ identity_detail: List[str] = None,
+ height: int = 0,
+ weight: int = 0,
+ age: int = 0,
+ gender: str = "",
+ appearance: str = "",
+ ):
"""初始化身份特征
-
+
Args:
identity_detail: 身份细节描述列表
height: 身高(厘米)
@@ -39,23 +48,24 @@ class Identity:
self.age = age
self.gender = gender
self.appearance = appearance
-
+
@classmethod
- def get_instance(cls) -> 'Identity':
+ def get_instance(cls) -> "Identity":
"""获取Identity单例实例
-
+
Returns:
Identity: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
-
+
@classmethod
- def initialize(cls, identity_detail: List[str], height: int, weight: int,
- age: int, gender: str, appearance: str) -> 'Identity':
+ def initialize(
+ cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str
+ ) -> "Identity":
"""初始化身份特征
-
+
Args:
identity_detail: 身份细节描述列表
height: 身高(厘米)
@@ -63,7 +73,7 @@ class Identity:
age: 年龄
gender: 性别
appearance: 外貌特征
-
+
Returns:
Identity: 初始化后的身份特征实例
"""
@@ -75,8 +85,8 @@ class Identity:
instance.gender = gender
instance.appearance = appearance
return instance
-
- def get_prompt(self,x_person,level):
+
+ def get_prompt(self, x_person, level):
"""
获取身份特征的prompt
"""
@@ -86,7 +96,7 @@ class Identity:
prompt_identity = "我"
else:
prompt_identity = "他"
-
+
if level == 1:
identity_detail = self.identity_detail
random.shuffle(identity_detail)
@@ -96,7 +106,7 @@ class Identity:
prompt_identity += f",{detail}"
prompt_identity += "。"
return prompt_identity
-
+
def to_dict(self) -> dict:
"""将身份特征转换为字典格式"""
return {
@@ -105,13 +115,13 @@ class Identity:
"weight": self.weight,
"age": self.age,
"gender": self.gender,
- "appearance": self.appearance
+ "appearance": self.appearance,
}
-
+
@classmethod
- def from_dict(cls, data: dict) -> 'Identity':
+ def from_dict(cls, data: dict) -> "Identity":
"""从字典创建身份特征实例"""
instance = cls.get_instance()
for key, value in data.items():
setattr(instance, key, value)
- return instance
\ No newline at end of file
+ return instance
diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py
index b491ed308..e7616ec27 100644
--- a/src/individuality/individuality.py
+++ b/src/individuality/individuality.py
@@ -2,35 +2,46 @@ from typing import Optional
from .personality import Personality
from .identity import Identity
+
class Individuality:
"""个体特征管理类"""
+
_instance = None
-
+
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
-
+
def __init__(self):
self.personality: Optional[Personality] = None
self.identity: Optional[Identity] = None
-
+
@classmethod
- def get_instance(cls) -> 'Individuality':
+ def get_instance(cls) -> "Individuality":
"""获取Individuality单例实例
-
+
Returns:
Individuality: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
-
- def initialize(self, bot_nickname: str, personality_core: str, personality_sides: list,
- identity_detail: list, height: int, weight: int, age: int,
- gender: str, appearance: str) -> None:
+
+ def initialize(
+ self,
+ bot_nickname: str,
+ personality_core: str,
+ personality_sides: list,
+ identity_detail: list,
+ height: int,
+ weight: int,
+ age: int,
+ gender: str,
+ appearance: str,
+ ) -> None:
"""初始化个体特征
-
+
Args:
bot_nickname: 机器人昵称
personality_core: 人格核心特点
@@ -44,50 +55,43 @@ class Individuality:
"""
# 初始化人格
self.personality = Personality.initialize(
- bot_nickname=bot_nickname,
- personality_core=personality_core,
- personality_sides=personality_sides
+ bot_nickname=bot_nickname, personality_core=personality_core, personality_sides=personality_sides
)
-
+
# 初始化身份
self.identity = Identity.initialize(
- identity_detail=identity_detail,
- height=height,
- weight=weight,
- age=age,
- gender=gender,
- appearance=appearance
+ identity_detail=identity_detail, height=height, weight=weight, age=age, gender=gender, appearance=appearance
)
-
+
def to_dict(self) -> dict:
"""将个体特征转换为字典格式"""
return {
"personality": self.personality.to_dict() if self.personality else None,
- "identity": self.identity.to_dict() if self.identity else None
+ "identity": self.identity.to_dict() if self.identity else None,
}
-
+
@classmethod
- def from_dict(cls, data: dict) -> 'Individuality':
+ def from_dict(cls, data: dict) -> "Individuality":
"""从字典创建个体特征实例"""
instance = cls.get_instance()
if data.get("personality"):
instance.personality = Personality.from_dict(data["personality"])
if data.get("identity"):
instance.identity = Identity.from_dict(data["identity"])
- return instance
-
- def get_prompt(self,type,x_person,level):
+ return instance
+
+ def get_prompt(self, type, x_person, level):
"""
获取个体特征的prompt
"""
if type == "personality":
- return self.personality.get_prompt(x_person,level)
+ return self.personality.get_prompt(x_person, level)
elif type == "identity":
- return self.identity.get_prompt(x_person,level)
+ return self.identity.get_prompt(x_person, level)
else:
return ""
-
- def get_traits(self,factor):
+
+ def get_traits(self, factor):
"""
获取个体特征的特质
"""
@@ -101,5 +105,3 @@ class Individuality:
return self.personality.agreeableness
elif factor == "neuroticism":
return self.personality.neuroticism
-
-
diff --git a/src/individuality/per_bf_gen.py b/src/individuality/per_bf_gen.py
index 0a8b2e4a7..d898ea5e3 100644
--- a/src/individuality/per_bf_gen.py
+++ b/src/individuality/per_bf_gen.py
@@ -17,9 +17,9 @@ with open(config_path, "r", encoding="utf-8") as f:
config = toml.load(f)
# 现在可以导入src模块
-from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES #noqa E402
-from src.individuality.questionnaire import FACTOR_DESCRIPTIONS #noqa E402
-from src.individuality.offline_llm import LLM_request_off #noqa E402
+from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402
+from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
+from src.individuality.offline_llm import LLM_request_off # noqa E402
# 加载环境变量
env_path = os.path.join(root_path, ".env")
@@ -32,13 +32,12 @@ else:
def adapt_scene(scene: str) -> str:
-
- personality_core = config['personality']['personality_core']
- personality_sides = config['personality']['personality_sides']
+ personality_core = config["personality"]["personality_core"]
+ personality_sides = config["personality"]["personality_sides"]
personality_side = random.choice(personality_sides)
- identity_details = config['identity']['identity_detail']
+ identity_details = config["identity"]["identity_detail"]
identity_detail = random.choice(identity_details)
-
+
"""
根据config中的属性,改编场景使其更适合当前角色
@@ -51,10 +50,10 @@ def adapt_scene(scene: str) -> str:
try:
prompt = f"""
这是一个参与人格测评的角色形象:
-- 昵称: {config['bot']['nickname']}
-- 性别: {config['identity']['gender']}
-- 年龄: {config['identity']['age']}岁
-- 外貌: {config['identity']['appearance']}
+- 昵称: {config["bot"]["nickname"]}
+- 性别: {config["identity"]["gender"]}
+- 年龄: {config["identity"]["age"]}岁
+- 外貌: {config["identity"]["appearance"]}
- 性格核心: {personality_core}
- 性格侧面: {personality_side}
- 身份细节: {identity_detail}
@@ -62,18 +61,18 @@ def adapt_scene(scene: str) -> str:
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
{scene}
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
-改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config['bot']['nickname']}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
+改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config["bot"]["nickname"]}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
现在,请你给出改编后的场景描述
"""
- llm = LLM_request_off(model_name=config['model']['llm_normal']['name'])
+ llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"])
adapted_scene, _ = llm.generate_response(prompt)
-
+
# 检查返回的场景是否为空或错误信息
if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene:
print("场景改编失败,将使用原始场景")
return scene
-
+
return adapted_scene
except Exception as e:
print(f"场景改编过程出错:{str(e)},将使用原始场景")
@@ -169,7 +168,7 @@ class PersonalityEvaluator_direct:
except Exception as e:
print(f"评估过程出错:{str(e)}")
return {dim: 3.5 for dim in dimensions}
-
+
def run_evaluation(self):
"""
运行整个评估过程
@@ -185,18 +184,23 @@ class PersonalityEvaluator_direct:
print(f"- 身份细节:{config['identity']['identity_detail']}")
print("\n准备好了吗?按回车键开始...")
input()
-
+
total_scenarios = len(self.scenarios)
- progress_bar = tqdm(total=total_scenarios, desc="场景进度", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
+ progress_bar = tqdm(
+ total=total_scenarios,
+ desc="场景进度",
+ ncols=100,
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
+ )
for _i, scenario_data in enumerate(self.scenarios, 1):
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
-
+
# 改编场景,使其更适合当前角色
print(f"{config['bot']['nickname']}祈祷中...")
adapted_scene = adapt_scene(scenario_data["场景"])
scenario_data["改编场景"] = adapted_scene
-
+
print(adapted_scene)
print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:")
response = input().strip()
@@ -220,13 +224,13 @@ class PersonalityEvaluator_direct:
# 更新进度条
progress_bar.update(1)
-
+
# if i < total_scenarios:
- # print("\n按回车键继续下一个场景...")
- # input()
-
+ # print("\n按回车键继续下一个场景...")
+ # input()
+
progress_bar.close()
-
+
# 计算平均分
for dimension in self.final_scores:
if self.dimension_counts[dimension] > 0:
@@ -241,26 +245,26 @@ class PersonalityEvaluator_direct:
# 返回评估结果
return self.get_result()
-
+
def get_result(self):
"""
获取评估结果
"""
return {
- "final_scores": self.final_scores,
- "dimension_counts": self.dimension_counts,
+ "final_scores": self.final_scores,
+ "dimension_counts": self.dimension_counts,
"scenarios": self.scenarios,
"bot_info": {
- "nickname": config['bot']['nickname'],
- "gender": config['identity']['gender'],
- "age": config['identity']['age'],
- "height": config['identity']['height'],
- "weight": config['identity']['weight'],
- "appearance": config['identity']['appearance'],
- "personality_core": config['personality']['personality_core'],
- "personality_sides": config['personality']['personality_sides'],
- "identity_detail": config['identity']['identity_detail']
- }
+ "nickname": config["bot"]["nickname"],
+ "gender": config["identity"]["gender"],
+ "age": config["identity"]["age"],
+ "height": config["identity"]["height"],
+ "weight": config["identity"]["weight"],
+ "appearance": config["identity"]["appearance"],
+ "personality_core": config["personality"]["personality_core"],
+ "personality_sides": config["personality"]["personality_sides"],
+ "identity_detail": config["identity"]["identity_detail"],
+ },
}
@@ -275,28 +279,28 @@ def main():
"extraversion": round(result["final_scores"]["外向性"] / 6, 1),
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
- "bot_nickname": config['bot']['nickname']
+ "bot_nickname": config["bot"]["nickname"],
}
# 确保目录存在
save_dir = os.path.join(root_path, "data", "personality")
os.makedirs(save_dir, exist_ok=True)
-
+
# 创建文件名,替换可能的非法字符
- bot_name = config['bot']['nickname']
+ bot_name = config["bot"]["nickname"]
# 替换Windows文件名中不允许的字符
- for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']:
- bot_name = bot_name.replace(char, '_')
-
+ for char in ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]:
+ bot_name = bot_name.replace(char, "_")
+
file_name = f"{bot_name}_personality.per"
save_path = os.path.join(save_dir, file_name)
-
+
# 保存简化的结果
with open(save_path, "w", encoding="utf-8") as f:
json.dump(simplified_result, f, ensure_ascii=False, indent=4)
print(f"\n结果已保存到 {save_path}")
-
+
# 同时保存完整结果到results目录
os.makedirs("results", exist_ok=True)
with open("results/personality_result.json", "w", encoding="utf-8") as f:
diff --git a/src/individuality/personality.py b/src/individuality/personality.py
index eb822ab1f..1b05f2d91 100644
--- a/src/individuality/personality.py
+++ b/src/individuality/personality.py
@@ -4,9 +4,11 @@ import json
from pathlib import Path
import random
+
@dataclass
class Personality:
"""人格特质类"""
+
openness: float # 开放性
conscientiousness: float # 尽责性
extraversion: float # 外向性
@@ -15,45 +17,45 @@ class Personality:
bot_nickname: str # 机器人昵称
personality_core: str # 人格核心特点
personality_sides: List[str] # 人格侧面描述
-
+
_instance = None
-
+
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
-
+
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
if personality_sides is None:
personality_sides = []
self.personality_core = personality_core
self.personality_sides = personality_sides
-
+
@classmethod
- def get_instance(cls) -> 'Personality':
+ def get_instance(cls) -> "Personality":
"""获取Personality单例实例
-
+
Returns:
Personality: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
-
+
def _init_big_five_personality(self):
"""初始化大五人格特质"""
# 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
-
+
# 如果文件存在,读取文件
if personality_file.exists():
- with open(personality_file, 'r', encoding='utf-8') as f:
+ with open(personality_file, "r", encoding="utf-8") as f:
personality_data = json.load(f)
- self.openness = personality_data.get('openness', 0.5)
- self.conscientiousness = personality_data.get('conscientiousness', 0.5)
- self.extraversion = personality_data.get('extraversion', 0.5)
- self.agreeableness = personality_data.get('agreeableness', 0.5)
- self.neuroticism = personality_data.get('neuroticism', 0.5)
+ self.openness = personality_data.get("openness", 0.5)
+ self.conscientiousness = personality_data.get("conscientiousness", 0.5)
+ self.extraversion = personality_data.get("extraversion", 0.5)
+ self.agreeableness = personality_data.get("agreeableness", 0.5)
+ self.neuroticism = personality_data.get("neuroticism", 0.5)
else:
# 如果文件不存在,根据personality_core和personality_core来设置大五人格特质
if "活泼" in self.personality_core or "开朗" in self.personality_sides:
@@ -62,31 +64,31 @@ class Personality:
else:
self.extraversion = 0.3
self.neuroticism = 0.5
-
+
if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9
else:
self.conscientiousness = 0.5
-
+
if "友善" in self.personality_core or "温柔" in self.personality_sides:
self.agreeableness = 0.9
else:
self.agreeableness = 0.5
-
+
if "创新" in self.personality_core or "开放" in self.personality_sides:
self.openness = 0.8
else:
self.openness = 0.5
-
+
@classmethod
- def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> 'Personality':
+ def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> "Personality":
"""初始化人格特质
-
+
Args:
bot_nickname: 机器人昵称
personality_core: 人格核心特点
personality_sides: 人格侧面描述
-
+
Returns:
Personality: 初始化后的人格特质实例
"""
@@ -96,7 +98,7 @@ class Personality:
instance.personality_sides = personality_sides
instance._init_big_five_personality()
return instance
-
+
def to_dict(self) -> Dict:
"""将人格特质转换为字典格式"""
return {
@@ -107,18 +109,18 @@ class Personality:
"neuroticism": self.neuroticism,
"bot_nickname": self.bot_nickname,
"personality_core": self.personality_core,
- "personality_sides": self.personality_sides
+ "personality_sides": self.personality_sides,
}
-
+
@classmethod
- def from_dict(cls, data: Dict) -> 'Personality':
+ def from_dict(cls, data: Dict) -> "Personality":
"""从字典创建人格特质实例"""
instance = cls.get_instance()
for key, value in data.items():
setattr(instance, key, value)
- return instance
-
- def get_prompt(self,x_person,level):
+ return instance
+
+ def get_prompt(self, x_person, level):
# 开始构建prompt
if x_person == 2:
prompt_personality = "你"
@@ -126,10 +128,10 @@ class Personality:
prompt_personality = "我"
else:
prompt_personality = "他"
- #person
-
+ # person
+
prompt_personality += self.personality_core
-
+
if level == 2:
personality_sides = self.personality_sides
random.shuffle(personality_sides)
@@ -140,5 +142,5 @@ class Personality:
prompt_personality += f",{side}"
prompt_personality += "。"
-
+
return prompt_personality
diff --git a/src/individuality/scene.py b/src/individuality/scene.py
index b94d55046..76304dbbd 100644
--- a/src/individuality/scene.py
+++ b/src/individuality/scene.py
@@ -2,6 +2,7 @@ import json
from typing import Dict
import os
+
def load_scenes() -> Dict:
"""
从JSON文件加载场景数据
@@ -10,13 +11,15 @@ def load_scenes() -> Dict:
Dict: 包含所有场景的字典
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
- json_path = os.path.join(current_dir, 'template_scene.json')
-
- with open(json_path, 'r', encoding='utf-8') as f:
+ json_path = os.path.join(current_dir, "template_scene.json")
+
+ with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
+
PERSONALITY_SCENES = load_scenes()
+
def get_scene_by_factor(factor: str) -> Dict:
"""
根据人格因子获取对应的情景测试
diff --git a/src/main.py b/src/main.py
index b3cf07e12..aa6f908bf 100644
--- a/src/main.py
+++ b/src/main.py
@@ -100,7 +100,7 @@ class MainSystem:
weight=global_config.weight,
age=global_config.age,
gender=global_config.gender,
- appearance=global_config.appearance
+ appearance=global_config.appearance,
)
logger.success("个体特征初始化成功")
@@ -135,7 +135,6 @@ class MainSystem:
await asyncio.sleep(global_config.build_memory_interval)
logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory()
-
async def forget_memory_task(self):
"""记忆遗忘任务"""
@@ -144,7 +143,6 @@ class MainSystem:
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
-
async def print_mood_task(self):
"""打印情绪状态"""
diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py
index 0a6b3bfb6..395b14043 100644
--- a/src/plugins/PFC/chat_observer.py
+++ b/src/plugins/PFC/chat_observer.py
@@ -1,6 +1,6 @@
import time
import asyncio
-from typing import Optional, Dict, Any, List, Tuple
+from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger
from ..message.message_base import UserInfo
from ..config.config import global_config
@@ -9,16 +9,17 @@ from .message_storage import MessageStorage, MongoDBMessageStorage
logger = get_module_logger("chat_observer")
+
class ChatObserver:
"""聊天状态观察器"""
-
+
# 类级别的实例管理
- _instances: Dict[str, 'ChatObserver'] = {}
-
+ _instances: Dict[str, "ChatObserver"] = {}
+
@classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
"""获取或创建观察器实例
-
+
Args:
stream_id: 聊天流ID
message_storage: 消息存储实现,如果为None则使用MongoDB实现
@@ -32,14 +33,14 @@ class ChatObserver:
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
"""初始化观察器
-
+
Args:
stream_id: 聊天流ID
message_storage: 消息存储实现,如果为None则使用MongoDB实现
"""
if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
-
+
self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage()
@@ -53,9 +54,9 @@ class ChatObserver:
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
- self.last_message_id: Optional[str] = None # 最后一条消息的ID
- self.message_count: int = 0 # 消息计数
-
+ self.last_message_id: Optional[str] = None # 最后一条消息的ID
+ self.message_count: int = 0 # 消息计数
+
# 运行状态
self._running: bool = False
self._task: Optional[asyncio.Task] = None
@@ -77,7 +78,7 @@ class ChatObserver:
async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
-
+
Returns:
bool: 是否有新消息
"""
@@ -91,7 +92,7 @@ class ChatObserver:
if new_message_exists:
logger.debug("发现新消息")
self.last_check_time = time.time()
-
+
return new_message_exists
async def _add_message_to_history(self, message: Dict[str, Any]):
@@ -104,7 +105,7 @@ class ChatObserver:
self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1
-
+
# 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
@@ -186,41 +187,40 @@ class ChatObserver:
start_time: Optional[float] = None,
end_time: Optional[float] = None,
limit: Optional[int] = None,
- user_id: Optional[str] = None
+ user_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""获取消息历史
-
+
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回消息数量
user_id: 指定用户ID
-
+
Returns:
List[Dict[str, Any]]: 消息列表
"""
filtered_messages = self.message_history
-
+
if start_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
-
+
if end_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
-
+
if user_id is not None:
filtered_messages = [
- m for m in filtered_messages
- if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
+ m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
]
-
+
if limit is not None:
filtered_messages = filtered_messages[-limit:]
-
+
return filtered_messages
-
+
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
-
+
Returns:
List[Dict[str, Any]]: 新消息列表
"""
@@ -231,15 +231,15 @@ class ChatObserver:
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
-
+
return new_messages
-
+
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
-
+
Args:
time_point: 时间戳
-
+
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
@@ -250,7 +250,7 @@ class ChatObserver:
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
-
+
return new_messages
'''主要观察循环'''
@@ -263,7 +263,7 @@ class ChatObserver:
await self._add_message_to_history(message)
except Exception as e:
logger.error(f"缓冲消息出错: {e}")
-
+
while self._running:
try:
# 等待事件或超时(1秒)
@@ -271,13 +271,13 @@ class ChatObserver:
await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError:
pass # 超时后也执行一次检查
-
+
self._update_event.clear() # 重置触发事件
self._update_complete.clear() # 重置完成事件
-
+
# 获取新消息
new_messages = await self._fetch_new_messages()
-
+
if new_messages:
# 处理新消息
for message in new_messages:
@@ -285,21 +285,21 @@ class ChatObserver:
# 设置完成事件
self._update_complete.set()
-
+
except Exception as e:
logger.error(f"更新循环出错: {e}")
self._update_complete.set() # 即使出错也要设置完成事件
-
+
def trigger_update(self):
"""触发一次立即更新"""
self._update_event.set()
-
+
async def wait_for_update(self, timeout: float = 5.0) -> bool:
"""等待更新完成
-
+
Args:
timeout: 超时时间(秒)
-
+
Returns:
bool: 是否成功完成更新(False表示超时)
"""
@@ -309,16 +309,16 @@ class ChatObserver:
except asyncio.TimeoutError:
logger.warning(f"等待更新完成超时({timeout}秒)")
return False
-
+
def start(self):
"""启动观察器"""
if self._running:
return
-
+
self._running = True
self._task = asyncio.create_task(self._update_loop())
logger.info(f"ChatObserver for {self.stream_id} started")
-
+
def stop(self):
"""停止观察器"""
self._running = False
@@ -327,15 +327,15 @@ class ChatObserver:
if self._task:
self._task.cancel()
logger.info(f"ChatObserver for {self.stream_id} stopped")
-
+
async def process_chat_history(self, messages: list):
"""处理聊天历史
-
+
Args:
messages: 消息列表
"""
self.update_check_time()
-
+
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
@@ -345,33 +345,33 @@ class ChatObserver:
self.update_user_speak_time(msg["time"])
except Exception as e:
logger.warning(f"处理消息时间时出错: {e}")
- continue
-
+ continue
+
def update_check_time(self):
"""更新查看时间"""
self.last_check_time = time.time()
-
+
def update_bot_speak_time(self, speak_time: Optional[float] = None):
"""更新机器人说话时间"""
self.last_bot_speak_time = speak_time or time.time()
-
+
def update_user_speak_time(self, speak_time: Optional[float] = None):
"""更新用户说话时间"""
self.last_user_speak_time = speak_time or time.time()
-
+
def get_time_info(self) -> str:
"""获取时间信息文本"""
current_time = time.time()
time_info = ""
-
+
if self.last_bot_speak_time:
bot_speak_ago = current_time - self.last_bot_speak_time
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
-
+
if self.last_user_speak_time:
user_speak_ago = current_time - self.last_user_speak_time
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
-
+
return time_info
def start_periodic_update(self):
diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py
index 65a2ac982..96c147650 100644
--- a/src/plugins/PFC/pfc.py
+++ b/src/plugins/PFC/pfc.py
@@ -1,5 +1,5 @@
-#Programmable Friendly Conversationalist
-#Prefrontal cortex
+# Programmable Friendly Conversationalist
+# Prefrontal cortex
import datetime
import asyncio
from typing import List, Optional, Tuple, TYPE_CHECKING
@@ -29,20 +29,17 @@ logger = get_module_logger("pfc")
class GoalAnalyzer:
"""对话目标分析器"""
-
+
def __init__(self, stream_id: str):
self.llm = LLM_request(
- model=global_config.llm_normal,
- temperature=0.7,
- max_tokens=1000,
- request_type="conversation_goal"
+ model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
-
- self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
+
+ self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.chat_observer = ChatObserver.get_instance(stream_id)
-
+
# 多目标存储结构
self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量
@@ -50,10 +47,10 @@ class GoalAnalyzer:
async def analyze_goal(self) -> Tuple[str, str, str]:
"""分析对话历史并设定目标
-
+
Args:
chat_history: 聊天历史记录列表
-
+
Returns:
Tuple[str, str, str]: (目标, 方法, 原因)
"""
@@ -70,16 +67,16 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
-
+
personality_text = f"你的名字是{self.name},{self.personality_info}"
-
+
# 构建当前已有目标的文本
existing_goals_text = ""
if self.goals:
existing_goals_text = "当前已有的对话目标:\n"
for i, (goal, _, reason) in enumerate(self.goals):
- existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
-
+ existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
+
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。
@@ -107,46 +104,44 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
-
+
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
- content,
- "goal", "reasoning",
- required_types={"goal": str, "reasoning": str}
+ content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
-
+
if not success:
logger.error(f"无法解析JSON,重试第{retry + 1}次")
continue
-
+
goal = result["goal"]
reasoning = result["reasoning"]
-
+
# 使用默认的方法
method = "以友好的态度回应"
-
+
# 更新目标列表
await self._update_goals(goal, method, reasoning)
-
+
# 返回当前最主要的目标
if self.goals:
current_goal, current_method, current_reasoning = self.goals[0]
return current_goal, current_method, current_reasoning
else:
return goal, method, reasoning
-
+
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
if retry == max_retries - 1:
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
continue
-
+
# 所有重试都失败后的默认返回
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
-
+
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
-
+
Args:
new_goal: 新的目标
method: 实现目标的方法
@@ -160,23 +155,23 @@ class GoalAnalyzer:
# 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i))
return
-
+
# 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning))
-
+
# 限制目标数量
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
-
+
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
-
+
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
-
+
Args:
goal1: 第一个目标
goal2: 第二个目标
-
+
Returns:
float: 相似度得分 (0-1)
"""
@@ -186,18 +181,18 @@ class GoalAnalyzer:
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
-
+
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标
-
+
Returns:
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
"""
return self.goals.copy()
-
+
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标
-
+
Returns:
List[Tuple[str, str, str]]: 备选目标列表
"""
@@ -215,9 +210,9 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
-
+
personality_text = f"你的名字是{self.name},{self.personality_info}"
-
+
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,
当前对话目标:{goal}
产生该对话目标的原因:{reasoning}
@@ -247,7 +242,7 @@ class GoalAnalyzer:
"goal_achieved", "stop_conversation", "reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
)
-
+
if not success:
logger.error("无法解析对话分析结果JSON")
return False, False, "解析结果失败"
@@ -265,14 +260,15 @@ class GoalAnalyzer:
class Waiter:
"""快 速 等 待"""
+
def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id)
- self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
+ self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
-
+
async def wait(self) -> bool:
"""等待
-
+
Returns:
bool: 是否超时(True表示超时)
"""
@@ -298,7 +294,7 @@ class Waiter:
class DirectMessageSender:
"""直接发送消息到平台的发送器"""
-
+
def __init__(self):
self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage()
@@ -310,7 +306,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None,
) -> None:
"""直接发送消息到平台
-
+
Args:
chat_stream: 聊天流
content: 消息内容
@@ -323,7 +319,7 @@ class DirectMessageSender:
user_nickname=global_config.BOT_NICKNAME,
platform=chat_stream.platform,
)
-
+
message = MessageSending(
message_id=f"dm{round(time.time(), 2)}",
chat_stream=chat_stream,
@@ -343,18 +339,17 @@ class DirectMessageSender:
try:
message_json = message.to_dict()
end_point = global_config.api_urls.get(chat_stream.platform, None)
-
+
if not end_point:
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
-
+
await global_api.send_message_REST(end_point, message_json)
-
+
# 存储消息
await self.storage.store_message(message, message.chat_stream)
-
+
self.logger.info(f"直接发送消息成功: {content[:30]}...")
-
+
except Exception as e:
self.logger.error(f"直接发送消息失败: {str(e)}")
raise
-
diff --git a/src/plugins/PFC/pfc_KnowledgeFetcher.py b/src/plugins/PFC/pfc_KnowledgeFetcher.py
index 560283f25..b4041bb34 100644
--- a/src/plugins/PFC/pfc_KnowledgeFetcher.py
+++ b/src/plugins/PFC/pfc_KnowledgeFetcher.py
@@ -7,24 +7,22 @@ from ..chat.message import Message
logger = get_module_logger("knowledge_fetcher")
+
class KnowledgeFetcher:
"""知识调取器"""
-
+
def __init__(self):
self.llm = LLM_request(
- model=global_config.llm_normal,
- temperature=0.7,
- max_tokens=1000,
- request_type="knowledge_fetch"
+ model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
)
-
+
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
-
+
Args:
query: 查询内容
chat_history: 聊天历史
-
+
Returns:
Tuple[str, str]: (获取的知识, 知识来源)
"""
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
for msg in chat_history:
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
chat_history_text += f"{msg.detailed_plain_text}\n"
-
+
# 从记忆中获取相关知识
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=f"{query}\n{chat_history_text}",
max_memory_num=3,
max_memory_length=2,
max_depth=3,
- fast_retrieval=False
+ fast_retrieval=False,
)
-
+
if related_memory:
knowledge = ""
sources = []
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
knowledge += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}")
return knowledge.strip(), ",".join(sources)
-
- return "未找到相关知识", "无记忆匹配"
\ No newline at end of file
+
+ return "未找到相关知识", "无记忆匹配"
diff --git a/src/plugins/PFC/pfc_utils.py b/src/plugins/PFC/pfc_utils.py
index 9d0278b02..633d9016e 100644
--- a/src/plugins/PFC/pfc_utils.py
+++ b/src/plugins/PFC/pfc_utils.py
@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
+
def get_items_from_json(
content: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
- required_types: Optional[Dict[str, type]] = None
+ required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段
-
+
Args:
content: 包含JSON的文本
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
-
+
Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
"""
content = content.strip()
result = {}
-
+
# 设置默认值
if default_values:
result.update(default_values)
-
+
# 尝试解析JSON
try:
json_data = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败,尝试查找和提取JSON部分
- json_pattern = r'\{[^{}]*\}'
+ json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -45,28 +46,28 @@ def get_items_from_json(
else:
logger.error("无法在返回内容中找到有效的JSON")
return False, result
-
+
# 提取字段
for item in items:
if item in json_data:
result[item] = json_data[item]
-
+
# 验证必需字段
if not all(item in result for item in items):
logger.error(f"JSON缺少必要字段,实际内容: {json_data}")
return False, result
-
+
# 验证字段类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
return False, result
-
+
# 验证字符串字段不为空
for field in items:
if isinstance(result[field], str) and not result[field].strip():
logger.error(f"{field} 不能为空")
return False, result
-
- return True, result
\ No newline at end of file
+
+ return True, result
diff --git a/src/plugins/PFC/reply_checker.py b/src/plugins/PFC/reply_checker.py
index 3d8c743f2..c53feba9b 100644
--- a/src/plugins/PFC/reply_checker.py
+++ b/src/plugins/PFC/reply_checker.py
@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
logger = get_module_logger("reply_checker")
+
class ReplyChecker:
"""回复检查器"""
-
+
def __init__(self, stream_id: str):
self.llm = LLM_request(
- model=global_config.llm_normal,
- temperature=0.7,
- max_tokens=1000,
- request_type="reply_check"
+ model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.max_retries = 2 # 最大重试次数
-
- async def check(
- self,
- reply: str,
- goal: str,
- retry_count: int = 0
- ) -> Tuple[bool, str, bool]:
+
+ async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适
-
+
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
-
+
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
@@ -49,7 +42,7 @@ class ReplyChecker:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
-
+
prompt = f"""请检查以下回复是否合适:
当前对话目标:{goal}
@@ -83,7 +76,7 @@ class ReplyChecker:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"检查回复的原始返回: {content}")
-
+
# 清理内容,尝试提取JSON部分
content = content.strip()
try:
@@ -92,7 +85,8 @@ class ReplyChecker:
except json.JSONDecodeError:
# 如果直接解析失败,尝试查找和提取JSON部分
import re
- json_pattern = r'\{[^{}]*\}'
+
+ json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -109,33 +103,33 @@ class ReplyChecker:
reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan
-
+
# 验证JSON字段
suitable = result.get("suitable", None)
reason = result.get("reason", "未提供原因")
need_replan = result.get("need_replan", False)
-
+
# 如果suitable字段是字符串,转换为布尔值
if isinstance(suitable, str):
suitable = suitable.lower() == "true"
-
+
# 如果suitable字段不存在或不是布尔值,从reason中判断
if suitable is None:
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
-
+
# 如果不合适且未达到最大重试次数,返回需要重试
if not suitable and retry_count < self.max_retries:
return False, reason, False
-
+
# 如果不合适且已达到最大重试次数,返回需要重新规划
if not suitable and retry_count >= self.max_retries:
return False, f"多次重试后仍不合适: {reason}", True
-
+
return suitable, reason, need_replan
-
+
except Exception as e:
logger.error(f"检查回复时出错: {e}")
# 如果出错且已达到最大重试次数,建议重新规划
if retry_count >= self.max_retries:
return False, "多次检查失败,建议重新规划", True
- return False, f"检查过程出错,建议重试: {str(e)}", False
\ No newline at end of file
+ return False, f"检查过程出错,建议重试: {str(e)}", False
diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py
index e5cef56a5..6d2455202 100644
--- a/src/plugins/chat/__init__.py
+++ b/src/plugins/chat/__init__.py
@@ -12,5 +12,5 @@ __all__ = [
"chat_manager",
"message_manager",
"MessageStorage",
- "auto_speak_manager"
+ "auto_speak_manager",
]
diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py
index 119d1aa01..40a00a3ab 100644
--- a/src/plugins/chat/bot.py
+++ b/src/plugins/chat/bot.py
@@ -44,11 +44,11 @@ class ChatBot:
async def _create_PFC_chat(self, message: MessageRecv):
try:
chat_id = str(message.chat_stream.stream_id)
-
+
if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id)
-
+
except Exception as e:
logger.error(f"创建PFC聊天失败: {e}")
@@ -59,16 +59,16 @@ class ChatBot:
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
-
+
2. reasoning模式:使用推理系统进行回复
- 直接使用意愿管理器计算回复概率
- 没有思维流相关的状态管理
- 更简单直接的回复逻辑
-
+
3. pfc_chatting模式:仅进行消息处理
- 不进行任何回复
- 只处理和存储消息
-
+
所有模式都包含:
- 消息过滤
- 记忆激活
@@ -89,7 +89,7 @@ class ChatBot:
if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
-
+
if global_config.enable_pfc_chatting:
try:
if groupinfo is None and global_config.enable_friend_chat:
@@ -118,7 +118,7 @@ class ChatBot:
logger.error(f"处理PFC消息失败: {e}")
else:
if groupinfo is None and global_config.enable_friend_chat:
- # 私聊处理流程
+ # 私聊处理流程
# await self._handle_private_chat(message)
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py
index 6247bf405..6d070c83f 100644
--- a/src/plugins/chat/emoji_manager.py
+++ b/src/plugins/chat/emoji_manager.py
@@ -38,11 +38,11 @@ class EmojiManager:
self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
-
+
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
-
+
logger.info("启动表情包管理器")
def _ensure_emoji_dir(self):
@@ -51,7 +51,7 @@ class EmojiManager:
def _update_emoji_count(self):
"""更新表情包数量统计
-
+
检查数据库中的表情包数量并更新到 self.emoji_num
"""
try:
@@ -376,7 +376,6 @@ class EmojiManager:
except Exception:
logger.exception("[错误] 扫描表情包失败")
-
def check_emoji_file_integrity(self):
"""检查表情包文件完整性
@@ -451,7 +450,7 @@ class EmojiManager:
def check_emoji_file_full(self):
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
-
+
删除规则:
1. 优先删除创建时间更早的表情包
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
@@ -460,23 +459,23 @@ class EmojiManager:
self._ensure_db()
# 更新表情包数量
self._update_emoji_count()
-
+
# 检查是否超出限制
if self.emoji_num <= self.emoji_num_max:
return
-
+
# 如果超出限制但不允许删除,则只记录警告
if not global_config.max_reach_deletion:
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
return
-
+
# 计算需要删除的数量
delete_count = self.emoji_num - self.emoji_num_max
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
-
+
# 获取所有表情包,按时间戳升序(旧的在前)排序
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
-
+
# 计算权重:使用次数越多,被删除的概率越小
weights = []
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
@@ -485,11 +484,11 @@ class EmojiManager:
# 使用指数衰减函数计算权重,使用次数越多权重越小
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
weights.append(weight)
-
+
# 根据权重随机选择要删除的表情包
to_delete = []
remaining_indices = list(range(len(all_emojis)))
-
+
while len(to_delete) < delete_count and remaining_indices:
# 计算当前剩余表情包的权重
current_weights = [weights[i] for i in remaining_indices]
@@ -497,13 +496,13 @@ class EmojiManager:
total_weight = sum(current_weights)
if total_weight == 0:
break
- normalized_weights = [w/total_weight for w in current_weights]
-
+ normalized_weights = [w / total_weight for w in current_weights]
+
# 随机选择一个表情包
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
to_delete.append(all_emojis[selected_idx])
remaining_indices.remove(selected_idx)
-
+
# 删除选中的表情包
deleted_count = 0
for emoji in to_delete:
@@ -512,26 +511,26 @@ class EmojiManager:
if "path" in emoji and os.path.exists(emoji["path"]):
os.remove(emoji["path"])
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
-
+
# 删除数据库记录
db.emoji.delete_one({"_id": emoji["_id"]})
deleted_count += 1
-
+
# 同时从images集合中删除
if "hash" in emoji:
db.images.delete_one({"hash": emoji["hash"]})
-
+
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
continue
-
+
# 更新表情包数量
self._update_emoji_count()
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
-
+
except Exception as e:
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
-
+
async def start_periodic_check_register(self):
"""定期检查表情包完整性和数量"""
while True:
@@ -542,7 +541,7 @@ class EmojiManager:
logger.info("[扫描] 开始扫描新表情包...")
if self.emoji_num < self.emoji_num_max:
await self.scan_new_emojis()
- if (self.emoji_num > self.emoji_num_max):
+ if self.emoji_num > self.emoji_num_max:
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
if not global_config.max_reach_deletion:
logger.warning("表情包数量超过最大限制,终止注册")
@@ -551,7 +550,7 @@ class EmojiManager:
logger.warning("表情包数量超过最大限制,开始删除表情包")
self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
-
+
async def delete_all_images(self):
"""删除 data/image 目录下的所有文件"""
try:
@@ -559,10 +558,10 @@ class EmojiManager:
if not os.path.exists(image_dir):
logger.warning(f"[警告] 目录不存在: {image_dir}")
return
-
+
deleted_count = 0
failed_count = 0
-
+
# 遍历目录下的所有文件
for filename in os.listdir(image_dir):
file_path = os.path.join(image_dir, filename)
@@ -574,11 +573,12 @@ class EmojiManager:
except Exception as e:
failed_count += 1
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
-
+
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count} 个")
-
+
except Exception as e:
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
+
# 创建全局单例
emoji_manager = EmojiManager()
diff --git a/src/plugins/chat/message_buffer.py b/src/plugins/chat/message_buffer.py
index a87ed4e9d..f62e015b4 100644
--- a/src/plugins/chat/message_buffer.py
+++ b/src/plugins/chat/message_buffer.py
@@ -13,9 +13,10 @@ from ..config.config import global_config
logger = get_module_logger("message_buffer")
+
@dataclass
class CacheMessages:
- message: MessageRecv
+ message: MessageRecv
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
result: str = "U"
@@ -25,7 +26,7 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
- def get_person_id_(self, platform:str, user_id:str, group_info:GroupInfo):
+ def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
@@ -34,16 +35,17 @@ class MessageBuffer:
key = f"{platform}_{user_id}_{group_id}"
return hashlib.md5(key.encode()).hexdigest()
- async def start_caching_messages(self, message:MessageRecv):
+ async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
- person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
- message.message_info.user_info.user_id)
+ person_id = person_info_manager.get_person_id(
+ message.message_info.user_info.platform, message.message_info.user_info.user_id
+ )
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return
- person_id_ = self.get_person_id_(message.message_info.platform,
- message.message_info.user_info.user_id,
- message.message_info.group_info)
+ person_id_ = self.get_person_id_(
+ message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
+ )
async with self.lock:
if person_id_ not in self.buffer_pool:
@@ -64,25 +66,24 @@ class MessageBuffer:
break
elif msg.result == "F":
recent_F_count += 1
-
+
# 判断条件:最近T之后有超过3-5条F
- if (recent_F_count >= random.randint(3, 5)):
+ if recent_F_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
return
-
+
# 添加新消息
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
-
+
# 启动3秒缓冲计时器
- person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
- message.message_info.user_info.user_id)
+ person_id = person_info_manager.get_person_id(
+ message.message_info.user_info.platform, message.message_info.user_info.user_id
+ )
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
- asyncio.create_task(self._debounce_processor(person_id_,
- message.message_info.message_id,
- person_id))
+ asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息"""
@@ -92,36 +93,33 @@ class MessageBuffer:
return
interval_time = max(0.5, int(interval_time) / 1000)
await asyncio.sleep(interval_time)
-
+
async with self.lock:
- if (person_id_ not in self.buffer_pool or
- message_id not in self.buffer_pool[person_id_]):
+ if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
logger.debug(f"消息已被清理,msgid: {message_id}")
return
-
+
cache_msg = self.buffer_pool[person_id_][message_id]
if cache_msg.result == "U":
cache_msg.result = "T"
cache_msg.cache_determination.set()
-
- async def query_buffer_result(self, message:MessageRecv) -> bool:
+ async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
return True
- person_id_ = self.get_person_id_(message.message_info.platform,
- message.message_info.user_info.user_id,
- message.message_info.group_info)
-
-
+ person_id_ = self.get_person_id_(
+ message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
+ )
+
async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {})
cache_msg = user_msgs.get(message.message_info.message_id)
-
+
if not cache_msg:
logger.debug(f"查询异常,消息不存在,msgid: {message.message_info.message_id}")
return False # 消息不存在或已清理
-
+
try:
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
result = cache_msg.result == "T"
@@ -144,9 +142,8 @@ class MessageBuffer:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
- if (hasattr(msg.message, 'processed_plain_text')
- and msg.message.processed_plain_text):
- if msg.message.message_segment.type == "text":
+ if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
+ if msg.message.message_segment.type == "text":
combined_text.append(msg.message.processed_plain_text)
elif msg.message.message_segment.type != "text":
is_update = False
@@ -157,20 +154,20 @@ class MessageBuffer:
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
if type == "text":
message.processed_plain_text = "".join(combined_text)
- logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息")
+ logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
elif type == "emoji":
combined_text.pop()
message.processed_plain_text = "".join(combined_text)
message.is_emoji = False
- logger.debug(f"整合了{len(combined_text)-1}条F消息的内容,覆盖当前emoji消息")
+ logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容,覆盖当前emoji消息")
self.buffer_pool[person_id_] = keep_msgs
return result
except asyncio.TimeoutError:
logger.debug(f"查询超时消息id: {message.message_info.message_id}")
return False
-
- async def save_message_interval(self, person_id:str, message:BaseMessageInfo):
+
+ async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:
@@ -179,12 +176,12 @@ class MessageBuffer:
message_interval_list.pop(0)
message_interval_list.append(now_time_ms)
data = {
- "platform" : message.platform,
- "user_id" : message.user_info.user_id,
- "nickname" : message.user_info.user_nickname,
- "konw_time" : int(time.time())
+ "platform": message.platform,
+ "user_id": message.user_info.user_id,
+ "nickname": message.user_info.user_nickname,
+ "konw_time": int(time.time()),
}
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
-message_buffer = MessageBuffer()
\ No newline at end of file
+message_buffer = MessageBuffer()
diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py
index 566fe295e..9f547ed10 100644
--- a/src/plugins/chat/message_sender.py
+++ b/src/plugins/chat/message_sender.py
@@ -68,7 +68,8 @@ class Message_Sender:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
- is_emoji=message.is_emoji)
+ is_emoji=message.is_emoji,
+ )
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
await asyncio.sleep(typing_time)
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
@@ -227,7 +228,7 @@ class MessageManager:
await message_earliest.process()
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
-
+
await message_sender.send_message(message_earliest)
await self.storage.store_message(message_earliest, message_earliest.chat_stream)
diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py
index 1b9196e14..b7cc32e2f 100644
--- a/src/plugins/chat/utils.py
+++ b/src/plugins/chat/utils.py
@@ -56,14 +56,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
logger.info("被@,回复概率设置为100%")
else:
if not is_mentioned:
-
# 判断是否被回复
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
is_mentioned = True
-
+
# 判断内容中是否被提及
- message_content = re.sub(r'\@[\s\S]*?((\d+))','', message.processed_plain_text)
- message_content = re.sub(r'回复[\s\S]*?\((\d+)\)的消息,说: ','', message_content)
+ message_content = re.sub(r"\@[\s\S]*?((\d+))", "", message.processed_plain_text)
+ message_content = re.sub(r"回复[\s\S]*?\((\d+)\)的消息,说: ", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
@@ -359,7 +358,13 @@ def process_llm_response(text: str) -> List[str]:
return sentences
-def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float:
+def calculate_typing_time(
+ input_string: str,
+ thinking_start_time: float,
+ chinese_time: float = 0.2,
+ english_time: float = 0.1,
+ is_emoji: bool = False,
+) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
@@ -393,19 +398,18 @@ def calculate_typing_time(input_string: str, thinking_start_time: float, chinese
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
-
-
+
if is_emoji:
total_time = 1
-
+
if time.time() - thinking_start_time > 10:
total_time = 1
-
+
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}")
-
+
return total_time # 加上回车时间
@@ -535,39 +539,32 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
try:
# 获取开始时间之前最新的一条消息
start_message = db.messages.find_one(
- {
- "chat_id": stream_id,
- "time": {"$lte": start_time}
- },
- sort=[("time", -1), ("_id", -1)] # 按时间倒序,_id倒序(最后插入的在前)
+ {"chat_id": stream_id, "time": {"$lte": start_time}},
+ sort=[("time", -1), ("_id", -1)], # 按时间倒序,_id倒序(最后插入的在前)
)
-
+
# 获取结束时间最近的一条消息
# 先找到结束时间点的所有消息
- end_time_messages = list(db.messages.find(
- {
- "chat_id": stream_id,
- "time": {"$lte": end_time}
- },
- sort=[("time", -1)] # 先按时间倒序
- ).limit(10)) # 限制查询数量,避免性能问题
-
+ end_time_messages = list(
+ db.messages.find(
+ {"chat_id": stream_id, "time": {"$lte": end_time}},
+ sort=[("time", -1)], # 先按时间倒序
+ ).limit(10)
+ ) # 限制查询数量,避免性能问题
+
if not end_time_messages:
logger.warning(f"未找到结束时间 {end_time} 之前的消息")
return 0, 0
-
+
# 找到最大时间
max_time = end_time_messages[0]["time"]
# 在最大时间的消息中找最后插入的(_id最大的)
- end_message = max(
- [msg for msg in end_time_messages if msg["time"] == max_time],
- key=lambda x: x["_id"]
- )
-
+ end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
+
if not start_message:
logger.warning(f"未找到开始时间 {start_time} 之前的消息")
return 0, 0
-
+
# 调试输出
# print("\n=== 消息范围信息 ===")
# print("Start message:", {
@@ -587,20 +584,16 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 如果结束消息的时间等于开始时间,返回0
if end_message["time"] == start_message["time"]:
return 0, 0
-
+
# 获取并打印这个时间范围内的所有消息
# print("\n=== 时间范围内的所有消息 ===")
- all_messages = list(db.messages.find(
- {
- "chat_id": stream_id,
- "time": {
- "$gte": start_message["time"],
- "$lte": end_message["time"]
- }
- },
- sort=[("time", 1), ("_id", 1)] # 按时间正序,_id正序
- ))
-
+ all_messages = list(
+ db.messages.find(
+ {"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
+ sort=[("time", 1), ("_id", 1)], # 按时间正序,_id正序
+ )
+ )
+
count = 0
total_length = 0
for msg in all_messages:
@@ -615,10 +608,10 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# "text_length": text_length,
# "_id": str(msg.get("_id"))
# })
-
+
# 如果时间不同,需要把end_message本身也计入
return count - 1, total_length
-
+
except Exception as e:
logger.error(f"计算消息数量时出错: {str(e)}")
return 0, 0
diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py
index 7c930f6dc..ed78dc17e 100644
--- a/src/plugins/chat/utils_image.py
+++ b/src/plugins/chat/utils_image.py
@@ -239,13 +239,13 @@ class ImageManager:
# 解码base64
gif_data = base64.b64decode(gif_base64)
gif = Image.open(io.BytesIO(gif_data))
-
+
# 收集所有帧
frames = []
try:
while True:
gif.seek(len(frames))
- frame = gif.convert('RGB')
+ frame = gif.convert("RGB")
frames.append(frame.copy())
except EOFError:
pass
@@ -264,18 +264,19 @@ class ImageManager:
# 获取单帧的尺寸
frame_width, frame_height = selected_frames[0].size
-
+
# 计算目标尺寸,保持宽高比
target_height = 200 # 固定高度
target_width = int((target_height / frame_height) * frame_width)
-
+
# 调整所有帧的大小
- resized_frames = [frame.resize((target_width, target_height), Image.Resampling.LANCZOS)
- for frame in selected_frames]
+ resized_frames = [
+ frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
+ ]
# 创建拼接图像
total_width = target_width * len(resized_frames)
- combined_image = Image.new('RGB', (total_width, target_height))
+ combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像
for idx, frame in enumerate(resized_frames):
@@ -283,11 +284,11 @@ class ImageManager:
# 转换为base64
buffer = io.BytesIO()
- combined_image.save(buffer, format='JPEG', quality=85)
- result_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
-
+ combined_image.save(buffer, format="JPEG", quality=85)
+ result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+
return result_base64
-
+
except Exception as e:
logger.error(f"GIF转换失败: {str(e)}")
return None
diff --git a/src/plugins/chat_module/only_process/only_message_process.py b/src/plugins/chat_module/only_process/only_message_process.py
index 4c1e7d5e1..6da19efe7 100644
--- a/src/plugins/chat_module/only_process/only_message_process.py
+++ b/src/plugins/chat_module/only_process/only_message_process.py
@@ -7,12 +7,13 @@ from datetime import datetime
logger = get_module_logger("pfc_message_processor")
+
class MessageProcessor:
"""消息处理器,负责处理接收到的消息并存储"""
-
+
def __init__(self):
self.storage = MessageStorage()
-
+
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
@@ -34,10 +35,10 @@ class MessageProcessor:
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
return True
return False
-
+
async def process_message(self, message: MessageRecv) -> None:
"""处理消息并存储
-
+
Args:
message: 消息对象
"""
@@ -55,12 +56,9 @@ class MessageProcessor:
# 存储消息
await self.storage.store_message(message, chat)
-
+
# 打印消息信息
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# 将时间戳转换为datetime对象
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
- logger.info(
- f"[{current_time}][{mes_name}]"
- f"{chat.user_info.user_nickname}: {message.processed_plain_text}"
- )
\ No newline at end of file
+ logger.info(f"[{current_time}][{mes_name}]{chat.user_info.user_nickname}: {message.processed_plain_text}")
diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py
index aa00992a6..683bef463 100644
--- a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py
+++ b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py
@@ -27,6 +27,7 @@ chat_config = LogConfig(
logger = get_module_logger("reasoning_chat", config=chat_config)
+
class ReasoningChat:
def __init__(self):
self.storage = MessageStorage()
@@ -224,13 +225,13 @@ class ReasoningChat:
do_reply = False
if random() < reply_probability:
do_reply = True
-
+
# 创建思考消息
timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1
-
+
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)
diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py
index 688d09f03..eca5d0956 100644
--- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py
+++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py
@@ -40,7 +40,7 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
- #从global_config中获取模型概率值并选择模型
+ # 从global_config中获取模型概率值并选择模型
if random.random() < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "深深地"
current_model = self.model_reasoning
@@ -51,7 +51,6 @@ class ResponseGenerator:
logger.info(
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501
-
model_response = await self._generate_response_with_model(message, current_model)
@@ -189,4 +188,4 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}")
- return processed_response
\ No newline at end of file
+ return processed_response
diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
index 3a9f0dc46..a379fa6d5 100644
--- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
+++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
@@ -24,35 +24,32 @@ class PromptBuilder:
async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
-
# 开始构建prompt
prompt_personality = "你"
- #person
+ # person
individuality = Individuality.get_instance()
-
+
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
-
+
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
-
+
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
-
-
-
+
# 关系
- who_chat_in_group = [(chat_stream.user_info.platform,
- chat_stream.user_info.user_id,
- chat_stream.user_info.user_nickname)]
+ who_chat_in_group = [
+ (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
+ ]
who_chat_in_group += get_recent_group_speaker(
stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE,
)
-
+
relation_prompt = ""
for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -67,7 +64,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt()
# logger.info(f"心情prompt: {mood_prompt}")
-
+
# 调取记忆
memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
@@ -84,7 +81,7 @@ class PromptBuilder:
# print(f"相关记忆:{related_memory_info}")
# 日程构建
- schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
+ schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
# 获取聊天上下文
chat_in_group = True
@@ -143,7 +140,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt")
-
+
prompt = f"""
{relation_prompt_all}
{memory_prompt}
@@ -165,7 +162,7 @@ class PromptBuilder:
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
-
+
# 1. 先从LLM获取主题,类似于记忆系统的做法
topics = []
# try:
@@ -173,7 +170,7 @@ class PromptBuilder:
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
-
+
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
@@ -184,7 +181,7 @@ class PromptBuilder:
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
-
+
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -192,7 +189,7 @@ class PromptBuilder:
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
-
+
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
@@ -200,26 +197,26 @@ class PromptBuilder:
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
-
+
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
return related_info
-
+
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
-
+
# 优化:批量获取嵌入向量,减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
-
+
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
-
+
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
@@ -228,17 +225,17 @@ class PromptBuilder:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
-
+
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
-
+
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
-
+
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
-
+
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -247,12 +244,12 @@ class PromptBuilder:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
-
+
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
-
+
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
@@ -263,9 +260,9 @@ class PromptBuilder:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
-
+
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
-
+
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
@@ -275,14 +272,16 @@ class PromptBuilder:
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
-
+
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
-
+
# 6. 限制总数量(最多10条)
filtered_results = filtered_results[:10]
- logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
-
+ logger.info(
+ f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
+ )
+
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
@@ -292,7 +291,7 @@ class PromptBuilder:
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
-
+
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}】\n"
@@ -303,13 +302,15 @@ class PromptBuilder:
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
-
+
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
-
+
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
return related_info
- def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
+ def get_info_from_db(
+ self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
+ ) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py
index c0af9d6b5..f845770d3 100644
--- a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py
+++ b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py
@@ -28,6 +28,7 @@ chat_config = LogConfig(
logger = get_module_logger("think_flow_chat", config=chat_config)
+
class ThinkFlowChat:
def __init__(self):
self.storage = MessageStorage()
@@ -96,7 +97,7 @@ class ThinkFlowChat:
)
if not mark_head:
mark_head = True
-
+
# print(f"thinking_start_time:{bot_message.thinking_start_time}")
message_set.add_message(bot_message)
message_manager.add_message(message_set)
@@ -110,7 +111,7 @@ class ThinkFlowChat:
if emoji_raw:
emoji_path, description = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
-
+
# logger.info(emoji_cq)
thinking_time_point = round(message.message_info.time, 2)
@@ -130,7 +131,7 @@ class ThinkFlowChat:
is_head=False,
is_emoji=True,
)
-
+
# logger.info("22222222222222")
message_manager.add_message(bot_message)
@@ -180,7 +181,7 @@ class ThinkFlowChat:
await message.process()
logger.debug(f"消息处理成功{message.processed_plain_text}")
-
+
# 过滤词/正则表达式过滤
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
message.raw_message, chat, userinfo
@@ -190,7 +191,7 @@ class ThinkFlowChat:
await self.storage.store_message(message, chat)
logger.debug(f"存储成功{message.processed_plain_text}")
-
+
# 记忆激活
timer1 = time.time()
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
@@ -214,15 +215,13 @@ class ThinkFlowChat:
# 处理提及
is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
-
# 计算回复意愿
current_willing_old = willing_manager.get_willing(chat_stream=chat)
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
- # current_willing = (current_willing_old + current_willing_new) / 2
+ # current_willing = (current_willing_old + current_willing_new) / 2
# 有点bug
current_willing = current_willing_old
-
willing_manager.set_willing(chat.stream_id, current_willing)
# 意愿激活
@@ -258,7 +257,7 @@ class ThinkFlowChat:
if random() < reply_probability:
try:
do_reply = True
-
+
# 创建思考消息
try:
timer1 = time.time()
@@ -267,9 +266,9 @@ class ThinkFlowChat:
timing_results["创建思考消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流创建思考消息失败: {e}")
-
+
try:
- # 观察
+ # 观察
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time()
@@ -280,12 +279,14 @@ class ThinkFlowChat:
# 思考前脑内状态
try:
timer1 = time.time()
- await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
+ await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
+ message.processed_plain_text
+ )
timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}")
-
+
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)
diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py
index d7240d9a6..4087b0b89 100644
--- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py
+++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py
@@ -35,7 +35,6 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
-
logger.info(
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
)
@@ -178,4 +177,3 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response
-
diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py
index b6fe9fb89..fc52a6151 100644
--- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py
+++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py
@@ -21,22 +21,21 @@ class PromptBuilder:
async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
-
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
-
+
individuality = Individuality.get_instance()
- prompt_personality = individuality.get_prompt(type = "personality",x_person = 2,level = 1)
- prompt_identity = individuality.get_prompt(type = "identity",x_person = 2,level = 1)
+ prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
+ prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
# 关系
- who_chat_in_group = [(chat_stream.user_info.platform,
- chat_stream.user_info.user_id,
- chat_stream.user_info.user_nickname)]
+ who_chat_in_group = [
+ (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
+ ]
who_chat_in_group += get_recent_group_speaker(
stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE,
)
-
+
relation_prompt = ""
for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -100,7 +99,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt")
-
+
prompt = f"""
{relation_prompt_all}\n
{chat_target}
@@ -114,7 +113,7 @@ class PromptBuilder:
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。"""
-
+
return prompt
diff --git a/src/plugins/config/auto_update.py b/src/plugins/config/auto_update.py
index 9c4264233..04b4b3ced 100644
--- a/src/plugins/config/auto_update.py
+++ b/src/plugins/config/auto_update.py
@@ -3,6 +3,7 @@ import tomlkit
from pathlib import Path
from datetime import datetime
+
def update_config():
print("开始更新配置文件...")
# 获取根目录路径
@@ -25,11 +26,11 @@ def update_config():
print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
-
+
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
-
+
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}")
diff --git a/src/plugins/config/config.py b/src/plugins/config/config.py
index c16d1360b..eccb3bc0b 100644
--- a/src/plugins/config/config.py
+++ b/src/plugins/config/config.py
@@ -28,6 +28,7 @@ logger = get_module_logger("config", config=config_config)
is_test = True
mai_version_main = "0.6.2"
mai_version_fix = "snapshot-1"
+
if mai_version_fix:
if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
@@ -39,6 +40,7 @@ else:
else:
mai_version = mai_version_main
+
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent
@@ -54,7 +56,7 @@ def update_config():
# 检查配置文件是否存在
if not old_config_path.exists():
logger.info("配置文件不存在,从模板创建新配置")
- #创建文件夹
+ # 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(template_path, old_config_path)
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
@@ -84,7 +86,7 @@ def update_config():
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
-
+
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}")
@@ -127,6 +129,7 @@ def update_config():
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成")
+
logger = get_module_logger("config")
@@ -148,17 +151,21 @@ class BotConfig:
ban_user_id = set()
# personality
- personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
- personality_sides: List[str] = field(default_factory=lambda: [
- "用一句话或几句话描述人格的一些侧面",
- "用一句话或几句话描述人格的一些侧面",
- "用一句话或几句话描述人格的一些侧面"
- ])
+ personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
+ personality_sides: List[str] = field(
+ default_factory=lambda: [
+ "用一句话或几句话描述人格的一些侧面",
+ "用一句话或几句话描述人格的一些侧面",
+ "用一句话或几句话描述人格的一些侧面",
+ ]
+ )
# identity
- identity_detail: List[str] = field(default_factory=lambda: [
- "身份特点",
- "身份特点",
- ])
+ identity_detail: List[str] = field(
+ default_factory=lambda: [
+ "身份特点",
+ "身份特点",
+ ]
+ )
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁
@@ -181,22 +188,22 @@ class BotConfig:
ban_words = set()
ban_msgs_regex = set()
-
- #heartflow
+
+ # heartflow
# enable_heartflow: bool = False # 是否启用心流
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
-
+
# willing
willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
- mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
- at_bot_inevitable_reply: bool = False # @bot 必然回复
+ mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
+ at_bot_inevitable_reply: bool = False # @bot 必然回复
# response
response_mode: str = "heart_flow" # 回复策略
@@ -354,7 +361,6 @@ class BotConfig:
"""从TOML配置文件加载配置"""
config = cls()
-
def personality(parent: dict):
personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
@@ -418,13 +424,21 @@ class BotConfig:
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
config.response_mode = response_config.get("response_mode", config.response_mode)
-
+
def heartflow(parent: dict):
heartflow_config = parent["heartflow"]
- config.sub_heart_flow_update_interval = heartflow_config.get("sub_heart_flow_update_interval", config.sub_heart_flow_update_interval)
- config.sub_heart_flow_freeze_time = heartflow_config.get("sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time)
- config.sub_heart_flow_stop_time = heartflow_config.get("sub_heart_flow_stop_time", config.sub_heart_flow_stop_time)
- config.heart_flow_update_interval = heartflow_config.get("heart_flow_update_interval", config.heart_flow_update_interval)
+ config.sub_heart_flow_update_interval = heartflow_config.get(
+ "sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
+ )
+ config.sub_heart_flow_freeze_time = heartflow_config.get(
+ "sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
+ )
+ config.sub_heart_flow_stop_time = heartflow_config.get(
+ "sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
+ )
+ config.heart_flow_update_interval = heartflow_config.get(
+ "heart_flow_update_interval", config.heart_flow_update_interval
+ )
def willing(parent: dict):
willing_config = parent["willing"]
diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py
index 7f781ac31..717cebe17 100644
--- a/src/plugins/memory_system/Hippocampus.py
+++ b/src/plugins/memory_system/Hippocampus.py
@@ -14,6 +14,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from .memory_config import MemoryConfig
+
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index 852bba412..784bfa1db 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -179,7 +179,6 @@ class LLM_request:
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}")
-
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
@@ -205,13 +204,17 @@ class LLM_request:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry)
- logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
+ logger.warning(
+ f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
+ )
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
- logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
+ logger.error(
+ f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
+ )
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
@@ -219,7 +222,9 @@ class LLM_request:
await asyncio.sleep(wait_time)
continue
elif response.status in policy["abort_codes"]:
- logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
+ logger.error(
+ f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
+ )
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
@@ -257,7 +262,9 @@ class LLM_request:
):
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
- logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
+ logger.warning(
+ f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}"
+ )
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
@@ -266,7 +273,9 @@ class LLM_request:
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
+ logger.warning(
+ f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
+ )
# 更新payload中的模型名
if payload and "model" in payload:
@@ -328,7 +337,14 @@ class LLM_request:
await response.release()
# 返回已经累积的内容
result = {
- "choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
+ "choices": [
+ {
+ "message": {
+ "content": accumulated_content,
+ "reasoning_content": reasoning_content,
+ }
+ }
+ ],
"usage": usage,
}
return (
@@ -345,7 +361,14 @@ class LLM_request:
logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容
result = {
- "choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
+ "choices": [
+ {
+ "message": {
+ "content": accumulated_content,
+ "reasoning_content": reasoning_content,
+ }
+ }
+ ],
"usage": usage,
}
return (
@@ -360,7 +383,9 @@ class LLM_request:
content = re.sub(r".*? ", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {
- "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}],
+ "choices": [
+ {"message": {"content": content, "reasoning_content": reasoning_content}}
+ ],
"usage": usage,
}
return (
@@ -394,7 +419,9 @@ class LLM_request:
# 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
- logger.error(f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
+ logger.error(
+ f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
+ )
try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text()
@@ -419,13 +446,17 @@ class LLM_request:
else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
- logger.warning(f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
+ logger.warning(
+ f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
+ )
except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
else:
- logger.critical(f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
+ logger.critical(
+ f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
+ )
# 安全地检查和记录请求详情
if (
image_base64
diff --git a/src/plugins/moods/moods.py b/src/plugins/moods/moods.py
index e7b6261a6..61b211523 100644
--- a/src/plugins/moods/moods.py
+++ b/src/plugins/moods/moods.py
@@ -139,7 +139,7 @@ class MoodManager:
# 神经质:影响情绪变化速度
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
-
+
# 宜人性:影响情绪基准线
if personality.agreeableness < 0.2:
agreeableness_bias = (personality.agreeableness - 0.2) * 2
@@ -151,7 +151,7 @@ class MoodManager:
# 分别计算正向和负向的衰减率
if self.current_mood.valence >= 0:
# 正向情绪衰减
- decay_rate_positive = self.decay_rate_valence * (1/agreeableness_factor)
+ decay_rate_positive = self.decay_rate_valence * (1 / agreeableness_factor)
valence_target = 0 + agreeableness_bias
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-decay_rate_positive * time_diff * neuroticism_factor
@@ -279,8 +279,9 @@ class MoodManager:
# 限制范围
self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence))
self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal))
-
+
self._update_mood_text()
- logger.info(f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}")
-
+ logger.info(
+ f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}"
+ )
diff --git a/src/plugins/person_info/person_info.py b/src/plugins/person_info/person_info.py
index 4dbdcd65f..4c1f9c688 100644
--- a/src/plugins/person_info/person_info.py
+++ b/src/plugins/person_info/person_info.py
@@ -8,7 +8,8 @@ import asyncio
import numpy as np
import matplotlib
-matplotlib.use('Agg')
+
+matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
@@ -30,38 +31,39 @@ PersonInfoManager 类方法功能摘要:
logger = get_module_logger("person_info")
person_info_default = {
- "person_id" : None,
- "platform" : None,
- "user_id" : None,
- "nickname" : None,
+ "person_id": None,
+ "platform": None,
+ "user_id": None,
+ "nickname": None,
# "age" : 0,
- "relationship_value" : 0,
+ "relationship_value": 0,
# "saved" : True,
# "impression" : None,
# "gender" : Unkown,
- "konw_time" : 0,
+ "konw_time": 0,
"msg_interval": 3000,
- "msg_interval_list": []
+ "msg_interval_list": [],
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
+
class PersonInfoManager:
def __init__(self):
if "person_info" not in db.list_collection_names():
db.create_collection("person_info")
db.person_info.create_index("person_id", unique=True)
- def get_person_id(self, platform:str, user_id:int):
+ def get_person_id(self, platform: str, user_id: int):
"""获取唯一id"""
components = [platform, str(user_id)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
- async def create_person_info(self, person_id:str, data:dict = None):
+ async def create_person_info(self, person_id: str, data: dict = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败,personid不存在")
return
-
+
_person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id
@@ -72,19 +74,16 @@ class PersonInfoManager:
db.person_info.insert_one(_person_info_default)
- async def update_one_field(self, person_id:str, field_name:str, value, Data:dict = None):
+ async def update_one_field(self, person_id: str, field_name: str, value, Data: dict = None):
"""更新某一个字段,会补全"""
if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段")
return
-
+
document = db.person_info.find_one({"person_id": person_id})
if document:
- db.person_info.update_one(
- {"person_id": person_id},
- {"$set": {field_name: value}}
- )
+ db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
else:
Data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建")
@@ -107,23 +106,20 @@ class PersonInfoManager:
if not person_id:
logger.debug("get_value获取失败:person_id不能为空")
return None
-
+
if field_name not in person_info_default:
logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
return None
-
- document = db.person_info.find_one(
- {"person_id": person_id},
- {field_name: 1}
- )
-
+
+ document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
+
if document and field_name in document:
return document[field_name]
else:
default_value = copy.deepcopy(person_info_default[field_name])
logger.debug(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
return default_value
-
+
async def get_values(self, person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
if not person_id:
@@ -139,62 +135,57 @@ class PersonInfoManager:
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
- document = db.person_info.find_one(
- {"person_id": person_id},
- projection
- )
+ document = db.person_info.find_one({"person_id": person_id}, projection)
result = {}
for field in field_names:
result[field] = copy.deepcopy(
- document.get(field, person_info_default[field])
- if document else person_info_default[field]
+ document.get(field, person_info_default[field]) if document else person_info_default[field]
)
return result
-
+
async def del_all_undefined_field(self):
"""删除所有项里的未定义字段"""
# 获取所有已定义的字段名
defined_fields = set(person_info_default.keys())
-
+
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
- undefined_fields = set(document.keys()) - defined_fields - {'_id'}
-
+ undefined_fields = set(document.keys()) - defined_fields - {"_id"}
+
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
- {'_id': document['_id']},
- {'$unset': {field: 1 for field in undefined_fields}}
+ {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
)
-
+
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
-
+
return
-
+
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return
-
+
async def get_specific_value_list(
- self,
- field_name: str,
- way: Callable[[Any], bool], # 接受任意类型值
-) ->Dict[str, Any]:
+ self,
+ field_name: str,
+ way: Callable[[Any], bool], # 接受任意类型值
+ ) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
-
+
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
-
+
Returns:
{person_id: value} | {}
-
+
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
@@ -208,10 +199,7 @@ class PersonInfoManager:
try:
result = {}
- for doc in db.person_info.find(
- {field_name: {"$exists": True}},
- {"person_id": 1, field_name: 1, "_id": 0}
- ):
+ for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
try:
value = doc[field_name]
if way(value):
@@ -225,11 +213,11 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
return {}
-
+
async def personal_habit_deduction(self):
"""启动个人信息推断,每天根据一定条件推断一次"""
try:
- while(1):
+ while 1:
await asyncio.sleep(60)
current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
@@ -237,8 +225,7 @@ class PersonInfoManager:
# "msg_interval"推断
msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list(
- "msg_interval_list",
- lambda x: isinstance(x, list) and len(x) >= 100
+ "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
)
for person_id, msg_interval_list_ in msg_interval_lists.items():
try:
@@ -258,23 +245,23 @@ class PersonInfoManager:
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
time_series = pd.Series(time_interval)
- plt.hist(time_series, bins=50, density=True, alpha=0.4, color='pink', label='Histogram')
- time_series.plot(kind='kde', color='mediumpurple', linewidth=1, label='Density')
+ plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
+ time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
plt.grid(True, alpha=0.2)
plt.xlim(0, 8000)
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
plt.xlabel("Interval (ms)")
plt.ylabel("Density")
- plt.legend(framealpha=0.9, facecolor='white')
+ plt.legend(framealpha=0.9, facecolor="white")
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
# 画图
-
+
q25, q75 = np.percentile(time_interval, [25, 75])
iqr = q75 - q25
- filtered = [x for x in time_interval if (q25 - 1.5*iqr) <= x <= (q75 + 1.5*iqr)]
-
+ filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
+
msg_interval = int(round(np.percentile(filtered, 80)))
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
diff --git a/src/plugins/person_info/relationship_manager.py b/src/plugins/person_info/relationship_manager.py
index 9bbcf4e19..f64e2851c 100644
--- a/src/plugins/person_info/relationship_manager.py
+++ b/src/plugins/person_info/relationship_manager.py
@@ -12,6 +12,7 @@ relationship_config = LogConfig(
)
logger = get_module_logger("rel_manager", config=relationship_config)
+
class RelationshipManager:
def __init__(self):
self.positive_feedback_value = 0 # 正反馈系统
@@ -22,6 +23,7 @@ class RelationshipManager:
def mood_manager(self):
if self._mood_manager is None:
from ..moods.moods import MoodManager # 延迟导入
+
self._mood_manager = MoodManager.get_instance()
return self._mood_manager
@@ -51,27 +53,27 @@ class RelationshipManager:
self.positive_feedback_value -= 1
elif self.positive_feedback_value > 0:
self.positive_feedback_value = 0
-
+
if abs(self.positive_feedback_value) > 1:
logger.info(f"触发mood变更增益,当前增益系数:{self.gain_coefficient[abs(self.positive_feedback_value)]}")
def mood_feedback(self, value):
"""情绪反馈"""
mood_manager = self.mood_manager
- mood_gain = (mood_manager.get_current_mood().valence) ** 2 \
- * math.copysign(1, value * mood_manager.get_current_mood().valence)
+ mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
+ 1, value * mood_manager.get_current_mood().valence
+ )
value += value * mood_gain
logger.info(f"当前relationship增益系数:{mood_gain:.3f}")
return value
-
+
def feedback_to_mood(self, mood_value):
"""对情绪的反馈"""
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
- if (mood_value > 0 and self.positive_feedback_value > 0
- or mood_value < 0 and self.positive_feedback_value < 0):
- return mood_value*coefficient
+ if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
+ return mood_value * coefficient
else:
- return mood_value/coefficient
+ return mood_value / coefficient
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
"""计算并变更关系值
@@ -88,7 +90,7 @@ class RelationshipManager:
"中立": 1,
"反对": 2,
}
-
+
valuedict = {
"开心": 1.5,
"愤怒": -2.0,
@@ -103,10 +105,10 @@ class RelationshipManager:
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
data = {
- "platform" : chat_stream.user_info.platform,
- "user_id" : chat_stream.user_info.user_id,
- "nickname" : chat_stream.user_info.user_nickname,
- "konw_time" : int(time.time())
+ "platform": chat_stream.user_info.platform,
+ "user_id": chat_stream.user_info.user_id,
+ "nickname": chat_stream.user_info.user_nickname,
+ "konw_time": int(time.time()),
}
old_value = await person_info_manager.get_value(person_id, "relationship_value")
old_value = self.ensure_float(old_value, person_id)
@@ -200,4 +202,5 @@ class RelationshipManager:
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}),已重置为0")
return 0.0
+
relationship_manager = RelationshipManager()
diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py
index e8999099b..ccab662d1 100644
--- a/src/plugins/schedule/schedule_generator.py
+++ b/src/plugins/schedule/schedule_generator.py
@@ -14,7 +14,7 @@ from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfi
from src.plugins.models.utils_model import LLM_request # noqa: E402
from src.plugins.config.config import global_config # noqa: E402
-TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
+TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
schedule_config = LogConfig(
@@ -31,10 +31,16 @@ class ScheduleGenerator:
def __init__(self):
# 使用离线LLM模型
self.llm_scheduler_all = LLM_request(
- model=global_config.llm_reasoning, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=7000, request_type="schedule"
+ model=global_config.llm_reasoning,
+ temperature=global_config.SCHEDULE_TEMPERATURE,
+ max_tokens=7000,
+ request_type="schedule",
)
self.llm_scheduler_doing = LLM_request(
- model=global_config.llm_normal, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=2048, request_type="schedule"
+ model=global_config.llm_normal,
+ temperature=global_config.SCHEDULE_TEMPERATURE,
+ max_tokens=2048,
+ request_type="schedule",
)
self.today_schedule_text = ""
diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py
index eef10c01d..4b9afff39 100644
--- a/src/plugins/utils/statistic.py
+++ b/src/plugins/utils/statistic.py
@@ -2,7 +2,7 @@ import threading
import time
from collections import defaultdict
from datetime import datetime, timedelta
-from typing import Any, Dict
+from typing import Any, Dict, List
from src.common.logger import get_module_logger
from ...common.database import db
@@ -22,6 +22,7 @@ class LLMStatistics:
self.stats_thread = None
self.console_thread = None
self._init_database()
+ self.name_dict: Dict[List] = {}
def _init_database(self):
"""初始化数据库集合"""
@@ -137,16 +138,24 @@ class LLMStatistics:
# user_id = str(doc.get("user_info", {}).get("user_id", "unknown"))
chat_info = doc.get("chat_info", {})
user_info = doc.get("user_info", {})
+ message_time = doc.get("time", 0)
group_info = chat_info.get("group_info") if chat_info else {}
# print(f"group_info: {group_info}")
group_name = None
if group_info:
+ group_id = f"g{group_info.get('group_id')}"
group_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
if user_info and not group_name:
+ group_id = f"u{user_info['user_id']}"
group_name = user_info["user_nickname"]
+ if self.name_dict.get(group_id):
+ if message_time > self.name_dict.get(group_id)[1]:
+ self.name_dict[group_id] = [group_name, message_time]
+ else:
+ self.name_dict[group_id] = [group_name, message_time]
# print(f"group_name: {group_name}")
stats["messages_by_user"][user_id] += 1
- stats["messages_by_chat"][group_name] += 1
+ stats["messages_by_chat"][group_id] += 1
return stats
@@ -187,7 +196,7 @@ class LLMStatistics:
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
output.append(
- data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
+ data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
)
output.append("")
@@ -221,8 +230,8 @@ class LLMStatistics:
# 添加聊天统计
output.append("群组统计:")
output.append(("群组名称 消息数量"))
- for group_name, count in sorted(stats["messages_by_chat"].items()):
- output.append(f"{group_name[:32]:<32} {count:>10}")
+ for group_id, count in sorted(stats["messages_by_chat"].items()):
+ output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
return "\n".join(output)
@@ -250,7 +259,7 @@ class LLMStatistics:
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
output.append(
- data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
+ data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
)
output.append("")
@@ -284,8 +293,8 @@ class LLMStatistics:
# 添加聊天统计
output.append("群组统计:")
output.append(("群组名称 消息数量"))
- for group_name, count in sorted(stats["messages_by_chat"].items()):
- output.append(f"{group_name[:32]:<32} {count:>10}")
+ for group_id, count in sorted(stats["messages_by_chat"].items()):
+ output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
return "\n".join(output)
diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py
index cf38874ce..6d046e025 100644
--- a/src/plugins/zhishi/knowledge_library.py
+++ b/src/plugins/zhishi/knowledge_library.py
@@ -53,18 +53,18 @@ class KnowledgeLibrary:
# 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = []
-
+
for para in paragraphs:
para_length = len(para)
-
+
# 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length:
chunks.append(para)
else:
# 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length):
- chunks.append(para[i:i + max_length])
-
+ chunks.append(para[i : i + max_length])
+
return chunks
def get_embedding(self, text: str) -> list:
diff --git a/temp_utils_ui/temp_ui.py b/temp_utils_ui/temp_ui.py
index f81fbdc44..3e0e1b5a5 100644
--- a/temp_utils_ui/temp_ui.py
+++ b/temp_utils_ui/temp_ui.py
@@ -32,7 +32,7 @@ SECTION_TRANSLATIONS = {
"response_spliter": "回复分割器",
"remote": "远程设置",
"experimental": "实验功能",
- "model": "模型设置"
+ "model": "模型设置",
}
# 配置项的中文描述
@@ -41,16 +41,13 @@ CONFIG_DESCRIPTIONS = {
"bot.qq": "机器人的QQ号码",
"bot.nickname": "机器人的昵称",
"bot.alias_names": "机器人的别名列表",
-
# 群组设置
"groups.talk_allowed": "允许机器人回复消息的群号列表",
"groups.talk_frequency_down": "降低回复频率的群号列表",
"groups.ban_user_id": "禁止回复和读取消息的QQ号列表",
-
# 人格设置
"personality.personality_core": "人格核心描述,建议20字以内",
"personality.personality_sides": "人格特点列表",
-
# 身份设置
"identity.identity_detail": "身份细节描述列表",
"identity.height": "身高(厘米)",
@@ -58,28 +55,23 @@ CONFIG_DESCRIPTIONS = {
"identity.age": "年龄",
"identity.gender": "性别",
"identity.appearance": "外貌特征",
-
# 日程设置
"schedule.enable_schedule_gen": "是否启用日程表生成",
"schedule.prompt_schedule_gen": "日程表生成提示词",
"schedule.schedule_doing_update_interval": "日程表更新间隔(秒)",
"schedule.schedule_temperature": "日程表温度,建议0.3-0.6",
"schedule.time_zone": "时区设置",
-
# 平台设置
"platforms.nonebot-qq": "QQ平台适配器链接",
-
# 回复设置
"response.response_mode": "回复策略(heart_flow:心流,reasoning:推理)",
"response.model_r1_probability": "主要回复模型使用概率",
"response.model_v3_probability": "次要回复模型使用概率",
-
# 心流设置
"heartflow.sub_heart_flow_update_interval": "子心流更新频率(秒)",
"heartflow.sub_heart_flow_freeze_time": "子心流冻结时间(秒)",
"heartflow.sub_heart_flow_stop_time": "子心流停止时间(秒)",
"heartflow.heart_flow_update_interval": "心流更新频率(秒)",
-
# 消息设置
"message.max_context_size": "获取的上下文数量",
"message.emoji_chance": "使用表情包的概率",
@@ -88,14 +80,12 @@ CONFIG_DESCRIPTIONS = {
"message.message_buffer": "是否启用消息缓冲器",
"message.ban_words": "禁用词列表",
"message.ban_msgs_regex": "禁用消息正则表达式列表",
-
# 意愿设置
"willing.willing_mode": "回复意愿模式",
"willing.response_willing_amplifier": "回复意愿放大系数",
"willing.response_interested_rate_amplifier": "回复兴趣度放大系数",
"willing.down_frequency_rate": "降低回复频率的群组回复意愿降低系数",
"willing.emoji_response_penalty": "表情包回复惩罚系数",
-
# 表情设置
"emoji.max_emoji_num": "表情包最大数量",
"emoji.max_reach_deletion": "达到最大数量时是否删除表情包",
@@ -103,7 +93,6 @@ CONFIG_DESCRIPTIONS = {
"emoji.auto_save": "是否保存表情包和图片",
"emoji.enable_check": "是否启用表情包过滤",
"emoji.check_prompt": "表情包过滤要求",
-
# 记忆设置
"memory.build_memory_interval": "记忆构建间隔(秒)",
"memory.build_memory_distribution": "记忆构建分布参数",
@@ -114,130 +103,118 @@ CONFIG_DESCRIPTIONS = {
"memory.memory_forget_time": "记忆遗忘时间(小时)",
"memory.memory_forget_percentage": "记忆遗忘比例",
"memory.memory_ban_words": "记忆禁用词列表",
-
# 情绪设置
"mood.mood_update_interval": "情绪更新间隔(秒)",
"mood.mood_decay_rate": "情绪衰减率",
"mood.mood_intensity_factor": "情绪强度因子",
-
# 关键词反应
"keywords_reaction.enable": "是否启用关键词反应功能",
-
# 中文错别字
"chinese_typo.enable": "是否启用中文错别字生成器",
"chinese_typo.error_rate": "单字替换概率",
"chinese_typo.min_freq": "最小字频阈值",
"chinese_typo.tone_error_rate": "声调错误概率",
"chinese_typo.word_replace_rate": "整词替换概率",
-
# 回复分割器
"response_spliter.enable_response_spliter": "是否启用回复分割器",
"response_spliter.response_max_length": "回复允许的最大长度",
"response_spliter.response_max_sentence_num": "回复允许的最大句子数",
-
# 远程设置
"remote.enable": "是否启用远程统计",
-
# 实验功能
"experimental.enable_friend_chat": "是否启用好友聊天",
"experimental.pfc_chatting": "是否启用PFC聊天",
-
# 模型设置
"model.llm_reasoning.name": "推理模型名称",
"model.llm_reasoning.provider": "推理模型提供商",
"model.llm_reasoning.pri_in": "推理模型输入价格",
"model.llm_reasoning.pri_out": "推理模型输出价格",
-
"model.llm_normal.name": "回复模型名称",
"model.llm_normal.provider": "回复模型提供商",
"model.llm_normal.pri_in": "回复模型输入价格",
"model.llm_normal.pri_out": "回复模型输出价格",
-
"model.llm_emotion_judge.name": "表情判断模型名称",
"model.llm_emotion_judge.provider": "表情判断模型提供商",
"model.llm_emotion_judge.pri_in": "表情判断模型输入价格",
"model.llm_emotion_judge.pri_out": "表情判断模型输出价格",
-
"model.llm_topic_judge.name": "主题判断模型名称",
"model.llm_topic_judge.provider": "主题判断模型提供商",
"model.llm_topic_judge.pri_in": "主题判断模型输入价格",
"model.llm_topic_judge.pri_out": "主题判断模型输出价格",
-
"model.llm_summary_by_topic.name": "概括模型名称",
"model.llm_summary_by_topic.provider": "概括模型提供商",
"model.llm_summary_by_topic.pri_in": "概括模型输入价格",
"model.llm_summary_by_topic.pri_out": "概括模型输出价格",
-
"model.moderation.name": "内容审核模型名称",
"model.moderation.provider": "内容审核模型提供商",
"model.moderation.pri_in": "内容审核模型输入价格",
"model.moderation.pri_out": "内容审核模型输出价格",
-
"model.vlm.name": "图像识别模型名称",
"model.vlm.provider": "图像识别模型提供商",
"model.vlm.pri_in": "图像识别模型输入价格",
"model.vlm.pri_out": "图像识别模型输出价格",
-
"model.embedding.name": "嵌入模型名称",
"model.embedding.provider": "嵌入模型提供商",
"model.embedding.pri_in": "嵌入模型输入价格",
"model.embedding.pri_out": "嵌入模型输出价格",
-
"model.llm_observation.name": "观察模型名称",
"model.llm_observation.provider": "观察模型提供商",
"model.llm_observation.pri_in": "观察模型输入价格",
"model.llm_observation.pri_out": "观察模型输出价格",
-
"model.llm_sub_heartflow.name": "子心流模型名称",
"model.llm_sub_heartflow.provider": "子心流模型提供商",
"model.llm_sub_heartflow.pri_in": "子心流模型输入价格",
"model.llm_sub_heartflow.pri_out": "子心流模型输出价格",
-
"model.llm_heartflow.name": "心流模型名称",
"model.llm_heartflow.provider": "心流模型提供商",
"model.llm_heartflow.pri_in": "心流模型输入价格",
"model.llm_heartflow.pri_out": "心流模型输出价格",
}
+
# 获取翻译
def get_translation(key):
return SECTION_TRANSLATIONS.get(key, key)
+
# 获取配置项描述
def get_description(key):
return CONFIG_DESCRIPTIONS.get(key, "")
+
# 获取根目录路径
def get_root_dir():
try:
# 获取当前脚本所在目录
- if getattr(sys, 'frozen', False):
+ if getattr(sys, "frozen", False):
# 如果是打包后的应用
current_dir = os.path.dirname(sys.executable)
else:
# 如果是脚本运行
current_dir = os.path.dirname(os.path.abspath(__file__))
-
+
# 获取根目录(假设当前脚本在temp_utils_ui目录下或者是可执行文件在根目录)
if os.path.basename(current_dir) == "temp_utils_ui":
root_dir = os.path.dirname(current_dir)
else:
root_dir = current_dir
-
+
# 检查是否存在config目录
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir, exist_ok=True)
-
+
return root_dir
except Exception as e:
print(f"获取根目录路径失败: {e}")
# 返回当前目录作为备选
return os.getcwd()
+
# 配置文件路径
CONFIG_PATH = os.path.join(get_root_dir(), "config", "bot_config.toml")
+
# 保存配置
def save_config(config_data):
try:
@@ -247,17 +224,17 @@ def save_config(config_data):
backup_dir = os.path.join(os.path.dirname(CONFIG_PATH), "old")
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
-
+
# 生成备份文件名(使用时间戳)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"bot_config_{timestamp}.toml.bak"
backup_path = os.path.join(backup_dir, backup_filename)
-
+
# 复制文件
with open(CONFIG_PATH, "r", encoding="utf-8") as src:
with open(backup_path, "w", encoding="utf-8") as dst:
dst.write(src.read())
-
+
# 保存新配置
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
@@ -266,6 +243,7 @@ def save_config(config_data):
print(f"保存配置失败: {e}")
return False
+
# 加载配置
def load_config():
try:
@@ -279,80 +257,82 @@ def load_config():
print(f"加载配置失败: {e}")
return {}
+
# 多行文本输入框
class ScrollableTextFrame(ctk.CTkFrame):
def __init__(self, master, initial_text="", height=100, width=400, **kwargs):
super().__init__(master, **kwargs)
-
+
self.text_var = StringVar(value=initial_text)
-
+
# 文本框
self.text_box = ctk.CTkTextbox(self, height=height, width=width, wrap="word")
self.text_box.pack(fill="both", expand=True, padx=5, pady=5)
self.text_box.insert("1.0", initial_text)
-
+
# 绑定更改事件
self.text_box.bind("", self.update_var)
-
+
def update_var(self, event=None):
self.text_var.set(self.text_box.get("1.0", "end-1c"))
-
+
def get(self):
return self.text_box.get("1.0", "end-1c")
-
+
def set(self, text):
self.text_box.delete("1.0", "end")
self.text_box.insert("1.0", text)
self.update_var()
+
# 配置UI
class ConfigUI(ctk.CTk):
def __init__(self):
super().__init__()
-
+
# 窗口设置
self.title("麦麦配置修改器")
self.geometry("1100x750")
-
+
# 加载配置
self.config_data = load_config()
if not self.config_data:
messagebox.showerror("错误", "无法加载配置文件!将创建空白配置文件。")
# 如果配置加载失败,创建一个最小化的空配置
self.config_data = {"inner": {"version": "1.0.0"}}
-
+
# 保存原始配置,用于检测变更
self.original_config = json.dumps(self.config_data, sort_keys=True)
-
+
# 自动保存状态
self.auto_save = ctk.BooleanVar(value=False)
-
+
# 创建主框架
self.main_frame = ctk.CTkFrame(self)
self.main_frame.pack(padx=10, pady=10, fill="both", expand=True)
-
+
# 创建顶部工具栏
self.create_toolbar()
-
+
# 创建标签和输入框的字典,用于后续保存配置
self.config_vars = {}
-
+
# 创建左侧导航和右侧内容区域
self.create_split_view()
-
+
# 创建底部状态栏
self.status_label = ctk.CTkLabel(self, text="就绪", anchor="w")
self.status_label.pack(fill="x", padx=10, pady=(0, 5))
-
+
# 绑定关闭事件
self.protocol("WM_DELETE_WINDOW", self.on_closing)
-
+
# 设置最小窗口大小
self.minsize(800, 600)
-
+
# 居中显示窗口
self.center_window()
-
+
def center_window(self):
"""将窗口居中显示"""
try:
@@ -366,112 +346,112 @@ class ConfigUI(ctk.CTk):
print(f"居中窗口时出错: {e}")
# 使用默认位置
pass
-
+
def create_toolbar(self):
toolbar = ctk.CTkFrame(self.main_frame, height=40)
toolbar.pack(fill="x", padx=5, pady=5)
-
+
# 保存按钮
save_btn = ctk.CTkButton(toolbar, text="保存配置", command=self.save_config, width=100)
save_btn.pack(side="left", padx=5)
-
+
# 自动保存选项
auto_save_cb = ctk.CTkCheckBox(toolbar, text="自动保存", variable=self.auto_save)
auto_save_cb.pack(side="left", padx=15)
-
+
# 重新加载按钮
reload_btn = ctk.CTkButton(toolbar, text="重新加载", command=self.reload_config, width=100)
reload_btn.pack(side="left", padx=5)
-
+
# 手动备份按钮
backup_btn = ctk.CTkButton(toolbar, text="手动备份", command=self.backup_config, width=100)
backup_btn.pack(side="left", padx=5)
-
+
# 查看备份按钮
view_backup_btn = ctk.CTkButton(toolbar, text="查看备份", command=self.view_backups, width=100)
view_backup_btn.pack(side="left", padx=5)
-
+
# 导入导出菜单按钮
import_export_btn = ctk.CTkButton(toolbar, text="导入/导出", command=self.show_import_export_menu, width=100)
import_export_btn.pack(side="left", padx=5)
-
+
# 关于按钮
about_btn = ctk.CTkButton(toolbar, text="关于", command=self.show_about, width=80)
about_btn.pack(side="right", padx=5)
-
+
def create_split_view(self):
# 创建分隔视图框架
split_frame = ctk.CTkFrame(self.main_frame)
split_frame.pack(fill="both", expand=True, padx=5, pady=5)
-
+
# 左侧分类列表
self.category_frame = ctk.CTkFrame(split_frame, width=220)
self.category_frame.pack(side="left", fill="y", padx=(0, 5), pady=0)
self.category_frame.pack_propagate(False) # 固定宽度
-
+
# 右侧内容区域
self.content_frame = ctk.CTkScrollableFrame(split_frame)
self.content_frame.pack(side="right", fill="both", expand=True)
-
+
# 创建类别列表
self.create_category_list()
-
+
def create_category_list(self):
# 标题和搜索框
header_frame = ctk.CTkFrame(self.category_frame)
header_frame.pack(fill="x", padx=5, pady=(10, 5))
-
+
ctk.CTkLabel(header_frame, text="配置分类", font=("Arial", 14, "bold")).pack(side="left", padx=5, pady=5)
-
+
# 搜索按钮
search_btn = ctk.CTkButton(
- header_frame,
- text="🔍",
- width=30,
+ header_frame,
+ text="🔍",
+ width=30,
command=self.show_search_dialog,
fg_color="transparent",
- hover_color=("gray80", "gray30")
+ hover_color=("gray80", "gray30"),
)
search_btn.pack(side="right", padx=5, pady=5)
-
+
# 分类按钮
self.category_buttons = {}
self.active_category = None
-
+
# 分类按钮容器
buttons_frame = ctk.CTkScrollableFrame(self.category_frame, height=600)
buttons_frame.pack(fill="both", expand=True, padx=5, pady=5)
-
+
for section in self.config_data:
# 跳过inner部分,这个不应该被用户修改
if section == "inner":
continue
-
+
# 获取翻译
section_name = f"{section} ({get_translation(section)})"
-
+
btn = ctk.CTkButton(
- buttons_frame,
+ buttons_frame,
text=section_name,
fg_color="transparent",
text_color=("gray10", "gray90"),
anchor="w",
height=35,
- command=lambda s=section: self.show_category(s)
+ command=lambda s=section: self.show_category(s),
)
btn.pack(fill="x", padx=5, pady=2)
self.category_buttons[section] = btn
-
+
# 默认显示第一个分类
first_section = next((s for s in self.config_data.keys() if s != "inner"), None)
if first_section:
self.show_category(first_section)
-
+
def show_category(self, category):
# 清除当前内容
for widget in self.content_frame.winfo_children():
widget.destroy()
-
+
# 更新按钮状态
for section, btn in self.category_buttons.items():
if section == category:
@@ -479,49 +459,38 @@ class ConfigUI(ctk.CTk):
self.active_category = section
else:
btn.configure(fg_color="transparent")
-
+
# 获取翻译
category_name = f"{category} ({get_translation(category)})"
-
+
# 添加标题
- ctk.CTkLabel(
- self.content_frame,
- text=f"{category_name} 配置",
- font=("Arial", 16, "bold")
- ).pack(anchor="w", padx=10, pady=(5, 15))
-
- # 添加配置项
- self.add_config_section(
- self.content_frame,
- category,
- self.config_data[category]
+ ctk.CTkLabel(self.content_frame, text=f"{category_name} 配置", font=("Arial", 16, "bold")).pack(
+ anchor="w", padx=10, pady=(5, 15)
)
-
+
+ # 添加配置项
+ self.add_config_section(self.content_frame, category, self.config_data[category])
+
def add_config_section(self, parent, section_path, section_data, indent=0):
# 递归添加配置项
for key, value in section_data.items():
full_path = f"{section_path}.{key}" if indent > 0 else f"{section_path}.{key}"
-
+
# 获取描述
description = get_description(full_path)
-
+
if isinstance(value, dict):
# 如果是字典,创建一个分组框架并递归添加子项
group_frame = ctk.CTkFrame(parent)
group_frame.pack(fill="x", expand=True, padx=10, pady=10)
-
+
# 添加标题
header_frame = ctk.CTkFrame(group_frame, fg_color=("gray85", "gray25"))
header_frame.pack(fill="x", padx=0, pady=0)
-
- label = ctk.CTkLabel(
- header_frame,
- text=f"{key}",
- font=("Arial", 13, "bold"),
- anchor="w"
- )
+
+ label = ctk.CTkLabel(header_frame, text=f"{key}", font=("Arial", 13, "bold"), anchor="w")
label.pack(anchor="w", padx=10, pady=5)
-
+
# 如果有描述,添加提示图标
if description:
# 创建工具提示窗口显示函数
@@ -529,221 +498,179 @@ class ConfigUI(ctk.CTk):
x, y, _, _ = widget.bbox("all")
x += widget.winfo_rootx() + 25
y += widget.winfo_rooty() + 25
-
+
# 创建工具提示窗口
tipwindow = ctk.CTkToplevel(widget)
tipwindow.wm_overrideredirect(True)
tipwindow.wm_geometry(f"+{x}+{y}")
tipwindow.lift()
-
- label = ctk.CTkLabel(
- tipwindow,
- text=text,
- justify="left",
- wraplength=300
- )
+
+ label = ctk.CTkLabel(tipwindow, text=text, justify="left", wraplength=300)
label.pack(padx=5, pady=5)
-
+
# 自动关闭
def close_tooltip():
tipwindow.destroy()
-
+
widget.after(3000, close_tooltip)
return tipwindow
-
+
# 在标题后添加提示图标
tip_label = ctk.CTkLabel(
- header_frame,
- text="ℹ️",
- font=("Arial", 12),
- text_color="light blue",
- width=20
+ header_frame, text="ℹ️", font=("Arial", 12), text_color="light blue", width=20
)
tip_label.pack(side="right", padx=5)
-
+
# 绑定鼠标悬停事件
tip_label.bind("", lambda e, t=description, w=tip_label: show_tooltip(e, t, w))
-
+
# 添加内容
content_frame = ctk.CTkFrame(group_frame)
content_frame.pack(fill="x", expand=True, padx=5, pady=5)
-
- self.add_config_section(content_frame, full_path, value, indent+1)
-
+
+ self.add_config_section(content_frame, full_path, value, indent + 1)
+
elif isinstance(value, list):
# 如果是列表,创建一个文本框用于编辑JSON格式的列表
frame = ctk.CTkFrame(parent)
frame.pack(fill="x", expand=True, padx=5, pady=5)
-
+
# 标签和输入框在一行
label_frame = ctk.CTkFrame(frame)
label_frame.pack(fill="x", padx=5, pady=(5, 0))
-
+
# 标签包含描述提示
label_text = f"{key}:"
if description:
label_text = f"{key}: ({description})"
-
- label = ctk.CTkLabel(
- label_frame,
- text=label_text,
- font=("Arial", 12),
- anchor="w"
- )
- label.pack(anchor="w", padx=5 + indent*10, pady=0)
-
+
+ label = ctk.CTkLabel(label_frame, text=label_text, font=("Arial", 12), anchor="w")
+ label.pack(anchor="w", padx=5 + indent * 10, pady=0)
+
# 添加提示信息
- info_label = ctk.CTkLabel(
- label_frame,
- text="(列表格式: JSON)",
- font=("Arial", 9),
- text_color="gray50"
- )
- info_label.pack(anchor="w", padx=5 + indent*10, pady=(0, 5))
-
+ info_label = ctk.CTkLabel(label_frame, text="(列表格式: JSON)", font=("Arial", 9), text_color="gray50")
+ info_label.pack(anchor="w", padx=5 + indent * 10, pady=(0, 5))
+
# 确定文本框高度,根据列表项数量决定
list_height = max(100, min(len(value) * 20 + 40, 200))
-
+
# 将列表转换为JSON字符串,美化格式
json_str = json.dumps(value, ensure_ascii=False, indent=2)
-
+
# 使用多行文本框
- text_frame = ScrollableTextFrame(
- frame,
- initial_text=json_str,
- height=list_height,
- width=550
- )
- text_frame.pack(fill="x", padx=10 + indent*10, pady=5)
-
+ text_frame = ScrollableTextFrame(frame, initial_text=json_str, height=list_height, width=550)
+ text_frame.pack(fill="x", padx=10 + indent * 10, pady=5)
+
self.config_vars[full_path] = (text_frame.text_var, "list")
-
+
# 绑定变更事件,用于自动保存
text_frame.text_box.bind("", lambda e, path=full_path: self.on_field_change(path))
-
+
elif isinstance(value, bool):
# 如果是布尔值,创建一个复选框
frame = ctk.CTkFrame(parent)
frame.pack(fill="x", expand=True, padx=5, pady=5)
-
+
var = ctk.BooleanVar(value=value)
self.config_vars[full_path] = (var, "bool")
-
+
# 复选框文本包含描述
checkbox_text = key
if description:
checkbox_text = f"{key} ({description})"
-
+
checkbox = ctk.CTkCheckBox(
- frame,
- text=checkbox_text,
- variable=var,
- command=lambda path=full_path: self.on_field_change(path)
+ frame, text=checkbox_text, variable=var, command=lambda path=full_path: self.on_field_change(path)
)
- checkbox.pack(anchor="w", padx=10 + indent*10, pady=5)
-
+ checkbox.pack(anchor="w", padx=10 + indent * 10, pady=5)
+
elif isinstance(value, (int, float)):
# 如果是数字,创建一个数字输入框
frame = ctk.CTkFrame(parent)
frame.pack(fill="x", expand=True, padx=5, pady=5)
-
+
# 标签包含描述
label_text = f"{key}:"
if description:
label_text = f"{key}: ({description})"
-
- label = ctk.CTkLabel(
- frame,
- text=label_text,
- font=("Arial", 12),
- anchor="w"
- )
- label.pack(anchor="w", padx=10 + indent*10, pady=(5, 0))
-
+
+ label = ctk.CTkLabel(frame, text=label_text, font=("Arial", 12), anchor="w")
+ label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0))
+
var = StringVar(value=str(value))
self.config_vars[full_path] = (var, "number", type(value))
-
+
# 判断数值的长度,决定输入框宽度
entry_width = max(200, min(len(str(value)) * 15, 300))
-
+
entry = ctk.CTkEntry(frame, width=entry_width, textvariable=var)
- entry.pack(anchor="w", padx=10 + indent*10, pady=5)
-
+ entry.pack(anchor="w", padx=10 + indent * 10, pady=5)
+
# 绑定变更事件,用于自动保存
entry.bind("", lambda e, path=full_path: self.on_field_change(path))
-
+
else:
# 对于字符串,创建一个文本输入框
frame = ctk.CTkFrame(parent)
frame.pack(fill="x", expand=True, padx=5, pady=5)
-
+
# 标签包含描述
label_text = f"{key}:"
if description:
label_text = f"{key}: ({description})"
-
- label = ctk.CTkLabel(
- frame,
- text=label_text,
- font=("Arial", 12),
- anchor="w"
- )
- label.pack(anchor="w", padx=10 + indent*10, pady=(5, 0))
-
+
+ label = ctk.CTkLabel(frame, text=label_text, font=("Arial", 12), anchor="w")
+ label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0))
+
var = StringVar(value=str(value))
self.config_vars[full_path] = (var, "string")
-
+
# 判断文本长度,决定输入框的类型和大小
text_len = len(str(value))
-
- if text_len > 80 or '\n' in str(value):
+
+ if text_len > 80 or "\n" in str(value):
# 对于长文本或多行文本,使用多行文本框
- text_height = max(80, min(str(value).count('\n') * 20 + 40, 150))
-
- text_frame = ScrollableTextFrame(
- frame,
- initial_text=str(value),
- height=text_height,
- width=550
- )
- text_frame.pack(fill="x", padx=10 + indent*10, pady=5)
+ text_height = max(80, min(str(value).count("\n") * 20 + 40, 150))
+
+ text_frame = ScrollableTextFrame(frame, initial_text=str(value), height=text_height, width=550)
+ text_frame.pack(fill="x", padx=10 + indent * 10, pady=5)
self.config_vars[full_path] = (text_frame.text_var, "string")
-
+
# 绑定变更事件,用于自动保存
text_frame.text_box.bind("", lambda e, path=full_path: self.on_field_change(path))
else:
# 对于短文本,使用单行输入框
# 根据内容长度动态调整输入框宽度
entry_width = max(400, min(text_len * 10, 550))
-
+
entry = ctk.CTkEntry(frame, width=entry_width, textvariable=var)
- entry.pack(anchor="w", padx=10 + indent*10, pady=5, fill="x")
-
+ entry.pack(anchor="w", padx=10 + indent * 10, pady=5, fill="x")
+
# 绑定变更事件,用于自动保存
entry.bind("", lambda e, path=full_path: self.on_field_change(path))
-
+
def on_field_change(self, path):
"""当字段值改变时调用,用于自动保存"""
if self.auto_save.get():
self.save_config(show_message=False)
self.status_label.configure(text=f"已自动保存更改 ({path})")
-
+
def save_config(self, show_message=True):
"""保存配置文件"""
# 更新配置数据
updated = False
_error_path = None
-
+
for path, (var, var_type, *args) in self.config_vars.items():
parts = path.split(".")
-
+
# 如果路径有多层级
target = self.config_data
for p in parts[:-1]:
if p not in target:
target[p] = {}
target = target[p]
-
+
# 根据变量类型更新值
try:
if var_type == "bool":
@@ -751,7 +678,6 @@ class ConfigUI(ctk.CTk):
target[parts[-1]] = var.get()
updated = True
elif var_type == "number":
-
# 获取原始类型(int或float)
num_type = args[0] if args else int
new_value = num_type(var.get())
@@ -760,7 +686,6 @@ class ConfigUI(ctk.CTk):
updated = True
elif var_type == "list":
-
# 解析JSON字符串为列表
new_value = json.loads(var.get())
if json.dumps(target[parts[-1]], sort_keys=True) != json.dumps(new_value, sort_keys=True):
@@ -777,11 +702,11 @@ class ConfigUI(ctk.CTk):
else:
self.status_label.configure(text=f"保存失败: {e}")
return False
-
+
if not updated and show_message:
self.status_label.configure(text="无更改,无需保存")
return True
-
+
# 保存配置
if save_config(self.config_data):
if show_message:
@@ -794,7 +719,7 @@ class ConfigUI(ctk.CTk):
else:
self.status_label.configure(text="保存失败!")
return False
-
+
def reload_config(self):
"""重新加载配置"""
if self.check_unsaved_changes():
@@ -802,28 +727,28 @@ class ConfigUI(ctk.CTk):
if not self.config_data:
messagebox.showerror("错误", "无法加载配置文件!")
return
-
+
# 保存原始配置,用于检测变更
self.original_config = json.dumps(self.config_data, sort_keys=True)
-
+
# 重新显示当前分类
self.show_category(self.active_category)
-
+
self.status_label.configure(text="配置已重新加载")
-
+
def check_unsaved_changes(self):
"""检查是否有未保存的更改"""
# 临时更新配置数据以进行比较
temp_config = self.config_data.copy()
-
+
try:
for path, (var, var_type, *args) in self.config_vars.items():
parts = path.split(".")
-
+
target = temp_config
for p in parts[:-1]:
target = target[p]
-
+
if var_type == "bool":
target[parts[-1]] = var.get()
elif var_type == "number":
@@ -836,24 +761,20 @@ class ConfigUI(ctk.CTk):
except (ValueError, json.JSONDecodeError):
# 如果有无效输入,认为有未保存更改
return False
-
+
# 比较原始配置和当前配置
current_config = json.dumps(temp_config, sort_keys=True)
-
+
if current_config != self.original_config:
- result = messagebox.askyesnocancel(
- "未保存的更改",
- "有未保存的更改,是否保存?",
- icon="warning"
- )
-
+ result = messagebox.askyesnocancel("未保存的更改", "有未保存的更改,是否保存?", icon="warning")
+
if result is None: # 取消
return False
elif result: # 是
return self.save_config()
-
+
return True
-
+
def show_about(self):
"""显示关于对话框"""
about_window = ctk.CTkToplevel(self)
@@ -861,37 +782,25 @@ class ConfigUI(ctk.CTk):
about_window.geometry("400x200")
about_window.resizable(False, False)
about_window.grab_set() # 模态对话框
-
+
# 居中
x = self.winfo_x() + (self.winfo_width() - 400) // 2
y = self.winfo_y() + (self.winfo_height() - 200) // 2
about_window.geometry(f"+{x}+{y}")
-
+
# 内容
- ctk.CTkLabel(
- about_window,
- text="麦麦配置修改器",
- font=("Arial", 16, "bold")
- ).pack(pady=(20, 10))
-
- ctk.CTkLabel(
- about_window,
- text="用于修改MaiBot-Core的配置文件\n配置文件路径: config/bot_config.toml"
- ).pack(pady=5)
-
- ctk.CTkLabel(
- about_window,
- text="注意: 修改配置前请备份原始配置文件",
- text_color=("red", "light coral")
- ).pack(pady=5)
-
- ctk.CTkButton(
- about_window,
- text="确定",
- command=about_window.destroy,
- width=100
- ).pack(pady=15)
-
+ ctk.CTkLabel(about_window, text="麦麦配置修改器", font=("Arial", 16, "bold")).pack(pady=(20, 10))
+
+ ctk.CTkLabel(about_window, text="用于修改MaiBot-Core的配置文件\n配置文件路径: config/bot_config.toml").pack(
+ pady=5
+ )
+
+ ctk.CTkLabel(about_window, text="注意: 修改配置前请备份原始配置文件", text_color=("red", "light coral")).pack(
+ pady=5
+ )
+
+ ctk.CTkButton(about_window, text="确定", command=about_window.destroy, width=100).pack(pady=15)
+
def on_closing(self):
"""关闭窗口前检查未保存更改"""
if self.check_unsaved_changes():
@@ -904,22 +813,22 @@ class ConfigUI(ctk.CTk):
if not os.path.exists(CONFIG_PATH):
messagebox.showerror("错误", "配置文件不存在!")
return False
-
+
# 创建备份目录
backup_dir = os.path.join(os.path.dirname(CONFIG_PATH), "old")
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
-
+
# 生成备份文件名(使用时间戳)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"bot_config_{timestamp}.toml.bak"
backup_path = os.path.join(backup_dir, backup_filename)
-
+
# 复制文件
with open(CONFIG_PATH, "r", encoding="utf-8") as src:
with open(backup_path, "w", encoding="utf-8") as dst:
dst.write(src.read())
-
+
messagebox.showinfo("成功", f"配置已备份到:\n{backup_path}")
self.status_label.configure(text=f"手动备份已创建: {backup_filename}")
return True
@@ -933,7 +842,7 @@ class ConfigUI(ctk.CTk):
backup_dir = os.path.join(os.path.dirname(CONFIG_PATH), "old")
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
-
+
# 查找备份文件
backup_files = []
for filename in os.listdir(backup_dir):
@@ -941,109 +850,92 @@ class ConfigUI(ctk.CTk):
backup_path = os.path.join(backup_dir, filename)
mod_time = os.path.getmtime(backup_path)
backup_files.append((filename, backup_path, mod_time))
-
+
if not backup_files:
messagebox.showinfo("提示", "未找到备份文件")
return
-
+
# 按修改时间排序,最新的在前
backup_files.sort(key=lambda x: x[2], reverse=True)
-
+
# 创建备份查看窗口
backup_window = ctk.CTkToplevel(self)
backup_window.title("备份文件")
backup_window.geometry("600x400")
backup_window.grab_set() # 模态对话框
-
+
# 居中
x = self.winfo_x() + (self.winfo_width() - 600) // 2
y = self.winfo_y() + (self.winfo_height() - 400) // 2
backup_window.geometry(f"+{x}+{y}")
-
+
# 创建说明标签
- ctk.CTkLabel(
- backup_window,
- text="备份文件列表 (双击可恢复)",
- font=("Arial", 14, "bold")
- ).pack(pady=(10, 5), padx=10, anchor="w")
-
+ ctk.CTkLabel(backup_window, text="备份文件列表 (双击可恢复)", font=("Arial", 14, "bold")).pack(
+ pady=(10, 5), padx=10, anchor="w"
+ )
+
# 创建列表框
backup_frame = ctk.CTkScrollableFrame(backup_window, width=580, height=300)
backup_frame.pack(padx=10, pady=10, fill="both", expand=True)
-
+
# 添加备份文件项
for _i, (filename, filepath, mod_time) in enumerate(backup_files):
# 格式化时间为可读格式
time_str = datetime.datetime.fromtimestamp(mod_time).strftime("%Y-%m-%d %H:%M:%S")
-
+
# 创建一个框架用于每个备份项
item_frame = ctk.CTkFrame(backup_frame)
item_frame.pack(fill="x", padx=5, pady=5)
-
+
# 显示备份文件信息
- ctk.CTkLabel(
- item_frame,
- text=f"{time_str}",
- font=("Arial", 12, "bold"),
- width=200
- ).pack(side="left", padx=10, pady=10)
-
- # 文件名
- name_label = ctk.CTkLabel(
- item_frame,
- text=filename,
- font=("Arial", 11)
+ ctk.CTkLabel(item_frame, text=f"{time_str}", font=("Arial", 12, "bold"), width=200).pack(
+ side="left", padx=10, pady=10
)
+
+ # 文件名
+ name_label = ctk.CTkLabel(item_frame, text=filename, font=("Arial", 11))
name_label.pack(side="left", fill="x", expand=True, padx=5, pady=10)
-
+
# 恢复按钮
restore_btn = ctk.CTkButton(
- item_frame,
- text="恢复",
- width=80,
- command=lambda path=filepath: self.restore_backup(path)
+ item_frame, text="恢复", width=80, command=lambda path=filepath: self.restore_backup(path)
)
restore_btn.pack(side="right", padx=10, pady=10)
-
+
# 绑定双击事件
for widget in (item_frame, name_label):
widget.bind("", lambda e, path=filepath: self.restore_backup(path))
-
+
# 关闭按钮
- ctk.CTkButton(
- backup_window,
- text="关闭",
- command=backup_window.destroy,
- width=100
- ).pack(pady=10)
+ ctk.CTkButton(backup_window, text="关闭", command=backup_window.destroy, width=100).pack(pady=10)
def restore_backup(self, backup_path):
"""从备份文件恢复配置"""
if not os.path.exists(backup_path):
messagebox.showerror("错误", "备份文件不存在!")
return False
-
+
# 确认还原
confirm = messagebox.askyesno(
- "确认",
+ "确认",
f"确定要从以下备份文件恢复配置吗?\n{os.path.basename(backup_path)}\n\n这将覆盖当前的配置!",
- icon="warning"
+ icon="warning",
)
-
+
if not confirm:
return False
-
+
try:
# 先备份当前配置
self.backup_config()
-
+
# 恢复配置
with open(backup_path, "r", encoding="utf-8") as src:
with open(CONFIG_PATH, "w", encoding="utf-8") as dst:
dst.write(src.read())
-
+
messagebox.showinfo("成功", "配置已从备份恢复!")
-
+
# 重新加载配置
self.reload_config()
return True
@@ -1058,143 +950,141 @@ class ConfigUI(ctk.CTk):
search_window.title("搜索配置项")
search_window.geometry("500x400")
search_window.grab_set() # 模态对话框
-
+
# 居中
x = self.winfo_x() + (self.winfo_width() - 500) // 2
y = self.winfo_y() + (self.winfo_height() - 400) // 2
search_window.geometry(f"+{x}+{y}")
-
+
# 搜索框
search_frame = ctk.CTkFrame(search_window)
search_frame.pack(fill="x", padx=10, pady=10)
-
+
search_var = StringVar()
- search_entry = ctk.CTkEntry(search_frame, placeholder_text="输入关键词搜索...", width=380, textvariable=search_var)
+ search_entry = ctk.CTkEntry(
+ search_frame, placeholder_text="输入关键词搜索...", width=380, textvariable=search_var
+ )
search_entry.pack(side="left", padx=5, pady=5, fill="x", expand=True)
-
+
# 结果列表框
results_frame = ctk.CTkScrollableFrame(search_window, width=480, height=300)
results_frame.pack(padx=10, pady=5, fill="both", expand=True)
-
+
# 搜索结果标签
results_label = ctk.CTkLabel(results_frame, text="请输入关键词进行搜索", anchor="w")
results_label.pack(fill="x", padx=10, pady=10)
-
+
# 结果项列表
results_items = []
-
+
# 搜索函数
def perform_search():
# 清除之前的结果
for item in results_items:
item.destroy()
results_items.clear()
-
+
keyword = search_var.get().lower()
if not keyword:
results_label.configure(text="请输入关键词进行搜索")
return
-
+
# 收集所有匹配的配置项
matches = []
-
+
def search_config(section_path, config_data):
for key, value in config_data.items():
full_path = f"{section_path}.{key}" if section_path else key
-
+
# 检查键名是否匹配
if keyword in key.lower():
matches.append((full_path, value))
-
+
# 检查描述是否匹配
description = get_description(full_path)
if description and keyword in description.lower():
matches.append((full_path, value))
-
+
# 检查值是否匹配(仅字符串类型)
if isinstance(value, str) and keyword in value.lower():
matches.append((full_path, value))
-
+
# 递归搜索子项
if isinstance(value, dict):
search_config(full_path, value)
-
+
# 开始搜索
search_config("", self.config_data)
-
+
if not matches:
results_label.configure(text=f"未找到包含 '{keyword}' 的配置项")
return
-
+
results_label.configure(text=f"找到 {len(matches)} 个匹配项")
-
+
# 显示搜索结果
for full_path, value in matches:
# 创建一个框架用于每个结果项
item_frame = ctk.CTkFrame(results_frame)
item_frame.pack(fill="x", padx=5, pady=3)
results_items.append(item_frame)
-
+
# 配置项路径
path_parts = full_path.split(".")
section = path_parts[0] if len(path_parts) > 0 else ""
_key = path_parts[-1] if len(path_parts) > 0 else ""
-
+
# 获取描述
description = get_description(full_path)
desc_text = f" ({description})" if description else ""
-
+
# 显示完整路径
path_label = ctk.CTkLabel(
- item_frame,
- text=f"{full_path}{desc_text}",
+ item_frame,
+ text=f"{full_path}{desc_text}",
font=("Arial", 11, "bold"),
anchor="w",
- wraplength=450
+ wraplength=450,
)
path_label.pack(anchor="w", padx=10, pady=(5, 0), fill="x")
-
+
# 显示值的预览(截断过长的值)
value_str = str(value)
if len(value_str) > 50:
value_str = value_str[:50] + "..."
-
+
value_label = ctk.CTkLabel(
- item_frame,
- text=f"值: {value_str}",
- font=("Arial", 10),
- anchor="w",
- wraplength=450
+ item_frame, text=f"值: {value_str}", font=("Arial", 10), anchor="w", wraplength=450
)
value_label.pack(anchor="w", padx=10, pady=(0, 5), fill="x")
-
+
# 添加"转到"按钮
goto_btn = ctk.CTkButton(
- item_frame,
- text="转到",
+ item_frame,
+ text="转到",
width=60,
height=25,
- command=lambda s=section: self.goto_config_item(s, search_window)
+ command=lambda s=section: self.goto_config_item(s, search_window),
)
goto_btn.pack(side="right", padx=10, pady=5)
-
+
# 绑定双击事件
for widget in (item_frame, path_label, value_label):
widget.bind("", lambda e, s=section: self.goto_config_item(s, search_window))
-
+
# 搜索按钮
search_button = ctk.CTkButton(search_frame, text="搜索", width=80, command=perform_search)
search_button.pack(side="right", padx=5, pady=5)
-
+
# 绑定回车键
search_entry.bind("", lambda e: perform_search())
-
+
# 初始聚焦到搜索框
search_window.after(100, lambda: self.safe_focus(search_entry))
except Exception as e:
print(f"显示搜索对话框出错: {e}")
messagebox.showerror("错误", f"显示搜索对话框失败: {e}")
-
+
def safe_focus(self, widget):
"""安全地设置焦点,避免应用崩溃"""
try:
@@ -1208,7 +1098,7 @@ class ConfigUI(ctk.CTk):
"""跳转到指定的配置项"""
if dialog:
dialog.destroy()
-
+
# 切换到相应的分类
if section in self.category_buttons:
self.show_category(section)
@@ -1220,44 +1110,29 @@ class ConfigUI(ctk.CTk):
menu_window.geometry("300x200")
menu_window.resizable(False, False)
menu_window.grab_set() # 模态对话框
-
+
# 居中
x = self.winfo_x() + (self.winfo_width() - 300) // 2
y = self.winfo_y() + (self.winfo_height() - 200) // 2
menu_window.geometry(f"+{x}+{y}")
-
+
# 创建按钮
- ctk.CTkLabel(
- menu_window,
- text="配置导入导出",
- font=("Arial", 16, "bold")
- ).pack(pady=(20, 10))
-
+ ctk.CTkLabel(menu_window, text="配置导入导出", font=("Arial", 16, "bold")).pack(pady=(20, 10))
+
# 导出按钮
export_btn = ctk.CTkButton(
- menu_window,
- text="导出配置到文件",
- command=lambda: self.export_config(menu_window),
- width=200
+ menu_window, text="导出配置到文件", command=lambda: self.export_config(menu_window), width=200
)
export_btn.pack(pady=10)
-
+
# 导入按钮
import_btn = ctk.CTkButton(
- menu_window,
- text="从文件导入配置",
- command=lambda: self.import_config(menu_window),
- width=200
+ menu_window, text="从文件导入配置", command=lambda: self.import_config(menu_window), width=200
)
import_btn.pack(pady=10)
-
+
# 取消按钮
- cancel_btn = ctk.CTkButton(
- menu_window,
- text="取消",
- command=menu_window.destroy,
- width=100
- )
+ cancel_btn = ctk.CTkButton(menu_window, text="取消", command=menu_window.destroy, width=100)
cancel_btn.pack(pady=10)
def export_config(self, parent_window=None):
@@ -1268,31 +1143,31 @@ class ConfigUI(ctk.CTk):
pass
else:
return
-
+
# 选择保存位置
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
default_filename = f"bot_config_export_{timestamp}.toml"
-
+
file_path = filedialog.asksaveasfilename(
title="导出配置",
filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")],
defaultextension=".toml",
- initialfile=default_filename
+ initialfile=default_filename,
)
-
+
if not file_path:
return
-
+
try:
# 复制当前配置文件到选择的位置
shutil.copy2(CONFIG_PATH, file_path)
-
+
messagebox.showinfo("成功", f"配置已导出到:\n{file_path}")
self.status_label.configure(text=f"配置已导出到: {file_path}")
-
+
if parent_window:
parent_window.destroy()
-
+
return True
except Exception as e:
messagebox.showerror("导出失败", f"导出配置失败: {e}")
@@ -1303,57 +1178,55 @@ class ConfigUI(ctk.CTk):
# 先检查是否有未保存的更改
if not self.check_unsaved_changes():
return
-
+
# 选择要导入的文件
file_path = filedialog.askopenfilename(
- title="导入配置",
- filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")]
+ title="导入配置", filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")]
)
-
+
if not file_path:
return
-
+
try:
# 尝试加载TOML文件以验证格式
with open(file_path, "r", encoding="utf-8") as f:
import_data = toml.load(f)
-
+
# 验证导入文件的基本结构
if "inner" not in import_data:
raise ValueError("导入的配置文件没有inner部分,格式不正确")
-
+
if "version" not in import_data["inner"]:
raise ValueError("导入的配置文件没有版本信息,格式不正确")
-
+
# 确认导入
confirm = messagebox.askyesno(
- "确认导入",
- f"确定要导入此配置文件吗?\n{file_path}\n\n这将替换当前的配置!",
- icon="warning"
+ "确认导入", f"确定要导入此配置文件吗?\n{file_path}\n\n这将替换当前的配置!", icon="warning"
)
-
+
if not confirm:
return
-
+
# 先备份当前配置
self.backup_config()
-
+
# 复制导入的文件到配置位置
shutil.copy2(file_path, CONFIG_PATH)
-
+
messagebox.showinfo("成功", "配置已导入,请重新加载以应用更改")
-
+
# 重新加载配置
self.reload_config()
-
+
if parent_window:
parent_window.destroy()
-
+
return True
except Exception as e:
messagebox.showerror("导入失败", f"导入配置失败: {e}")
return False
+
# 主函数
def main():
try:
@@ -1365,6 +1238,7 @@ def main():
import tkinter as tk
from tkinter import messagebox
+
root = tk.Tk()
root.withdraw()
messagebox.showerror("程序错误", f"程序运行时发生错误:\n{e}")